File ‹Tools/Activation_Functions.ML›
structure Activation_Term:ACTIVATION_TERM = struct
datatype mode = Single | MultiList | MultiMatrix
open TensorFlow_Type
fun term_of_activation_single Linear = @{Const ‹Activation_Functions.activation⇩s⇩i⇩n⇩g⇩l⇩e.Identity›}
| term_of_activation_single Softsign = @{Const ‹Activation_Functions.activation⇩s⇩i⇩n⇩g⇩l⇩e.SoftSign›}
| term_of_activation_single Sign = @{Const ‹Activation_Functions.activation⇩s⇩i⇩n⇩g⇩l⇩e.Sign›}
| term_of_activation_single BinaryStep = @{Const ‹Activation_Functions.activation⇩s⇩i⇩n⇩g⇩l⇩e.BinaryStep›}
| term_of_activation_single Sigmoid = @{Const ‹Activation_Functions.activation⇩s⇩i⇩n⇩g⇩l⇩e.Sigmoid›}
| term_of_activation_single Swish = @{Const ‹Activation_Functions.activation⇩s⇩i⇩n⇩g⇩l⇩e.Swish›}
| term_of_activation_single Tanh = @{Const ‹Activation_Functions.activation⇩s⇩i⇩n⇩g⇩l⇩e.Tanh›}
| term_of_activation_single Relu = @{Const ‹Activation_Functions.activation⇩s⇩i⇩n⇩g⇩l⇩e.ReLU›}
| term_of_activation_single Gelu = @{Const ‹Activation_Functions.activation⇩s⇩i⇩n⇩g⇩l⇩e.GeLUapprox›}
| term_of_activation_single GRelu = @{Const ‹Activation_Functions.activation⇩s⇩i⇩n⇩g⇩l⇩e.GReLU›}
| term_of_activation_single Softplus = @{Const ‹Activation_Functions.activation⇩s⇩i⇩n⇩g⇩l⇩e.Softplus›}
| term_of_activation_single Elu = @{Const ‹Activation_Functions.activation⇩s⇩i⇩n⇩g⇩l⇩e.ELU›}
| term_of_activation_single Selu = @{Const ‹Activation_Functions.activation⇩s⇩i⇩n⇩g⇩l⇩e.SELU›}
| term_of_activation_single Exponential = @{Const ‹Activation_Functions.activation⇩s⇩i⇩n⇩g⇩l⇩e.Exp›}
| term_of_activation_single Hard_sigmoid = @{Const ‹Activation_Functions.activation⇩s⇩i⇩n⇩g⇩l⇩e.HardSigmoid›}
| term_of_activation_single Sigmoid_taylor = (@{Const ‹Activation_Functions.activation⇩s⇩i⇩n⇩g⇩l⇩e.Sigmoid⇩t⇩a⇩y⇩l⇩o⇩r›}$HOLogic.mk_nat 2)
| term_of_activation_single Softmax = error "Activation fuction 'softmax' is not a single class activation function."
| term_of_activation_single Softmax_taylor = error "Activation fuction 'softmax_taylor' is not a single class activation function."
fun term_of_activation_multi Linear = @{Const ‹Activation_Functions.activation⇩m⇩u⇩l⇩t⇩i.mIdentity›}
| term_of_activation_multi Softsign = @{Const ‹Activation_Functions.activation⇩m⇩u⇩l⇩t⇩i.mSoftSign›}
| term_of_activation_multi Sign = @{Const ‹Activation_Functions.activation⇩m⇩u⇩l⇩t⇩i.mSign›}
| term_of_activation_multi BinaryStep = @{Const ‹Activation_Functions.activation⇩m⇩u⇩l⇩t⇩i.mBinaryStep›}
| term_of_activation_multi Sigmoid = @{Const ‹Activation_Functions.activation⇩m⇩u⇩l⇩t⇩i.mSigmoid›}
| term_of_activation_multi Swish = @{Const ‹Activation_Functions.activation⇩m⇩u⇩l⇩t⇩i.mSwish›}
| term_of_activation_multi Tanh = @{Const ‹Activation_Functions.activation⇩m⇩u⇩l⇩t⇩i.mTanh›}
| term_of_activation_multi Relu = @{Const ‹Activation_Functions.activation⇩m⇩u⇩l⇩t⇩i.mReLU›}
| term_of_activation_multi Gelu = @{Const ‹Activation_Functions.activation⇩m⇩u⇩l⇩t⇩i.mGeLUapprox›}
| term_of_activation_multi GRelu = @{Const ‹Activation_Functions.activation⇩m⇩u⇩l⇩t⇩i.mGReLU›}
| term_of_activation_multi Softplus = @{Const ‹Activation_Functions.activation⇩m⇩u⇩l⇩t⇩i.mSoftplus›}
| term_of_activation_multi Elu = @{Const ‹Activation_Functions.activation⇩m⇩u⇩l⇩t⇩i.mELU›}
| term_of_activation_multi Selu = @{Const ‹Activation_Functions.activation⇩m⇩u⇩l⇩t⇩i.mSELU›}
| term_of_activation_multi Exponential = @{Const ‹Activation_Functions.activation⇩m⇩u⇩l⇩t⇩i.mExp›}
| term_of_activation_multi Hard_sigmoid = @{Const ‹Activation_Functions.activation⇩m⇩u⇩l⇩t⇩i.mHardSigmoid›}
| term_of_activation_multi Softmax = @{Const ‹Activation_Functions.activation⇩m⇩u⇩l⇩t⇩i.mSoftmax›}
| term_of_activation_multi Sigmoid_taylor = (@{Const ‹Activation_Functions.activation⇩m⇩u⇩l⇩t⇩i.mSigmoid⇩t⇩a⇩y⇩l⇩o⇩r›}$HOLogic.mk_nat 2)
| term_of_activation_multi Softmax_taylor = @{Const ‹Activation_Functions.activation⇩m⇩u⇩l⇩t⇩i.mSoftmax⇩t⇩a⇩y⇩l⇩o⇩r›}
$(HOLogic.mk_nat 2)
fun term_of_activation_eqn_single Linear = @{const ‹Activation_Functions.identity›(‹real›)}
| term_of_activation_eqn_single Softsign = @{const ‹Activation_Functions.softsign›(‹real›)}
| term_of_activation_eqn_single Sign = @{const ‹Activation_Functions.sign›(‹real›)}
| term_of_activation_eqn_single BinaryStep = @{const ‹Activation_Functions.binary_step›(‹real›)}
| term_of_activation_eqn_single Sigmoid = @{const ‹Activation_Functions.sigmoid›}
| term_of_activation_eqn_single Swish = @{const ‹Activation_Functions.swish›}
| term_of_activation_eqn_single Tanh = @{const ‹Transcendental.tanh›(‹real›)}
| term_of_activation_eqn_single Relu = @{const ‹Activation_Functions.relu›(‹real›)}
| term_of_activation_eqn_single Gelu = @{const ‹Activation_Functions.gelu_approx›}
| term_of_activation_eqn_single GRelu = @{const ‹Activation_Functions.generalized_relu›(‹real›)}
| term_of_activation_eqn_single Softplus = @{const ‹Activation_Functions.softplus›(‹real›)}
| term_of_activation_eqn_single Elu = @{const ‹Activation_Functions.elu›(‹real›)}
| term_of_activation_eqn_single Selu = @{const ‹Activation_Functions.selu›(‹real›)}
| term_of_activation_eqn_single Exponential = @{const ‹Transcendental.exp›(‹real›)}
| term_of_activation_eqn_single Hard_sigmoid = @{const ‹Activation_Functions.hard_sigmoid›(‹real›)}
| term_of_activation_eqn_single Sigmoid_taylor = @{const ‹Activation_Functions.sigmoid⇩t⇩a⇩y⇩l⇩o⇩r›(‹real›)}
| term_of_activation_eqn_single Softmax = error "Activation fuction 'softmax' is not a single class activation function."
| term_of_activation_eqn_single Softmax_taylor = error "Activation fuction 'softmax_taylor' is not a single class activation function."
fun term_of_activation_eqn_multi_list Linear = (@{const ‹map›(real,real)})$(@{const ‹Activation_Functions.identity›(‹real›)})
| term_of_activation_eqn_multi_list Softsign = (@{const ‹map›(real,real)})$(@{const ‹Activation_Functions.softsign›(‹real›)})
| term_of_activation_eqn_multi_list Sign = (@{const ‹map›(real,real)})$(@{const ‹Activation_Functions.sign›(‹real›)})
| term_of_activation_eqn_multi_list BinaryStep = (@{const ‹map›(real,real)})$(@{const ‹Activation_Functions.binary_step›(‹real›)})
| term_of_activation_eqn_multi_list Sigmoid = (@{const ‹map›(real,real)})$(@{const ‹Activation_Functions.sigmoid›})
| term_of_activation_eqn_multi_list Swish = (@{const ‹map›(real,real)})$(@{const ‹Activation_Functions.swish›})
| term_of_activation_eqn_multi_list Tanh = (@{const ‹map›(real,real)})$(@{const ‹Transcendental.tanh›(‹real›)})
| term_of_activation_eqn_multi_list Relu = (@{const ‹map›(real,real)})$(@{const ‹Activation_Functions.relu›(‹real›)})
| term_of_activation_eqn_multi_list Gelu = (@{const ‹map›(real,real)})$(@{const ‹Activation_Functions.gelu_approx›})
| term_of_activation_eqn_multi_list GRelu = (@{const ‹map›(real,real)})$(@{const ‹Activation_Functions.generalized_relu›(‹real›)})
| term_of_activation_eqn_multi_list Softplus = (@{const ‹map›(real,real)})$(@{const ‹Activation_Functions.softplus›(‹real›)})
| term_of_activation_eqn_multi_list Elu = (@{const ‹map›(real,real)})$(@{const ‹Activation_Functions.elu›(‹real›)})
| term_of_activation_eqn_multi_list Selu = (@{const ‹map›(real,real)})$(@{const ‹Activation_Functions.selu›(‹real›)})
| term_of_activation_eqn_multi_list Exponential = (@{const ‹map›(real,real)})$(@{const ‹Transcendental.exp›(‹real›)})
| term_of_activation_eqn_multi_list Hard_sigmoid = (@{const ‹map›(real,real)})$(@{const ‹Activation_Functions.hard_sigmoid›(‹real›)})
| term_of_activation_eqn_multi_list Sigmoid_taylor = (@{const ‹map›(real,real)})$(@{const ‹Activation_Functions.sigmoid⇩t⇩a⇩y⇩l⇩o⇩r›(‹real›)}$(HOLogic.mk_nat 2) )
| term_of_activation_eqn_multi_list Softmax = (@{const ‹Activation_Functions.softmax›(‹real›)})
| term_of_activation_eqn_multi_list Softmax_taylor = (@{const ‹Activation_Functions.softmax⇩t⇩a⇩y⇩l⇩o⇩r›(‹real›)}$HOLogic.mk_nat 2)
fun term_of_activation_eqn_multi_matrix Linear = (@{const ‹map_vec›(real,real)})$(@{const ‹Activation_Functions.identity›(‹real›)})
| term_of_activation_eqn_multi_matrix Softsign = (@{const ‹map_vec›(real,real)})$(@{const ‹Activation_Functions.softsign›(‹real›)})
| term_of_activation_eqn_multi_matrix Sign = (@{const ‹map_vec›(real,real)})$(@{const ‹Activation_Functions.sign›(‹real›)})
| term_of_activation_eqn_multi_matrix BinaryStep = (@{const ‹map_vec›(real,real)})$(@{const ‹Activation_Functions.binary_step›(‹real›)})
| term_of_activation_eqn_multi_matrix Sigmoid = (@{const ‹map_vec›(real,real)})$(@{const ‹Activation_Functions.sigmoid›})
| term_of_activation_eqn_multi_matrix Swish = (@{const ‹map_vec›(real,real)})$(@{const ‹Activation_Functions.swish›})
| term_of_activation_eqn_multi_matrix Tanh = (@{const ‹map_vec›(real,real)})$(@{const ‹Transcendental.tanh›(‹real›)})
| term_of_activation_eqn_multi_matrix Relu = (@{const ‹map_vec›(real,real)})$(@{const ‹Activation_Functions.relu›(‹real›)})
| term_of_activation_eqn_multi_matrix Gelu = (@{const ‹map_vec›(real,real)})$(@{const ‹Activation_Functions.gelu_approx›})
| term_of_activation_eqn_multi_matrix GRelu = (@{const ‹map_vec›(real,real)})$(@{const ‹Activation_Functions.generalized_relu›(‹real›)})
| term_of_activation_eqn_multi_matrix Softplus = (@{const ‹map_vec›(real,real)})$(@{const ‹Activation_Functions.softplus›(‹real›)})
| term_of_activation_eqn_multi_matrix Elu = (@{const ‹map_vec›(real,real)})$(@{const ‹Activation_Functions.elu›(‹real›)})
| term_of_activation_eqn_multi_matrix Selu = (@{const ‹map_vec›(real,real)})$(@{const ‹Activation_Functions.selu›(‹real›)})
| term_of_activation_eqn_multi_matrix Exponential = (@{const ‹map_vec›(real,real)})$(@{const ‹Transcendental.exp›(‹real›)})
| term_of_activation_eqn_multi_matrix Hard_sigmoid = (@{const ‹map_vec›(real,real)})$(@{const ‹Activation_Functions.hard_sigmoid›(‹real›)})
| term_of_activation_eqn_multi_matrix Sigmoid_taylor = (@{const ‹map_vec›(real,real)})$(@{const ‹Activation_Functions.sigmoid⇩t⇩a⇩y⇩l⇩o⇩r›(‹real›)}$(HOLogic.mk_nat 2) )
| term_of_activation_eqn_multi_matrix Softmax = (@{const ‹Activation_Functions.msoftmax›(‹real›)})
| term_of_activation_eqn_multi_matrix Softmax_taylor = (@{const ‹Activation_Functions.msoftmax⇩t⇩a⇩y⇩l⇩o⇩r›(‹real›)}$HOLogic.mk_nat 2)
fun add_function binding eqs lthy =
let
val cfg =
Function_Common.FunctionConfig { sequential=true, default=NONE,
domintros=true, partials=false}
val eq_defs = map (fn eq => (((Binding.empty, []), eq), [], [])) eqs
val ctx =
Function_Fun.add_fun [(binding, NONE, NoSyn)]
eq_defs
cfg lthy;
in
ctx
end
val mk_Trueprop_eq = HOLogic.mk_Trueprop o HOLogic.mk_eq
fun remdups (x::xs) = if NONE <> (List.find (fn x' => x'=x) xs) then remdups xs else x::(remdups xs)
| remdups [] = []
fun lemma_phi_ran mode defN phitab lthy =
let
fun get_local_thms n0 n1 = Proof_Context.get_thms lthy (Local_Theory.full_name lthy (Binding.qualify false defN (Binding.qualify_name false (Binding.name (n0)) n1) ))
fun mk_ss ctx thms = put_simpset HOL_basic_ss ctx addsimps thms
val phitab = remdups phitab
val Phi = Local_Theory.full_name lthy (Binding.qualify_name false (Binding.name defN) ("φ_"^defN))
val phi_const = case mode of
Single => Const(Phi, @{typ ‹activation⇩s⇩i⇩n⇩g⇩l⇩e ⇒ (real ⇒ real) option›})
| MultiList => Const(Phi, @{typ ‹activation⇩m⇩u⇩l⇩t⇩i ⇒ (real list ⇒ real list) option›})
| MultiMatrix => Const(Phi, @{typ ‹activation⇩m⇩u⇩l⇩t⇩i ⇒ (real Matrix.vec ⇒ real Matrix.vec) option›})
val ran_const = case mode of
Single => @{const "ran" (‹activation⇩s⇩i⇩n⇩g⇩l⇩e›, ‹real ⇒ real›)}
| MultiList => @{const "ran" (‹activation⇩m⇩u⇩l⇩t⇩i›, ‹real list ⇒ real list›)}
| MultiMatrix => @{const "ran" (‹activation⇩m⇩u⇩l⇩t⇩i›, ‹real Matrix.vec ⇒ real Matrix.vec›)}
val lhs = ran_const$phi_const
val rhs = case mode of
Single => HOLogic.mk_set @{typ "real ⇒ real"} (map term_of_activation_eqn_single phitab)
| MultiList => HOLogic.mk_set @{typ "real list ⇒ real list"} (map term_of_activation_eqn_multi_list phitab)
| MultiMatrix => HOLogic.mk_set @{typ "real Matrix.vec⇒ real Matrix.vec"} (map term_of_activation_eqn_multi_matrix phitab)
in
nn_tactics.prove_simple (Binding.qualify_name false (Binding.name defN) "φ_ran")
(mk_Trueprop_eq (lhs,rhs))
(fn ctx => auto_tac (mk_ss ctx [@{thm "ran_def"},@{thm "ranI"}])
THEN eresolve_tac ctx (get_local_thms ("φ_"^defN) "elims") 1
THEN auto_tac (mk_ss ctx [@{thm "ran_def"},@{thm "ranI"}])
THEN ALLGOALS (Meson.meson_tac ctx ([@{thm "bot.extremum"}, @{thm "insert_subsetI"}, @{thm "ranI"}]@(get_local_thms ("φ_"^defN) "simps")) )
) lthy
end
fun def_phi_tab mode defN phitab lthy =
let
val phitab = remdups phitab
val Phi = "φ_"^defN
fun phi_tab_term Phi mode = case mode of
Single => Free(Phi, @{typ ‹activation⇩s⇩i⇩n⇩g⇩l⇩e ⇒ (real ⇒ real) option›})
| MultiList => Free(Phi, @{typ ‹activation⇩m⇩u⇩l⇩t⇩i ⇒ (real list ⇒ real list) option›})
| MultiMatrix => Free(Phi, @{typ ‹activation⇩m⇩u⇩l⇩t⇩i ⇒ (real Matrix.vec ⇒ real Matrix.vec) option›})
fun mk_term phi =
let
val lhs = case mode of
Single => (phi_tab_term Phi mode)$(term_of_activation_single phi)
| MultiList => (phi_tab_term Phi mode)$(term_of_activation_multi phi)
| MultiMatrix => (phi_tab_term Phi mode)$(term_of_activation_multi phi)
val rhs = case mode of
Single => @{const ‹Some›(‹real ⇒real›)}$(term_of_activation_eqn_single phi)
| MultiList => @{const ‹Some›(‹real list ⇒real list›)}$(term_of_activation_eqn_multi_list phi)
| MultiMatrix => @{const ‹Some›(‹real Matrix.vec ⇒real Matrix.vec›)}$(term_of_activation_eqn_multi_matrix phi)
in
mk_Trueprop_eq (lhs,rhs)
end
val catch_all_lhs = case mode of
Single => (Free(Phi, @{typ ‹activation⇩s⇩i⇩n⇩g⇩l⇩e ⇒ (real ⇒ real) option›}))$(@{term ‹(x::activation⇩s⇩i⇩n⇩g⇩l⇩e)›})
| MultiList => (Free(Phi, @{typ ‹activation⇩m⇩u⇩l⇩t⇩i ⇒ (real list ⇒ real list) option›}))$(@{term ‹(x::activation⇩m⇩u⇩l⇩t⇩i)›})
| MultiMatrix => (Free(Phi, @{typ ‹activation⇩m⇩u⇩l⇩t⇩i ⇒ (real Matrix.vec ⇒ real Matrix.vec) option›}))$(@{term ‹(x::activation⇩m⇩u⇩l⇩t⇩i)›})
val catch_all_rhs = case mode of
Single => @{term "None::(real ⇒ real) option"}
| MultiList => @{term "None::(real list ⇒ real list) option"}
| MultiMatrix => @{term "None::(real Matrix.vec ⇒ real Matrix.vec) option"}
val eqs = (map mk_term phitab)@[mk_Trueprop_eq(catch_all_lhs, catch_all_rhs)]
in
(snd o Local_Theory.begin_nested) lthy
|> add_function (Binding.qualify_name true (Binding.name defN) ("φ_"^defN)) eqs
|> Local_Theory.end_nested
|> (snd o Local_Theory.begin_nested)
|> lemma_phi_ran mode defN phitab
|> Local_Theory.end_nested
end
end