File ‹mat_alg.ML›
fun string_of_terms ctxt ts =
ts |> map (Syntax.pretty_term ctxt)
|> Pretty.commas |> Pretty.block |> Pretty.string_of
fun trace_t ctxt s t =
tracing (s ^ " " ^ (Syntax.string_of_term ctxt t))
fun trace_fullthm ctxt s th =
tracing (s ^ " [" ^ (Thm.hyps_of th |> string_of_terms ctxt) ^
"] ==> " ^ (Thm.prop_of th |> Syntax.string_of_term ctxt))
val natT = HOLogic.natT
fun is_times t =
case t of
Const (@{const_name times}, _) $ _ $ _ => true
| _ => false
fun is_plus t =
case t of
Const (@{const_name plus}, _) $ _ $ _ => true
| _ => false
fun is_minus t =
case t of
Const (@{const_name minus}, _) $ _ $ _ => true
| _ => false
fun is_uminus t =
case t of
Const (@{const_name uminus}, _) $ _ => true
| _ => false
fun dest_binop t =
case t of
_ $ a $ b => (a, b)
| _ => raise Fail "dest_binop"
fun dest_arg t =
case t of
_ $ x => x
| _ => raise Fail "dest_arg"
fun dest_arg1 t =
case t of
_ $ arg1 $ _ => arg1
| _ => raise Fail "dest_arg1"
fun is_mat_type t =
is_Type (fastype_of t) andalso
(fastype_of t |> dest_Type |> fst) = "Matrix.mat"
fun is_smult_mat t =
case t of
Const (@{const_name smult_mat}, _) $ _ $ _ => true
| _ => false
fun is_adjoint t =
case t of
Const (@{const_name mat_adjoint}, _) $ _ => true
| _ => false
fun is_id_mat t =
case t of
Const (@{const_name one_mat}, _) $ _ => true
| _ => false
fun is_zero_mat t =
case t of
Const (@{const_name zero_mat}, _) $ _ $ _ => true
| _ => false
fun strip_times t =
if is_times t then
strip_times (dest_arg1 t) @ [dest_arg t]
else
[t]
fun carrier_mat n t =
let
val T = fastype_of t
val Tset = HOLogic.mk_setT T
in
Const (@{const_name carrier_mat}, natT --> natT --> Tset) $ n $ n
end
fun mk_mem_carrier n t =
HOLogic.mk_mem (t, carrier_mat n t)
fun assume_carrier ctxt n t =
Thm.assume (Thm.cterm_of ctxt (HOLogic.mk_Trueprop (mk_mem_carrier n t)))
fun prod_in_carrier ctxt n t =
if is_times t then
let
val (a, b) = dest_binop t
val th1 = prod_in_carrier ctxt n a
val th2 = prod_in_carrier ctxt n b
in
[th1, th2] MRS @{thm mult_carrier_mat}
end
else if is_plus t then
let
val (a, b) = dest_binop t
val th1 = prod_in_carrier ctxt n a
val th2 = prod_in_carrier ctxt n b
in
[th1, th2] MRS @{thm add_carrier_mat'}
end
else if is_uminus t then
let
val a = dest_arg t
val th = prod_in_carrier ctxt n a
in
th RS @{thm uminus_carrier_mat}
end
else if is_minus t then
let
val (a, b) = dest_binop t
val th1 = prod_in_carrier ctxt n a
val th2 = prod_in_carrier ctxt n b
in
[th1, th2] MRS @{thm minus_carrier_mat'}
end
else if is_adjoint t then
let
val a = dest_arg t
val th = prod_in_carrier ctxt n a
in
th RS @{thm adjoint_dim}
end
else if is_smult_mat t then
let
val a = dest_arg t
val th = prod_in_carrier ctxt n a
in
th RS @{thm smult_carrier_mat}
end
else
assume_carrier ctxt n t
fun obj_sym th =
th RS @{thm HOL.sym}
fun to_meta_eq th =
th RS @{thm HOL.eq_reflection}
fun to_obj_eq th =
th RS @{thm HOL.meta_eq_to_obj_eq}
fun rewr_cv ctxt n th ct =
let
val th = to_meta_eq th
val pat = th |> Thm.concl_of |> dest_arg1 |> Thm.cterm_of ctxt
val inst = Thm.match (pat, ct)
val th = Thm.instantiate inst th
val prems = map (fn prem => prod_in_carrier ctxt n (prem |> dest_arg |> dest_arg1))
(Thm.prems_of th)
in
prems MRS th
end
handle THM _ => let val _ = trace_fullthm ctxt "here" th in raise Fail "THM" end
| Pattern.MATCH => let val _ = trace_fullthm ctxt "here" th in raise Fail "MATCH" end
fun assoc_times_norm ctxt n ct =
let
val t = Thm.term_of ct
val (a, b) = dest_binop t
in
if is_smult_mat a then
Conv.every_conv [
rewr_cv ctxt n @{thm mult_smult_assoc_mat},
Conv.arg_conv (assoc_times_norm ctxt n)] ct
else if is_smult_mat b then
Conv.every_conv [
rewr_cv ctxt n @{thm mult_smult_distrib},
Conv.arg_conv (assoc_times_norm ctxt n)] ct
else if is_times b then
Conv.every_conv [
rewr_cv ctxt n (obj_sym @{thm assoc_mult_mat}),
Conv.arg1_conv (assoc_times_norm ctxt n)] ct
else if is_id_mat a then
rewr_cv ctxt n @{thm left_mult_one_mat} ct
else if is_id_mat b then
rewr_cv ctxt n @{thm right_mult_one_mat} ct
else
Conv.all_conv ct
end
fun assoc_plus_one_norm ctxt n ct =
let
val t = Thm.term_of ct
val (a, b) = dest_binop t
in
if not (is_mat_type t) then
Conv.all_conv ct
else if is_plus a then
if Term_Ord.term_ord (dest_arg a, b) = GREATER then
Conv.every_conv [
rewr_cv ctxt n @{thm swap_plus_mat},
Conv.arg1_conv (assoc_plus_one_norm ctxt n)] ct
else
Conv.all_conv ct
else
if Term_Ord.term_ord (a, b) = GREATER then
rewr_cv ctxt n @{thm comm_add_mat} ct
else
Conv.all_conv ct
end
fun assoc_plus_norm ctxt n ct =
let
val t = Thm.term_of ct
val (a, b) = dest_binop t
in
if not (is_mat_type t) then
Conv.all_conv ct
else if is_plus b then
Conv.every_conv [
rewr_cv ctxt n (obj_sym @{thm assoc_add_mat}),
Conv.arg1_conv (assoc_plus_norm ctxt n),
assoc_plus_one_norm ctxt n] ct
else if is_zero_mat a then
rewr_cv ctxt n @{thm left_add_zero_mat} ct
else if is_zero_mat b then
rewr_cv ctxt n @{thm right_add_zero_mat} ct
else
assoc_plus_one_norm ctxt n ct
end
fun smult_plus_norm ctxt n ct =
let
val t = Thm.term_of ct
in
if is_plus (dest_arg t) then
Conv.every_conv [
rewr_cv ctxt n @{thm add_smult_distrib_left_mat},
Conv.arg1_conv (smult_plus_norm ctxt n)] ct
else
Conv.all_conv ct
end
fun norm_mult_poly_monomial ctxt n ct =
let
val t = Thm.term_of ct
in
if is_plus (dest_arg1 t) then
Conv.every_conv [
rewr_cv ctxt n @{thm add_mult_distrib_mat},
Conv.arg1_conv (norm_mult_poly_monomial ctxt n),
Conv.arg_conv (assoc_times_norm ctxt n),
assoc_plus_norm ctxt n] ct
else
assoc_times_norm ctxt n ct
end
fun norm_mult_polynomials ctxt n ct =
let
val t = Thm.term_of ct
in
if is_plus (dest_arg t) then
Conv.every_conv [
rewr_cv ctxt n @{thm mult_add_distrib_mat},
Conv.arg1_conv (norm_mult_polynomials ctxt n),
Conv.arg_conv (norm_mult_poly_monomial ctxt n),
assoc_plus_norm ctxt n] ct
else
norm_mult_poly_monomial ctxt n ct
end
fun is_trace t =
case t of
Const (@{const_name trace}, _) $ _ => true
| _ => false
fun norm_trace_times ctxt n ct =
let
val tt = Thm.term_of ct
val t = dest_arg tt
val ts = strip_times t
val (rest, last) = split_last ts
in
if exists (fn t' => Term_Ord.term_ord (last, t') = LESS) rest then
Conv.every_conv [
rewr_cv ctxt n @{thm trace_comm},
Conv.arg_conv (assoc_times_norm ctxt n),
norm_trace_times ctxt n] ct
else
Conv.all_conv ct
end
fun norm_trace_plus ctxt n ct =
let
val tt = Thm.term_of ct
val t = dest_arg tt
in
if is_plus t then
Conv.every_conv [
rewr_cv ctxt n @{thm trace_add_linear},
Conv.arg1_conv (norm_trace_plus ctxt n),
Conv.arg_conv (norm_trace_times ctxt n)] ct
else
norm_trace_times ctxt n ct
end
fun assoc_norm ctxt n ct =
let
val t = Thm.term_of ct
in
if is_times t then
Conv.every_conv [
Conv.binop_conv (assoc_norm ctxt n),
norm_mult_polynomials ctxt n] ct
else if is_plus t then
Conv.every_conv [
Conv.binop_conv (assoc_norm ctxt n),
assoc_plus_norm ctxt n] ct
else if is_smult_mat t then
Conv.every_conv [
Conv.arg_conv (assoc_norm ctxt n),
smult_plus_norm ctxt n] ct
else if is_minus t then
Conv.every_conv [
rewr_cv ctxt n @{thm minus_add_uminus_mat},
assoc_norm ctxt n] ct
else if is_uminus t then
Conv.every_conv [
rewr_cv ctxt n @{thm uminus_mat},
assoc_norm ctxt n] ct
else if is_adjoint t then
if is_times (dest_arg t) then
Conv.every_conv [
rewr_cv ctxt n @{thm adjoint_mult},
assoc_norm ctxt n] ct
else if is_adjoint (dest_arg t) then
Conv.every_conv [
Conv.rewr_conv (to_meta_eq @{thm adjoint_adjoint}),
assoc_norm ctxt n] ct
else
Conv.all_conv ct
else if is_trace t then
Conv.every_conv [
Conv.arg_conv (assoc_norm ctxt n),
norm_trace_plus ctxt n] ct
else
Conv.all_conv ct
end
fun prove_by_assoc_norm ctxt n t =
let
val _ = trace_t ctxt "To show equation:" t
val (a, b) = dest_binop t
val norm1 = assoc_norm ctxt n (Thm.cterm_of ctxt a)
val norm2 = assoc_norm ctxt n (Thm.cterm_of ctxt b)
in
if Thm.rhs_of norm1 aconvc Thm.rhs_of norm2 then
let
val res = Thm.transitive norm1 (Thm.symmetric norm2)
in
res |> to_obj_eq
end
else
let
val _ = trace_t ctxt "Left side is:" (Thm.term_of (Thm.rhs_of norm1))
val _ = trace_t ctxt "Right side is:" (Thm.term_of (Thm.rhs_of norm2))
in
raise Fail "Normalization are not equal."
end
end
fun prove_by_assoc_norm_tac n ctxt state =
let
val n = Syntax.read_term ctxt n
val subgoals = Thm.prems_of state
in
if null subgoals then Seq.empty else
let
val subgoal = state |> Drule.cprems_of |> hd
val (cprems, cconcl) = (Drule.strip_imp_prems subgoal, Drule.strip_imp_concl subgoal)
val concl = HOLogic.dest_Trueprop (Thm.term_of cconcl)
val subgoal_th = fold Thm.implies_intr (rev cprems) (prove_by_assoc_norm ctxt n concl)
val chyps = Thm.chyps_of subgoal_th
val res = Thm.implies_elim state subgoal_th
in
Seq.single (fold Thm.implies_intr chyps res)
end
end
val mat_assoc_method : (Proof.context -> Method.method) context_parser =
Scan.lift Parse.term >> (fn n => fn ctxt => (SIMPLE_METHOD (prove_by_assoc_norm_tac n ctxt)))