File ‹Tools/Activation_Functions.ML›

(***********************************************************************************
 * Copyright (c) 2021-2023 University of Exeter, UK
 *
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *
 * * Redistributions of source code must retain the above copyright notice, this
 *
 * * Redistributions in binary form must reproduce the above copyright notice,
 *   this list of conditions and the following disclaimer in the documentation
 *   and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
 * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 * SPDX-License-Identifier: BSD-2-Clause
 ***********************************************************************************)

structure Activation_Term:ACTIVATION_TERM = struct
  datatype mode = Single | MultiList | MultiMatrix
  open TensorFlow_Type
  fun term_of_activation_single Linear      = @{ConstActivation_Functions.activationsingle.Identity}
    | term_of_activation_single Softsign   = @{ConstActivation_Functions.activationsingle.SoftSign}
    | term_of_activation_single Sign   = @{ConstActivation_Functions.activationsingle.Sign}
    | term_of_activation_single BinaryStep   = @{ConstActivation_Functions.activationsingle.BinaryStep}
    | term_of_activation_single Sigmoid = @{ConstActivation_Functions.activationsingle.Sigmoid}
    | term_of_activation_single Swish = @{ConstActivation_Functions.activationsingle.Swish}
    | term_of_activation_single Tanh = @{ConstActivation_Functions.activationsingle.Tanh}
    | term_of_activation_single Relu = @{ConstActivation_Functions.activationsingle.ReLU}
    | term_of_activation_single Gelu = @{ConstActivation_Functions.activationsingle.GeLUapprox}
    | term_of_activation_single GRelu = @{ConstActivation_Functions.activationsingle.GReLU}
    | term_of_activation_single Softplus = @{ConstActivation_Functions.activationsingle.Softplus}
    | term_of_activation_single Elu = @{ConstActivation_Functions.activationsingle.ELU}
    | term_of_activation_single Selu = @{ConstActivation_Functions.activationsingle.SELU}
    | term_of_activation_single Exponential = @{ConstActivation_Functions.activationsingle.Exp}
    | term_of_activation_single Hard_sigmoid = @{ConstActivation_Functions.activationsingle.HardSigmoid}
    | term_of_activation_single Sigmoid_taylor = (@{ConstActivation_Functions.activationsingle.Sigmoidtaylor}$HOLogic.mk_nat 2)
    | term_of_activation_single Softmax = error "Activation fuction 'softmax' is not a single class activation function."
    | term_of_activation_single Softmax_taylor = error "Activation fuction 'softmax_taylor' is not a single class activation function."
  
  fun term_of_activation_multi Linear      = @{ConstActivation_Functions.activationmulti.mIdentity}
    | term_of_activation_multi Softsign   = @{ConstActivation_Functions.activationmulti.mSoftSign}
    | term_of_activation_multi Sign   = @{ConstActivation_Functions.activationmulti.mSign}
    | term_of_activation_multi BinaryStep   = @{ConstActivation_Functions.activationmulti.mBinaryStep}
    | term_of_activation_multi Sigmoid = @{ConstActivation_Functions.activationmulti.mSigmoid}
    | term_of_activation_multi Swish = @{ConstActivation_Functions.activationmulti.mSwish}
    | term_of_activation_multi Tanh = @{ConstActivation_Functions.activationmulti.mTanh}
    | term_of_activation_multi Relu = @{ConstActivation_Functions.activationmulti.mReLU}
    | term_of_activation_multi Gelu = @{ConstActivation_Functions.activationmulti.mGeLUapprox}
    | term_of_activation_multi GRelu = @{ConstActivation_Functions.activationmulti.mGReLU}
    | term_of_activation_multi Softplus = @{ConstActivation_Functions.activationmulti.mSoftplus}
    | term_of_activation_multi Elu = @{ConstActivation_Functions.activationmulti.mELU}
    | term_of_activation_multi Selu = @{ConstActivation_Functions.activationmulti.mSELU}
    | term_of_activation_multi Exponential = @{ConstActivation_Functions.activationmulti.mExp}
    | term_of_activation_multi Hard_sigmoid = @{ConstActivation_Functions.activationmulti.mHardSigmoid}
    | term_of_activation_multi Softmax = @{ConstActivation_Functions.activationmulti.mSoftmax}
    | term_of_activation_multi Sigmoid_taylor = (@{ConstActivation_Functions.activationmulti.mSigmoidtaylor}$HOLogic.mk_nat 2)
    | term_of_activation_multi Softmax_taylor = @{ConstActivation_Functions.activationmulti.mSoftmaxtaylor}
                                                 $(HOLogic.mk_nat 2)


  fun term_of_activation_eqn_single Linear      = @{constActivation_Functions.identity(real)}
    | term_of_activation_eqn_single Softsign   = @{constActivation_Functions.softsign(real)}
    | term_of_activation_eqn_single Sign   = @{constActivation_Functions.sign(real)}
    | term_of_activation_eqn_single BinaryStep   = @{constActivation_Functions.binary_step(real)}
    | term_of_activation_eqn_single Sigmoid = @{constActivation_Functions.sigmoid}
    | term_of_activation_eqn_single Swish = @{constActivation_Functions.swish}
    | term_of_activation_eqn_single Tanh = @{constTranscendental.tanh(real)}
    | term_of_activation_eqn_single Relu = @{constActivation_Functions.relu(real)}
    | term_of_activation_eqn_single Gelu = @{constActivation_Functions.gelu_approx}
    | term_of_activation_eqn_single GRelu = @{constActivation_Functions.generalized_relu(real)}
    | term_of_activation_eqn_single Softplus = @{constActivation_Functions.softplus(real)}
    | term_of_activation_eqn_single Elu = @{constActivation_Functions.elu(real)}
    | term_of_activation_eqn_single Selu = @{constActivation_Functions.selu(real)}
    | term_of_activation_eqn_single Exponential = @{constTranscendental.exp(real)}
    | term_of_activation_eqn_single Hard_sigmoid = @{constActivation_Functions.hard_sigmoid(real)}
    | term_of_activation_eqn_single Sigmoid_taylor = @{constActivation_Functions.sigmoidtaylor(real)}
    | term_of_activation_eqn_single Softmax = error "Activation fuction 'softmax' is not a single class activation function."
    | term_of_activation_eqn_single Softmax_taylor = error "Activation fuction 'softmax_taylor' is not a single class activation function."

  fun term_of_activation_eqn_multi_list Linear      = (@{constmap(real,real)})$(@{constActivation_Functions.identity(real)})
    | term_of_activation_eqn_multi_list Softsign   = (@{constmap(real,real)})$(@{constActivation_Functions.softsign(real)})
    | term_of_activation_eqn_multi_list Sign   = (@{constmap(real,real)})$(@{constActivation_Functions.sign(real)})
    | term_of_activation_eqn_multi_list BinaryStep   = (@{constmap(real,real)})$(@{constActivation_Functions.binary_step(real)})
    | term_of_activation_eqn_multi_list Sigmoid = (@{constmap(real,real)})$(@{constActivation_Functions.sigmoid})
    | term_of_activation_eqn_multi_list Swish = (@{constmap(real,real)})$(@{constActivation_Functions.swish})
    | term_of_activation_eqn_multi_list Tanh = (@{constmap(real,real)})$(@{constTranscendental.tanh(real)})
    | term_of_activation_eqn_multi_list Relu = (@{constmap(real,real)})$(@{constActivation_Functions.relu(real)})
    | term_of_activation_eqn_multi_list Gelu = (@{constmap(real,real)})$(@{constActivation_Functions.gelu_approx})
    | term_of_activation_eqn_multi_list GRelu = (@{constmap(real,real)})$(@{constActivation_Functions.generalized_relu(real)})
    | term_of_activation_eqn_multi_list Softplus = (@{constmap(real,real)})$(@{constActivation_Functions.softplus(real)})
    | term_of_activation_eqn_multi_list Elu = (@{constmap(real,real)})$(@{constActivation_Functions.elu(real)})
    | term_of_activation_eqn_multi_list Selu = (@{constmap(real,real)})$(@{constActivation_Functions.selu(real)})
    | term_of_activation_eqn_multi_list Exponential = (@{constmap(real,real)})$(@{constTranscendental.exp(real)})
    | term_of_activation_eqn_multi_list Hard_sigmoid = (@{constmap(real,real)})$(@{constActivation_Functions.hard_sigmoid(real)})
    | term_of_activation_eqn_multi_list Sigmoid_taylor = (@{constmap(real,real)})$(@{constActivation_Functions.sigmoidtaylor(real)}$(HOLogic.mk_nat 2) )
    | term_of_activation_eqn_multi_list Softmax = (@{constActivation_Functions.softmax(real)})
    | term_of_activation_eqn_multi_list Softmax_taylor = (@{constActivation_Functions.softmaxtaylor(real)}$HOLogic.mk_nat 2)

  fun term_of_activation_eqn_multi_matrix Linear      = (@{constmap_vec(real,real)})$(@{constActivation_Functions.identity(real)})
    | term_of_activation_eqn_multi_matrix Softsign   = (@{constmap_vec(real,real)})$(@{constActivation_Functions.softsign(real)})
    | term_of_activation_eqn_multi_matrix Sign   = (@{constmap_vec(real,real)})$(@{constActivation_Functions.sign(real)})
    | term_of_activation_eqn_multi_matrix BinaryStep   = (@{constmap_vec(real,real)})$(@{constActivation_Functions.binary_step(real)})
    | term_of_activation_eqn_multi_matrix Sigmoid = (@{constmap_vec(real,real)})$(@{constActivation_Functions.sigmoid})
    | term_of_activation_eqn_multi_matrix Swish = (@{constmap_vec(real,real)})$(@{constActivation_Functions.swish})
    | term_of_activation_eqn_multi_matrix Tanh = (@{constmap_vec(real,real)})$(@{constTranscendental.tanh(real)})
    | term_of_activation_eqn_multi_matrix Relu = (@{constmap_vec(real,real)})$(@{constActivation_Functions.relu(real)})
    | term_of_activation_eqn_multi_matrix Gelu = (@{constmap_vec(real,real)})$(@{constActivation_Functions.gelu_approx})
    | term_of_activation_eqn_multi_matrix GRelu = (@{constmap_vec(real,real)})$(@{constActivation_Functions.generalized_relu(real)})
    | term_of_activation_eqn_multi_matrix Softplus = (@{constmap_vec(real,real)})$(@{constActivation_Functions.softplus(real)})
    | term_of_activation_eqn_multi_matrix Elu = (@{constmap_vec(real,real)})$(@{constActivation_Functions.elu(real)})
    | term_of_activation_eqn_multi_matrix Selu = (@{constmap_vec(real,real)})$(@{constActivation_Functions.selu(real)})
    | term_of_activation_eqn_multi_matrix Exponential = (@{constmap_vec(real,real)})$(@{constTranscendental.exp(real)})
    | term_of_activation_eqn_multi_matrix Hard_sigmoid = (@{constmap_vec(real,real)})$(@{constActivation_Functions.hard_sigmoid(real)})
    | term_of_activation_eqn_multi_matrix Sigmoid_taylor = (@{constmap_vec(real,real)})$(@{constActivation_Functions.sigmoidtaylor(real)}$(HOLogic.mk_nat 2) )
    | term_of_activation_eqn_multi_matrix Softmax = (@{constActivation_Functions.msoftmax(real)})
    | term_of_activation_eqn_multi_matrix Softmax_taylor = (@{constActivation_Functions.msoftmaxtaylor(real)}$HOLogic.mk_nat 2)


  fun add_function binding eqs lthy =
    let
  val cfg =
    Function_Common.FunctionConfig { sequential=true, default=NONE,
      domintros=true, partials=false}
      val eq_defs = map (fn eq => (((Binding.empty, []), eq), [], [])) eqs
      val ctx =
        Function_Fun.add_fun [(binding, NONE, NoSyn)]
          eq_defs
          cfg  lthy;
    in
       ctx      
    end

   val mk_Trueprop_eq = HOLogic.mk_Trueprop o HOLogic.mk_eq
   fun remdups (x::xs) = if NONE <> (List.find (fn x' => x'=x) xs) then remdups xs else x::(remdups xs)
     | remdups [] = []

      fun lemma_phi_ran mode defN phitab lthy = 
      let 
         fun get_local_thms n0 n1 = Proof_Context.get_thms lthy (Local_Theory.full_name lthy (Binding.qualify false defN (Binding.qualify_name false (Binding.name (n0))  n1) ))
        fun mk_ss ctx thms = put_simpset HOL_basic_ss ctx addsimps thms
        val phitab = remdups phitab
        val Phi = Local_Theory.full_name lthy (Binding.qualify_name false (Binding.name defN) ("φ_"^defN))
        val phi_const = case mode of 
                        Single => Const(Phi, @{typ activationsingle  (real  real) option})
                      | MultiList => Const(Phi, @{typ activationmulti  (real list  real list) option})
                      | MultiMatrix => Const(Phi, @{typ activationmulti  (real Matrix.vec  real Matrix.vec) option})
         val ran_const = case mode of 
                        Single => @{const "ran" (activationsingle, real  real)}
                      | MultiList => @{const "ran" (activationmulti, real list  real list)}
                      | MultiMatrix => @{const "ran" (activationmulti, real Matrix.vec  real Matrix.vec)}
         val lhs = ran_const$phi_const 
         val rhs =  case mode of 
                           Single => HOLogic.mk_set @{typ "real  real"} (map term_of_activation_eqn_single phitab)
                         | MultiList  => HOLogic.mk_set @{typ "real list  real list"} (map term_of_activation_eqn_multi_list phitab)
                         | MultiMatrix  => HOLogic.mk_set @{typ "real Matrix.vec real Matrix.vec"} (map term_of_activation_eqn_multi_matrix phitab)
      in
         nn_tactics.prove_simple (Binding.qualify_name false (Binding.name defN) "φ_ran")
                      (mk_Trueprop_eq (lhs,rhs))
    
                    (fn ctx =>  auto_tac (mk_ss ctx [@{thm "ran_def"},@{thm "ranI"}]) 
                           THEN eresolve_tac ctx (get_local_thms ("φ_"^defN) "elims") 1  
                           THEN auto_tac (mk_ss ctx [@{thm "ran_def"},@{thm "ranI"}])
                           THEN ALLGOALS (Meson.meson_tac ctx ([@{thm "bot.extremum"}, @{thm "insert_subsetI"}, @{thm "ranI"}]@(get_local_thms ("φ_"^defN) "simps")) )
                     )  lthy 
(*  
                   (fn s => Skip_Proof.cheat_tac lthy 1) lthy 
 *)     end 

  fun def_phi_tab mode defN phitab lthy = 
      let
        val phitab = remdups phitab
        val Phi = "φ_"^defN 
        fun phi_tab_term Phi mode = case mode of 
                        Single => Free(Phi, @{typ activationsingle  (real  real) option})
                      | MultiList => Free(Phi, @{typ activationmulti  (real list  real list) option})
                      | MultiMatrix => Free(Phi, @{typ activationmulti  (real Matrix.vec  real Matrix.vec) option})

        fun mk_term phi = 
            let 
              val lhs =  case mode of 
                           Single => (phi_tab_term Phi mode)$(term_of_activation_single phi)
                         | MultiList => (phi_tab_term Phi mode)$(term_of_activation_multi phi)
                         | MultiMatrix => (phi_tab_term Phi mode)$(term_of_activation_multi phi)
              val rhs =  case mode of 
                           Single => @{constSome(real real)}$(term_of_activation_eqn_single phi)
                         | MultiList  => @{constSome(real list real list)}$(term_of_activation_eqn_multi_list phi)
                         | MultiMatrix  => @{constSome(real Matrix.vec real Matrix.vec)}$(term_of_activation_eqn_multi_matrix phi)
            in
               mk_Trueprop_eq (lhs,rhs)
            end 
        val catch_all_lhs = case mode of 
                              Single =>  (Free(Phi, @{typ activationsingle  (real  real) option}))$(@{term (x::activationsingle)})
                            | MultiList =>   (Free(Phi, @{typ activationmulti  (real list  real list) option}))$(@{term (x::activationmulti)})
                            | MultiMatrix =>   (Free(Phi, @{typ activationmulti  (real Matrix.vec  real Matrix.vec) option}))$(@{term (x::activationmulti)})
        val catch_all_rhs = case mode of 
                              Single => @{term "None::(real  real) option"}
                            | MultiList   =>  @{term "None::(real list  real list) option"}
                            | MultiMatrix =>  @{term "None::(real Matrix.vec  real Matrix.vec) option"}
                                      
                                      
        val eqs =  (map mk_term phitab)@[mk_Trueprop_eq(catch_all_lhs, catch_all_rhs)]
      in
        (snd o Local_Theory.begin_nested) lthy
        |> add_function (Binding.qualify_name true (Binding.name defN) ("φ_"^defN)) eqs 
        |> Local_Theory.end_nested
        |> (snd o Local_Theory.begin_nested)
        |> lemma_phi_ran mode defN phitab 
        |> Local_Theory.end_nested

      end 


end