From 01f521fa2ce8a0ffd3d503649c57dae64f120cf8 Mon Sep 17 00:00:00 2001 From: mkaratarakis Date: Mon, 28 Jul 2025 13:45:58 +0200 Subject: [PATCH 01/15] port --- .../SpinGlasses/HopfieldNetwork/Asym.lean | 442 ++++++++ .../BoltzmannMachine/Core.lean | 203 ++++ .../BoltzmannMachine/Markov.lean | 173 ++++ .../SpinGlasses/HopfieldNetwork/Core.lean | 936 +++++++++++++++++ .../SpinGlasses/HopfieldNetwork/Markov.lean | 316 ++++++ .../HopfieldNetwork/NNStochastic.lean | 16 + .../HopfieldNetwork/NeuralNetwork.lean | 115 +++ .../HopfieldNetwork/Stochastic.lean | 940 ++++++++++++++++++ .../HopfieldNetwork/StochasticAux.lean | 471 +++++++++ .../SpinGlasses/HopfieldNetwork/aux.lean | 282 ++++++ .../SpinGlasses/HopfieldNetwork/test.lean | 260 +++++ 11 files changed, 4154 insertions(+) create mode 100644 PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Asym.lean create mode 100644 PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/BoltzmannMachine/Core.lean create mode 100644 PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/BoltzmannMachine/Markov.lean create mode 100644 PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Core.lean create mode 100644 PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Markov.lean create mode 100644 PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/NNStochastic.lean create mode 100644 PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/NeuralNetwork.lean create mode 100644 PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Stochastic.lean create mode 100644 PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/StochasticAux.lean create mode 100644 PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/aux.lean create mode 100644 PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/test.lean diff --git a/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Asym.lean b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Asym.lean new file mode 100644 index 000000000..4f71fbd4c --- /dev/null +++ b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Asym.lean @@ -0,0 +1,442 @@ +/- +Copyright (c) 2025 Matteo Cipollina. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Matteo Cipollina +-/ + +import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.Core +import Mathlib.LinearAlgebra.Matrix.PosDef + +/-! +# Asymmetric Hopfield Networks + +This module provides an implementation of asymmetric Hopfield networks, which are neural networks +with potentially non-symmetric weight matrices. Unlike standard Hopfield networks that require +symmetric weights, asymmetric networks can converge under certain conditions. + +## Main Components + +* `AsymmetricHopfieldNetwork`: A neural network structure with asymmetric weights +* `updateStateAsym`: A function for updating states in the asymmetric case +* `localFieldAsym`: A function computing local fields in asymmetric networks +* `energyAsym`: A modified energy function for asymmetric networks + +## Mathematical Details + +The implementation is based on the decomposition of asymmetric weight matrices into: +* An antisymmetric component A (where A_{ij} = -A_{ji}) +* A positive definite symmetric component S +* A non-negative diagonal constraint + +This decomposition allows us to analyze the dynamics and convergence properties +of Hopfield networks with asymmetric weights. +-/ + +open Finset Matrix NeuralNetwork State + +variable {R U : Type} [Field R] [LinearOrder R] [IsStrictOrderedRing R] + [StarRing R] [DecidableEq U] [Fintype U] [Nonempty U] + +/-- A matrix `A : Matrix n n α` is "antisymmetric" if `Aᵀ = -A`. -/ +def Matrix.IsAntisymm [Neg α] (A : Matrix n n α) : Prop := Aᵀ = -A + +@[simp] +lemma Matrix.IsAntisymm.ext_iff [Neg α] {A : Matrix n n α} : + A.IsAntisymm ↔ ∀ i j, A j i = -A i j := by + simp [Matrix.IsAntisymm, Matrix.ext_iff] + exact Eq.congr_right rfl + +/-- +`AsymmetricHopfieldNetwork` defines a Hopfield network with asymmetric weights. +Unlike standard Hopfield networks where weights must be symmetric, +this variant allows for asymmetric weights that can be decomposed into +an antisymmetric part and a positive definite symmetric part. + +The network has: +- Fully connected neurons (each neuron is connected to all others) +- No self-connections (zero diagonal in weight matrix) +- Asymmetric weights (w_ij can differ from w_ji) +- Weights that can be decomposed into antisymmetric and positive definite components +-/ +abbrev AsymmetricHopfieldNetwork (R U : Type) [Field R] [LinearOrder R] [IsStrictOrderedRing R] [DecidableEq U] + [Nonempty U] [Fintype U] [StarRing R] : NeuralNetwork R U where + /- The adjacency relation between neurons `u` and `v`, defined as `u ≠ v`. -/ + Adj u v := u ≠ v + /- The set of input neurons, defined as the universal set. -/ + Ui := Set.univ + /- The set of output neurons, defined as the universal set. -/ + Uo := Set.univ + /- A proof that the intersection of the input and output sets is empty. -/ + hhio := Set.empty_inter (Set.univ ∪ Set.univ) + /- The set of hidden neurons, defined as the empty set. -/ + Uh := ∅ + /- A proof that all neurons are in the universal set. -/ + hU := by simp only [Set.mem_univ, Set.union_self, Set.union_empty] + /- A proof that the input set is not equal to the empty set. -/ + hUi := Ne.symm Set.empty_ne_univ + /- A proof that the output set is not equal to the empty set. -/ + hUo := Ne.symm Set.empty_ne_univ + /- The weights can be decomposed into antisymmetric and positive definite parts -/ + pw := fun w => + ∃ (A S : Matrix U U R), + A.IsAntisymm ∧ -- A is antisymmetric + Matrix.PosDef S ∧ -- S is positive definite + w = A + S ∧ -- Decomposition of the weight matrix + (∀ i, w i i ≥ 0) -- Non-negative diagonal + /- κ₁ is 0 for every neuron. -/ + κ1 _ := 0 + /- κ₂ is 1 for every neuron. -/ + κ2 _ := 1 + /- The network function for neuron `u`, given weights `w` and predecessor states `pred`. -/ + fnet u w pred _ := HNfnet u w pred + /- The activation function for neuron `u`, given input and threshold `θ`. -/ + fact u (_current_act_val : R) (net_input_val : R) (θ_vec : Vector R 1) := HNfact (θ_vec.get 0) net_input_val /- The output function, given the activation state `act`. -/ + fout _ act := HNfout act + /- A predicate that the activation state `act` is either 1 or -1. -/ + pact act := act = 1 ∨ act = -1 + /- A proof that the activation state of neuron `u` + is determined by the threshold `θ` and the network function. -/ + hpact w _ _ _ θ act _ u := + ite_eq_or_eq ((θ u).get 0 ≤ HNfnet u (w u) fun v => HNfout (act v)) 1 (-1) + +/-- +Extracts the antisymmetric and symmetric components from parameters that +satisfy the asymmetric condition. + +Parameters: +- wθ: The Hopfield network parameters +- hw': Proof that the parameters satisfy the asymmetric condition + +Returns: +- A pair (A, S) where A is antisymmetric, S is positive definite, and w = A + S +-/ +noncomputable def getAsymmetricDecomposition (wθ : Params (AsymmetricHopfieldNetwork R U)) + (_ : ∃ (A S : Matrix U U R), A.IsAntisymm ∧ Matrix.PosDef S ∧ wθ.w = A + S ∧ (∀ i, wθ.w i i ≥ 0)) : + Matrix U U R × Matrix U U R := + let w := wθ.w + let one_half : R := (1 : R) / (2 : R) + let A_matrix := one_half • (w - wᵀ) + let S_matrix := one_half • (w + wᵀ) + (A_matrix, S_matrix) + +/-- +Computes the local field for a neuron in an asymmetric Hopfield network. + +Parameters: +- wθ: The parameters for the asymmetric Hopfield network +- s: The current state of the network +- i: The index of the neuron + +Returns: +- The local field (weighted sum of inputs minus threshold) +-/ +def localFieldAsym [Nonempty U] (wθ : Params (AsymmetricHopfieldNetwork R U)) + (s : State (AsymmetricHopfieldNetwork R U)) (i : U) : R := + (∑ j ∈ Finset.univ, wθ.w i j * s.act j) - (wθ.θ i).get 0 + +/-- +Updates a neuron state according to asymmetric update rules. + +Parameters: +- wθ: The parameters for the asymmetric Hopfield network +- s: The current state of the network +- i: The index of the neuron to update + +Returns: +- A new state with the selected neuron updated +-/ +def updateStateAsym (wθ : Params (AsymmetricHopfieldNetwork R U)) + (s : State (AsymmetricHopfieldNetwork R U)) (i : U) : State (AsymmetricHopfieldNetwork R U) := + { act := fun j => if j = i then + let lf := localFieldAsym wθ s i + if lf ≤ 0 then -1 else 1 + else s.act j, + hp := fun j => by + by_cases h : j = i + · subst h + by_cases h_lf : localFieldAsym wθ s j ≤ 0 + · simp [h_lf] + · simp [h_lf] + · simp [h] + exact s.hp j + } + +/-- +Returns the updated activation value of a neuron after applying the asymmetric update rule. + +Parameters: +- wθ: The parameters for the asymmetric Hopfield network +- s: The current state of the network +- i: The index of the neuron + +Returns: +- The activation value after update +-/ +def updatedActValue (wθ : Params (AsymmetricHopfieldNetwork R U)) + (s : State (AsymmetricHopfieldNetwork R U)) (i : U) : R := + (updateStateAsym wθ s i).act i + +/-- +Creates a sequence of state updates for an asymmetric Hopfield network. + +Parameters: +- wθ: The parameters for the asymmetric Hopfield network +- s: The initial state +- useq: A sequence of neurons to update + +Returns: +- A function mapping each time step to the corresponding state +-/ +def seqStatesAsym (wθ : Params (AsymmetricHopfieldNetwork R U)) + (s : State (AsymmetricHopfieldNetwork R U)) (useq : ℕ → U) : ℕ → State (AsymmetricHopfieldNetwork R U) + | 0 => s + | n+1 => updateStateAsym wθ (seqStatesAsym wθ s useq n) (useq n) + +/-- +Checks if a state is stable in an asymmetric Hopfield network. +A state is stable if no single-neuron update changes the state. + +Parameters: +- wθ: The parameters for the asymmetric Hopfield network +- s: The state to check + +Returns: +- True if the state is stable, false otherwise +-/ +def isStableAsym (wθ : Params (AsymmetricHopfieldNetwork R U)) + (s : State (AsymmetricHopfieldNetwork R U)) : Prop := + ∀ i, updateStateAsym wθ s i = s + +/-- +Get the energy of an asymmetric Hopfield network for a specific state. +This is a modified energy function designed for asymmetric networks. + +Parameters: +- wθ: The parameters for the asymmetric Hopfield network +- s: The current state of the network + +Returns: +- The energy value of the network +-/ +def energyAsym (wθ : Params (AsymmetricHopfieldNetwork R U)) + (s : State (AsymmetricHopfieldNetwork R U)) : R := + let w := wθ.w; + let θ := fun i => (wθ.θ i).get 0; + -- For asymmetric networks, we use a modified energy function + -1/2 * ∑ i ∈ Finset.univ, ∑ j ∈ Finset.univ, w i j * s.act i * s.act j - + ∑ i ∈ Finset.univ, θ i * s.act i + +/-- +Potential function for asymmetric Hopfield networks at time step k. +This function is used to analyze the dynamics of the network during updates. + +Parameters: +- wθ: The parameters for the asymmetric Hopfield network +- s: The current state of the network +- k: The current time step +- useq: A sequence of neurons to update + +Returns: +- The potential value of the network +-/ +def potentialFunction (wθ : Params (AsymmetricHopfieldNetwork R U)) + (s : State (AsymmetricHopfieldNetwork R U)) (k : ℕ) (useq : ℕ → U) : R := + ∑ i ∈ Finset.univ, + ∑ j ∈ Finset.univ, + wθ.w i j * + (if i = useq (k % Fintype.card U) then + (if localFieldAsym wθ s i > 0 then 1 else -1) + else s.act i) * + s.act j + +-- Lemma: potentialFunction is bounded +@[simp] +lemma potential_function_bounded (wθ : Params (AsymmetricHopfieldNetwork R U)) + (s : State (AsymmetricHopfieldNetwork R U)) (k : ℕ) (useq : ℕ → U) : + ∃ (lowerBound upperBound : R), lowerBound ≤ potentialFunction wθ s k useq ∧ potentialFunction wθ s k useq ≤ upperBound := by + let maxSum : R := ∑ i ∈ Finset.univ, ∑ j ∈ Finset.univ, |wθ.w i j| + use -maxSum, maxSum + constructor + · -- Show that for each term, the product is at least -|weights_ij| + unfold potentialFunction + + have hbound : ∀ (i j : U), + wθ.w i j * + (if i = useq (k % Fintype.card U) then (if localFieldAsym wθ s i > 0 then 1 else -1) else s.act i) * + s.act j ≥ -|wθ.w i j| := by + intro i j + have h1 : |(if i = useq (k % Fintype.card U) then (if localFieldAsym wθ s i > 0 then 1 else -1) else s.act i)| = 1 := by + split_ifs with h h_field + · simp + · simp + · cases s.hp i with + | inl h_one => simp [h_one] + | inr h_neg_one => simp [h_neg_one] + have h2 : |s.act j| = 1 := by + cases s.hp j with + | inl h_one => simp [h_one] + | inr h_neg_one => simp [h_neg_one] + + calc + wθ.w i j * (if i = useq (k % Fintype.card U) then (if localFieldAsym wθ s i > 0 then 1 else -1) else s.act i) * s.act j + ≥ -|wθ.w i j * (if i = useq (k % Fintype.card U) then (if localFieldAsym wθ s i > 0 then 1 else -1) else s.act i) * s.act j| := neg_abs_le _ + _ = -|wθ.w i j * (if i = useq (k % Fintype.card U) then (if localFieldAsym wθ s i > 0 then 1 else -1) else s.act i) * s.act j| := by + simp only [gt_iff_lt, mul_ite, mul_one, mul_neg, ite_mul, neg_mul] + _ = -|wθ.w i j| * |(if i = useq (k % Fintype.card U) then (if localFieldAsym wθ s i > 0 then 1 else -1) else s.act i)| * |s.act j| := by + rw [abs_mul, abs_mul]; simp_all only [gt_iff_lt, mul_one] + _ = -|wθ.w i j| * 1 * 1 := by rw [h1, h2] + _ = -|wθ.w i j| := by ring_nf + + -- Use the bound to establish the inequality with the sum + calc + potentialFunction wθ s k useq + = ∑ i ∈ Finset.univ, ∑ j ∈ Finset.univ, + wθ.w i j * + (if i = useq (k % Fintype.card U) then (if localFieldAsym wθ s i > 0 then 1 else -1) else s.act i) * + s.act j := by rfl + _ ≥ ∑ i ∈ Finset.univ, ∑ j ∈ Finset.univ, -|wθ.w i j| := by + apply Finset.sum_le_sum + intro i _hi + apply Finset.sum_le_sum + intro j _hj + exact hbound i j + _ = -(∑ i ∈ Finset.univ, ∑ j ∈ Finset.univ, |wθ.w i j|) := by + simp [Finset.sum_neg_distrib] + _ = -maxSum := by rfl + + · apply Finset.sum_le_sum + intro i _ + apply Finset.sum_le_sum + intro j _ + have h1 : |(if i = useq (k % Fintype.card U) then (if localFieldAsym wθ s i > 0 then 1 else -1) else s.act i)| = 1 := by + split_ifs with h h_field + · exact abs_one + · simp_all only [Finset.mem_univ, gt_iff_lt, not_lt, abs_neg, abs_one] + cases s.hp i with + | inl h_one => simp [h_one] + | inr h_neg_one => simp [h_neg_one] + have h2: |s.act j| = 1 := by + cases s.hp j with + | inl h_one => simp [h_one] + | inr h_neg_one => simp [h_neg_one] + calc + wθ.w i j * (if i = useq (k % Fintype.card U) then (if localFieldAsym wθ s i > 0 then 1 else -1) else s.act i) * s.act j + ≤ |wθ.w i j * (if i = useq (k % Fintype.card U) then (if localFieldAsym wθ s i > 0 then 1 else -1) else s.act i) * s.act j| := le_abs_self _ + _ = |wθ.w i j| * |(if i = useq (k % Fintype.card U) then (if localFieldAsym wθ s i > 0 then 1 else -1) else s.act i)| * |s.act j| := by + rw [abs_mul, abs_mul] + _ = |wθ.w i j| * 1 * 1 := by rw [h1, h2] + _ = |wθ.w i j| := by ring + +-- Lemma for updatedActValue in terms of localFieldAsym +@[simp] +lemma updatedActValue_eq (wθ : Params (AsymmetricHopfieldNetwork R U)) + (s : State (AsymmetricHopfieldNetwork R U)) (i : U) : + updatedActValue wθ s i = if localFieldAsym wθ s i > 0 then 1 else -1 := by + unfold updatedActValue updateStateAsym + simp only + by_cases h : localFieldAsym wθ s i > 0 + · -- case: localFieldAsym wθ s i > 0 + simp [h] + · -- case: localFieldAsym wθ s i ≤ 0 + simp [h] + +-- Helper Lemma: localFieldAsym after an update +@[simp] +lemma localFieldAsym_update (wθ : Params (AsymmetricHopfieldNetwork R U)) + (s : State (AsymmetricHopfieldNetwork R U)) (i : U) : + localFieldAsym wθ (updateStateAsym wθ s i) i = + (∑ j ∈ Finset.univ, wθ.w i j * (if j = i then (updatedActValue wθ s i) else s.act j)) - (wθ.θ i).get 0 := by + unfold localFieldAsym + -- The key is to understand how updateStateAsym affects the sum + have h_update : ∀ j : U, (updateStateAsym wθ s i).act j = + if j = i then updatedActValue wθ s i else s.act j := by + intro j + unfold updatedActValue updateStateAsym + by_cases h_j_eq_i : j = i + · subst h_j_eq_i + simp + · simp [h_j_eq_i] + rw [Finset.sum_congr rfl] + intro j _ + rw [h_update] + +-- Helper Lemma: Expressing s'_j in terms of s_j and updatedActValue +@[simp] +lemma s'_eq_s_update (wθ : Params (AsymmetricHopfieldNetwork R U)) + (s : State (AsymmetricHopfieldNetwork R U)) (i j : U) : + let s' := updateStateAsym wθ s i + s'.act j = (if j = i then (updatedActValue wθ s i) else s.act j) + := by + by_cases h_j_eq_i : j = i + · subst h_j_eq_i + unfold updatedActValue updateStateAsym + simp + · unfold updateStateAsym + simp [h_j_eq_i] + +/-- +Helper function to create parameters for an asymmetric Hopfield network. + +Parameters: +- w: The weight matrix +- θ: The threshold function +- hw: A proof that weights respect the adjacency relation +- hasym: A proof that the weights can be decomposed appropriately + +Returns: +- Parameters for an asymmetric Hopfield network +-/ +def mkAsymmetricParams (w : Matrix U U R) (θ : U → Vector R 1) [Nonempty U] + (hw : ∀ u v, ¬ (AsymmetricHopfieldNetwork R U).Adj u v → w u v = 0) + (hasym : ∃ (A S : Matrix U U R), + A.IsAntisymm ∧ -- A is antisymmetric + (Matrix.PosDef S) ∧ -- S is positive definite + w = A + S ∧ -- Decomposition of the weight matrix + (∀ i, w i i ≥ 0)) : Params (AsymmetricHopfieldNetwork R U) where + w := w + hw := hw + hw' := hasym -- The asymmetric decomposition property + θ := θ + σ := fun _ => Vector.emptyWithCapacity 0 + +/-- +Attempts to find a stable state of an asymmetric Hopfield network by running a fixed number of updates. + +Unlike symmetric Hopfield networks, asymmetric networks do NOT generally guarantee convergence to fixed points. +Convergence properties depend on the specific structure of the weight matrix decomposition: + +1. When the symmetric positive definite component (S) dominates the antisymmetric component (A), + the network is more likely to converge to fixed points +2. When the antisymmetric component dominates, the network may exhibit limit cycles or chaotic behavior +3. For balanced cases, behavior depends on initial conditions and update sequence + +This function runs for N iterations and returns the resulting state, which may or may not be stable. + +Parameters: +- wθ: The parameters for the asymmetric Hopfield network +- s: The initial state +- useq: A sequence of neurons to update +- hf: A proof that the update sequence is fair +- N: The maximum number of iterations (defaults to 10 times the network size - + for different AHNs, different N values might be needed,) + +Returns: +- The state after N iterations, with no guarantee of stability +-/ +def iterateAsym (wθ : Params (AsymmetricHopfieldNetwork R U)) + (s : State (AsymmetricHopfieldNetwork R U)) [Nonempty U] (useq : ℕ → U) (_hf : fair useq) + (N : ℕ := Fintype.card U * 10) : State (AsymmetricHopfieldNetwork R U) := + seqStatesAsym wθ s useq N + +/-- +Checks whether the network has stabilized after N iterations by comparing states +across multiple time points. +-/ +def verifyStabilityAsym (wθ : Params (AsymmetricHopfieldNetwork R U)) + (s : State (AsymmetricHopfieldNetwork R U)) (useq : ℕ → U) (N : ℕ) : Bool := + -- Get the state after N iterations + let stateN := seqStatesAsym wθ s useq N + -- Get the state after N+1 iterations + let stateN1 := seqStatesAsym wθ s useq (N+1) + -- Check if all neurons have the same activation in both states + decide (∀ i : U, stateN.act i = stateN1.act i) diff --git a/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/BoltzmannMachine/Core.lean b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/BoltzmannMachine/Core.lean new file mode 100644 index 000000000..00b5207ba --- /dev/null +++ b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/BoltzmannMachine/Core.lean @@ -0,0 +1,203 @@ +/- +Copyright (c) 2025 Matteo Cipollina. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Matteo Cipollina +-/ + +import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.Core +import Mathlib.Algebra.Lie.OfAssociative +import Mathlib.Data.Real.StarOrdered +import Mathlib.Order.CompletePartialOrder +import Mathlib.Probability.Distributions.Uniform + +/-! +# Boltzmann Machine Neural Network + +This file defines Boltzmann Machines (BMs), a type of stochastic +recurrent neural network with symmetric connectivity between units and no self-connections. +This module extends the deterministic Hopfield network with stochastic dynamics: a fully‐connected, +symmetric weight network without self‐loops and binary activations `{1, ‐1}`. +We implement the Boltzmann Machine inside our `NeuralNetwork` framework, providing: + +• `BoltzmannMachine` : the network instance +• `ParamsBM`, `StateBM` : parameter and state types +• `energy` / `localField` / `probNeuronIsOne` : key statistics +• `gibbsUpdateSingleNeuron` / `gibbsSamplingStep` : Gibbs sampler + +## Mathematics + +Boltzmann Machines have binary neurons (±1) with probability of activation determined by: +- Energy function: $E(s) = -\frac{1}{2}\sum_{u,v, u \neq v} w_{u,v}s_u s_v - \sum_u \theta_u s_u$ +- Probability distribution: $P(s) \propto \exp(-E(s)/T)$ where $T$ is the temperature parameter +- Local field for neuron $u$: $L_u(s) = \sum_{v \neq u} w_{u,v}s_v + \theta_u$ +- Probability of neuron $u$ being 1: $P(s_u = 1) = \frac{1}{1 + \exp(-2L_u(s)/T)}$ + +The network uses Gibbs sampling to generate samples from the underlying probability distribution. +-/ +open Finset Matrix NeuralNetwork State ENNReal Real PMF + +variable {R U : Type} [Field R] [LinearOrder R] [IsStrictOrderedRing R] + [DecidableEq U] [Fintype U] [Nonempty U] + +omit [IsStrictOrderedRing R] in +lemma BM_pact_of_HNfact (θ val : R) : + (fun act : R => act = 1 ∨ act = -1) (HNfact θ val) := by + unfold HNfact + split_ifs + · left; rfl + · right; rfl + +variable [Coe R ℝ] + + +/-- +`BoltzmannMachine` defines a Boltzmann Machine neural network. +It extends `HopfieldNetwork` with specific properties: +- Neurons are fully connected (except to themselves). +- All neurons are both input and output neurons. +- Weights are symmetric, and self-weights are zero. +- Activation is binary (1 or -1). +-/ +abbrev BoltzmannMachine (R U : Type) [Field R] [LinearOrder R] + [IsStrictOrderedRing R] [DecidableEq U] [Nonempty U] [Fintype U] : NeuralNetwork R U := +{ (HopfieldNetwork R U) with + Adj := fun u v => u ≠ v, + Ui := Set.univ, Uo := Set.univ, Uh := ∅, + hU := by simp only [Set.univ, Set.union_self, Set.union_empty], + hUi := Ne.symm Set.empty_ne_univ, hUo := Ne.symm Set.empty_ne_univ, + hhio := Set.empty_inter (Set.univ ∪ Set.univ), + pw := fun w => w.IsSymm ∧ ∀ u, w u u = 0, + κ1 := fun _ => 0, κ2 := fun _ => 1, + fnet := fun u w_u pred _ => HNfnet u w_u pred, + fact := fun u (_current_act_val : R) (net_input_val : R) (θ_vec : Vector R 1) => + HNfact (θ_vec.get 0) net_input_val, fout := fun _ act => act, + pact := fun act => act = (1 : R) ∨ act = (-1 : R), -- This is the pact for BoltzmannMachine + hpact := fun _ _ _ _ _ _ _ _ => BM_pact_of_HNfact _ _ +} + +/-- +Parameters for a Boltzmann Machine. +-/ +structure ParamsBM (R U : Type) [Field R] [LinearOrder R] [IsStrictOrderedRing R] + [DecidableEq U] [Fintype U] [Nonempty U] where + /-- The weight matrix of the Boltzmann Machine. -/ + w : Matrix U U R + /-- Proof that the weight matrix satisfies the properties required by `BoltzmannMachine.pw`. -/ + hpw : (BoltzmannMachine R U).pw w + /-- A function mapping each neuron to its threshold value. -/ + θ : U → R + /-- The temperature parameter of the Boltzmann Machine. -/ + T : R + /-- Proof that the temperature `T` is positive. -/ + hT_pos : T > 0 + +/-- +`StateBM` is an alias for the state of a `BoltzmannMachine`. +It represents the activation values of all neurons in the network. +-/ +abbrev StateBM (R U : Type) [Field R] [LinearOrder R] [IsStrictOrderedRing R] + [DecidableEq U] [Fintype U] [Nonempty U] := + (BoltzmannMachine R U).State + +namespace BoltzmannMachine + +/-- +Calculates the energy of a state `s` in a Boltzmann Machine. + +The energy is defined as: +$$E(s) = -\frac{1}{2}\sum_{u,v, u \neq v} w_{u,v}s_u s_v - \sum_u \theta_u s_u$$ +where $w$ is the weight matrix, $s$ is the state (activations), and $\theta$ are the thresholds. +-/ +def energy (p : ParamsBM R U) (s : StateBM R U) : R := + - (1 / (2 : R)) * (∑ u, ∑ v ∈ Finset.univ.filter (fun v' => v' ≠ u), p.w u v * s.act u * s.act v) + - (∑ u, p.θ u * s.act u) + +/-- +Calculates the local field for a neuron `u` in a Boltzmann Machine. + +The local field is the sum of weighted inputs from other neurons plus the neuron's own threshold: +$$L_u(s) = \sum_{v \neq u} w_{u,v}s_v + \theta_u$$ +where $w$ is the weight matrix, $s$ is the state (activations), +and $\theta_u$ is the threshold for neuron $u$. +-/ +def localField (p : ParamsBM R U) (s : StateBM R U) (u : U) : R := + (∑ v ∈ Finset.univ.filter (fun v' => v' ≠ u), p.w u v * s.act v) + p.θ u + +/-- +Calculates the probability of neuron `u` being in state `1` (activated). + +The probability is given by the sigmoid function of the local field: +$$P(s_u = 1) = \frac{1}{1 + \exp(-2L_u(s)/T)}$$ +where $L_u(s)$ is the local field for neuron $u$ and $T$ is the temperature. +-/ +noncomputable def probNeuronIsOne (p : ParamsBM R U) + (s : StateBM R U) (u : U) : ℝ := + let L_u : ℝ := ↑(localField p s u) + let T_R : ℝ := ↑(p.T) + (1 : ℝ) / (1 + Real.exp (- (2 * L_u / T_R))) + +/-! +## Probability of a Neuron Being One + +This section defines properties related to the probability of a specific neuron being in state '1' +in a Boltzmann Machine. +-/ + +/-- The probability of a neuron `u` being in state '1' is always non-negative. -/ +lemma probNeuronIsOne_nonneg (p : ParamsBM R U) + (s : StateBM R U) (u : U) : probNeuronIsOne p s u ≥ 0 := by + simp only [probNeuronIsOne, div_nonneg_iff] + left + constructor + · norm_num + · positivity + +/-- The probability of a neuron `u` being in state '1' is always less than or equal to 1. -/ +lemma probNeuronIsOne_le_one (p : ParamsBM R U) + (s : StateBM R U) (u : U) : probNeuronIsOne p s u ≤ 1 := by + unfold probNeuronIsOne + apply div_le_one_of_le₀ + · have h : 0 < Real.exp (-(2 * ↑(localField p s u) / ↑p.T)) := Real.exp_pos _ + linarith + · have h1 : 0 < Real.exp (-(2 * ↑(localField p s u) / ↑p.T)) := Real.exp_pos _ + linarith + +/-- +Updates a single neuron `u` in state `s` according to the Gibbs distribution. + +The neuron's new state (1 or -1) is chosen probabilistically based on `probNeuronIsOne`. +Returns a probability mass function over the possible next states. +-/ +noncomputable def gibbsUpdateSingleNeuron (p : ParamsBM R U) + (s : StateBM R U) (u : U) : PMF (StateBM R U) := + let prob_one_R : ℝ := probNeuronIsOne p s u + let prob_one_ennreal := ENNReal.ofReal prob_one_R + have h_prob_ennreal_le_one : prob_one_ennreal ≤ 1 := + ENNReal.ofReal_le_one.mpr (probNeuronIsOne_le_one p s u) + PMF.bernoulli prob_one_ennreal h_prob_ennreal_le_one >>= fun takes_value_one => + let new_val : R := if takes_value_one then (1 : R) else (-1 : R) + PMF.pure + { act := fun v => if v = u then new_val else s.act v + , hp := fun v => by + by_cases hv : v = u + · subst hv + dsimp [new_val] + split_ifs with h + · exact Or.inl rfl + · exact Or.inr rfl + exact False.elim (h rfl) + · dsimp only; simp [if_neg hv]; exact s.hp v } + +/-- +Performs one step of Gibbs sampling. + +A neuron `u` is chosen uniformly at random from all neurons, +and then its state is updated according to `gibbsUpdateSingleNeuron`. +Returns a probability mass function over the possible next states. +-/ +noncomputable def gibbsSamplingStep (p : ParamsBM R U) + (s : StateBM R U) : PMF (StateBM R U) := + let neuron_pmf : PMF U := PMF.uniformOfFintype U + neuron_pmf >>= fun u => gibbsUpdateSingleNeuron p s u + +end BoltzmannMachine diff --git a/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/BoltzmannMachine/Markov.lean b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/BoltzmannMachine/Markov.lean new file mode 100644 index 000000000..90361556e --- /dev/null +++ b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/BoltzmannMachine/Markov.lean @@ -0,0 +1,173 @@ +/- +Copyright (c) 2025 Matteo Cipollina. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Matteo Cipollina +-/ +import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.BoltzmannMachine.Core +import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.BoltzmannMachine.Markov + +open Finset Matrix NeuralNetwork State ENNReal Real +open PMF MeasureTheory ProbabilityTheory.Kernel Set + +variable {R U : Type} [Field R] [LinearOrder R] [IsStrictOrderedRing R] [DecidableEq U] + [Fintype U] [Nonempty U] [Coe R ℝ] + +noncomputable instance : Fintype ((BoltzmannMachine R U).State) := by + -- States are functions from U to {-1, 1} with a predicate + let binaryType := {x : R | x = -1 ∨ x = 1} + have binaryFintype : Fintype binaryType := by + apply Fintype.ofList [⟨-1, Or.inl rfl⟩, ⟨1, Or.inr rfl⟩] + intro ⟨x, hx⟩ + simp only [List.mem_singleton, List.mem_cons] + cases hx with + | inl h => + left + apply Subtype.ext + exact h + | inr h => + right + left + apply Subtype.ext + exact h + let f : ((BoltzmannMachine R U).State) → (U → binaryType) := fun s u => + ⟨s.act u, by + unfold binaryType + have h := s.hp u + cases h with + | inl h_pos => right; exact h_pos + | inr h_neg => left; exact h_neg⟩ + have f_inj : Function.Injective f := by + intro s1 s2 h_eq + apply @NeuralNetwork.ext + intro u + have h := congr_fun h_eq u + have hval : (f s1 u).val = (f s2 u).val := congr_arg Subtype.val h + exact hval + exact Fintype.ofInjective f f_inj + +noncomputable instance : Fintype (StateBM R U) := by + dsimp [StateBM] + exact inferInstance + +namespace BoltzmannMachine + +instance (R U : Type) [Field R] [LinearOrder R] [IsStrictOrderedRing R] + [DecidableEq U] [Fintype U] [Nonempty U] : + MeasurableSpace ((BoltzmannMachine R U).State) := ⊤ + +instance : MeasurableSpace (StateBM R U) := inferInstance + +instance : DiscreteMeasurableSpace (StateBM R U) := + show DiscreteMeasurableSpace ((BoltzmannMachine R U).State) from inferInstance + +/-- +The Gibbs transition kernel for a Boltzmann Machine. +Given a current state `s`, it returns a probability measure over the next possible states, +determined by the `BoltzmannMachine.gibbsSamplingStep` function. +-/ +noncomputable def gibbsTransitionKernelBM (p : ParamsBM R U) : + ProbabilityTheory.Kernel (StateBM R U) (StateBM R U) where + toFun := fun state => (BoltzmannMachine.gibbsSamplingStep p state).toMeasure + -- For discrete state spaces, any function to measures is measurable + measurable' := Measurable.of_discrete + +/-- The Gibbs transition kernel for a Boltzmann Machine is a Markov Kernel +(preserves probability)-/ +instance isMarkovKernel_gibbsTransitionKernelBM (p : ParamsBM R U) : + ProbabilityTheory.IsMarkovKernel (gibbsTransitionKernelBM p) where + isProbabilityMeasure := by + intro s + simp only [gibbsTransitionKernelBM] + exact PMF.toMeasure.isProbabilityMeasure (BoltzmannMachine.gibbsSamplingStep p s) + +/-- +The unnormalized Boltzmann density function for a Boltzmann Machine state. +$ρ(s) = e^{-E(s)/T}$. +-/ +noncomputable def boltzmannDensityFnBM (p : ParamsBM R U) (s : StateBM R U) : ENNReal := + -- Break this into steps to help type inference + let energy_val : R := BoltzmannMachine.energy p s + let energy_real : ℝ := (↑energy_val : ℝ) + let temp_real : ℝ := (↑(p.T) : ℝ) + ENNReal.ofReal (Real.exp (-energy_real / temp_real)) + +/-- +The partition function (normalizing constant) for the Boltzmann Machine. +$Z = \sum_s e^{-E(s)/T}$. +-/ +noncomputable def partitionFunctionBM (p : ParamsBM R U) : ENNReal := + ∑ s : StateBM R U, boltzmannDensityFnBM p s -- Sum over all states in the Fintype + +/-- +The partition function is positive and finite, provided T > 0. +-/ +lemma partitionFunctionBM_pos_finite (p : ParamsBM R U) + [Nonempty (StateBM R U)] : + 0 < partitionFunctionBM p ∧ partitionFunctionBM p < ⊤ := by + constructor + · -- Proof of 0 < Z + apply ENNReal.sum_pos + · exact Finset.univ_nonempty + · intro s _hs + unfold boltzmannDensityFnBM + exact ENNReal.ofReal_pos.mpr (Real.exp_pos _) + · -- Proof of Z < ⊤ + unfold partitionFunctionBM + rw [sum_lt_top] + intro s _hs + unfold boltzmannDensityFnBM + exact ENNReal.ofReal_lt_top + +/-- +The Boltzmann distribution for a Boltzmann Machine. +This is the stationary distribution for the Gibbs sampling process. +$\pi(s) = \frac{1}{Z} e^{-E(s)/T}$. +Defined as a measure with density `boltzmannDensityFnBM / partitionFunctionBM` +with respect to the counting measure on the finite state space. +-/ +noncomputable def boltzmannDistributionBM (p : ParamsBM R U) + [ Nonempty (StateBM R U)] : + Measure (StateBM R U) := + let density := fun s => boltzmannDensityFnBM p s / partitionFunctionBM p + let Z_pos_finite := partitionFunctionBM_pos_finite p + if hZ_zero : partitionFunctionBM p = 0 then by { + -- This case should not happen due to Z_pos_finite.1 + exfalso; exact Z_pos_finite.1.ne' hZ_zero + } else if hZ_top : partitionFunctionBM p = ⊤ then by { + -- This case should not happen due to Z_pos_finite.2 + exfalso; exact Z_pos_finite.2.ne hZ_top + } else + @Measure.withDensity (StateBM R U) _ Measure.count density + +-- Cleaner definition relying on the proof that Z is good +noncomputable def boltzmannDistributionBM' (p : ParamsBM R U) : Measure (StateBM R U) := + @Measure.withDensity (StateBM R U) _ Measure.count (fun s => boltzmannDensityFnBM p s / partitionFunctionBM p) + +-- Prove it's a probability measure +instance isProbabilityMeasure_boltzmannDistributionBM + [ Nonempty (StateBM R U)] (p : ParamsBM R U) : + IsProbabilityMeasure (boltzmannDistributionBM' p) := by + constructor + -- Need to show: μ Set.univ = 1 + have h : boltzmannDistributionBM' p Set.univ = + ∫⁻ s, boltzmannDensityFnBM p s / partitionFunctionBM p ∂Measure.count := by + -- For withDensity μ f, applying to a set gives integral of f over that set + simp only [boltzmannDistributionBM', withDensity_apply] + -- This is a discrete space, so integral becomes a sum + simp only [MeasurableSpace.measurableSet_top, withDensity_apply, Measure.restrict_univ] + rw [h] + -- For counting measure on finite type, integral is sum over all elements + rw [MeasureTheory.lintegral_count] + -- For fintype, tsum becomes finite sum + rw [tsum_fintype] + have h_sum_div : ∑ s, boltzmannDensityFnBM p s / partitionFunctionBM p = + (∑ s, boltzmannDensityFnBM p s) / partitionFunctionBM p := by + have hpos := (partitionFunctionBM_pos_finite p).1.ne' + have hlt := (partitionFunctionBM_pos_finite p).2.ne + simp only [ENNReal.div_eq_inv_mul] + rw [← mul_sum] + rw [h_sum_div] + -- The numerator sum is exactly the definition of the partition function + rw [← partitionFunctionBM] + -- So we get Z/Z = 1 + exact ENNReal.div_self (partitionFunctionBM_pos_finite p).1.ne' (partitionFunctionBM_pos_finite p).2.ne diff --git a/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Core.lean b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Core.lean new file mode 100644 index 000000000..cf3b21f0c --- /dev/null +++ b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Core.lean @@ -0,0 +1,936 @@ +/- +Copyright (c) 2024 Michail Karatarakis. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Michail Karatarakis +-/ +import Mathlib.LinearAlgebra.Matrix.Symmetric +import Mathlib.Data.Matrix.Reflection +import Mathlib.Data.Vector.Defs +import Init.Data.Vector.Lemmas +import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.aux +import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.NeuralNetwork + +open Finset Matrix NeuralNetwork State + +variable {R U : Type} [Field R] [LinearOrder R] [IsStrictOrderedRing R] + [DecidableEq U] [Fintype U] + +/-- +`HNfnet` computes the weighted sum of predictions for all elements in `U`, excluding `u`. +-/ +abbrev HNfnet (u : U) (wu : U → R) (pred : U → R) : R := ∑ v ∈ {v | v ≠ u}, wu v * pred v + +lemma HNfnet_eq (u : U) (wu : U → R) (pred : U → R) (hw : wu u = 0) : + HNfnet u wu pred = ∑ v, wu v * pred v := by + simp_rw [sum_filter, ite_not] + rw [Finset.sum_congr rfl] + intros v _ + rw [ite_eq_right_iff, zero_eq_mul] + intros hvu + left + rwa [hvu] + +/-- +`HNfact` returns `1` if `θ` is less than or equal to `input`, otherwise `-1`. +-/ +abbrev HNfact (θ input : R) : R := if θ ≤ input then 1 else -1 + +/-- +`HNfout` is an identity function that returns its input unchanged. +-/ +abbrev HNfout (act : R) : R := act + +/-- +`HopfieldNetwork` is a type of neural network with parameters `R` and `U`. + +- `R`: A linear ordered field. +- `U`: A finite, nonempty set of neurons with decidable equality. +-/ +abbrev HopfieldNetwork (R U : Type) [Field R] [LinearOrder R] [IsStrictOrderedRing R] + [DecidableEq U] [Nonempty U] [Fintype U] : NeuralNetwork R U where + /- The adjacency relation between neurons `u` and `v`, defined as `u ≠ v`. -/ + Adj u v := u ≠ v + /- The set of input neurons, defined as the universal set. -/ + Ui := Set.univ + /- The set of output neurons, defined as the universal set. -/ + Uo := Set.univ + /- A proof that the intersection of the input and output sets is empty. -/ + hhio := Set.empty_inter (Set.univ ∪ Set.univ) + /- The set of hidden neurons, defined as the empty set. -/ + Uh := ∅ + /- A proof that all neurons are in the universal set. -/ + hU := by simp only [Set.mem_univ, Set.union_self, Set.union_empty] + /- A proof that the input set is not equal to the empty set. -/ + hUi := Ne.symm Set.empty_ne_univ + /- A proof that the output set is not equal to the empty set. -/ + hUo := Ne.symm Set.empty_ne_univ + /- A property that the weight matrix `w` is symmetric. -/ + pw w := w.IsSymm + /- κ₁ is 0 for every neuron. -/ + κ1 _ := 0 + /- κ₂ is 1 for every neuron. -/ + κ2 _ := 1 + /- The network function for neuron `u`, given weights `w` and predecessor states `pred`. -/ + fnet u w pred _ := HNfnet u w pred + /- The activation function for neuron `u`, given input and threshold `θ`. -/ + fact u _ net_input_val θ_vec := HNfact (θ_vec.get 0) net_input_val + -- Ignoring the current_act_val argument + /- The output function, given the activation state `act`. -/ + fout _ act := HNfout act + /- A predicate that the activation state `act` is either 1 or -1. -/ + pact act := act = 1 ∨ act = -1 + /- A proof that the activation state of neuron `u` + is determined by the threshold `θ` and the network function. -/ + hpact w _ _ _ θ act _ u := + ite_eq_or_eq ((θ u).get 0 ≤ HNfnet u (w u) fun v => HNfout (act v)) 1 (-1) + +variable [Nonempty U] + +/-- +In a Hopfield network, two neurons are adjacent if and only if they are different. +This formalizes the fully connected nature of Hopfield networks. +-/ +lemma HopfieldNetwork.all_nodes_adjacent (u v : U) : + ¬(HopfieldNetwork R U).Adj u v → u = v := by + intro h + unfold HopfieldNetwork at h + simp only [ne_eq] at h + simp_all only [Decidable.not_not] + +/-- In a Hopfield network, activation values can only be 1 or -1. -/ +lemma hopfield_value_dichotomy + (val : R) (hval : (HopfieldNetwork R U).pact val) : + val ≠ 1 → val = -1 := by + intro h_not_one + unfold HopfieldNetwork at hval + simp only at hval + cases hval with + | inl h_eq_one => + contradiction + | inr h_eq_neg_one => + exact h_eq_neg_one + +/-- +Extracts the first element from a vector of length 1. +-/ +def θ' : Vector R ((HopfieldNetwork R U).κ2 u) → R := fun (θ : Vector R 1) => θ.get 0 + +/-- +Computes the outer product of two patterns in a Hopfield Network. + +Returns: +- A matrix where each element `(i, j)` is the product of the +activations of `p1` at `i` and `p2` at `j`. +-/ +abbrev outerProduct (p1 : (HopfieldNetwork R U).State) + (p2 : (HopfieldNetwork R U).State) : Matrix U U R := fun i j => p1.act i * p2.act j + +variable {s : (HopfieldNetwork R U).State} + +lemma NeuralNetwork.State.act_one_or_neg_one (u : U) : s.act u = 1 ∨ s.act u = -1 := s.hp u + +/-- Instances o establish decidability of equality for network states + under certain conditions. -/ +instance decidableEqState : + DecidableEq ((HopfieldNetwork R U).State) := by + intro s₁ s₂ + apply decidable_of_iff (∀ u, s₁.act u = s₂.act u) ⟨fun h ↦ ext h, fun h u ↦ by rw [h]⟩ + +/-- +Defines the Hebbian learning rule for a Hopfield Network. + +Given a set of patterns `ps`, this function returns the network parameters +using the Hebbian learning rule, which adjusts weights based on pattern correlations. +-/ +def Hebbian {m : ℕ} (ps : Fin m → (HopfieldNetwork R U).State) : Params (HopfieldNetwork R U) where + /- The weight matrix, calculated as the sum of the outer products of the patterns minus + a scaled identity matrix. -/ + w := ∑ k, outerProduct (ps k) (ps k) - (m : R) • (1 : Matrix U U R) + /- The threshold function, which is set to a constant value of 0 for all units. -/ + θ u := ⟨#[0], rfl⟩ + /- The state function, which is set to an empty vector. -/ + σ _ := Vector.emptyWithCapacity 0 + /- A proof that the weight matrix is symmetric and satisfies the Hebbian learning rule. -/ + hw u v huv := by + simp only [sub_apply, smul_apply, smul_eq_mul] + rw [Finset.sum_apply, Finset.sum_apply] + have : ∀ k i, (ps k).act i * (ps k).act i = 1 := by + intros k i ; rw [mul_self_eq_one_iff.mpr]; exact act_one_or_neg_one i + unfold HopfieldNetwork at huv + simp only [ne_eq, Decidable.not_not] at huv + rw [huv] + conv => enter [1, 1, 2]; + simp only [this, sum_const, card_univ, Fintype.card_fin, nsmul_eq_mul, mul_one, one_apply_eq, + sub_self] + /- A proof that the weight matrix is symmetric. -/ + hw' := by + simp only [Matrix.IsSymm, Fin.isValue, transpose_sub, transpose_smul, transpose_one, sub_left_inj] + rw [isSymm_sum] + intro k + refine IsSymm.ext_iff.mpr (fun i j => CommMonoid.mul_comm ((ps k).act j) ((ps k).act i)) + +variable (wθ : Params (HopfieldNetwork R U)) + +@[simp] +lemma act_up_def : (s.Up wθ u).act u = + (if (wθ.θ u : Vector R ((HopfieldNetwork R U).κ2 u)).get 0 ≤ s.net wθ u then 1 else -1) := by + simp only [Up, reduceIte, Fin.isValue] + rfl + +@[simp] +lemma act_of_non_up (huv : v2 ≠ u) : (s.Up wθ u).act v2 = s.act v2 := by + simp only [Up, if_neg huv] + +@[simp] +lemma act_new_neg_one_if_net_lt_th (hn : s.net wθ u < θ' (wθ.θ u)) : (s.Up wθ u).act u = -1 := by + rw [act_up_def]; exact ite_eq_right_iff.mpr fun hyp => (hn.not_le hyp).elim + +@[simp] +lemma actnew_neg_one_if_net_lt_th (hn : s.net wθ u < θ' (wθ.θ u)) : (s.Up wθ u).act u = -1 := + ((s.Up wθ _).act_one_or_neg_one _).elim (fun _ => act_new_neg_one_if_net_lt_th wθ hn) id + +@[simp] +lemma act_new_neg_one_if_not_net_lt_th (hn : ¬s.net wθ u < θ' (wθ.θ u)) : (s.Up wθ u).act u = 1 := by + rw [act_up_def]; exact ite_eq_left_iff.mpr fun hyp => (hn (lt_of_not_ge hyp)).elim + +@[simp] +lemma act_new_neg_one_if_net_eq_th (hn : s.net wθ u = θ' (wθ.θ u)) : (s.Up wθ u).act u = 1 := by + rw [act_up_def]; exact ite_eq_left_iff.mpr fun hyp => (hyp (le_iff_lt_or_eq.mpr (Or.inr hn.symm))).elim + +@[simp] +lemma activ_old_one (hc : (s.Up wθ u).act u ≠ s.act u) (hn : s.net wθ u < θ' (wθ.θ u)) : s.act u = 1 := + (act_one_or_neg_one _).elim id (fun h2 => (hc (actnew_neg_one_if_net_lt_th wθ hn ▸ h2.symm)).elim) + +@[simp] +lemma actnew_one (hn : ¬s.net wθ u < θ' (wθ.θ u)) : (s.Up wθ u).act u = 1 := + ((s.Up wθ _).act_one_or_neg_one _).elim id (fun _ => act_new_neg_one_if_not_net_lt_th wθ hn) + +@[simp] +lemma activ_old_neg_one (hc : (s.Up wθ u).act u ≠ s.act u) (_ : ¬s.net wθ u < θ' (wθ.θ u)) + (hnew : (s.Up wθ u).act u = 1) : s.act u = -1 := +(act_one_or_neg_one _).elim (fun h1 => (hc (hnew ▸ h1.symm)).elim) id + +@[simp] +lemma act_eq_neg_one_if_up_act_eq_one_and_net_eq_th (hc : (s.Up wθ u).act u ≠ s.act u) + (h2 : s.net wθ u = θ' (wθ.θ u)) (hactUp : (s.Up wθ u).act u = 1) : s.act u = -1 := +activ_old_neg_one wθ hc h2.symm.not_gt hactUp + +/-- +`NeuralNetwork.State.Wact` computes the weighted activation for neurons `u` and `v` +by multiplying the weight `wθ.w u v` with their activations `s.act u` and `s.act v`. +-/ +abbrev NeuralNetwork.State.Wact u v := wθ.w u v * s.act u * s.act v + +/-- +`NeuralNetwork.State.Eθ` computes the sum of `θ' (wθ.θ u) * s.act u` for all `u`. +-/ +def NeuralNetwork.State.Eθ := ∑ u, θ' (wθ.θ u) * s.act u + +/-- +`NeuralNetwork.State.Ew` computes the energy contribution from the weights in a state. +It is defined as `-1/2` times the sum of `s.Wact wθ u v2` for all `u` and `v2` where `v2 ≠ u`. +-/ +def NeuralNetwork.State.Ew := - 1/2 * (∑ u, (∑ v2 ∈ {v2 | v2 ≠ u}, s.Wact wθ u v2)) + +/-- +Calculates the energy `E` of a state `s` in a Hopfield Network. + +The energy is the sum of: +- `Ew` : Weighted energy component. +- `Eθ` : Threshold energy component. + +Arguments: +- `s`: A state in the Hopfield Network. +-/ +def NeuralNetwork.State.E (s : (HopfieldNetwork R U).State) : R := s.Ew wθ + s.Eθ wθ + +@[simp] +lemma Wact_sym (v1 v2 : U) : s.Wact wθ v1 v2 = s.Wact wθ v2 v1 := by + by_cases h : v1 = v2; + · simp_rw [mul_comm, h] + · simp_rw [mul_comm, congrFun (congrFun (id (wθ.hw').symm) v1) v2] + exact mul_left_comm (s.act v2) (s.act v1) (wθ.w v2 v1) + +@[simp] +lemma Ew_update_formula_split : s.Ew wθ = (- ∑ v2 ∈ {v2 | v2 ≠ u}, s.Wact wθ v2 u) + + - 1/2 * ∑ v1, (∑ v2 ∈ {v2 | (v2 ≠ v1 ∧ v1 ≠ u) ∧ v2 ≠ u}, s.Wact wθ v1 v2) := by + + have Ew_sum_formula_eq : + ∑ v1, (∑ v2 ∈ {v2 | v2 ≠ v1 ∧ v2 = u}, s.Wact wθ v1 v2) = + ∑ v1, (∑ v2 ∈ {v2 | (v2 ≠ v1 ∧ v1 ≠ u) ∧ v2 = u}, s.Wact wθ v1 v2) := by + rw [sum_congr rfl]; intro v1 _; rw [sum_congr] + ext v2; simp only [mem_filter, mem_univ, true_and, and_congr_left_iff, iff_self_and] + intro hv2 hnv1; rw [← hv2]; exact fun hv1v2 => hnv1 (id (hv1v2.symm)); intro v2 _; rfl + + calc _ = -1 / 2 * ∑ v1 : U, ∑ v2 ∈ {v2 | v2 ≠ v1 ∧ v1 = u}, s.Wact wθ v1 v2 + + -1 / 2 * ∑ v1 : U, ∑ v2 ∈ {v2 | v2 ≠ v1 ∧ v1 ≠ u}, s.Wact wθ v1 v2 := ?_ + _ = -1 / 2 * ∑ v1 : U, ∑ v2 ∈ {v2 | v2 ≠ v1 ∧ v1 = u}, s.Wact wθ v1 v2 + + -1 / 2 * (∑ v1 : U, ∑ v2 ∈ {v2 | (v2 ≠ v1 ∧ v1 ≠ u) ∧ v2 = u}, s.Wact wθ v1 v2 + + ∑ v1 : U, ∑ v2 ∈ {v2 | (v2 ≠ v1 ∧ v1 ≠ u) ∧ v2 ≠ u}, s.Wact wθ v1 v2) := ?_ + _ = (- ∑ v2 ∈ {v2 | v2 ≠ u}, s.Wact wθ v2 u) + + - 1/2 * ∑ v1, (∑ v2 ∈ {v2 | (v2 ≠ v1 ∧ v1 ≠ u) ∧ v2 ≠ u}, s.Wact wθ v1 v2) := ?_ + · simp only [Ew, mem_filter, mem_univ, true_and, true_implies, mul_sum, and_imp, + ← sum_add_distrib, ← sum_split] + · simp only [← sum_add_distrib, sum_congr, div_eq_zero_iff, neg_eq_zero, + one_ne_zero, OfNat.ofNat_ne_zero, or_self, or_false, ← sum_split] + · rw [mul_add, ← add_assoc, add_right_cancel_iff] + + have sum_v1_v2_not_eq_v1_eq_u : + ∑ v1, (∑ v2 ∈ {v2 | v2 ≠ v1 ∧ v1 = u}, s.Wact wθ v1 v2) = ∑ v2 ∈ {v2 | v2 ≠ u}, s.Wact wθ u v2 := by + rw [Fintype.sum_eq_single u]; simp only [and_true]; + intro v1 hv1; simp_all only [and_false, filter_False, sum_empty] + rw [sum_v1_v2_not_eq_v1_eq_u] + + have sum_v1_v2_not_eq_v1_eq_u' : + ∑ v1, (∑ v2 ∈ {v2 | (v2 ≠ v1 ∧ v1 ≠ u) ∧ v2 = u}, s.Wact wθ v1 v2) = ∑ v1 ∈ {v1 | v1 ≠ u}, s.Wact wθ u v1 := by + rw [← Ew_sum_formula_eq]; nth_rw 2 [sum_over_subset]; rw [sum_congr rfl]; intro v1 hv1 + have sum_Wact_v1_u : ∑ v2 ∈ {v2 | v2 ≠ v1 ∧ v2 = u}, s.Wact wθ v1 v2 = if v1 ≠ u then s.Wact wθ v1 u else 0 := by + split + · rw [sum_filter]; rw [sum_eq_single u] + · simp_all only [ne_eq, and_true, ite_not, ite_eq_right_iff] + intro a; subst a; simp_all only [not_true_eq_false] + · intro hv1 _ a; simp_all only [mem_univ, and_false, reduceIte] + · intro a; simp_all only [mem_univ, not_true_eq_false] + · simp_all only [Decidable.not_not, not_and_self, filter_False, sum_empty] + simp_rw [sum_Wact_v1_u, ite_not, mem_filter, mem_univ, true_and]; + split; next h => exact (if_neg fun hv1u => hv1u h).symm; ; exact Wact_sym wθ v1 u + + rw [← sum_v1_v2_not_eq_v1_eq_u', ← Ew_sum_formula_eq] + have sum_Wact_eq_sum_Wact_sym : ∑ v1, ∑ v2 ∈ {v2 | v2 ≠ v1 ∧ v2 = u}, s.Wact wθ v1 v2 = + ∑ v2 ∈ {v2 | v2 ≠ u}, s.Wact wθ v2 u := by + rw [Ew_sum_formula_eq, sum_v1_v2_not_eq_v1_eq_u']; apply sum_congr rfl (fun _ _ => Wact_sym wθ u _) + rw [sum_Wact_eq_sum_Wact_sym, mul_sum, ← sum_add_distrib, ← sum_neg_distrib]; + congr; apply funext; intro v2; + rw [← mul_add, (two_mul (Wact wθ v2 u)).symm, div_eq_mul_inv] + simp_all only [neg_mul, one_mul, isUnit_iff_ne_zero, ne_eq, + OfNat.ofNat_ne_zero, not_false_eq_true, IsUnit.inv_mul_cancel_left] + +@[simp] +lemma Ew_diff' : (s.Up wθ u).Ew wθ - s.Ew wθ = + - ∑ v2 ∈ {v2 | v2 ≠ u}, (s.Up wθ u).Wact wθ v2 u - (- ∑ v2 ∈ {v2 | v2 ≠ u}, s.Wact wθ v2 u) := by + rw [Ew_update_formula_split, Ew_update_formula_split, sub_eq_add_neg, sub_eq_add_neg] + simp only [neg_add_rev, neg_neg]; rw [mul_sum, mul_sum] + calc _ = -∑ v2 ∈ {v2 | v2 ≠ u}, (s.Up wθ u).Wact wθ v2 u + + (∑ v1, -1 / 2 * ∑ v2 ∈ {v2 | (v2 ≠ v1 ∧ v1 ≠ u) ∧ v2 ≠ u}, (s.Up wθ u).Wact wθ v1 v2 + + -∑ v1, -1 / 2 * ∑ v2 ∈ {v2 | (v2 ≠ v1 ∧ v1 ≠ u) ∧ v2 ≠ u}, s.Wact wθ v1 v2) + + ∑ v2 ∈ {v2 | v2 ≠ u}, s.Wact wθ v2 u := ?_ + _ = - ∑ v2 ∈ {v2 | v2 ≠ u}, (s.Up wθ u).Wact wθ v2 u - (- ∑ v2 ∈ {v2 | v2 ≠ u}, s.Wact wθ v2 u) := ?_ + · nth_rw 2 [← add_assoc]; rw [(add_assoc + (-∑ v2 ∈ {v2 | v2 ≠ u}, Wact wθ v2 u + ∑ v1, -1 / 2 * + ∑ v2 ∈ {v2 | (v2 ≠ v1 ∧ v1 ≠ u) ∧ v2 ≠ u}, (s.Up wθ u).Wact wθ v1 v2) + (-∑ v1, -1 / 2 * ∑ v2 ∈ {v2 | (v2 ≠ v1 ∧ v1 ≠ u) ∧ v2 ≠ u}, s.Wact wθ v1 v2) + (∑ v2 ∈ {v2 | v2 ≠ u}, Wact wθ v2 u))] + · simp only [sub_neg_eq_add, add_left_inj, add_eq_left] + rw [← sum_neg_distrib, ← sum_add_distrib, sum_eq_zero] + simp only [mem_univ, true_implies]; intro v1 + rw [mul_sum, mul_sum, ← sum_neg_distrib, ← sum_add_distrib, sum_eq_zero] + simp only [mem_filter, mem_univ, true_and, and_imp]; intro v2 _ hv1 hvneg2 + simp_all only [Wact, Up, mul_ite, ite_mul, reduceIte, add_neg_cancel] + simp only [sub_neg_eq_add] + +@[simp] +lemma θ_stable : ∑ v2 ∈ {v2 | v2 ≠ u}, θ' (wθ.θ v2) * s.act v2 = + ∑ v2 ∈ {v2 | v2 ≠ u}, θ' (wθ.θ v2) * (s.Up wθ u).act v2 := by + rw [sum_congr rfl]; intro v2 hv2; rw [act_of_non_up] + simp only [mem_filter, mem_univ, true_and] at hv2; assumption + +lemma θ_formula : ∑ v2, θ' (wθ.θ v2) * s.act v2 = θ' (wθ.θ u) * s.act u + + ∑ v2 ∈ {v2 | v2 ≠ u}, θ' (wθ.θ v2) * s.act v2 := by + have : ∑ v2 ∈ {v2 | v2 = u}, θ' (wθ.θ v2) * s.act v2 = θ' (wθ.θ u) * s.act u := by + rw [sum_filter]; simp only [sum_ite_eq', mem_univ, reduceIte] + rw [← this]; rw [sum_filter_add_sum_filter_not] + +@[simp] +theorem Eθ_diff : (s.Up wθ u).Eθ wθ - s.Eθ wθ = θ' (wθ.θ u) * ((s.Up wθ u).act u - s.act u) := by + calc _ = θ' (wθ.θ u) * (s.Up wθ u).act u + ∑ v2 ∈ {v2 | v2 ≠ u}, θ' (wθ.θ v2) * (s.Up wθ u).act v2 + + - (θ' (wθ.θ u) * s.act u + ∑ v2 ∈ {v2 | v2 ≠ u}, θ' (wθ.θ v2) * s.act v2) := ?_ + _ = θ' (wθ.θ u) * ((s.Up wθ u).act u - s.act u) := ?_ + · unfold NeuralNetwork.State.Eθ; rw [θ_formula, θ_formula, θ_stable] + rw [sub_eq_add_neg (θ' (wθ.θ u) * (s.Up wθ u).act u + + ∑ v2 ∈ {v2 | v2 ≠ u}, θ' (wθ.θ v2) * ((s.Up wθ u).Up wθ u).act v2) + (θ' (wθ.θ u) * s.act u + ∑ v2 ∈ {v2 | v2 ≠ u}, θ' (wθ.θ v2) * s.act v2)] + · rw [neg_add_rev, (add_assoc (θ' (wθ.θ u) * (s.Up wθ u).act u + + ∑ v2 ∈ {v2 | v2 ≠ u}, θ' (wθ.θ v2) * (s.Up wθ u).act v2) + (-∑ v2 ∈ {v2 | v2 ≠ u}, θ' (wθ.θ v2) * s.act v2) (-(θ' (wθ.θ u) * s.act u))).symm] + simp only [add_assoc, add_right_inj, add_eq_left]; nth_rw 2 [θ_stable] + rw [sub_eq_add_neg, mul_add, mul_neg]; simp only [add_neg_cancel_left] + +@[simp] +lemma E_final_Form : (s.Up wθ u).E wθ - s.E wθ = (s.act u - (s.Up wθ u).act u) * + ((∑ v2 ∈ {v2 | v2 ≠ u}, wθ.w u v2 * s.act v2) - θ' (wθ.θ u)) := by + calc _ = (s.Up wθ u).Eθ wθ- s.Eθ wθ + (s.Up wθ u).Ew wθ - s.Ew wθ := ?_ + _ = ∑ v2 ∈ {v2 | v2 ≠ u}, (- wθ.w v2 u * (s.Up wθ u).act v2 * (s.Up wθ u).act u + + (wθ.w v2 u * s.act v2 * s.act u)) + θ' (wθ.θ u) * ((s.Up wθ u).act u - s.act u) := ?_ + _ = ∑ v2 ∈ {v2 | v2 ≠ u}, (- wθ.w v2 u * s.act v2 * (s.Up wθ u).act u + wθ.w v2 u * s.act v2 * s.act u) + + θ' (wθ.θ u) * ((s.Up wθ u).act u - s.act u) := ?_ + _ = ∑ v2 ∈ {v2 | v2 ≠ u}, - (wθ.w v2 u * s.act v2 * ((s.Up wθ u).act u - s.act u)) + + θ' (wθ.θ u) * ((s.Up wθ u).act u - s.act u) := ?_ + _ = ((s.Up wθ u).act u - s.act u) * ∑ v2 ∈ {v2 | v2 ≠ u}, - (wθ.w v2 u * s.act v2) + + θ' (wθ.θ u) * ((s.Up wθ u).act u - s.act u) := ?_ + _ = ((s.Up wθ u).act u - s.act u) * (∑ v2 ∈ {v2 | v2 ≠ u}, - (wθ.w v2 u * s.act v2) + θ' (wθ.θ u)) := ?_ + _ = ((s.Up wθ u).act u - s.act u) * - (∑ v2 ∈ {v2 | v2 ≠ u}, (wθ.w v2 u * s.act v2) + - θ' (wθ.θ u)) := ?_ + _ = - ((s.Up wθ u).act u - s.act u) * ((∑ v2 ∈ {v2 | v2 ≠ u}, wθ.w v2 u * s.act v2) - θ' (wθ.θ u)) := ?_ + _ = (s.act u - (s.Up wθ u).act u) * ((∑ v2 ∈ {v2 | v2 ≠ u}, wθ.w u v2 * s.act v2) - θ' (wθ.θ u)) := ?_ + + · simp_rw [NeuralNetwork.State.E, sub_eq_add_neg, neg_add_rev] + rw [add_assoc, add_comm, ← add_assoc, add_right_comm (Eθ wθ + -Eθ wθ) (-Ew wθ) (Ew wθ) ] + · rw [add_sub_assoc (Eθ wθ - Eθ wθ) (Ew wθ) (Ew wθ), Eθ_diff, Ew_diff'] + nth_rw 1 [add_comm]; simp only [sub_neg_eq_add, neg_mul, add_left_inj] + rw [← sum_neg_distrib, ← sum_add_distrib] + · rw [sum_congr rfl]; intro v2 hv2 + rw [add_left_inj, mul_eq_mul_right_iff, mul_eq_mul_left_iff] + left; left; rw [act_of_non_up] + simp only [mem_filter, mem_univ, true_and] at hv2 + assumption + · simp_rw [neg_mul, sum_neg_distrib, add_left_inj] + rw [← sum_neg_distrib, sum_congr rfl]; intro v2 _; rw [mul_sub, add_comm, neg_sub] + rw [sub_eq_neg_add (wθ.w v2 u * s.act v2 * s.act u) (wθ.w v2 u * s.act v2 * (s.Up wθ u).act u)] + rw [add_comm] + · simp only [sum_neg_distrib, mul_neg, add_left_inj, neg_inj] + rw [mul_sum, sum_congr rfl]; intro v2 _; rw [mul_comm] + · rw [mul_add]; nth_rw 2 [mul_comm] + · simp_rw [sum_neg_distrib, neg_add_rev, neg_neg, mul_eq_mul_left_iff, add_comm,true_or] + · rw [neg_mul_comm, mul_eq_mul_left_iff]; left; simp only [neg_add_rev, neg_neg, neg_sub] + rw [(sub_eq_add_neg (θ' (wθ.θ u)) (∑ v2 ∈ {v2 | v2 ≠ u}, wθ.w v2 u * s.act v2))] + · simp only [neg_sub, ne_eq]; rw [mul_eq_mul_left_iff, sub_left_inj] + left; rw [sum_congr rfl]; intro v2 hv2 + simp_all only [mem_filter, mem_univ, true_and, mul_eq_mul_right_iff] + left; exact ((congrFun (congrFun (id (wθ.hw').symm) u) v2).symm) + +@[simp] +lemma energy_diff_leq_zero (hc : (s.Up wθ u).act u ≠ s.act u) : (s.Up wθ u).E wθ ≤ s.E wθ := by + apply le_of_sub_nonpos; rw [E_final_Form] + by_cases hs : s.net wθ u < θ' (wθ.θ u) + · apply mul_nonpos_of_nonneg_of_nonpos ?_ ?_ + · apply le_of_lt; apply sub_pos_of_lt; + simp_rw [activ_old_one wθ hc hs , actnew_neg_one_if_net_lt_th wθ hs, + neg_lt_self_iff, zero_lt_one] + · apply le_of_lt; rwa [sub_neg] + · apply mul_nonpos_of_nonpos_of_nonneg ?_ ?_ + · simp only [tsub_le_iff_right, zero_add] + simp_rw [activ_old_neg_one wθ hc hs (actnew_one wθ hs), + actnew_one wθ hs, neg_le_self_iff, zero_le_one] + · apply sub_nonneg_of_le; rwa [← not_lt] + +/-- +`NeuralNetwork.State.pluses` counts the number of neurons in the state `s` with activation `1`. +-/ +def NeuralNetwork.State.pluses := ∑ u, if s.act u = 1 then 1 else 0 + +@[simp] +theorem energy_lt_zero_or_pluses_increase (hc : (s.Up wθ u).act u ≠ s.act u) : + (s.Up wθ u).E wθ < s.E wθ ∨ (s.Up wθ u).E wθ = s.E wθ ∧ s.pluses < (s.Up wθ u).pluses := +(lt_or_eq_of_le (energy_diff_leq_zero wθ hc)).elim Or.inl (fun hr => Or.inr (by + constructor; assumption; rw [← sub_eq_zero, E_final_Form, mul_eq_zero] at hr + cases' hr with h1 h2 + · rw [sub_eq_zero] at h1; apply sum_lt_sum; + · simp_all only [ne_eq, not_true_eq_false] + · simp_all only [ne_eq, not_true_eq_false] + · rw [sub_eq_zero] at h2 + have hactUp := act_new_neg_one_if_net_eq_th wθ h2 + have hactu := act_eq_neg_one_if_up_act_eq_one_and_net_eq_th wθ hc h2 hactUp + apply sum_lt_sum + · intro v hv; split + · simp only [Up, HNfact]; split + · simp_all only [mem_univ, ne_eq] + · apply le_refl + · simp only [Up]; split + · split; apply zero_le_one; apply le_refl + · apply le_refl + · use u; simp_rw [hactUp, reduceIte]; split + · simp_all only [not_true_eq_false] + · simp only [zero_lt_one, true_and, mem_univ])) + +variable (extu : (HopfieldNetwork R U).State) (hext : extu.onlyUi) + +/-- +`stateToActValMap` maps a state from a `HopfieldNetwork` to the set `{-1, 1}`. +-/ +def stateToActValMap : (HopfieldNetwork R U).State → ({-1,1} : Finset R) := fun _ => by + simp_all only [mem_insert, mem_singleton]; apply Subtype.mk; apply Or.inr; rfl + +/-- +`neuronToActMap` maps a neuron `u` to its activation value in the set `{-1, 1}`. +-/ +def neuronToActMap : U → ({-1,1} : Finset R) := fun _ => stateToActValMap s + +/-- +`stateToNeurActMap` maps a Hopfield Network state to a function that returns +the activation state (1 or -1) of a given neuron. +-/ +def stateToNeurActMap : (HopfieldNetwork R U).State → (U → ({1,-1} : Finset R)) := fun s u => + ⟨s.act u, by simp only [mem_insert, mem_singleton, s.act_one_or_neg_one u]⟩ + +/-- +`NeuralNetwork.stateToNeurActMap_equiv'` provides an equivalence between the `State` type +of a `HopfieldNetwork` and a function type `U → ({1, -1} : Finset R)`. +This equivalence allows for easier manipulation of neural network states. +-/ +def NeuralNetwork.stateToNeurActMap_equiv' : + (HopfieldNetwork R U).State ≃ (U → ({1,-1} : Finset R)) where + toFun := stateToNeurActMap + invFun := fun f => + { act := fun u => f u, hp := fun u => by + simp only; cases' f u with val hval; simp only + simp_all only [mem_insert, mem_singleton]} + left_inv := congrFun rfl + right_inv := congrFun rfl + +instance : Fintype ((HopfieldNetwork R U).State) := Fintype.ofEquiv _ ((stateToNeurActMap_equiv').symm) + +/-- +`State'` is a type alias for the state of a `HopfieldNetwork` with given parameters. +-/ +def State' (_ : Params (HopfieldNetwork R U)) := (HopfieldNetwork R U).State + +variable {wθ : Params (HopfieldNetwork R U)} + +/-- +`Up'` updates the state `s` at neuron `u`. +-/ +abbrev Up' (s : State' wθ) (u : U) : State' wθ := s.Up wθ u + +/-- +Generates a sequence of states for a Hopfield Network. + +Parameters: +- `s`: A state. +- `useq`: A sequence of states. + +-/ +def seqStates' {wθ : Params (HopfieldNetwork R U)} (s : State' wθ) (useq : ℕ → U) : ℕ → State' wθ + := seqStates wθ s useq + +/-- +Defines a ordering between two states `s1` and `s2` based on their energy `E` +and the number of pluses. +A state `s1` is before `s2` if: +- `s1` has lower energy than `s2`, or +- `s1` has the same energy as `s2`, but more pluses. +-/ +def stateLt (s1 s2 : State' wθ) : Prop := s1.E wθ < s2.E wθ ∨ s1.E wθ = s2.E wθ ∧ s2.pluses < s1.pluses + +@[simp] +lemma stateLt_antisym (s1 s2 : State' wθ) : stateLt s1 s2 → ¬stateLt s2 s1 := by + rintro (h1 | ⟨_, h3⟩) (h2 | ⟨_, h4⟩) + · exact h1.not_lt h2 + · simp_all only [lt_self_iff_false] + · simp_all only [lt_self_iff_false] + · exact h3.not_lt h4 + +/-- +Defines a partial order on states. The relation `stateOrd` holds between two states `s1` and `s2` +if `s1` is equal to `s2` or if `s1` is before `s2` according to `stateLt`. +-/ +def stateOrd (s1 s2 : State' wθ) : Prop := s1 = s2 ∨ stateLt s1 s2 + +instance StatePartialOrder : PartialOrder (State' wθ) where + le s1 s2 := stateOrd s1 s2 + le_refl _ := Or.inl rfl + le_trans s1 s2 s3 h12 h23 := by + cases' h12 with h12 h12 + · cases' h23 with h23 h23 + · left; rw [h12, h23] + · right; rw [h12]; assumption + · cases' h23 with h23 h23; right; simp_all only; right + have : stateLt s1 s2 → stateLt s2 s3 → stateLt s1 s3 := by + rintro (h1 | ⟨h1, h2⟩) (h3 | ⟨h3, h4⟩) + · left; exact lt_trans h1 h3 + · left; rw [← h3]; assumption + · left; rw [h1]; assumption + · right; exact ⟨h1.trans h3, h4.trans h2⟩ + exact this h12 h23 + le_antisymm s1 s2 h12 h21 := by + cases' h12 with h12 h12 + · cases' h21 with h21 h21; assumption; assumption + · cases' h21 with h21 h21; exact h21.symm + by_contra; exact stateLt_antisym s1 s2 h12 h21 + +@[simp] +lemma stateLt_lt (s1 s2 : State' wθ) : s1 < s2 ↔ stateLt s1 s2 := by + simp only [LT.lt]; unfold stateOrd; simp_all only [not_or] + constructor + · intro H; obtain ⟨hl, hr⟩ := H + obtain ⟨_, hr⟩ := hr + cases' hl with hl hr + · subst hl; simp_all only [not_true_eq_false] + · simp_all only + · intro hs2; simp_all only [or_true, true_and] + constructor + · intro hs; subst hs; + have : ¬stateLt s2 s2:= fun + | Or.inl h1 => h1.not_lt h1 + | Or.inr ⟨_, h3⟩ => h3.not_lt h3 + exact this hs2 + · intro hs; apply stateLt_antisym s1 s2 hs2 hs + +@[simp] +lemma state_act_eq (s1 s2 : State' wθ) : s1.act = s2.act → s1 = s2 := by + intro h; cases' s1 with act1 hact1; cases' s2 with act2 hact2 + simp only at h; simp only [h] + +@[simp] +lemma state_Up_act (s : State' wθ) : (Up' s u).act u = s.act u → Up' s u = s := by + intro h; cases' s with act hact; apply state_act_eq; ext v + by_cases huv : v = u; simp only [huv, h]; simp only [Up', Up, huv, reduceIte] + +@[simp] +lemma up_act_eq_act_of_up_eq (s : State' wθ) : Up' s u = s → (Up' s u).act u = s.act u := fun hs => + congrFun (congrArg act hs) u + +@[simp] +lemma up_act_eq_iff_eq (s : State' wθ) : (Up' s u).act u = s.act u ↔ Up' s u = s := by + exact ⟨state_Up_act s, fun hs => congrFun (congrArg act hs) u⟩ + +@[simp] +lemma update_less' (s : State' wθ) : Up' s u ≠ s → Up' s u < s := fun h => by + simp only [stateLt_lt] + apply energy_lt_zero_or_pluses_increase + intros H + apply h + apply state_Up_act + assumption + +@[simp] +lemma update_le (s : State' wθ) : Up' s u ≤ s := by + by_cases h : Up' s u = s; left; assumption + right; simp only [← stateLt_lt]; exact update_less' s h + +@[simp] +lemma n_leq_n'_imp_sseq_n (n : ℕ) : + (seqStates wθ s useq (n + 1)) = (seqStates wθ s useq n).Up wθ (useq n):= by + unfold seqStates; split; rfl; simp_all only [Nat.succ_eq_add_one]; rfl + +@[simp] +lemma n_leq_n'_imp_sseq_n_k'' (n : ℕ) : + (seqStates wθ s useq (n+1)) = (seqStates wθ s useq n).Up wθ (useq n):= rfl + +@[simp] +lemma n_leq_n'_imp_sseq_n_k (n k : ℕ) : + (seqStates wθ s useq ((n + k) + 1)) = (seqStates wθ s useq (n + k)).Up wθ (useq (n + k)) := by + simp only [seqStates] + +@[simp] +lemma NeuralNetwork.n_leq_n'_imp_sseq_n'_leq_sseq'' (s : State' wθ) (n k : ℕ) : + seqStates' s useq (n + k) ≤ seqStates' s useq n := by + induction k with + | zero => simp only [Nat.add_zero]; apply le_refl + | succ k hk => rw [Nat.add_succ, seqStates', n_leq_n'_imp_sseq_n_k]; trans; apply update_le; exact hk + +@[simp] +lemma not_stable_u (s : (HopfieldNetwork R U).State) : ¬s.isStable wθ → ∃ u, (s.Up wθ u) ≠ s := by + intro h; + obtain ⟨u, h⟩ := not_forall.mp h + exact ⟨u, fun a => h (congrFun (congrArg act a) u)⟩ + +@[simp] +theorem seqStates_lt (s : State' wθ) (useq : ℕ → U) (n : ℕ) (m' : ℕ) (hm' : m' > n) : + seqStates' s useq m' ≤ seqStates' s useq n := by + obtain ⟨k, hk⟩ := Nat.exists_eq_add_of_le' hm' + rw [hk, Nat.add_left_comm k n 1] + exact n_leq_n'_imp_sseq_n'_leq_sseq'' s n (k + 1) + +variable (s s' : State' wθ) + +instance : DecidablePred (fun s' => s' < s) := fun s' => by + simp only; rw [stateLt_lt, stateLt]; exact instDecidableOr + +/-- +`states_less` is the set of patterns in a Hopfield Network that are less than a given state `s`. +-/ +def states_less : Finset (HopfieldNetwork R U).State := {s' : State' wθ | s' < s} + +open Fintype + +/-- +`num_of_states_less` returns the number of states that come before a given state `s`. +-/ +def num_of_states_less := Fintype.card (states_less s) + +@[simp] +lemma num_of_states_decreases (hs : s < s') : + num_of_states_less s < num_of_states_less s' := by + unfold num_of_states_less states_less + simp only [Fintype.card_coe] + apply Finset.card_lt_card + rw [Finset.ssubset_iff_of_subset] + simp only [mem_filter, mem_univ, true_and, not_lt] + use s; exact ⟨hs, gt_irrefl s⟩ + simp only [Finset.subset_iff, mem_filter, mem_univ, true_and] + exact fun _ hx => hx.trans hs + +@[simp] +lemma num_of_states_leq_zero_implies_stable (hn : num_of_states_less s = 0) : + s.isStable wθ := fun u => by + cases' update_le s with h1 h2 + · exact congrFun (congrArg act h1) u + · rw [← stateLt_lt] at h2 + unfold num_of_states_less states_less at hn + simp only [Fintype.card_eq_zero_iff] at hn + simp only [mem_filter, mem_univ, true_and, isEmpty_subtype] at hn + cases hn ((s.Up wθ u)) h2 + +@[simp] +lemma seqStates_le' (useq : ℕ → U) (n : ℕ) (m' : ℕ) (hm' : m' ≥ n) : + seqStates' s useq m' ≤ seqStates' s useq n := by + simp only [ge_iff_le, le_iff_lt_or_eq] at hm' + cases' hm' with h1 h2 + · exact seqStates_lt s useq n m' h1 + · exact le_of_eq (congrArg (seqStates wθ s useq) (id (h2.symm))) + +@[simp] +lemma not_stable_implies_sseqm_lt_sseqn (useq : ℕ → U) (hf : fair useq) (n : ℕ) + (hstable : ¬ (seqStates' s useq n).isStable wθ) : + ∃ m, m ≥ n ∧ (seqStates' s useq m) < (seqStates' s useq n) := by + obtain ⟨u, hc⟩ := not_forall.mp hstable + obtain ⟨m', ⟨hm', hu⟩⟩ := hf u n + have : seqStates' s useq m' ≤ (seqStates' s useq n) := seqStates_le' s useq n m' hm' + cases' (le_iff_lt_or_eq.mp this) with h1 h2 + · use m'; + · use m' + 1; constructor + · exact Nat.le_add_right_of_le hm' + · calc _ < _ := ?_ + _ = _ := h2 + · apply update_less' (seqStates' s useq m') + intro a; simp_all only [not_true_eq_false] + +@[simp] +lemma num_of_states_leq_c_implies_stable_sseq (s : (HopfieldNetwork R U).State) + (useq : ℕ → U) (hf : fair useq) (c : ℕ) : + ∀ n : ℕ, (@num_of_states_less _ _ _ _ _ _ _ _ wθ (seqStates' s useq n)) ≤ c → + ∃ m ≥ n, (@seqStates' _ _ _ _ _ _ _ _ wθ s useq m).isStable wθ := by + induction' c with c hc + · intros n hn; use n; constructor + · apply Nat.le_refl + · apply num_of_states_leq_zero_implies_stable + simp only [nonpos_iff_eq_zero] at hn; assumption + · intros n hn; + by_cases H : (@seqStates' _ _ _ _ _ _ _ _ wθ s useq n).isStable wθ + · use n + · obtain ⟨m, ⟨hm, hlt⟩⟩ := not_stable_implies_sseqm_lt_sseqn s useq hf n H + have : @num_of_states_less _ _ _ _ _ _ _ _ wθ (seqStates' s useq m) + < @num_of_states_less _ _ _ _ _ _ _ _ wθ (seqStates' s useq n) := by + apply num_of_states_decreases; assumption + have : @num_of_states_less _ _ _ _ _ _ _ _ wθ (seqStates' s useq m) ≤ c := by + apply Nat.le_of_lt_succ; + rw [← Nat.succ_eq_add_one] at hn + calc _ < @num_of_states_less _ _ _ _ _ _ _ _ wθ (seqStates' s useq n) := this + _ ≤ c.succ := hn + obtain ⟨m', ⟨hm', hstable⟩⟩ := hc m this + use m'; constructor + trans; assumption; assumption; assumption + +@[simp] +theorem HopfieldNet_convergence_fair : ∀ (useq : ℕ → U), fair useq → + ∃ N, (seqStates' s useq N).isStable wθ := fun useq hfair => by + let c := @num_of_states_less _ _ _ _ _ _ _ _ wθ (seqStates' s useq 0) + obtain ⟨N, ⟨_, hN⟩⟩ := num_of_states_leq_c_implies_stable_sseq s useq hfair c 0 (Nat.le_refl c) + use N + +instance (s : State' wθ): Decidable (isStable wθ s) := Fintype.decidableForallFintype + +/-- +A function that returns the stabilized state after updating. +-/ +def HopfieldNet_stabilize (wθ : Params (HopfieldNetwork R U)) + (s : State' wθ) (useq : ℕ → U) (hf : fair useq) : State' wθ := + (seqStates' s useq) (Nat.find (HopfieldNet_convergence_fair s useq hf)) + +@[simp] +lemma isStable_HN_stabilize : ∀ (s : State' wθ) (useq : ℕ → U) (hf : fair useq), + (HopfieldNet_stabilize wθ s useq hf).isStable wθ := fun s useq hf => + Nat.find_spec (HopfieldNet_convergence_fair s useq hf) + +@[simp] +lemma not_stable_implies_sseqm_lt_sseqn_cyclic (useq : ℕ → U) (hf : cyclic useq) (n : ℕ) + (hstable : ¬ (seqStates' s useq n).isStable wθ) : + ∃ m, m ≥ n ∧ m ≤ n + card U ∧ (seqStates' s useq m) < (seqStates' s useq n) := by + obtain ⟨u, hc⟩ := not_forall.mp hstable + have : (Up' (seqStates' s useq n) u).act u = (seqStates' s useq n).act u ↔ + (Up' (seqStates' s useq n) u) = (seqStates' s useq n) := up_act_eq_iff_eq (seqStates' s useq n) + rw [this] at hc + obtain ⟨m', ⟨hm', ⟨hm, hfoo⟩⟩⟩ := cyclic_Fair_bound useq hf u n + have : seqStates' s useq m' ≤ (seqStates' s useq n) := seqStates_le' s useq n m' hm' + cases' (le_iff_lt_or_eq.mp this) with h1 h2 + · use m'; constructor; exact hm'; subst hfoo + simp_all only [gt_iff_lt, and_self, and_true] + rw [le_iff_lt_or_eq]; left; exact hm + · use m' + 1; simp only [ge_iff_le] at hm'; constructor + · simp only [ge_iff_le]; exact Nat.le_add_right_of_le hm' + · constructor + · exact hm + · calc _ < _ := ?_ + _ = _ := h2 + · apply update_less' (seqStates' s useq m') + intro a; simp_all only [not_true_eq_false] + +@[simp] +lemma num_of_states_leq_c_implies_stable_sseq_cyclic (s : State' wθ) (useq : ℕ → U) + (hcy : cyclic useq) (c : ℕ) : ∀ n, num_of_states_less (seqStates' s useq n) ≤ c → + ∃ m ≥ n, m ≤ n + card U * c ∧ (s.seqStates wθ useq m).isStable wθ := by + induction' c with c hc + · intros n hn; use n; constructor + · exact Nat.le_refl n + · constructor + · exact Nat.le_add_right n (card U * 0) + · apply num_of_states_leq_zero_implies_stable + simp only [nonpos_iff_eq_zero] at hn; exact hn + · intros n hn + by_cases H : (s.seqStates wθ useq n).isStable wθ + · simp only [ge_iff_le]; use n; constructor + · exact Nat.le_refl n + · constructor + · exact Nat.le_add_right n (card U * (c + 1)) + · assumption + · obtain ⟨m, ⟨hm, hlt⟩⟩ := not_stable_implies_sseqm_lt_sseqn_cyclic s useq hcy n H + have : num_of_states_less (seqStates' s useq m) ≤ c := by + apply Nat.le_of_lt_succ; rw [← Nat.succ_eq_add_one] at hn + calc _ < num_of_states_less (seqStates' s useq n) := + num_of_states_decreases _ _ hlt.2 + _ ≤ c.succ := hn + obtain ⟨m', ⟨hm', hstable⟩⟩ := hc m this + use m'; constructor + · trans; assumption; assumption + · constructor + · obtain ⟨hlt', _⟩ := hlt + calc _ ≤ m + card U * c := hstable.1 + _ ≤ n + card U + card U * c := + Nat.add_le_add_right hlt' (card U * c) + _ ≤ n + card U * (c + 1) := by + rw [add_assoc, add_le_add_iff_left, + mul_add, mul_one, le_iff_lt_or_eq] + right + exact Nat.add_comm (card U) (card U * c) + · exact hstable.2 + +@[simp] +lemma num_of_states_card : card (HopfieldNetwork R U).State = 2 ^ card U := by + rw [Fintype.card_congr (stateToNeurActMap_equiv')] + have h3 : #({1,-1} : Finset R) = 2 := by + refine Finset.card_pair ?h + norm_cast + rw [Fintype.card_fun] + simp only [mem_insert, mem_singleton, Fintype.card_coe] + exact congrFun (congrArg HPow.hPow h3) (card U) + +@[simp] +lemma NeuralNetwork.initial_state_bound (useq : ℕ → U) : + num_of_states_less (seqStates' s useq 0) ≤ 2 ^ card U := by + rw [num_of_states_less, Fintype.card_of_subtype] + rw [← @num_of_states_card R _ _ _] + exact card_le_univ (states_less s); intros x; rfl + +@[simp] +theorem HopfieldNet_convergence_cyclic : ∀ (useq : ℕ → U), cyclic useq → + ∃ N, N ≤ card U * (2 ^ card U) ∧ + (s.seqStates wθ useq N).isStable wθ := fun useq hcy => by + obtain ⟨N, ⟨_, ⟨hN1, hN2⟩⟩⟩ := num_of_states_leq_c_implies_stable_sseq_cyclic s + useq hcy (2 ^ card U) 0 (initial_state_bound s useq) + use N; constructor; simp only [zero_add] at hN1; assumption; assumption + +/-- +`HopfieldNet_stabilize_cyclic` stabilizes a Hopfield network given an initial state `s`, +a sequence of updates `useq`, and a proof `hf` that the sequence is cyclic. +It returns the state of the network after convergence. +-/ +def HopfieldNet_stabilize_cyclic (s : State' wθ) (useq : ℕ → U) (hf : cyclic useq) : State' wθ := + (seqStates' s useq) (Nat.find (HopfieldNet_convergence_cyclic s useq hf)) + +/-- +`HopfieldNet_conv_time_steps` calculates the number of time steps required for a Hopfield Network to converge. +-/ +def HopfieldNet_conv_time_steps (wθ : Params (HopfieldNetwork R U)) (s : State' wθ) + (useq : ℕ → U) (hf : cyclic useq) : ℕ := + (Nat.find (HopfieldNet_convergence_cyclic s useq hf)) + +lemma HopfieldNet_cyclic_converg (wθ : Params (HopfieldNetwork R U)) (s : State' wθ) + (useq : ℕ → U) (hf : cyclic useq) : + (HopfieldNet_stabilize_cyclic s useq hf).isStable wθ := + (Nat.find_spec (HopfieldNet_convergence_cyclic s useq hf)).2 + +lemma patterns_pairwise_orthogonal (ps : Fin m → (HopfieldNetwork R U).State) + (horth : ∀ {i j : Fin m} (_ : i ≠ j), dotProduct (ps i).act (ps j).act = 0) : + ∀ (j : Fin m), ((Hebbian ps).w).mulVec (ps j).act = (card U - m) * (ps j).act := by + intros k + ext t + unfold Hebbian + simp only [sub_apply, smul_apply, smul_eq_mul] + rw [mulVec, dotProduct] + simp only [sub_apply, smul_apply, smul_eq_mul, Pi.natCast_def, Pi.mul_apply, Pi.sub_apply] + rw [Finset.sum_apply] + simp only [Finset.sum_apply] + unfold dotProduct at horth + have : ∀ i j, (dotProduct (ps i).act (ps j).act) = if i ≠ j then 0 else card U := by + intros i j + by_cases h : i ≠ j + · specialize horth h + simp_all only [ne_eq, not_false_eq_true, reduceIte, Nat.cast_zero] + assumption + · simp only [Decidable.not_not] at h + nth_rw 1 [h] + simp only [ite_not, Nat.cast_ite, Nat.cast_zero] + refine eq_ite_iff.mpr ?_ + left + constructor + · assumption + · unfold dotProduct + have hact : ∀ i, ((ps j).act i) = 1 ∨ ((ps j).act i) = -1 := fun i => act_one_or_neg_one i + have hact1 : ∀ i, ((ps j).act i) * ((ps j).act i) = 1 := fun i => mul_self_eq_one_iff.mpr (hact i) + calc _ = ∑ i, (ps j).act i * (ps j).act i := rfl + _ = ∑ i, 1 * 1 := by simp_rw [hact1]; rw [mul_one] + _ = card U := by simp only [sum_const, card_univ, Fintype.card_fin, nsmul_eq_mul, + mul_one] + simp only [dotProduct, ite_not, Nat.cast_ite, Nat.cast_zero] at this + conv => enter [1,2]; ext l; rw [sub_mul]; rw [sum_mul]; conv => enter [1,2]; ext i; rw [mul_assoc] + rw [Finset.sum_sub_distrib] + nth_rw 1 [sum_comm] + calc _= ∑ y : Fin m, (ps y).act t * ∑ x , ((ps y).act x * (ps k).act x) + - ∑ x , ↑m * (1 : Matrix U U R) t x * (ps k).act x := ?_ + _= ∑ y : Fin m, (ps y).act t * (if y ≠ k then 0 else card U) - + ∑ x , ↑m * (1 : Matrix U U R) t x * (ps k).act x := ?_ + _ = (card U - ↑m) * (ps k).act t := ?_ + · simp only [sub_left_inj]; rw [Finset.sum_congr rfl] + exact fun x _ => (mul_sum univ (fun i => (ps x).act i * (ps k).act i) ((ps x).act t)).symm + · simp only [sub_left_inj]; rw [Finset.sum_congr rfl]; intros i _ + simp_all only [reduceIte, implies_true, mem_univ, mul_ite, mul_zero, ite_not, Nat.cast_ite, Nat.cast_zero] + · simp only [ite_not, Nat.cast_ite, Nat.cast_zero, mul_ite, mul_zero, Finset.sum_ite_eq', mem_univ, reduceIte] + conv => enter [1,2,2]; ext k; rw [mul_assoc] + rw [← mul_sum, mul_comm] + simp only [one_apply, ite_mul, one_mul, zero_mul, Finset.sum_ite_eq, mem_univ, reduceIte] + exact (sub_mul (card U : R) m ((ps k).act t)).symm + +lemma stateisStablecondition (ps : Fin m → (HopfieldNetwork R U).State) + (s : (HopfieldNetwork R U).State) c (hc : 0 < c) + (hw : ∀ u, ((Hebbian ps).w).mulVec s.act u = c * s.act u) : s.isStable (Hebbian ps) := by + intros u + unfold Up out + simp only [reduceIte, Fin.isValue] + rw [HNfnet_eq] + simp_rw [mulVec, dotProduct] at hw u + refine ite_eq_iff.mpr ?_ + cases' s.act_one_or_neg_one u with h1 h2 + · left; rw [h1]; constructor + · rw [hw, le_iff_lt_or_eq]; left; rwa [h1, mul_one] + · rfl + · right; rw [h2]; constructor + · change ¬ 0 ≤ _ + rw [le_iff_lt_or_eq] + simp only [Left.neg_pos_iff, zero_eq_neg, not_or, not_lt] + constructor + · rw [le_iff_lt_or_eq]; left; + simpa only [hw, h2, mul_neg, mul_one, Left.neg_neg_iff] + · simp_all only [List.length_nil, Nat.succ_eq_add_one, + Nat.reduceAdd, mul_neg, mul_one, Fin.isValue, zero_eq_neg] + exact ne_of_gt hc + · rfl + exact (Hebbian ps).hw u u fun a => a rfl + +lemma Hebbian_stable (hm : m < card U) (ps : Fin m → (HopfieldNetwork R U).State) (j : Fin m) + (horth : ∀ {i j : Fin m} (_ : i ≠ j), dotProduct (ps i).act (ps j).act = 0): + isStable (Hebbian ps) (ps j) := by + unfold isStable + have := patterns_pairwise_orthogonal ps horth j + have hmn0 : 0 < (card U - m : R) := by + simpa only [sub_pos, Nat.cast_lt] + apply stateisStablecondition ps (ps j) (card U - m) hmn0 + · intros u; rw [funext_iff] at this; exact this u diff --git a/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Markov.lean b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Markov.lean new file mode 100644 index 000000000..33cf96039 --- /dev/null +++ b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Markov.lean @@ -0,0 +1,316 @@ +import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.Stochastic +import Mathlib.MeasureTheory.Measure.WithDensity +import Mathlib.Probability.Kernel.Invariance +import Mathlib.Probability.Kernel.Basic +import Mathlib.Probability.Kernel.Composition.MeasureComp +import Mathlib.Analysis.BoundedVariation + +open ProbabilityTheory.Kernel + +namespace ProbabilityTheory.Kernel + +/-- `Kernel.pow κ n` is the `n`-fold composition of the kernel `κ`, with `pow κ 0 = id`. -/ +noncomputable def pow {α : Type*} [MeasurableSpace α] (κ : Kernel α α) : ℕ → Kernel α α +| 0 => Kernel.id +| n + 1 => κ ∘ₖ (pow κ n) + +end ProbabilityTheory.Kernel + +/-! +# Markov Chain Framework + +## Main definitions + +* `stochasticHopfieldMarkovProcess`: A Markov process on Hopfield network states +* `gibbsTransitionKernel`: The transition kernel for Gibbs sampling +* `DetailedBalance`: The detailed balance condition for reversible Markov chains +* `mixingTime`: The time needed to approach the stationary distribution (TODO) + +-/ + +open MeasureTheory ProbabilityTheory ENNReal Finset Function ProbabilityTheory.Kernel Set + +namespace MarkovChain + +-- Using the discrete sigma-algebra implicitly for the finite state space +instance (R U : Type) [Field R] [LinearOrder R] [IsStrictOrderedRing R] + [DecidableEq U] [Fintype U] [Nonempty U] : + MeasurableSpace ((HopfieldNetwork R U).State) := ⊤ + +-- Prove all sets are measurable in the discrete sigma-algebra +lemma measurableSet_discrete {α : Type*} [MeasurableSpace α] (h : ‹_› = ⊤) + (s : Set α) : MeasurableSet s := by + rw [h] + trivial + +instance (R U : Type) [Field R] [LinearOrder R] [IsStrictOrderedRing R] + [DecidableEq U] [Fintype U] [Nonempty U] : + DiscreteMeasurableSpace ((HopfieldNetwork R U).State) where + forall_measurableSet := fun s => measurableSet_discrete rfl s + +/-! +### Core Markov Chain Definitions +-/ + +/-- +A `StationaryDistribution` for a transition kernel is a measure that remains +invariant under the action of the kernel. +-/ +structure StationaryDistribution {α : Type*} [MeasurableSpace α] (K : Kernel α α) where + /-- The probability measure that is stationary with respect to the kernel K. -/ + measure : Measure α + /-- Proof that the measure is a probability measure (sums to 1). -/ + isProbability : IsProbabilityMeasure measure + /-- Proof that the measure is invariant under the kernel K. -/ + isStationary : ∀ s, MeasurableSet s → (Measure.bind measure K) s = measure s + +/-- +The detailed balance condition for a Markov kernel with respect to a measure. +`μ(dx) K(x,dy) = μ(dy) K(y,dx)` for all measurable sets +-/ +def DetailedBalance {α : Type*} [MeasurableSpace α] (μ : Measure α) (K : Kernel α α) : Prop := + ∀ A B : Set α, MeasurableSet A → MeasurableSet B → + ∫⁻ x in A, (K x B) ∂μ = ∫⁻ y in B, (K y A) ∂μ + +/-- When detailed balance holds, the measure is stationary -/ +def stationaryOfDetailedBalance {α : Type*} [MeasurableSpace α] {μ : Measure α} + [IsProbabilityMeasure μ] {K : Kernel α α} [IsMarkovKernel K] + (h : DetailedBalance μ K) : StationaryDistribution K where + measure := μ + isProbability := inferInstance + isStationary := by + intro s hs + have bind_def : (μ.bind K) s = ∫⁻ x, (K x s) ∂μ := by + apply Measure.bind_apply hs (Kernel.aemeasurable K) + have h_balance := h Set.univ s MeasurableSet.univ hs + rw [bind_def] + have h_univ : ∫⁻ x, K x s ∂μ = ∫⁻ x in Set.univ, K x s ∂μ := by + simp only [Measure.restrict_univ] + rw [h_univ, h_balance] + have univ_one : ∀ y, K y Set.univ = 1 := by + intro y + exact measure_univ + have h_one : ∫⁻ y in s, K y Set.univ ∂μ = ∫⁻ y in s, 1 ∂μ := by + apply lintegral_congr_ae + exact ae_of_all (μ.restrict s) univ_one + rw [h_one, MeasureTheory.lintegral_const, Measure.restrict_apply MeasurableSet.univ, + Set.univ_inter, one_mul] + +/-! +### Markov Chain on Hopfield Networks +-/ + +section HopfieldMarkovChain + +variable {R U : Type} [Field R] [LinearOrder R] [IsStrictOrderedRing R] [DecidableEq U] + [Fintype U] [Nonempty U] [Coe R ℝ] + +instance : Nonempty ((HopfieldNetwork R U).State) := by + let defaultState : (HopfieldNetwork R U).State := { + act := fun _ => -1, + hp := fun _ => Or.inr rfl + } + exact ⟨defaultState⟩ + +-- Fintype instance for the state space +noncomputable instance : Fintype ((HopfieldNetwork R U).State) := by + let f : ((HopfieldNetwork R U).State) → (U → {r : R | r = 1 ∨ r = -1}) := + fun s u => ⟨s.act u, s.hp u⟩ + have h_inj : Function.Injective f := by + intro s1 s2 h + cases s1 with | mk act1 hp1 => + cases s2 with | mk act2 hp2 => + simp at * + ext u + have h_u := congr_fun h u + simp [f] at h_u + exact h_u + have h_surj : Function.Surjective f := by + intro g + let act := fun u => (g u).val + have hp : ∀ u, act u = 1 ∨ act u = -1 := fun u => (g u).property + exists ⟨act, hp⟩ + exact _root_.instFintypeStateHopfieldNetwork + +noncomputable def gibbsTransitionKernel (wθ : Params (HopfieldNetwork R U)) (T : ℝ) : + Kernel ((HopfieldNetwork R U).State) ((HopfieldNetwork R U).State) where + toFun := fun state => (NN.State.gibbsSamplingStep wθ state T).toMeasure + measurable' := Measurable.of_discrete + +-- Mark the kernel as a Markov kernel (preserves probability) +instance gibbsIsMarkovKernel (wθ : Params (HopfieldNetwork R U)) (T : ℝ) : + IsMarkovKernel (gibbsTransitionKernel wθ T) where + isProbabilityMeasure := by + intro s + simp [gibbsTransitionKernel] + exact PMF.toMeasure.isProbabilityMeasure (NN.State.gibbsSamplingStep wθ s T) + +/-- +The stochastic Hopfield Markov process, which models the evolution of Hopfield network states +over discrete time steps using Gibbs sampling at fixed temperature. +In this simplified model, the transition kernel is time-homogeneous (same for all steps). +-/ +noncomputable def stochasticHopfieldMarkovProcess (wθ : Params (HopfieldNetwork R U)) (T : ℝ) : + ℕ → Kernel ((HopfieldNetwork R U).State) ((HopfieldNetwork R U).State) := + fun _ => gibbsTransitionKernel wθ T + +/-- +The n-step transition probability, which gives the probability of moving from +state x to state y in exactly n steps. +-/ +noncomputable def nStepTransition (wθ : Params (HopfieldNetwork R U)) (T : ℝ) (n : ℕ) : + ((HopfieldNetwork R U).State) → ((HopfieldNetwork R U).State) → ENNReal := + fun x y => (Kernel.pow (gibbsTransitionKernel wθ T) n x) {y} -- Correct application of Kernel.pow + +/-- +The total variation distance between two probability measures on Hopfield network states. +Defined as supremum of |μ(A) - ν(A)| over all measurable sets A. +-/ +noncomputable def totalVariation (μ ν : Measure ((HopfieldNetwork R U).State)) : ENNReal := + ⨆ (A : Set ((HopfieldNetwork R U).State)) (_ : MeasurableSet A), + ENNReal.ofReal (abs ((μ A).toReal - (ν A).toReal)) + +/-- +A state is aperiodic if there's a positive probability of returning to it in a single step. +-/ +def IsAperiodic (wθ : Params (HopfieldNetwork R U)) (T : ℝ) + (s : (HopfieldNetwork R U).State) : Prop := (gibbsTransitionKernel wθ T s) {s} > 0 + +/-- +A Markov chain is irreducible if it's possible to get from any state to any other state +with positive probability in some finite number of steps. +-/ +def IsIrreducible (wθ : Params (HopfieldNetwork R U)) (T : ℝ) : Prop := + ∀ x y, ∃ n, (Kernel.pow (gibbsTransitionKernel wθ T) n x) {y} > 0 -- Use Kernel.pow correctly + +/-- The unnormalized Boltzmann density function -/ +noncomputable def boltzmannDensityFn (wθ : Params (HopfieldNetwork R U)) (T : ℝ) + (s : (HopfieldNetwork R U).State) : ENNReal := + ENNReal.ofReal (Real.exp (-(Coe.coe (NeuralNetwork.State.E wθ s) : ℝ) / T)) + +/-- The Boltzmann partition function (normalizing constant) -/ +noncomputable def boltzmannPartitionFn (wθ : Params (HopfieldNetwork R U)) (T : ℝ) : ENNReal := + ∑ s ∈ Finset.univ, boltzmannDensityFn wθ T s + +/-- Helper lemma: A finite sum of ENNReal values is positive if the set is + nonempty and all terms are positive -/ +lemma ENNReal.sum_pos {α : Type*} (s : Finset α) (f : α → ENNReal) + (h_nonempty : s.Nonempty) (h_pos : ∀ i ∈ s, 0 < f i) : 0 < ∑ i ∈ s, f i := by + rcases h_nonempty with ⟨i, hi⟩ + have h_pos_i : 0 < f i := h_pos i hi + have h_le : f i ≤ ∑ j ∈ s, f j := Finset.single_le_sum (fun j _ => zero_le (f j)) hi + exact lt_of_lt_of_le h_pos_i h_le + +/-- The Boltzmann partition function is positive and finite -/ +lemma boltzmannPartitionFn_pos_finite (wθ : Params (HopfieldNetwork R U)) (T : ℝ) (_hT : T ≠ 0) : + 0 < boltzmannPartitionFn wθ T ∧ boltzmannPartitionFn wθ T < ⊤ := by + simp only [boltzmannPartitionFn] + have h_pos_term : ∀ s : (HopfieldNetwork R U).State, 0 < boltzmannDensityFn wθ T s := by + intro s + simp only [boltzmannDensityFn] + exact ENNReal.ofReal_pos.mpr (Real.exp_pos _) + have h_finite_term : ∀ s : (HopfieldNetwork R U).State, boltzmannDensityFn wθ T s ≠ ⊤ := by + intro s + simp only [boltzmannDensityFn] + exact ENNReal.ofReal_ne_top + constructor + · -- Proves positivity: sum of positive terms is positive + apply ENNReal.sum_pos + · exact Finset.univ_nonempty + · intro s _hs_in_univ + exact h_pos_term s + · -- Proves finiteness: sum is finite if all terms are finite + rw [sum_lt_top] + intro s _hs_in_univ + rw [lt_top_iff_ne_top] + exact h_finite_term s +/-- +The Boltzmann distribution over Hopfield network states at temperature T. +-/ +noncomputable def boltzmannDistribution (wθ : Params (HopfieldNetwork R U)) (T : ℝ) (hT : T ≠ 0) : + Measure ((HopfieldNetwork R U).State) := + let densityFn := boltzmannDensityFn wθ T + let partitionFn := boltzmannPartitionFn wθ T + let _h_part_pos_finite := boltzmannPartitionFn_pos_finite wθ T hT + let countMeasure : Measure ((HopfieldNetwork R U).State) := MeasureTheory.Measure.count + if h_part : partitionFn = 0 ∨ partitionFn = ⊤ then + 0 + else + let partitionFn_ne_zero : partitionFn ≠ 0 := by + intro h_zero + exact h_part (Or.inl h_zero) + let partitionFn_ne_top : partitionFn ≠ ⊤ := by + intro h_top + exact h_part (Or.inr h_top) + Measure.withDensity countMeasure (fun s => densityFn s / partitionFn) + +-- Helper lemma to handle the 'if' in boltzmannDistribution definition +lemma boltzmannDistribution_def_of_pos_finite (wθ : Params (HopfieldNetwork R U)) + (T : ℝ) (hT : T ≠ 0) : + boltzmannDistribution wθ T hT = + let densityFn := boltzmannDensityFn wθ T + let partitionFn := boltzmannPartitionFn wθ T + let countMeasure : Measure ((HopfieldNetwork R U).State) := MeasureTheory.Measure.count + Measure.withDensity countMeasure (fun s => densityFn s / partitionFn) := by + let h_part := boltzmannPartitionFn_pos_finite wθ T hT + simp [boltzmannDistribution, h_part.1.ne', h_part.2.ne] + -- Use the fact that partitionFn is > 0 and < ⊤ + +/-- The Boltzmann distribution measure of the universe equals the integral of density/partition -/ +lemma boltzmannDistribution_measure_univ (wθ : Params (HopfieldNetwork R U)) (T : ℝ) (hT : T ≠ 0) : + boltzmannDistribution wθ T hT Set.univ = + ∫⁻ s in Set.univ, (boltzmannDensityFn wθ T s) / (boltzmannPartitionFn wθ T) ∂Measure.count := by + rw [boltzmannDistribution_def_of_pos_finite wθ T hT] + simp only [withDensity_apply _ MeasurableSet.univ] + +/-- The integral over the universe equals the sum over all states -/ +lemma boltzmannDistribution_integral_eq_sum (wθ : Params (HopfieldNetwork R U)) + (T : ℝ) (_hT : T ≠ 0) : + ∫⁻ s in Set.univ, (boltzmannDensityFn wθ T s) / (boltzmannPartitionFn wθ T) ∂Measure.count = + ∑ s ∈ Finset.univ, (boltzmannDensityFn wθ T s) / (boltzmannPartitionFn wθ T) := by + rw [Measure.restrict_univ] + trans ∑' (s : (HopfieldNetwork R U).State), + (boltzmannDensityFn wθ T s) / (boltzmannPartitionFn wθ T) + · exact MeasureTheory.lintegral_count + (fun s => (boltzmannDensityFn wθ T s) / (boltzmannPartitionFn wθ T)) + · exact tsum_fintype fun b ↦ boltzmannDensityFn wθ T b / boltzmannPartitionFn wθ T + +/-- Division can be distributed over the sum in the Boltzmann distribution -/ +lemma boltzmannDistribution_div_sum (wθ : Params (HopfieldNetwork R U)) (T : ℝ) (hT : T ≠ 0) : + ∑ s ∈ Finset.univ, (boltzmannDensityFn wθ T s) / (boltzmannPartitionFn wθ T) = + (∑ s ∈ Finset.univ, boltzmannDensityFn wθ T s) / (boltzmannPartitionFn wθ T) := by + let Z := boltzmannPartitionFn wθ T + let h_part := boltzmannPartitionFn_pos_finite wθ T hT + have h_Z_pos : Z > 0 := h_part.1 + have h_Z_lt_top : Z < ⊤ := h_part.2 + have h_div_def : ∀ (a b : ENNReal), a / b = a * b⁻¹ := fun a b => by + rw [ENNReal.div_eq_inv_mul] + rw [mul_comm b⁻¹ a] + simp only [h_div_def] + rw [Finset.sum_mul] + + +/-- The sum of Boltzmann probabilities equals 1 -/ +lemma boltzmannDistribution_sum_one (wθ : Params (HopfieldNetwork R U)) (T : ℝ) (hT : T ≠ 0) : + (∑ s ∈ Finset.univ, boltzmannDensityFn wθ T s) / (boltzmannPartitionFn wθ T) = 1 := by + simp only [boltzmannPartitionFn] + let h_part := boltzmannPartitionFn_pos_finite wθ T hT + exact ENNReal.div_self h_part.1.ne' h_part.2.ne + +/-- +Proves that the Boltzmann distribution for a Hopfield network forms a valid probability measure. +-/ +theorem boltzmannDistribution_isProbability {R U : Type} + [Field R] [LinearOrder R] [IsStrictOrderedRing R] [DecidableEq U] + [Fintype U] [Nonempty U] [Coe R ℝ] + (wθ : Params (HopfieldNetwork R U)) (T : ℝ) (hT : T ≠ 0) : + IsProbabilityMeasure (boltzmannDistribution wθ T hT) := by + constructor + rw [boltzmannDistribution_measure_univ wθ T hT] + rw [boltzmannDistribution_integral_eq_sum wθ T hT] + rw [boltzmannDistribution_div_sum wθ T hT] + exact boltzmannDistribution_sum_one wθ T hT + +end HopfieldMarkovChain + +end MarkovChain diff --git a/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/NNStochastic.lean b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/NNStochastic.lean new file mode 100644 index 000000000..bcff7235c --- /dev/null +++ b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/NNStochastic.lean @@ -0,0 +1,16 @@ +import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.NeuralNetwork +import Mathlib.Probability.ProbabilityMassFunction.Constructions + +/-- Probability Mass Function over Neural Network States -/ +def NeuralNetwork.StatePMF {R U : Type} [Zero R] + (NN : NeuralNetwork R U) := PMF (NN.State) + +/-- Temperature-parameterized stochastic dynamics for neural networks -/ +def NeuralNetwork.StochasticDynamics {R U : Type} [Zero R] + (NN : NeuralNetwork R U) := + ∀ (_ : ℝ), NN.State → NeuralNetwork.StatePMF NN + +/-- Metropolis acceptance decision as a probability mass function over Boolean outcomes -/ +def NN.State.metropolisDecision (p : ℝ) : PMF Bool := + PMF.bernoulli (ENNReal.ofReal (min p 1)) + (mod_cast min_le_right p 1) diff --git a/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/NeuralNetwork.lean b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/NeuralNetwork.lean new file mode 100644 index 000000000..05ae88be9 --- /dev/null +++ b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/NeuralNetwork.lean @@ -0,0 +1,115 @@ +/- +Copyright (c) 2024 Michail Karatarakis. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Michail Karatarakis +-/ +import Mathlib.Combinatorics.Digraph.Basic +import Mathlib.Data.Matrix.Basic +import Mathlib.Data.Vector.Basic + +open Mathlib Finset + +/- +A `NeuralNetwork` models a neural network with: + +- `R`: Type for weights and activations. +- `U`: Type for neurons. +- `[Zero R]`: `R` has a zero element. + +It extends `Digraph U` and includes the network's architecture, activation functions, and constraints. +-/ +structure NeuralNetwork (R U : Type) [Zero R] extends Digraph U where + /-- Input neurons. -/ + (Ui Uo Uh : Set U) + /-- There is at least one input neuron. -/ + (hUi : Ui ≠ ∅) + /-- There is at least one output neuron. -/ + (hUo : Uo ≠ ∅) + /-- All neurons are either input, output, or hidden. -/ + (hU : Set.univ = (Ui ∪ Uo ∪ Uh)) + /-- Hidden neurons are not input or output neurons. -/ + (hhio : Uh ∩ (Ui ∪ Uo) = ∅) + /-- Dimensions of input vectors for each neuron. -/ + (κ1 κ2 : U → ℕ) + /-- Computes the net input to a neuron. -/ + (fnet : ∀ u : U, (U → R) → (U → R) → Vector R (κ1 u) → R) + /-- Computes the activation of a neuron. -/ + (fact : ∀ u : U, R → R → Vector R (κ2 u) → R) -- R_current_activation, R_net_input, params + /-- Computes the final output of a neuron. -/ + (fout : ∀ _ : U, R → R) + /-- Predicate on activations. -/ + (pact : R → Prop) + /-- Predicate on weight matrices. -/ + (pw : Matrix U U R → Prop) + /-- If all activations satisfy `pact`, then the activations computed by `fact` also satisfy `pact`. -/ + (hpact : ∀ (w : Matrix U U R) (_ : ∀ u v, ¬ Adj u v → w u v = 0) (_ : pw w) + (σ : (u : U) → Vector R (κ1 u)) (θ : (u : U) → Vector R (κ2 u)) (current_neuron_activations : U → R), + (∀ u_idx : U, pact (current_neuron_activations u_idx)) → -- Precondition on all current activations + (∀ u_target : U, pact (fact u_target (current_neuron_activations u_target) -- Pass current_act of target neuron + (fnet u_target (w u_target) (fun v => fout v (current_neuron_activations v)) (σ u_target)) + (θ u_target)))) + + +variable {R U : Type} [Zero R] + +/-- `Params` is a structure that holds the parameters for a neural network `NN`. -/ +structure Params (NN : NeuralNetwork R U) where + (w : Matrix U U R) + (hw : ∀ u v, ¬ NN.Adj u v → w u v = 0) + (hw' : NN.pw w) + (σ : ∀ u : U, Vector R (NN.κ1 u)) + (θ : ∀ u : U, Vector R (NN.κ2 u)) + +namespace NeuralNetwork + +structure State (NN : NeuralNetwork R U) where + act : U → R + hp : ∀ u : U, NN.pact (act u) + +/-- Extensionality lemma for neural network states -/ +@[ext] +lemma ext {R U : Type} [Zero R] {NN : NeuralNetwork R U} + {s₁ s₂ : NN.State} : (∀ u, s₁.act u = s₂.act u) → s₁ = s₂ := by + intro h + cases s₁ + cases s₂ + simp only [NeuralNetwork.State.mk.injEq] + apply funext + exact h + +namespace State + +variable {NN : NeuralNetwork R U} (wσθ : Params NN) (s : NN.State) + +def out (u : U) : R := NN.fout u (s.act u) +def net (u : U) : R := NN.fnet u (wσθ.w u) (fun v => s.out v) (wσθ.σ u) +def onlyUi : Prop := ∀ u : U, ¬ u ∈ NN.Ui → s.act u = 0 + +variable [DecidableEq U] + +def Up {NN_local : NeuralNetwork R U} (s : NN_local.State) (wσθ : Params NN_local) (u_upd : U) : NN_local.State := + { act := fun v => if v = u_upd then + NN_local.fact u_upd (s.act u_upd) + (NN_local.fnet u_upd (wσθ.w u_upd) (fun n => s.out n) (wσθ.σ u_upd)) + (wσθ.θ u_upd) + else + s.act v, + hp := by + intro v_target + rw [ite_eq_dite] + split_ifs with h_eq_upd_neuron + · exact NN_local.hpact wσθ.w wσθ.hw wσθ.hw' wσθ.σ wσθ.θ s.act s.hp u_upd + · exact s.hp v_target + } + +def workPhase (extu : NN.State) (_ : extu.onlyUi) (uOrder : List U) : NN.State := + uOrder.foldl (fun s_iter u_iter => s_iter.Up wσθ u_iter) extu + +def seqStates (useq : ℕ → U) : ℕ → NeuralNetwork.State NN + | 0 => s + | n + 1 => .Up (seqStates useq n) wσθ (useq n) + +def isStable : Prop := ∀ (u : U), (s.Up wσθ u).act u = s.act u + +end State +end NeuralNetwork diff --git a/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Stochastic.lean b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Stochastic.lean new file mode 100644 index 000000000..2235c1991 --- /dev/null +++ b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Stochastic.lean @@ -0,0 +1,940 @@ +/- +Copyright (c) 2025 Matteo Cipollina. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Matteo Cipollina +-/ + +import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.NNStochastic +import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.StochasticAux +import Mathlib.Analysis.RCLike.Basic +import Mathlib.LinearAlgebra.AffineSpace.AffineMap +import Mathlib.LinearAlgebra.Dual.Lemmas + +/- +# Stochastic Hopfield Network Implementation + +This file defines and proves properties related to a stochastic Hopfield network. +It includes definitions for states, neural network parameters, energy computations, +and stochastic updates using both Gibbs sampling and Metropolis-Hastings algorithms. +- Functions (`StatePMF`, `StochasticDynamics`) representing probability measures over states. +- Key stochastic update operations, including a single-neuron Gibbs update + (`gibbsUpdateNeuron`, `gibbsUpdateSingleNeuron`) and full-network sampling steps + (`gibbsSamplingStep`, `gibbsSamplingSteps`) that iterate these updates. +- Definitions (`metropolisDecision`, `metropolisHastingsStep`, `metropolisHastingsSteps`) for + implementing a Metropolis-Hastings update rule in a Hopfield network. +- A simulated annealing procedure (`simulatedAnnealing`) that adaptively lowers the temperature + to guide the network into a low-energy configuration. +- Various lemmas (such as `single_site_difference`, `updateNeuron_preserves`, and + `gibbs_probs_sum_one`) ensuring correctness and consistency of the update schemes. +- Utility definitions and proofs, including creation of valid parameters + (`mkArray_creates_valid_hopfield_params`), + verification of adjacency (`all_nodes_adjacent`), total variation distance + (`total_variation_distance`), partition function (`partitionFunction`), and more. +-/ +open Finset Matrix NeuralNetwork State + +variable {R U : Type} [Field R] [LinearOrder R] [IsStrictOrderedRing R] + [DecidableEq U] [Fintype U] [Nonempty U] (wθ : Params (HopfieldNetwork R U)) (s : (HopfieldNetwork R U).State) + [Coe R ℝ] (T : ℝ) + +/-- Performs a Gibbs update on a single neuron `u` of the state `s`. + The update probability depends on the energy change associated with flipping the neuron's state, + parameterized by the temperature `T`. -/ +noncomputable def NN.State.gibbsUpdateNeuron [Coe R ℝ] (T : ℝ) (u : U) : PMF ((HopfieldNetwork R U).State) := + let h_u := s.net wθ u + let ΔE := 2 * h_u * s.act u + let p_flip := ENNReal.ofReal (Real.exp (-(↑ΔE) / T)) / (1 + ENNReal.ofReal (Real.exp (-(↑ΔE) / T))) + let p_flip_le_one : p_flip ≤ 1 := by + simp only [p_flip] + let a := ENNReal.ofReal (Real.exp (-(↑ΔE) / T)) + have h_a_nonneg : 0 ≤ a := zero_le a + have h_denom_ne_zero : 1 + a ≠ 0 := by + intro h + have h1 : 0 ≤ 1 + a := zero_le (1 + a) + have h2 : 1 + a = 0 := h + simp_all only [zero_le, add_eq_zero, one_ne_zero, ENNReal.ofReal_eq_zero, false_and, a, ΔE, h_u, p_flip] + have h_sum_ne_top : (1 + a) ≠ ⊤ := by + apply ENNReal.add_ne_top.2 + constructor + · exact ENNReal.one_ne_top + · apply ENNReal.ofReal_ne_top + rw [ENNReal.div_le_iff h_denom_ne_zero h_sum_ne_top] + simp only [one_mul, h_u, ΔE, a, p_flip] + exact le_add_self + PMF.bind (PMF.bernoulli p_flip p_flip_le_one) $ λ should_flip => + PMF.pure $ if should_flip then s.Up wθ u else s + +/-- Update a single neuron according to Gibbs sampling rule -/ +noncomputable def NN.State.gibbsUpdateSingleNeuron (u : U) : PMF ((HopfieldNetwork R U).State) := + -- Calculate local field for the neuron + let local_field := s.net wθ u + -- Calculate probabilities based on Boltzmann distribution + let probs : Bool → ENNReal := fun b => + let new_act_val := if b then 1 else -1 + ENNReal.ofReal (Real.exp (local_field * new_act_val / T)) + -- Create PMF with normalized probabilities + let total : ENNReal := probs true + probs false + let norm_probs : Bool → ENNReal := λ b => probs b / total + -- Convert Bool to State + (PMF.map (λ b => if b then + NN.State.updateNeuron s u 1 (mul_self_eq_mul_self_iff.mp rfl) + else + NN.State.updateNeuron s u (-1) (AffineMap.lineMap_eq_lineMap_iff.mp rfl)) + (PMF.ofFintype norm_probs (by + have h_total : total ≠ 0 := by { + simp [probs] + refine ENNReal.inv_ne_top.mp ?_ + have h_exp_pos := Real.exp_pos (local_field * 1 / T) + have h := ENNReal.ofReal_pos.mpr h_exp_pos + simp_all only [mul_one, ENNReal.ofReal_pos, mul_ite, mul_neg, ↓reduceIte, Bool.false_eq_true, ne_eq, + ENNReal.inv_eq_top, add_eq_zero, ENNReal.ofReal_eq_zero, not_and, not_le, isEmpty_Prop, IsEmpty.forall_iff, + local_field, total, probs]} + have h_total_ne_top : total ≠ ⊤ := by {simp [probs, total]} + have h_sum : Finset.sum Finset.univ norm_probs = 1 := by + calc Finset.sum Finset.univ norm_probs + = (probs true)/total + (probs false)/total := Fintype.sum_bool fun b ↦ probs b / total + _ = (probs true + probs false)/total := ENNReal.div_add_div_same + _ = total/total := by rfl + _ = 1 := ENNReal.div_self h_total h_total_ne_top + exact h_sum))) + +@[inherit_doc] +scoped[ENNReal] notation "ℝ≥0∞" => ENNReal + +/-- Given a Hopfield Network's parameters, temperature, and current state, performs a single step +of Gibbs sampling by: +1. Uniformly selecting a random neuron +2. Updating that neuron's state according to the Gibbs distribution +-/ +noncomputable def NN.State.gibbsSamplingStep : PMF ((HopfieldNetwork R U).State) := + -- Uniform random selection of neuron + let neuron_pmf : PMF U := + PMF.ofFintype (λ _ => (1 : ENNReal) / (Fintype.card U : ENNReal)) + (by + rw [Finset.sum_const, Finset.card_univ] + rw [ENNReal.div_eq_inv_mul] + simp only [mul_one] + have h : (Fintype.card U : ENNReal) ≠ 0 := by + simp [Fintype.card_pos_iff.mpr inferInstance] + have h_top : (Fintype.card U : ENNReal) ≠ ⊤ := ENNReal.coe_ne_top + rw [← ENNReal.mul_inv_cancel h h_top] + simp_all only [ne_eq, Nat.cast_eq_zero, Fintype.card_ne_zero, not_false_eq_true, ENNReal.natCast_ne_top, + nsmul_eq_mul]) + -- Bind neuron selection with conditional update + PMF.bind neuron_pmf $ λ u => NN.State.gibbsUpdateSingleNeuron wθ s T u + +instance : Coe ℝ ℝ := ⟨id⟩ + +/-- Perform a stochastic update on a Pattern representation -/ +noncomputable def patternStochasticUpdate + {n : ℕ} [Nonempty (Fin n)] (weights : Fin n → Fin n → ℝ) (h_diag_zero : ∀ i : Fin n, weights i i = 0) + (h_sym : ∀ i j : Fin n, weights i j = weights j i) (T : ℝ) + (pattern : NeuralNetwork.State (HopfieldNetwork ℝ (Fin n))) (i : Fin n) : + PMF (NeuralNetwork.State (HopfieldNetwork ℝ (Fin n))) := + let wθ : Params (HopfieldNetwork ℝ (Fin n)) := { + w := weights, + hw := fun u v h => by + if h_eq : u = v then + rw [h_eq] + exact h_diag_zero v + else + have h_adj : (HopfieldNetwork ℝ (Fin n)).Adj u v := by + simp only [HopfieldNetwork]; simp only [ne_eq] + exact h_eq + contradiction + hw' := by + unfold NeuralNetwork.pw + exact IsSymm.ext_iff.mpr fun i j ↦ h_sym j i + σ := fun u => Vector.mk (Array.replicate + ((HopfieldNetwork ℝ (Fin n)).κ1 u) (0 : ℝ)) rfl, + θ := fun u => Vector.mk (Array.replicate + ((HopfieldNetwork ℝ (Fin n)).κ2 u) (0 : ℝ)) rfl + } + NN.State.gibbsUpdateSingleNeuron wθ pattern T i + +/-- Performs multiple steps of Gibbs sampling in a Hopfield network, starting from + an initial state. Each step involves: + 1. First recursively applying previous steps (if any) + 2. Then performing a single Gibbs sampling step on the resulting state + The temperature parameter T controls the randomness of the updates. -/ +noncomputable def NN.State.gibbsSamplingSteps (steps : ℕ) : PMF ((HopfieldNetwork R U).State) := + match steps with + | 0 => PMF.pure s + | steps+1 => PMF.bind (gibbsSamplingSteps steps) $ λ s' => + NN.State.gibbsSamplingStep wθ s' T + +/-- Temperature schedule for simulated annealing that decreases exponentially with each step. -/ +noncomputable def temperatureSchedule (initial_temp : ℝ) (cooling_rate : ℝ) (step : ℕ) : ℝ := + initial_temp * Real.exp (-cooling_rate * step) + +/-- Recursively applies Gibbs sampling steps with decreasing temperature according to + the cooling schedule, terminating when the step count reaches the target number of steps. -/ +noncomputable def applyAnnealingSteps (temp_schedule : ℕ → ℝ) (steps : ℕ) + (step : ℕ) (state : (HopfieldNetwork R U).State) : PMF ((HopfieldNetwork R U).State) := + if h : step ≥ steps then + PMF.pure state + else + PMF.bind (NN.State.gibbsSamplingStep wθ state (temp_schedule step)) + (applyAnnealingSteps temp_schedule steps (step + 1)) +termination_by steps - step +decreasing_by { + have : step < steps := not_le.mp h + have : steps - (step + 1) < steps - step := by + rw [Nat.sub_succ] + simp_all only [ge_iff_le, not_le, Nat.pred_eq_sub_one, tsub_lt_self_iff, tsub_pos_iff_lt, Nat.lt_one_iff, + pos_of_gt, and_self] + exact this +} + +/-- `NN.State.simulatedAnnealing` implements the simulated annealing optimization algorithm for a Hopfield Network. +This function performs simulated annealing by starting from an initial state and gradually reducing +the temperature according to an exponential cooling schedule, allowing the system to explore the +state space and eventually settle into a low-energy configuration. +-/ +noncomputable def NN.State.simulatedAnnealing + (initial_temp : ℝ) (cooling_rate : ℝ) (steps : ℕ) + (initial_state : (HopfieldNetwork R U).State) : PMF ((HopfieldNetwork R U).State) := + let temp_schedule := temperatureSchedule initial_temp cooling_rate + applyAnnealingSteps wθ temp_schedule steps 0 initial_state + +/-- Given a HopfieldNetwork with parameters `wθ` and temperature `T`, computes the acceptance probability +for transitioning from a `current` state to a `proposed` state according to the Metropolis-Hastings algorithm. + +* If the energy difference (ΔE) is negative or zero, returns 1.0 (always accepts the transition) +* If the energy difference is positive, returns exp(-ΔE/T) following the Boltzmann distribution +-/ +noncomputable def NN.State.acceptanceProbability + (current : (HopfieldNetwork R U).State) (proposed : (HopfieldNetwork R U).State) : ℝ := + let energy_diff := proposed.E wθ - current.E wθ + if energy_diff ≤ 0 then + 1.0 -- Always accept if energy decreases + else + Real.exp (-energy_diff / T) -- Accept with probability e^(-ΔE/T) if energy increases + +/-- The partition function for a Hopfield network, defined as the sum over all possible states +of the Boltzmann factor `exp(-E/T)`. +-/ +noncomputable def NN.State.partitionFunction : ℝ := + ∑ s : (HopfieldNetwork R U).State, Real.exp (-s.E wθ / T) + +/-- Metropolis-Hastings single step for Hopfield networks -/ +noncomputable def NN.State.metropolisHastingsStep : PMF ((HopfieldNetwork R U).State) := + -- Uniform random selection of neuron + let neuron_pmf : PMF U := + PMF.ofFintype (λ _ => (1 : ENNReal) / (Fintype.card U : ENNReal)) + (by + rw [Finset.sum_const, Finset.card_univ] + rw [ENNReal.div_eq_inv_mul] + simp only [mul_one] + have h : (Fintype.card U : ENNReal) ≠ 0 := by + simp [Fintype.card_pos_iff.mpr inferInstance] + have h_top : (Fintype.card U : ENNReal) ≠ ⊤ := ENNReal.coe_ne_top + rw [← ENNReal.mul_inv_cancel h h_top] + simp_all only [ne_eq, Nat.cast_eq_zero, Fintype.card_ne_zero, not_false_eq_true, ENNReal.natCast_ne_top, + nsmul_eq_mul]) + -- Create proposed state by flipping a randomly selected neuron + let propose : U → PMF ((HopfieldNetwork R U).State) := λ u => + let flipped_state := + if s.act u = 1 then -- Assuming 1 and -1 as valid activation values + NN.State.updateNeuron s u (-1) (Or.inr rfl) + else + NN.State.updateNeuron s u 1 (Or.inl rfl) + let p := NN.State.acceptanceProbability wθ T s flipped_state + -- Make acceptance decision + PMF.bind (NN.State.metropolisDecision p) (λ (accept : Bool) => + if accept then PMF.pure flipped_state else PMF.pure s) + -- Combine neuron selection with state proposal + PMF.bind neuron_pmf propose + +/-- Multiple steps of Metropolis-Hastings algorithm for Hopfield networks -/ +noncomputable def NN.State.metropolisHastingsSteps (steps : ℕ) + : PMF ((HopfieldNetwork R U).State) := + match steps with + | 0 => PMF.pure s + | steps+1 => PMF.bind (metropolisHastingsSteps steps) $ λ s' => + NN.State.metropolisHastingsStep wθ s' T + +/-- The Boltzmann (Gibbs) distribution over neural network states -/ +noncomputable def boltzmannDistribution : ((HopfieldNetwork R U).State → ℝ) := + λ s => Real.exp (-s.E wθ / T) / NN.State.partitionFunction wθ T + +/-- The transition probability matrix for Gibbs sampling -/ +noncomputable def gibbsTransitionProb (s s' : (HopfieldNetwork R U).State) : ℝ := + ENNReal.toReal ((NN.State.gibbsSamplingStep wθ s) T s') + +/-- The transition probability matrix for Metropolis-Hastings -/ +noncomputable def metropolisTransitionProb (s s' : (HopfieldNetwork R U).State) : ℝ := + ENNReal.toReal ((NN.State.metropolisHastingsStep wθ s) T s') + +/-- Total variation distance between probability distributions -/ +noncomputable def total_variation_distance + (μ ν : (HopfieldNetwork R U).State → ℝ) : ℝ := + (1/2) * ∑ s : (HopfieldNetwork R U).State, |μ s - ν s| + +/-- For Gibbs updates, given the normalization and probabilities, the sum of normalized probabilities equals 1 -/ +lemma gibbs_probs_sum_one + (v : U) : + let local_field := s.net wθ v + let probs : Bool → ENNReal := fun b => + let new_act_val := if b then 1 else -1 + ENNReal.ofReal (Real.exp (local_field * new_act_val / T)) + let total := probs true + probs false + let norm_probs := λ b => probs b / total + ∑ b : Bool, norm_probs b = 1 := by + intro local_field probs total norm_probs + have h_sum : ∑ b : Bool, probs b / total = (probs true + probs false) / total := by + simp only [Fintype.sum_bool] + exact ENNReal.div_add_div_same + rw [h_sum] + have h_total_eq : probs true + probs false = total := by rfl + rw [h_total_eq] + have h_total_ne_zero : total ≠ 0 := by + simp only [total, probs, ne_eq] + intro h_zero + have h1 : ENNReal.ofReal (Real.exp (local_field * 1 / T)) > 0 := by + apply ENNReal.ofReal_pos.mpr + apply Real.exp_pos + have h_sum_zero : ENNReal.ofReal (Real.exp (local_field * 1 / T)) + + ENNReal.ofReal (Real.exp (local_field * (-1) / T)) = 0 := h_zero + exact h1.ne' (add_eq_zero.mp h_sum_zero).1 + have h_total_ne_top : total ≠ ⊤ := by simp [total, probs] + exact ENNReal.div_self h_total_ne_zero h_total_ne_top + +/-- The function that maps boolean values to states in Gibbs sampling -/ +def gibbs_bool_to_state_map + (s : (HopfieldNetwork R U).State) (v : U) : Bool → (HopfieldNetwork R U).State := + λ b => if b then + NN.State.updateNeuron s v 1 (mul_self_eq_mul_self_iff.mp rfl) + else + NN.State.updateNeuron s v (-1) (AffineMap.lineMap_eq_lineMap_iff.mp rfl) + +/-- The total normalization constant for Gibbs sampling is positive -/ +lemma gibbs_total_positive + (local_field : ℝ) (T : ℝ) : + let probs : Bool → ENNReal := fun b => + let new_act_val := if b then 1 else -1 + ENNReal.ofReal (Real.exp (local_field * new_act_val / T)) + probs true + probs false ≠ 0 := by + intro probs + simp only [ne_eq] + intro h_zero + have h1 : ENNReal.ofReal (Real.exp (local_field * 1 / T)) > 0 := by + apply ENNReal.ofReal_pos.mpr + apply Real.exp_pos + have h_sum_zero : ENNReal.ofReal (Real.exp (local_field * 1 / T)) + + ENNReal.ofReal (Real.exp (local_field * (-1) / T)) = 0 := h_zero + have h_both_zero : ENNReal.ofReal (Real.exp (local_field * 1 / T)) = 0 ∧ + ENNReal.ofReal (Real.exp (local_field * (-1) / T)) = 0 := + add_eq_zero.mp h_sum_zero + exact h1.ne' h_both_zero.1 + +/-- The total normalization constant for Gibbs sampling is not infinity -/ +lemma gibbs_total_not_top + (local_field : ℝ) (T : ℝ) : + let probs : Bool → ENNReal := fun b => + let new_act_val := if b then 1 else -1 + ENNReal.ofReal (Real.exp (local_field * new_act_val / T)) + probs true + probs false ≠ ⊤ := by + intro probs + simp only [mul_ite, mul_one, mul_neg, ↓reduceIte, Bool.false_eq_true, ne_eq, ENNReal.add_eq_top, + ENNReal.ofReal_ne_top, or_self, not_false_eq_true, probs] + +/-- For a positive PMF.map application, there exists a preimage with positive probability -/ +lemma pmf_map_pos_implies_preimage {α β : Type} [Fintype α] [DecidableEq β] + {p : α → ENNReal} (h_pmf : ∑ a, p a = 1) (f : α → β) (y : β) : + (PMF.map f (PMF.ofFintype p h_pmf)) y > 0 → + ∃ x : α, p x > 0 ∧ f x = y := by + intro h_pos + simp only [PMF.map_apply] at h_pos + simp_all only [PMF.ofFintype_apply, tsum_eq_filter_sum, gt_iff_lt, filter_sum_pos_iff_exists_pos, + pmf_map_pos_iff_exists_pos] + +/-- For states with positive Gibbs update probability, there exists a boolean variable that + determines whether the state has activation 1 or -1 at the updated neuron -/ +lemma gibbsUpdate_exists_bool (v : U) (s_next : (HopfieldNetwork R U).State) : + (NN.State.gibbsUpdateSingleNeuron wθ s T v) s_next > 0 → + ∃ b : Bool, s_next = gibbs_bool_to_state_map s v b := by + intro h_prob_pos + unfold NN.State.gibbsUpdateSingleNeuron at h_prob_pos + let local_field_R := s.net wθ v + let local_field : ℝ := ↑local_field_R + let probs : Bool → ENNReal := fun b => + let new_act_val := if b then 1 else -1 + ENNReal.ofReal (Real.exp (local_field * new_act_val / T)) + let total := probs true + probs false + let norm_probs : Bool → ENNReal := λ b => probs b / total + let map_fn : Bool → (HopfieldNetwork R U).State := gibbs_bool_to_state_map s v + have h_sum_eq_1 : ∑ b : Bool, norm_probs b = 1 := by + have h_total_ne_zero : total ≠ 0 := gibbs_total_positive local_field T + have h_total_ne_top : total ≠ ⊤ := gibbs_total_not_top local_field T + calc Finset.sum Finset.univ norm_probs + = (probs true)/total + (probs false)/total := + Fintype.sum_bool fun b ↦ probs b / total + _ = (probs true + probs false)/total := ENNReal.div_add_div_same + _ = total/total := by rfl + _ = 1 := ENNReal.div_self h_total_ne_zero h_total_ne_top + let base_pmf := PMF.ofFintype norm_probs h_sum_eq_1 + have ⟨b, _, h_map_eq⟩ := pmf_map_pos_implies_preimage h_sum_eq_1 map_fn s_next h_prob_pos + use b + exact Eq.symm h_map_eq + +/-- For states with positive probability under gibbsUpdateSingleNeuron, + they must be one of exactly two possible states (with neuron v set to 1 or -1) -/ +@[simp] +lemma gibbsUpdate_possible_states (v : U) (s_next : (HopfieldNetwork R U).State) : + (NN.State.gibbsUpdateSingleNeuron wθ s T v) s_next > 0 → + s_next = NN.State.updateNeuron s v 1 (mul_self_eq_mul_self_iff.mp rfl) ∨ + s_next = NN.State.updateNeuron s v (-1) + (AffineMap.lineMap_eq_lineMap_iff.mp rfl) := by + intro h_prob_pos + obtain ⟨b, h_eq⟩ := gibbsUpdate_exists_bool wθ s T v s_next h_prob_pos + cases b with + | false => + right + unfold gibbs_bool_to_state_map at h_eq + rw [@Std.Tactic.BVDecide.Normalize.if_eq_cond] at h_eq + exact h_eq + | true => + left + unfold gibbs_bool_to_state_map at h_eq + rw [@Std.Tactic.BVDecide.Normalize.if_eq_cond] at h_eq + exact h_eq + +/-- Gibbs update preserves states at non-updated sites -/ +@[simp] +lemma gibbsUpdate_preserves_other_neurons + (v w : U) (h_neq : w ≠ v) : + ∀ s_next, (NN.State.gibbsUpdateSingleNeuron wθ s T v) s_next > 0 → + s_next.act w = s.act w := by + intro s_next h_prob_pos + have h_structure := gibbsUpdate_possible_states wθ s T v s_next h_prob_pos + cases h_structure with + | inl h_pos => + rw [h_pos] + exact updateNeuron_preserves s v w 1 (mul_self_eq_mul_self_iff.mp rfl) h_neq + | inr h_neg => + rw [h_neg] + exact updateNeuron_preserves s v w (-1) + (AffineMap.lineMap_eq_lineMap_iff.mp rfl) h_neq + +/-- The probability mass function for a binary choice (true/false) + has sum 1 when properly normalized -/ +lemma pmf_binary_norm_sum_one (local_field : ℝ) (T : ℝ) : + let probs : Bool → ENNReal := fun b => + let new_act_val := if b then 1 else -1 + ENNReal.ofReal (Real.exp (local_field * new_act_val / T)) + let total := probs true + probs false + let norm_probs := λ b => probs b / total + ∑ b : Bool, norm_probs b = 1 := by + intro probs total norm_probs + have h_sum : ∑ b : Bool, probs b / total = (probs true + probs false) / total := by + simp only [Fintype.sum_bool] + exact ENNReal.div_add_div_same + rw [h_sum] + have h_total_ne_zero : total ≠ 0 := by + simp only [total, probs, ne_eq] + intro h_zero + have h1 : ENNReal.ofReal (Real.exp (local_field * 1 / T)) > 0 := by + apply ENNReal.ofReal_pos.mpr + apply Real.exp_pos + have h_sum_zero : ENNReal.ofReal (Real.exp (local_field * 1 / T)) + + ENNReal.ofReal (Real.exp (local_field * (-1) / T)) = 0 := h_zero + have h_both_zero : ENNReal.ofReal (Real.exp (local_field * 1 / T)) = 0 ∧ + ENNReal.ofReal (Real.exp (local_field * (-1) / T)) = 0 := by + exact add_eq_zero.mp h_sum_zero + exact h1.ne' h_both_zero.1 + have h_total_ne_top : total ≠ ⊤ := by + simp [total, probs] + exact ENNReal.div_self h_total_ne_zero h_total_ne_top + +/-- The normalization factor in Gibbs sampling is the sum of Boltzmann + factors for both possible states -/ +lemma gibbs_normalization_factor + (local_field : ℝ) (T : ℝ) : + let probs : Bool → ENNReal := fun b => + let new_act_val := if b then 1 else -1 + ENNReal.ofReal (Real.exp (local_field * new_act_val / T)) + let total := probs true + probs false + total = ENNReal.ofReal (Real.exp (local_field / T)) + ENNReal.ofReal + (Real.exp (-local_field / T)) := by + intro probs total + simp only [probs, total] + simp only [↓reduceIte, mul_one, Bool.false_eq_true, mul_neg, total, probs] + +/-- The probability mass assigned to true when using Gibbs sampling -/ +lemma gibbs_prob_true + (local_field : ℝ) (T : ℝ) : + let probs : Bool → ENNReal := fun b => + let new_act_val := if b then 1 else -1 + ENNReal.ofReal (Real.exp (local_field * new_act_val / T)) + let total := probs true + probs false + let norm_probs := λ b => probs b / total + norm_probs true = ENNReal.ofReal (Real.exp (local_field / T)) / + (ENNReal.ofReal (Real.exp (local_field / T)) + ENNReal.ofReal + (Real.exp (-local_field / T))) := by + intro probs total norm_probs + simp only [norm_probs, probs] + have h_total : total = ENNReal.ofReal (Real.exp (local_field / T)) + + ENNReal.ofReal (Real.exp (-local_field / T)) := by + simp only [mul_ite, mul_one, mul_neg, ↓reduceIte, Bool.false_eq_true, total, probs, norm_probs] + rw [h_total] + congr + simp only [↓reduceIte, mul_one, total, norm_probs, probs] + +/-- The probability mass assigned to false when using Gibbs sampling -/ +lemma gibbs_prob_false + (local_field : ℝ) (T : ℝ) : + let probs : Bool → ENNReal := fun b => + let new_act_val := if b then 1 else -1 + ENNReal.ofReal (Real.exp (local_field * new_act_val / T)) + let total := probs true + probs false + let norm_probs := λ b => probs b / total + norm_probs false = ENNReal.ofReal (Real.exp (-local_field / T)) / + (ENNReal.ofReal (Real.exp (local_field / T)) + ENNReal.ofReal (Real.exp (-local_field / T))) := by + intro probs total norm_probs + simp only [norm_probs, probs] + have h_total : total = ENNReal.ofReal (Real.exp (local_field / T)) + + ENNReal.ofReal (Real.exp (-local_field / T)) := by + simp [total, probs] + rw [h_total] + congr + simp only [Bool.false_eq_true, ↓reduceIte, mul_neg, mul_one, norm_probs, probs, total] + + +/-- Converts the ratio of Boltzmann factors to ENNReal sigmoid form. -/ +@[simp] +lemma ENNReal_exp_ratio_to_sigmoid (x : ℝ) : + ENNReal.ofReal (Real.exp x) / + (ENNReal.ofReal (Real.exp x) + ENNReal.ofReal (Real.exp (-x))) = + ENNReal.ofReal (1 / (1 + Real.exp (-2 * x))) := by + have num_pos : 0 ≤ Real.exp x := le_of_lt (Real.exp_pos x) + have denom_pos : 0 < Real.exp x + Real.exp (-x) := by + apply add_pos + · exact Real.exp_pos x + · exact Real.exp_pos (-x) + have h1 : ENNReal.ofReal (Real.exp x) / + (ENNReal.ofReal (Real.exp x) + ENNReal.ofReal (Real.exp (-x))) = + ENNReal.ofReal (Real.exp x / (Real.exp x + Real.exp (-x))) := by + have h_sum : ENNReal.ofReal (Real.exp x) + ENNReal.ofReal (Real.exp (-x)) = + ENNReal.ofReal (Real.exp x + Real.exp (-x)) := by + have exp_neg_pos : 0 ≤ Real.exp (-x) := le_of_lt (Real.exp_pos (-x)) + exact Eq.symm (ENNReal.ofReal_add num_pos exp_neg_pos) + rw [h_sum] + exact Eq.symm (ENNReal.ofReal_div_of_pos denom_pos) + have h2 : Real.exp x / (Real.exp x + Real.exp (-x)) = 1 / (1 + Real.exp (-2 * x)) := by + have h_denom : Real.exp x + Real.exp (-x) = Real.exp x * (1 + Real.exp (-2 * x)) := by + have h_exp_diff : Real.exp (-x) = Real.exp x * Real.exp (-2 * x) := by + rw [← Real.exp_add]; congr; ring + calc Real.exp x + Real.exp (-x) + = Real.exp x + Real.exp x * Real.exp (-2 * x) := by rw [h_exp_diff] + _ = Real.exp x * (1 + Real.exp (-2 * x)) := by rw [mul_add, mul_one] + rw [h_denom, div_mul_eq_div_div] + have h_exp_ne_zero : Real.exp x ≠ 0 := ne_of_gt (Real.exp_pos x) + field_simp + rw [h1, h2] + +@[simp] +lemma ENNReal.div_ne_top' {a b : ENNReal} (ha : a ≠ ⊤) (hb : b ≠ 0) : + a / b ≠ ⊤ := by + intro h_top + rw [ENNReal.div_eq_top] at h_top + rcases h_top with (⟨_, h_right⟩ | ⟨h_left, _⟩); + exact hb h_right; exact ha h_left + +lemma gibbs_prob_positive + (local_field : ℝ) (T : ℝ) : + let probs : Bool → ENNReal := fun b => + let new_act_val := if b then 1 else -1 + ENNReal.ofReal (Real.exp (local_field * new_act_val / T)) + let total := probs true + probs false + ENNReal.ofReal (Real.exp (local_field / T)) / total = + ENNReal.ofReal (1 / (1 + Real.exp (-2 * local_field / T))) := by + intro probs total + have h_total : total = ENNReal.ofReal (Real.exp (local_field / T)) + + ENNReal.ofReal (Real.exp (-local_field / T)) := by + simp only [mul_ite, mul_one, mul_neg, ↓reduceIte, Bool.false_eq_true, total, probs] + rw [h_total] + have h_temp : ∀ x, Real.exp (x / T) = Real.exp (x * (1/T)) := by + intro x; congr; field_simp + rw [h_temp local_field, h_temp (-local_field)] + have h_direct : + ENNReal.ofReal (Real.exp (local_field * (1 / T))) / + (ENNReal.ofReal (Real.exp (local_field * (1 / T))) + + ENNReal.ofReal (Real.exp (-local_field * (1 / T)))) = + ENNReal.ofReal (1 / (1 + Real.exp (-2 * local_field / T))) := by + have h := ENNReal_exp_ratio_to_sigmoid (local_field * (1 / T)) + have h_rhs : -2 * (local_field * (1 / T)) = -2 * local_field / T := by + field_simp + rw [h_rhs] at h + have neg_equiv : ENNReal.ofReal (Real.exp (-(local_field * (1 / T)))) = + ENNReal.ofReal (Real.exp (-local_field * (1 / T))) := by + congr; ring + rw [neg_equiv] at h + exact h + exact h_direct + +/-- The probability of setting a neuron to -1 under Gibbs sampling -/ +lemma gibbs_prob_negative + (local_field : ℝ) (T : ℝ) : + let probs : Bool → ENNReal := fun b => + let new_act_val := if b then 1 else -1 + ENNReal.ofReal (Real.exp (local_field * new_act_val / T)) + let total := probs true + probs false + ENNReal.ofReal (Real.exp (-local_field / T)) / total = + ENNReal.ofReal (1 / (1 + Real.exp (2 * local_field / T))) := by + intro probs total + have h_total : total = ENNReal.ofReal (Real.exp (local_field / T)) + + ENNReal.ofReal (Real.exp (-local_field / T)) := by + simp only [mul_ite, mul_one, mul_neg, ↓reduceIte, Bool.false_eq_true, total, probs] + rw [h_total] + have h_neg2_neg : -2 * (-local_field / T) = 2 * local_field / T := by ring + have h_neg_neg : -(-local_field / T) = local_field / T := by ring + have h_ratio_final : ENNReal.ofReal (Real.exp (-local_field / T)) / + (ENNReal.ofReal (Real.exp (local_field / T)) + + ENNReal.ofReal (Real.exp (-local_field / T))) = + ENNReal.ofReal (1 / (1 + Real.exp (2 * local_field / T))) := by + have h := ENNReal_exp_ratio_to_sigmoid (-local_field / T) + have h_exp_neg_neg : ENNReal.ofReal (Real.exp (-(-local_field / T))) = + ENNReal.ofReal (Real.exp (local_field / T)) := by congr + rw [h_exp_neg_neg] at h + have h_comm : ENNReal.ofReal (Real.exp (-local_field / T)) + + ENNReal.ofReal (Real.exp (local_field / T)) = + ENNReal.ofReal (Real.exp (local_field / T)) + + ENNReal.ofReal (Real.exp (-local_field / T)) := by + rw [add_comm] + rw [h_neg2_neg] at h + rw [h_comm] at h + exact h + exact h_ratio_final + +-- Lemma for the probability calculation in the positive case +lemma gibbs_prob_positive_case + (u : U) : + let local_field := s.net wθ u + let Z := ENNReal.ofReal (Real.exp (local_field / T)) + ENNReal.ofReal (Real.exp (-local_field / T)) + let norm_probs := λ b => if b then + ENNReal.ofReal (Real.exp (local_field / T)) / Z + else + ENNReal.ofReal (Real.exp (-local_field / T)) / Z + (PMF.map (gibbs_bool_to_state_map s u) (PMF.ofFintype norm_probs (by + have h_sum : ∑ b : Bool, norm_probs b = norm_probs true + norm_probs false := by + exact Fintype.sum_bool (λ b => norm_probs b) + rw [h_sum] + simp only [norm_probs] + have h_ratio_sum : ENNReal.ofReal (Real.exp (local_field / T)) / Z + + ENNReal.ofReal (Real.exp (-local_field / T)) / Z = + (ENNReal.ofReal (Real.exp (local_field / T)) + + ENNReal.ofReal (Real.exp (-local_field / T))) / Z := by + exact ENNReal.div_add_div_same + simp only [Bool.false_eq_true] + have h_if_true : (if True then ENNReal.ofReal (Real.exp (local_field / T)) / Z + else ENNReal.ofReal (Real.exp (-local_field / T)) / Z) = + ENNReal.ofReal (Real.exp (local_field / T)) / Z := by simp + + have h_if_false : (if False then ENNReal.ofReal (Real.exp (local_field / T)) / Z + else ENNReal.ofReal (Real.exp (-local_field / T)) / Z) = + ENNReal.ofReal (Real.exp (-local_field / T)) / Z := by simp + rw [h_if_true, h_if_false] + rw [h_ratio_sum] + have h_Z_ne_zero : Z ≠ 0 := by + simp only [ne_eq, add_eq_zero, ENNReal.ofReal_eq_zero, not_and, not_le, Z, norm_probs] + intros + exact Real.exp_pos (-Coe.coe local_field / T) + have h_Z_ne_top : Z ≠ ⊤ := by simp [Z] + exact ENNReal.div_self h_Z_ne_zero h_Z_ne_top + ))) (NN.State.updateNeuron s u 1 (Or.inl rfl)) = norm_probs true := by + intro + apply pmf_map_update_one + +-- Lemma for the probability calculation in the negative case +lemma gibbs_prob_negative_case + (u : U) : + let local_field := s.net wθ u + let Z := ENNReal.ofReal (Real.exp (local_field / T)) + + ENNReal.ofReal (Real.exp (-local_field / T)) + let norm_probs := λ b => if b then + ENNReal.ofReal (Real.exp (local_field / T)) / Z + else + ENNReal.ofReal (Real.exp (-local_field / T)) / Z + (PMF.map (gibbs_bool_to_state_map s u) (PMF.ofFintype norm_probs (by + have h_sum : ∑ b : Bool, norm_probs b = norm_probs true + norm_probs false := by + exact Fintype.sum_bool (λ b => norm_probs b) + rw [h_sum] + simp only [norm_probs] + have h_ratio_sum : ENNReal.ofReal (Real.exp (local_field / T)) / Z + + ENNReal.ofReal (Real.exp (-local_field / T)) / Z = + (ENNReal.ofReal (Real.exp (local_field / T)) + + ENNReal.ofReal (Real.exp (-local_field / T))) / Z := by + exact ENNReal.div_add_div_same + simp only [Bool.false_eq_true] + simp only [↓reduceIte, norm_probs] + rw [h_ratio_sum] + have h_Z_ne_zero : Z ≠ 0 := by + simp only [Z, ne_eq, add_eq_zero] + intro h + have h_exp_pos : ENNReal.ofReal (Real.exp (local_field / T)) > 0 := by + apply ENNReal.ofReal_pos.mpr + apply Real.exp_pos + exact (not_and_or.mpr (Or.inl h_exp_pos.ne')) h + have h_Z_ne_top : Z ≠ ⊤ := by + simp only [ne_eq, ENNReal.add_eq_top, ENNReal.ofReal_ne_top, or_self, not_false_eq_true, Z, + norm_probs] + exact ENNReal.div_self h_Z_ne_zero h_Z_ne_top))) + (NN.State.updateNeuron s u (-1) (Or.inr rfl)) = norm_probs false := by + intro + apply pmf_map_update_neg_one + +/-- PMF map from boolean values to updated states preserves probability structure -/ +lemma gibbsUpdate_pmf_structure + (u : U) : + let local_field := s.net wθ u + let probs : Bool → ENNReal := fun b => + let new_act_val := if b then 1 else -1 + ENNReal.ofReal (Real.exp (local_field * new_act_val / T)) + let total := probs true + probs false + let norm_probs := λ b => probs b / total + ∀ b : Bool, (PMF.map (gibbs_bool_to_state_map s u) (PMF.ofFintype norm_probs (by + have h_sum : ∑ b : Bool, norm_probs b = norm_probs true + norm_probs false := by + exact Fintype.sum_bool (λ b => norm_probs b) + rw [h_sum] + have h_ratio_sum : probs true / total + probs false / total = + (probs true + probs false) / total := by + exact ENNReal.div_add_div_same + rw [h_ratio_sum] + have h_total_ne_zero : total ≠ 0 := by + simp only [total, probs, ne_eq, add_eq_zero] + intro h + have h_exp_pos : ENNReal.ofReal (Real.exp (local_field * 1 / T)) > 0 := by + apply ENNReal.ofReal_pos.mpr + apply Real.exp_pos + exact (not_and_or.mpr (Or.inl h_exp_pos.ne')) h + have h_total_ne_top : total ≠ ⊤ := by simp only [mul_ite, mul_one, mul_neg, ↓reduceIte, + Bool.false_eq_true, ne_eq, ENNReal.add_eq_top, ENNReal.ofReal_ne_top, or_self, + not_false_eq_true, total, probs] + exact ENNReal.div_self h_total_ne_zero h_total_ne_top + ))) (gibbs_bool_to_state_map s u b) = norm_probs b := by + intro local_field probs total norm_probs b_bool + exact pmf_map_binary_state s u b_bool (fun b => norm_probs b) (by + have h_sum : ∑ b : Bool, norm_probs b = norm_probs true + norm_probs false := by + exact Fintype.sum_bool (λ b => norm_probs b) + rw [h_sum] + have h_ratio_sum : probs true / total + probs false / total = + (probs true + probs false) / total := by + exact ENNReal.div_add_div_same + rw [h_ratio_sum] + have h_total_ne_zero : total ≠ 0 := by + simp only [total, probs, ne_eq, add_eq_zero] + intro h + have h_exp_pos : ENNReal.ofReal (Real.exp (local_field * 1 / T)) > 0 := by + apply ENNReal.ofReal_pos.mpr + apply Real.exp_pos + exact (not_and_or.mpr (Or.inl h_exp_pos.ne')) h + have h_total_ne_top : total ≠ ⊤ := by simp only [mul_ite, mul_one, mul_neg, ↓reduceIte, + Bool.false_eq_true, ne_eq, ENNReal.add_eq_top, ENNReal.ofReal_ne_top, or_self, + not_false_eq_true, total, probs] + exact ENNReal.div_self h_total_ne_zero h_total_ne_top) + +/-- The probability of updating a neuron to 1 using Gibbs sampling -/ +lemma gibbsUpdate_prob_positive + (u : U) : + let local_field := s.net wθ u + let Z := ENNReal.ofReal (Real.exp (local_field / T)) + ENNReal.ofReal (Real.exp (-local_field / T)) + (NN.State.gibbsUpdateSingleNeuron wθ s T u) (NN.State.updateNeuron s u 1 (Or.inl rfl)) = + ENNReal.ofReal (Real.exp (local_field / T)) / Z := by + intro local_field Z + unfold NN.State.gibbsUpdateSingleNeuron + let probs : Bool → ENNReal := fun b => + let new_act_val := if b then 1 else -1 + ENNReal.ofReal (Real.exp (local_field * new_act_val / T)) + let total := probs true + probs false + have h_total_eq_Z : total = Z := by + simp only [mul_ite, mul_one, mul_neg, ↓reduceIte, Bool.false_eq_true, total, probs, Z] + have h_result := pmf_map_update_one s u (fun b => probs b / total) (by + have h_sum : ∑ b : Bool, probs b / total = (probs true + probs false) / total := by + simp only [Fintype.univ_bool, mem_singleton, Bool.true_eq_false, + not_false_eq_true,sum_insert, sum_singleton, total, probs, Z] + exact + ENNReal.div_add_div_same + rw [h_sum] + have h_total_ne_zero : total ≠ 0 := by + simp only [total, probs, ne_eq, add_eq_zero] + intro h + have h_exp_pos : ENNReal.ofReal (Real.exp (local_field * 1 / T)) > 0 := by + apply ENNReal.ofReal_pos.mpr + apply Real.exp_pos + exact (not_and_or.mpr (Or.inl h_exp_pos.ne')) h + have h_total_ne_top : total ≠ ⊤ := by simp [total, probs] + exact ENNReal.div_self h_total_ne_zero h_total_ne_top) + rw [h_result] + simp only [probs, mul_one_div] + rw [h_total_eq_Z] + simp only [if_true, mul_one] + +/-- The probability of updating a neuron to -1 using Gibbs sampling -/ +lemma gibbsUpdate_prob_negative + (u : U) : + let local_field := s.net wθ u + let Z := ENNReal.ofReal (Real.exp (local_field / T)) + ENNReal.ofReal (Real.exp (-local_field / T)) + (NN.State.gibbsUpdateSingleNeuron wθ s T u) (NN.State.updateNeuron s u (-1) (Or.inr rfl)) = + ENNReal.ofReal (Real.exp (-local_field / T)) / Z := by + intro local_field Z + unfold NN.State.gibbsUpdateSingleNeuron + let probs : Bool → ENNReal := fun b => + let new_act_val := if b then 1 else -1 + ENNReal.ofReal (Real.exp (local_field * new_act_val / T)) + let total := probs true + probs false + have h_total_eq_Z : total = Z := by + simp only [mul_ite, mul_one, mul_neg, ↓reduceIte, Bool.false_eq_true, total, probs, Z] + have h_result := pmf_map_update_neg_one s u (fun b => probs b / total) (by + have h_sum : ∑ b : Bool, probs b / total = (probs true + probs false) / total := by + simp only [Fintype.univ_bool, mem_singleton, Bool.true_eq_false, + not_false_eq_true, sum_insert, sum_singleton, total, probs, Z] + exact ENNReal.div_add_div_same + rw [h_sum] + have h_total_ne_zero : total ≠ 0 := by + simp only [total, probs, ne_eq, add_eq_zero] + intro h + have h_exp_pos : ENNReal.ofReal (Real.exp (local_field * 1 / T)) > 0 := by + apply ENNReal.ofReal_pos.mpr + apply Real.exp_pos + exact (not_and_or.mpr (Or.inl h_exp_pos.ne')) h + have h_total_ne_top : total ≠ ⊤ := by + simp only [mul_ite, mul_one, mul_neg, ↓reduceIte, + Bool.false_eq_true, ne_eq, ENNReal.add_eq_top, ENNReal.ofReal_ne_top, or_self, + not_false_eq_true, total, probs, Z] + exact ENNReal.div_self h_total_ne_zero h_total_ne_top) + rw [h_result] + simp only [probs, one_div_neg_one_eq_neg_one, one_div_neg_one_eq_neg_one] + rw [h_total_eq_Z] + simp only [Bool.false_eq_true, ↓reduceIte, mul_neg, mul_one, probs, Z, total] + +/-- Computes the probability of updating a neuron to a specific value using Gibbs sampling. +- If new_val = 1: probability = exp(local_field/T)/Z +- If new_val = -1: probability = exp(-local_field/T)/Z +where Z is the normalization constant (partition function). +-/ +@[simp] +lemma gibbs_update_single_neuron_prob (u : U) (new_val : R) + (hval : (HopfieldNetwork R U).pact new_val) : + let local_field := s.net wθ u + let Z := ENNReal.ofReal (Real.exp (local_field / T)) + ENNReal.ofReal (Real.exp (-local_field / T)) + (NN.State.gibbsUpdateSingleNeuron wθ s T u) (NN.State.updateNeuron s u new_val hval) = + if new_val = 1 then + ENNReal.ofReal (Real.exp (local_field / T)) / Z + else + ENNReal.ofReal (Real.exp (-local_field / T)) / Z := by + intro local_field Z + by_cases h_val : new_val = 1 + · rw [if_pos h_val] + have h_update_equiv := gibbs_bool_to_state_map_positive s u new_val hval h_val + rw [h_update_equiv] + exact gibbsUpdate_prob_positive wθ s T u + · rw [if_neg h_val] + have h_neg_val : new_val = -1 := hopfield_value_dichotomy new_val hval h_val + have h_update_equiv := gibbs_bool_to_state_map_negative s u new_val hval h_neg_val + rw [h_update_equiv] + exact gibbsUpdate_prob_negative wθ s T u + +/-- When states differ at site u, the probability of transitioning to s' by updating + any other site v is zero -/ +lemma gibbs_update_zero_other_sites (s s' : (HopfieldNetwork R U).State) + (u v : U) (h : ∀ w : U, w ≠ u → s.act w = s'.act w) (h_diff : s.act u ≠ s'.act u) : + v ≠ u → (NN.State.gibbsUpdateSingleNeuron wθ s T v) s' = 0 := by + intro hv + have h_act_diff : s'.act u ≠ s.act u := by + exact Ne.symm h_diff + have h_s'_diff_update : ∀ new_val hval, + s' ≠ NN.State.updateNeuron s v new_val hval := by + intro new_val hval + by_contra h_eq + have h_u_eq : s'.act u = (NN.State.updateNeuron s v new_val hval).act u := by + rw [←h_eq] + have h_u_preserved : (NN.State.updateNeuron s v new_val hval).act u = s.act u := by + exact updateNeuron_preserves s v u new_val hval (id (Ne.symm hv)) + rw [h_u_preserved] at h_u_eq + -- Use h to show contradiction + have h_s'_neq_s : s' ≠ s := by + by_contra h_s_eq + rw [h_s_eq] at h_diff + exact h_diff rfl + have h_same_elsewhere := h v hv + -- Now we have a contradiction: s' differs from s at u but also equals s.act u there + exact h_act_diff h_u_eq + by_contra h_pmf_nonzero + have h_pos_gt_zero : (NN.State.gibbsUpdateSingleNeuron wθ s T v) s' > 0 := by + exact (PMF.apply_pos_iff (NN.State.gibbsUpdateSingleNeuron wθ s T v) s').mpr h_pmf_nonzero + have h_structure := gibbsUpdate_possible_states wθ s T v s' h_pos_gt_zero + cases h_structure with + | inl h_pos_case => + apply h_s'_diff_update 1 (mul_self_eq_mul_self_iff.mp rfl) + exact h_pos_case + | inr h_neg_case => + apply h_s'_diff_update (-1) (AffineMap.lineMap_eq_lineMap_iff.mp rfl) + exact h_neg_case + +/-- When calculating the transition probability sum, only the term for the + differing site contributes -/ +lemma gibbs_transition_sum_simplification (s s' : (HopfieldNetwork R U).State) + (u : U) (h : ∀ v : U, v ≠ u → s.act v = s'.act v) (h_diff : s.act u ≠ s'.act u) : + let neuron_pmf : PMF U := PMF.ofFintype + (λ _ => (1 : ENNReal) / (Fintype.card U : ENNReal)) + (by + simp only [one_div, sum_const, card_univ, nsmul_eq_mul] + have h_card_ne_zero : (Fintype.card U : ENNReal) ≠ 0 := by + simp only [ne_eq, Nat.cast_eq_zero] + exact Fintype.card_ne_zero + have h_card_ne_top : (Fintype.card U : ENNReal) ≠ ⊤ := ENNReal.natCast_ne_top (Fintype.card U) + rw [← ENNReal.mul_inv_cancel h_card_ne_zero h_card_ne_top]) + let update_prob (v : U) : ENNReal := (NN.State.gibbsUpdateSingleNeuron wθ s T v) s' + ∑ v ∈ Finset.univ, neuron_pmf v * update_prob v = neuron_pmf u * update_prob u := by + intro neuron_pmf update_prob + have h_zero : ∀ v ∈ Finset.univ, v ≠ u → update_prob v = 0 := by + intro v _ hv + exact gibbs_update_zero_other_sites wθ T s s' u v h h_diff hv + apply Finset.sum_eq_single u + · intro v hv hvu + rw [h_zero v hv hvu] + simp only [mul_zero] + · intro hu + exfalso + apply hu + simp only [mem_univ] + +@[simp] +lemma gibbs_update_preserves_other_sites (v u : U) (hvu : v ≠ u) : + ∀ s_next, (NN.State.gibbsUpdateSingleNeuron wθ s T v) s_next > 0 → s_next.act u = s.act u := by + intro s_next h_pos + have h_supp : s_next ∈ PMF.support (NN.State.gibbsUpdateSingleNeuron wθ s T v) := by + exact (PMF.apply_pos_iff (NN.State.gibbsUpdateSingleNeuron wθ s T v) s_next).mp h_pos + have h_structure := gibbsUpdate_possible_states wθ s T v s_next h_pos + cases h_structure with + | inl h_pos => + -- Case s_next = updateNeuron s v 1 + rw [h_pos] + exact updateNeuron_preserves s v u 1 (mul_self_eq_mul_self_iff.mp rfl) (id (Ne.symm hvu)) + | inr h_neg => + -- Case s_next = updateNeuron s v (-1) + rw [h_neg] + exact + updateNeuron_preserves s v u (-1) (AffineMap.lineMap_eq_lineMap_iff.mp rfl) (id (Ne.symm hvu)) + +@[simp] +lemma uniform_neuron_prob {U : Type} [Fintype U] [Nonempty U] (u : U) : + (1 : ENNReal) / (Fintype.card U : ENNReal) = + PMF.ofFintype (λ _ : U => (1 : ENNReal) / (Fintype.card U : ENNReal)) + (by + rw [Finset.sum_const, Finset.card_univ] + simp only [nsmul_eq_mul] + have h_card_ne_zero : (Fintype.card U : ENNReal) ≠ 0 := by + simp only [ne_eq, Nat.cast_eq_zero] + exact Fintype.card_ne_zero + have h_card_ne_top : (Fintype.card U : ENNReal) ≠ ⊤ := ENNReal.natCast_ne_top _ + rw [ENNReal.div_eq_inv_mul] + rw [mul_comm] + rw [← ENNReal.mul_inv_cancel h_card_ne_zero h_card_ne_top] + rw [ENNReal.inv_mul_cancel_left h_card_ne_zero h_card_ne_top] + simp_all only [ne_eq, Nat.cast_eq_zero, Fintype.card_ne_zero, + not_false_eq_true, ENNReal.natCast_ne_top] + rw [mul_comm] + ) u := by + simp only [one_div, PMF.ofFintype_apply] diff --git a/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/StochasticAux.lean b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/StochasticAux.lean new file mode 100644 index 000000000..e570c868c --- /dev/null +++ b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/StochasticAux.lean @@ -0,0 +1,471 @@ +import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.Core +import Mathlib.Analysis.Normed.Ring.Basic +import Mathlib.Data.Complex.Exponential + +/- Several helper lemmas to support proofs of correctness, such as: +Lemmas (`energy_decomposition`, `weight_symmetry`, `energy_sum_split`) connecting the local +parameters (weights, biases) to the global energy function. -/ + +open Finset Matrix NeuralNetwork State + +variable {R U : Type} [Field R] [LinearOrder R] [IsStrictOrderedRing R] + [Fintype U] [Nonempty U] + +/-- The probability of selecting a specific neuron in the uniform distribution is 1/|U| -/ +lemma uniform_neuron_selection_prob (u : U) : + let p := λ _ => (1 : ENNReal) / (Fintype.card U : ENNReal) + let neuron_pmf := PMF.ofFintype p (by + rw [Finset.sum_const, Finset.card_univ] + rw [ENNReal.div_eq_inv_mul] + simp only [mul_one] + have h_card_ne_zero : (Fintype.card U : ENNReal) ≠ 0 := by + simp only [ne_eq, Nat.cast_eq_zero] + exact Fintype.card_ne_zero + have h_card_ne_top : (Fintype.card U : ENNReal) ≠ ⊤ := ENNReal.natCast_ne_top (Fintype.card U) + rw [← ENNReal.mul_inv_cancel h_card_ne_zero h_card_ne_top] + simp only [nsmul_eq_mul]) + neuron_pmf u = (1 : ENNReal) / (Fintype.card U : ENNReal) := by + intro p neuron_pmf + simp only [PMF.ofFintype_apply, one_div, neuron_pmf, p] + +/-- Uniform neuron selection gives a valid PMF -/ +lemma uniform_neuron_selection_prob_valid : + let p := λ (_ : U) => (1 : ENNReal) / (Fintype.card U : ENNReal) + ∑ a ∈ Finset.univ, p a = 1 := by + intro p + rw [Finset.sum_const, Finset.card_univ] + have h_card_pos : 0 < Fintype.card U := Fintype.card_pos_iff.mpr inferInstance + have h_card_ne_zero : (Fintype.card U : ENNReal) ≠ 0 := by + simp only [ne_eq, Nat.cast_eq_zero] + exact ne_of_gt h_card_pos + have h_card_top : (Fintype.card U : ENNReal) ≠ ⊤ := ENNReal.natCast_ne_top (Fintype.card U) + rw [ENNReal.div_eq_inv_mul] + rw [nsmul_eq_mul] + simp only [mul_one] + rw [ENNReal.mul_inv_cancel h_card_ne_zero h_card_top] + +variable [DecidableEq U] (wθ : Params (HopfieldNetwork R U)) + (s : (HopfieldNetwork R U).State) +/-- Decompose energy into weight component and bias component -/ +@[simp] +lemma energy_decomposition : + s.E wθ = s.Ew wθ + s.Eθ wθ := by + rw [NeuralNetwork.State.E] + --rw [← @add_neg_eq_iff_eq_add]; exact add_neg_eq_of_eq_add rfl + +/-- Weight matrix is symmetric in a Hopfield network -/ +lemma weight_symmetry (v1 v2 : U) : + wθ.w v1 v2 = wθ.w v2 v1 := (congrFun (congrFun (id (wθ.hw').symm) v1) v2) + +/-- Energy sum can be split into terms with u and terms without u -/ +lemma energy_sum_split (u : U): + ∑ v : U, ∑ v2 ∈ {v2 | v2 ≠ v}, wθ.w v v2 * s.act v * s.act v2 = + (∑ v2 ∈ {v2 | v2 ≠ u}, wθ.w u v2 * s.act u * s.act v2) + + (∑ v ∈ univ.erase u, ∑ v2 ∈ {v2 | v2 ≠ v}, wθ.w v v2 * s.act v * s.act v2) := by + rw [← sum_erase_add _ _ (mem_univ u)] + simp only [ne_eq, mem_univ, sum_erase_eq_sub, sub_add_cancel, add_sub_cancel] + +/-- When states differ at exactly one site, we can identify that site -/ +@[simp] +lemma single_site_difference (s s' : (HopfieldNetwork R U).State) + (h : s ≠ s' ∧ ∃ u : U, ∀ v : U, v ≠ u → s.act v = s'.act v) : + ∃! u : U, s.act u ≠ s'.act u := by + obtain ⟨s_neq, hu_all⟩ := h + obtain ⟨u, hu⟩ := hu_all + have diff_at_u : s.act u ≠ s'.act u := by { + by_contra h_eq + have all_same : ∀ v, s.act v = s'.act v := by { + intro v + by_cases hv : v = u + { rw [hv, h_eq]} + { exact hu v hv }} + have s_eq : s = s' := ext all_same + exact s_neq s_eq} + use u + constructor + { exact diff_at_u } + { intros v h_diff + by_contra h_v + have eq_v : s.act v = s'.act v := by { + by_cases hv : v = u + { rw [hv]; exact hu u fun a ↦ h_diff (hu v h_v) } + { exact hu v hv }} + exact h_diff eq_v } + +/-- States that are equal at all sites are equal -/ +lemma state_equality_from_sites + (s s' : (HopfieldNetwork R U).State) + (h : ∀ u : U, s.act u = s'.act u) : s = s' := by + apply NeuralNetwork.ext + exact h + +/-- Function to set a specific neuron state -/ +def NN.State.updateNeuron (u : U) (val : R) + (hval : (HopfieldNetwork R U).pact val) : (HopfieldNetwork R U).State := +{ act := fun u' => if u' = u then val else s.act u', + hp := by + intro u' + by_cases h : u' = u + · simp only [h, ↓reduceIte] + exact hval + · simp only [h, ↓reduceIte] + exact s.hp u' } + +/-- The updateNeuron function only changes the specified neuron, leaving others unchanged -/ +@[simp] +lemma updateNeuron_preserves + (s : (HopfieldNetwork R U).State) (v w : U) (val : R) (hval : (HopfieldNetwork R U).pact val) : + w ≠ v → (NN.State.updateNeuron s v val hval).act w = s.act w := by + intro h_neq + unfold NN.State.updateNeuron + exact if_neg h_neq + +/-- For states differing at only one site, that site must be u -/ +@[simp] +lemma single_site_difference_unique + (s s' : (HopfieldNetwork R U).State) + (u : U) (h : ∀ v : U, v ≠ u → s.act v = s'.act v) (h_diff : s ≠ s') : + ∃! v : U, s.act v ≠ s'.act v := by + use u + constructor + · by_contra h_eq + have all_equal : ∀ v, s.act v = s'.act v := by + intro v + by_cases hv : v = u + · rw [hv] + exact h_eq + · exact h v hv + exact h_diff (ext all_equal) + · intro v hv + by_contra h_neq + have v_diff_u : v ≠ u := by + by_contra h_eq + rw [h_eq] at hv + contradiction + exact hv (h v v_diff_u) + +/-- Given a single-site difference, the destination state is + an update of the source state -/ +lemma single_site_is_update + (s s' : (HopfieldNetwork R U).State) (u : U) + (h : ∀ v : U, v ≠ u → s.act v = s'.act v) : + s' = NN.State.updateNeuron s u (s'.act u) (s'.hp u) := by + apply NeuralNetwork.ext + intro v + by_cases hv : v = u + · rw [hv] + exact Eq.symm (if_pos rfl) + · rw [← h v hv] + exact Eq.symm (if_neg hv) + +/-- When updating a neuron with a value that equals one of the + standard values (1 or -1), the result equals the standard update -/ +@[simp] +lemma update_neuron_equiv + (s : (HopfieldNetwork R U).State) (u : U) (val : R) + (hval : (HopfieldNetwork R U).pact val) : + val = 1 → NN.State.updateNeuron s u val hval = + NN.State.updateNeuron s u 1 (Or.inl rfl) := by + intro h_val + apply NeuralNetwork.ext + intro v + unfold NN.State.updateNeuron + by_cases h_v : v = u + · exact ite_congr rfl (fun a ↦ h_val) (congrFun rfl) + · exact ite_congr rfl (fun a ↦ h_val) (congrFun rfl) + +/-- Updates with different activation values produce different states -/ +@[simp] +lemma different_activation_different_state + (s : (HopfieldNetwork R U).State) (u : U) : + NN.State.updateNeuron s u 1 (Or.inl rfl) ≠ + NN.State.updateNeuron s u (-1) (Or.inr rfl) := by + intro h_contra + have h_values : + (NN.State.updateNeuron s u 1 (Or.inl rfl)).act u = + (NN.State.updateNeuron s u (-1) (Or.inr rfl)).act u := by + congr + unfold NN.State.updateNeuron at h_values + simp at h_values + have : (1 : R) ≠ -1 := by + simp only [ne_eq] + norm_num + exact this h_values + +/-- Two neuron updates at the same site are equal if and only if + their new values are equal -/ +lemma update_neuron_eq_iff + (s : (HopfieldNetwork R U).State) (u : U) (val₁ val₂ : R) + (hval₁ : (HopfieldNetwork R U).pact val₁) (hval₂ : (HopfieldNetwork R U).pact val₂) : + NN.State.updateNeuron s u val₁ hval₁ = NN.State.updateNeuron s u val₂ hval₂ ↔ val₁ = val₂ := by + constructor + · intro h + have h_act : (NN.State.updateNeuron s u val₁ hval₁).act u = (NN.State.updateNeuron s u val₂ hval₂).act u := by + rw [h] + unfold NN.State.updateNeuron at h_act + simp at h_act + exact h_act + · intro h_val + subst h_val + apply NeuralNetwork.ext + intro v + by_cases hv : v = u + · subst hv; unfold NN.State.updateNeuron; exact rfl + · unfold NN.State.updateNeuron; exact rfl + +/-- Determines when a boolean-indexed update equals a specific update -/ +@[simp] +lemma bool_update_eq_iff + (s : (HopfieldNetwork R U).State) (u : U) (b : Bool) (val : R) + (hval : (HopfieldNetwork R U).pact val) : + (if b then NN.State.updateNeuron s u 1 (Or.inl rfl) + else NN.State.updateNeuron s u (-1) (Or.inr rfl)) = + NN.State.updateNeuron s u val hval ↔ + (b = true ∧ val = 1) ∨ (b = false ∧ val = -1) := by + cases b + · simp only [Bool.false_eq_true, ↓reduceIte, update_neuron_eq_iff, + false_and, true_and, false_or] + constructor + · intro h + exact id (Eq.symm h) + · intro h_cases + cases h_cases + trivial + · simp only [↓reduceIte, update_neuron_eq_iff, true_and, Bool.true_eq_false, + false_and, or_false] + constructor + · intro h + exact id (Eq.symm h) + · intro h_cases + cases h_cases + ·exact rfl + +/-- When filtering a PMF with binary support to states matching a given state's update, + the result reduces to a singleton if the update site matches -/ +lemma pmf_filter_update_neuron + (s : (HopfieldNetwork R U).State) (u : U) (val : R) + (hval : (HopfieldNetwork R U).pact val) : + let f : Bool → (HopfieldNetwork R U).State := λ b => + if b then NN.State.updateNeuron s u 1 (Or.inl rfl) + else NN.State.updateNeuron s u (-1) (Or.inr rfl) + filter (fun b => f b = NN.State.updateNeuron s u val hval) univ = + if val = 1 then {true} else + if val = -1 then {false} else ∅ := by + intro f + by_cases h1 : val = 1 + · simp only [h1] + ext b + simp only [mem_filter, mem_univ, true_and, mem_singleton] + rw [@bool_update_eq_iff] + simp only [and_true, ↓reduceIte, mem_singleton, or_iff_left_iff_imp, and_imp] + cases b + · simp only [Bool.false_eq_true, imp_false, forall_const] + norm_num + · simp only [Bool.true_eq_false, implies_true] + · by_cases h2 : val = -1 + · simp only [h1, h2] + ext b + simp only [mem_filter, mem_univ, true_and, mem_singleton] + rw [@bool_update_eq_iff] + simp only [and_true, ↓reduceIte] + cases b + · simp only [Bool.false_eq_true, false_and, or_true, true_iff] + norm_num + · simp only [true_and, Bool.true_eq_false, or_false] + norm_num + · simp only [h1, h2] + ext b + simp only [mem_filter, mem_univ, true_and] + rw [@bool_update_eq_iff] + simp only [h1, and_false, h2, or_self, ↓reduceIte, not_mem_empty] + +/-- For a PMF over binary values mapped to states, the probability of a specific state + equals the probability of its corresponding binary value -/ +lemma pmf_map_binary_state + (s : (HopfieldNetwork R U).State) (u : U) (b : Bool) (p : Bool → ENNReal) (h_sum : ∑ b, p b = 1) : + let f : Bool → (HopfieldNetwork R U).State := λ b => + if b then NN.State.updateNeuron s u 1 (Or.inl rfl) + else NN.State.updateNeuron s u (-1) (Or.inr rfl) + PMF.map f (PMF.ofFintype p h_sum) (f b) = p b := by + intro f + simp only [PMF.map_apply] + have h_inj : ∀ b₁ b₂ : Bool, b₁ ≠ b₂ → f b₁ ≠ f b₂ := by + intro b₁ b₂ hneq + unfold f + cases b₁ <;> cases b₂ + · contradiction + · simp only [Bool.false_eq_true, ↓reduceIte, ne_eq] + apply Ne.symm + exact different_activation_different_state s u + · dsimp only [↓dreduceIte, Bool.false_eq_true, ne_eq] + have h_values_diff : (1 : R) ≠ (-1 : R) := by + simp only [ne_eq] + norm_num + exact (update_neuron_eq_iff s u 1 (-1) + (Or.inl rfl) (Or.inr rfl)).not.mpr h_values_diff + · contradiction + have h_unique : ∀ b' : Bool, f b' = f b ↔ b' = b := by + intro b' + by_cases h : b' = b + · constructor + · intro _ + exact h + · intro _ + rw [h] + · have : f b' ≠ f b := h_inj b' b h + constructor + · intro h_eq + contradiction + · intro h_eq + contradiction + have h_filter : (∑' (b' : Bool), if f b = f b' then (PMF.ofFintype p h_sum) b' else 0) = + (PMF.ofFintype p h_sum) b := by + rw [tsum_fintype] + have h_iff : ∀ b' : Bool, f b = f b' ↔ b = b' := by + intro b' + constructor + · intro h_eq + by_contra h_neq + have h_different : f b ≠ f b' := by + apply h_inj + exact h_neq + contradiction + · intro h_eq + rw [h_eq] + have h_eq : ∑ b' : Bool, ite (f b = f b') ((PMF.ofFintype p h_sum) b') 0 = + ∑ b' : Bool, ite (b = b') ((PMF.ofFintype p h_sum) b') 0 := by + apply Finset.sum_congr rfl + intro b' _ + have hcond : (f b = f b') ↔ (b = b') := h_iff b' + simp only [hcond] + rw [h_eq] + simp [h_eq, Finset.sum_ite_eq] + rw [@tsum_bool] + simp only [PMF.ofFintype_apply] + cases b + · have h_true_neq : f false ≠ f true := by + apply h_inj + simp only [ne_eq, Bool.false_eq_true, not_false_eq_true] + simp only [h_true_neq, if_true, if_false, add_zero] + · have h_false_neq : f true ≠ f false := by + apply h_inj + simp only [ne_eq, Bool.true_eq_false, not_false_eq_true] + simp only [h_false_neq, if_true, if_false, zero_add] + +/-- A specialized version of the previous lemma for the case where the state is an update with new_val = 1 -/ +lemma pmf_map_update_one + (s : (HopfieldNetwork R U).State) (u : U) (p : Bool → ENNReal) (h_sum : ∑ b, p b = 1) : + let f : Bool → (HopfieldNetwork R U).State := λ b => + if b then NN.State.updateNeuron s u 1 (Or.inl rfl) + else NN.State.updateNeuron s u (-1) (Or.inr rfl) + PMF.map f (PMF.ofFintype p h_sum) (NN.State.updateNeuron s u 1 (Or.inl rfl)) = p true := by + intro f + apply pmf_map_binary_state s u true p h_sum + +/-- A specialized version for the case where the state is an update with new_val = -1 -/ +lemma pmf_map_update_neg_one + (s : (HopfieldNetwork R U).State) (u : U) (p : Bool → ENNReal) (h_sum : ∑ b, p b = 1) : + let f : Bool → (HopfieldNetwork R U).State := λ b => + if b then NN.State.updateNeuron s u 1 (Or.inl rfl) + else NN.State.updateNeuron s u (-1) (Or.inr rfl) + PMF.map f (PMF.ofFintype p h_sum) (NN.State.updateNeuron s u (-1) (Or.inr rfl)) = p false := by + intro f + apply pmf_map_binary_state s u false p h_sum + +/-- Expresses a ratio of exponentials in terms of the sigmoid function format. +-/ +@[simp] +lemma exp_ratio_to_sigmoid (x : ℝ) : + Real.exp x / (Real.exp x + Real.exp (-x)) = 1 / (1 + Real.exp (-2 * x)) := by + have h_denom : Real.exp x + Real.exp (-x) = Real.exp x * (1 + Real.exp (-2 * x)) := by + have rhs_expanded : Real.exp x * (1 + Real.exp (-2 * x)) = + Real.exp x + Real.exp x * Real.exp (-2 * x) := by + rw [mul_add, mul_one] + have exp_identity : Real.exp x * Real.exp (-2 * x) = Real.exp (-x) := by + rw [← Real.exp_add] + congr + ring + rw [rhs_expanded, exp_identity] + rw [h_denom, div_mul_eq_div_div] + have h_exp_ne_zero : Real.exp x ≠ 0 := ne_of_gt (Real.exp_pos x) + field_simp + +/-- Local field is the weighted sum of incoming activations -/ +lemma local_field_eq_weighted_sum + (wθ : Params (HopfieldNetwork R U)) (s : (HopfieldNetwork R U).State) (u : U) : + s.net wθ u = ∑ v ∈ univ.erase u, wθ.w u v * s.act v := by + unfold NeuralNetwork.State.net + unfold NeuralNetwork.fnet HopfieldNetwork + simp only [ne_eq] + have sum_filter_eq : ∑ v ∈ filter (fun v => v ≠ u) univ, wθ.w u v * s.act v = + ∑ v ∈ univ.erase u, wθ.w u v * s.act v := by + apply Finset.sum_congr + · ext v + simp only [mem_filter, mem_erase, mem_univ, true_and, and_true] + · intro v _ + simp_all only [mem_erase, ne_eq, mem_univ, and_true] + --rw [@OrderedCommSemiring.mul_comm] + exact sum_filter_eq + +@[simp] +lemma gibbs_bool_to_state_map_positive + (s : (HopfieldNetwork R U).State) (u : U) (val : R) (hval : (HopfieldNetwork R U).pact val) : + val = 1 → NN.State.updateNeuron s u val hval = + NN.State.updateNeuron s u 1 (Or.inl rfl) := by + intro h_val + apply NeuralNetwork.ext + intro v + unfold NN.State.updateNeuron + by_cases h_v : v = u + · rw [h_v] + exact ite_congr rfl (fun a ↦ h_val) (congrFun rfl) + · simp only [h_v, if_neg] + exact rfl + +@[simp] +lemma gibbs_bool_to_state_map_negative + (s : (HopfieldNetwork R U).State) (u : U) (val : R) (hval : (HopfieldNetwork R U).pact val) : + val = -1 → NN.State.updateNeuron s u val hval = + NN.State.updateNeuron s u (-1) (Or.inr rfl) := by + intro h_val + apply NeuralNetwork.ext + intro v + unfold NN.State.updateNeuron + by_cases h_v : v = u + · rw [h_v] + dsimp only; exact congrFun (congrArg (ite (u = u)) h_val) (s.act u) + · dsimp only [h_v]; exact congrFun (congrArg (ite (v = u)) h_val) (s.act v) + +/-- When states differ at exactly one site, the later state can be expressed as + an update of the first state at that site -/ +lemma single_site_transition_as_update + (s s' : (HopfieldNetwork R U).State) (u : U) + (h : ∀ v : U, v ≠ u → s.act v = s'.act v) : + s' = NN.State.updateNeuron s u (s'.act u) (s'.hp u) := by + apply NeuralNetwork.ext + intro v + by_cases hv : v = u + · rw [hv] + unfold NN.State.updateNeuron + exact Eq.symm (if_pos rfl) + · unfold NN.State.updateNeuron + rw [← h v hv] + exact Eq.symm (if_neg hv) + +/-- When states differ at exactly one site, the later state can be expressed as + an update of the first state at that site -/ +@[simp] +lemma single_site_difference_as_update (s s' : (HopfieldNetwork R U).State) (u : U) + (h_diff_at_u : s.act u ≠ s'.act u) (h_same_elsewhere : ∀ v : U, v ≠ u → s.act v = s'.act v) : + s' = NN.State.updateNeuron s u (s'.act u) (s'.hp u) := by + apply NeuralNetwork.ext + intro v + by_cases hv : v = u + · rw [hv] + unfold NN.State.updateNeuron + simp only [if_pos rfl] + have _ := h_diff_at_u + exact rfl + · unfold NN.State.updateNeuron + simp only [if_neg hv] + exact Eq.symm (h_same_elsewhere v hv) diff --git a/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/aux.lean b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/aux.lean new file mode 100644 index 000000000..ed8c21ae6 --- /dev/null +++ b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/aux.lean @@ -0,0 +1,282 @@ +/- +Copyright (c) 2024 Michail Karatarakis. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Michail Karatarakis, Matteo Cipollina +-/ +import Mathlib.Algebra.EuclideanDomain.Field +import Mathlib.Algebra.Order.Star.Basic +import Mathlib.LinearAlgebra.Matrix.Symmetric +import Mathlib.Probability.ProbabilityMassFunction.Constructions +set_option checkBinderAnnotations false + +variable {U : Type*} [Field R] --[LinearOrder R] [IsStrictOrderedRing R] +open Finset Fintype Matrix + +@[simp] +lemma isSymm_sum (f : Fin m → Matrix U U R) (hi : ∀ i, (f i).IsSymm) : + (∑ x : Fin m, f x).IsSymm := by + rw [Matrix.IsSymm] + simp only [Matrix.transpose_sum] + apply Finset.sum_congr rfl + intro x _ + exact hi x + +--def fair' (sseq : Nat → NN.State) : Prop := ∀ n, ∃ m > n, sseq (m + 1) ≠ sseq m + +/-- +A sequence `useq` is fair if for every element `u` and every index `n`, +there exists an index `m` greater than or equal to `n` such that `useq m = u`. +--/ +def fair (useq : ℕ → U) : Prop := ∀ u n, ∃ m ≥ n, useq m = u + +/-- +`cyclic useq` is a property that holds if `useq` is a sequence such that: +1. Every element of type `U` appears at least once in the sequence. +2. The sequence is periodic with a period equal to the cardinality of `U`. +--/ +def cyclic [Fintype U] (useq : ℕ → U) : Prop := + (∀ u : U, ∃ i, useq i = u) ∧ (∀ i j, i % card U = j % card U → useq i = useq j) + +@[simp] +lemma Fintype.exists_mod_eq_within_bounds [Fintype U] [Nonempty U] : ∀ i (n : ℕ), + ∃ m ≥ n, m < n + card U ∧ m % card U = i % (card U) := by { + simp only [ge_iff_le] + intros i n + let a := (n / card U) + 1 + have : a * card U ≥ n := by { + simp only [ge_iff_le] + simp only [a] + rw [add_mul] + simp only [one_mul] + rw [le_iff_lt_or_eq] + left + exact Nat.lt_div_mul_add Fintype.card_pos} + let j := i + a * card U + have hj : j ≥ n := by { + simp only [j] + exact le_add_of_le_right this} + use (n + (j - n) % card U) + constructor + · simp only [ge_iff_le, a, j] at * + exact Nat.le_add_right n ((i + (n / Fintype.card U + 1) * Fintype.card U - n) % Fintype.card U) + · constructor + · simp only [add_lt_add_iff_left] + exact Nat.mod_lt _ Fintype.card_pos + · simp only [Nat.add_mod_mod] + simp_all only [ge_iff_le, add_tsub_cancel_of_le, Nat.add_mul_mod_self_right, a, j]} + +@[simp] +lemma cycl_Fair [Fintype U] [Nonempty U] (useq : ℕ → U) : cyclic useq → fair useq := + fun ⟨hcov, hper⟩ u n => + let ⟨i, hi⟩ := hcov u + let ⟨m, hm, hkmod⟩ := @Fintype.exists_mod_eq_within_bounds U _ _ i n + ⟨m, hm, hi ▸ hper m i hkmod.2⟩ + +@[simp] +lemma Fintype.cyclic_Fair_bound [Fintype U] [Nonempty U] (useq : ℕ → U) : + cyclic useq → ∀ u (n : ℕ), + ∃ m ≥ n, m < n + card U ∧ useq m = u := fun ⟨hcov, hper⟩ u n => by { + obtain ⟨i, hi⟩ := hcov u + have := @Fintype.exists_mod_eq_within_bounds U _ _ i n + obtain ⟨m, hm, hle, hkmod⟩ := this + use m + constructor + · exact hm + · constructor + · exact hle + · rw [← hi] + exact hper m i hkmod} + +/-- Split a sum over elements satisfying P into sums over elements satisfying P∧Q and P∧¬Q -/ +lemma sum_split (P Q : α → Prop) [DecidablePred P] [AddCommMonoid β] + [DecidablePred Q] (f : α → β) : + ∑ u ∈ filter (fun u => P u) s, f u = ∑ u ∈ filter (fun u => P u ∧ Q u) s, f u + + ∑ u ∈ filter (fun u => P u ∧ ¬ Q u) s, f u := by + simp only [sum_filter, ← sum_add_distrib, ite_and, ite_add, ite_not, zero_add, add_zero, zero_add] + simp_all only [↓reduceIte, add_zero, ite_self] + +lemma sum_over_subset (f : α → β) (s : Finset α) [Fintype α] [AddCommMonoid β] + [DecidablePred (fun x => x ∈ s)] : + ∑ x ∈ s, f x = ∑ x, if x ∈ s then f x else 0 := by + simp_rw [← sum_filter]; congr; + ext; simp only [mem_filter, mem_univ, true_and] + +/-- Convert telescope sum to filtered sum --/ +@[simp] +lemma tsum_to_filter {α : Type} [Fintype α] {p : α → ENNReal} {y : β} (f : α → β) : + ∑' (a : α), p a = ∑ a ∈ filter (fun a ↦ f a = y) univ, p a := by + rw [tsum_fintype] + exact Eq.symm (sum_filter (fun a ↦ f a = y) p) + +/-- If filtered sum is positive, there exists a positive element satisfying filter condition --/ +@[simp] +lemma filter_sum_pos_exists {α β : Type} [Fintype α] [DecidableEq β] {p : α → ENNReal} + {f : α → β} {y : β} : + ∑ a ∈ filter (fun a ↦ f a = y) univ, p a > 0 → + ∃ x : α, p x > 0 ∧ f x = y := by + intro sum_pos + have exists_pos : ∃ a ∈ filter (fun a ↦ f a = y) univ, p a > 0 := by + by_contra h + have all_zero : ∀ a ∈ filter (fun a ↦ f a = y) univ, p a = 0 := by + intro a ha + apply le_antisymm + · push_neg at h + exact h a ha + · simp_all only [gt_iff_lt, mem_filter, mem_univ, true_and, not_exists, not_and, not_lt, + nonpos_iff_eq_zero, le_refl] + simp_all only [Finset.sum_eq_zero all_zero, sum_const_zero, gt_iff_lt, lt_self_iff_false] + rcases exists_pos with ⟨x, h_mem, h_p_pos⟩ + -- Membership in filter means f x = y + simp only [filter_subset, mem_filter, mem_univ, true_and] at h_mem + subst h_mem + simp_all only [gt_iff_lt] + apply Exists.intro + · apply And.intro + on_goal 2 => {rfl + } + · simp_all only + +/-- Any element in finset contributes to supremum over filtered sums --/ +@[simp] +lemma ENNReal.le_iSup_finset {α : Type} {s : Finset α} {f : α → ENNReal} : + ∑ a ∈ s, f a ≤ ⨆ (t : Finset α), ∑ a ∈ t, f a := by + -- Use s as witness for supremum bound + exact le_iSup_iff.mpr fun b a ↦ a s + +/-- If there exists an element with positive value, then the telescope sum is positive --/ +@[simp] +lemma ENNReal.tsum_pos_of_exists {α : Type} {f : α → ENNReal} (h : ∃ a : α, f a > 0) : + (∑' a, f a) > 0 := by + -- Extract the element with positive value + rcases h with ⟨a₀, h_pos⟩ + + -- Show that sum is at least as large as the term f a₀ + have h_le : f a₀ ≤ ∑' a, f a := by + apply ENNReal.le_tsum + + -- Since f a₀ > 0 and sum ≥ f a₀, the sum must be positive + exact lt_of_lt_of_le h_pos h_le + +/-- If there exists a positive element satisfying filter condition, filtered sum is positive --/ +@[simp] +lemma exists_pos_filter_sum_pos {α β : Type} {p : α → ENNReal} {f : α → β} {y : β} : + (∃ x : α, 0 < p x ∧ f x = y) → + ∃ x : α, 0 < p x ∧ f x = y := by + rintro ⟨x, h_p_pos, h_fx_eq_y⟩ + exact ⟨x, h_p_pos, h_fx_eq_y⟩ + +/-- Filter of non-equal elements is equivalent to erase operation --/ +@[simp] +lemma filter_erase_equiv {U : Type} + [DecidableEq U] [Fintype U] (u : U) : + filter (fun v => v ≠ u) univ = univ.erase u := by + ext v + simp only [mem_filter, mem_erase, mem_univ, true_and] + exact Iff.symm (and_iff_left_of_imp fun a ↦ trivial) + +lemma pmf_map_apply_eq_tsum {α β : Type} [Fintype α] [DecidableEq β] {p : α → ENNReal} + (h_pmf : ∑ a, p a = 1) (f : α → β) (y : β) : + (PMF.map f (PMF.ofFintype p h_pmf)) y = ∑' (a : α), if y = f a + then (PMF.ofFintype p h_pmf) a else 0 := by + simp only [PMF.map_apply]; simp_all only [PMF.ofFintype_apply] + +lemma filter_mem_iff {α β : Type} [Fintype α] [DecidableEq β] {f : α → β} {y : β} {x : α} : + x ∈ filter (fun a ↦ f a = y) univ ↔ f x = y := by + simp only [mem_filter, mem_univ, true_and] + +@[simp] +lemma tsum_eq_filter_sum {α β : Type} [Fintype α] [DecidableEq β] + {p : α → ENNReal} {y : β} (f : α → β) : + (∑' (a : α), if y = f a then p a else 0) = ∑ a ∈ filter (fun a ↦ f a = y) univ, p a := by + rw [tsum_fintype] + rw [← Finset.sum_filter] + apply Finset.sum_congr + · ext a + simp only [mem_filter, mem_univ, true_and] + by_cases h : f a = y + · simp [h] + · simp [h]; exact fun a_1 ↦ h (id (Eq.symm a_1)) + · intro a ha; simp_all only [mem_filter, mem_univ, true_and] + +@[simp] +lemma pmf_ofFintype_apply_eq {α : Type} [Fintype α] + {p : α → ENNReal} (h_pmf : ∑ a, p a = 1) (a : α) : + (PMF.ofFintype p h_pmf) a = p a := by + simp only [PMF.ofFintype_apply] + +@[simp] +lemma filter_sum_pos_iff_exists_pos {α β : Type} [Fintype α] + [DecidableEq β] {p : α → ENNReal} {f : α → β} {y : β} : + (∑ a ∈ filter (fun a ↦ f a = y) univ, p a) > 0 ↔ + ∃ x : α, f x = y ∧ p x > 0 := by + constructor + · intro h_pos + have exists_pos : ∃ a ∈ filter (fun a ↦ f a = y) univ, p a > 0 := by + by_contra h + have all_zero : ∀ a ∈ filter (fun a ↦ f a = y) univ, p a = 0 := by + intro a ha + apply le_antisymm + · push_neg at h + exact h a ha + · exact zero_le (p a) + have sum_zero := Finset.sum_eq_zero all_zero + exact not_lt_of_le (by exact nonpos_iff_eq_zero.mpr sum_zero) h_pos + rcases exists_pos with ⟨x, hx_mem, hx_pos⟩ + exact ⟨x, filter_mem_iff.mp hx_mem, hx_pos⟩ + · rintro ⟨x, hx_mem, hx_pos⟩ + simp_all only [mem_filter, mem_univ, true_and, gt_iff_lt] + subst hx_mem + have x_in_filter : x ∈ filter (fun a ↦ f a = f x) univ := by + simp only [filter_mem_iff] + have sum_ge_x : ∑ a ∈ filter (fun a ↦ f a = f x) univ, p a ≥ p x := by + exact CanonicallyOrderedAddCommMonoid.single_le_sum x_in_filter + exact lt_of_lt_of_le hx_pos sum_ge_x + +/-- Main aux lemma: + For a mapped PMF, the probability of a state is positive if and only if + there exists a preimage with positive probability --/ +@[simp] +lemma pmf_map_pos_iff_exists_pos {α β : Type} + {p : α → ENNReal} (f : α → β) (y : β) : + (∃ x : α, f x = y ∧ 0 < p x) ↔ + ∃ x : α, p x > 0 ∧ f x = y := by + constructor + · intro h_pos + rcases h_pos with ⟨x, hx_eq, hx_pos⟩ + use x + · intro h_exists + rcases h_exists with ⟨x, hx_pos, hx_eq⟩ + rw [← hx_eq] + rw [hx_eq] + rw [← hx_eq] + use x +lemma Array.mkArray_size {α : Type} (n : ℕ) (a : α) : + (Array.replicate n a).size = n := by simp only [size_replicate] + +lemma Array.mkArray_get {α : Type} (n : ℕ) (a : α) (i : Nat) (h : i < n) : + (Array.replicate n a)[i]'(by rw [Array.size_replicate]; exact h) = a := + getElem_replicate (Eq.mpr (id (congrArg (fun _a ↦ i < _a) size_replicate)) h) + +/-- +Proves that `Array.mkArray` creates valid parameters for a Hopfield network. +Given a vertex `u` in a Hopfield network with `n` nodes, this lemma establishes that: +1. The array `σ_u` has size equal to `κ1 u` +2. The array `θ_u` has size equal to `κ2 u` +3. All elements in `σ_u` are initialized to 0 +4. All elements in `θ_u` are initialized to 0 +where `κ1` and `κ2` are dimension functions defined in the `HopfieldNetwork` structure. +-/ +lemma Array.mkArray_creates_valid_nn_params {α : Type} (κ₁ κ₂ : α → ℕ) (u : α) (a₁ a₂ : ℝ) : + let σ_u := Array.replicate (κ₁ u) a₁ + let θ_u := Array.replicate (κ₂ u) a₂ + σ_u.size = κ₁ u ∧ + θ_u.size = κ₂ u ∧ + (∀ i : Nat, ∀ h : i < κ₁ u, σ_u[i]'(by { + simp only [σ_u]; rw [Array.mkArray_size]; exact h }) = a₁) ∧ + (∀ i : Nat, ∀ h : i < κ₂ u, θ_u[i]'(by { + simp only [θ_u]; rw [Array.mkArray_size]; exact h }) = a₂) := by + let σ_u := Array.replicate (κ₁ u) a₁ + let θ_u := Array.replicate (κ₂ u) a₂ + refine ⟨Array.size_replicate .., Array.size_replicate .., ?_, ?_⟩ + · intro i h; exact Array.getElem_replicate .. + · intro i h; exact Array.getElem_replicate .. diff --git a/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/test.lean b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/test.lean new file mode 100644 index 000000000..4401ff0c8 --- /dev/null +++ b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/test.lean @@ -0,0 +1,260 @@ +/- +Copyright (c) 2024 Michail Karatarakis. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Michail Karatarakis +-/ +import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.Core + +set_option linter.unusedVariables false +set_option maxHeartbeats 500000 + +open Mathlib Finset + +variable {R U : Type} [Zero R] + +/-- Two neurons `u` and `v` are connected in the graph if `w u v` is not zero. -/ +def Matrix.Adj (w : Matrix U U R) : U → U → Prop := fun u v => w u v ≠ 0 + +/-- `Matrix.w` returns the value of the matrix `w` at position `(u, v)` if `u` and `v` are connected. -/ +def Matrix.w (w : Matrix U U R) : ∀ u v : U, w.Adj u v → R := fun u v _ => w u v + +/-- A 3x3 matrix of rational numbers. --/ +def test.M : Matrix (Fin 3) (Fin 3) ℚ := Matrix.of ![![0,0,4], ![1,0,0], ![-2,3,0]] + +def test : NeuralNetwork ℚ (Fin 3) where + Adj := test.M.Adj + Ui := {0,1} + Uo := {2} + hU := by + ext x + simp only [Set.mem_univ, Fin.isValue, Set.union_singleton, + Set.union_empty, Set.mem_insert_iff, + Set.mem_singleton_iff, true_iff] + revert x + decide + Uh := ∅ + hUi := Ne.symm (Set.ne_insert_of_not_mem {1} fun a ↦ a) + hUo := Set.singleton_ne_empty 2 + hhio := by + simp only [Fin.isValue, Set.union_singleton, Set.empty_inter] + pw W := True + κ1 u := 0 + κ2 u := 1 + fnet u w pred σ := ∑ v, w v * pred v + fact u input θ := if input ≥ θ then 1 else 0 + fout u act := act + pact u := True + hpact w _ _ σ θ _ pact u := pact u + +def wθ : Params test where + w := Matrix.of ![![0,0,4], ![1,0,0], ![-2,3,0]] + θ u := ⟨#[1], by + simp only [List.size_toArray, List.length_cons, List.length_nil, zero_add] + unfold test + simp only⟩ + σ _ := Vector.emptyWithCapacity 0 + hw u v hv := by by_contra h; exact hv h + hw' := by simp only [test] + +instance : Repr (NeuralNetwork.State test) where + reprPrec state _ := + ("acts: " ++ repr (state.act)) ++ ", outs: " ++ + repr (state.out) ++ ", nets: " ++ repr (state.net wθ) + +/-- +`test.extu` is the initial state for the `test` neural network with activations `[1, 0, 0]`. +-/ +def test.extu : test.State := {act := ![1,0,0], hp := fun u => trivial} + +lemma zero_if_not_mem_Ui : ∀ u : Fin 3, + ¬ u ∈ ({0,1} : Finset (Fin 3)) → test.extu.act u = 0 := by decide + +/--If `u` is not in the input neurons `Ui`, then `test.extu.act u` is zero.-/ +lemma test.onlyUi : test.extu.onlyUi := by { + intros u hu + apply zero_if_not_mem_Ui u + simp only [Fin.isValue, mem_insert, mem_singleton, not_or] + exact not_or.mp hu} + +/-The workphase for the asynchronous update of the sequence of neurons u3 , u1 , u2 , u3 , u1 , u2 , u3. -/ +#eval NeuralNetwork.State.workPhase wθ test.extu test.onlyUi [2,0,1,2,0,1,2] + +/-The workphase for the asynchronous update of the sequence of neurons u3 , u2 , u1 , u3 , u2 , u1 , u3. -/ +#eval NeuralNetwork.State.workPhase wθ test.extu test.onlyUi [2,1,0,2,1,0,2] + +/-Hopfield Networks-/ + +/-- A 4x4 matrix of rational numbers. --/ +def W1 : Matrix (Fin 4) (Fin 4) ℚ := + Matrix.of ![![0,1,-1,-1], ![1,0,-1,-1], ![-1,-1,0,1], ![-1,-1,1,0]] + +/-- +`HebbianParamsTest` defines a Hopfield Network with 4 neurons and rational weights. +- `w`: The weight matrix `W1`. +- `hw`: Proof that the weights are symmetric. +- `hw'`: Proof that the weights are zero on the diagonal. +- `σ`: Always an empty vector. +- `θ`: Always returns a list with a single 0. +--/ +def HebbianParamsTest : Params (HopfieldNetwork ℚ (Fin 4)) where + w := W1 + hw u v huv := by + unfold HopfieldNetwork at huv + simp only [ne_eq, Decidable.not_not] at huv + rw [huv] + revert v v + simp only [forall_eq'] + revert u u + decide + hw' := by { + unfold HopfieldNetwork + simp only + decide} + σ := fun u => Vector.emptyWithCapacity 0 + θ u := ⟨#[0],by + simp only [List.size_toArray, List.length_cons, List.length_nil, zero_add]⟩ + +/-- `extu` is the initial state for our `HebbianParamsTest` Hopfield network. +- `act`: `[1, -1, -1, 1]` - initial activations. + +This initializes the state for a Hopfield network test. +--/ +def extu : State' HebbianParamsTest where + act := ![1,-1,-1,1] + hp := by + intros u + unfold HopfieldNetwork + simp only + revert u + decide + +instance : Repr (HopfieldNetwork ℚ (Fin 4)).State where + reprPrec state _ := ("acts: " ++ repr (state.act)) + +-- Computations + +-- lemma zero_if_not_mem_Ui' : ∀ u : Fin 4, +-- ¬ u ∈ ({0,1,2,3} : Finset (Fin 4)) → extu.act u = 0 := by {decide} + +-- def HN.hext : extu.onlyUi := by {intros u; tauto} + +-- #eval NeuralNetwork.State.workPhase HebbianParamsTest extu HN.hext [2,0,1,2,0,1,2] + + +/-- +`pattern_ofVec` converts a pattern vector from `Fin n` to `ℚ` into a `State` +for a `HopfieldNetwork` with `n` neurons. +It checks if all elements are either 1 or -1. If they are, it returns `some` + pattern; otherwise, it returns `none`. +--/ +def pattern_ofVec {n} [NeZero n] (pattern : Fin n → ℚ) : + Option (HopfieldNetwork ℚ (Fin n)).State := + if hp : ∀ i, pattern i = 1 ∨ pattern i = -1 then + some {act := pattern, hp := by { + intros u + unfold HopfieldNetwork + simp only + apply hp + }} + else + none + +/-- +`obviousFunction` tries to convert a function `f` from `α` to `Option β` into a regular function +from `α` to `β`. If `f` returns `some` for every input, it returns `some` function that extracts +these values. Otherwise, it returns `none`. +--/ +def obviousFunction [Fintype α] (f : α → Option β) : Option (α → β) := + if h : ∀ x, (f x).isSome then + some (fun a => (f a).get (h a)) + else + none + + +/-- +Converts a matrix of patterns `V` into Hopfield network states. + +Each row of `V` (a function `Fin m → Fin n → ℚ`) is mapped to a Hopfield network state +if all elements are either `1` or `-1`. Returns `some` mapping if successful, otherwise `none`. +-/ +def patternsOfVecs (V : Fin m → Fin n → ℚ) [NeZero n] (hmn : m < n) : + Option (Fin m → (HopfieldNetwork ℚ (Fin n)).State) := + obviousFunction (fun i => pattern_ofVec (V i)) + +/-- +`ZeroParams_4` defines a simple Hopfield Network with 4 neurons. + +- `w`: A 4x4 weight matrix filled with zeros. +- `hw`: Proof that the weight matrix is symmetric. +- `hw'`: Proof that the weight matrix has zeros on the diagonal. +- `σ`: An empty vector for each neuron. +- `θ`: A threshold of 0 for each neuron, with proof that the list has length 1. +--/ +def ZeroParams_4 : Params (HopfieldNetwork ℚ (Fin 4)) where + w := (Matrix.of ![![0,0,0,0], ![0,0,0,0], ![0,0,0,0], ![0,0,0,0]]) + hw u v huv := by { + unfold HopfieldNetwork at huv + simp only [ne_eq, Decidable.not_not] at huv + rw [huv] + revert v v + simp only [forall_eq'] + revert u u + decide} + hw' := by { + unfold HopfieldNetwork + simp only + decide} + σ u := Vector.emptyWithCapacity 0 + θ u := ⟨#[0], by {simp only [List.size_toArray, List.length_cons, + List.length_nil, zero_add]}⟩ + +/-- +`ps` are two patterns represented by a 2x4 matrix of rational numbers. +--/ +def ps : Fin 2 → Fin 4 → ℚ := ![![1,1,-1,-1], ![-1,1,-1,1]] + +/-- +`test_params` sets up a `HopfieldNetwork` with 4 neurons. +It converts the patterns `ps` into a network using Hebbian learning if possible. +If not, it defaults to `ZeroParams_4`. +--/ +def test_params : Params (HopfieldNetwork ℚ (Fin 4)) := + match (patternsOfVecs ps (by simp only [Nat.succ_eq_add_one, zero_add, + Nat.reduceAdd, Nat.reduceLT])) with + | some patterns => Hebbian patterns + | none => ZeroParams_4 + +/-- +`useq_Fin n` maps a natural number `i` to an element of `Fin n` (a type with `n` elements). +It wraps `i` around using modulo `n`. + +Arguments: +- `n`: The size of the finite type (must be non-zero). +- `i`: The natural number to convert. + +Returns: +- An element of `Fin n`. +--/ +def useq_Fin n [NeZero n] : ℕ → Fin n := fun i => + ⟨_, Nat.mod_lt i (Nat.pos_of_neZero n)⟩ + +lemma useq_Fin_cyclic n [NeZero n] : cyclic (useq_Fin n) := by + unfold cyclic + unfold useq_Fin + simp only [Fintype.card_fin] + apply And.intro + · intros u + use u.val + cases' u with u hu + simp only + simp_all only [Fin.mk.injEq] + exact Nat.mod_eq_of_lt hu + · intros i j hij + exact Fin.mk.inj_iff.mpr hij + +lemma useq_Fin_fair n [NeZero n] : fair (useq_Fin n) := + cycl_Fair (useq_Fin n) (useq_Fin_cyclic n) + +#eval HopfieldNet_stabilize test_params extu (useq_Fin 4) (useq_Fin_fair 4) + +#eval HopfieldNet_conv_time_steps test_params extu (useq_Fin 4) (useq_Fin_cyclic 4) From dd0960b74c2096e28b92665998ad3032bca17c51 Mon Sep 17 00:00:00 2001 From: mkaratarakis Date: Mon, 28 Jul 2025 18:12:10 +0200 Subject: [PATCH 02/15] imports --- PhysLean.lean | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/PhysLean.lean b/PhysLean.lean index 5ad208927..cbd7e65bb 100644 --- a/PhysLean.lean +++ b/PhysLean.lean @@ -335,3 +335,14 @@ import PhysLean.StringTheory.FTheory.SU5U1.Quanta.IsViable.Elems import PhysLean.StringTheory.FTheory.SU5U1.Quanta.ToList import PhysLean.StringTheory.FTheory.SU5U1.Quanta.YukawaRegeneration import PhysLean.Thermodynamics.Basic +import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.Asym +import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.BoltzmannMachine.Core +import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.BoltzmannMachine.Markov +import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.Core +import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.Markov +import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.NNStochastic +import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.NeuralNetwork +import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.Stochastic +import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.StochasticAux +import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.aux +import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.test From 406adf2e98df8ae4361b76d96eb1802366e70423 Mon Sep 17 00:00:00 2001 From: mkaratarakis Date: Mon, 28 Jul 2025 18:19:53 +0200 Subject: [PATCH 03/15] fix lint --- PhysLean.lean | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/PhysLean.lean b/PhysLean.lean index 5ad208927..cbd7e65bb 100644 --- a/PhysLean.lean +++ b/PhysLean.lean @@ -335,3 +335,14 @@ import PhysLean.StringTheory.FTheory.SU5U1.Quanta.IsViable.Elems import PhysLean.StringTheory.FTheory.SU5U1.Quanta.ToList import PhysLean.StringTheory.FTheory.SU5U1.Quanta.YukawaRegeneration import PhysLean.Thermodynamics.Basic +import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.Asym +import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.BoltzmannMachine.Core +import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.BoltzmannMachine.Markov +import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.Core +import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.Markov +import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.NNStochastic +import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.NeuralNetwork +import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.Stochastic +import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.StochasticAux +import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.aux +import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.test From 0f7edeff8d98dfc58df90c1886a064fabd7dd265 Mon Sep 17 00:00:00 2001 From: mkaratarakis Date: Fri, 1 Aug 2025 22:39:52 +0200 Subject: [PATCH 04/15] refactor --- .../SpinGlasses/HopfieldNetwork/Asym.lean | 5 +- .../BoltzmannMachine/Markov.lean | 13 +- .../HopfieldNetwork/Stochastic.lean | 735 ++++++++---------- .../HopfieldNetwork/StochasticAux.lean | 12 +- 4 files changed, 320 insertions(+), 445 deletions(-) diff --git a/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Asym.lean b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Asym.lean index 4f71fbd4c..da85669cc 100644 --- a/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Asym.lean +++ b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Asym.lean @@ -32,6 +32,7 @@ This decomposition allows us to analyze the dynamics and convergence properties of Hopfield networks with asymmetric weights. -/ + open Finset Matrix NeuralNetwork State variable {R U : Type} [Field R] [LinearOrder R] [IsStrictOrderedRing R] @@ -58,8 +59,8 @@ The network has: - Asymmetric weights (w_ij can differ from w_ji) - Weights that can be decomposed into antisymmetric and positive definite components -/ -abbrev AsymmetricHopfieldNetwork (R U : Type) [Field R] [LinearOrder R] [IsStrictOrderedRing R] [DecidableEq U] - [Nonempty U] [Fintype U] [StarRing R] : NeuralNetwork R U where +abbrev AsymmetricHopfieldNetwork (R U : Type) [Field R] [LinearOrder R] [IsStrictOrderedRing R] + [DecidableEq U] [Nonempty U] [Fintype U] [StarRing R] : NeuralNetwork R U where /- The adjacency relation between neurons `u` and `v`, defined as `u ≠ v`. -/ Adj u v := u ≠ v /- The set of input neurons, defined as the universal set. -/ diff --git a/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/BoltzmannMachine/Markov.lean b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/BoltzmannMachine/Markov.lean index 90361556e..65ca9735f 100644 --- a/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/BoltzmannMachine/Markov.lean +++ b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/BoltzmannMachine/Markov.lean @@ -4,7 +4,6 @@ Released under Apache 2.0 license as described in the file LICENSE. Authors: Matteo Cipollina -/ import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.BoltzmannMachine.Core -import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.BoltzmannMachine.Markov open Finset Matrix NeuralNetwork State ENNReal Real open PMF MeasureTheory ProbabilityTheory.Kernel Set @@ -101,8 +100,7 @@ noncomputable def partitionFunctionBM (p : ParamsBM R U) : ENNReal := /-- The partition function is positive and finite, provided T > 0. -/ -lemma partitionFunctionBM_pos_finite (p : ParamsBM R U) - [Nonempty (StateBM R U)] : +lemma partitionFunctionBM_pos_finite (p : ParamsBM R U) [Nonempty (StateBM R U)] : 0 < partitionFunctionBM p ∧ partitionFunctionBM p < ⊤ := by constructor · -- Proof of 0 < Z @@ -125,8 +123,7 @@ $\pi(s) = \frac{1}{Z} e^{-E(s)/T}$. Defined as a measure with density `boltzmannDensityFnBM / partitionFunctionBM` with respect to the counting measure on the finite state space. -/ -noncomputable def boltzmannDistributionBM (p : ParamsBM R U) - [ Nonempty (StateBM R U)] : +noncomputable def boltzmannDistributionBM (p : ParamsBM R U) [Nonempty (StateBM R U)] : Measure (StateBM R U) := let density := fun s => boltzmannDensityFnBM p s / partitionFunctionBM p let Z_pos_finite := partitionFunctionBM_pos_finite p @@ -141,7 +138,8 @@ noncomputable def boltzmannDistributionBM (p : ParamsBM R U) -- Cleaner definition relying on the proof that Z is good noncomputable def boltzmannDistributionBM' (p : ParamsBM R U) : Measure (StateBM R U) := - @Measure.withDensity (StateBM R U) _ Measure.count (fun s => boltzmannDensityFnBM p s / partitionFunctionBM p) + @Measure.withDensity (StateBM R U) _ Measure.count + (fun s => boltzmannDensityFnBM p s / partitionFunctionBM p) -- Prove it's a probability measure instance isProbabilityMeasure_boltzmannDistributionBM @@ -170,4 +168,5 @@ instance isProbabilityMeasure_boltzmannDistributionBM -- The numerator sum is exactly the definition of the partition function rw [← partitionFunctionBM] -- So we get Z/Z = 1 - exact ENNReal.div_self (partitionFunctionBM_pos_finite p).1.ne' (partitionFunctionBM_pos_finite p).2.ne + exact ENNReal.div_self (partitionFunctionBM_pos_finite p).1.ne' + (partitionFunctionBM_pos_finite p).2.ne diff --git a/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Stochastic.lean b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Stochastic.lean index 2235c1991..073428551 100644 --- a/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Stochastic.lean +++ b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Stochastic.lean @@ -6,10 +6,14 @@ Authors: Matteo Cipollina import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.NNStochastic import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.StochasticAux +import PhysLean.StatisticalMechanics.Temperature import Mathlib.Analysis.RCLike.Basic import Mathlib.LinearAlgebra.AffineSpace.AffineMap import Mathlib.LinearAlgebra.Dual.Lemmas +set_option linter.unusedSectionVars false +set_option linter.unusedVariables false + /- # Stochastic Hopfield Network Implementation @@ -31,76 +35,95 @@ and stochastic updates using both Gibbs sampling and Metropolis-Hastings algorit verification of adjacency (`all_nodes_adjacent`), total variation distance (`total_variation_distance`), partition function (`partitionFunction`), and more. -/ -open Finset Matrix NeuralNetwork State +open Finset Matrix NeuralNetwork State ENNReal Real variable {R U : Type} [Field R] [LinearOrder R] [IsStrictOrderedRing R] - [DecidableEq U] [Fintype U] [Nonempty U] (wθ : Params (HopfieldNetwork R U)) (s : (HopfieldNetwork R U).State) - [Coe R ℝ] (T : ℝ) + [DecidableEq U] [Fintype U] [Nonempty U] (wθ : Params (HopfieldNetwork R U)) + (s : (HopfieldNetwork R U).State) [Coe R ℝ] (T : ℝ) /-- Performs a Gibbs update on a single neuron `u` of the state `s`. The update probability depends on the energy change associated with flipping the neuron's state, parameterized by the temperature `T`. -/ -noncomputable def NN.State.gibbsUpdateNeuron [Coe R ℝ] (T : ℝ) (u : U) : PMF ((HopfieldNetwork R U).State) := - let h_u := s.net wθ u - let ΔE := 2 * h_u * s.act u - let p_flip := ENNReal.ofReal (Real.exp (-(↑ΔE) / T)) / (1 + ENNReal.ofReal (Real.exp (-(↑ΔE) / T))) +noncomputable def NN.State.gibbsUpdateNeuron (u : U) : + PMF ((HopfieldNetwork R U).State) := + let hᵤ := s.net wθ u + let ΔE := 2 * hᵤ * s.act u + let p_flip := ENNReal.ofReal (exp (-(↑ΔE) / T)) / (1 + ENNReal.ofReal (exp (-(↑ΔE) / T))) let p_flip_le_one : p_flip ≤ 1 := by - simp only [p_flip] - let a := ENNReal.ofReal (Real.exp (-(↑ΔE) / T)) - have h_a_nonneg : 0 ≤ a := zero_le a - have h_denom_ne_zero : 1 + a ≠ 0 := by - intro h - have h1 : 0 ≤ 1 + a := zero_le (1 + a) - have h2 : 1 + a = 0 := h - simp_all only [zero_le, add_eq_zero, one_ne_zero, ENNReal.ofReal_eq_zero, false_and, a, ΔE, h_u, p_flip] - have h_sum_ne_top : (1 + a) ≠ ⊤ := by - apply ENNReal.add_ne_top.2 - constructor - · exact ENNReal.one_ne_top - · apply ENNReal.ofReal_ne_top - rw [ENNReal.div_le_iff h_denom_ne_zero h_sum_ne_top] - simp only [one_mul, h_u, ΔE, a, p_flip] - exact le_add_self - PMF.bind (PMF.bernoulli p_flip p_flip_le_one) $ λ should_flip => + let a := ENNReal.ofReal (exp (-(↑ΔE) / T)) + have h_sum_ne_top : (1 + a) ≠ ⊤ := add_ne_top.2 ⟨one_ne_top, ofReal_ne_top⟩ + rw [ENNReal.div_le_iff _ h_sum_ne_top, one_mul] + · exact le_add_self + · intro H; rw [add_eq_zero] at H; simp only [one_ne_zero] at H; exact H.1 + PMF.bind (PMF.bernoulli p_flip p_flip_le_one) $ fun should_flip => PMF.pure $ if should_flip then s.Up wθ u else s + -- Calculate probabilities based on Boltzmann distribution +noncomputable def probs (u : U) (local_field : R) : Bool → ENNReal := fun b => + let new_act_val := if b then 1 else -1 + ENNReal.ofReal (exp (local_field * new_act_val / T)) + +noncomputable def total (u : U) (local_field : R) : ENNReal := + probs T u local_field true + probs T u local_field false + +noncomputable def norm_probs (u : U) (local_field : R) : Bool → ENNReal := fun b => + probs T u local_field b / total T u local_field + +noncomputable def Z (local_field : R) := + ENNReal.ofReal (exp (local_field / T)) + ENNReal.ofReal (exp (-local_field / T)) + +omit [Field R] [LinearOrder R] [IsStrictOrderedRing R] in +lemma h_total_eq_Z (local_field : R) : total T u local_field = (Z T local_field) := by + simp only [mul_ite, mul_one, mul_neg, ↓reduceIte, Bool.false_eq_true, total, + probs, Z] + +omit [Field R] [LinearOrder R] [IsStrictOrderedRing R] [DecidableEq U] [Fintype U] [Nonempty U] in +lemma h_total_ne_zero (u : U) (local_field : R) : total T u local_field ≠ 0 := by + simp only [total, probs, ne_eq, add_eq_zero] + intro h + have h_exp_pos : ENNReal.ofReal (exp (local_field * 1 / T)) > 0 := by + apply ENNReal.ofReal_pos.mpr; apply exp_pos + exact (not_and_or.mpr (Or.inl h_exp_pos.ne')) h + +lemma h_sum (u : U) (local_field : R) : ∑ b : Bool, (norm_probs T u local_field b ) = + (probs T u local_field true + probs T u local_field false) / total T u local_field := by + simp only [Fintype.univ_bool, mem_singleton, Bool.true_eq_false, + not_false_eq_true,sum_insert, sum_singleton, total, probs, Z] + exact ENNReal.div_add_div_same + +lemma h_total_ne_top (u : U) (local_field : R) : total T u local_field ≠ ⊤ := by simp [total, probs] + /-- Update a single neuron according to Gibbs sampling rule -/ noncomputable def NN.State.gibbsUpdateSingleNeuron (u : U) : PMF ((HopfieldNetwork R U).State) := -- Calculate local field for the neuron let local_field := s.net wθ u - -- Calculate probabilities based on Boltzmann distribution - let probs : Bool → ENNReal := fun b => - let new_act_val := if b then 1 else -1 - ENNReal.ofReal (Real.exp (local_field * new_act_val / T)) - -- Create PMF with normalized probabilities - let total : ENNReal := probs true + probs false - let norm_probs : Bool → ENNReal := λ b => probs b / total -- Convert Bool to State - (PMF.map (λ b => if b then + (PMF.map (fun b => if b then NN.State.updateNeuron s u 1 (mul_self_eq_mul_self_iff.mp rfl) else NN.State.updateNeuron s u (-1) (AffineMap.lineMap_eq_lineMap_iff.mp rfl)) - (PMF.ofFintype norm_probs (by - have h_total : total ≠ 0 := by { - simp [probs] - refine ENNReal.inv_ne_top.mp ?_ - have h_exp_pos := Real.exp_pos (local_field * 1 / T) - have h := ENNReal.ofReal_pos.mpr h_exp_pos - simp_all only [mul_one, ENNReal.ofReal_pos, mul_ite, mul_neg, ↓reduceIte, Bool.false_eq_true, ne_eq, - ENNReal.inv_eq_top, add_eq_zero, ENNReal.ofReal_eq_zero, not_and, not_le, isEmpty_Prop, IsEmpty.forall_iff, - local_field, total, probs]} - have h_total_ne_top : total ≠ ⊤ := by {simp [probs, total]} - have h_sum : Finset.sum Finset.univ norm_probs = 1 := by - calc Finset.sum Finset.univ norm_probs - = (probs true)/total + (probs false)/total := Fintype.sum_bool fun b ↦ probs b / total - _ = (probs true + probs false)/total := ENNReal.div_add_div_same - _ = total/total := by rfl - _ = 1 := ENNReal.div_self h_total h_total_ne_top + (PMF.ofFintype (norm_probs T u local_field) (by + have h_sum : Finset.sum Finset.univ (norm_probs T u local_field) = 1 := by + calc Finset.sum Finset.univ (norm_probs T u local_field) + = (probs T u local_field true)/total T u local_field + + (probs T u local_field false)/total T u local_field := + Fintype.sum_bool fun b ↦ probs T u local_field b / total T u local_field + _ = (probs T u local_field true + probs T u local_field false)/total T u local_field := + ENNReal.div_add_div_same + _ = total T u local_field /total T u local_field := by rfl + _ = 1 := ENNReal.div_self (h_total_ne_zero T u local_field) (h_total_ne_top T u local_field) exact h_sum))) @[inherit_doc] scoped[ENNReal] notation "ℝ≥0∞" => ENNReal +open Fintype + +theorem NN.State.gibbsSamplingStep.extracted_1 {U : Type} [inst : Fintype U] [inst_1 : Nonempty U] : + ∑ a : U, (fun _ => 1 / ((Fintype.card U) : ENNReal)) a = 1 := by { + exact uniform_neuron_selection_prob_valid + } + /-- Given a Hopfield Network's parameters, temperature, and current state, performs a single step of Gibbs sampling by: 1. Uniformly selecting a random neuron @@ -109,28 +132,18 @@ of Gibbs sampling by: noncomputable def NN.State.gibbsSamplingStep : PMF ((HopfieldNetwork R U).State) := -- Uniform random selection of neuron let neuron_pmf : PMF U := - PMF.ofFintype (λ _ => (1 : ENNReal) / (Fintype.card U : ENNReal)) - (by - rw [Finset.sum_const, Finset.card_univ] - rw [ENNReal.div_eq_inv_mul] - simp only [mul_one] - have h : (Fintype.card U : ENNReal) ≠ 0 := by - simp [Fintype.card_pos_iff.mpr inferInstance] - have h_top : (Fintype.card U : ENNReal) ≠ ⊤ := ENNReal.coe_ne_top - rw [← ENNReal.mul_inv_cancel h h_top] - simp_all only [ne_eq, Nat.cast_eq_zero, Fintype.card_ne_zero, not_false_eq_true, ENNReal.natCast_ne_top, - nsmul_eq_mul]) + PMF.ofFintype (fun _ => (1 : ENNReal) / (card U : ENNReal)) + (NN.State.gibbsSamplingStep.extracted_1 (U:=U)) -- Bind neuron selection with conditional update - PMF.bind neuron_pmf $ λ u => NN.State.gibbsUpdateSingleNeuron wθ s T u + PMF.bind neuron_pmf $ fun u => NN.State.gibbsUpdateSingleNeuron wθ s T u instance : Coe ℝ ℝ := ⟨id⟩ /-- Perform a stochastic update on a Pattern representation -/ -noncomputable def patternStochasticUpdate - {n : ℕ} [Nonempty (Fin n)] (weights : Fin n → Fin n → ℝ) (h_diag_zero : ∀ i : Fin n, weights i i = 0) - (h_sym : ∀ i j : Fin n, weights i j = weights j i) (T : ℝ) - (pattern : NeuralNetwork.State (HopfieldNetwork ℝ (Fin n))) (i : Fin n) : - PMF (NeuralNetwork.State (HopfieldNetwork ℝ (Fin n))) := +noncomputable def patternStochasticUpdate {n : ℕ} [Nonempty (Fin n)] (weights : Fin n → Fin n → ℝ) + (h_diag_zero : ∀ i : Fin n, weights i i = 0) (h_sym : ∀ i j : Fin n, weights i j = weights j i) + (pattern : State (HopfieldNetwork ℝ (Fin n))) (i : Fin n) : + PMF (State (HopfieldNetwork ℝ (Fin n))) := let wθ : Params (HopfieldNetwork ℝ (Fin n)) := { w := weights, hw := fun u v h => by @@ -138,12 +151,8 @@ noncomputable def patternStochasticUpdate rw [h_eq] exact h_diag_zero v else - have h_adj : (HopfieldNetwork ℝ (Fin n)).Adj u v := by - simp only [HopfieldNetwork]; simp only [ne_eq] - exact h_eq contradiction hw' := by - unfold NeuralNetwork.pw exact IsSymm.ext_iff.mpr fun i j ↦ h_sym j i σ := fun u => Vector.mk (Array.replicate ((HopfieldNetwork ℝ (Fin n)).κ1 u) (0 : ℝ)) rfl, @@ -160,12 +169,14 @@ noncomputable def patternStochasticUpdate noncomputable def NN.State.gibbsSamplingSteps (steps : ℕ) : PMF ((HopfieldNetwork R U).State) := match steps with | 0 => PMF.pure s - | steps+1 => PMF.bind (gibbsSamplingSteps steps) $ λ s' => + | steps + 1 => PMF.bind (gibbsSamplingSteps steps) $ fun s' => NN.State.gibbsSamplingStep wθ s' T /-- Temperature schedule for simulated annealing that decreases exponentially with each step. -/ noncomputable def temperatureSchedule (initial_temp : ℝ) (cooling_rate : ℝ) (step : ℕ) : ℝ := - initial_temp * Real.exp (-cooling_rate * step) + initial_temp * exp (-cooling_rate * step) + + --initial_temp * exp (-cooling_rate * step) /-- Recursively applies Gibbs sampling steps with decreasing temperature according to the cooling schedule, terminating when the step count reaches the target number of steps. -/ @@ -178,27 +189,25 @@ noncomputable def applyAnnealingSteps (temp_schedule : ℕ → ℝ) (steps : ℕ (applyAnnealingSteps temp_schedule steps (step + 1)) termination_by steps - step decreasing_by { - have : step < steps := not_le.mp h - have : steps - (step + 1) < steps - step := by rw [Nat.sub_succ] - simp_all only [ge_iff_le, not_le, Nat.pred_eq_sub_one, tsub_lt_self_iff, tsub_pos_iff_lt, Nat.lt_one_iff, - pos_of_gt, and_self] - exact this -} - -/-- `NN.State.simulatedAnnealing` implements the simulated annealing optimization algorithm for a Hopfield Network. -This function performs simulated annealing by starting from an initial state and gradually reducing -the temperature according to an exponential cooling schedule, allowing the system to explore the -state space and eventually settle into a low-energy configuration. + simp only [Nat.pred_eq_sub_one, tsub_lt_self_iff, tsub_pos_iff_lt, Nat.lt_one_iff] + rw [and_true] + exact not_le.mp h} + +/-- `NN.State.simulatedAnnealing` implements the simulated annealing optimization +algorithm for a Hopfield Network. This function performs simulated annealing by starting +from an initial state and gradually reducing the temperature according to an exponential +cooling schedule, allowing the system to explore the state space and eventually settle into a +low-energy configuration. -/ -noncomputable def NN.State.simulatedAnnealing - (initial_temp : ℝ) (cooling_rate : ℝ) (steps : ℕ) +noncomputable def NN.State.simulatedAnnealing (initial_temp : ℝ) (cooling_rate : ℝ) (steps : ℕ) (initial_state : (HopfieldNetwork R U).State) : PMF ((HopfieldNetwork R U).State) := - let temp_schedule := temperatureSchedule initial_temp cooling_rate + let temp_schedule := temperatureSchedule initial_temp cooling_rate applyAnnealingSteps wθ temp_schedule steps 0 initial_state -/-- Given a HopfieldNetwork with parameters `wθ` and temperature `T`, computes the acceptance probability -for transitioning from a `current` state to a `proposed` state according to the Metropolis-Hastings algorithm. +/-- Given a HopfieldNetwork with parameters `wθ` and temperature `T`, computes the +acceptance probability for transitioning from a `current` state to a `proposed` state according +to the Metropolis-Hastings algorithm. * If the energy difference (ΔE) is negative or zero, returns 1.0 (always accepts the transition) * If the energy difference is positive, returns exp(-ΔE/T) following the Boltzmann distribution @@ -209,31 +218,22 @@ noncomputable def NN.State.acceptanceProbability if energy_diff ≤ 0 then 1.0 -- Always accept if energy decreases else - Real.exp (-energy_diff / T) -- Accept with probability e^(-ΔE/T) if energy increases + exp (-energy_diff / T) -- Accept with probability e^(-ΔE/T) if energy increases /-- The partition function for a Hopfield network, defined as the sum over all possible states of the Boltzmann factor `exp(-E/T)`. -/ noncomputable def NN.State.partitionFunction : ℝ := - ∑ s : (HopfieldNetwork R U).State, Real.exp (-s.E wθ / T) + ∑ s : (HopfieldNetwork R U).State, exp (-s.E wθ / T) /-- Metropolis-Hastings single step for Hopfield networks -/ noncomputable def NN.State.metropolisHastingsStep : PMF ((HopfieldNetwork R U).State) := -- Uniform random selection of neuron let neuron_pmf : PMF U := - PMF.ofFintype (λ _ => (1 : ENNReal) / (Fintype.card U : ENNReal)) - (by - rw [Finset.sum_const, Finset.card_univ] - rw [ENNReal.div_eq_inv_mul] - simp only [mul_one] - have h : (Fintype.card U : ENNReal) ≠ 0 := by - simp [Fintype.card_pos_iff.mpr inferInstance] - have h_top : (Fintype.card U : ENNReal) ≠ ⊤ := ENNReal.coe_ne_top - rw [← ENNReal.mul_inv_cancel h h_top] - simp_all only [ne_eq, Nat.cast_eq_zero, Fintype.card_ne_zero, not_false_eq_true, ENNReal.natCast_ne_top, - nsmul_eq_mul]) + PMF.ofFintype (fun _ => (1 : ENNReal) / (Fintype.card U : ENNReal)) + (gibbsSamplingStep.extracted_1) -- Create proposed state by flipping a randomly selected neuron - let propose : U → PMF ((HopfieldNetwork R U).State) := λ u => + let propose : U → PMF ((HopfieldNetwork R U).State) := fun u => let flipped_state := if s.act u = 1 then -- Assuming 1 and -1 as valid activation values NN.State.updateNeuron s u (-1) (Or.inr rfl) @@ -241,7 +241,7 @@ noncomputable def NN.State.metropolisHastingsStep : PMF ((HopfieldNetwork R U).S NN.State.updateNeuron s u 1 (Or.inl rfl) let p := NN.State.acceptanceProbability wθ T s flipped_state -- Make acceptance decision - PMF.bind (NN.State.metropolisDecision p) (λ (accept : Bool) => + PMF.bind (NN.State.metropolisDecision p) (fun (accept : Bool) => if accept then PMF.pure flipped_state else PMF.pure s) -- Combine neuron selection with state proposal PMF.bind neuron_pmf propose @@ -251,12 +251,12 @@ noncomputable def NN.State.metropolisHastingsSteps (steps : ℕ) : PMF ((HopfieldNetwork R U).State) := match steps with | 0 => PMF.pure s - | steps+1 => PMF.bind (metropolisHastingsSteps steps) $ λ s' => + | steps+1 => PMF.bind (metropolisHastingsSteps steps) $ fun s' => NN.State.metropolisHastingsStep wθ s' T /-- The Boltzmann (Gibbs) distribution over neural network states -/ noncomputable def boltzmannDistribution : ((HopfieldNetwork R U).State → ℝ) := - λ s => Real.exp (-s.E wθ / T) / NN.State.partitionFunction wθ T + fun s => exp (-s.E wθ / T) / NN.State.partitionFunction wθ T /-- The transition probability matrix for Gibbs sampling -/ noncomputable def gibbsTransitionProb (s s' : (HopfieldNetwork R U).State) : ℝ := @@ -271,71 +271,52 @@ noncomputable def total_variation_distance (μ ν : (HopfieldNetwork R U).State → ℝ) : ℝ := (1/2) * ∑ s : (HopfieldNetwork R U).State, |μ s - ν s| -/-- For Gibbs updates, given the normalization and probabilities, the sum of normalized probabilities equals 1 -/ -lemma gibbs_probs_sum_one - (v : U) : +/-- For Gibbs updates, given the normalization and probabilities, the sum of + normalized probabilities equals 1 -/ +lemma gibbs_probs_sum_one (v : U) : let local_field := s.net wθ v - let probs : Bool → ENNReal := fun b => - let new_act_val := if b then 1 else -1 - ENNReal.ofReal (Real.exp (local_field * new_act_val / T)) - let total := probs true + probs false - let norm_probs := λ b => probs b / total - ∑ b : Bool, norm_probs b = 1 := by - intro local_field probs total norm_probs - have h_sum : ∑ b : Bool, probs b / total = (probs true + probs false) / total := by - simp only [Fintype.sum_bool] - exact ENNReal.div_add_div_same + let norm_probs := fun b => probs T v local_field b / total T v local_field + ∑ b : Bool, norm_probs b = 1 := by + intro local_field norm_probs + have h_sum : ∑ b : Bool, probs T v local_field b / total T v local_field = + (probs T v local_field true + probs T v local_field false) / total T v local_field := by + rw [Fintype.sum_bool, ENNReal.div_add_div_same] rw [h_sum] - have h_total_eq : probs true + probs false = total := by rfl + have h_total_eq : probs T v local_field true + + probs T v local_field false = total T v local_field := by rfl rw [h_total_eq] - have h_total_ne_zero : total ≠ 0 := by - simp only [total, probs, ne_eq] - intro h_zero - have h1 : ENNReal.ofReal (Real.exp (local_field * 1 / T)) > 0 := by - apply ENNReal.ofReal_pos.mpr - apply Real.exp_pos - have h_sum_zero : ENNReal.ofReal (Real.exp (local_field * 1 / T)) + - ENNReal.ofReal (Real.exp (local_field * (-1) / T)) = 0 := h_zero - exact h1.ne' (add_eq_zero.mp h_sum_zero).1 - have h_total_ne_top : total ≠ ⊤ := by simp [total, probs] - exact ENNReal.div_self h_total_ne_zero h_total_ne_top + exact ENNReal.div_self (h_total_ne_zero T v local_field) (h_total_ne_top T v local_field) /-- The function that maps boolean values to states in Gibbs sampling -/ def gibbs_bool_to_state_map (s : (HopfieldNetwork R U).State) (v : U) : Bool → (HopfieldNetwork R U).State := - λ b => if b then + fun b => if b then NN.State.updateNeuron s v 1 (mul_self_eq_mul_self_iff.mp rfl) else NN.State.updateNeuron s v (-1) (AffineMap.lineMap_eq_lineMap_iff.mp rfl) /-- The total normalization constant for Gibbs sampling is positive -/ -lemma gibbs_total_positive - (local_field : ℝ) (T : ℝ) : +lemma gibbs_total_positive (local_field : ℝ) (T : ℝ) : let probs : Bool → ENNReal := fun b => let new_act_val := if b then 1 else -1 - ENNReal.ofReal (Real.exp (local_field * new_act_val / T)) + ENNReal.ofReal (exp (local_field * new_act_val / T)) probs true + probs false ≠ 0 := by - intro probs - simp only [ne_eq] - intro h_zero - have h1 : ENNReal.ofReal (Real.exp (local_field * 1 / T)) > 0 := by + intro probs h_zero + have h1 : ENNReal.ofReal (exp (local_field * 1 / T)) > 0 := by apply ENNReal.ofReal_pos.mpr - apply Real.exp_pos - have h_sum_zero : ENNReal.ofReal (Real.exp (local_field * 1 / T)) + - ENNReal.ofReal (Real.exp (local_field * (-1) / T)) = 0 := h_zero - have h_both_zero : ENNReal.ofReal (Real.exp (local_field * 1 / T)) = 0 ∧ - ENNReal.ofReal (Real.exp (local_field * (-1) / T)) = 0 := + apply exp_pos + have h_sum_zero : ENNReal.ofReal (exp (local_field * 1 / T)) + + ENNReal.ofReal (exp (local_field * (-1) / T)) = 0 := by { + exact h_zero + } + have h_both_zero : ENNReal.ofReal (exp (local_field * 1 / T)) = 0 ∧ + ENNReal.ofReal (exp (local_field * (-1) / T)) = 0 := add_eq_zero.mp h_sum_zero exact h1.ne' h_both_zero.1 /-- The total normalization constant for Gibbs sampling is not infinity -/ -lemma gibbs_total_not_top - (local_field : ℝ) (T : ℝ) : - let probs : Bool → ENNReal := fun b => - let new_act_val := if b then 1 else -1 - ENNReal.ofReal (Real.exp (local_field * new_act_val / T)) - probs true + probs false ≠ ⊤ := by - intro probs +lemma gibbs_total_not_top (local_field : ℝ) (T : ℝ) : + probs T u local_field true + probs T u local_field false ≠ ⊤ := by simp only [mul_ite, mul_one, mul_neg, ↓reduceIte, Bool.false_eq_true, ne_eq, ENNReal.add_eq_top, ENNReal.ofReal_ne_top, or_self, not_false_eq_true, probs] @@ -358,20 +339,19 @@ lemma gibbsUpdate_exists_bool (v : U) (s_next : (HopfieldNetwork R U).State) : unfold NN.State.gibbsUpdateSingleNeuron at h_prob_pos let local_field_R := s.net wθ v let local_field : ℝ := ↑local_field_R - let probs : Bool → ENNReal := fun b => - let new_act_val := if b then 1 else -1 - ENNReal.ofReal (Real.exp (local_field * new_act_val / T)) - let total := probs true + probs false - let norm_probs : Bool → ENNReal := λ b => probs b / total + --let total := probs T v local_field true + probs T v local_field false + let norm_probs : Bool → ENNReal := fun b => probs T v local_field b / total T v local_field let map_fn : Bool → (HopfieldNetwork R U).State := gibbs_bool_to_state_map s v have h_sum_eq_1 : ∑ b : Bool, norm_probs b = 1 := by - have h_total_ne_zero : total ≠ 0 := gibbs_total_positive local_field T - have h_total_ne_top : total ≠ ⊤ := gibbs_total_not_top local_field T + have h_total_ne_zero : total T v local_field ≠ 0 := gibbs_total_positive local_field T + have h_total_ne_top : total T v local_field ≠ ⊤ := gibbs_total_not_top local_field T calc Finset.sum Finset.univ norm_probs - = (probs true)/total + (probs false)/total := - Fintype.sum_bool fun b ↦ probs b / total - _ = (probs true + probs false)/total := ENNReal.div_add_div_same - _ = total/total := by rfl + = (probs T v local_field true) /total T v local_field + + (probs T v local_field false)/total T v local_field := + Fintype.sum_bool fun b ↦ probs T v local_field b / total T v local_field + _ = (probs T v local_field true + + probs T v local_field false)/total T v local_field:= ENNReal.div_add_div_same + _ = total T v local_field /total T v local_field := by rfl _ = 1 := ENNReal.div_self h_total_ne_zero h_total_ne_top let base_pmf := PMF.ofFintype norm_probs h_sum_eq_1 have ⟨b, _, h_map_eq⟩ := pmf_map_pos_implies_preimage h_sum_eq_1 map_fn s_next h_prob_pos @@ -384,8 +364,7 @@ lemma gibbsUpdate_exists_bool (v : U) (s_next : (HopfieldNetwork R U).State) : lemma gibbsUpdate_possible_states (v : U) (s_next : (HopfieldNetwork R U).State) : (NN.State.gibbsUpdateSingleNeuron wθ s T v) s_next > 0 → s_next = NN.State.updateNeuron s v 1 (mul_self_eq_mul_self_iff.mp rfl) ∨ - s_next = NN.State.updateNeuron s v (-1) - (AffineMap.lineMap_eq_lineMap_iff.mp rfl) := by + s_next = NN.State.updateNeuron s v (-1) (AffineMap.lineMap_eq_lineMap_iff.mp rfl) := by intro h_prob_pos obtain ⟨b, h_eq⟩ := gibbsUpdate_exists_bool wθ s T v s_next h_prob_pos cases b with @@ -420,27 +399,25 @@ lemma gibbsUpdate_preserves_other_neurons /-- The probability mass function for a binary choice (true/false) has sum 1 when properly normalized -/ lemma pmf_binary_norm_sum_one (local_field : ℝ) (T : ℝ) : - let probs : Bool → ENNReal := fun b => - let new_act_val := if b then 1 else -1 - ENNReal.ofReal (Real.exp (local_field * new_act_val / T)) - let total := probs true + probs false - let norm_probs := λ b => probs b / total + let total := probs T u local_field true + probs T u local_field false + let norm_probs := fun b => probs T u local_field b / total ∑ b : Bool, norm_probs b = 1 := by - intro probs total norm_probs - have h_sum : ∑ b : Bool, probs b / total = (probs true + probs false) / total := by + intro total norm_probs + have h_sum : ∑ b : Bool, probs T u local_field b / total = + (probs T u local_field true + probs T u local_field false) / total := by simp only [Fintype.sum_bool] exact ENNReal.div_add_div_same rw [h_sum] have h_total_ne_zero : total ≠ 0 := by simp only [total, probs, ne_eq] intro h_zero - have h1 : ENNReal.ofReal (Real.exp (local_field * 1 / T)) > 0 := by + have h1 : ENNReal.ofReal (exp (local_field * 1 / T)) > 0 := by apply ENNReal.ofReal_pos.mpr - apply Real.exp_pos - have h_sum_zero : ENNReal.ofReal (Real.exp (local_field * 1 / T)) + - ENNReal.ofReal (Real.exp (local_field * (-1) / T)) = 0 := h_zero - have h_both_zero : ENNReal.ofReal (Real.exp (local_field * 1 / T)) = 0 ∧ - ENNReal.ofReal (Real.exp (local_field * (-1) / T)) = 0 := by + apply exp_pos + have h_sum_zero : ENNReal.ofReal (exp (local_field * 1 / T)) + + ENNReal.ofReal (exp (local_field * (-1) / T)) = 0 := h_zero + have h_both_zero : ENNReal.ofReal (exp (local_field * 1 / T)) = 0 ∧ + ENNReal.ofReal (exp (local_field * (-1) / T)) = 0 := by exact add_eq_zero.mp h_sum_zero exact h1.ne' h_both_zero.1 have h_total_ne_top : total ≠ ⊤ := by @@ -449,161 +426,139 @@ lemma pmf_binary_norm_sum_one (local_field : ℝ) (T : ℝ) : /-- The normalization factor in Gibbs sampling is the sum of Boltzmann factors for both possible states -/ -lemma gibbs_normalization_factor - (local_field : ℝ) (T : ℝ) : - let probs : Bool → ENNReal := fun b => - let new_act_val := if b then 1 else -1 - ENNReal.ofReal (Real.exp (local_field * new_act_val / T)) - let total := probs true + probs false - total = ENNReal.ofReal (Real.exp (local_field / T)) + ENNReal.ofReal - (Real.exp (-local_field / T)) := by - intro probs total +lemma gibbs_normalization_factor (local_field : ℝ) (T : ℝ) : + let total := probs T u local_field true + probs T u local_field false + total = ENNReal.ofReal (exp (local_field / T)) + ENNReal.ofReal + (exp (-local_field / T)) := by + intro total simp only [probs, total] simp only [↓reduceIte, mul_one, Bool.false_eq_true, mul_neg, total, probs] + rfl /-- The probability mass assigned to true when using Gibbs sampling -/ -lemma gibbs_prob_true - (local_field : ℝ) (T : ℝ) : - let probs : Bool → ENNReal := fun b => - let new_act_val := if b then 1 else -1 - ENNReal.ofReal (Real.exp (local_field * new_act_val / T)) - let total := probs true + probs false - let norm_probs := λ b => probs b / total - norm_probs true = ENNReal.ofReal (Real.exp (local_field / T)) / - (ENNReal.ofReal (Real.exp (local_field / T)) + ENNReal.ofReal - (Real.exp (-local_field / T))) := by - intro probs total norm_probs +lemma gibbs_prob_true (local_field : ℝ) (T : ℝ) : + norm_probs T u local_field true = ENNReal.ofReal (exp (local_field / T)) / + (ENNReal.ofReal (exp (local_field / T)) + ENNReal.ofReal + (exp (-local_field / T))) := by + --intro total --norm_probs simp only [norm_probs, probs] - have h_total : total = ENNReal.ofReal (Real.exp (local_field / T)) + - ENNReal.ofReal (Real.exp (-local_field / T)) := by + have h_total : total T u local_field = ENNReal.ofReal (exp (local_field / T)) + + ENNReal.ofReal (exp (-local_field / T)) := by simp only [mul_ite, mul_one, mul_neg, ↓reduceIte, Bool.false_eq_true, total, probs, norm_probs] + rfl rw [h_total] congr simp only [↓reduceIte, mul_one, total, norm_probs, probs] + rfl /-- The probability mass assigned to false when using Gibbs sampling -/ -lemma gibbs_prob_false - (local_field : ℝ) (T : ℝ) : - let probs : Bool → ENNReal := fun b => - let new_act_val := if b then 1 else -1 - ENNReal.ofReal (Real.exp (local_field * new_act_val / T)) - let total := probs true + probs false - let norm_probs := λ b => probs b / total - norm_probs false = ENNReal.ofReal (Real.exp (-local_field / T)) / - (ENNReal.ofReal (Real.exp (local_field / T)) + ENNReal.ofReal (Real.exp (-local_field / T))) := by - intro probs total norm_probs +lemma gibbs_prob_false (local_field : ℝ) (T : ℝ) : + norm_probs T u local_field false = ENNReal.ofReal (exp (-local_field / T)) / + (ENNReal.ofReal (exp (local_field / T)) + ENNReal.ofReal (exp (-local_field / T))) := by simp only [norm_probs, probs] - have h_total : total = ENNReal.ofReal (Real.exp (local_field / T)) + - ENNReal.ofReal (Real.exp (-local_field / T)) := by + have h_total : total T u local_field = ENNReal.ofReal (exp (local_field / T)) + + ENNReal.ofReal (exp (-local_field / T)) := by simp [total, probs] + rfl rw [h_total] congr simp only [Bool.false_eq_true, ↓reduceIte, mul_neg, mul_one, norm_probs, probs, total] - + rfl /-- Converts the ratio of Boltzmann factors to ENNReal sigmoid form. -/ @[simp] lemma ENNReal_exp_ratio_to_sigmoid (x : ℝ) : - ENNReal.ofReal (Real.exp x) / - (ENNReal.ofReal (Real.exp x) + ENNReal.ofReal (Real.exp (-x))) = - ENNReal.ofReal (1 / (1 + Real.exp (-2 * x))) := by - have num_pos : 0 ≤ Real.exp x := le_of_lt (Real.exp_pos x) - have denom_pos : 0 < Real.exp x + Real.exp (-x) := by + ENNReal.ofReal (exp x) / (ENNReal.ofReal (exp x) + ENNReal.ofReal (exp (-x))) = + ENNReal.ofReal (1 / (1 + exp (-2 * x))) := by + have num_pos : 0 ≤ exp x := le_of_lt (exp_pos x) + have denom_pos : 0 < exp x + exp (-x) := by apply add_pos - · exact Real.exp_pos x - · exact Real.exp_pos (-x) - have h1 : ENNReal.ofReal (Real.exp x) / - (ENNReal.ofReal (Real.exp x) + ENNReal.ofReal (Real.exp (-x))) = - ENNReal.ofReal (Real.exp x / (Real.exp x + Real.exp (-x))) := by - have h_sum : ENNReal.ofReal (Real.exp x) + ENNReal.ofReal (Real.exp (-x)) = - ENNReal.ofReal (Real.exp x + Real.exp (-x)) := by - have exp_neg_pos : 0 ≤ Real.exp (-x) := le_of_lt (Real.exp_pos (-x)) - exact Eq.symm (ENNReal.ofReal_add num_pos exp_neg_pos) + · exact exp_pos x + · exact exp_pos (-x) + have h1 : ENNReal.ofReal (exp x) / + (ENNReal.ofReal (exp x) + ENNReal.ofReal (exp (-x))) = + ENNReal.ofReal (exp x / (exp x + exp (-x))) := by + have h_sum : ENNReal.ofReal (exp x) + ENNReal.ofReal (exp (-x)) = + ENNReal.ofReal (exp x + exp (-x)) := by + have exp_neg_pos : 0 ≤ exp (-x) := le_of_lt (exp_pos (-x)) + exact Eq.symm (ofReal_add num_pos exp_neg_pos) rw [h_sum] - exact Eq.symm (ENNReal.ofReal_div_of_pos denom_pos) - have h2 : Real.exp x / (Real.exp x + Real.exp (-x)) = 1 / (1 + Real.exp (-2 * x)) := by - have h_denom : Real.exp x + Real.exp (-x) = Real.exp x * (1 + Real.exp (-2 * x)) := by - have h_exp_diff : Real.exp (-x) = Real.exp x * Real.exp (-2 * x) := by - rw [← Real.exp_add]; congr; ring - calc Real.exp x + Real.exp (-x) - = Real.exp x + Real.exp x * Real.exp (-2 * x) := by rw [h_exp_diff] - _ = Real.exp x * (1 + Real.exp (-2 * x)) := by rw [mul_add, mul_one] + exact Eq.symm (ofReal_div_of_pos denom_pos) + have h2 : exp x / (exp x + exp (-x)) = 1 / (1 + exp (-2 * x)) := by + have h_denom : exp x + exp (-x) = exp x * (1 + exp (-2 * x)) := by + have h_exp_diff : exp (-x) = exp x * exp (-2 * x) := by + rw [← exp_add]; congr; ring + calc exp x + exp (-x) + = exp x + exp x * exp (-2 * x) := by rw [h_exp_diff] + _ = exp x * (1 + exp (-2 * x)) := by rw [mul_add, mul_one] rw [h_denom, div_mul_eq_div_div] - have h_exp_ne_zero : Real.exp x ≠ 0 := ne_of_gt (Real.exp_pos x) + have h_exp_ne_zero : exp x ≠ 0 := ne_of_gt (exp_pos x) field_simp rw [h1, h2] @[simp] -lemma ENNReal.div_ne_top' {a b : ENNReal} (ha : a ≠ ⊤) (hb : b ≠ 0) : - a / b ≠ ⊤ := by +lemma ENNReal.div_ne_top' {a b : ENNReal} (ha : a ≠ ⊤) (hb : b ≠ 0) : a / b ≠ ⊤ := by intro h_top - rw [ENNReal.div_eq_top] at h_top + rw [div_eq_top] at h_top rcases h_top with (⟨_, h_right⟩ | ⟨h_left, _⟩); exact hb h_right; exact ha h_left -lemma gibbs_prob_positive - (local_field : ℝ) (T : ℝ) : - let probs : Bool → ENNReal := fun b => - let new_act_val := if b then 1 else -1 - ENNReal.ofReal (Real.exp (local_field * new_act_val / T)) - let total := probs true + probs false - ENNReal.ofReal (Real.exp (local_field / T)) / total = - ENNReal.ofReal (1 / (1 + Real.exp (-2 * local_field / T))) := by - intro probs total - have h_total : total = ENNReal.ofReal (Real.exp (local_field / T)) + - ENNReal.ofReal (Real.exp (-local_field / T)) := by +lemma gibbs_prob_positive (local_field : ℝ) (T : ℝ) : + let total := probs T u local_field true + probs T u local_field false + ENNReal.ofReal (exp (local_field / T)) / total = + ENNReal.ofReal (1 / (1 + exp (-2 * local_field / T))) := by + intro total + have h_total : total = ENNReal.ofReal (exp (local_field / T)) + + ENNReal.ofReal (exp (-local_field / T)) := by simp only [mul_ite, mul_one, mul_neg, ↓reduceIte, Bool.false_eq_true, total, probs] + rfl rw [h_total] - have h_temp : ∀ x, Real.exp (x / T) = Real.exp (x * (1/T)) := by + have h_temp : ∀ x, exp (x / T) = exp (x * (1/T)) := by intro x; congr; field_simp rw [h_temp local_field, h_temp (-local_field)] have h_direct : - ENNReal.ofReal (Real.exp (local_field * (1 / T))) / - (ENNReal.ofReal (Real.exp (local_field * (1 / T))) + - ENNReal.ofReal (Real.exp (-local_field * (1 / T)))) = - ENNReal.ofReal (1 / (1 + Real.exp (-2 * local_field / T))) := by + ENNReal.ofReal (exp (local_field * (1 / T))) / + (ENNReal.ofReal (exp (local_field * (1 / T))) + + ENNReal.ofReal (exp (-local_field * (1 / T)))) = + ENNReal.ofReal (1 / (1 + exp (-2 * local_field / T))) := by have h := ENNReal_exp_ratio_to_sigmoid (local_field * (1 / T)) have h_rhs : -2 * (local_field * (1 / T)) = -2 * local_field / T := by field_simp rw [h_rhs] at h - have neg_equiv : ENNReal.ofReal (Real.exp (-(local_field * (1 / T)))) = - ENNReal.ofReal (Real.exp (-local_field * (1 / T))) := by + have neg_equiv : ENNReal.ofReal (exp (-(local_field * (1 / T)))) = + ENNReal.ofReal (exp (-local_field * (1 / T))) := by congr; ring rw [neg_equiv] at h exact h exact h_direct /-- The probability of setting a neuron to -1 under Gibbs sampling -/ -lemma gibbs_prob_negative - (local_field : ℝ) (T : ℝ) : - let probs : Bool → ENNReal := fun b => - let new_act_val := if b then 1 else -1 - ENNReal.ofReal (Real.exp (local_field * new_act_val / T)) - let total := probs true + probs false - ENNReal.ofReal (Real.exp (-local_field / T)) / total = - ENNReal.ofReal (1 / (1 + Real.exp (2 * local_field / T))) := by - intro probs total - have h_total : total = ENNReal.ofReal (Real.exp (local_field / T)) + - ENNReal.ofReal (Real.exp (-local_field / T)) := by +lemma gibbs_prob_negative (local_field : ℝ) (T : ℝ) : + let total := probs T u local_field true + probs T u local_field false + ENNReal.ofReal (exp (-local_field / T)) / total = + ENNReal.ofReal (1 / (1 + exp (2 * local_field / T))) := by + intro total + have h_total : total = ENNReal.ofReal (exp (local_field / T)) + + ENNReal.ofReal (exp (-local_field / T)) := by simp only [mul_ite, mul_one, mul_neg, ↓reduceIte, Bool.false_eq_true, total, probs] + rfl rw [h_total] have h_neg2_neg : -2 * (-local_field / T) = 2 * local_field / T := by ring have h_neg_neg : -(-local_field / T) = local_field / T := by ring - have h_ratio_final : ENNReal.ofReal (Real.exp (-local_field / T)) / - (ENNReal.ofReal (Real.exp (local_field / T)) + - ENNReal.ofReal (Real.exp (-local_field / T))) = - ENNReal.ofReal (1 / (1 + Real.exp (2 * local_field / T))) := by + have h_ratio_final : ENNReal.ofReal (exp (-local_field / T)) / + (ENNReal.ofReal (exp (local_field / T)) + + ENNReal.ofReal (exp (-local_field / T))) = + ENNReal.ofReal (1 / (1 + exp (2 * local_field / T))) := by have h := ENNReal_exp_ratio_to_sigmoid (-local_field / T) - have h_exp_neg_neg : ENNReal.ofReal (Real.exp (-(-local_field / T))) = - ENNReal.ofReal (Real.exp (local_field / T)) := by congr + have h_exp_neg_neg : ENNReal.ofReal (exp (-(-local_field / T))) = + ENNReal.ofReal (exp (local_field / T)) := by congr rw [h_exp_neg_neg] at h - have h_comm : ENNReal.ofReal (Real.exp (-local_field / T)) + - ENNReal.ofReal (Real.exp (local_field / T)) = - ENNReal.ofReal (Real.exp (local_field / T)) + - ENNReal.ofReal (Real.exp (-local_field / T)) := by + have h_comm : ENNReal.ofReal (exp (-local_field / T)) + + ENNReal.ofReal (exp (local_field / T)) = + ENNReal.ofReal (exp (local_field / T)) + + ENNReal.ofReal (exp (-local_field / T)) := by rw [add_comm] - rw [h_neg2_neg] at h - rw [h_comm] at h + rw [h_neg2_neg, h_comm] at h exact h exact h_ratio_final @@ -611,35 +566,35 @@ lemma gibbs_prob_negative lemma gibbs_prob_positive_case (u : U) : let local_field := s.net wθ u - let Z := ENNReal.ofReal (Real.exp (local_field / T)) + ENNReal.ofReal (Real.exp (-local_field / T)) - let norm_probs := λ b => if b then - ENNReal.ofReal (Real.exp (local_field / T)) / Z + let Z := ENNReal.ofReal (exp (local_field / T)) + ENNReal.ofReal (exp (-local_field / T)) + let norm_probs := fun b => if b then + ENNReal.ofReal (exp (local_field / T)) / Z else - ENNReal.ofReal (Real.exp (-local_field / T)) / Z + ENNReal.ofReal (exp (-local_field / T)) / Z (PMF.map (gibbs_bool_to_state_map s u) (PMF.ofFintype norm_probs (by have h_sum : ∑ b : Bool, norm_probs b = norm_probs true + norm_probs false := by - exact Fintype.sum_bool (λ b => norm_probs b) + exact Fintype.sum_bool (fun b => norm_probs b) rw [h_sum] simp only [norm_probs] - have h_ratio_sum : ENNReal.ofReal (Real.exp (local_field / T)) / Z + - ENNReal.ofReal (Real.exp (-local_field / T)) / Z = - (ENNReal.ofReal (Real.exp (local_field / T)) + - ENNReal.ofReal (Real.exp (-local_field / T))) / Z := by + have h_ratio_sum : ENNReal.ofReal (exp (local_field / T)) / Z + + ENNReal.ofReal (exp (-local_field / T)) / Z = + (ENNReal.ofReal (exp (local_field / T)) + + ENNReal.ofReal (exp (-local_field / T))) / Z := by exact ENNReal.div_add_div_same simp only [Bool.false_eq_true] - have h_if_true : (if True then ENNReal.ofReal (Real.exp (local_field / T)) / Z - else ENNReal.ofReal (Real.exp (-local_field / T)) / Z) = - ENNReal.ofReal (Real.exp (local_field / T)) / Z := by simp + have h_if_true : (if True then ENNReal.ofReal (exp (local_field / T)) / Z + else ENNReal.ofReal (exp (-local_field / T)) / Z) = + ENNReal.ofReal (exp (local_field / T)) / Z := by simp - have h_if_false : (if False then ENNReal.ofReal (Real.exp (local_field / T)) / Z - else ENNReal.ofReal (Real.exp (-local_field / T)) / Z) = - ENNReal.ofReal (Real.exp (-local_field / T)) / Z := by simp + have h_if_false : (if False then ENNReal.ofReal (exp (local_field / T)) / Z + else ENNReal.ofReal (exp (-local_field / T)) / Z) = + ENNReal.ofReal (exp (-local_field / T)) / Z := by simp rw [h_if_true, h_if_false] rw [h_ratio_sum] have h_Z_ne_zero : Z ≠ 0 := by simp only [ne_eq, add_eq_zero, ENNReal.ofReal_eq_zero, not_and, not_le, Z, norm_probs] intros - exact Real.exp_pos (-Coe.coe local_field / T) + exact exp_pos (-Coe.coe local_field / T) have h_Z_ne_top : Z ≠ ⊤ := by simp [Z] exact ENNReal.div_self h_Z_ne_zero h_Z_ne_top ))) (NN.State.updateNeuron s u 1 (Or.inl rfl)) = norm_probs true := by @@ -647,24 +602,23 @@ lemma gibbs_prob_positive_case apply pmf_map_update_one -- Lemma for the probability calculation in the negative case -lemma gibbs_prob_negative_case - (u : U) : +lemma gibbs_prob_negative_case (u : U) : let local_field := s.net wθ u - let Z := ENNReal.ofReal (Real.exp (local_field / T)) + - ENNReal.ofReal (Real.exp (-local_field / T)) - let norm_probs := λ b => if b then - ENNReal.ofReal (Real.exp (local_field / T)) / Z + let Z := ENNReal.ofReal (exp (local_field / T)) + + ENNReal.ofReal (exp (-local_field / T)) + let norm_probs := fun b => if b then + ENNReal.ofReal (exp (local_field / T)) / Z else - ENNReal.ofReal (Real.exp (-local_field / T)) / Z + ENNReal.ofReal (exp (-local_field / T)) / Z (PMF.map (gibbs_bool_to_state_map s u) (PMF.ofFintype norm_probs (by have h_sum : ∑ b : Bool, norm_probs b = norm_probs true + norm_probs false := by - exact Fintype.sum_bool (λ b => norm_probs b) + exact Fintype.sum_bool (fun b => norm_probs b) rw [h_sum] simp only [norm_probs] - have h_ratio_sum : ENNReal.ofReal (Real.exp (local_field / T)) / Z + - ENNReal.ofReal (Real.exp (-local_field / T)) / Z = - (ENNReal.ofReal (Real.exp (local_field / T)) + - ENNReal.ofReal (Real.exp (-local_field / T))) / Z := by + have h_ratio_sum : ENNReal.ofReal (exp (local_field / T)) / Z + + ENNReal.ofReal (exp (-local_field / T)) / Z = + (ENNReal.ofReal (exp (local_field / T)) + + ENNReal.ofReal (exp (-local_field / T))) / Z := by exact ENNReal.div_add_div_same simp only [Bool.false_eq_true] simp only [↓reduceIte, norm_probs] @@ -672,9 +626,9 @@ lemma gibbs_prob_negative_case have h_Z_ne_zero : Z ≠ 0 := by simp only [Z, ne_eq, add_eq_zero] intro h - have h_exp_pos : ENNReal.ofReal (Real.exp (local_field / T)) > 0 := by + have h_exp_pos : ENNReal.ofReal (exp (local_field / T)) > 0 := by apply ENNReal.ofReal_pos.mpr - apply Real.exp_pos + apply exp_pos exact (not_and_or.mpr (Or.inl h_exp_pos.ne')) h have h_Z_ne_top : Z ≠ ⊤ := by simp only [ne_eq, ENNReal.add_eq_top, ENNReal.ofReal_ne_top, or_self, not_false_eq_true, Z, @@ -688,125 +642,64 @@ lemma gibbs_prob_negative_case lemma gibbsUpdate_pmf_structure (u : U) : let local_field := s.net wθ u - let probs : Bool → ENNReal := fun b => - let new_act_val := if b then 1 else -1 - ENNReal.ofReal (Real.exp (local_field * new_act_val / T)) - let total := probs true + probs false - let norm_probs := λ b => probs b / total + let total := probs T u local_field true + probs T u local_field false + let norm_probs := fun b => probs T u local_field b / total ∀ b : Bool, (PMF.map (gibbs_bool_to_state_map s u) (PMF.ofFintype norm_probs (by have h_sum : ∑ b : Bool, norm_probs b = norm_probs true + norm_probs false := by - exact Fintype.sum_bool (λ b => norm_probs b) + exact Fintype.sum_bool (fun b => norm_probs b) rw [h_sum] - have h_ratio_sum : probs true / total + probs false / total = - (probs true + probs false) / total := by + have h_ratio_sum : probs T u local_field true / total + probs T u local_field false / total = + (probs T u local_field true + probs T u local_field false) / total := by exact ENNReal.div_add_div_same rw [h_ratio_sum] - have h_total_ne_zero : total ≠ 0 := by - simp only [total, probs, ne_eq, add_eq_zero] - intro h - have h_exp_pos : ENNReal.ofReal (Real.exp (local_field * 1 / T)) > 0 := by - apply ENNReal.ofReal_pos.mpr - apply Real.exp_pos - exact (not_and_or.mpr (Or.inl h_exp_pos.ne')) h - have h_total_ne_top : total ≠ ⊤ := by simp only [mul_ite, mul_one, mul_neg, ↓reduceIte, - Bool.false_eq_true, ne_eq, ENNReal.add_eq_top, ENNReal.ofReal_ne_top, or_self, - not_false_eq_true, total, probs] - exact ENNReal.div_self h_total_ne_zero h_total_ne_top + exact ENNReal.div_self (h_total_ne_zero T u local_field) (h_total_ne_top T u local_field) ))) (gibbs_bool_to_state_map s u b) = norm_probs b := by - intro local_field probs total norm_probs b_bool + intro local_field total norm_probs b_bool exact pmf_map_binary_state s u b_bool (fun b => norm_probs b) (by have h_sum : ∑ b : Bool, norm_probs b = norm_probs true + norm_probs false := by - exact Fintype.sum_bool (λ b => norm_probs b) + exact Fintype.sum_bool (fun b => norm_probs b) rw [h_sum] - have h_ratio_sum : probs true / total + probs false / total = - (probs true + probs false) / total := by + have h_ratio_sum : probs T u local_field true / total + probs T u local_field false / total = + (probs T u local_field true + probs T u local_field false) / total := by exact ENNReal.div_add_div_same rw [h_ratio_sum] - have h_total_ne_zero : total ≠ 0 := by - simp only [total, probs, ne_eq, add_eq_zero] - intro h - have h_exp_pos : ENNReal.ofReal (Real.exp (local_field * 1 / T)) > 0 := by - apply ENNReal.ofReal_pos.mpr - apply Real.exp_pos - exact (not_and_or.mpr (Or.inl h_exp_pos.ne')) h - have h_total_ne_top : total ≠ ⊤ := by simp only [mul_ite, mul_one, mul_neg, ↓reduceIte, - Bool.false_eq_true, ne_eq, ENNReal.add_eq_top, ENNReal.ofReal_ne_top, or_self, - not_false_eq_true, total, probs] - exact ENNReal.div_self h_total_ne_zero h_total_ne_top) + exact ENNReal.div_self (h_total_ne_zero T u local_field) (h_total_ne_top T u local_field)) + +def h_result_update_one (u : U) (local_field : R) := + pmf_map_update_one s u (norm_probs T u local_field ) (by + rw [h_sum] + exact ENNReal.div_self (h_total_ne_zero T u local_field) (h_total_ne_top T u local_field)) + +def h_result_neg_one (u : U) (local_field : R) := + pmf_map_update_neg_one s u (norm_probs T u local_field ) (by + rw [h_sum] + exact ENNReal.div_self (h_total_ne_zero T u local_field) (h_total_ne_top T u local_field)) /-- The probability of updating a neuron to 1 using Gibbs sampling -/ -lemma gibbsUpdate_prob_positive - (u : U) : +lemma gibbsUpdate_prob_positive (u : U) : let local_field := s.net wθ u - let Z := ENNReal.ofReal (Real.exp (local_field / T)) + ENNReal.ofReal (Real.exp (-local_field / T)) + --let Z := ENNReal.ofReal (exp (local_field / T)) + ENNReal.ofReal (exp (-local_field / T)) (NN.State.gibbsUpdateSingleNeuron wθ s T u) (NN.State.updateNeuron s u 1 (Or.inl rfl)) = - ENNReal.ofReal (Real.exp (local_field / T)) / Z := by - intro local_field Z + ENNReal.ofReal (exp (local_field / T)) / (Z T local_field) := by + intro local_field --Z unfold NN.State.gibbsUpdateSingleNeuron - let probs : Bool → ENNReal := fun b => - let new_act_val := if b then 1 else -1 - ENNReal.ofReal (Real.exp (local_field * new_act_val / T)) - let total := probs true + probs false - have h_total_eq_Z : total = Z := by - simp only [mul_ite, mul_one, mul_neg, ↓reduceIte, Bool.false_eq_true, total, probs, Z] - have h_result := pmf_map_update_one s u (fun b => probs b / total) (by - have h_sum : ∑ b : Bool, probs b / total = (probs true + probs false) / total := by - simp only [Fintype.univ_bool, mem_singleton, Bool.true_eq_false, - not_false_eq_true,sum_insert, sum_singleton, total, probs, Z] - exact - ENNReal.div_add_div_same - rw [h_sum] - have h_total_ne_zero : total ≠ 0 := by - simp only [total, probs, ne_eq, add_eq_zero] - intro h - have h_exp_pos : ENNReal.ofReal (Real.exp (local_field * 1 / T)) > 0 := by - apply ENNReal.ofReal_pos.mpr - apply Real.exp_pos - exact (not_and_or.mpr (Or.inl h_exp_pos.ne')) h - have h_total_ne_top : total ≠ ⊤ := by simp [total, probs] - exact ENNReal.div_self h_total_ne_zero h_total_ne_top) - rw [h_result] - simp only [probs, mul_one_div] + rw [h_result_update_one] + simp only [probs, mul_one_div, norm_probs] rw [h_total_eq_Z] - simp only [if_true, mul_one] + simp only [if_true, mul_one,local_field] /-- The probability of updating a neuron to -1 using Gibbs sampling -/ -lemma gibbsUpdate_prob_negative - (u : U) : +lemma gibbsUpdate_prob_negative (u : U) : let local_field := s.net wθ u - let Z := ENNReal.ofReal (Real.exp (local_field / T)) + ENNReal.ofReal (Real.exp (-local_field / T)) + --let Z := ENNReal.ofReal (exp (local_field / T)) + ENNReal.ofReal (exp (-local_field / T)) (NN.State.gibbsUpdateSingleNeuron wθ s T u) (NN.State.updateNeuron s u (-1) (Or.inr rfl)) = - ENNReal.ofReal (Real.exp (-local_field / T)) / Z := by - intro local_field Z + ENNReal.ofReal (exp (-local_field / T)) / (Z T local_field) := by + intro local_field unfold NN.State.gibbsUpdateSingleNeuron - let probs : Bool → ENNReal := fun b => - let new_act_val := if b then 1 else -1 - ENNReal.ofReal (Real.exp (local_field * new_act_val / T)) - let total := probs true + probs false - have h_total_eq_Z : total = Z := by - simp only [mul_ite, mul_one, mul_neg, ↓reduceIte, Bool.false_eq_true, total, probs, Z] - have h_result := pmf_map_update_neg_one s u (fun b => probs b / total) (by - have h_sum : ∑ b : Bool, probs b / total = (probs true + probs false) / total := by - simp only [Fintype.univ_bool, mem_singleton, Bool.true_eq_false, - not_false_eq_true, sum_insert, sum_singleton, total, probs, Z] - exact ENNReal.div_add_div_same - rw [h_sum] - have h_total_ne_zero : total ≠ 0 := by - simp only [total, probs, ne_eq, add_eq_zero] - intro h - have h_exp_pos : ENNReal.ofReal (Real.exp (local_field * 1 / T)) > 0 := by - apply ENNReal.ofReal_pos.mpr - apply Real.exp_pos - exact (not_and_or.mpr (Or.inl h_exp_pos.ne')) h - have h_total_ne_top : total ≠ ⊤ := by - simp only [mul_ite, mul_one, mul_neg, ↓reduceIte, - Bool.false_eq_true, ne_eq, ENNReal.add_eq_top, ENNReal.ofReal_ne_top, or_self, - not_false_eq_true, total, probs, Z] - exact ENNReal.div_self h_total_ne_zero h_total_ne_top) - rw [h_result] - simp only [probs, one_div_neg_one_eq_neg_one, one_div_neg_one_eq_neg_one] + rw [h_result_neg_one] + simp only [probs, one_div_neg_one_eq_neg_one, one_div_neg_one_eq_neg_one, norm_probs] rw [h_total_eq_Z] - simp only [Bool.false_eq_true, ↓reduceIte, mul_neg, mul_one, probs, Z, total] + simp only [Bool.false_eq_true, ↓reduceIte, mul_neg, mul_one, probs, Z, total, local_field] /-- Computes the probability of updating a neuron to a specific value using Gibbs sampling. - If new_val = 1: probability = exp(local_field/T)/Z @@ -817,12 +710,13 @@ where Z is the normalization constant (partition function). lemma gibbs_update_single_neuron_prob (u : U) (new_val : R) (hval : (HopfieldNetwork R U).pact new_val) : let local_field := s.net wθ u - let Z := ENNReal.ofReal (Real.exp (local_field / T)) + ENNReal.ofReal (Real.exp (-local_field / T)) + let Z := ENNReal.ofReal (exp (local_field / T)) + + ENNReal.ofReal (exp (-local_field / T)) (NN.State.gibbsUpdateSingleNeuron wθ s T u) (NN.State.updateNeuron s u new_val hval) = if new_val = 1 then - ENNReal.ofReal (Real.exp (local_field / T)) / Z + ENNReal.ofReal (exp (local_field / T)) / Z else - ENNReal.ofReal (Real.exp (-local_field / T)) / Z := by + ENNReal.ofReal (exp (-local_field / T)) / Z := by intro local_field Z by_cases h_val : new_val = 1 · rw [if_pos h_val] @@ -877,14 +771,8 @@ lemma gibbs_update_zero_other_sites (s s' : (HopfieldNetwork R U).State) lemma gibbs_transition_sum_simplification (s s' : (HopfieldNetwork R U).State) (u : U) (h : ∀ v : U, v ≠ u → s.act v = s'.act v) (h_diff : s.act u ≠ s'.act u) : let neuron_pmf : PMF U := PMF.ofFintype - (λ _ => (1 : ENNReal) / (Fintype.card U : ENNReal)) - (by - simp only [one_div, sum_const, card_univ, nsmul_eq_mul] - have h_card_ne_zero : (Fintype.card U : ENNReal) ≠ 0 := by - simp only [ne_eq, Nat.cast_eq_zero] - exact Fintype.card_ne_zero - have h_card_ne_top : (Fintype.card U : ENNReal) ≠ ⊤ := ENNReal.natCast_ne_top (Fintype.card U) - rw [← ENNReal.mul_inv_cancel h_card_ne_zero h_card_ne_top]) + (fun _ => (1 : ENNReal) / (Fintype.card U : ENNReal)) + (NN.State.gibbsSamplingStep.extracted_1) let update_prob (v : U) : ENNReal := (NN.State.gibbsUpdateSingleNeuron wθ s T v) s' ∑ v ∈ Finset.univ, neuron_pmf v * update_prob v = neuron_pmf u * update_prob u := by intro neuron_pmf update_prob @@ -916,25 +804,12 @@ lemma gibbs_update_preserves_other_sites (v u : U) (hvu : v ≠ u) : -- Case s_next = updateNeuron s v (-1) rw [h_neg] exact - updateNeuron_preserves s v u (-1) (AffineMap.lineMap_eq_lineMap_iff.mp rfl) (id (Ne.symm hvu)) + updateNeuron_preserves s v u (-1) (AffineMap.lineMap_eq_lineMap_iff.mp rfl) (id (Ne.symm hvu)) @[simp] lemma uniform_neuron_prob {U : Type} [Fintype U] [Nonempty U] (u : U) : (1 : ENNReal) / (Fintype.card U : ENNReal) = - PMF.ofFintype (λ _ : U => (1 : ENNReal) / (Fintype.card U : ENNReal)) - (by - rw [Finset.sum_const, Finset.card_univ] - simp only [nsmul_eq_mul] - have h_card_ne_zero : (Fintype.card U : ENNReal) ≠ 0 := by - simp only [ne_eq, Nat.cast_eq_zero] - exact Fintype.card_ne_zero - have h_card_ne_top : (Fintype.card U : ENNReal) ≠ ⊤ := ENNReal.natCast_ne_top _ - rw [ENNReal.div_eq_inv_mul] - rw [mul_comm] - rw [← ENNReal.mul_inv_cancel h_card_ne_zero h_card_ne_top] - rw [ENNReal.inv_mul_cancel_left h_card_ne_zero h_card_ne_top] - simp_all only [ne_eq, Nat.cast_eq_zero, Fintype.card_ne_zero, - not_false_eq_true, ENNReal.natCast_ne_top] - rw [mul_comm] + PMF.ofFintype (fun _ : U => (1 : ENNReal) / (Fintype.card U : ENNReal)) + (by exact NN.State.gibbsSamplingStep.extracted_1 ) u := by simp only [one_div, PMF.ofFintype_apply] diff --git a/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/StochasticAux.lean b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/StochasticAux.lean index e570c868c..57f800743 100644 --- a/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/StochasticAux.lean +++ b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/StochasticAux.lean @@ -39,8 +39,7 @@ lemma uniform_neuron_selection_prob_valid : simp only [ne_eq, Nat.cast_eq_zero] exact ne_of_gt h_card_pos have h_card_top : (Fintype.card U : ENNReal) ≠ ⊤ := ENNReal.natCast_ne_top (Fintype.card U) - rw [ENNReal.div_eq_inv_mul] - rw [nsmul_eq_mul] + rw [ENNReal.div_eq_inv_mul, nsmul_eq_mul] simp only [mul_one] rw [ENNReal.mul_inv_cancel h_card_ne_zero h_card_top] @@ -277,7 +276,7 @@ lemma pmf_filter_update_neuron ext b simp only [mem_filter, mem_univ, true_and] rw [@bool_update_eq_iff] - simp only [h1, and_false, h2, or_self, ↓reduceIte, not_mem_empty] + simp only [h1, and_false, h2, or_self, ↓reduceIte, Finset.notMem_empty] /-- For a PMF over binary values mapped to states, the probability of a specific state equals the probability of its corresponding binary value -/ @@ -352,9 +351,10 @@ lemma pmf_map_binary_state simp only [ne_eq, Bool.true_eq_false, not_false_eq_true] simp only [h_false_neq, if_true, if_false, zero_add] -/-- A specialized version of the previous lemma for the case where the state is an update with new_val = 1 -/ -lemma pmf_map_update_one - (s : (HopfieldNetwork R U).State) (u : U) (p : Bool → ENNReal) (h_sum : ∑ b, p b = 1) : +/-- A specialized version of the previous lemma for the case where the state + is an update with new_val = 1 -/ +lemma pmf_map_update_one (s : (HopfieldNetwork R U).State) (u : U) + (p : Bool → ENNReal) (h_sum : ∑ b, p b = 1) : let f : Bool → (HopfieldNetwork R U).State := λ b => if b then NN.State.updateNeuron s u 1 (Or.inl rfl) else NN.State.updateNeuron s u (-1) (Or.inr rfl) From 9499eddf8e1dc52f28baf051ac54e2538f8b0452 Mon Sep 17 00:00:00 2001 From: mkaratarakis Date: Fri, 1 Aug 2025 22:49:49 +0200 Subject: [PATCH 05/15] added 1982 HN paper files --- .../Hopfield82/ContentAddressableMemory.lean | 389 +++++++ .../Hopfield82/EnergyConvergence.lean | 53 + .../HopfieldNetwork/Hopfield82/Example.lean | 246 +++++ .../Hopfield82/FaultTolerance.lean | 952 ++++++++++++++++++ .../Hopfield82/MemoryConfusion.lean | 78 ++ .../Hopfield82/MemoryStorage.lean | 38 + .../Hopfield82/PhaseSpaceFlow.lean | 190 ++++ .../HopfieldNetwork/Stochastic.lean | 3 +- 8 files changed, 1947 insertions(+), 2 deletions(-) create mode 100644 PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Hopfield82/ContentAddressableMemory.lean create mode 100644 PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Hopfield82/EnergyConvergence.lean create mode 100644 PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Hopfield82/Example.lean create mode 100644 PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Hopfield82/FaultTolerance.lean create mode 100644 PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Hopfield82/MemoryConfusion.lean create mode 100644 PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Hopfield82/MemoryStorage.lean create mode 100644 PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Hopfield82/PhaseSpaceFlow.lean diff --git a/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Hopfield82/ContentAddressableMemory.lean b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Hopfield82/ContentAddressableMemory.lean new file mode 100644 index 000000000..6192fa0ac --- /dev/null +++ b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Hopfield82/ContentAddressableMemory.lean @@ -0,0 +1,389 @@ +import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.Hopfield82.MemoryConfusion +import Mathlib.Combinatorics.Enumerative.Bell +import Mathlib.Combinatorics.SimpleGraph.Finite + +namespace Hopfield82 + +open NeuralNetwork State Matrix Finset Real + +variable {R U : Type} [Field R] [LinearOrder R] [IsStrictOrderedRing R] [DecidableEq U] [Fintype U] [Nonempty U] [Inhabited U] + +/-! ### Content-Addressable Memory -/ + +/-- +The `retrievalDistance` function measures how far from a pattern we can initialize +the network and still have it converge to that pattern. + +From the paper (p.2556): "For distance ≤ 5, the nearest state was reached more than +90% of the time. Beyond that distance, the probability fell off smoothly." +-/ +noncomputable def retrievalDistance (wθ : Params (HopfieldNetwork R U)) (p : PhaseSpacePoint R U) + (useq : ℕ → U) (hf : fair useq) : ℕ := + sSup {d : ℕ | ∀ s : PhaseSpacePoint R U, hammingDistance s p ≤ d → s ∈ BasinOfAttraction wθ p useq hf} +where + hammingDistance (s₁ s₂ : PhaseSpacePoint R U) : ℕ := + card {i | s₁.act i ≠ s₂.act i} + + +/-! +# Content-Addressable Memory Properties of Hopfield Networks + +## Main Components + +* `ContentAddressableMemory`: Formal definition of content-addressable memory properties +* `BasinOfAttraction`: Analysis of attractor basins around stored patterns +* `PatternCompletion`: Completion of partial patterns and error correction +* `FamiliarityRecognition`: Recognition of familiar vs. unfamiliar inputs +* `CategoryFormation`: Emergence of categories and generalization + +## References + +* Hopfield, J.J. (1982). Neural networks and physical systems with emergent collective + computational abilities. Proceedings of the National Academy of Sciences, 79(8), 2554-2558. +* Hertz, J., Krogh, A., & Palmer, R.G. (1991). Introduction to the theory of neural + computation. Addison-Wesley. +-/ + +open NeuralNetwork State Matrix Finset Real +open BigOperators Order MeasureTheory Set + +-- variable {R U : Type} [Field R] [LinearOrder R] [IsStrictOrderedRing R] [DecidableEq U] [Fintype U] [Nonempty U] -- Already declared above + +-- Define a single instance and make sure it's properly placed before it's used +instance : Inhabited (PhaseSpacePoint R U) := + ⟨{ act := fun _ => 1, hp := fun _ => Or.inl rfl }⟩ + +-- Derive Nonempty from Inhabited automatically +instance : Nonempty (PhaseSpacePoint R U) := inferInstance + +/-! ### Content-Addressable Memory Formalization -/ + +/-- +The `HammingDistance` between two states measures the number of neurons with different activations. +This provides a metric for measuring the similarity between patterns. +-/ +def HammingDistance (s₁ s₂ : PhaseSpacePoint R U) : ℕ := + card {i : U | s₁.act i ≠ s₂.act i} + +/-- +`PatternCompletionThreshold` is an empirically suggested maximum Hamming distance from a stored pattern +such that the network can still reliably retrieve the complete pattern. +The actual threshold for a specific CAM instance is part of the `ContentAddressableMemory` structure. + +The paper suggests (p.2556): "For distance ≤ 5, the nearest state was reached more than +90% of the time" for a network of 30 neurons. This definition approximates that. +-/ +noncomputable def PatternCompletionThreshold (N : ℕ) : ℕ := + Nat.floor (0.15 * (N : Real)) + +/-- +A `ContentAddressableMemory` is a system that can retrieve a complete pattern +from a partial or corrupted version. + +This formalizes the central concept from the paper (p.2554): +"A general content-addressable memory would be capable of retrieving this entire +memory item on the basis of sufficient partial information." +-/ +structure ContentAddressableMemory (R U : Type) + [Field R] [LinearOrder R] [IsStrictOrderedRing R] [DecidableEq U] [Fintype U] [Nonempty U] where + /-- The Hopfield network parameters. -/ + params : Params (HopfieldNetwork R U) + /-- The set of stored patterns. -/ + patterns : Finset (PhaseSpacePoint R U) + /-- Proof that patterns is non-empty. -/ + patterns_nonempty : patterns.Nonempty + /-- The maximum Hamming distance for reliable pattern completion. + An empirical suggestion for this value is `PatternCompletionThreshold (Fintype.card U)`. -/ + threshold : ℕ + /-- Proof that all patterns are stable states of the network. -/ + patterns_stable : ∀ p ∈ patterns, FixedPoint params p + /-- Proof that corrupted patterns within threshold are retrieved correctly. -/ + completion_guarantee : ∀ p ∈ patterns, ∀ s : PhaseSpacePoint R U, + HammingDistance s p ≤ threshold → + ∃ useq : ℕ → U, ∃ (hf : fair useq), + HopfieldNet_stabilize params s useq hf = p + +/-! ### Pattern Completion and Error Correction -/ + +/-- A generic function type representing how a metric (like completion probability or familiarity) +decays or changes based on a distance-like value and network size. +Parameters: +- `metric_value`: The input value to the decay function (e.g., distance from threshold, closest distance). +- `network_size`: The size of the network (e.g., `Fintype.card U`). +- `R`: The type for the output (e.g., probabilities). +Returns a value in `R`. -/ +def MetricDecayFunction (R : Type) := (metric_value : ℕ) → (network_size : ℕ) → R + +/-- +A specific exponential decay model, often used to model probabilities or familiarity scores. +This corresponds to `exp(-value / (N/C))` where C is a constant (e.g., 10). +-/ +noncomputable def ExponentialDecayMetric [Field R] [LinearOrder R] [IsStrictOrderedRing R] [HDiv R R ℝ] [Coe ℝ R] : MetricDecayFunction R := + fun value network_size => + ((Real.exp (-((value : R) / ((network_size : R) / 10)))) : R) + +/-- +The `AbstractCompletionProbability` measures the likelihood of correctly completing a pattern +as a function of the Hamming distance `d` from the stored pattern, using a provided `decay_func`. + +If `d` is within `cam.threshold`, probability is 1. Beyond that, it's determined by `decay_func` +applied to `d - cam.threshold`. +This formalizes the empirical finding from the paper (p.2556): +"For distance ≤ 5, the nearest state was reached more than 90% of the time. +Beyond that distance, the probability fell off smoothly." +The `ExponentialDecayMetric` can be used as `decay_func` to model this smooth fall-off. +-/ +def AbstractCompletionProbability [Field R] [LinearOrder R] [IsStrictOrderedRing R] + (cam : ContentAddressableMemory R U) (p : PhaseSpacePoint R U) (d : ℕ) + (decay_func : MetricDecayFunction R) : R := + if p ∈ cam.patterns then + if d ≤ cam.threshold then 1 + else -- d > cam.threshold + -- Since d > cam.threshold, d - cam.threshold is a well-defined natural number representing + -- the distance beyond the threshold. + let distance_beyond_threshold := d - cam.threshold + decay_func distance_beyond_threshold (Fintype.card U) + else 0 + +/-- `ErrorCorrection` quantifies the network's ability to correct errors in the input pattern. +It's measured as the reduction in Hamming distance to the closest stored pattern after convergence. -/ +def ErrorCorrection (cam : ContentAddressableMemory R U) + (s : PhaseSpacePoint R U) (useq : ℕ → U) (hf : fair useq) : ℕ := + let s' := HopfieldNet_stabilize cam.params s useq hf; + let original_errors := Finset.min' (cam.patterns.image (fun p => HammingDistance s p)) (cam.patterns_nonempty.image (fun p => HammingDistance s p)); + let final_errors := Finset.min' (cam.patterns.image (fun p => HammingDistance s' p)) (cam.patterns_nonempty.image (fun p => HammingDistance s' p)); + original_errors - final_errors + +omit [Inhabited U] in +/-- +The `error_correction_guarantee` theorem establishes that Hopfield networks +can correct a substantial fraction of errors in the input pattern. + +This formalizes a key capability of content-addressable memories. +-/ +theorem error_correction_guarantee (cam : ContentAddressableMemory R U) + (p : PhaseSpacePoint R U) (hp : p ∈ cam.patterns) (s : PhaseSpacePoint R U) + (h_dist : HammingDistance s p ≤ cam.threshold) : + ∃ useq : ℕ → U, ∃ (hf : fair useq), HopfieldNet_stabilize cam.params s useq hf = p := by + exact cam.completion_guarantee p hp s h_dist + +/-! ### Basin of Attraction Analysis -/ + +/-- +The `BasinOfAttraction'` of a pattern is the set of all states that converge to it. +This concept is central to understanding the storage and retrieval properties of Hopfield networks. + +From the paper (p.2554): "Then, if the system is started sufficiently near any Xa, +as at X = Xa + Δ, it will proceed in time until X ≈ Xa." +-/ +def BasinOfAttraction' (cam : ContentAddressableMemory R U) (p : PhaseSpacePoint R U) + (useq : ℕ → U) (hf : fair useq) : Set (PhaseSpacePoint R U) := + {s | HopfieldNet_stabilize cam.params s useq hf = p} + +/-- +The `BasinVolume` is the "size" of the basin of attraction, measured as the +fraction of the state space that converges to a given pattern. + +This quantifies the robustness of memory retrieval. +-/ +def BasinVolume [DecidableEq (PhaseSpacePoint R U)] (cam : ContentAddressableMemory R U) (p : PhaseSpacePoint R U) + (useq : ℕ → U) (hf : fair useq) : R := + let total_states := (2 : R) ^ (Fintype.card U); + -- Use Finset.filter with a decidable predicate instead of trying to compute cardinality of a Set + let all_states := @Finset.univ (PhaseSpacePoint R U) (inferInstance : Fintype (PhaseSpacePoint R U)); + let basin_states := (Finset.filter (fun s : PhaseSpacePoint R U => + HopfieldNet_stabilize cam.params s useq hf = p) all_states).card; + (basin_states : R) / total_states + +/-- +The `basin_volume_bound` theorem establishes that the basin volume decreases +exponentially with the number of stored patterns. + +This formalizes how memory capacity affects retrieval robustness. +-/ +theorem basin_volume_bound (cam : ContentAddressableMemory R U) (p : PhaseSpacePoint R U) (hp : p ∈ cam.patterns) + (useq : ℕ → U) (hf : fair useq) : + BasinVolume cam p useq hf ≥ (1 : R) / ((2 : R)^cam.patterns.card) := by + sorry -- Requires statistical analysis of basin volumes, likely combinatorial arguments about pattern distribution and overlap. + +/-! ### Familiarity Recognition -/ + +/-- +`AbstractFamiliarityMeasure` quantifies how familiar a given state `s` is to the network, +based on its proximity (closest Hamming distance) to stored patterns, using a provided `decay_func`. + +The paper discusses (p.2557): "The state 00000... is always stable. For a threshold of 0, this +stable state is much higher in energy than the stored memory states and very seldom occurs." +A high familiarity measure (close to 1) indicates `s` is similar to a stored pattern. +The `ExponentialDecayMetric` can be used as `decay_func`. +-/ +def AbstractFamiliarityMeasure [Field R] [LinearOrder R] [IsStrictOrderedRing R] + (cam : ContentAddressableMemory R U) (s : PhaseSpacePoint R U) + (decay_func : MetricDecayFunction R) : R := + let distances := cam.patterns.image (fun p_img => HammingDistance s p_img); -- Renamed p to p_img + let closest_distance := Finset.min' distances (by + rcases cam.patterns_nonempty with ⟨p_exist, hp_exist⟩ + use HammingDistance s p_exist + apply Finset.mem_image.mpr + exists p_exist + ); + decay_func closest_distance (Fintype.card U) + +/-- +`IsFamiliar` determines whether a pattern should be recognized as familiar +based on a threshold familiarity measure, using a specific `decay_func` for the measure. +-/ +def IsFamiliar [Field R] [LinearOrder R] [IsStrictOrderedRing R] + (cam : ContentAddressableMemory R U) (s : PhaseSpacePoint R U) (threshold_val : R) -- Renamed threshold to threshold_val + (decay_func : MetricDecayFunction R) : Prop := + AbstractFamiliarityMeasure cam s decay_func ≥ threshold_val + +/-- +The `familiarity_recognition` theorem establishes that the network can distinguish +between familiar and unfamiliar inputs. A pattern close to a stored one should be familiar. + +This formalizes the paper's discussion (p.2557): "Adding a uniform threshold in +the algorithm is equivalent to raising the effective energy of the +stored memories compared to the 0000 state, and 0000 also +becomes a likely stable state. The 0000 state is then generated +by any initial state that does not resemble adequately closely +one of the assigned memories and represents positive recognition +that the starting state is not familiar." +-/ +theorem familiarity_recognition [Field R] [LinearOrder R] [IsStrictOrderedRing R] + (cam : ContentAddressableMemory R U) + (p : PhaseSpacePoint R U) (s : PhaseSpacePoint R U) (threshold_val : R) -- Renamed threshold to threshold_val + (decay_func : MetricDecayFunction R) -- Added decay_func parameter + (hp : p ∈ cam.patterns) : + HammingDistance s p ≤ cam.threshold → IsFamiliar cam s threshold_val decay_func := by + sorry -- Requires analysis of AbstractFamiliarityMeasure's properties with the given decay_func. + +/-! ### Category Formation and Generalization -/ + +/-- +`CategoryRepresentative` identifies a pattern that represents a category +of similar patterns. + +This formalizes the emergence of categories in Hopfield networks. +-/ +noncomputable def CategoryRepresentative [DecidableEq (PhaseSpacePoint R U)] + (cam : ContentAddressableMemory R U) + (category : Finset (PhaseSpacePoint R U)) + (useq : ℕ → U) (hf : fair useq) : PhaseSpacePoint R U := + let representatives := cam.patterns.filter (fun p => + ∀ s ∈ category, HopfieldNet_stabilize cam.params s useq hf = p); + match representatives.toList with + | [] => default -- Using the Inhabited instance defined earlier + | (p :: _) => p -- Take the first element from the list + +-- Helper function that turns equality of PhaseSpacePoints into a Bool +private def eqbPhaseSpacePoint {R U : Type} + [Field R] [LinearOrder R] [IsStrictOrderedRing R] [DecidableEq U] [Fintype U] [Nonempty U] + [DecidableEq (PhaseSpacePoint R U)] (x y : PhaseSpacePoint R U) : Bool := + if x = y then true else false + +/-- +`GeneralizationCapability` measures how many of the given categories +converge to a single pattern (i.e., a "correct generalization") using a +fixed update sequence `useq`. We count how many categories meet that +criterion, then divide by the total number of categories. + +This avoids needing a `DecidablePred (λ c, ∃ p, ... )` by explicitly iterating +over finite lists and checking the condition constructively. +-/ +noncomputable def GeneralizationCapability + (cam : ContentAddressableMemory R U) + (categories : Finset (Finset (PhaseSpacePoint R U))) + (useq : ℕ → U) (hf : fair useq) : R := + let num_categories := categories.card + let catList := categories.toList + let correctCount := catList.foldl + (fun acc c => + let cList := c.toList + if cam.patterns.toList.any (fun p => + cList.all (fun s => + eqbPhaseSpacePoint (HopfieldNet_stabilize cam.params s useq hf) p)) + then acc + 1 + else acc) + 0 + (correctCount : R) / (num_categories : R) + +/-- A simple measure of how closely two patterns match, returning a real number. -/ +def PatternSimilarity (p q : PhaseSpacePoint R U) : R := + (Fintype.card U - HammingDistance p q : ℕ) + +/-- +The `category_formation` theorem establishes that Hopfield networks naturally +form categories when presented with similar patterns. + +This formalizes the emergent categorization property discussed in the paper. +-/ +theorem category_formation (cam : ContentAddressableMemory R U) : + ∀ ε > 0, ∃ threshold_sim : R, ∀ p q : PhaseSpacePoint R U, -- Renamed threshold to threshold_sim + PatternSimilarity p q / (Fintype.card U : R) > threshold_sim → + ∃ r, FixedPoint cam.params r ∧ + PatternSimilarity p r / (Fintype.card U : R) > 1 - ε ∧ + PatternSimilarity q r / (Fintype.card U : R) > 1 - ε := by + sorry -- Requires analysis of category formation dynamics, potentially how Hebbian learning on similar patterns shapes the energy landscape. + +/-! ### Temporal Sequence Memory -/ + +/-- +`SequentialAttractor` represents a cyclic sequence of states that the network +can store and retrieve. + +The paper briefly discusses (p.2557-2558): "Additional nonsymmetric terms which could be easily +generated by a minor modification of Hebb synapses ... were added to Tij. When A was judiciously adjusted, the +system would spend a while near Vs and then leave and go to a point +near Vs+1." +-/ +structure SequentialAttractor (R U : Type) + [Field R] [LinearOrder R] [IsStrictOrderedRing R] [DecidableEq U] [Fintype U] [Nonempty U] where + /-- The parameters of the Hopfield network. These might include non-symmetric weights for sequence memory. -/ + params : Params (HopfieldNetwork R U) + /-- The sequence of states forming the attractor. -/ + states : List (PhaseSpacePoint R U) + /-- Proof that the list of states is not empty. -/ + h_nonempty : states ≠ [] + /-- + Proof that the sequence forms a cycle, meaning: + 1) For each index i < states.length - 1, the state at i+1 follows from the state at i after N steps. + 2) The last state in `states` leads to the first one after N steps. + -/ + cycle : + (∀ (i : Fin states.length) (h : i.1 < states.length - 1), + ∃ (useq : ℕ → U) (_ : fair useq) (N_steps : ℕ), -- Renamed N to N_steps + seqStates params (states.get i) useq N_steps = + states.get ⟨i.1 + 1, by + simpa [Nat.sub_add_cancel (List.length_pos_of_ne_nil h_nonempty)] using + Nat.succ_lt_succ h⟩ + ) + ∧ + ∃ (useq : ℕ → U) (_ : fair useq) (N_steps : ℕ), -- Renamed N to N_steps + seqStates params + (states.get + ⟨states.length - 1, + Nat.sub_lt (List.length_pos_of_ne_nil h_nonempty) (Nat.zero_lt_succ 0)⟩) + useq N_steps + = states.get ⟨0, List.length_pos_of_ne_nil h_nonempty⟩ + +/-- +`SequenceRetrieval` tests whether the network can recall a sequence when given +one of its elements.-/ +def SequenceRetrieval (seq_attr : SequentialAttractor R U) (s : PhaseSpacePoint R U) : Prop := -- Renamed seq to seq_attr + ∃ i, seq_attr.states[i]? = some s ∧ + ∀ j, ∃ useq : ℕ → U, ∃ _ : fair useq, ∃ N_steps : ℕ, -- Renamed N to N_steps + seqStates seq_attr.params s useq N_steps = seq_attr.states[(i + j) % seq_attr.states.length]? + +/-- +The `sequence_memory_limit` theorem establishes that Hopfield networks with +standard Hebbian learning (or simple modifications for sequences) have limited capacity for storing sequential patterns. + +The paper notes (p.2558): "But sequences longer than four states proved +impossible to generate, and even these were not faithfully +followed."-/ +theorem sequence_memory_limit (seq_attr : SequentialAttractor R U) : -- Renamed seq to seq_attr + seq_attr.states.length ≤ 4 := by + sorry -- Requires analysis of sequential memory capacity, likely specific to the (non-symmetric) learning rule used to store sequences. + +end Hopfield82 diff --git a/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Hopfield82/EnergyConvergence.lean b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Hopfield82/EnergyConvergence.lean new file mode 100644 index 000000000..65b660c3a --- /dev/null +++ b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Hopfield82/EnergyConvergence.lean @@ -0,0 +1,53 @@ +import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.Hopfield82.PhaseSpaceFlow +namespace Hopfield82 + +open NeuralNetwork State Matrix Finset Real + +variable {R U : Type} [Field R] [LinearOrder R] [IsStrictOrderedRing R] [DecidableEq U] [Fintype U] [Nonempty U] + +/-! ### Connection to Physical Systems -/ + +/-- +The `spin_glass_analogy` formalizes the connection between Hopfield networks and +physical spin glass systems, as discussed in the paper. + +From the paper (p.2555): "This case is isomorphic with an Ising model." +-/ +def spin_glass_analogy (wθ : Params (HopfieldNetwork R U)) : Prop := + wθ.w.IsSymm ∧ ∀ i, wθ.w i i = 0 + +/-- +The `energy_convergence` theorem formalizes the connection between energy minimization +and the convergence to fixed points. + +From the paper (p.2555): "State changes will continue until a least (local) E is reached." +-/ +theorem energy_convergence (wθ : Params (HopfieldNetwork R U)) (s : PhaseSpacePoint R U) + (useq : ℕ → U) (hf : fair useq) : + ∃ (p : PhaseSpacePoint R U), FixedPoint wθ p ∧ + (∀ q : PhaseSpacePoint R U, q ∈ BasinOfAttraction wθ p useq hf → p.E wθ ≤ q.E wθ) := by + obtain ⟨N, hN⟩ := HopfieldNet_convergence_fair s useq hf + let p := seqStates wθ s useq N + have p_fixed : FixedPoint wθ p := hN + use p, p_fixed + intro q hq + obtain ⟨n, hn, _⟩ := hq + have energy_decreases_along_sequence : + ∀ (q : PhaseSpacePoint R U) (m : ℕ), + (seqStates wθ q useq m).E wθ ≤ q.E wθ := by + intro q' m + induction' m with m ih + · simp only [seqStates, energy_decomposition, le_refl] + · have : seqStates wθ q' useq (m+1) = + (seqStates wθ q' useq m).Up wθ (useq m) := by + simp only [seqStates] + rw [this] + have energy_step : ((seqStates wθ q' useq m).Up wθ (useq m)).E wθ ≤ + (seqStates wθ q' useq m).E wθ := + energy_decrease wθ (seqStates wθ q' useq m) (useq m) + exact le_trans energy_step ih + have p_energy_leq_q : p.E wθ ≤ q.E wθ := by + rw [← hn] + exact energy_decreases_along_sequence q n + + exact p_energy_leq_q diff --git a/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Hopfield82/Example.lean b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Hopfield82/Example.lean new file mode 100644 index 000000000..1f46804c1 --- /dev/null +++ b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Hopfield82/Example.lean @@ -0,0 +1,246 @@ +/- +Copyright (c) 2025. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Matteo Cipollina +-/ + +import HopfieldNet.Papers.Hopfield82.ContentAddressableMemory +import HopfieldNet.Papers.Hopfield82.FaultTolerance +import HopfieldNet.Papers.Hopfield82.EnergyConvergence +import Mathlib.Data.Fintype.Basic + +/-! +# Example Implementation of Hopfield's 1982 Model + +This module provides concrete examples that demonstrate the key concepts from Hopfield's 1982 paper. +It implements a small Hopfield network and shows how it functions as a content-addressable memory. + +## Main Examples + +* `smallNetwork`: A 4-neuron Hopfield network storing 2 patterns +* `patternRetrieval`: Demonstration of retrieving patterns from partial information +* `patternCapacity`: Analysis of storage capacity in relation to network size +* `faultToleranceDemo`: Demonstration of robustness to component failures + +## References + +* Hopfield, J.J. (1982). Neural networks and physical systems with emergent collective + computational abilities. Proceedings of the National Academy of Sciences, 79(8), 2554-2558. +-/ + +namespace HopfieldExample + +open Hopfield82 NeuralNetwork State Matrix Finset + +/-! ### Small Example Network -/ + +/-- +`smallNetwork` creates a small Hopfield network with 4 neurons that stores 2 patterns. +The patterns are `[1, 1, -1, -1]` and `[-1, 1, -1, 1]`, which are orthogonal to each other. +-/ +def smallNetwork : Params (HopfieldNetwork ℚ (Fin 4)) := + let patterns : Fin 2 → (HopfieldNetwork ℚ (Fin 4)).State := + ![{ act := ![1, 1, -1, -1], hp := by decide }, + { act := ![-1, 1, -1, 1], hp := by decide }]; + + Hebbian patterns + +/-- +`hammingDistance` calculates the number of bits that differ between two states. +-/ +def hammingDistance {R U : Type} [Field R] [LinearOrder R] [IsStrictOrderedRing R] [Nonempty U] [DecidableEq R] [DecidableEq U] [Fintype U] + (s₁ s₂ : (HopfieldNetwork R U).State) : ℕ := + card {i | s₁.act i ≠ s₂.act i} + +/-- +`updateWithNoise` creates a corrupted version of a pattern by flipping a specified +number of bits randomly. + +For simplicity in this example, we deterministically flip the first `numBits` bits. +-/ +def updateWithNoise {R : Type} {n : ℕ} [Field R] [LinearOrder R] [IsStrictOrderedRing R] [DecidableEq (Fin n)] [Nonempty (Fin n)] + (s : (HopfieldNetwork R (Fin n)).State) (numBits : ℕ) : (HopfieldNetwork R (Fin n)).State := + if numBits = 0 then s else + { act := λ i => if i.val < numBits then -s.act i else s.act i, + hp := by + intro i + by_cases h_cond : i.val < numBits + · simp only [h_cond, if_true] + rcases s.hp i with h_s_act_eq_1 | h_s_act_eq_neg_1 + · rw [h_s_act_eq_1]; simp + · rw [h_s_act_eq_neg_1]; simp + · simp only [h_cond, if_false] + exact s.hp i + } +/-- +`pattern1` is the first pattern stored in the small network: `[1, 1, -1, -1]`. +-/ +def pattern1 : (HopfieldNetwork ℚ (Fin 4)).State := + { act := ![1, 1, -1, -1], hp := by decide } + +/-- +`pattern2` is the second pattern stored in the small network: `[-1, 1, -1, 1]`. +-/ +def pattern2 : (HopfieldNetwork ℚ (Fin 4)).State := + { act := ![-1, 1, -1, 1], hp := by decide } + +/-- +`useq_cyclic_example` is a cyclic update sequence for the 4-neuron network. +It cycles through neurons 0, 1, 2, 3 in order. +-/ +def useq_cyclic_example : ℕ → Fin 4 := + λ i => ⟨i % 4, by apply Nat.mod_lt; exact Nat.zero_lt_succ 3⟩ + +/-- +Proof that the update sequence is cyclic. +-/ +lemma useq_cyclic_example_is_cyclic : cyclic useq_cyclic_example := by + unfold cyclic + constructor + · intro u + use u.val + simp only [useq_cyclic_example] + have h : u.val % 4 = u.val := by + apply Nat.mod_eq_of_lt + exact u.2 + simp [h] + · intros i j h_mod + simp only [Fintype.card_fin] at h_mod + exact Fin.eq_of_val_eq h_mod + +/-- +`useq_cyclic_example_is_fair` proves that the cyclic update sequence is fair. +-/ +lemma useq_cyclic_example_is_fair : fair useq_cyclic_example := + cycl_Fair useq_cyclic_example useq_cyclic_example_is_cyclic + +/-! ### Pattern Retrieval Example -/ + +/-- +`demonstrateRetrieval` shows how the Hopfield network retrieves a stored pattern +from a corrupted version with noise. + +This example explicitly demonstrates the content-addressable memory property +described in the paper. +-/ +def demonstrateRetrieval (pattern : (HopfieldNetwork ℚ (Fin 4)).State) (numBitsFlipped : ℕ) : + (HopfieldNetwork ℚ (Fin 4)).State := + let corruptedPattern := updateWithNoise pattern numBitsFlipped; + HopfieldNet_stabilize smallNetwork corruptedPattern useq_cyclic_example useq_cyclic_example_is_fair + +/-! ### Pattern Orthogonality -/ + +/-- +`patternsAreOrthogonal` proves that the two patterns stored in our small network +are orthogonal to each other (their dot product is zero), which is a key requirement +for reliable storage and retrieval in Hopfield networks, as discussed in the paper. +-/ +theorem patternsAreOrthogonal : dotProduct pattern1.act pattern2.act = 0 := by + unfold dotProduct pattern1 pattern2 + + -- Expand the sum into individual terms + have sum_expansion : ∑ i : Fin 4, pattern1.act i * pattern2.act i = + pattern1.act 0 * pattern2.act 0 + + pattern1.act 1 * pattern2.act 1 + + pattern1.act 2 * pattern2.act 2 + + pattern1.act 3 * pattern2.act 3 := by + exact Fin.sum_univ_four fun i ↦ pattern1.act i * pattern2.act i + rw [@Fin.sum_univ_four] + + -- Evaluate each term manually + have p1_0 : pattern1.act 0 = 1 := rfl + have p1_1 : pattern1.act 1 = 1 := rfl + have p1_2 : pattern1.act 2 = -1 := rfl + have p1_3 : pattern1.act 3 = -1 := rfl + + have p2_0 : pattern2.act 0 = -1 := rfl + have p2_1 : pattern2.act 1 = 1 := rfl + have p2_2 : pattern2.act 2 = -1 := rfl + have p2_3 : pattern2.act 3 = 1 := rfl + + rw [Rat.add_assoc] + + -- Simplify the arithmetic + norm_num + +/-! ### Demonstration of Energy Function -/ + +/-- +`energyDecreases` demonstrates that the energy of the Hopfield network decreases +(or remains the same) with each update, as proven in the paper. + +This function returns a list of energy values as the network evolves from a +corrupted pattern to the stored pattern. +-/ +def energyDecreases (pattern : (HopfieldNetwork ℚ (Fin 4)).State) (numBitsFlipped : ℕ) : + List ℚ := + let corruptedPattern := updateWithNoise pattern numBitsFlipped; + let maxSteps := 10; -- Limit to avoid infinite loops + + let rec computeEnergySequence (s : (HopfieldNetwork ℚ (Fin 4)).State) (step : ℕ) : List ℚ := + if step = 0 then [s.E smallNetwork] + else + let nextS := s.Up smallNetwork (useq_cyclic_example step); + s.E smallNetwork :: computeEnergySequence nextS (step - 1); + + computeEnergySequence corruptedPattern maxSteps + +/-! ### Analysis of Pattern Capacity -/ + +/-- +`theoreticalCapacity` calculates the theoretical maximum number of patterns +that can be stored in a Hopfield network with N neurons. + +From the paper (p.2556): "About 0.15 N states can be simultaneously remembered +before error in recall is severe." +-/ +noncomputable def theoreticalCapacity (N : ℕ) : ℕ := + Nat.floor (0.15 * (N : ℝ)) + +/-! ### Fault Tolerance Demonstration -/ + +/-- +`demonstrateFaultTolerance` shows how the Hopfield network maintains functionality +even when some neurons are removed, as discussed in the paper. + +This function simulates the removal of a specified number of neurons and tests +whether the network can still correctly retrieve the stored patterns. +-/ +noncomputable def demonstrateFaultTolerance (numNeuronsRemoved : ℕ) (_ : numNeuronsRemoved ≤ 4) : Bool := + -- Use h to ensure we don't exceed the number of neurons in our network + let neuronSet : Finset (Fin 4) := Finset.filter (fun i => i.val < min numNeuronsRemoved (by exact + numNeuronsRemoved)) Finset.univ; + + -- Create a modified network with some neurons removed + let modifiedNetwork := DeleteNeurons neuronSet.toList smallNetwork; + + -- Test if the patterns are still stable in the modified network + let isPattern1Stable := isStable modifiedNetwork pattern1; + let isPattern2Stable := isStable modifiedNetwork pattern2; + + isPattern1Stable && isPattern2Stable + +/-! ### Demonstration of Content-Addressable Memory -/ + +/-- +`contentAddressableMemoryDemo` creates a Hopfield network that functions as a +content-addressable memory, as described in the paper. + +It shows that the network can retrieve complete patterns from partial information. +-/ +def contentAddressableMemoryDemo : Bool := + -- Create two patterns with sufficient Hamming distance + let p1 := pattern1; + let p2 := pattern2; + + -- Create Hopfield network that stores these patterns + let network := smallNetwork; + + -- Test retrieval with 1 bit flipped for each pattern + let test1 := HopfieldNet_stabilize network (updateWithNoise p1 1) useq_cyclic_example useq_cyclic_example_is_fair; + let test2 := HopfieldNet_stabilize network (updateWithNoise p2 1) useq_cyclic_example useq_cyclic_example_is_fair; + + -- Check if the network correctly retrieves the original patterns + (hammingDistance test1 p1 = 0) && (hammingDistance test2 p2 = 0) + +end HopfieldExample diff --git a/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Hopfield82/FaultTolerance.lean b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Hopfield82/FaultTolerance.lean new file mode 100644 index 000000000..dc03aff2a --- /dev/null +++ b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Hopfield82/FaultTolerance.lean @@ -0,0 +1,952 @@ +import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.Hopfield82.ContentAddressableMemory +import Mathlib.Algebra.BigOperators.Group.Finset.Sigma + +namespace Hopfield82 + +open NeuralNetwork State Matrix Finset Real + +variable {R U : Type} [Field R] [LinearOrder R] [IsStrictOrderedRing R] [Fintype U] [Nonempty U] + +/-! ### Fault Tolerance -/ + +/-- +The `DeleteNeuron` function simulates the failure of a neuron by removing its connections. +This corresponds to setting weights to/from that neuron to zero. + +The paper discusses (p.2558): "The collective properties are only weakly sensitive to +details of the modeling or the failure of individual devices." +-/ +noncomputable def DeleteNeuron [DecidableEq U] (i : U) + (wθ : Params (HopfieldNetwork R U)) : Params (HopfieldNetwork R U) where + w := λ j k => if j = i ∨ k = i then 0 else wθ.w j k + θ := wθ.θ + σ := wθ.σ + hw := by + intros u v huv + by_cases h : u = i ∨ v = i + · cases h with + | inl h_eq => rw [h_eq]; simp only [true_or, ↓reduceIte] + | inr h_eq => rw [h_eq]; simp only [or_true, ↓reduceIte] + · push_neg at h + simp only [h.1, h.2, or_self, ↓reduceIte] + exact wθ.hw u v huv + hw' := by + apply IsSymm.ext_iff.2 + intros j k + by_cases h : j = i ∨ k = i + · cases h with + | inl h_left => + rw [h_left] + simp only [or_true, ↓reduceIte, true_or] + | inr h_right => + rw [h_right] + simp only [true_or, ↓reduceIte, or_true] + · push_neg at h + simp only [h.1, h.2] + have hsymm := IsSymm.ext_iff.1 wθ.hw' + specialize hsymm j k + exact hsymm + +/-- Apply DeleteNeuron to a list of neurons sequentially -/ +noncomputable def DeleteNeurons [DecidableEq U] (neurons : List U) +(wθ : Params (HopfieldNetwork R U)) : Params (HopfieldNetwork R U) := + List.foldl (fun acc neuron => DeleteNeuron neuron acc) wθ neurons + + +-- The `FaultTolerance` of a Hopfield network is its ability to maintain function +-- despite the failure of some components. The paper notes that these networks are +-- inherently robust to component failures. + +/-- +Defines fault tolerance for a Hopfield network. + +Given a set of patterns `ps` and a specific pattern `ps k` from that set, +this property ensures that even after removing up to `f` neurons from the network, +the pattern `ps k` remains a fixed point under the network dynamics. + +Specifically: +- We calculate the weights and thresholds using Hebbian learning on all patterns +- For any subset of neurons with cardinality at most `f`, removing these neurons + still allows `ps k` to be a fixed point of the resulting network. + +This captures the fault-tolerance property described in Hopfield's 1982 paper on +content-addressable memory. + +### Parameters +- `ps` : A collection of patterns (states) in the Hopfield network +- `k` : Index of the pattern we want to check for fault tolerance +- `f` : Maximum number of neurons that can be removed while preserving the fixed point +-/ +def FaultTolerance {m : ℕ} [DecidableEq U] + (ps : Fin m → (HopfieldNetwork R U).State) + (k : Fin m) (f : ℕ) : Prop := + let wθ := Hebbian ps; + ∀ neurons_to_delete : Finset U, card neurons_to_delete ≤ f → + let wθ' := DeleteNeurons neurons_to_delete.toList wθ; + -- Check stability only for neurons *not* in the set of deleted neurons + ∀ u_check ∈ (Finset.univ : Finset U) \ neurons_to_delete, + ((ps k).Up wθ' u_check).act u_check = (ps k).act u_check + +omit [Field R] [LinearOrder R] [IsStrictOrderedRing R] [Fintype U] in +/-- When m is at most a tenth of total neurons, each pattern is fixed point in the undamaged network -/ +@[simp] +lemma pattern_stability_in_hebbian {m : ℕ} [Field R] [LinearOrder R] [IsStrictOrderedRing R] [Fintype U] [DecidableEq U] + (ps : Fin m → (HopfieldNetwork R U).State) + (horth : ∀ {i j : Fin m}, i ≠ j → dotProduct (ps i).act (ps j).act = 0) + (hm : m ≤ Fintype.card U / 10) (k : Fin m) : + FixedPoint (Hebbian ps) (ps k) := by + have hcard : Fintype.card U / 10 < Fintype.card U := by + have : Fintype.card U > 0 := Fintype.card_pos + apply Nat.div_lt_self + · exact this + · norm_num + have : m < Fintype.card U := lt_of_le_of_lt hm hcard + exact Hebbian_stable this ps k horth + +omit [Nonempty U] in +@[simp] +lemma filter_eq_singleton [DecidableEq U] (i : U) : filter (fun j => j = i) univ = {i} := by + apply Finset.ext + intro j + simp only [mem_filter, mem_univ, true_and] + exact Iff.symm mem_singleton + +omit [Fintype U] [Nonempty U] in +@[simp] +lemma sum_over_singleton (i : U) (f : U → R) : + ∑ j ∈ {i}, f j = f i := by + simp only [sum_singleton] + +omit [Nonempty U] in +@[simp] +lemma sum_filter_eq [DecidableEq U] (i : U) (f : U → R) : + ∑ j ∈ filter (fun j => j = i) univ, f j = f i := by + rw [filter_eq_singleton] + exact sum_over_singleton i f + +omit [Nonempty U] in +@[simp] +lemma sum_split_by_value [DecidableEq U] (i : U) (f : U → R) : + ∑ j, f j = f i + ∑ j ∈ filter (fun j => j ≠ i) univ, f j := by + have h_split : ∑ j, f j = + ∑ j, if j = i then f i else f j := by + apply Finset.sum_congr + · rfl + · intro j _ + by_cases h : j = i + · rw [h]; simp only [↓reduceIte] + · simp only [if_neg h] + rw [h_split, sum_ite] + congr + · exact sum_filter_eq i (fun _ => f i) + +omit [Fintype U] [Nonempty U] in +@[simp] +lemma condition_simplify_when_neq (i u j : U) (hu : u ≠ i) : + (j = i ∨ u = i) ↔ j = i := by + exact or_iff_left hu + +omit [Nonempty U] in +@[simp] +lemma sum_filter_eq_singleton [DecidableEq U] (i : U) (f : U → R) : + ∑ j ∈ filter (fun j => j = i) univ, f j = f i := by + rw [filter_eq_singleton] + exact sum_over_singleton i f + +omit [Nonempty U] in +@[simp] +lemma sum_split_by_eq [DecidableEq U](i : U) (f : U → R) : + ∑ j, f j = f i + ∑ j ∈ filter (fun j => j ≠ i) univ, f j := by + exact sum_split_by_value i f + +omit [Nonempty U] in +@[simp] +lemma sum_filter_neq_as_subtraction [DecidableEq U] (i : U) (f : U → R) : + ∑ j ∈ filter (fun j => j ≠ i) univ, f j = ∑ j, f j - f i := by + rw [sum_split_by_eq] + simp only [ne_eq, filter_erase_equiv, mem_univ, sum_erase_eq_sub, add_sub_cancel] + exact i + +@[simp] +lemma Finset.erase_eq_filter {α : Type*} [DecidableEq α] (s : Finset α) (a : α) : + s.erase a = s.filter (fun x => x ≠ a) := by + ext x + simp only [mem_erase, mem_filter] + constructor + · intro h + exact id (And.symm h) + · intro h + exact id (And.symm h) + +omit [Nonempty U] in +@[simp] +lemma sum_if_or_condition [DecidableEq U] (i u : U) (hu : u ≠ i) (f : U → R) : + ∑ j : U, (if j = i ∨ u = i then 0 else f j) = ∑ j : U, (if j = i then 0 else f j) := by + apply Finset.sum_congr + · rfl + · intro j _ + have h_equiv : (j = i ∨ u = i) ↔ j = i := by + constructor + · intro h + cases h with + | inl h_j => exact h_j + | inr h_u => exfalso; exact hu h_u + · intro h_j + exact Or.inl h_j + simp only [h_equiv] + +omit [Field R] [LinearOrder R] [IsStrictOrderedRing R] [Fintype U] [Nonempty U] in +lemma sum_if_eq_to_sub [Field R] [LinearOrder R] [IsStrictOrderedRing R] [Fintype U] [DecidableEq U] (i : U) (f : U → R) : + ∑ j : U, (if j = i then 0 else f j) = ∑ j : U, f j - f i := by + apply eq_sub_of_add_eq + rw [sum_split_by_eq i f] + rw [add_comm, add_left_cancel_iff] + rw [sum_ite] + rw [sum_const_zero, zero_add] + +omit [Field R] [LinearOrder R] [IsStrictOrderedRing R] [Nonempty U] in +lemma deleted_connection_sum [DecidableEq U] [Field R] [LinearOrder R] [IsStrictOrderedRing R] + (i u : U) (hu : u ≠ i) (w : Matrix U U R) (act : U → R) : + (∑ j : U, if j = i ∨ u = i then 0 else w u j * act j) = ∑ j : U, w u j * act j - w u i * act i := by + have h_condition_simp : ∀ j : U, (j = i ∨ u = i) ↔ j = i := by + intro j + constructor + · intro h + cases h with + | inl h_j => exact h_j + | inr h_u => exfalso; exact hu h_u + · intro h_j + exact Or.inl h_j + simp_rw [h_condition_simp] + apply sum_if_eq_to_sub i (fun j => w u j * act j) + +omit [Field R] [LinearOrder R] [IsStrictOrderedRing R] [Fintype U] in +@[simp] +lemma delete_one_neuron_effect_general +[Field R] [LinearOrder R] [IsStrictOrderedRing R] [Fintype U] [DecidableEq U] + (wθ : Params (HopfieldNetwork R U)) + (i : U) (u : U) (hu : u ≠ i) (act : U → R) : + (DeleteNeuron i wθ).w.mulVec act u = + wθ.w.mulVec act u - wθ.w u i * act i := by + unfold DeleteNeuron + simp only [mulVec, dotProduct] + have h_cond_simp : ∀ x : U, (u = i ∨ x = i) ↔ x = i := by + intro x + constructor + · intro h + cases h with + | inl h_u => exfalso; exact hu h_u + | inr h_x => exact h_x + · intro h_x + exact Or.inr h_x + simp_rw [h_cond_simp] + rw [Finset.sum_congr rfl fun x _ => by rw [ite_mul, zero_mul]] + apply sum_if_eq_to_sub i (fun x => wθ.w u x * act x) + +omit [Field R] [LinearOrder R] [IsStrictOrderedRing R] [Fintype U]in +@[simp] +lemma delete_one_neuron_effect {m : ℕ} +[Field R] [LinearOrder R] [IsStrictOrderedRing R] [Fintype U] [DecidableEq U] + (ps : Fin m → (HopfieldNetwork R U).State) + (i : U) (u : U) (hu : u ≠ i) (k : Fin m) : + (DeleteNeuron i (Hebbian ps)).w.mulVec (ps k).act u = + (Hebbian ps).w.mulVec (ps k).act u - (Hebbian ps).w u i * (ps k).act i := by + exact delete_one_neuron_effect_general (Hebbian ps) i u hu (ps k).act + +lemma delete_neuron_preserves_other_weights [DecidableEq U] (wθ : Params (HopfieldNetwork R U)) (i u v : U) + (hu : u ≠ i) (hv : v ≠ i) : + (DeleteNeuron i wθ).w u v = wθ.w u v := by + unfold DeleteNeuron + simp [hu, hv] + +lemma delete_neurons_order_independent [DecidableEq U] (wθ : Params (HopfieldNetwork R U)) (i j : U) : + DeleteNeuron i (DeleteNeuron j wθ) = DeleteNeuron j (DeleteNeuron i wθ) := by + have w_eq : (DeleteNeuron i (DeleteNeuron j wθ)).w = (DeleteNeuron j (DeleteNeuron i wθ)).w := by + ext u v + unfold DeleteNeuron + by_cases hu : u = i ∨ u = j + · by_cases hv : v = i ∨ v = j + · simp only [hu, hv, true_or, ↓reduceIte]; cases hu with + | inl h => + cases hv with + | inl h_1 => + subst h_1 h + simp_all only [or_self, ↓reduceIte, ite_self] + | inr h_2 => + subst h_2 h + simp_all only [true_or, ↓reduceIte, or_true] + | inr h_1 => + cases hv with + | inl h => + subst h_1 h + simp_all only [or_true, ↓reduceIte, true_or] + | inr h_2 => + subst h_1 h_2 + simp_all only [or_self, ↓reduceIte, ite_self] + · simp only [hu, hv, true_or, ↓reduceIte] + cases hu with + | inl h_ui => + simp only [h_ui, true_or, if_true] + simp only [ite_self] + | inr h_uj => + rw [h_uj] + simp only [eq_self_iff_true, true_or] + simp only [if_true] + exact ite_self 0 + · by_cases hv : v = i ∨ v = j + · simp only [hu, hv, true_or, ↓reduceIte] + cases hv with + | inl h_vi => + simp only [h_vi, true_or, if_true, if_true] + simp only [or_true, ↓reduceIte, ite_self] + | inr h_vj => + simp only [h_vj, or_true, if_true, if_true] + simp only [ite_self] + · simp only at hu hv + have h_ui : u ≠ i := by + intro h_eq + apply hu + exact Or.inl h_eq + have h_uj : u ≠ j := by + intro h_eq + apply hu + exact Or.inr h_eq + have h_vi : v ≠ i := by + intro h_eq + apply hv + exact Or.inl h_eq + have h_vj : v ≠ j := by + intro h_eq + apply hv + exact Or.inr h_eq + unfold DeleteNeuron at * + simp only [h_ui, h_uj, h_vi, h_vj, or_false, false_or, ↓reduceIte] + have θ_eq : (DeleteNeuron i (DeleteNeuron j wθ)).θ = (DeleteNeuron j (DeleteNeuron i wθ)).θ := by + unfold DeleteNeuron; rfl + have σ_eq : (DeleteNeuron i (DeleteNeuron j wθ)).σ = (DeleteNeuron j (DeleteNeuron i wθ)).σ := by + unfold DeleteNeuron; rfl + unfold DeleteNeuron + congr + +omit [Field R] [LinearOrder R] [IsStrictOrderedRing R] [Fintype U] in +/-- When deleting a single neuron from a network, the resulting weighted sum for a neuron u + that's not the deleted neuron equals the original weighted sum minus the contribution + from the deleted neuron. -/ +lemma delete_single_neuron_step {m : ℕ} + [Field R] [LinearOrder R] [IsStrictOrderedRing R] [Fintype U] [DecidableEq U] + (ps : Fin m → (HopfieldNetwork R U).State) + (neuron : U) (u : U) (hu : u ≠ neuron) (k : Fin m) : + (DeleteNeuron neuron (Hebbian ps)).w.mulVec (ps k).act u = + (Hebbian ps).w.mulVec (ps k).act u - (Hebbian ps).w u neuron * (ps k).act neuron := by + rw [delete_one_neuron_effect_general (Hebbian ps) neuron u hu (ps k).act] + +omit [Field R] [LinearOrder R] [IsStrictOrderedRing R] [Fintype U] in +/-- When deleting neurons from an empty list, the result is the original network -/ +lemma delete_empty_neurons_step {m : ℕ} + [Field R] [LinearOrder R] [IsStrictOrderedRing R] [Fintype U] [DecidableEq U] + (ps : Fin m → (HopfieldNetwork R U).State) + (wθ : Params (HopfieldNetwork R U)) (u : U) (k : Fin m) : + (List.foldl (fun acc neuron => DeleteNeuron neuron acc) wθ []).w.mulVec (ps k).act u = + wθ.w.mulVec (ps k).act u := by + simp only [List.foldl_nil] + +omit [Field R] [LinearOrder R] [IsStrictOrderedRing R] [Fintype U] in +/-- When deleting a list of neurons with a new neuron added at the front, the effect + on the weighted sum equals the effect of deleting the first neuron and then + deleting the rest of the list. -/ +lemma delete_cons_neuron_step {m : ℕ} + [Field R] [LinearOrder R] [IsStrictOrderedRing R] [Fintype U] [DecidableEq U] + (ps : Fin m → (HopfieldNetwork R U).State) + (head : U) (tail : List U) (u : U) (_ : u ≠ head) (_ : u ∉ tail) (k : Fin m) : + (List.foldl (fun acc neuron => DeleteNeuron neuron acc) (Hebbian ps) (head :: tail)).w.mulVec (ps k).act u = + (List.foldl (fun acc neuron => DeleteNeuron neuron acc) (DeleteNeuron head (Hebbian ps)) tail).w.mulVec (ps k).act u := by + rw [List.foldl_cons] + +omit [Field R] [LinearOrder R] [IsStrictOrderedRing R] [Fintype U] in +/-- For a singleton list, the effect matches the single neuron deletion case -/ +lemma delete_singleton_neuron_step {m : ℕ} + [Field R] [LinearOrder R] [IsStrictOrderedRing R] [Fintype U] [DecidableEq U] + (ps : Fin m → (HopfieldNetwork R U).State) + (neuron : U) (u : U) (hu : u ≠ neuron) (k : Fin m) : + (List.foldl (fun acc n => DeleteNeuron n acc) (Hebbian ps) [neuron]).w.mulVec (ps k).act u = + (Hebbian ps).w.mulVec (ps k).act u - (Hebbian ps).w u neuron * (ps k).act neuron := by + rw [← delete_single_neuron_step ps neuron u hu k] + exact rfl + +omit [Field R] [LinearOrder R] [IsStrictOrderedRing R] [Fintype U] in +/-- The effect of deleting a neuron from an already deleted network on a neuron u that + is not in the deleted set -/ +lemma delete_neuron_from_deleted_network {m : ℕ} + [Field R] [LinearOrder R] [IsStrictOrderedRing R] [Fintype U] [DecidableEq U] + (ps : Fin m → (HopfieldNetwork R U).State) + (prev_deleted : List U) (neuron : U) (u : U) (hu : u ≠ neuron) (_ : u ∉ prev_deleted) (k : Fin m) : + (DeleteNeuron neuron (List.foldl (fun acc n => DeleteNeuron n acc) (Hebbian ps) prev_deleted)).w.mulVec (ps k).act u = + (List.foldl (fun acc n => DeleteNeuron n acc) (Hebbian ps) prev_deleted).w.mulVec (ps k).act u - + (List.foldl (fun acc n => DeleteNeuron n acc) (Hebbian ps) prev_deleted).w u neuron * (ps k).act neuron := by + apply delete_one_neuron_effect_general (List.foldl (fun acc n => DeleteNeuron n acc) (Hebbian ps) prev_deleted) + neuron u hu (ps k).act + +/-- Helper lemma: DeleteNeuron commutes with foldl of DeleteNeuron if the neuron is not in the list -/ +lemma commute_delete_foldl [DecidableEq U] + (i : U) (base : Params (HopfieldNetwork R U)) (l : List U) (h_nodup_l : l.Nodup) (hi_notin_l : i ∉ l) : + DeleteNeuron i (List.foldl (fun acc neuron => DeleteNeuron neuron acc) base l) = + List.foldl (fun acc neuron => DeleteNeuron neuron acc) (DeleteNeuron i base) l := by + induction l generalizing base with + | nil => simp [List.foldl_nil] + | cons hd tl ih_commute => + have h_nodup_cons : (hd :: tl).Nodup := h_nodup_l + replace h_nodup_l : tl.Nodup := (List.nodup_cons.mp h_nodup_cons).2 + have h_hd_notin_tl : hd ∉ tl := (List.nodup_cons.mp h_nodup_cons).1 + have hi_notin_tl : i ∉ tl := fun h => hi_notin_l (List.mem_cons_of_mem _ h) + have hi_neq_hd : i ≠ hd := fun h => hi_notin_l (h ▸ List.mem_cons_self) + rw [List.foldl_cons, List.foldl_cons] + rw [ih_commute (DeleteNeuron hd base) h_nodup_l hi_notin_tl] + rw [delete_neurons_order_independent base i hd] + +/-- Helper lemma: Weights are preserved by foldl if indices are not in the list-/ +lemma foldl_delete_preserves_weights [DecidableEq U] + (base : Params (HopfieldNetwork R U)) + (l : List U) (h_nodup_l : l.Nodup) (v w : U) (hv_notin : v ∉ l) (hw_notin : w ∉ l) : + (List.foldl (fun acc neuron => DeleteNeuron neuron acc) base l).w v w = base.w v w := by + induction l generalizing base with + | nil => simp + | cons hd tl ih_w => + have h_nodup_cons : (hd :: tl).Nodup := h_nodup_l + replace h_nodup_l : tl.Nodup := (List.nodup_cons.mp h_nodup_cons).2 + have hv_notin_tl : v ∉ tl := fun h => hv_notin (List.mem_cons_of_mem _ h) + have hw_notin_tl : w ∉ tl := fun h => hw_notin (List.mem_cons_of_mem _ h) + have hv_neq_hd : v ≠ hd := fun h => hv_notin (h ▸ List.mem_cons_self) + have hw_neq_hd : w ≠ hd := fun h => hw_notin (h ▸ List.mem_cons_self) + rw [List.foldl_cons] + rw [ih_w (DeleteNeuron hd base) h_nodup_l hv_notin_tl hw_notin_tl] + rw [delete_neuron_preserves_other_weights base hd v w hv_neq_hd hw_neq_hd] + +omit [Field R] [LinearOrder R] [IsStrictOrderedRing R] [Fintype U] in +/-- Helper lemma: Deleting a list of neurons recursively subtracts their contributions. + Requires that the list of neurons has no duplicates. -/ +lemma delete_neurons_recursive {m : ℕ} [Field R] [LinearOrder R] [IsStrictOrderedRing R] [Fintype U] [DecidableEq U] + (ps : Fin m → (HopfieldNetwork R U).State) + (neurons : List U) (h_nodup : neurons.Nodup) (u : U) (hu : u ∉ neurons) (k : Fin m) : + (List.foldl (fun acc neuron => DeleteNeuron neuron acc) (Hebbian ps) neurons).w.mulVec (ps k).act u = + (Hebbian ps).w.mulVec (ps k).act u - ∑ j ∈ neurons.toFinset, (Hebbian ps).w u j * (ps k).act j := by + induction neurons with + | nil => + simp [List.toFinset] + | cons head tail ih => + have h_nodup_cons : (head :: tail).Nodup := h_nodup + replace h_nodup : tail.Nodup := (List.nodup_cons.mp h_nodup_cons).2 + have h_head_notin_tail : head ∉ tail := (List.nodup_cons.mp h_nodup_cons).1 + + have hu_head : u ≠ head := by + intro h; apply hu; rw [h]; exact List.mem_cons_self + have hu_tail : u ∉ tail := by + intro h; apply hu; exact List.mem_cons_of_mem _ h + rw [List.foldl_cons] + let base := Hebbian ps + let act_k := (ps k).act + let L₀ := List.foldl (fun acc neuron => DeleteNeuron neuron acc) base tail + have h_commute : DeleteNeuron head L₀ = + List.foldl (fun acc neuron => DeleteNeuron neuron acc) (DeleteNeuron head base) tail := by + apply commute_delete_foldl + exact h_nodup + exact h_head_notin_tail + rw [← h_commute] + rw [delete_one_neuron_effect_general L₀ head u hu_head act_k] + rw [ih h_nodup hu_tail] + have L0_u_head_eq : L₀.w u head = base.w u head := by + apply foldl_delete_preserves_weights base tail h_nodup u head hu_tail h_head_notin_tail + rw [L0_u_head_eq] + have head_notin_tail_finset : head ∉ tail.toFinset := by + rw [List.mem_toFinset] + exact h_head_notin_tail + rw [List.toFinset_cons, Finset.sum_insert head_notin_tail_finset] + rw [sub_add_eq_sub_sub] + ring + +lemma deleted_neuron_weight_contribution [DecidableEq U] (wθ : Params (HopfieldNetwork R U)) + (i u : U) (_ : u ≠ i) (s : U → R) : + wθ.w u i * s i = ∑ j ∈ {i}, wθ.w u j * s j := by + rw [sum_singleton] + +omit [Field R] [LinearOrder R] [IsStrictOrderedRing R] [Fintype U] in +/-- DeleteNeurons removes weights connected to deleted neurons -/ +lemma deleted_neurons_field_effect {m : ℕ} [Field R] [LinearOrder R] [IsStrictOrderedRing R] [Fintype U] [DecidableEq U] + (ps : Fin m → (HopfieldNetwork R U).State) + (deleted_neurons : Finset U) (u : U) (hu : u ∉ deleted_neurons) (k : Fin m) : + (DeleteNeurons deleted_neurons.toList (Hebbian ps)).w.mulVec (ps k).act u = + (Hebbian ps).w.mulVec (ps k).act u - ∑ v ∈ deleted_neurons, (Hebbian ps).w u v * (ps k).act v := by + rw [DeleteNeurons] + have hu_list : u ∉ deleted_neurons.toList := by + simp only [Finset.mem_toList] + exact hu + have h_nodup : deleted_neurons.toList.Nodup := Finset.nodup_toList deleted_neurons + rw [delete_neurons_recursive ps deleted_neurons.toList h_nodup u hu_list k] + congr 1 + rw [@toList_toFinset] + +/-! ### Lemmas for Bounding Field Reduction due to Neuron Deletion -/ + +variable {m : ℕ} [DecidableEq U] + (ps : Fin m → (HopfieldNetwork R U).State) + (deleted_neurons : Finset U) + (k : Fin m) (u : U) (hu : u ∉ deleted_neurons) + (horth : ∀ {i j : Fin m}, i ≠ j → dotProduct (ps i).act (ps j).act = 0) + +/-- +Calculates the contribution to the deleted field sum from the target pattern `k` itself +in the Hebbian weight definition. +-/ +lemma hebbian_weight_deleted_neurons_l_eq_k_term : + ∑ v ∈ deleted_neurons, (ps k).act u * (ps k).act v * (ps k).act v = + (card deleted_neurons : R) * (ps k).act u := by + have h_act_sq : ∀ v, (ps k).act v * (ps k).act v = 1 := + fun v => mul_self_eq_one_iff.mpr ((ps k).hp v) + simp_rw [mul_assoc, h_act_sq, mul_one] + rw [Finset.sum_const] + simp only [nsmul_eq_mul, Nat.cast_id] + +/-- +Defines the cross-talk contribution to the deleted field sum. +This term arises from the interaction of the target pattern `k` with other stored patterns `l ≠ k` +over the set of deleted neurons. +-/ +noncomputable def hebbian_weight_deleted_neurons_cross_talk_term : R := + ∑ v ∈ deleted_neurons, (∑ l ∈ {l' | l' ≠ k}, (ps l).act u * (ps l).act v) * (ps k).act v + +/-- +Axiom stating the bound on the absolute value of the cross-talk term. +This encapsulates the statistical argument from Hopfield's paper that for +random-like patterns, the sum of interfering terms is bounded. +-/ +lemma cross_talk_term_abs_bound_assumption {R U : Type} [Field R] [LinearOrder R] [IsStrictOrderedRing R] [Fintype U] [Nonempty U] [DecidableEq U] {m : ℕ} + (ps : Fin m → (HopfieldNetwork R U).State) (deleted_neurons : Finset U) + (k : Fin m) (u : U) (_hu : u ∉ deleted_neurons) + (_horth : ∀ {i j : Fin m}, i ≠ j → dotProduct (ps i).act (ps j).act = 0) -- Orthogonality is a simplifying assumption often made. + (hcard : card deleted_neurons ≤ Fintype.card U / 10) + (hm : m ≤ Fintype.card U / 10) : + |hebbian_weight_deleted_neurons_cross_talk_term ps deleted_neurons k u| ≤ (Fintype.card U / 10 : R) := by sorry + +/-- +Decomposes the total sum representing the field reduction into the contribution +from the target pattern `k` and the cross-talk term from other patterns `l ≠ k`. + +**Hopfield Assumption:** Assumes the standard Hebbian learning rule where `Tᵢᵢ = 0`. +The `Hebbian` definition in `HN.lean` implements this by subtracting `m • 1`. +-/ +lemma Hebbian_stable (hm : m < Fintype.card U) (ps : Fin m → (HopfieldNetwork R U).State) (j : Fin m) + (horth : ∀ {i j : Fin m} (_ : i ≠ j), dotProduct (ps i).act (ps j).act = 0): + isStable (Hebbian ps) (ps j) := by + have h_mulVec_eq : ((Hebbian ps).w).mulVec (ps j).act = ((Fintype.card U : R) - (m : R)) • (ps j).act := + patterns_pairwise_orthogonal ps horth j + have h_pos_diff : 0 < (Fintype.card U : R) - m := by + rw [sub_pos, Nat.cast_lt] + exact hm + apply stateisStablecondition ps (ps j) ((Fintype.card U : R) - m) h_pos_diff + intro u + exact congrFun h_mulVec_eq u + +omit [Fintype U] [Nonempty U] in +lemma hebbian_weight_diagonal_term_zero_if_neq (u v : U) (huv : u ≠ v) {m : ℕ} : + (↑m * if u = v then (1 : R) else 0) = 0 := by + simp [if_neg huv] + +omit [DecidableEq U] in +lemma sum_hebbian_weight_times_act_remove_diagonal {m : ℕ} [DecidableEq U] + (ps : Fin m → (HopfieldNetwork R U).State) (deleted_neurons : Finset U) + (k : Fin m) (u : U) (hu : u ∉ deleted_neurons) : + ∑ v ∈ deleted_neurons, (Hebbian ps).w u v * (ps k).act v = + ∑ v ∈ deleted_neurons, (∑ l, outerProduct (ps l) (ps l) u v) * (ps k).act v := by + unfold Hebbian + simp only [sub_apply, smul_apply, smul_eq_mul, Matrix.one_apply] + have h_diag_sum_zero : + ∑ v ∈ deleted_neurons, (↑m * if u = v then 1 else 0) * (ps k).act v = 0 := by + apply Finset.sum_eq_zero + intro v hv + have huv : u ≠ v := fun eqv => hu (eqv ▸ hv) + rw [hebbian_weight_diagonal_term_zero_if_neq u v huv, zero_mul] + simp_rw [sub_mul] + rw [Finset.sum_sub_distrib] + rw [h_diag_sum_zero, sub_zero] + apply Finset.sum_congr rfl + intro v _ + rw [Matrix.sum_apply] + +omit [DecidableEq U] in +lemma sum_outer_product_times_act_swap_summation {m : ℕ} [DecidableEq U] + (ps : Fin m → (HopfieldNetwork R U).State) (deleted_neurons : Finset U) + (k : Fin m) (u : U) : + ∑ v ∈ deleted_neurons, (∑ l, outerProduct (ps l) (ps l) u v) * (ps k).act v = + ∑ l, ∑ v ∈ deleted_neurons, (ps l).act u * (ps l).act v * (ps k).act v := by + simp only [outerProduct] + rw [Finset.sum_comm] + apply Finset.sum_congr rfl + intro l _ + rw [Finset.sum_mul] + +omit [DecidableEq U] in +lemma sum_outer_product_times_act_l_eq_k {m : ℕ} [DecidableEq U] + (ps : Fin m → (HopfieldNetwork R U).State) (deleted_neurons : Finset U) + (k : Fin m) (u : U) : + ∑ v ∈ deleted_neurons, (ps k).act u * (ps k).act v * (ps k).act v = + (card deleted_neurons : R) * (ps k).act u := by + have act_sq_one : ∀ v, (ps k).act v * (ps k).act v = 1 := + fun v => mul_self_eq_one_iff.mpr ((ps k).hp v) + simp_rw [mul_assoc] + rw [← Finset.mul_sum] + have h_sum : ∑ v ∈ deleted_neurons, (ps k).act v * (ps k).act v = ↑(deleted_neurons.card) := by + rw [Finset.sum_congr rfl fun v _ => act_sq_one v] + simp only [Finset.sum_const, nsmul_one, Nat.cast_id] + rw [h_sum] + rw [mul_comm] + +omit [DecidableEq U] in +lemma sum_outer_product_times_act_l_neq_k {m : ℕ} [DecidableEq U] + (ps : Fin m → (HopfieldNetwork R U).State) (deleted_neurons : Finset U) + (k : Fin m) (u : U) : + ∑ l ∈ Finset.univ.erase k, ∑ v ∈ deleted_neurons, (ps l).act u * (ps l).act v * (ps k).act v = + hebbian_weight_deleted_neurons_cross_talk_term ps deleted_neurons k u := by + unfold hebbian_weight_deleted_neurons_cross_talk_term + dsimp only [ne_eq] + rw [Finset.sum_comm] + apply Finset.sum_congr rfl + intro v _ + rw [← Finset.sum_mul] + apply congr_arg (· * (ps k).act v) + apply Finset.sum_congr + · simp only [Finset.filter_ne', Finset.mem_univ, true_and] + · intros _ _ ; rfl + +omit [DecidableEq U] in +lemma deleted_field_sum_decomposition {m : ℕ} [DecidableEq U] + (ps : Fin m → (HopfieldNetwork R U).State) (deleted_neurons : Finset U) + (k : Fin m) (u : U) (hu : u ∉ deleted_neurons) : + ∑ v ∈ deleted_neurons, (Hebbian ps).w u v * (ps k).act v = + (card deleted_neurons : R) * (ps k).act u + hebbian_weight_deleted_neurons_cross_talk_term ps deleted_neurons k u := by + rw [sum_hebbian_weight_times_act_remove_diagonal ps deleted_neurons k u hu] + rw [sum_outer_product_times_act_swap_summation ps deleted_neurons k u] + rw [Finset.sum_eq_add_sum_diff_singleton (Finset.mem_univ k)] + rw [sum_outer_product_times_act_l_eq_k ps deleted_neurons k u] + simp_rw [Finset.sdiff_singleton_eq_erase] + rw [sum_outer_product_times_act_l_neq_k ps deleted_neurons k u] + +/-- +Placeholder lemma for bounding the cross-talk term. +Proving a tight bound likely requires assumptions beyond simple orthogonality, +such as patterns being random and uncorrelated, analyzed in the limit of large N. +The bound likely depends on `m` (number of patterns) and `N` (number of neurons). + +**Hopfield Assumption:** Implicitly assumes patterns behave statistically like random vectors. +-/ +lemma bound_cross_talk_term (hcard : card deleted_neurons ≤ Fintype.card U / 10) (hm : m ≤ Fintype.card U / 10) : + hebbian_weight_deleted_neurons_cross_talk_term ps deleted_neurons k u ≤ + ((Fintype.card U / 10 : R) - (card deleted_neurons : R)) * (ps k).act u := by + -- The proof of this bound is complex and depends on statistical assumptions + -- about the patterns {ps l} and potentially the size of the network N = card U. + -- Simple orthogonality is insufficient to guarantee this bound in general. + -- This lemma is stronger than the absolute value bound and may require different/stronger assumptions. + sorry + +/-- +The field reduction from deleting neurons has a bounded effect. +This version uses the decomposition and the (unproven) cross-talk bound. + +**Hopfield Assumptions:** +1. Standard Hebbian learning (`Tᵢᵢ = 0`). +2. Patterns `ps` are orthogonal (`horth`). +3. Patterns `ps` behave statistically like random vectors (implicit in `bound_cross_talk_term`). +4. Number of stored patterns `m` is small relative to network size `N` (`hm`). +5. Number of deleted neurons is small relative to network size `N` (`hcard`). +-/ +lemma deleted_field_bound (hu : u ∉ deleted_neurons) + (_ : ∀ {i j : Fin m}, i ≠ j → dotProduct (ps i).act (ps j).act = 0) (hcard : card deleted_neurons ≤ Fintype.card U / 10) (hm : m ≤ Fintype.card U / 10) : + ∑ v ∈ deleted_neurons, (Hebbian ps).w u v * (ps k).act v ≤ + (Fintype.card U / 10 : R) * (ps k).act u := by + rw [deleted_field_sum_decomposition ps deleted_neurons k u hu] + let C := hebbian_weight_deleted_neurons_cross_talk_term ps deleted_neurons k u + let lk_term := (card deleted_neurons : R) * (ps k).act u + have h_bound_C : C ≤ ((Fintype.card U / 10 : R) - (card deleted_neurons : R)) * (ps k).act u := + by exact bound_cross_talk_term ps deleted_neurons k u hcard hm + calc + lk_term + C ≤ lk_term + ((Fintype.card U / 10 : R) - (card deleted_neurons : R)) * (ps k).act u := add_le_add_left h_bound_C _ + _ = (card deleted_neurons : R) * (ps k).act u + (Fintype.card U / 10 : R) * (ps k).act u - (card deleted_neurons : R) * (ps k).act u := by ring + _ = (Fintype.card U / 10 : R) * (ps k).act u := by + rw [add_sub_cancel_left] + +omit [Field R] [LinearOrder R] [IsStrictOrderedRing R] [Fintype U] [Nonempty U] [DecidableEq U] in +/-- With constrained m and limited deleted neurons, the field remains strong enough -/ +lemma field_remains_sufficient {m : ℕ} [Field R] [LinearOrder R] [IsStrictOrderedRing R] [Fintype U] [Nonempty U] + (hm_cond : m ≤ Fintype.card U / 10) : + (Fintype.card U : R) - (m : R) - (Fintype.card U / 10 : R) > 0 := by + let N_R : R := (Fintype.card U : R) + let m_R : R := (m : R) + let N_div_10_R : R := (Fintype.card U / 10 : R) + let ten_R : R := 10 + have N_pos_nat : 0 < Fintype.card U := Fintype.card_pos + have N_R_pos : N_R > 0 := Nat.cast_pos.mpr N_pos_nat + have ten_R_pos : ten_R > 0 := by simp [Nat.cast_ofNat]; exact Nat.cast_pos.mpr (by norm_num : 0 < 10) + have hm_R_le_N_div_10 : m_R ≤ N_R / ten_R := by + have h10 : (10 : ℕ) ≠ 0 := by norm_num + have : m * 10 ≤ Fintype.card U := by + have h_mul_div : Fintype.card U / 10 * 10 ≤ Fintype.card U := Nat.div_mul_le_self (Fintype.card U) 10 + exact le_trans (Nat.mul_le_mul_right 10 hm_cond) h_mul_div + have cast_this : (↑(m * 10) : R) ≤ (↑(Fintype.card U) : R) := Nat.cast_le.mpr this + rw [Nat.cast_mul] at cast_this + have : (↑10 : R) = ten_R := by rfl + have : ↑m * ten_R ≤ N_R := by + rw [← this] + exact cast_this + exact (le_div_iff₀ ten_R_pos).mpr cast_this + calc + N_R - m_R - N_R / ten_R ≥ N_R - (N_R / ten_R) - (N_R / ten_R) := by linarith [hm_R_le_N_div_10] + _ = N_R - 2 * (N_R / ten_R) := by ring + _ = N_R * (1 - 2 / ten_R) := by + field_simp [ten_R_pos.ne.symm] + ring + _ = N_R * (8 / ten_R) := by + congr 2 + field_simp [ten_R_pos.ne.symm] + norm_num + _ > 0 := by + apply mul_pos + · exact N_R_pos + · have eight_R_pos : (8 : R) > 0 := by simp [Nat.cast_ofNat]; + exact div_pos eight_R_pos ten_R_pos + +omit [Fintype U] [Nonempty U] in +/-- For random orthogonal patterns, the cross-talk term has a bounded absolute value. + This is a fundamental assumption from Hopfield's paper about how patterns interact. -/ +lemma bound_cross_talk_term_abs [Fintype U] [Nonempty U] + (ps : Fin m → (HopfieldNetwork R U).State) (deleted_neurons : Finset U) + (k : Fin m) (u : U) (hu : u ∉ deleted_neurons) + (horth : ∀ {i j : Fin m}, i ≠ j → dotProduct (ps i).act (ps j).act = 0) + (hcard : card deleted_neurons ≤ Fintype.card U / 10) + (hm : m ≤ Fintype.card U / 10) : + |hebbian_weight_deleted_neurons_cross_talk_term ps deleted_neurons k u| ≤ (Fintype.card U / 10 : R) := by + exact cross_talk_term_abs_bound_assumption ps deleted_neurons k u hu horth hcard hm + +lemma deleted_field_product_bound [Fintype U] [Nonempty U] + (ps : Fin m → (HopfieldNetwork R U).State) (deleted_neurons : Finset U) + (k_pat : Fin m) (u_check : U) (hu_check_not_deleted : u_check ∉ deleted_neurons) + (horth : ∀ {i j : Fin m}, i ≠ j → dotProduct (ps i).act (ps j).act = 0) + (hcard_deleted_cond : card deleted_neurons ≤ Fintype.card U / 10) + (hm_cond : m ≤ Fintype.card U / 10) : + (∑ v ∈ deleted_neurons, (Hebbian ps).w u_check v * (ps k_pat).act v) * (ps k_pat).act u_check ≤ (Fintype.card U / 5 : R) := by + -- We apply the decomposition to separate signal and cross-talk terms + have decomp := deleted_field_sum_decomposition ps deleted_neurons k_pat u_check hu_check_not_deleted + let C_del_R := (card deleted_neurons : R) + let X_talk := hebbian_weight_deleted_neurons_cross_talk_term ps deleted_neurons k_pat u_check + let N_R := (Fintype.card U : R) + let ten_R : R := 10 + let N_div_10_R := N_R / ten_R + let act_u_k := (ps k_pat).act u_check + have sum_field_eq : ∑ v ∈ deleted_neurons, (Hebbian ps).w u_check v * (ps k_pat).act v = + C_del_R * act_u_k + X_talk := decomp + -- Multiply both sides by act_u_k + calc + (∑ v ∈ deleted_neurons, (Hebbian ps).w u_check v * (ps k_pat).act v) * act_u_k = (C_del_R * act_u_k + X_talk) * act_u_k := by rw [sum_field_eq] + _ = C_del_R * act_u_k * act_u_k + X_talk * act_u_k := by rw [add_mul] + _ = C_del_R + X_talk * act_u_k := by + rw [mul_assoc, mul_self_eq_one_iff.mpr ((ps k_pat).hp u_check), mul_one] + _ ≤ N_div_10_R + N_div_10_R := by + apply add_le_add + · -- First bound: C_del_R ≤ N_div_10_R + have ten_R_pos : ten_R > 0 := Nat.cast_pos.mpr (by norm_num : 0 < 10) + have : 10 * card deleted_neurons ≤ Fintype.card U := by + have h10 : 10 > 0 := by norm_num + apply Nat.mul_le_of_le_div + · sorry + have cast_this : (↑(10 * card deleted_neurons) : R) ≤ N_R := Nat.cast_le.mpr this + rw [Nat.cast_mul] at cast_this + simp only [Nat.cast_ofNat] at cast_this + exact (le_div_iff₀' ten_R_pos).mpr cast_this + · -- Second bound: X_talk * act_u_k ≤ N_div_10_R + apply le_trans + · -- X_talk * act_u_k ≤ |X_talk| + calc + X_talk * act_u_k ≤ |X_talk * act_u_k| := le_abs_self _ + _ = |X_talk| * |act_u_k| := by rw [abs_mul] + _ = |X_talk| * 1 := sorry + _ = |X_talk| := by rw [mul_one] + · -- |X_talk| ≤ N_div_10_R + exact bound_cross_talk_term_abs ps deleted_neurons k_pat u_check hu_check_not_deleted horth hcard_deleted_cond hm_cond + _ = N_R / 5 := by + have five_R : R := 5 + have ten_R_pos : ten_R > 0 := Nat.cast_pos.mpr (by norm_num : 0 < 10) + have five_R_pos : five_R > 0 := sorry + field_simp + sorry + +/-- With bounded numbers of patterns and deleted neurons, the field remains strong enough + to maintain the pattern stability, adjusted for N/5 bound. -/ +lemma field_remains_sufficient_for_N_div_5 (R : Type) [Field R] [LinearOrder R] [IsStrictOrderedRing R] [Fintype U] [Nonempty U] + (hm_cond : m ≤ Fintype.card U / 10) : + (Fintype.card U : R) - (m : R) - (Fintype.card U / 5 : R) > 0 := by + let N_R : R := Fintype.card U + let m_R : R := m + let five_R : R := 5 + let ten_R : R := 10 + have N_R_pos : N_R > 0 := Nat.cast_pos.mpr Fintype.card_pos + have five_R_pos : five_R > 0 := Nat.cast_pos.mpr (by norm_num : 0 < 5) + have ten_R_pos : ten_R > 0 := Nat.cast_pos.mpr (by norm_num : 0 < 10) + have hm_R_le_N_div_10 : m_R ≤ N_R / ten_R := by + sorry + calc + N_R - m_R - N_R / five_R ≥ N_R - (N_R / ten_R) - (N_R / five_R) := by linarith [hm_R_le_N_div_10] + _ = N_R * (1 - 1/ten_R - 1/five_R) := by + field_simp [ten_R_pos.ne.symm, five_R_pos.ne.symm] + ring + _ = N_R * (1 - 1/ten_R - 2/ten_R) := by + apply congr_arg (fun x => N_R * x) + have h_frac : 1/five_R = 2/ten_R := by + field_simp [five_R_pos.ne.symm, ten_R_pos.ne.symm] + have h_five_eq : five_R = 5 := rfl + have h_ten_eq : ten_R = 10 := rfl + rw [h_five_eq, h_ten_eq] + norm_num + rw [h_frac] + _ = N_R * ( (10-1-2) / ten_R) := by + field_simp [ten_R_pos.ne.symm] + ring_nf + sorry + _ = N_R * (7 / ten_R) := by norm_num + _ > 0 := by + apply mul_pos N_R_pos + have seven_R_pos : (7 : R) > 0 := Nat.cast_pos.mpr (by norm_num : 0 < 7) + exact div_pos seven_R_pos ten_R_pos + +lemma hebbian_deleted_threshold_is_zero {m : ℕ} [DecidableEq U] + (ps : Fin m → (HopfieldNetwork R U).State) + (neurons_to_delete : Finset U) (u_check : U) : + θ' ((DeleteNeurons neurons_to_delete.toList (Hebbian ps)).θ u_check) = 0 := by + simp only [DeleteNeurons, DeleteNeuron, Hebbian, θ'] + induction neurons_to_delete.toList with + | nil => simp [List.foldl_nil]; exact rfl + | cons head tail ih => + rw [List.foldl_cons] + sorry + +lemma net_input_at_non_deleted_neuron {m : ℕ} [DecidableEq U] + (ps : Fin m → (HopfieldNetwork R U).State) (k_pat : Fin m) + (neurons_to_delete : Finset U) (u_check : U) (hu_check_not_deleted : u_check ∉ neurons_to_delete) + (wθ_orig wθ_deleted : Params (HopfieldNetwork R U)) (hwθ_orig_eq : wθ_orig = Hebbian ps) (hwθ_del_eq : wθ_deleted = DeleteNeurons neurons_to_delete.toList wθ_orig) : + (ps k_pat).net wθ_deleted u_check = + wθ_orig.w.mulVec (ps k_pat).act u_check - (∑ v ∈ neurons_to_delete, wθ_orig.w u_check v * (ps k_pat).act v) := by + subst hwθ_orig_eq hwθ_del_eq + have h_diag_w_deleted_zero : (DeleteNeurons neurons_to_delete.toList (Hebbian ps)).w u_check u_check = 0 := by + exact (DeleteNeurons neurons_to_delete.toList (Hebbian ps)).hw u_check u_check fun a ↦ a rfl + unfold NeuralNetwork.State.net HopfieldNetwork NeuralNetwork.fnet HNfnet + have sum_eq_net : ∑ v ∈ {v | v ≠ u_check}, ((DeleteNeurons neurons_to_delete.toList (Hebbian ps)).w u_check v) * (ps k_pat).act v = + (ps k_pat).net (DeleteNeurons neurons_to_delete.toList (Hebbian ps)) u_check := by rfl + have h_sum_split : ∑ v ∈ {v | v ≠ u_check}, ((DeleteNeurons neurons_to_delete.toList (Hebbian ps)).w u_check v) * (ps k_pat).act v = + ∑ v, ((DeleteNeurons neurons_to_delete.toList (Hebbian ps)).w u_check v) * (ps k_pat).act v := by + rw [← + HNfnet_eq u_check ((DeleteNeurons neurons_to_delete.toList (Hebbian ps)).w u_check) + (ps k_pat).act h_diag_w_deleted_zero] + have sum_eq_mulVec : ∑ v, ((DeleteNeurons neurons_to_delete.toList (Hebbian ps)).w u_check v) * (ps k_pat).act v = + ((DeleteNeurons neurons_to_delete.toList (Hebbian ps)).w).mulVec (ps k_pat).act u_check := by + rw [mulVec] + simp only [dotProduct] + have field_effect := deleted_neurons_field_effect ps neurons_to_delete u_check hu_check_not_deleted k_pat + sorry + +omit [DecidableEq U] in +lemma product_net_input_activation_at_non_deleted_neuron {m : ℕ} [DecidableEq U] + (ps : Fin m → (HopfieldNetwork R U).State) (k_pat : Fin m) + (neurons_to_delete : Finset U) (u_check : U) (hu_check_not_deleted : u_check ∉ neurons_to_delete) + (horth : ∀ {i j : Fin m}, i ≠ j → dotProduct (ps i).act (ps j).act = 0) + (wθ_orig wθ_deleted : Params (HopfieldNetwork R U)) (hwθ_orig_eq : wθ_orig = Hebbian ps) (hwθ_del_eq : wθ_deleted = DeleteNeurons neurons_to_delete.toList wθ_orig) : + ((ps k_pat).net wθ_deleted u_check) * (ps k_pat).act u_check = + (((Fintype.card U : R) - m) - (∑ v ∈ neurons_to_delete, wθ_orig.w u_check v * (ps k_pat).act v) * (ps k_pat).act u_check) := by + subst wθ_orig + rw [net_input_at_non_deleted_neuron ps k_pat neurons_to_delete u_check hu_check_not_deleted (Hebbian ps) wθ_deleted rfl hwθ_del_eq] + rw [sub_mul] + have h_orig_field_term_mul_act : (Hebbian ps).w.mulVec (ps k_pat).act u_check * (ps k_pat).act u_check = ((Fintype.card U : R) - m) := by + have h_orig_field_eq : (Hebbian ps).w.mulVec (ps k_pat).act u_check = ((Fintype.card U : R) - m) * (ps k_pat).act u_check := + congr_fun (patterns_pairwise_orthogonal ps horth k_pat) u_check + rw [h_orig_field_eq, mul_assoc] + rw [mul_self_eq_one_iff.mpr ((ps k_pat).hp u_check), mul_one] + rw [h_orig_field_term_mul_act] + +lemma non_deleted_neuron_maintains_sign_of_activation {m : ℕ} [Nonempty U] + (ps : Fin m → (HopfieldNetwork R U).State) (k_pat : Fin m) + (neurons_to_delete : Finset U) (u_check : U) (hu_check_not_deleted : u_check ∉ neurons_to_delete) + (horth : ∀ {i j : Fin m}, i ≠ j → dotProduct (ps i).act (ps j).act = 0) + (hm_cond : m ≤ Fintype.card U / 10) (hcard_deleted_cond : neurons_to_delete.card ≤ Fintype.card U / 10) + (wθ_orig wθ_deleted : Params (HopfieldNetwork R U)) (hwθ_orig_eq : wθ_orig = Hebbian ps) (hwθ_del_eq : wθ_deleted = DeleteNeurons neurons_to_delete.toList wθ_orig) : + ((ps k_pat).net wθ_deleted u_check) * (ps k_pat).act u_check > 0 := by + subst wθ_orig + rw [product_net_input_activation_at_non_deleted_neuron ps k_pat neurons_to_delete u_check hu_check_not_deleted horth (Hebbian ps) wθ_deleted rfl hwθ_del_eq] + let signal_term := (Fintype.card U : R) - (m : R) + let reduction_term_times_act := (∑ v ∈ neurons_to_delete, (Hebbian ps).w u_check v * (ps k_pat).act v) * (ps k_pat).act u_check + have h_bound_reduction_term : reduction_term_times_act ≤ (Fintype.card U / 5 : R) := + deleted_field_product_bound ps neurons_to_delete k_pat u_check hu_check_not_deleted horth hcard_deleted_cond hm_cond + have h_signal_bound : signal_term - (Fintype.card U / 5 : R) > 0 := by + unfold signal_term + simp only [gt_iff_lt, sub_pos] + sorry + calc + signal_term - reduction_term_times_act ≥ signal_term - (Fintype.card U / 5 : R) := by + apply sub_le_sub_left h_bound_reduction_term _ + _ > 0 := h_signal_bound + +omit [DecidableEq U] in +/-- When deleting neurons from a Finset, we can use Finset.toList to convert the Finset to a List. + This matches the API needed by DeleteNeurons. -/ +lemma DeleteNeurons_with_Finset [DecidableEq U] (deleted_neurons : Finset U) (wθ : Params (HopfieldNetwork R U)) : + DeleteNeurons (Finset.toList deleted_neurons) wθ = + DeleteNeurons deleted_neurons.toList wθ := by + rfl + +omit [Field R] [LinearOrder R] [IsStrictOrderedRing R] [Fintype U] [Nonempty U] [DecidableEq U] in +/-- A Hopfield network can tolerate the failure of up to 10% of its neurons + while maintaining all stored patterns as fixed points, provided: + 1) The stored patterns are orthogonal + 2) The number of patterns is at most 10% of the network size -/ +theorem fault_tolerance_bound {m : ℕ} [Field R] [LinearOrder R] [IsStrictOrderedRing R] [Fintype U] [DecidableEq U] [Nonempty U] + (ps : Fin m → (HopfieldNetwork R U).State) + (horth : ∀ i j : Fin m, i ≠ j → dotProduct (ps i).act (ps j).act = 0) : + m ≤ Fintype.card U / 10 → ∀ k_pat : Fin m, FaultTolerance ps k_pat (Fintype.card U / 10) := by + intro hm_cond k_pat + dsimp only [FaultTolerance] + intro deleted_neurons_finset hcard_deleted_cond + let wθ_orig := Hebbian ps + let wθ' := DeleteNeurons (Finset.toList deleted_neurons_finset) wθ_orig + intro u_check_neuron hu_check_mem_sdiff + let act_k_u_check := (ps k_pat).act u_check_neuron + let net_u_check_new := net wθ' (ps k_pat) u_check_neuron + let threshold_u_check_new := θ' (wθ'.θ u_check_neuron) + have hu_check_not_in_deleted : u_check_neuron ∉ deleted_neurons_finset := + (Finset.mem_sdiff.mp hu_check_mem_sdiff).right + have h_net_sign_correct : net_u_check_new * act_k_u_check > 0 := + non_deleted_neuron_maintains_sign_of_activation ps k_pat deleted_neurons_finset + u_check_neuron hu_check_not_in_deleted + (fun {i j} h => horth i j h) hm_cond hcard_deleted_cond wθ_orig wθ' rfl rfl + have h_threshold_is_zero : threshold_u_check_new = 0 := + hebbian_deleted_threshold_is_zero ps deleted_neurons_finset u_check_neuron + rw [act_up_def] + cases ((ps k_pat).hp u_check_neuron) with + | inl h_act_eq_one => + simp only [act_k_u_check, h_act_eq_one] + apply if_pos + dsimp only [h_threshold_is_zero] + have net_pos : net_u_check_new > 0 := by + have h_temp := h_net_sign_correct + simp only [act_k_u_check, h_act_eq_one] at h_temp + simp only [mul_one] at h_temp + exact h_temp + have threshold_zero : (wθ'.θ u_check_neuron).get 0 = 0 := by dsimp only [h_threshold_is_zero]; exact + h_threshold_is_zero + rw [threshold_zero] + exact le_of_lt net_pos + | inr h_act_eq_neg_one => + simp only [act_k_u_check, h_act_eq_neg_one] + apply if_neg + dsimp only [h_threshold_is_zero] + have net_neg : net wθ' (ps k_pat) u_check_neuron < 0 := by + have h_temp := h_net_sign_correct + simp only [act_k_u_check, h_act_eq_neg_one] at h_temp + simp only [mul_neg_one, neg_pos] at h_temp + exact h_temp + exact Mathlib.Tactic.IntervalCases.of_lt_right net_neg (id (Eq.symm h_threshold_is_zero)) + +end Hopfield82 diff --git a/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Hopfield82/MemoryConfusion.lean b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Hopfield82/MemoryConfusion.lean new file mode 100644 index 000000000..19a3a2840 --- /dev/null +++ b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Hopfield82/MemoryConfusion.lean @@ -0,0 +1,78 @@ +import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.Hopfield82.MemoryStorage +import Mathlib.Data.Real.StarOrdered +import Mathlib.MeasureTheory.Integral.Bochner.Basic +import Mathlib.MeasureTheory.Measure.Haar.OfBasis + +namespace Hopfield82 + +open NeuralNetwork State Matrix Finset Real + +variable {R U : Type} [Field R] [LinearOrder R] [IsStrictOrderedRing R] [Fintype U] [Nonempty U] + +/-! ### Memory Capacity -/ + +/-- +The `StorageCapacity` of a Hopfield network is the maximum number of patterns +that can be stored and reliably retrieved. The paper suggests this is around 0.15N. + +From the paper (p.2556): "About 0.15 N states can be simultaneously remembered before +error in recall is severe." +-/ +def StorageCapacity : ℝ := 0.15 * Fintype.card U + +/-- +The error function `erf(x)` is defined as: + erf(x) = (2/√π) ∫₀ˣ e^(-t²) dt + +This function is central in probability theory, especially for normal distributions. +-/ +noncomputable def Real.erf (x : ℝ) : ℝ := + (2 / Real.sqrt π) * ∫ (t : ℝ) in Set.Icc 0 x, Real.exp (-(t^2)) + +/-- +The `PatternRetrievalError` function computes the probability of an error in pattern retrieval +for a network storing m patterns. This increases as m approaches and exceeds 0.15N. + +This corresponds to the error probability P discussed in Equation 10 of the paper: +P = ∫_σ^∞ (1/√(2π)) e^(-x²/2) dx = (1/2)(1 - erf(σ/√2)) + +where σ = N/(2√(nN)) and n is the number of patterns. +-/ +noncomputable def PatternRetrievalError (m : ℕ) : ℝ := + let N := Fintype.card U + let σ := N / (2 * Real.sqrt (m * N : ℝ)) + (1/2) * (1 - Real.erf (σ / Real.sqrt 2)) + +/-- +The result from the paper that a Hopfield network can store approximately 0.15N patterns +with high reliability, where N is the number of neurons. + +This theorem formalizes the key result about storage capacity from the paper, +utilizing the Hebbian_stable theorem from the existing codebase. +-/ +theorem storage_capacity_bound {m : ℕ} [DecidableEq U] + (ps : Fin m → (HopfieldNetwork R U).State) + (horth : ∀ {i j : Fin m}, i ≠ j → dotProduct (ps i).act (ps j).act = 0) : + let wθ := Hebbian ps; + (m : ℝ) ≤ ((Fintype.card U : ℝ) * (15 : ℝ)) / (100 : ℝ) → ∀ k : Fin m, FixedPoint wθ (ps k) := by + intros _ hm k + unfold FixedPoint + apply Hebbian_stable + · have h1 : (m : ℝ) ≤ ((Fintype.card U : ℝ) * 15) / 100 := hm + have aux : (15 : ℝ) / 100 < 1 := by norm_num; + have h2 : ((Fintype.card U : ℝ) * 15) / 100 < (Fintype.card U : ℝ) := + by + have cardU_pos : 0 < (Fintype.card U : ℝ) := by norm_cast; exact Fintype.card_pos + have h_rewrite : ((Fintype.card U : ℝ) * 15) / 100 = (Fintype.card U : ℝ) * (15 / 100) := by ring + have h_frac_lt_one : (15 : ℝ) / 100 < 1 := by norm_num + have h_mul_lt_mul : (Fintype.card U : ℝ) * (15 / 100) < (Fintype.card U : ℝ) * 1 := + mul_lt_mul_of_pos_left h_frac_lt_one cardU_pos + have h_simplify : (Fintype.card U : ℝ) * 1 = (Fintype.card U : ℝ) := by ring + calc + ((Fintype.card U : ℝ) * 15) / 100 = (Fintype.card U : ℝ) * (15 / 100) := h_rewrite + _ < (Fintype.card U : ℝ) * 1 := h_mul_lt_mul + _ = (Fintype.card U : ℝ) := h_simplify + have h3 : (m : ℝ) < (Fintype.card U : ℝ) := lt_of_le_of_lt h1 h2 + have h4 : m < Fintype.card U := by exact_mod_cast h3 + exact h4 + · exact horth diff --git a/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Hopfield82/MemoryStorage.lean b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Hopfield82/MemoryStorage.lean new file mode 100644 index 000000000..cc45e6559 --- /dev/null +++ b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Hopfield82/MemoryStorage.lean @@ -0,0 +1,38 @@ +import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.Hopfield82.PhaseSpaceFlow + +namespace Hopfield82 + +open NeuralNetwork State Matrix Finset Real + +variable {R U : Type} [Field R] [LinearOrder R] [IsStrictOrderedRing R] [Fintype U] [Nonempty U] + +/-! ### Memory Storage Algorithm -/ + +/-- +The `normalizedPattern` converts a neural state to a vector with -1/+1 values, +matching the (2Vᵢ - 1) term from equation 2 in Hopfield's paper. +-/ +def normalizedPattern [DecidableEq R] [DecidableEq U] (s : (HopfieldNetwork R U).State) : U → R := + λ u => s.act u + +/-- +The `hebbian` function computes the weight matrix according to Equation 2 of Hopfield's paper: + Tᵢⱼ = Σₛ (2Vᵢˢ - 1)(2Vⱼˢ - 1) with Tᵢᵢ = 0 + +Note that this is equivalent to the existing `Hebbian` definition in HopfieldNet.HN, +but we make the connection to the paper explicit here. +-/ +def hebbian {m : ℕ} [DecidableEq R] [DecidableEq U] + (ps : Fin m → (HopfieldNetwork R U).State) : Matrix U U R := + let normPatterns := λ s i => (ps s).act i; + let T := ∑ s : Fin m, fun i j => normPatterns s i * normPatterns s j; + λ i j => if i = j then 0 else T i j + +/-- +The `pseudoOrthogonality` property from Hopfield's paper (Equations 3-4) states: +For random patterns, the dot product between different patterns is approximately 0, +while the dot product of a pattern with itself is approximately N. +-/ +def isPseudoOrthogonal {m : ℕ} [DecidableEq R] [DecidableEq U] + (ps : Fin m → (HopfieldNetwork R U).State) : Prop := + ∀ i j : Fin m, i ≠ j → dotProduct (ps i).act (ps j).act = 0 diff --git a/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Hopfield82/PhaseSpaceFlow.lean b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Hopfield82/PhaseSpaceFlow.lean new file mode 100644 index 000000000..36c238171 --- /dev/null +++ b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Hopfield82/PhaseSpaceFlow.lean @@ -0,0 +1,190 @@ +/- +Copyright (c) 2025 Matteo Cipollina, Michail Karatarakis. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Matteo Cipollina, Michail Karatarakis +-/ + +import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.Stochastic +import Mathlib.Algebra.Lie.OfAssociative + +/-! +# Hopfield Networks: Formalization of J.J. Hopfield's 1982 Paper + +This module formalizes the key mathematical concepts from J.J. Hopfield's 1982 paper +"Neural networks and physical systems with emergent collective computational abilities", +focusing on aspects that are not already covered in the main HopfieldNet formalization. + +The paper introduces a model of neural networks with binary neurons and studies their collective +computational properties, particularly as content-addressable memories. The key insights include: + +1. The phase space flow and stable states of the network +2. The storage capacity and pattern retrieval capabilities +3. The relationship between energy minimization and memory retrieval +4. Tolerance to noise and component failures + +## Main Components + +* `PhaseSpaceFlow`: Formalization of phase space flow and attractors +* `MemoryCapacity`: Relationship between memory capacity and error rates +* `MemoryConfusion`: Formalization of how similar memories can interfere +* `FaultTolerance`: Analysis of network robustness to component failures + +## References + +* Hopfield, J.J. (1982). Neural networks and physical systems with emergent collective + computational abilities. Proceedings of the National Academy of Sciences, 79(8), 2554-2558. +-/ + +namespace Hopfield82 + +open NeuralNetwork State Matrix Finset Real + +variable {R U : Type} [Field R] [LinearOrder R] [IsStrictOrderedRing R] [Fintype U] [Nonempty U] + +/-! ### Phase Space Flow -/ + +/-- +A `PhaseSpacePoint` represents a state in the phase space of the Hopfield system. +In the paper, this corresponds to the instantaneous state of all neurons (p.2554): +"A point in state space then represents the instantaneous condition of the system." +-/ +abbrev PhaseSpacePoint (R U : Type) + [Field R] [LinearOrder R] [IsStrictOrderedRing R] [DecidableEq U] [Fintype U] [Nonempty U] := + (HopfieldNetwork R U).State + +/-- Convert Option to Vector for threshold values --/ +def optionToVector (o : Option R) : Vector R 1 := + let arr := match o with + | some v => #[v] + | none => #[0] + Vector.mk arr (by + cases o <;> rfl) + +/-- Convert Vector to Option for threshold values --/ +def vectorToOption (v : Vector R 1) : Option R := + some (v.get 0) + +/-- Extract threshold value safely --/ +def getThreshold (θ : Option R) : R := + match θ with + | some v => v + | none => 0 + +/-- +The `localField` for neuron i in state s is the weighted sum of inputs from other neurons, +minus the threshold. This corresponds to ∑j Tij Vj - θi in the paper. +-/ + +def localField {R U : Type} [Field R] [LinearOrder R] [IsStrictOrderedRing R] + [DecidableEq U] [Fintype U] [Nonempty U] + (wθ : Params (HopfieldNetwork R U)) (s : PhaseSpacePoint R U) (i : U) : R := + + (∑ j ∈ Finset.univ, wθ.w i j * s.act j) - getThreshold (vectorToOption (wθ.θ i)) +/-- +The `updateRule` defines the neural state update according to the paper's Equation 1: + Vi → 1 if ∑j Tij Vj > Ui + Vi → 0 if ∑j Tij Vj < Ui + +In our formalization, we use -1 instead of 0 for the "not firing" state. +-/ +def updateRule [DecidableEq R] [DecidableEq U] (wθ : Params (HopfieldNetwork R U)) + (s : PhaseSpacePoint R U) (i : U) : R := + if localField wθ s i > 0 then 1 else -1 + +/-- +A `PhaseSpaceFlow` describes how the system state evolves over time. +It maps each point in phase space to its successor state after updating one neuron. + +From the paper (p.2554): "The equations of motion of the system describe a flow in state space." +-/ +def PhaseSpaceFlow [DecidableEq R] [DecidableEq U] + (wθ : Params (HopfieldNetwork R U)) (useq : ℕ → U) : PhaseSpacePoint R U → PhaseSpacePoint R U := + λ (s : PhaseSpacePoint R U) => s.Up wθ (useq 0) + +/-- +A `FixedPoint` of the phase space flow is a state that does not change under evolution. +In the paper, these correspond to the locally stable states of the network (p.2554): +"Various classes of flow patterns are possible, but the systems of use for memory +particularly include those that flow toward locally stable points..." +-/ +def FixedPoint [DecidableEq R] [DecidableEq U] (wθ : Params (HopfieldNetwork R U)) + (s : PhaseSpacePoint R U) : Prop := + s.isStable wθ + +/-- +A `BasinOfAttraction` of a fixed point is the set of all states that converge to it. +In the paper (p.2554): "Then, if the system is started sufficiently near any Xa, +as at X = Xa + Δ, it will proceed in time until X ≈ Xa." +-/ +def BasinOfAttraction [DecidableEq R] [DecidableEq U] + (wθ : Params (HopfieldNetwork R U)) (p : PhaseSpacePoint R U) + (useq : ℕ → U) (hf : fair useq) : Set (PhaseSpacePoint R U) := + {s | ∃ n : ℕ, seqStates wθ s useq n = p ∧ FixedPoint wθ p ∧ convergence_is_fair s useq hf} + where + convergence_is_fair (_ : PhaseSpacePoint R U) (useq : ℕ → U) (_ : fair useq) : Prop := fair useq + +/-- +The `EnergyLandscape` of a Hopfield network is the energy function defined over all possible states. +In the paper, this is the function E defined in Equation 7: + E = -1/2 ∑∑ Tij Vi Vj +-/ +def EnergyLandscape [DecidableEq R] [DecidableEq U] (wθ : Params (HopfieldNetwork R U)) : + PhaseSpacePoint R U → R := λ (s : PhaseSpacePoint R U) => s.E wθ + +@[simp] +lemma up_act_eq_iff_eq {R U : Type} [Field R] [LinearOrder R] [IsStrictOrderedRing R] + [DecidableEq U] [Fintype U] [Nonempty U] + {wθ : Params (HopfieldNetwork R U)} {s : (HopfieldNetwork R U).State} {u : U} : + (s.Up wθ u).act u = s.act u → s.Up wθ u = s := by + intro h_act_eq + -- Apply state extensionality + apply NeuralNetwork.ext + intro v + -- Case split on whether v equals u + by_cases h_v : v = u + · -- Case v = u: use the hypothesis directly + rw [h_v, h_act_eq] + · -- Case v ≠ u: use act_of_non_up lemma + rw [act_of_non_up wθ h_v] + +/-- +The `EnergyChange` when updating neuron i is always non-positive, +as proven in the paper with Equation 8. + +This theorem formalizes a key result from the paper: the energy function +always decreases (or remains constant) under asynchronous updates. +-/ +theorem energy_decrease [DecidableEq R] [DecidableEq U] + (wθ : Params (HopfieldNetwork R U)) (s : PhaseSpacePoint R U) (i : U) : + (s.Up wθ i).E wθ ≤ s.E wθ := by + have h_stab_or_diff : (s.Up wθ i = s) ∨ (s.Up wθ i).act i ≠ s.act i := by + let s' := s.Up wθ i + by_cases h : s'.act i = s.act i + case pos => + left + exact up_act_eq_iff_eq h + case neg => + right + exact h + cases h_stab_or_diff with + | inl h_same => + rw [h_same] + | inr h_diff => + exact energy_diff_leq_zero wθ h_diff + +/-- +This theorem captures the convergence result from the paper: +"Every initial state flows to a limit point (if synchrony is not assumed)." + +The proof leverages the HopfieldNet_convergence_fair theorem from the existing codebase. +-/ +theorem convergence_to_fixed_point [DecidableEq R] [DecidableEq U] + (wθ : Params (HopfieldNetwork R U)) (s : PhaseSpacePoint R U) + (useq : ℕ → U) (hf : fair useq) : + ∃ (p : PhaseSpacePoint R U) (n : ℕ), + seqStates wθ s useq n = p ∧ FixedPoint wθ p := by + obtain ⟨N, hN⟩ := HopfieldNet_convergence_fair s useq hf + use seqStates wθ s useq N, N + constructor + · rfl + · exact hN diff --git a/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Stochastic.lean b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Stochastic.lean index 073428551..f1b35ae6d 100644 --- a/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Stochastic.lean +++ b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Stochastic.lean @@ -323,8 +323,7 @@ lemma gibbs_total_not_top (local_field : ℝ) (T : ℝ) : /-- For a positive PMF.map application, there exists a preimage with positive probability -/ lemma pmf_map_pos_implies_preimage {α β : Type} [Fintype α] [DecidableEq β] {p : α → ENNReal} (h_pmf : ∑ a, p a = 1) (f : α → β) (y : β) : - (PMF.map f (PMF.ofFintype p h_pmf)) y > 0 → - ∃ x : α, p x > 0 ∧ f x = y := by + (PMF.map f (PMF.ofFintype p h_pmf)) y > 0 → ∃ x : α, p x > 0 ∧ f x = y := by intro h_pos simp only [PMF.map_apply] at h_pos simp_all only [PMF.ofFintype_apply, tsum_eq_filter_sum, gt_iff_lt, filter_sum_pos_iff_exists_pos, From d80b270164c86918cc77adbcc70ade983fed940a Mon Sep 17 00:00:00 2001 From: Matteo Cipollina Date: Mon, 11 Aug 2025 14:16:30 +0200 Subject: [PATCH 06/15] Update PhysLean.lean --- PhysLean.lean | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/PhysLean.lean b/PhysLean.lean index df292da32..173b2695e 100644 --- a/PhysLean.lean +++ b/PhysLean.lean @@ -355,3 +355,14 @@ import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.Stochastic import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.StochasticAux import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.aux import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.test +import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.Asym +import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.BoltzmannMachine.Core +import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.BoltzmannMachine.Markov +import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.Core +import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.Markov +import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.NNStochastic +import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.NeuralNetwork +import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.Stochastic +import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.StochasticAux +import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.aux +import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.test From d4d85e7c93e71e651459d9d65d862c6408ee775c Mon Sep 17 00:00:00 2001 From: Matteo Cipollina Date: Mon, 11 Aug 2025 14:55:51 +0200 Subject: [PATCH 07/15] Update PhysLean.lean --- PhysLean.lean | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/PhysLean.lean b/PhysLean.lean index 173b2695e..0468d9f13 100644 --- a/PhysLean.lean +++ b/PhysLean.lean @@ -344,6 +344,16 @@ import PhysLean.StringTheory.FTheory.SU5U1.Quanta.IsViable.Elems import PhysLean.StringTheory.FTheory.SU5U1.Quanta.ToList import PhysLean.StringTheory.FTheory.SU5U1.Quanta.YukawaRegeneration import PhysLean.Thermodynamics.Basic +import PhysLean.Thermodynamics.Temperature.Basic +import PhysLean.Thermodynamics.Temperature.TemperatureUnits +import PhysLean.Units.Area +import PhysLean.Units.Basic +import PhysLean.Units.Energy +import PhysLean.Units.Mass +import PhysLean.Units.Momentum.Basic +import PhysLean.Units.Pressure +import PhysLean.Units.Speed +import PhysLean.Units.Velocity import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.Asym import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.BoltzmannMachine.Core import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.BoltzmannMachine.Markov From c263dca818bc36491dd724f3f04bb26b6a0a912a Mon Sep 17 00:00:00 2001 From: Matteo Cipollina Date: Tue, 12 Aug 2025 18:41:55 +0200 Subject: [PATCH 08/15] =?UTF-8?q?feat(NN,=20Core):=20Generalized=20Neural?= =?UTF-8?q?=20Network=20activation=20values=20(`=CF=83`)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR enhances the `NeuralNetwork` framework to support polymorphic activation values (`σ`), enabling flexible modeling of diverse neuron types (e.g., `R`, `SignumType`, `Bool`). It introduces a helper map `m : σ → R` for numeric computations and refines state definitions and update mechanisms for clarity and correctness. Key Changes 1. **Polymorphic Activation Values**: `σ` parameter added to `NeuralNetwork`, with `m` for numeric mapping. 2. **State Enhancements**: Added `State.out` for numeric outputs and `onlyUi` for consistent non-input activations. 3. **Improved Update Mechanisms**: Refined `Up`, `workPhase`, and `seqStates` for generalized updates. 4. **Backward Compatibility**: Existing Hopfield networks remain functional with `σ := R` and `m := id`. No breaking changes. --- .../SpinGlasses/HopfieldNetwork/Core.lean | 50 ++++++++++------- .../HopfieldNetwork/NeuralNetwork.lean | 56 ++++++++++--------- .../HopfieldNetwork/Stochastic.lean | 2 +- 3 files changed, 60 insertions(+), 48 deletions(-) diff --git a/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Core.lean b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Core.lean index cf3b21f0c..b62411241 100644 --- a/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Core.lean +++ b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Core.lean @@ -47,7 +47,7 @@ abbrev HNfout (act : R) : R := act - `U`: A finite, nonempty set of neurons with decidable equality. -/ abbrev HopfieldNetwork (R U : Type) [Field R] [LinearOrder R] [IsStrictOrderedRing R] - [DecidableEq U] [Nonempty U] [Fintype U] : NeuralNetwork R U where + [DecidableEq U] [Nonempty U] [Fintype U] : NeuralNetwork R U R where /- The adjacency relation between neurons `u` and `v`, defined as `u ≠ v`. -/ Adj u v := u ≠ v /- The set of input neurons, defined as the universal set. -/ @@ -74,9 +74,10 @@ abbrev HopfieldNetwork (R U : Type) [Field R] [LinearOrder R] [IsStrictOrderedRi fnet u w pred _ := HNfnet u w pred /- The activation function for neuron `u`, given input and threshold `θ`. -/ fact u _ net_input_val θ_vec := HNfact (θ_vec.get 0) net_input_val - -- Ignoring the current_act_val argument - /- The output function, given the activation state `act`. -/ - fout _ act := HNfout act + /- The output function is identity since σ = R here. -/ + fout _ act := act + /- Optional σ → R map; identity since σ = R. -/ + m := id /- A predicate that the activation state `act` is either 1 or -1. -/ pact act := act = 1 ∨ act = -1 /- A proof that the activation state of neuron `u` @@ -848,9 +849,12 @@ lemma HopfieldNet_cyclic_converg (wθ : Params (HopfieldNetwork R U)) (s : State (HopfieldNet_stabilize_cyclic s useq hf).isStable wθ := (Nat.find_spec (HopfieldNet_convergence_cyclic s useq hf)).2 -lemma patterns_pairwise_orthogonal (ps : Fin m → (HopfieldNetwork R U).State) +lemma patterns_pairwise_orthogonal {m : ℕ} + (ps : Fin m → (HopfieldNetwork R U).State) (horth : ∀ {i j : Fin m} (_ : i ≠ j), dotProduct (ps i).act (ps j).act = 0) : - ∀ (j : Fin m), ((Hebbian ps).w).mulVec (ps j).act = (card U - m) * (ps j).act := by + ∀ (j : Fin m), + ((Hebbian ps).w).mulVec (ps j).act + = ((card U : R) - (m : R)) • (ps j).act := by intros k ext t unfold Hebbian @@ -860,7 +864,7 @@ lemma patterns_pairwise_orthogonal (ps : Fin m → (HopfieldNetwork R U).State) rw [Finset.sum_apply] simp only [Finset.sum_apply] unfold dotProduct at horth - have : ∀ i j, (dotProduct (ps i).act (ps j).act) = if i ≠ j then 0 else card U := by + have : ∀ i j, (dotProduct (ps i).act (ps j).act) = if i ≠ j then 0 else (card U : R) := by intros i j by_cases h : i ≠ j · specialize horth h @@ -878,17 +882,18 @@ lemma patterns_pairwise_orthogonal (ps : Fin m → (HopfieldNetwork R U).State) have hact1 : ∀ i, ((ps j).act i) * ((ps j).act i) = 1 := fun i => mul_self_eq_one_iff.mpr (hact i) calc _ = ∑ i, (ps j).act i * (ps j).act i := rfl _ = ∑ i, 1 * 1 := by simp_rw [hact1]; rw [mul_one] - _ = card U := by simp only [sum_const, card_univ, Fintype.card_fin, nsmul_eq_mul, - mul_one] + _ = (card U : R) := by + -- sum of 1 over all i in U is card U, cast to R + simp only [sum_const, card_univ, Fintype.card_fin, nsmul_eq_mul, mul_one, Nat.cast_ofNat] simp only [dotProduct, ite_not, Nat.cast_ite, Nat.cast_zero] at this conv => enter [1,2]; ext l; rw [sub_mul]; rw [sum_mul]; conv => enter [1,2]; ext i; rw [mul_assoc] rw [Finset.sum_sub_distrib] nth_rw 1 [sum_comm] calc _= ∑ y : Fin m, (ps y).act t * ∑ x , ((ps y).act x * (ps k).act x) - - ∑ x , ↑m * (1 : Matrix U U R) t x * (ps k).act x := ?_ - _= ∑ y : Fin m, (ps y).act t * (if y ≠ k then 0 else card U) - - ∑ x , ↑m * (1 : Matrix U U R) t x * (ps k).act x := ?_ - _ = (card U - ↑m) * (ps k).act t := ?_ + - ∑ x , (m : R) * (1 : Matrix U U R) t x * (ps k).act x := ?_ + _= ∑ y : Fin m, (ps y).act t * (if y ≠ k then 0 else (card U : R)) - + ∑ x , (m : R) * (1 : Matrix U U R) t x * (ps k).act x := ?_ + _ = (((card U : R) - (m : R)) * (ps k).act t) := ?_ · simp only [sub_left_inj]; rw [Finset.sum_congr rfl] exact fun x _ => (mul_sum univ (fun i => (ps x).act i * (ps k).act i) ((ps x).act t)).symm · simp only [sub_left_inj]; rw [Finset.sum_congr rfl]; intros i _ @@ -897,10 +902,11 @@ lemma patterns_pairwise_orthogonal (ps : Fin m → (HopfieldNetwork R U).State) conv => enter [1,2,2]; ext k; rw [mul_assoc] rw [← mul_sum, mul_comm] simp only [one_apply, ite_mul, one_mul, zero_mul, Finset.sum_ite_eq, mem_univ, reduceIte] - exact (sub_mul (card U : R) m ((ps k).act t)).symm + exact (sub_mul (card U : R) (m : R) ((ps k).act t)).symm -lemma stateisStablecondition (ps : Fin m → (HopfieldNetwork R U).State) - (s : (HopfieldNetwork R U).State) c (hc : 0 < c) +lemma stateisStablecondition {m : ℕ} + (ps : Fin m → (HopfieldNetwork R U).State) + (s : (HopfieldNetwork R U).State) (c : R) (hc : 0 < c) (hw : ∀ u, ((Hebbian ps).w).mulVec s.act u = c * s.act u) : s.isStable (Hebbian ps) := by intros u unfold Up out @@ -925,12 +931,14 @@ lemma stateisStablecondition (ps : Fin m → (HopfieldNetwork R U).State) · rfl exact (Hebbian ps).hw u u fun a => a rfl -lemma Hebbian_stable (hm : m < card U) (ps : Fin m → (HopfieldNetwork R U).State) (j : Fin m) +lemma Hebbian_stable {m : ℕ} + (hm : m < card U) (ps : Fin m → (HopfieldNetwork R U).State) (j : Fin m) (horth : ∀ {i j : Fin m} (_ : i ≠ j), dotProduct (ps i).act (ps j).act = 0): isStable (Hebbian ps) (ps j) := by unfold isStable - have := patterns_pairwise_orthogonal ps horth j - have hmn0 : 0 < (card U - m : R) := by - simpa only [sub_pos, Nat.cast_lt] - apply stateisStablecondition ps (ps j) (card U - m) hmn0 + have := patterns_pairwise_orthogonal (ps := ps) horth j + have hmn0 : 0 < ((card U : R) - (m : R)) := by + have : (m : R) < (card U : R) := by exact_mod_cast hm + exact sub_pos.mpr this + apply stateisStablecondition (ps := ps) (s := (ps j)) (((card U : R) - (m : R))) hmn0 · intros u; rw [funext_iff] at this; exact this u diff --git a/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/NeuralNetwork.lean b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/NeuralNetwork.lean index 05ae88be9..c2262553d 100644 --- a/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/NeuralNetwork.lean +++ b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/NeuralNetwork.lean @@ -12,13 +12,14 @@ open Mathlib Finset /- A `NeuralNetwork` models a neural network with: -- `R`: Type for weights and activations. +- `R`: Type for weights and numeric computations. - `U`: Type for neurons. +- `σ`: Type for neuron activation values. - `[Zero R]`: `R` has a zero element. It extends `Digraph U` and includes the network's architecture, activation functions, and constraints. -/ -structure NeuralNetwork (R U : Type) [Zero R] extends Digraph U where +structure NeuralNetwork (R U : Type) (σ : Type) [Zero R] extends Digraph U where /-- Input neurons. -/ (Ui Uo Uh : Set U) /-- There is at least one input neuron. -/ @@ -33,27 +34,31 @@ structure NeuralNetwork (R U : Type) [Zero R] extends Digraph U where (κ1 κ2 : U → ℕ) /-- Computes the net input to a neuron. -/ (fnet : ∀ u : U, (U → R) → (U → R) → Vector R (κ1 u) → R) - /-- Computes the activation of a neuron. -/ - (fact : ∀ u : U, R → R → Vector R (κ2 u) → R) -- R_current_activation, R_net_input, params - /-- Computes the final output of a neuron. -/ - (fout : ∀ _ : U, R → R) - /-- Predicate on activations. -/ - (pact : R → Prop) + /-- Computes the activation of a neuron (polymorphic σ). -/ + (fact : ∀ u : U, σ → R → Vector R (κ2 u) → σ) -- current σ, net input, params → σ + /-- Converts an activation value to a numeric output for computation. -/ + (fout : ∀ _ : U, σ → R) + /-- Optional helper map σ → R (can be same as fout u if independent of u). -/ + (m : σ → R) + /-- Predicate on activations (in σ). -/ + (pact : σ → Prop) /-- Predicate on weight matrices. -/ (pw : Matrix U U R → Prop) - /-- If all activations satisfy `pact`, then the activations computed by `fact` also satisfy `pact`. -/ - (hpact : ∀ (w : Matrix U U R) (_ : ∀ u v, ¬ Adj u v → w u v = 0) (_ : pw w) - (σ : (u : U) → Vector R (κ1 u)) (θ : (u : U) → Vector R (κ2 u)) (current_neuron_activations : U → R), - (∀ u_idx : U, pact (current_neuron_activations u_idx)) → -- Precondition on all current activations - (∀ u_target : U, pact (fact u_target (current_neuron_activations u_target) -- Pass current_act of target neuron - (fnet u_target (w u_target) (fun v => fout v (current_neuron_activations v)) (σ u_target)) - (θ u_target)))) - - -variable {R U : Type} [Zero R] + /-- If all activations satisfy `pact`, then the next activations computed by `fact` also satisfy `pact`. -/ + (hpact : + ∀ (w : Matrix U U R) (_ : ∀ u v, ¬ Adj u v → w u v = 0) (_ : pw w) + (σv : (u : U) → Vector R (κ1 u)) (θ : (u : U) → Vector R (κ2 u)) + (current : U → σ), + (∀ u_idx : U, pact (current u_idx)) → + ∀ u_target : U, + pact (fact u_target (current u_target) + (fnet u_target (w u_target) (fun v => fout v (current v)) (σv u_target)) + (θ u_target))) + +variable {R U σ : Type} [Zero R] /-- `Params` is a structure that holds the parameters for a neural network `NN`. -/ -structure Params (NN : NeuralNetwork R U) where +structure Params (NN : NeuralNetwork R U σ) where (w : Matrix U U R) (hw : ∀ u v, ¬ NN.Adj u v → w u v = 0) (hw' : NN.pw w) @@ -62,13 +67,13 @@ structure Params (NN : NeuralNetwork R U) where namespace NeuralNetwork -structure State (NN : NeuralNetwork R U) where - act : U → R +structure State (NN : NeuralNetwork R U σ) where + act : U → σ hp : ∀ u : U, NN.pact (act u) /-- Extensionality lemma for neural network states -/ @[ext] -lemma ext {R U : Type} [Zero R] {NN : NeuralNetwork R U} +lemma ext {R U σ : Type} [Zero R] {NN : NeuralNetwork R U σ} {s₁ s₂ : NN.State} : (∀ u, s₁.act u = s₂.act u) → s₁ = s₂ := by intro h cases s₁ @@ -79,15 +84,14 @@ lemma ext {R U : Type} [Zero R] {NN : NeuralNetwork R U} namespace State -variable {NN : NeuralNetwork R U} (wσθ : Params NN) (s : NN.State) +variable {NN : NeuralNetwork R U σ} (wσθ : Params NN) (s : NN.State) def out (u : U) : R := NN.fout u (s.act u) def net (u : U) : R := NN.fnet u (wσθ.w u) (fun v => s.out v) (wσθ.σ u) -def onlyUi : Prop := ∀ u : U, ¬ u ∈ NN.Ui → s.act u = 0 - +def onlyUi : Prop := ∃ σ0 : σ, ∀ u : U, u ∉ NN.Ui → s.act u = σ0 variable [DecidableEq U] -def Up {NN_local : NeuralNetwork R U} (s : NN_local.State) (wσθ : Params NN_local) (u_upd : U) : NN_local.State := +def Up {NN_local : NeuralNetwork R U σ} (s : NN_local.State) (wσθ : Params NN_local) (u_upd : U) : NN_local.State := { act := fun v => if v = u_upd then NN_local.fact u_upd (s.act u_upd) (NN_local.fnet u_upd (wσθ.w u_upd) (fun n => s.out n) (wσθ.σ u_upd)) diff --git a/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Stochastic.lean b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Stochastic.lean index f1b35ae6d..3b2d71744 100644 --- a/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Stochastic.lean +++ b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Stochastic.lean @@ -6,7 +6,7 @@ Authors: Matteo Cipollina import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.NNStochastic import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.StochasticAux -import PhysLean.StatisticalMechanics.Temperature +import PhysLean.Thermodynamics.Temperature.Basic import Mathlib.Analysis.RCLike.Basic import Mathlib.LinearAlgebra.AffineSpace.AffineMap import Mathlib.LinearAlgebra.Dual.Lemmas From 3f3af061ee16da00fb40c1176e6b947a855d8262 Mon Sep 17 00:00:00 2001 From: Matteo Cipollina Date: Sat, 16 Aug 2025 11:56:12 +0200 Subject: [PATCH 09/15] feat(MCMC):DetailedBalance Add reversibility (detailed balance) for Markov kernels and show it implies invariance --- .../HopfieldNetwork/DetailedBalanceGen.lean | 38 +++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/DetailedBalanceGen.lean diff --git a/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/DetailedBalanceGen.lean b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/DetailedBalanceGen.lean new file mode 100644 index 000000000..2ac50d000 --- /dev/null +++ b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/DetailedBalanceGen.lean @@ -0,0 +1,38 @@ +import Mathlib.Probability.Kernel.Invariance + +open MeasureTheory Filter Set + +open scoped ProbabilityTheory + +variable {α : Type*} [MeasurableSpace α] + +namespace ProbabilityTheory.Kernel + +/-- +Reversibility (detailed balance) of a Markov kernel `κ` w.r.t. a (σ-finite) measure `π`: +for all measurable sets `A B`, the mass flowing from `A` to `B` equals that from `B` to `A`. +-/ +def IsReversible (κ : Kernel α α) (π : Measure α) : Prop := + ∀ ⦃A B⦄, MeasurableSet A → MeasurableSet B → + ∫⁻ x in A, κ x B ∂π = ∫⁻ x in B, κ x A ∂π + +/-- +A reversible Markov kernel leaves the measure `π` invariant. +Proof uses detailed balance with `B = univ` and `κ x univ = 1`. +-/ +theorem Invariant.of_IsReversible + {κ : Kernel α α} [IsMarkovKernel κ] {π : Measure α} + (h_rev : IsReversible κ π) : Invariant κ π := by + ext s hs + have h' := (h_rev hs MeasurableSet.univ).symm + have h'' : ∫⁻ x, κ x s ∂π = ∫⁻ x in s, κ x Set.univ ∂π := by + simpa [Measure.restrict_univ] using h' + have hConst : ∫⁻ x in s, κ x Set.univ ∂π = π s := by + classical + simp [measure_univ, lintegral_const, hs] + have hπ : ∫⁻ x, κ x s ∂π = π s := h''.trans hConst + calc + (π.bind κ) s = ∫⁻ x, κ x s ∂π := Measure.bind_apply hs (Kernel.aemeasurable _) + _ = π s := hπ + +end ProbabilityTheory.Kernel From 24332b4c82e9e96e5bb5c9a118a2d4cbecb482a3 Mon Sep 17 00:00:00 2001 From: Matteo Cipollina Date: Sun, 17 Aug 2025 00:11:49 +0200 Subject: [PATCH 10/15] feat(Temperature): make Temperature a structure, add convergence lemmas --- .../Thermodynamics/Temperature/Basic.lean | 221 +++++++++++++----- 1 file changed, 169 insertions(+), 52 deletions(-) diff --git a/PhysLean/Thermodynamics/Temperature/Basic.lean b/PhysLean/Thermodynamics/Temperature/Basic.lean index 2c9b106e7..b7f6106fc 100644 --- a/PhysLean/Thermodynamics/Temperature/Basic.lean +++ b/PhysLean/Thermodynamics/Temperature/Basic.lean @@ -1,7 +1,7 @@ /- Copyright (c) 2025 Joseph Tooby-Smith. All rights reserved. Released under Apache 2.0 license as described in the file LICENSE. -Authors: Joseph Tooby-Smith +Authors: Joseph Tooby-Smith, Matteo Cipollina -/ import Mathlib.Analysis.Calculus.Deriv.Inv import Mathlib.Analysis.InnerProductSpace.Basic @@ -22,101 +22,218 @@ The choice of units can be made on a case-by-case basis, as long as they are don -/ open NNReal -TODO "IOY4E" "Change the definition of `Temperature` to be a structure rather than a `def`." - /-- The type `Temperature` represents the temperature in a given (but arbitary) set of units - (preserving zero). -/ -def Temperature : Type := ℝ≥0 + (preserving zero). It currently wraps `ℝ≥0`, i.e., absolute temperature in nonnegative reals. -/ +structure Temperature where + /-- The nonnegative real value of the temperature. -/ + val : ℝ≥0 namespace Temperature open Constants -noncomputable instance : LinearOrder Temperature := - Subtype.instLinearOrder _ +/-- Coercion to `ℝ≥0`. -/ +instance : Coe Temperature ℝ≥0 := ⟨fun T => T.val⟩ -/-- The underlying real-number associated with the tempature. -/ -noncomputable def toReal (T : Temperature) : ℝ := NNReal.toReal T +/-- The underlying real-number associated with the temperature. -/ +noncomputable def toReal (T : Temperature) : ℝ := NNReal.toReal T.val +/-- Coercion to `ℝ`. -/ noncomputable instance : Coe Temperature ℝ := ⟨toReal⟩ -instance : TopologicalSpace Temperature := inferInstanceAs (TopologicalSpace ℝ≥0) +/-- Topology on `Temperature` induced from `ℝ≥0`. -/ +instance : TopologicalSpace Temperature := + TopologicalSpace.induced (fun T : Temperature => (T.val : ℝ≥0)) inferInstance + +instance : Zero Temperature := ⟨⟨0⟩⟩ -instance : Zero Temperature := ⟨0, Preorder.le_refl 0⟩ +@[ext] lemma ext {T₁ T₂ : Temperature} (h : T₁.val = T₂.val) : T₁ = T₂ := by + cases T₁; cases T₂; cases h; rfl /-- The inverse temperature defined as `1/(kB * T)` in a given, but arbitary set of units. This has dimensions equivalent to `Energy`. -/ -noncomputable def β (T : Temperature) : ℝ≥0 := ⟨1 / (kB * T), by - apply div_nonneg - · exact zero_le_one' ℝ - · apply mul_nonneg - · exact kB_nonneg - · exact T.2⟩ +noncomputable def β (T : Temperature) : ℝ≥0 := + ⟨1 / (kB * (T : ℝ)), by + apply div_nonneg + · exact zero_le_one + · apply mul_nonneg + · exact kB_nonneg + · simp [toReal]⟩ /-- The temperature associated with a given inverse temperature `β`. -/ noncomputable def ofβ (β : ℝ≥0) : Temperature := - ⟨1 / (kB * β), by - apply div_nonneg - · exact zero_le_one' ℝ - · apply mul_nonneg - · exact kB_nonneg - · exact β.2⟩ - -lemma ofβ_eq : ofβ = fun β => ⟨1 / (kB * β), by + ⟨⟨1 / (kB * β), by + apply div_nonneg + · exact zero_le_one + · apply mul_nonneg + · exact kB_nonneg + · exact β.2⟩⟩ + +lemma ofβ_eq : ofβ = fun β => ⟨⟨1 / (kB * β), by apply div_nonneg - · exact zero_le_one' ℝ + · exact zero_le_one · apply mul_nonneg · exact kB_nonneg - · exact β.2⟩ := by rfl + · exact β.2⟩⟩ := by + rfl @[simp] lemma β_ofβ (β' : ℝ≥0) : β (ofβ β') = β' := by - simp [β, ofβ] ext - change ((↑β' : ℝ)⁻¹ * (↑kB : ℝ)⁻¹)⁻¹ * (kB)⁻¹ = _ + simp [β, ofβ, toReal] field_simp [kB_neq_zero] @[simp] lemma ofβ_β (T : Temperature) : ofβ (β T) = T := by - simp [ofβ_eq, β] - apply Subtype.ext - simp only [val_eq_coe] - field_simp [kB_neq_zero] - rfl + ext + change ((1 : ℝ) / (kB * ((β T : ℝ)))) = (T : ℝ) + have : (β T : ℝ) = (1 : ℝ) / (kB * (T : ℝ)) := rfl + simpa [this] using + show (1 / (kB * (1 / (kB * (T : ℝ))))) = (T : ℝ) from by + field_simp [kB_neq_zero] + +/-! ### Regularity of `ofβ` -/ + +open Filter Topology lemma ofβ_continuousOn : ContinuousOn (ofβ : ℝ≥0 → Temperature) (Set.Ioi 0) := by rw [ofβ_eq] refine continuousOn_of_forall_continuousAt ?_ - intro x h - have h1 : ContinuousAt (fun x => 1 / (kB * x)) x.1 := by + intro x hx + have h1 : ContinuousAt (fun t : ℝ => 1 / (kB * t)) x.1 := by refine ContinuousAt.div₀ ?_ ?_ ?_ · fun_prop · fun_prop · simp - apply And.intro + constructor · exact kB_neq_zero - · exact ne_of_gt h - rw [@Metric.continuousAt_iff] at ⊢ h1 - intro ε hε - obtain ⟨δ, hδ, h1⟩ := h1 ε hε - refine ⟨δ, hδ, ?_⟩ - intro x h - exact h1 h + · exact ne_of_gt hx + have hℝ : ContinuousAt (fun b : ℝ≥0 => (1 : ℝ) / (kB * (b : ℝ))) x := + h1.comp (continuous_subtype_val.continuousAt) + have hNN : + ContinuousAt (fun b : ℝ≥0 => + (⟨(1 : ℝ) / (kB * (b : ℝ)), + by + have hb : 0 ≤ kB * (b : ℝ) := + mul_nonneg kB_nonneg (by exact_mod_cast (show 0 ≤ b from b.2)) + exact div_nonneg zero_le_one hb⟩ : ℝ≥0)) x := + hℝ.codRestrict (fun b => by + have hb : 0 ≤ kB * (b : ℝ) := + mul_nonneg kB_nonneg (by exact_mod_cast (show 0 ≤ b from b.2)) + exact div_nonneg zero_le_one hb) + have hind : Topology.IsInducing (fun T : Temperature => (T.val : ℝ≥0)) := ⟨rfl⟩ + have : Tendsto (fun b : ℝ≥0 => ofβ b) (𝓝 x) (𝓝 (ofβ x)) := by + simp [hind.nhds_eq_comap, Function.comp, ofβ_eq] + simp_all only [Set.mem_Ioi, one_div, mul_inv_rev, val_eq_coe] + exact hNN + exact this lemma ofβ_differentiableOn : - DifferentiableOn ℝ (fun (x : ℝ) => (ofβ (Real.toNNReal x)).1) (Set.Ioi 0) := by + DifferentiableOn ℝ (fun (x : ℝ) => ((ofβ (Real.toNNReal x)).val : ℝ)) (Set.Ioi 0) := by refine DifferentiableOn.congr (f := fun x => 1 / (kB * x)) ?_ ?_ · refine DifferentiableOn.fun_div ?_ ?_ ?_ · fun_prop · fun_prop · intro x hx - simp only [ne_eq, mul_eq_zero, not_or] - apply And.intro - · exact kB_neq_zero - · exact ne_of_gt hx + have hx0 : x ≠ 0 := ne_of_gt (by simpa using hx) + simp [mul_eq_zero, kB_neq_zero, hx0] · intro x hx - simp [ofβ_eq] - simp at hx - left + simp [ofβ_eq] at hx + have hx' : 0 < x := by simpa using hx + simp [ofβ_eq, hx'.le, Real.toNNReal, NNReal.coe_mk, hx'.ne'] + +/-! ### Convergence -/ + +open Filter Topology + +/-- Eventually, `ofβ β` is positive as β → ∞`. -/ +lemma eventually_pos_ofβ : ∀ᶠ b : ℝ≥0 in atTop, ((Temperature.ofβ b : Temperature) : ℝ) > 0 := by + have hge : ∀ᶠ b : ℝ≥0 in atTop, (1 : ℝ≥0) ≤ b := Filter.eventually_ge_atTop 1 + refine hge.mono ?_ + intro b hb + have hbpos : 0 < (b : ℝ) := (zero_lt_one.trans_le hb) + have hden : 0 < kB * (b : ℝ) := mul_pos kB_pos hbpos + have : 0 < (1 : ℝ) / (kB * (b : ℝ)) := one_div_pos.mpr hden + simpa [Temperature.ofβ, one_div, Temperature.toReal] using this + +/-- General helper: for any `a > 0`, we have `1 / (a * b) → 0` as `b → ∞` in `ℝ≥0`. -/ +private lemma tendsto_const_inv_mul_atTop (a : ℝ) (ha : 0 < a) : + Tendsto (fun b : ℝ≥0 => (1 : ℝ) / (a * (b : ℝ))) atTop (𝓝 (0 : ℝ)) := by + refine Metric.tendsto_nhds.2 ?_ + intro ε hε + have hεpos : 0 < ε := hε + let Breal : ℝ := (1 / (a * ε)) + 1 + have hBpos : 0 < Breal := by + have : 0 < (1 / (a * ε)) := by + have : 0 < a * ε := mul_pos ha hεpos + exact one_div_pos.mpr this linarith + let B : ℝ≥0 := ⟨Breal, le_of_lt hBpos⟩ + have h_ev : ∀ᶠ b : ℝ≥0 in atTop, b ≥ B := Filter.eventually_ge_atTop B + refine h_ev.mono ?_ + intro b hb + have hBposR : 0 < (B : ℝ) := hBpos + have hbposR : 0 < (b : ℝ) := by + have hBB : (B : ℝ) ≤ (b : ℝ) := by exact_mod_cast hb + exact lt_of_lt_of_le hBposR hBB + have hb0 : 0 < (a * (b : ℝ)) := mul_pos ha hbposR + have hB0 : 0 < (a * (B : ℝ)) := mul_pos ha hBposR + have hmono : (1 : ℝ) / (a * (b : ℝ)) ≤ (1 : ℝ) / (a * (B : ℝ)) := by + have hBB : (B : ℝ) ≤ (b : ℝ) := by exact_mod_cast hb + have hden_le : (a * (B : ℝ)) ≤ (a * (b : ℝ)) := + mul_le_mul_of_nonneg_left hBB (le_of_lt ha) + simpa [one_div] using one_div_le_one_div_of_le hB0 hden_le + have hB_gt_base : (1 / (a * ε)) < (B : ℝ) := by + simp [B, Breal] + have hden_gt : (1 / ε) < (a * (B : ℝ)) := by + have h' := mul_lt_mul_of_pos_left hB_gt_base ha + have hane : a ≠ 0 := ne_of_gt ha + have hx' : a * (ε⁻¹ * a⁻¹) = (1 / ε) := by + have : a * (ε⁻¹ * a⁻¹) = ε⁻¹ := by + simp [mul_comm, mul_left_comm, mul_assoc, hane] + simpa [one_div] using this + simpa [hx'] using h' + have hpos : 0 < (1 / ε) := by simpa [one_div] using inv_pos.mpr hεpos + have hBbound : (1 : ℝ) / (a * (B : ℝ)) < ε := by + have := one_div_lt_one_div_of_lt hpos hden_gt + simpa [one_div, inv_div] using this + set A : ℝ := (1 : ℝ) / (a * (b : ℝ)) with hA + have hA_nonneg : 0 ≤ A := by + have : 0 ≤ a * (b : ℝ) := + mul_nonneg (le_of_lt ha) (by exact_mod_cast (show 0 ≤ b from b.2)) + simpa [hA] using div_nonneg zero_le_one this + have hxlt : A < ε := by + have := lt_of_le_of_lt hmono hBbound + simpa [hA] using this + have hAbs : |A| < ε := by + simpa [abs_of_nonneg hA_nonneg] using hxlt + have hAbs' : |A - 0| < ε := by simpa [sub_zero] using hAbs + have hdist : dist A 0 < ε := by simpa [Real.dist_eq] using hAbs' + simpa [Real.dist_eq, hA, one_div, mul_comm, mul_left_comm, mul_assoc] using hdist + +/-- Core convergence: as β → ∞, `toReal (ofβ β) → 0` in `ℝ`. -/ +lemma tendsto_toReal_ofβ_atTop : + Tendsto (fun b : ℝ≥0 => (Temperature.ofβ b : ℝ)) + atTop (𝓝 (0 : ℝ)) := by + have hform : + (fun b : ℝ≥0 => (Temperature.ofβ b : ℝ)) + = (fun b : ℝ≥0 => (1 : ℝ) / (kB * (b : ℝ))) := by + funext b; simp [Temperature.ofβ, Temperature.toReal] + have hsrc : + Tendsto (fun b : ℝ≥0 => (1 : ℝ) / (kB * (b : ℝ))) atTop (𝓝 (0 : ℝ)) := + tendsto_const_inv_mul_atTop kB kB_pos + simpa [hform] using hsrc + +/-- As β → ∞, T = ofβ β → 0+ in ℝ (within Ioi 0). -/ +lemma tendsto_ofβ_atTop : + Tendsto (fun b : ℝ≥0 => (Temperature.ofβ b : ℝ)) + atTop (nhdsWithin 0 (Set.Ioi 0)) := by + have h_to0 := tendsto_toReal_ofβ_atTop + have h_into : + Tendsto (fun b : ℝ≥0 => (Temperature.ofβ b : ℝ)) atTop (𝓟 (Set.Ioi (0 : ℝ))) := + tendsto_principal.2 (by simpa using Temperature.eventually_pos_ofβ) + have : Tendsto (fun b : ℝ≥0 => (Temperature.ofβ b : ℝ)) + atTop ((nhds (0 : ℝ)) ⊓ 𝓟 (Set.Ioi (0 : ℝ))) := + tendsto_inf.2 ⟨h_to0, h_into⟩ + simpa [nhdsWithin] using this end Temperature From 0b11cb258951fe2a4f1dd56e99ad5bb121baa7ee Mon Sep 17 00:00:00 2001 From: Matteo Cipollina Date: Sun, 17 Aug 2025 00:14:52 +0200 Subject: [PATCH 11/15] feat(NeuralNetwork): generalize HN/BM to TwoState with update and EnergySpec abstractions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This file builds a general interface for two-state neural networks, provides three concrete Hopfield-style instances, develops a one-site Gibbs update kernel, and proves convergence of this kernel to a deterministic zero-temperature limit as `β → ∞` (equivalently, `T → 0+`). --- .../SpinGlasses/HopfieldNetwork/TwoState.lean | 1419 +++++++++++++++++ 1 file changed, 1419 insertions(+) create mode 100644 PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/TwoState.lean diff --git a/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/TwoState.lean b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/TwoState.lean new file mode 100644 index 000000000..a4caa8dad --- /dev/null +++ b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/TwoState.lean @@ -0,0 +1,1419 @@ + +import Mathlib.LinearAlgebra.Matrix.Symmetric +import Mathlib.Data.Matrix.Reflection +import Mathlib.Data.Vector.Defs +import Init.Data.Vector.Lemmas +import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.aux +import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.NeuralNetwork +import PhysLean.Thermodynamics.Temperature.Basic +import PhysLean.StatisticalMechanics.CanonicalEnsemble.TwoState + +import Mathlib.Probability.Kernel.Invariance +import Mathlib.Analysis.SpecialFunctions.Exp + +import Mathlib.Topology.Basic +import Mathlib.Topology.Instances.Real.Lemmas + +import Mathlib + +/-! +# Two-state Hopfield networks: Gibbs update and zero-temperature limit + +This file builds a general interface for two-state neural networks, provides +three concrete Hopfield-style instances, develops a one-site Gibbs update +kernel, and proves convergence of this kernel to a deterministic zero-temperature +limit as `β → ∞` (equivalently, `T → 0+`). + +## Overview + +- Abstract typeclass `TwoStateNeuralNetwork` for two-valued activations: + it exposes distinguished states `σ_pos`, `σ_neg`, a canonical index `θ0` for the + single threshold parameter, and order data connecting the numeric embedding `m` + of the two states (`m σ_neg < m σ_pos`). + +- Three concrete encodings of (symmetric) Hopfield parameters: + * `SymmetricBinary R U` with σ = R and activations in {-1, 1}. + * `SymmetricSignum R U` with σ = Signum (a type-level two-point type). + * `ZeroOne R U` with σ = R and activations in {0, 1}. + +- A scale parameter, pushed along a ring hom `f`: + * `scale f : ℝ` and its ring-generalization `scaleS f : S`. + These quantify the numeric gap between `σ_pos` and `σ_neg` in the image of `f`. + +- Probabilistic update at positive temperature: + * `logisticProb x := 1 / (1 + exp(-x))` with basic bounds + `logisticProb_nonneg` and `logisticProb_le_one`. + * `probPos f p T s u : ℝ` gives `P(σ_u = σ_pos)` for one Gibbs update, formed from + logisticProb with argument `(scale f) * f(local field) * β(T)`. + +- One-site PMF update (and sweeps): + * `updPos`, `updNeg`: force a single site to `σ_pos`/`σ_neg`. + * `gibbsUpdate f p T s u : PMF State` is the one-site Gibbs update. + * `zeroTempDet p s u` is the deterministic threshold update at `T = 0`. + * `gibbsSweepAux`, `gibbsSweep`: sequential composition over a list of sites. + +- Energy specification abstraction: + * `EnergySpec` and a simplified `EnergySpec'` bundle together a global energy + `E` and a local-field `localField` satisfying: + `f (E p (updPos s u) - E p (updNeg s u)) + = - (scale f) * f (localField p s u)`. + It follows that: + `probPos f p T s u = logisticProb (- f(ΔE) * β(T))` + where `ΔE := E p (updPos s u) - E p (updNeg s u)`. + +- Zero-temperature limit: + * `scale_pos f` shows `scale f > 0` for an injective order-embedding `f`. + * `zeroTempLimitPMF p s u` is the limiting one-step kernel at `T = 0+`, which: + - moves deterministically to `updPos` if local field is positive, + - to `updNeg` if negative, + - and splits 1/2–1/2 between them on a tie. + * Main theorem: + `gibbs_update_tends_to_zero_temp_limit f hf p s u` + states the PMF `gibbsUpdate f p (tempOfBeta b) s u` converges (as `b → ∞`) + to `zeroTempLimitPMF p s u`, pointwise on states. + +## Key definitions and lemmas + +- Interfaces and instances: + * `class TwoStateNeuralNetwork NN`: + exposes `σ_pos`, `σ_neg`, `θ0`, and ordering data `m σ_neg < m σ_pos`. + * Instances: + - `instTwoStateSymmetricBinary` for `SymmetricBinary`, + - `instTwoStateSignum` for `SymmetricSignum`, + - `instTwoStateZeroOne` for `ZeroOne`. + +- Scaling gadgets: + * `scale f : ℝ`, `scaleS f : S`: + gap between the numeric images of `σ_pos` and `σ_neg`. + Specializations: + - `scale_binary f = f 2` on `SymmetricBinary`, + - `scale_zeroOne f = f 1` on `ZeroOne`. + +- Gibbs probability and PMFs: + * `logisticProb` with bounds: + `logisticProb_nonneg`, `logisticProb_le_one`. + * `probPos f p T s u`: + `= logisticProb ((scale f) * f(localField) * β(T))`. + * `updPos`, `updNeg`: single-site state updates. + * `gibbsUpdate f p T s u : PMF State`: one-site Gibbs step. + * `gibbsSweepAux`, `gibbsSweep`: sequential composition over a list of sites. + +- Energy view: + * `EnergySpec`, `EnergySpec'` and the fundamental relation + `f (E p (updPos s u) - E p (updNeg s u)) + = - (scale f) * f (localField p s u)`. + * `EnergySpec.probPos_eq_of_energy`: + re-express `probPos` via the energy difference `ΔE`. + +- Convergence to zero temperature: + * `scale_pos f`: positivity of `scale f` under an injective order embedding. + * `zeroTempLimitPMF p s u : PMF State`: the `T → 0+` limit kernel. + * Pointwise convergence on the only two reachable states: + - `gibbs_update_tends_to_zero_temp_limit_apply_updPos` + - `gibbs_update_tends_to_zero_temp_limit_apply_updNeg` + and zero limit on any other target state + - `gibbs_update_tends_to_zero_temp_limit_apply_other`. + * Main PMF convergence: + `gibbs_update_tends_to_zero_temp_limit f hf p s u`. + +## Typeclass and notation prerequisites + +- Base field and order: + `[Field R] [LinearOrder R] [IsStrictOrderedRing R]`. +- Sites: + `[DecidableEq U] [Fintype U] [Nonempty U]` for finite networks and + decidable equality on sites. +- Embeddings for the scale and probabilities: + - For `scale` and `probPos` we use a typeclass-driven `f` that is simultaneously + a ring hom (`[RingHomClass F R ℝ]`) and (when needed for convergence) + an order embedding (`[OrderHomClass F R ℝ]`) with `Function.Injective f`. + +- Probability monad: + uses `PMF` from Mathlib. The constructions are purely discrete. + +## Design notes + +- The class `TwoStateNeuralNetwork` abstracts the parts of a network used by a + two-site threshold update: a canonical scalar threshold index `θ0`, the result + of `fact` at or above threshold (`σ_pos`) and strictly below (`σ_neg`), and + numeric ordering `m σ_neg < m σ_pos`. +- The concrete Hopfield instances share the same adjacency and `κ2 = 1` setup, + but differ in the activation alphabet and decoding map `m`. +- The scaling `scale f` is a thin adapter that makes formulas uniform across + different encodings, so that Gibbs updates depending on `f(local field)` and + `β(T)` can be stated once for all `NN`. + +## Usage + +- To run one Gibbs update at site `u`: + ``` + gibbsUpdate f p T s u : PMF _ + ``` +- To sweep a list of sites in order: + ``` + gibbsSweep order p T f s0 : PMF _ + ``` +- If you can provide an `EnergySpec`, then: + ``` + probPos f p T s u = logisticProb (- f(ΔE) * β(T)) + ``` + where `ΔE = E p (updPos s u) - E p (updNeg s u)`. + +- The zero-temperature limit theorem applies once you supply an `f` that is both + a ring hom and an injective order embedding (via the corresponding typeclasses). + +## TODO + +- Finish the (commented) section proving reversibility and invariance of the + random-scan Gibbs kernel w.r.t. the Boltzmann distribution, after adding a + finite enumeration of states and the necessary summation lemmas. +- Provide convenience embeddings `f : R →+* ℝ` for common `R` (e.g. `R = ℝ`). + +-/ + +open Finset Matrix NeuralNetwork State Constants Temperature Filter Topology +open scoped ENNReal NNReal BigOperators +open NeuralNetwork + +--variable {R U σ : Type} +--variable {R U σ : Type*} +universe uR uU uσ + +-- (Optional) you can also parametrize earlier variables with these universes if desired: +variable {R : Type uR} {U : Type uU} {σ : Type uσ} + +/-- A minimal two-point activation alphabet. + +This class specifies: +- `σ_pos`: the distinguished “positive” activation, +- `σ_neg`: the distinguished “negative” activation, +- `embed`: a numeric embedding `σ → R` used to interpret activations in the ambient ring `R`. +-/ +class TwoPointActivation (R : Type uR) (σ : Type uσ) where + /-- Distinguished “positive” activation state. -/ + σ_pos : σ + /-- Distinguished “negative” activation state. -/ + σ_neg : σ + /-- Numeric embedding of activation states into the ambient ring `R`. -/ + embed : σ → R + +/-- Scale between the two distinguished activations in `σ`, computed in `R` via the embedding. +It is defined as `embed σ_pos - embed σ_neg`. -/ +@[simp] def twoPointScale {R σ} [Sub R] [TwoPointActivation R σ] : R := + TwoPointActivation.embed (R:=R) (σ:=σ) (TwoPointActivation.σ_pos (R:=R) (σ:=σ)) - + TwoPointActivation.embed (R:=R) (σ:=σ) (TwoPointActivation.σ_neg (R:=R) (σ:=σ)) + +/-- Two–state neural networks (abstract interface). + +This exposes: +- `σ_pos`, `σ_neg`: the two distinguished activation states, +- `θ0`: a canonical index into the `κ2`-vector of thresholds extracting the scalar threshold, +- facts that `fact` returns `σ_pos` at or above `θ0`, and `σ_neg` strictly below, +- that both `σ_pos` and `σ_neg` satisfy `pact`, +- and an order gap on the numeric embedding `m` (`m σ_neg < m σ_pos`). -/ +class TwoStateNeuralNetwork {R U σ} + [Field R] [LinearOrder R] [IsStrictOrderedRing R] + (NN : NeuralNetwork R U σ) where + /-- Distinguished “positive” activation state. -/ + σ_pos : σ + /-- Distinguished “negative” activation state. -/ + σ_neg : σ + /-- Proof that the two distinguished activation states are distinct. -/ + h_pos_ne_neg : σ_pos ≠ σ_neg + /-- Canonical index in `κ2 u` selecting the scalar threshold used by `fact`. -/ + θ0 : ∀ u : U, Fin (NN.κ2 u) + /-- At or above threshold `θ0 u`, `fact` returns `σ_pos`. -/ + h_fact_pos : + ∀ u (σcur : σ) (net : R) (θ : Vector R (NN.κ2 u)), + (θ.get (θ0 u)) ≤ net → NN.fact u σcur net θ = σ_pos + /-- Strictly below threshold `θ0 u`, `fact` returns `σ_neg`. -/ + h_fact_neg : + ∀ u (σcur : σ) (net : R) (θ : Vector R (NN.κ2 u)), + net < (θ.get (θ0 u)) → NN.fact u σcur net θ = σ_neg + /-- `σ_pos` satisfies the activation predicate `pact`. -/ + h_pact_pos : NN.pact σ_pos + /-- `σ_neg` satisfies the activation predicate `pact`. -/ + h_pact_neg : NN.pact σ_neg + /-- Numeric embedding separates the two states: `m σ_neg < m σ_pos`. -/ + m_order : NN.m σ_neg < NN.m σ_pos + +namespace TwoState +variable {R U σ : Type} +variable [Field R] [LinearOrder R] [IsStrictOrderedRing R] +--variable [DecidableEq U] [Fintype U] [Nonempty U] + +/-! Concrete network families (three encodings). -/ + +/-- Helper canonical index for all concrete networks with κ2 = 1. -/ +@[inline] def fin0 : Fin 1 := ⟨0, by decide⟩ + +/-- Standard symmetric Hopfield parameters with activations in {-1,1} (σ = R). -/ +def SymmetricBinary (R U : Type) [Field R] [LinearOrder R] + [DecidableEq U] [Fintype U] [Nonempty U] : NeuralNetwork R U R := +{ Adj := fun u v => u ≠ v + Ui := Set.univ + Uo := Set.univ + Uh := ∅ + hUi := by simp + hUo := by simp + hU := by simp + hhio := by simp + κ1 := fun _ => 1 + κ2 := fun _ => 1 + pw := fun W => W.IsSymm ∧ ∀ u, W u u = 0 + pact := fun a => a = 1 ∨ a = -1 + fnet := fun u row pred _ => ∑ v, if v ≠ u then row v * pred v else 0 + fact := fun _ _ net θ => if θ.get fin0 ≤ net then 1 else -1 + fout := fun _ a => a + m := id + hpact := by + classical + intro W _ _ _ θ cur hcur u + by_cases hth : + (θ u).get fin0 ≤ ∑ v, if v ≠ u then W u v * cur v else 0 + · simp [pact, fact, fnet, hth]; aesop + · simp [pact, fact, fnet, hth]; aesop } + +/-- Type level two-value signum variant. -/ +inductive Signum | pos | neg deriving DecidableEq + +instance : Fintype Signum where + elems := {Signum.pos, Signum.neg} + complete := by intro x; cases x <;> simp + +/-- Symmetric Hopfield parameters with σ = Signum. -/ +def SymmetricSignum (R U : Type) [Field R] [LinearOrder R] + [DecidableEq U] [Fintype U] [Nonempty U] : NeuralNetwork R U Signum := +{ Adj := fun u v => u ≠ v + Ui := Set.univ + Uo := Set.univ + Uh := ∅ + hUi := by simp + hUo := by simp + hU := by simp + hhio := by simp + κ1 := fun _ => 1 + κ2 := fun _ => 1 + pw := fun W => W.IsSymm ∧ ∀ u, W u u = 0 + fnet := fun u row pred _ => ∑ v, if v ≠ u then row v * pred v else 0 + pact := fun _ => True + fact := fun _ _ net θ => if θ.get fin0 ≤ net then Signum.pos else Signum.neg + fout := fun _ s => match s with | .pos => (1 : R) | .neg => (-1 : R) + m := fun s => match s with | .pos => (1 : R) | .neg => (-1 : R) + hpact := by intro; simp } + +/-- Zero / one network (σ ∈ {0,1}). -/ +def ZeroOne (R U : Type) [Field R] [LinearOrder R] + [DecidableEq U] [Fintype U] [Nonempty U] : NeuralNetwork R U R := +{ (SymmetricBinary R U) with + pact := fun a => a = 0 ∨ a = 1 + fact := fun _ _ net θ => if θ.get fin0 ≤ net then 1 else 0 + hpact := by + classical + intro W _ _ σ θ cur hcur u + by_cases hth : + (θ u).get fin0 ≤ ∑ v, if v ≠ u then W u v * cur v else 0 + · simp [SymmetricBinary, pact, fact, hth]; aesop + · simp [SymmetricBinary, pact, fact, hth]; aesop } + +variable [DecidableEq U] [Fintype U] [Nonempty U] + +instance instTwoStateSymmetricBinary : + TwoStateNeuralNetwork (SymmetricBinary R U) where + σ_pos := (1 : R); σ_neg := (-1 : R) + h_pos_ne_neg := by + have h0 : (0 : R) < 1 := zero_lt_one + have hneg : (-1 : R) < 0 := by simp + have hlt : (-1 : R) < 1 := hneg.trans h0 + exact (ne_of_lt hlt).symm + θ0 := fun _ => fin0 + h_fact_pos := by + intro u σcur net θ hle; simp [SymmetricBinary, hle] + h_fact_neg := by + intro u σcur net θ hlt + have : ¬ θ.get fin0 ≤ net := not_le.mpr hlt + simp [SymmetricBinary, this] + h_pact_pos := by left; rfl + h_pact_neg := by right; rfl + m_order := by + have h0 : (0 : R) < 1 := zero_lt_one + have hneg : (-1 : R) < 0 := by simp + exact hneg.trans h0 + +instance instTwoStateSignum : + TwoStateNeuralNetwork (SymmetricSignum R U) where + σ_pos := Signum.pos; σ_neg := Signum.neg + h_pos_ne_neg := by intro h; cases h + θ0 := fun _ => fin0 + h_fact_pos := by + intro u σcur net θ hn + change (if θ.get fin0 ≤ net then Signum.pos else Signum.neg) = Signum.pos + simp [hn] + h_fact_neg := by + intro u σcur net θ hlt + change (if θ.get fin0 ≤ net then Signum.pos else Signum.neg) = Signum.neg + have : ¬ θ.get fin0 ≤ net := not_le.mpr hlt + simp [this] + h_pact_pos := by trivial + h_pact_neg := by trivial + m_order := by + -- `m` maps pos ↦ 1, neg ↦ -1 + have h0 : (0 : R) < 1 := zero_lt_one + have hneg : (-1 : R) < 0 := by simp + simp [SymmetricSignum] + +instance instTwoStateZeroOne : + TwoStateNeuralNetwork (ZeroOne R U) where + σ_pos := (1 : R); σ_neg := (0 : R) + h_pos_ne_neg := one_ne_zero + θ0 := fun _ => fin0 + h_fact_pos := by + intro u σcur net θ hn + change (if θ.get fin0 ≤ net then (1 : R) else 0) = 1 + simp [hn] + h_fact_neg := by + intro u σcur net θ hlt + change (if θ.get fin0 ≤ net then (1 : R) else 0) = 0 + have : ¬ θ.get fin0 ≤ net := not_le.mpr hlt + simp [this] + h_pact_pos := by right; rfl + h_pact_neg := by left; rfl + m_order := by + -- m = id, so goal is 0 < 1 + simp [ZeroOne, SymmetricBinary] + +/-- Scale between numeric embeddings of the two states (pushed along f). -/ +noncomputable def scale + {F} [FunLike F R ℝ] + {NN : NeuralNetwork R U σ} [TwoStateNeuralNetwork NN] + (f : F) : ℝ := + f (NN.m (TwoStateNeuralNetwork.σ_pos (NN:=NN))) - + f (NN.m (TwoStateNeuralNetwork.σ_neg (NN:=NN))) + +/-- Generalized scale in an arbitrary target ring S. -/ +noncomputable def scaleS + {S} [Ring S] {F} [FunLike F R S] + {NN : NeuralNetwork R U σ} [TwoStateNeuralNetwork NN] (f : F) : S := + f (NN.m (TwoStateNeuralNetwork.σ_pos (NN:=NN))) - + f (NN.m (TwoStateNeuralNetwork.σ_neg (NN:=NN))) + +omit [DecidableEq U] [Fintype U] [Nonempty U] in +@[simp] lemma scaleS_apply_ℝ + {F} [FunLike F R ℝ] + {NN : NeuralNetwork R U σ} [TwoStateNeuralNetwork NN] (f : F) : + scaleS (NN:=NN) (f:=f) = scale (NN:=NN) (f:=f) := rfl + +@[simp] lemma scale_binary (f : R →+* ℝ) : + scale (R:=R) (U:=U) (σ:=R) (NN:=SymmetricBinary R U) (f:=f) = f 2 := by + -- σ_pos = 1, σ_neg = -1, m = id + unfold scale + simp [instTwoStateSymmetricBinary, SymmetricBinary, sub_neg_eq_add, one_add_one_eq_two] + rw [@map_ofNat] + +@[simp] lemma scale_zeroOne (f : R →+* ℝ) : + scale (R:=R) (U:=U) (σ:=R) (NN:=ZeroOne R U) (f:=f) = f 1 := by + -- σ_pos = 1, σ_neg = 0, m = id + unfold scale + simp [instTwoStateZeroOne, ZeroOne, SymmetricBinary] + +/-- Logistic function used for Gibbs probabilities. -/ +noncomputable def logisticProb (x : ℝ) : ℝ := 1 / (1 + Real.exp (-x)) + +lemma logisticProb_nonneg (x : ℝ) : 0 ≤ logisticProb x := by + unfold logisticProb + have hx : 0 < Real.exp (-x) := Real.exp_pos _ + have hden : 0 < 1 + Real.exp (-x) := by linarith + exact div_nonneg zero_le_one hden.le + +lemma logisticProb_le_one (x : ℝ) : logisticProb x ≤ 1 := by + unfold logisticProb + have hx : 0 < Real.exp (-x) := Real.exp_pos _ + have hden : 0 < 1 + Real.exp (-x) := by linarith + have : 1 ≤ 1 + Real.exp (-x) := by + linarith + simpa using (div_le_one hden).mpr this + +/-- Probability P(σ_u = σ_pos) for one Gibbs update. -/ +noncomputable def probPos + {F} [FunLike F R ℝ] + {NN : NeuralNetwork R U σ} [TwoStateNeuralNetwork NN] + (f : F) (p : Params NN) (T : Temperature) (s : NN.State) (u : U) : ℝ := + let L := (s.net p u) - (p.θ u).get (TwoStateNeuralNetwork.θ0 (NN:=NN) u) + let κ := scale (R:=R) (U:=U) (σ:=σ) (NN:=NN) (f:=f) + logisticProb (κ * (f L) * (β T)) + +omit [DecidableEq U] [Fintype U] [Nonempty U] in +lemma probPos_nonneg + {F} [FunLike F R ℝ] + {NN : NeuralNetwork R U σ} [TwoStateNeuralNetwork NN] + (f : F) (p : Params NN) (T : Temperature) (s : NN.State) (u : U) : + 0 ≤ probPos (R:=R) (U:=U) (σ:=σ) (NN:=NN) f p T s u := by + unfold probPos; apply logisticProb_nonneg + +omit [DecidableEq U] [Fintype U] [Nonempty U] in +lemma probPos_le_one + {F} [FunLike F R ℝ] + {NN : NeuralNetwork R U σ} [TwoStateNeuralNetwork NN] + (f : F) (p : Params NN) (T : Temperature) (s : NN.State) (u : U) : + probPos (R:=R) (U:=U) (σ:=σ) (NN:=NN) f p T s u ≤ 1 := by + unfold probPos; apply logisticProb_le_one + +/-- Force neuron u to σ_pos. -/ +def updPos {NN : NeuralNetwork R U σ} [TwoStateNeuralNetwork NN] + (s : NN.State) (u : U) : NN.State := +{ act := Function.update s.act u (TwoStateNeuralNetwork.σ_pos (NN:=NN)) +, hp := by + intro v + by_cases h : v = u + · subst h; simpa using TwoStateNeuralNetwork.h_pact_pos (NN:=NN) + · simpa [Function.update, h] using s.hp v } + +/-- Force neuron u to σ_neg. -/ +def updNeg {NN : NeuralNetwork R U σ} [TwoStateNeuralNetwork NN] + (s : NN.State) (u : U) : NN.State := +{ act := Function.update s.act u (TwoStateNeuralNetwork.σ_neg (NN:=NN)) +, hp := by + intro v + by_cases h : v = u + · subst h; simpa using TwoStateNeuralNetwork.h_pact_neg (NN:=NN) + · simpa [Function.update, h] using s.hp v } + +/-- One–site Gibbs update kernel (PMF). -/ +noncomputable def gibbsUpdate + {F} [FunLike F R ℝ] + {NN : NeuralNetwork R U σ} [TwoStateNeuralNetwork NN] + (f : F) (p : Params NN) (T : Temperature) (s : NN.State) (u : U) : + PMF (NN.State) := by + classical + let pPos := probPos (R:=R) (U:=U) (σ:=σ) (NN:=NN) f p T s u + let pPosE := ENNReal.ofReal pPos + have h_le : pPosE ≤ 1 := by + rw [ENNReal.ofReal_le_one] + exact probPos_le_one (R:=R) (U:=U) (σ:=σ) (NN:=NN) f p T s u + exact + PMF.bernoulli pPosE h_le >>= fun b => + if b then PMF.pure (updPos (s:=s) (u:=u)) else PMF.pure (updNeg (s:=s) (u:=u)) + +/-- Zero–temperature deterministic (threshold) update at site u. + (Adjusted to avoid unused variable warning in the Prop-based `if`.) -/ +def zeroTempDet + {NN : NeuralNetwork R U σ} [TwoStateNeuralNetwork NN] + (p : Params NN) (s : NN.State) (u : U) : NN.State := + let net := s.net p u + let θ := (p.θ u).get (TwoStateNeuralNetwork.θ0 (NN:=NN) u) + if h : θ ≤ net then + (have _ := h; updPos (s:=s) (u:=u)) + else + (have _ := h; updNeg (s:=s) (u:=u)) + +/-- Gibbs sweep auxiliary function. -/ +noncomputable def gibbsSweepAux + {F} [FunLike F R ℝ] + {NN : NeuralNetwork R U σ} [TwoStateNeuralNetwork NN] + (f : F) (p : Params NN) (T : Temperature) : + List U → NN.State → PMF NN.State + | [], s => PMF.pure s + | u :: us, s => + gibbsUpdate (NN:=NN) f p T s u >>= fun s' => + gibbsSweepAux f p T us s' + +omit [Fintype U] [Nonempty U] in +@[simp] lemma gibbsSweepAux_nil + {F} [FunLike F R ℝ] + {NN : NeuralNetwork R U σ} [TwoStateNeuralNetwork NN] + (f : F) (p : Params NN) (T : Temperature) (s : NN.State) : + gibbsSweepAux (NN:=NN) f p T [] s = PMF.pure s := rfl + +omit [Fintype U] [Nonempty U] in +@[simp] lemma gibbsSweepAux_cons + {F} [FunLike F R ℝ] + {NN : NeuralNetwork R U σ} [TwoStateNeuralNetwork NN] + (f : F) (p : Params NN) (T : Temperature) (u : U) (us : List U) (s : NN.State) : + gibbsSweepAux (NN:=NN) f p T (u :: us) s = + gibbsUpdate (NN:=NN) f p T s u >>= fun s' => + gibbsSweepAux (NN:=NN) f p T us s' := rfl + +/-- Sequential Gibbs sweep over a list of sites, head applied first. -/ +noncomputable def gibbsSweep + {F} [FunLike F R ℝ] + {NN : NeuralNetwork R U σ} [TwoStateNeuralNetwork NN] + (order : List U) (p : Params NN) (T : Temperature) (f : F) + (s0 : NN.State) : PMF NN.State := + gibbsSweepAux (NN:=NN) f p T order s0 + +omit [Fintype U] [Nonempty U] in +@[simp] lemma gibbsSweep_nil + {F} [FunLike F R ℝ] + {NN : NeuralNetwork R U σ} [TwoStateNeuralNetwork NN] + (p : Params NN) (T : Temperature) (f : F) (s0 : NN.State) : + gibbsSweep (NN:=NN) ([] : List U) p T f s0 = PMF.pure s0 := rfl + +omit [Fintype U] [Nonempty U] in +lemma gibbsSweep_cons + {F} [FunLike F R ℝ] + {NN : NeuralNetwork R U σ} [TwoStateNeuralNetwork NN] + (u : U) (us : List U) (p : Params NN) (T : Temperature) (f : F) (s0 : NN.State) : + gibbsSweep (NN:=NN) (u :: us) p T f s0 = + (gibbsUpdate (NN:=NN) f p T s0 u) >>= fun s => + gibbsSweep (NN:=NN) us p T f s := rfl + +@[simp] lemma probPos_nonneg_apply_binary + (f : R →+* ℝ) (p : Params (SymmetricBinary R U)) (T : Temperature) + (s : (SymmetricBinary R U).State) (u : U) : + 0 ≤ probPos (R:=R) (U:=U) (σ:=R) (NN:=SymmetricBinary R U) f p T s u := + probPos_nonneg (R:=R) (U:=U) (σ:=R) (NN:=SymmetricBinary R U) f p T s u + +@[simp] lemma probPos_le_one_apply_binary + (f : R →+* ℝ) (p : Params (SymmetricBinary R U)) (T : Temperature) + (s : (SymmetricBinary R U).State) (u : U) : + probPos (R:=R) (U:=U) (σ:=R) (NN:=SymmetricBinary R U) f p T s u ≤ 1 := + probPos_le_one (R:=R) (U:=U) (σ:=R) (NN:=SymmetricBinary R U) f p T s u + +/-- **Energy specification bundling a global energy and a local field**. + +This abstracts the thermodynamic view: +- `E p s` is the global energy of state `s` under parameters `p`; +- `localField p s u` is the local field at site `u` in state `s`. +- The specification `localField_spec` connects the local field to the + network primitives, and `flip_energy_relation` states the fundamental + relation between energy differences and local fields. +Together these properties connect energy differences +to local fields and underpin the Gibbs/zero–temperature analysis. -/ +structure EnergySpec + {R U σ} [Field R] [LinearOrder R] [IsStrictOrderedRing R] + [DecidableEq U] + (NN : NeuralNetwork R U σ) [TwoStateNeuralNetwork NN] where + /-- Global energy function `E p s`. -/ + E : Params NN → NN.State → R + /-- Local field `L = localField p s u` at site `u`. -/ + localField : Params NN → NN.State → U → R + /-- Specification tying the abstract `localField` to the network primitives: + `localField p s u = s.net p u - (p.θ u).get (θ0 u)`. -/ + localField_spec : + ∀ (p : Params NN) (s : NN.State) (u : U), + localField p s u = + (s.net p u) - (p.θ u).get (TwoStateNeuralNetwork.θ0 (NN:=NN) u) + /-- Fundamental flip–energy relation: the energy difference between the + `updPos` and `updNeg` flips at `u` equals `- scale f * f(localField p s u)`, + when pushed along a ring hom `f : R →+* ℝ`. -/ + flip_energy_relation : + ∀ (f : R →+* ℝ) + (p : Params NN) (s : NN.State) (u : U), + let sPos := updPos (R:=R) (U:=U) (σ:=σ) (NN:=NN) (s:=s) (u:=u) + let sNeg := updNeg (R:=R) (U:=U) (σ:=σ) (NN:=NN) (s:=s) (u:=u) + f (E p sPos - E p sNeg) = + - (scale (R:=R) (U:=U) (σ:=σ) (NN:=NN) (f:=f)) * + f (localField p s u) + +/-- A simplified energy specification carrying the same data as `EnergySpec`, +with the flip relation stated using inlined `updPos`/`updNeg`. -/ +structure EnergySpec' + {R U σ} [Field R] [LinearOrder R] [IsStrictOrderedRing R] + [DecidableEq U] + (NN : NeuralNetwork R U σ) [TwoStateNeuralNetwork NN] where + /-- Global energy function `E p s`. -/ + E : Params NN → NN.State → R + /-- Local field `L = localField p s u` at site `u`. -/ + localField : Params NN → NN.State → U → R + /-- Specification tying `localField` to the network primitives: + `localField p s u = s.net p u - (p.θ u).get (θ0 u)`. -/ + localField_spec : + ∀ (p : Params NN) (s : NN.State) (u : U), + localField p s u = + (s.net p u) - (p.θ u).get (TwoStateNeuralNetwork.θ0 (NN:=NN) u) + /-- Fundamental flip–energy relation pushed along a ring hom `f : R →+* ℝ`: + `f (E p (updPos s u) - E p (updNeg s u)) + = - scale f * f (localField p s u)`. -/ + flip_energy_relation : + ∀ (f : R →+* ℝ) + (p : Params NN) (s : NN.State) (u : U), + f (E p (updPos (NN:=NN) s u) - E p (updNeg (NN:=NN) s u)) = + - (scale (R:=R) (U:=U) (σ:=σ) (NN:=NN) (f:=f)) * + f (localField p s u) + +namespace EnergySpec +variable {NN : NeuralNetwork R U σ} +variable [TwoStateNeuralNetwork NN] + +omit [Fintype U] [Nonempty U] in +lemma flip_energy_rel' + {F} [FunLike F R ℝ] [RingHomClass F R ℝ] + (ES : TwoState.EnergySpec (NN:=NN)) (f : F) + (p : Params NN) (s : NN.State) (u : U) : + f (ES.E p (updPos (NN:=NN) s u) - ES.E p (updNeg (NN:=NN) s u)) = + - (scale (R:=R) (U:=U) (σ:=σ) (NN:=NN) (f:=f)) * + f ((s.net p u) - (p.θ u).get (TwoStateNeuralNetwork.θ0 (NN:=NN) u)) := by + -- we build a bundled ring hom from f; its coercion is definitionally f + let f_hom : R →+* ℝ := + { toMonoidHom := + { toFun := f + map_one' := map_one f + map_mul' := map_mul f } + map_zero' := map_zero f + map_add' := map_add f } + have h := ES.flip_energy_relation f_hom p s u + simpa [ES.localField_spec] using h + +omit [Fintype U] [Nonempty U] in +lemma probPos_eq_of_energy + {F} [FunLike F R ℝ] [RingHomClass F R ℝ] + (ES : EnergySpec (NN:=NN)) (f : F) (p : Params NN) (T : Temperature) + (s : NN.State) (u : U) : + probPos (R:=R) (U:=U) (σ:=σ) (NN:=NN) f p T s u = + let sPos := updPos (R:=R) (U:=U) (σ:=σ) (NN:=NN) (s:=s) (u:=u) + let sNeg := updNeg (R:=R) (U:=U) (σ:=σ) (NN:=NN) (s:=s) (u:=u) + let Δ := f (ES.E p sPos - ES.E p sNeg) + logisticProb (- Δ * (β T)) := by + classical + unfold probPos + have hΔ := + ES.flip_energy_rel' (f:=f) p s u + simp [ES.localField_spec, hΔ, logisticProb, + mul_comm, mul_left_comm, mul_assoc, + add_comm, add_left_comm, add_assoc] + +end EnergySpec + +namespace EnergySpec' +variable {NN : NeuralNetwork R U σ} +variable [TwoStateNeuralNetwork NN] + +/-- Convert an `EnergySpec` to an `EnergySpec'`. -/ +def ofOld + (ES : TwoState.EnergySpec (NN:=NN)) : EnergySpec' (NN:=NN) := +{ E := ES.E +, localField := ES.localField +, localField_spec := ES.localField_spec +, flip_energy_relation := by + intro f p s u + simpa using (ES.flip_energy_relation f p s u) } + +omit [Fintype U] [Nonempty U] in +lemma flip_energy_rel' + {F} [FunLike F R ℝ] [RingHomClass F R ℝ] + (ES : EnergySpec' (NN:=NN)) (f : F) + (p : Params NN) (s : NN.State) (u : U) : + f (ES.E p (updPos (NN:=NN) s u) - ES.E p (updNeg (NN:=NN) s u)) = + - (scale (R:=R) (U:=U) (σ:=σ) (NN:=NN) (f:=f)) * + f ((s.net p u) - (p.θ u).get (TwoStateNeuralNetwork.θ0 (NN:=NN) u)) := by + -- Build a bundled ring hom from f; its coercion is definitionally f + let f_hom : R →+* ℝ := + { toMonoidHom := + { toFun := f + map_one' := map_one f + map_mul' := map_mul f } + map_zero' := map_zero f + map_add' := map_add f } + have h := ES.flip_energy_relation f_hom p s u + simpa [ES.localField_spec] using h + +end EnergySpec' + +omit [Fintype U] [Nonempty U] in +lemma EnergySpec.flip_energy_rel'' + {NN : NeuralNetwork R U σ} [TwoStateNeuralNetwork NN] + (ES : TwoState.EnergySpec (NN:=NN)) + {F} [FunLike F R ℝ] [RingHomClass F R ℝ] (f : F) + (p : Params NN) (s : NN.State) (u : U) : + f (ES.E p (updPos (NN:=NN) s u) - ES.E p (updNeg (NN:=NN) s u)) = + - (scale (R:=R) (U:=U) (σ:=σ) (NN:=NN) (f:=f)) * + f ((s.net p u) - (p.θ u).get (TwoStateNeuralNetwork.θ0 (NN:=NN) u)) := by + classical + refine (EnergySpec'.flip_energy_rel' + (EnergySpec'.ofOld (NN:=NN) ES) (f:=f) p s u) + +/-! ### Convergence +As β → ∞ (i.e., T → 0+), the one–site Gibbs update PMF converges pointwise to the +zero–temperature limit kernel, for any TwoState NN and any order-embedding f. -/ + +open scoped Topology Filter +open PMF Filter + +section Convergence +variable {R U σ : Type} +variable [Field R] [LinearOrder R] [IsStrictOrderedRing R] +variable [DecidableEq U] [Fintype U] [Nonempty U] + +variable {NN : NeuralNetwork R U σ} [TwoStateNeuralNetwork NN] + +omit [Field R] [IsStrictOrderedRing R] in +-- strict monotonicity from order-hom + injective +lemma strictMono_of_injective_orderHom + {F} [FunLike F R ℝ] [OrderHomClass F R ℝ] + (f : F) (hf : Function.Injective f) : StrictMono f := + (OrderHomClass.mono f).strictMono_of_injective hf + +/-- If `a > 0`, the piecewise {1, 0, 1/2} based on the sign of `a*v` matches that of `v`. -/ +lemma piecewise01half_sign_mul_left_pos {a v : ℝ} (ha : 0 < a) : + (if 0 < a * v then 1 else if a * v < 0 then 0 else (1/2 : ℝ)) + = + (if 0 < v then 1 else if v < 0 then 0 else (1/2 : ℝ)) := by + by_cases hvpos : 0 < v + · have : 0 < a * v := Left.mul_pos ha hvpos + simp [hvpos, this, not_lt.mpr this.le] + · by_cases hvneg : v < 0 + · have : a * v < 0 := mul_neg_of_pos_of_neg ha hvneg + simp [hvpos, hvneg, this, not_lt.mpr this.le] + · have hv0 : v = 0 := le_antisymm (le_of_not_gt hvpos) (le_of_not_gt hvneg) + simp [hvpos, hvneg, hv0] + +/-- If `x < 0` then the {1,0,1/2} piecewise is ≤ 0 (it equals 0). -/ +private lemma piecewise01half_le_zero_of_neg {x : ℝ} (hx : x < 0) : + (if 0 < x then 1 else if x < 0 then 0 else (1/2 : ℝ)) ≤ 0 := by + have hnotpos : ¬ 0 < x := not_lt.mpr hx.le + simp [hnotpos, hx] + +omit [Fintype U] [Nonempty U] [DecidableEq U] in +/-- Rewrite the 1/0/1/2 piecewise expression by dropping the positive `1/kB` factor +in front of its argument. This aligns the “zero-temperature target” with the +real-limit expression using `c0 := κ * f L`. -/ +lemma sign_piecewise_rewrite_with_c0 + {F} [FunLike F R ℝ] + (f : F) + {NN : NeuralNetwork R U σ} [TwoStateNeuralNetwork NN] + (L : R) : + let κ := scale (R:=R) (U:=U) (σ:=σ) (NN:=NN) (f:=f) + (if 0 < ((κ / kB) * (f L)) then 1 + else if ((κ / kB) * (f L)) < 0 then 0 else (1/2 : ℝ)) + = + (if 0 < (κ * (f L)) then 1 + else if (κ * (f L)) < 0 then 0 else (1/2 : ℝ)) := by + intro κ + have ha : 0 < (kB : ℝ)⁻¹ := inv_pos.mpr kB_pos + have hrewrite : + ((κ / kB) * (f L)) = (kB : ℝ)⁻¹ * (κ * (f L)) := by + simp [div_eq_mul_inv, mul_left_comm, mul_assoc] + have h := piecewise01half_sign_mul_left_pos + (a := (kB : ℝ)⁻¹) (v := κ * (f L)) ha + simpa [hrewrite] using h + +omit [IsStrictOrderedRing R] in +/-- Map of a positive (resp. negative) argument remains positive (resp. negative) +under a strictly monotone embedding sending `0 ↦ 0`. -/ +lemma map_pos_of_pos + {F} [FunLike F R ℝ] [OrderHomClass F R ℝ] [RingHomClass F R ℝ] + (f : F) (hf : Function.Injective f) {x : R} (hx : 0 < x) : 0 < f x := by + have hsm := strictMono_of_injective_orderHom (R:=R) (f:=f) hf + simpa [map_zero f] using (hsm hx) + +omit [IsStrictOrderedRing R] in +lemma map_neg_of_neg + {F} [FunLike F R ℝ] [OrderHomClass F R ℝ] [RingHomClass F R ℝ] + (f : F) (hf : Function.Injective f) {x : R} (hx : x < 0) : f x < 0 := by + have hsm := strictMono_of_injective_orderHom (R:=R) (f:=f) hf + simpa [map_zero f] using (hsm hx) + +omit [DecidableEq U] [Fintype U] [Nonempty U] in +/-- κ := scale between the two numeric states (pushed along f) is positive + under a strictly monotone embedding and the class `m_order`. -/ +lemma scale_pos + {F} [FunLike F R ℝ] [OrderHomClass F R ℝ] + (f : F) (hf : Function.Injective f) + {NN : NeuralNetwork R U σ} [TwoStateNeuralNetwork NN] : + 0 < scale (R:=R) (U:=U) (σ:=σ) (NN:=NN) (f:=f) := by + -- TwoStateNeuralNetwork.m_order : m σ_neg < m σ_pos + have h := (strictMono_of_injective_orderHom (R:=R) (f:=f) hf) + have himg : + f (NN.m (TwoStateNeuralNetwork.σ_neg (NN:=NN))) < + f (NN.m (TwoStateNeuralNetwork.σ_pos (NN:=NN))) := + h (TwoStateNeuralNetwork.m_order (NN:=NN)) + exact sub_pos.mpr himg + +/-- One-step zero-temperature limit kernel (tie -> 1/2 mixture of updPos/updNeg). -/ +noncomputable def zeroTempLimitPMF + (p : Params NN) (s : NN.State) (u : U) : PMF NN.State := + let net := s.net p u + let θ := (p.θ u).get (TwoStateNeuralNetwork.θ0 (NN:=NN) u) + if h : θ < net then + PMF.pure (updPos (s:=s) (u:=u)) + else if h' : net < θ then + PMF.pure (updNeg (s:=s) (u:=u)) + else + let pHalf : ℝ≥0∞ := ENNReal.ofReal (1/2) + have hp : pHalf ≤ 1 := by + simpa [pHalf] using + (ENNReal.ofReal_le_one.mpr (by norm_num : (1/2 : ℝ) ≤ 1)) + PMF.bernoulli pHalf hp >>= fun b => + if b then PMF.pure (updPos (s:=s) (u:=u)) else PMF.pure (updNeg (s:=s) (u:=u)) + +omit [Fintype U] [Nonempty U] in +/-- Helper: the two updated states differ. -/ +private lemma updPos_ne_updNeg (s : NN.State) (u : U) : + updPos (s:=s) (u:=u) ≠ updNeg (s:=s) (u:=u) := by + intro h + have := congrArg (fun st => st.act u) h + simp [updPos, updNeg, Function.update, TwoStateNeuralNetwork.h_pos_ne_neg] at this + +/- + General limit lemmas for reals, used to analyze the zero-temperature limit. + These are independent of the neural-network context (mathlib-ready). +-/ +open Real Filter Topology Monotone + +/-- Multiplication by a positive constant maps `atTop` to `atTop`. -/ +lemma tendsto_mul_const_atTop_atTop_of_pos {c : ℝ} (hc : 0 < c) : + Tendsto (fun x : ℝ => c * x) atTop atTop := by + refine (Filter.tendsto_atTop_atTop).2 ?_ + intro M + refine ⟨M / c, ?_⟩ + intro x hx + exact (div_le_iff₀' hc).mp hx + +/-- Multiplication by a negative constant maps `atTop` to `atBot`. -/ +lemma tendsto_mul_const_atTop_atBot_of_neg {c : ℝ} (hc : c < 0) : + Tendsto (fun x : ℝ => c * x) atTop atBot := by + refine (Filter.tendsto_atTop_atBot).2 ?_ + intro M + refine ⟨M / c, ?_⟩ + intro x hx + exact (div_le_iff_of_neg' hc).mp hx + +/-- As `x → +∞`, `logisticProb x → 1`. -/ +lemma logisticProb_tendsto_atTop : + Tendsto logisticProb atTop (𝓝 (1 : ℝ)) := by + -- logisticProb x = 1/(1 + exp(-x)); as x→+∞, -x→-∞, so exp(-x)→0, then 1/(1+r)→1 + have hx_neg : Tendsto (fun x : ℝ => -x) atTop atBot := + (tendsto_neg_atBot_iff).mpr tendsto_id + have h_exp : Tendsto (fun x => Real.exp (-x)) atTop (𝓝 0) := + Real.tendsto_exp_atBot.comp hx_neg + have h_cont : ContinuousAt (fun r : ℝ => (1 : ℝ) / (1 + r)) 0 := + (continuousAt_const.div (continuousAt_const.add continuousAt_id) (by norm_num)) + have h_comp : + Tendsto (fun x => (1 : ℝ) / (1 + Real.exp (-x))) atTop (𝓝 ((1 : ℝ) / (1 + 0))) := + h_cont.tendsto.comp h_exp + unfold logisticProb + simpa [Real.exp_zero] using h_comp + +/-- As `x → -∞`, `logisticProb x → 0`. -/ +lemma logisticProb_tendsto_atBot : + Tendsto logisticProb atBot (𝓝 (0 : ℝ)) := by + -- 0 ≤ logistic ≤ exp, and exp x → 0 as x→-∞ + have h_le_exp : ∀ x : ℝ, logisticProb x ≤ Real.exp x := by + intro x + unfold logisticProb + have hxpos : 0 < Real.exp (-x) := Real.exp_pos _ + have hz_le : Real.exp (-x) ≤ 1 + Real.exp (-x) := by linarith + have : (1 : ℝ) / (1 + Real.exp (-x)) ≤ (1 : ℝ) / Real.exp (-x) := + one_div_le_one_div_of_le hxpos hz_le + simpa [one_div, Real.exp_neg] using this + refine + tendsto_of_tendsto_of_tendsto_of_le_of_le + (tendsto_const_nhds) (Real.tendsto_exp_atBot) + (fun x => logisticProb_nonneg x) + (fun x => h_le_exp x) + +/-- As `T → 0+`, if `c > 0` then `c/T → +∞`. -/ +lemma tendsto_c_div_atTop_of_pos {c : ℝ} (hc : 0 < c) : + Tendsto (fun T : ℝ => c / T) (𝓝[>] (0 : ℝ)) atTop := by + have h_inv : Tendsto (fun T : ℝ => T⁻¹) (𝓝[>] (0 : ℝ)) atTop := + tendsto_inv_nhdsGT_zero + have h_mul := tendsto_mul_const_atTop_atTop_of_pos hc + simpa [div_eq_mul_inv] using (h_mul.comp h_inv) + +/-- As `T → 0+`, if `c < 0` then `c/T → -∞`. -/ +lemma tendsto_c_div_atBot_of_neg {c : ℝ} (hc : c < 0) : + Tendsto (fun T : ℝ => c / T) (𝓝[>] (0 : ℝ)) atBot := by + have h_inv : Tendsto (fun T : ℝ => T⁻¹) (𝓝[>] (0 : ℝ)) atTop := + tendsto_inv_nhdsGT_zero + have h_mul := tendsto_mul_const_atTop_atBot_of_neg hc + simpa [div_eq_mul_inv] using (h_mul.comp h_inv) + +/-- As `T → 0+`, if `c > 0` then `logisticProb (c/T) → 1`. -/ +lemma tendsto_logistic_scaled_of_pos {c : ℝ} (hc : 0 < c) : + Tendsto (fun T : ℝ => logisticProb (c / T)) (𝓝[>] (0 : ℝ)) (𝓝 (1 : ℝ)) := + logisticProb_tendsto_atTop.comp (tendsto_c_div_atTop_of_pos hc) + +/-- As `T → 0+`, if `c < 0` then `logisticProb (c/T) → 0`. -/ +lemma tendsto_logistic_scaled_of_neg {c : ℝ} (hc : c < 0) : + Tendsto (fun T : ℝ => logisticProb (c / T)) (𝓝[>] (0 : ℝ)) (𝓝 (0 : ℝ)) := + logisticProb_tendsto_atBot.comp (tendsto_c_div_atBot_of_neg hc) + +/-- As `T → 0+`, if `c = 0` then `logisticProb (c/T)` is constantly `1/2`. -/ +lemma tendsto_logistic_scaled_of_eq_zero {c : ℝ} (hc : c = 0) : + Tendsto (fun T : ℝ => logisticProb (c / T)) (𝓝[>] (0 : ℝ)) (𝓝 ((1 : ℝ) / 2)) := by + have : (fun T : ℝ => logisticProb (c / T)) =ᶠ[𝓝[>] (0 : ℝ)] fun _ => 1 / 2 := by + filter_upwards [self_mem_nhdsWithin] with T _ + simp [logisticProb, hc, Real.exp_zero, one_add_one_eq_two] + exact (tendsto_congr' this).mpr tendsto_const_nhds + +/-- As `T → 0+`, `logisticProb (c / T)` tends to `1` if `c > 0`, to `0` if `c < 0`, +and to `1/2` if `c = 0`. -/ +lemma tendsto_logistic_scaled + (c : ℝ) : + Tendsto (fun T : ℝ => logisticProb (c / T)) (nhdsWithin 0 (Set.Ioi 0)) + (𝓝 (if 0 < c then 1 else if c < 0 then 0 else 1/2)) := by + by_cases hcpos : 0 < c + · simpa [hcpos] using (tendsto_logistic_scaled_of_pos (c:=c) hcpos) + · by_cases hcneg : c < 0 + · have := tendsto_logistic_scaled_of_neg (c:=c) hcneg + simpa [hcpos, hcneg] using this + · have hc0 : c = 0 := le_antisymm (le_of_not_gt hcpos) (le_of_not_gt hcneg) + simpa [hcpos, hcneg, hc0] using (tendsto_logistic_scaled_of_eq_zero (c:=c) hc0) + +/-- On ℝ≥0: as `b → ∞`, `logisticProb (c * b) → 1/0/1/2` depending on the sign of `c`. -/ +lemma tendsto_logistic_const_mul_coeNNReal + (c : ℝ) : + Tendsto (fun b : ℝ≥0 => logisticProb (c * (b : ℝ))) atTop + (𝓝 (if 0 < c then 1 else if c < 0 then 0 else 1/2)) := by + have h_coe : Tendsto (fun b : ℝ≥0 => (b : ℝ)) atTop atTop := by + refine (Filter.tendsto_atTop_atTop).2 ?_ + intro M + refine ⟨⟨max 0 (M + 1), by have : 0 ≤ max 0 (M + 1) := le_max_left _ _; exact this⟩, ?_⟩ + intro b hb + have hBR : (max 0 (M + 1) : ℝ) ≤ (b : ℝ) := by exact_mod_cast hb + have hM1 : (M + 1 : ℝ) ≤ (b : ℝ) := le_trans (le_max_right _ _) hBR + have : (M : ℝ) ≤ (b : ℝ) := by linarith + exact this + by_cases hcpos : 0 < c + · + have hmul := tendsto_mul_const_atTop_atTop_of_pos (c:=c) hcpos + simpa [hcpos, Function.comp] using + logisticProb_tendsto_atTop.comp (hmul.comp h_coe) + · by_cases hcneg : c < 0 + · + have hmul := tendsto_mul_const_atTop_atBot_of_neg (c:=c) hcneg + simpa [hcpos, hcneg, Function.comp] using + logisticProb_tendsto_atBot.comp (hmul.comp h_coe) + · + have hc0 : c = 0 := le_antisymm (le_of_not_gt hcpos) (le_of_not_gt hcneg) + have hconst : + (fun b : ℝ≥0 => logisticProb (c * (b : ℝ))) = fun _ => (1/2 : ℝ) := by + funext b; simp [hc0, logisticProb, Real.exp_zero, one_add_one_eq_two] + aesop + +omit [DecidableEq U] [Fintype U] [Nonempty U] in +/-- Real-valued probability limit P(T) for our model as T → 0+. -/ +private lemma tendsto_probPos_at_zero + {F} [FunLike F R ℝ] [RingHomClass F R ℝ] [OrderHomClass F R ℝ] + (f : F) (_ : Function.Injective f) + (p : Params NN) (s : NN.State) (u : U) : + let L : R := (s.net p u) - (p.θ u).get (TwoStateNeuralNetwork.θ0 (NN:=NN) u) + Tendsto (fun T : ℝ => probPos (NN:=NN) f p ⟨Real.toNNReal T⟩ s u) + (nhdsWithin 0 (Set.Ioi 0)) + (𝓝 (let κ := scale (NN:=NN) (f:=f); let c := (κ / kB) * (f L); + if 0 < c then 1 else if c < 0 then 0 else 1/2)) := by + intro L + have h_event : + (fun T : ℝ => probPos (NN:=NN) f p ⟨Real.toNNReal T⟩ s u) + =ᶠ[nhdsWithin 0 (Set.Ioi 0)] + (fun T : ℝ => + logisticProb (((scale (NN:=NN) (f:=f)) / kB) * (f L) / T)) := by + filter_upwards [self_mem_nhdsWithin] with T hTpos + have hT0 : 0 ≤ T := le_of_lt hTpos + have : (β ⟨Real.toNNReal T⟩ : ℝ) = 1 / (kB * T) := by + simp [Temperature.β, Temperature.toReal, Real.toNNReal_of_nonneg hT0, one_div] + unfold probPos + simp [this, logisticProb, one_div, mul_comm, mul_left_comm, mul_assoc, div_eq_mul_inv] + aesop + have hlim : + Tendsto (fun T : ℝ => + logisticProb (((scale (NN:=NN) (f:=f)) / kB) * (f L) / T)) + (nhdsWithin 0 (Set.Ioi 0)) + (𝓝 (let κ := scale (NN:=NN) (f:=f); let c := (κ / kB) * (f L); + if 0 < c then 1 else if c < 0 then 0 else 1/2)) := + tendsto_logistic_scaled _ + exact (tendsto_congr' h_event).mpr hlim + +/- Simple ENNReal evaluation lemmas for the Gibbs one-step with Bernoulli bind. -/ +namespace PMF + +variable {α : Type} + +@[simp] +lemma bernoulli_bind_pure_apply_left_of_ne + {p : ℝ≥0∞} (hp : p ≤ 1) {x y : α} (hxy : x ≠ y) : + ((PMF.bernoulli p hp) >>= fun b => if b then PMF.pure x else PMF.pure y) x = p := by + classical + change (PMF.bind (PMF.bernoulli p hp) (fun b => if b then PMF.pure x else PMF.pure y)) x = p + rw [PMF.bind_apply] + simp only [PMF.bernoulli_apply, PMF.pure_apply, tsum_fintype] + have : Finset.univ = ({false, true} : Finset Bool) := by + ext b; simp + rw [this, Finset.sum_pair (by simp : false ≠ true)] + simp only [Bool.cond_false, Bool.cond_true] + simp [hxy, if_neg hxy.symm] + +@[simp] +lemma bernoulli_bind_pure_apply_other + {p : ℝ≥0∞} (hp : p ≤ 1) {x y z : α} (hx : z ≠ x) (hy : z ≠ y) : + ((PMF.bernoulli p hp) >>= fun b => if b then PMF.pure x else PMF.pure y) z = 0 := by + classical + change (PMF.bind (PMF.bernoulli p hp) (fun b => if b then PMF.pure x else PMF.pure y)) z = 0 + rw [PMF.bind_apply] + simp only [PMF.bernoulli_apply, PMF.pure_apply, tsum_fintype] + have : Finset.univ = ({false, true} : Finset Bool) := by + ext b; simp + rw [this, Finset.sum_pair (by simp : false ≠ true)] + simp only [Bool.cond_false, Bool.cond_true] + simp [hx, hy, if_neg hx.symm, if_neg hy.symm] + +variable {α : Type} [DecidableEq α] + +@[simp] +lemma bernoulli_bind_pure_apply_right_of_ne + {p : ℝ≥0∞} (hp : p ≤ 1) {x y : α} (hxy : x ≠ y) : + ((PMF.bernoulli p hp) >>= fun b => if b then PMF.pure x else PMF.pure y) y = (1 - p) := by + classical + change (PMF.bind (PMF.bernoulli p hp) (fun b => if b then PMF.pure x else PMF.pure y)) y = (1 - p) + rw [PMF.bind_apply] + simp only [PMF.bernoulli_apply, PMF.pure_apply, tsum_fintype] + have : Finset.univ = ({false, true} : Finset Bool) := by + ext b; simp + rw [this, Finset.sum_pair (by simp : false ≠ true)] + simp only [Bool.cond_false, Bool.cond_true] + simp [hxy, if_neg hxy] + aesop + +end PMF +open PMF +omit [Fintype U] [Nonempty U] in +/-- Pointwise evaluation at `updPos`: exact (not just eventual) equality. -/ +private lemma gibbsUpdate_apply_updPos + {F} [FunLike F R ℝ] [RingHomClass F R ℝ] + (f : F) (p : Params NN) (T : Temperature) (s : NN.State) (u : U) : + (gibbsUpdate (NN:=NN) f p T s u) (updPos (s:=s) (u:=u)) + = ENNReal.ofReal (probPos (NN:=NN) f p T s u) := by + classical + unfold gibbsUpdate + set pPos := probPos (NN:=NN) f p T s u + set pPosE := ENNReal.ofReal pPos + have h_le : pPosE ≤ 1 := by + simpa [pPosE] using probPos_le_one (NN:=NN) f p T s u + have hne := updPos_ne_updNeg (s:=s) (u:=u) + simp [PMF.bernoulli_bind_pure_apply_left_of_ne (α:=NN.State) h_le hne, pPosE, pPos] + +omit [Fintype U] [Nonempty U] in +/-- Pointwise evaluation at `updNeg`: exact (not just eventual) equality. -/ +private lemma gibbsUpdate_apply_updNeg + {F} [FunLike F R ℝ] [RingHomClass F R ℝ] + (f : F) (p : Params NN) (T : Temperature) (s : NN.State) (u : U) : + (gibbsUpdate (NN:=NN) f p T s u) (updNeg (s:=s) (u:=u)) + = ENNReal.ofReal (1 - probPos (NN:=NN) f p T s u) := by + classical + unfold gibbsUpdate + set pPos := probPos (NN:=NN) f p T s u + set pPosE := ENNReal.ofReal pPos + have h_le : pPosE ≤ 1 := by + simpa [pPosE] using probPos_le_one (NN:=NN) f p T s u + have hne := updPos_ne_updNeg (s:=s) (u:=u) + have : ((PMF.bernoulli pPosE h_le) >>= fun b => if b then PMF.pure (updPos (s:=s) (u:=u)) else PMF.pure (updNeg (s:=s) (u:=u))) (updNeg (s:=s) (u:=u)) + = (1 - pPosE) := PMF.bernoulli_bind_pure_apply_right_of_ne (α:=NN.State) h_le hne + rw [this] + have hpos_nonneg : 0 ≤ pPos := probPos_nonneg (NN:=NN) f p T s u + have hpos_le_one : pPos ≤ 1 := probPos_le_one (NN:=NN) f p T s u + have h_eq : (1 : ℝ≥0∞) - pPosE = ENNReal.ofReal (1 - pPos) := by + simp_rw [pPosE] + rw [← ENNReal.ofReal_one] + rw [ENNReal.ofReal_sub 1 hpos_nonneg] + rw [h_eq] + +omit [Fintype U] [Nonempty U] in +/-- Eventual equality rewriting Gibbs mass at updPos along β → ∞ to ENNReal.ofReal (probPos at T). -/ +private lemma eventually_eval_updPos_eq_ofReal_probPos + {F} [FunLike F R ℝ] [RingHomClass F R ℝ] + (f : F) (p : Params NN) (s : NN.State) (u : U) : + (fun b : ℝ≥0 => (gibbsUpdate (NN:=NN) f p (Temperature.ofβ b) s u) (updPos (s:=s) (u:=u))) + =ᶠ[atTop] + (fun b : ℝ≥0 => ENNReal.ofReal (probPos (NN:=NN) f p (Temperature.ofβ b) s u)) := by + refine Filter.Eventually.of_forall ?_ + intro b; simp [gibbsUpdate_apply_updPos] + +omit [Fintype U] [Nonempty U] in +/-- Eventual equality rewriting Gibbs mass at updNeg along β → ∞ to ENNReal.ofReal (1 - probPos). -/ +private lemma eventually_eval_updNeg_eq_ofReal_one_sub_probPos + {F} [FunLike F R ℝ] [RingHomClass F R ℝ] + (f : F) (p : Params NN) (s : NN.State) (u : U) : + (fun b : ℝ≥0 => (gibbsUpdate (NN:=NN) f p (Temperature.ofβ b) s u) (updNeg (s:=s) (u:=u))) + =ᶠ[atTop] + (fun b : ℝ≥0 => ENNReal.ofReal (1 - probPos (NN:=NN) f p (Temperature.ofβ b) s u)) := by + refine Filter.Eventually.of_forall ?_ + intro b; simp [gibbsUpdate_apply_updNeg] + +omit [Fintype U] [Nonempty U] in +/-- Target evaluation: the zero-temperature PMF mass at `updPos` as an ENNReal.ofReal + piecewise {1,0,1/2} driven by the sign of `((scale f)/kB) * f L`. -/ +private lemma zeroTemp_target_updPos_as_ofReal_sign + {F} [FunLike F R ℝ] [RingHomClass F R ℝ] [OrderHomClass F R ℝ] + (f : F) (hf : Function.Injective f) + (p : Params NN) (s : NN.State) (u : U) : + let net := s.net p u + let θ := (p.θ u).get (TwoStateNeuralNetwork.θ0 (NN:=NN) u) + (zeroTempLimitPMF (NN:=NN) p s u) (updPos (s:=s) (u:=u)) = + ENNReal.ofReal + (if 0 < ((scale (NN:=NN) (f:=f)) / kB) * (f (net - θ)) + then 1 else if ((scale (NN:=NN) (f:=f)) / kB) * (f (net - θ)) < 0 then 0 else (1/2 : ℝ)) := by + classical + intro net θ + by_cases hpos : θ < net + · -- positive local field: pure updPos, RHS selects 1 branch + have hLpos : 0 < (net - θ) := sub_pos.mpr hpos + have hfpos : 0 < f (net - θ) := map_pos_of_pos (R:=R) (f:=f) hf hLpos + have hκpos : 0 < ((scale (NN:=NN) (f:=f)) / kB) := + div_pos (scale_pos (R:=R) (U:=U) (σ:=σ) (NN:=NN) (f:=f) hf) kB_pos + have hprodpos : 0 < ((scale (NN:=NN) (f:=f)) / kB) * (f (net - θ)) := + mul_pos hκpos hfpos + have hne := updPos_ne_updNeg (s:=s) (u:=u) + simp only [zeroTempLimitPMF, net, θ, hpos, not_lt.mpr hpos.le, PMF.pure_apply, + hne, dite_eq_ite, if_true, ENNReal.ofReal_one] + simp [hprodpos] + aesop + · by_cases hneg : net < θ + · -- negative local field: pure updNeg at updPos gives 0, RHS selects 0 branch + have hLneg : (net - θ) < 0 := sub_lt_zero.mpr hneg + have hfneg : f (net - θ) < 0 := map_neg_of_neg (R:=R) (f:=f) hf hLneg + have hκpos : 0 < ((scale (NN:=NN) (f:=f)) / kB) := + div_pos (scale_pos (R:=R) (U:=U) (σ:=σ) (NN:=NN) (f:=f) hf) kB_pos + have hprodneg : ((scale (NN:=NN) (f:=f)) / kB) * (f (net - θ)) < 0 := + mul_neg_of_pos_of_neg hκpos hfneg + have hne := updPos_ne_updNeg (s:=s) (u:=u) + simp only [zeroTempLimitPMF, net, θ, hneg, not_lt.mpr hneg.le, PMF.pure_apply, + hne, dite_eq_ite, if_false, if_true, ENNReal.ofReal_zero] + rw [@ENNReal.zero_eq_ofReal] + have hprodneg' : + ((scale (NN:=NN) (f:=f)) / kB) * ((f net) - (f θ)) < 0 := by + simpa [map_sub f] using hprodneg + simpa using + (piecewise01half_le_zero_of_neg + (x := ((scale (NN:=NN) (f:=f)) / kB) * ((f net) - (f θ)))) hprodneg' + · -- tie: 1/2 mixture, RHS selects 1/2 branch + have h_eq : net = θ := le_antisymm (le_of_not_gt hpos) (le_of_not_gt hneg) + have hf0 : f (net - θ) = 0 := by simp [h_eq, map_zero f] + have hne := updPos_ne_updNeg (s:=s) (u:=u) + simp [zeroTempLimitPMF, net, θ, hpos, hneg, hf0, hne] + +omit [Fintype U] [Nonempty U] in +/-- Target evaluation: the zero-temperature PMF mass at `updNeg` as ENNReal.ofReal + of one minus the same 1/0/1/2 piecewise expression. -/ +private lemma zeroTemp_target_updNeg_as_ofReal_one_sub_sign + {F} [FunLike F R ℝ] [RingHomClass F R ℝ] [OrderHomClass F R ℝ] + (f : F) (hf : Function.Injective f) + (p : Params NN) (s : NN.State) (u : U) : + let net := s.net p u + let θ := (p.θ u).get (TwoStateNeuralNetwork.θ0 (NN:=NN) u) + (zeroTempLimitPMF (NN:=NN) p s u) (updNeg (s:=s) (u:=u)) = + ENNReal.ofReal + (1 - (if 0 < ((scale (NN:=NN) (f:=f)) / kB) * (f (net - θ)) + then 1 else if ((scale (NN:=NN) (f:=f)) / kB) * (f (net - θ)) < 0 then 0 else (1/2 : ℝ))) := by + classical + intro net θ + by_cases hpos : θ < net + · have hLpos : 0 < (net - θ) := sub_pos.mpr hpos + have hfpos : 0 < f (net - θ) := map_pos_of_pos (R:=R) (f:=f) hf hLpos + have hκpos : 0 < ((scale (NN:=NN) (f:=f)) / kB) := + div_pos (scale_pos (R:=R) (U:=U) (σ:=σ) (NN:=NN) (f:=f) hf) kB_pos + have hprodpos : 0 < ((scale (NN:=NN) (f:=f)) / kB) * (f (net - θ)) := + mul_pos hκpos hfpos + have hne := updPos_ne_updNeg (s:=s) (u:=u) + simp [zeroTempLimitPMF, net, θ, hpos, not_lt.mpr hpos.le, hprodpos, PMF.pure_apply, hne] + aesop + · by_cases hneg : net < θ + · have hLneg : (net - θ) < 0 := sub_lt_zero.mpr hneg + have hfneg : f (net - θ) < 0 := map_neg_of_neg (R:=R) (f:=f) hf hLneg + have hκpos : 0 < ((scale (NN:=NN) (f:=f)) / kB) := + div_pos (scale_pos (R:=R) (U:=U) (σ:=σ) (NN:=NN) (f:=f) hf) kB_pos + have hprodneg : ((scale (NN:=NN) (f:=f)) / kB) * (f (net - θ)) < 0 := + mul_neg_of_pos_of_neg hκpos hfneg + have hne := updPos_ne_updNeg (s:=s) (u:=u) + have hprodneg' : + ((scale (NN:=NN) (f:=f)) / kB) * ((f net) - (f θ)) < 0 := by + simpa [map_sub f] using hprodneg + have hnotpos' : + ¬ 0 < ((scale (NN:=NN) (f:=f)) / kB) * ((f net) - (f θ)) := + not_lt.mpr hprodneg'.le + simp [zeroTempLimitPMF, net, θ, hneg, not_lt.mpr hneg.le, PMF.pure_apply, + hprodneg', hnotpos', one_div] + · -- tie branch: 1 - ofReal(1/2) = ofReal (1 - 1/2) + have h_eq : net = θ := le_antisymm (le_of_not_gt hpos) (le_of_not_gt hneg) + have hf0 : f (net - θ) = 0 := by simp [h_eq, map_zero f] + have hne := updPos_ne_updNeg (s:=s) (u:=u) + let pHalf : ℝ≥0∞ := ENNReal.ofReal (1 / 2 : ℝ) + have hp : pHalf ≤ 1 := by + simpa [pHalf] using + (ENNReal.ofReal_le_one.mpr (by norm_num : (1 / 2 : ℝ) ≤ 1)) + have hbind : + ((PMF.bernoulli pHalf hp) >>= fun b => + if b then PMF.pure (updPos (s:=s) (u:=u)) + else PMF.pure (updNeg (s:=s) (u:=u))) + (updNeg (s:=s) (u:=u)) = 1 - pHalf := + PMF.bernoulli_bind_pure_apply_right_of_ne (α:=NN.State) hp hne + have hle₁ : (1 / 2 : ℝ) ≤ 1 := by norm_num + have hnonneg : (0 : ℝ) ≤ (1 / 2 : ℝ) := by norm_num + have hsub : + (1 : ℝ≥0∞) - ENNReal.ofReal (1 / 2 : ℝ) = + ENNReal.ofReal (1 - (1 / 2 : ℝ)) := by + have h := ENNReal.ofReal_sub + (p := (1 : ℝ)) (q := (1 / 2 : ℝ)) (by norm_num : (0 : ℝ) ≤ (1 / 2 : ℝ)) + simpa [ENNReal.ofReal_one, one_div] using h.symm + have hbind' : + ((PMF.bernoulli pHalf hp) >>= fun b => + if b then PMF.pure (updPos (s:=s) (u:=u)) + else PMF.pure (updNeg (s:=s) (u:=u))) + (updNeg (s:=s) (u:=u)) + = (1 - ENNReal.ofReal (1 / 2 : ℝ)) := by + simpa [pHalf] using hbind + simp [zeroTempLimitPMF, net, θ, hpos, hneg, hf0, hne] + aesop + +omit [DecidableEq U] [Fintype U] [Nonempty U] in +/-- Real-valued limit along `β (ofβ b) = b`: `probPos` tends to 1/0/1/2 by the sign of `c0`. -/ +private lemma tendsto_probPos_along_ofβ_to_piecewise + {F} [FunLike F R ℝ] [RingHomClass F R ℝ] [OrderHomClass F R ℝ] + (f : F) (p : Params NN) (s : NN.State) (u : U) : + let L := (s.net p u) - (p.θ u).get (TwoStateNeuralNetwork.θ0 (NN:=NN) u) + let c0 : ℝ := (scale (NN:=NN) (f:=f)) * (f L) + Tendsto (fun b : ℝ≥0 => probPos (NN:=NN) f p (Temperature.ofβ b) s u) + atTop + (𝓝 (if 0 < c0 then 1 else if c0 < 0 then 0 else 1/2)) := by + intro L c0 + have hβ : ∀ b, (β (Temperature.ofβ b) : ℝ) = b := by intro b; simp + have h_form : + ∀ b, probPos (NN:=NN) f p (Temperature.ofβ b) s u + = logisticProb (c0 * (b : ℝ)) := by + intro b; unfold probPos; simp [hβ b, L, c0, mul_comm, mul_left_comm, mul_assoc, logisticProb] + simpa [h_form] using tendsto_logistic_const_mul_coeNNReal c0 + +omit [Fintype U] [Nonempty U] in +/-- Convergence on `updPos`: short proof using the split helpers. -/ +lemma gibbs_update_tends_to_zero_temp_limit_apply_updPos + {F} [FunLike F R ℝ] [RingHomClass F R ℝ] [OrderHomClass F R ℝ] + (f : F) (hf : Function.Injective f) + (p : Params NN) (s : NN.State) (u : U) : + Tendsto (fun b : ℝ≥0 => + (gibbsUpdate (NN:=NN) f p (Temperature.ofβ b) s u) (updPos (s:=s) (u:=u))) + atTop (𝓝 ((zeroTempLimitPMF (NN:=NN) p s u) (updPos (s:=s) (u:=u)))) := by + classical + have h_target := zeroTemp_target_updPos_as_ofReal_sign (NN:=NN) f hf p s u + have hev := eventually_eval_updPos_eq_ofReal_probPos (NN:=NN) f p s u + set net := s.net p u + set θ := (p.θ u).get (TwoStateNeuralNetwork.θ0 (NN:=NN) u) + set L : R := net - θ + set c0 : ℝ := (scale (NN:=NN) (f:=f)) * (f L) + have h_real := tendsto_probPos_along_ofβ_to_piecewise (NN:=NN) f p s u + have h_rewrite := sign_piecewise_rewrite_with_c0 (R:=R) (U:=U) (σ:=σ) (NN:=NN) (f:=f) L + have hlim := ENNReal.tendsto_ofReal (by simpa [L, c0, h_rewrite] using h_real) + have h := (tendsto_congr' hev).mpr hlim + have h' : Tendsto (fun b : ℝ≥0 => + (gibbsUpdate (NN:=NN) f p (Temperature.ofβ b) s u) (updPos (s:=s) (u:=u))) + atTop (𝓝 (ENNReal.ofReal + (if 0 < ((scale (NN:=NN) (f:=f)) / kB) * (f (net - θ)) + then 1 else if ((scale (NN:=NN) (f:=f)) / kB) * (f (net - θ)) < 0 then 0 else (1/2 : ℝ)))) := by + aesop + simpa [h_target, net, θ] using h' + +omit [Fintype U] [Nonempty U] in +/-- Convergence on `updNeg`: short proof using the split helpers. -/ +lemma gibbs_update_tends_to_zero_temp_limit_apply_updNeg + {F} [FunLike F R ℝ] [RingHomClass F R ℝ] [OrderHomClass F R ℝ] + (f : F) (hf : Function.Injective f) + (p : Params NN) (s : NN.State) (u : U) : + Tendsto (fun b : ℝ≥0 => + (gibbsUpdate (NN:=NN) f p (Temperature.ofβ b) s u) (updNeg (s:=s) (u:=u))) + atTop (𝓝 ((zeroTempLimitPMF (NN:=NN) p s u) (updNeg (s:=s) (u:=u)))) := by + classical + have h_target := zeroTemp_target_updNeg_as_ofReal_one_sub_sign (NN:=NN) f hf p s u + have hev := eventually_eval_updNeg_eq_ofReal_one_sub_probPos (NN:=NN) f p s u + set net := s.net p u + set θ := (p.θ u).get (TwoStateNeuralNetwork.θ0 (NN:=NN) u) + set L : R := net - θ + set c0 : ℝ := (scale (NN:=NN) (f:=f)) * (f L) + have h_real := + tendsto_probPos_along_ofβ_to_piecewise (NN:=NN) f p s u + have h_sub : + Tendsto (fun b : ℝ≥0 => + (1 : ℝ) - probPos (NN:=NN) f p (Temperature.ofβ b) s u) + atTop + (𝓝 (1 - (if 0 < c0 then 1 else if c0 < 0 then 0 else (1/2)))) := + tendsto_const_nhds.sub h_real + have h_lift : + Tendsto (fun b : ℝ≥0 => + ENNReal.ofReal (1 - probPos (NN:=NN) f p (Temperature.ofβ b) s u)) + atTop + (𝓝 (ENNReal.ofReal (1 - (if 0 < c0 then 1 else if c0 < 0 then 0 else (1/2))))) := + ENNReal.tendsto_ofReal h_sub + have h_rewrite : + (if 0 < ((scale (NN:=NN) (f:=f)) / kB) * (f L) then 1 + else if ((scale (NN:=NN) (f:=f)) / kB) * (f L) < 0 then 0 else (1/2 : ℝ)) + = + (if 0 < c0 then 1 else if c0 < 0 then 0 else (1/2 : ℝ)) := by + simpa [c0, one_div] using + sign_piecewise_rewrite_with_c0 + (R:=R) (U:=U) (σ:=σ) (NN:=NN) (f:=f) L + have h_conv : + Tendsto (fun b : ℝ≥0 => + (gibbsUpdate (NN:=NN) f p (Temperature.ofβ b) s u) (updNeg (s:=s) (u:=u))) + atTop + (𝓝 (ENNReal.ofReal + (1 - (if 0 < ((scale (NN:=NN) (f:=f)) / kB) * (f L) + then 1 else if ((scale (NN:=NN) (f:=f)) / kB) * (f L) < 0 + then 0 else (1/2 : ℝ))))) := by + have := (tendsto_congr' hev).mpr h_lift + aesop + aesop + +omit [Fintype U] [Nonempty U] in +/-- Convergence on any “other” state (neither updPos nor updNeg). -/ +lemma gibbs_update_tends_to_zero_temp_limit_apply_other + {F} [FunLike F R ℝ] + (f : F) + (p : Params NN) (s : NN.State) (u : U) + {state : NN.State} + (hpos : state ≠ updPos (s:=s) (u:=u)) + (hneg : state ≠ updNeg (s:=s) (u:=u)) : + Tendsto (fun b : ℝ≥0 => + (gibbsUpdate (NN:=NN) f p (Temperature.ofβ b) s u) state) + atTop (𝓝 ((zeroTempLimitPMF (NN:=NN) p s u) state)) := by + classical + set net := s.net p u + set θ := (p.θ u).get (TwoStateNeuralNetwork.θ0 (NN:=NN) u) + have htarget0 : + (zeroTempLimitPMF (NN:=NN) p s u) state = 0 := by + by_cases hθnet : θ < net + · simp [zeroTempLimitPMF, net, θ, hθnet, PMF.pure_apply, hpos] + · by_cases hnetθ : net < θ + · simp [zeroTempLimitPMF, net, θ, hθnet, hnetθ, PMF.pure_apply, hneg] + · + have hp : ENNReal.ofReal (1/2) ≤ (1 : ℝ≥0∞) := by + simpa using (ENNReal.ofReal_le_one.2 (by norm_num : (1/2 : ℝ) ≤ 1)) + have hbind_zero : + ((PMF.bernoulli (ENNReal.ofReal (1/2)) hp) >>= fun b => + if b then PMF.pure (updPos (s:=s) (u:=u)) else PMF.pure (updNeg (s:=s) (u:=u))) state = 0 := + PMF.bernoulli_bind_pure_apply_other (α:=NN.State) hp hpos hneg + simpa [zeroTempLimitPMF, net, θ, hθnet, hnetθ] using hbind_zero + have hev : + (fun b : ℝ≥0 => + (gibbsUpdate (NN:=NN) f p (Temperature.ofβ b) s u) state) + =ᶠ[atTop] (fun _ => (0 : ℝ≥0∞)) := by + refine Filter.Eventually.of_forall ?_ + intro b + simp [gibbsUpdate, PMF.bind_apply, tsum_fintype, PMF.pure_apply, hpos, hneg] + have : Tendsto (fun b : ℝ≥0 => + (gibbsUpdate (NN:=NN) f p (Temperature.ofβ b) s u) state) atTop (𝓝 0) := + (tendsto_congr' hev).mpr tendsto_const_nhds + simpa [htarget0] using this + +omit [Fintype U] [Nonempty U] in +/-- **Theorem** Pointwise convergence of the one–site Gibbs PMF to the zero-temperature limit PMF, +for every state. This wraps the three evaluation lemmas into a single statement. -/ +theorem gibbs_update_tends_to_zero_temp_limit + {F} [FunLike F R ℝ] [RingHomClass F R ℝ] [OrderHomClass F R ℝ] + (f : F) (hf : Function.Injective f) + (p : Params NN) (s : NN.State) (u : U) : + ∀ state : NN.State, + Tendsto (fun b : ℝ≥0 => + (gibbsUpdate (NN:=NN) f p (Temperature.ofβ b) s u) state) + atTop (𝓝 ((zeroTempLimitPMF (NN:=NN) p s u) state)) := by + classical + intro state + by_cases hpos : state = updPos (s:=s) (u:=u) + · subst hpos + exact gibbs_update_tends_to_zero_temp_limit_apply_updPos + (NN:=NN) f hf p s u + · by_cases hneg : state = updNeg (s:=s) (u:=u) + · subst hneg + exact gibbs_update_tends_to_zero_temp_limit_apply_updNeg + (NN:=NN) f hf p s u + · exact gibbs_update_tends_to_zero_temp_limit_apply_other + (NN:=NN) f p s u (by simpa using hpos) (by simpa using hneg) + +end Convergence +end TwoState From bc1fe8779c1fa82864f5eae7bc1baf97cf7b1660 Mon Sep 17 00:00:00 2001 From: Matteo Cipollina Date: Tue, 26 Aug 2025 12:20:45 +0200 Subject: [PATCH 12/15] feat(TwoState): generalize toCanonicalEnsemble, general proof of detailed balance and invariance for gibbs random kernel, etc. major refactor with proof of detailed balance /reversibility) and invariance (stationary) for our general TwoState random Gibbs kernel. This covers, with one general API, most of specific discrete HN and BM in the literature --- .../CanonicalEnsemble/Finite.lean | 17 + .../HopfieldNetwork/BoltzmannMachine.lean | 764 +++++++++++ .../SpinGlasses/HopfieldNetwork/Core.lean | 19 +- .../HopfieldNetwork/DetailedBalanceBM.lean | 1145 +++++++++++++++++ .../HopfieldNetwork/NeuralNetwork.lean | 86 +- .../SpinGlasses/HopfieldNetwork/TwoState.lean | 436 ++++++- .../SpinGlasses/HopfieldNetwork/aux.lean | 2 +- .../HopfieldNetwork/toCanonicalEnsemble.lean | 346 +++++ 8 files changed, 2759 insertions(+), 56 deletions(-) create mode 100644 PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/BoltzmannMachine.lean create mode 100644 PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/DetailedBalanceBM.lean create mode 100644 PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/toCanonicalEnsemble.lean diff --git a/PhysLean/StatisticalMechanics/CanonicalEnsemble/Finite.lean b/PhysLean/StatisticalMechanics/CanonicalEnsemble/Finite.lean index c0a0a45d1..5233d90fc 100644 --- a/PhysLean/StatisticalMechanics/CanonicalEnsemble/Finite.lean +++ b/PhysLean/StatisticalMechanics/CanonicalEnsemble/Finite.lean @@ -198,6 +198,23 @@ lemma μProd_of_fintype (T : Temperature) [IsFinite 𝓒] (i : ι) : rw [mul_comm] rfl +open scoped ENNReal + +/-- Finite singleton evaluation in `ℝ≥0∞` form. -/ +@[simp] +lemma μProd_singleton_of_fintype + (T : Temperature) [IsFinite 𝓒] [Nonempty ι] (i : ι) : + (𝓒.μProd T) {i} = ENNReal.ofReal (𝓒.probability T i) := by + have hReal := μProd_of_fintype (𝓒:=𝓒) (T:=T) i + have hfin : (𝓒.μProd T) {i} ≠ ∞ := (measure_ne_top _ _) + have hToReal : ((𝓒.μProd T) {i}).toReal = 𝓒.probability T i := by + simpa [measureReal_def, hfin] + using hReal + have hRewrite : + (𝓒.μProd T) {i} = ENNReal.ofReal (((𝓒.μProd T) {i}).toReal) := by + simp [ENNReal.ofReal_toReal, hfin] + rw [← hReal, ofReal_measureReal hfin] + lemma meanEnergy_of_fintype [IsFinite 𝓒] (T : Temperature) : 𝓒.meanEnergy T = ∑ i, 𝓒.energy i * 𝓒.probability T i := by simp [meanEnergy] diff --git a/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/BoltzmannMachine.lean b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/BoltzmannMachine.lean new file mode 100644 index 000000000..de7a63363 --- /dev/null +++ b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/BoltzmannMachine.lean @@ -0,0 +1,764 @@ +/- +Copyright (c) 2025 Matteo Cipollina. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Matteo Cipollina +-/ + +import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.TwoState +import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.toCanonicalEnsemble +import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.DetailedBalanceGen +import Mathlib.Probability.Kernel.Composition.Prod + +/-! ### Concrete Hopfield Energy and Fintype Instances +-/ + + + +/-! +Reintroduce (and simplify) a `Matrix.quadraticForm` helper and the update lemma +used later in the Hopfield energy flip relation proof (removed upstream). +-/ +namespace Matrix + +open scoped Classical Finset Set BigOperators + +variable {ι} [DecidableEq ι] [CommRing R] + +/-- Decomposition of an updated vector as original plus a single–site bump. -/ +lemma update_decomp (x : ι → R) (i : ι) (v : R) : + Function.update x i v = + fun j => x j + (if j = i then v - x i else 0) := by + funext j; by_cases hji : j = i + · subst hji; simp + · simp [hji, Function.update_of_ne hji, add_comm] + +/-- Auxiliary single–site perturbation (Kronecker bump). -/ +def singleBump (i : ι) (δ : R) : ι → R := fun j => if j = i then δ else 0 + +lemma update_eq_add_bump (x : ι → R) (i : ι) (v : R) : + Function.update x i v = (fun j => x j + singleBump i (v - x i) j) := by + simp [singleBump, update_decomp] + +variable [Fintype ι] + +/-- Column-sum split: separate the i-th term from the rest (unordered finite type). -/ +lemma sum_split_at + (f : ι → R) (i : ι) : + (∑ j, f j) = f i + ∑ j ∈ (Finset.univ.erase i), f j := by + classical + have : (Finset.univ : Finset ι) = {i} ∪ Finset.univ.erase i := by + ext j; by_cases hji : j = i <;> simp [hji] + have hDisj : Disjoint ({i} : Finset ι) (Finset.univ.erase i) := by + refine Finset.disjoint_left.mpr ?_ + intro j hj hj' + have : j = i := by simpa using Finset.mem_singleton.mp hj + simp [this] at hj' + calc + (∑ j, f j) + = ∑ j ∈ ({i} ∪ Finset.univ.erase i), f j := by rw [← this] + _ = (∑ j ∈ ({i} : Finset ι), f j) + ∑ j ∈ Finset.univ.erase i, f j := by + simp [Finset.sum_union hDisj, add_comm, add_left_comm, add_assoc] + _ = f i + ∑ j ∈ Finset.univ.erase i, f j := by simp + +/-- Quadratic form xᵀ M x written via `mulVec`. -/ +def quadraticForm (M : Matrix ι ι R) (x : ι → R) : R := + ∑ j, x j * (M.mulVec x) j + +/-- Effect of a single-site bump on mulVec (only the i-th column contributes). -/ +lemma mulVec_update_single + (M : Matrix ι ι R) (x : ι → R) (i : ι) (v : R) : + ∀ j, (M.mulVec (Function.update x i v)) j + = (M.mulVec x) j + M j i * (v - x i) := by + classical + intro j + have hUpd : Function.update x i v = fun k => if k = i then v else x k := by + funext k; by_cases hki : k = i + · subst hki; simp + · simp [Function.update_of_ne hki, hki] + unfold Matrix.mulVec dotProduct + have hSplitUpd := + sum_split_at (R:=R) (ι:=ι) (f:=fun k => M j k * (if k = i then v else x k)) i + have hSplitOrig := + sum_split_at (R:=R) (ι:=ι) (f:=fun k => M j k * x k) i + have hOthers : + (∑ k ∈ Finset.univ.erase i, M j k * (if k = i then v else x k)) + = ∑ k ∈ Finset.univ.erase i, M j k * x k := by + refine Finset.sum_congr rfl ?_ + intro k hk + rcases Finset.mem_erase.mp hk with ⟨hki, _⟩ + simp [hki] + have hLeft : + (∑ k, M j k * (Function.update x i v k)) + = M j i * v + ∑ k ∈ Finset.univ.erase i, M j k * x k := by + rw [hUpd, hSplitUpd, if_pos rfl, hOthers] + have hRightBase : + (∑ k, M j k * x k) + = M j i * x i + ∑ k ∈ Finset.univ.erase i, M j k * x k := by + simp only [hSplitOrig] + have hSplitv : M j i * v = M j i * x i + M j i * (v - x i) := by + rw [@mul_sub]; rw [@add_sub_cancel] + calc + (∑ k, M j k * (Function.update x i v k)) + = M j i * v + ∑ k ∈ Finset.univ.erase i, M j k * x k := hLeft + _ = (M j i * x i + M j i * (v - x i)) + ∑ k ∈ Finset.univ.erase i, M j k * x k := by + rw [hSplitv] + _ = (M j i * x i + ∑ k ∈ Finset.univ.erase i, M j k * x k) + M j i * (v - x i) := by + ac_rfl + _ = (∑ k, M j k * x k) + M j i * (v - x i) := by + rw [hRightBase] + +/- Raw single–site quadratic form update (no diagonal assumption). +Produces a δ-linear part plus a δ² * M i i remainder term. + + Q(update x i v) - Q x + = (v - x i) * ((∑ j, x j * M j i) + (M.mulVec x) i) + + (v - x i)^2 * M i i +-/ +lemma quadraticForm_update_point + (M : Matrix ι ι R) (x : ι → R) (i : ι) (v : R) (j : ι) : + let δ : R := v - x i + (Function.update x i v) j * (M.mulVec (Function.update x i v)) j + - x j * (M.mulVec x) j + = + δ * (x j * M j i + (if j = i then (M.mulVec x) i else 0)) + + (δ * δ) * (if j = i then M j i else 0) := by + classical + intro δ + have hMv : + (M.mulVec (Function.update x i v)) j = + (M.mulVec x) j + M j i * (v - x i) := by + simpa using + (mulVec_update_single (M:=M) (x:=x) (i:=i) (v:=v) j : _) + by_cases hji : j = i + · have hUpd_i : (Function.update x i v) i = v := by simp + have hMv_i : + (M.mulVec (Function.update x i v)) i = + (M.mulVec x) i + M i i * (v - x i) := by + simpa [hji] using hMv + have hOnSite : + (v * (((M.mulVec x) i) + M i i * (v - x i)) - (x i) * ((M.mulVec x) i)) + = + (v - x i) * ((x i) * M i i + (M.mulVec x) i) + + (v - x i) * (v - x i) * M i i := by + ring + aesop + · have hUpd_off : (Function.update x i v) j = x j := by + simp [Function.update, hji] + have hIf1 : (if j = i then (M.mulVec x) i else 0) = 0 := by + simp [hji] + have hIf2 : (if j = i then M j i else 0) = 0 := by + simp [hji] + have hOffSite : + (x j) * (((M.mulVec x) j) + M j i * (v - x i)) + - (x j) * ((M.mulVec x) j) + = + (v - x i) * ((x j) * M j i) := by + ring + simpa [hMv, hUpd_off, hIf1, hIf2, δ] using hOffSite + +/-- Core raw single–site quadratic form update (separated into a standalone lemma). +Produces a δ-linear part plus a δ² * M i i remainder term. -/ +lemma quadraticForm_update_sum + (M : Matrix ι ι R) (x : ι → R) (i : ι) (v : R) : + quadraticForm M (Function.update x i v) - quadraticForm M x + = + (v - x i) * ( (∑ j, x j * M j i) + (M.mulVec x) i ) + + (v - x i) * (v - x i) * M i i := by + classical + set δ : R := v - x i + have hPoint := + quadraticForm_update_point (M:=M) (x:=x) (i:=i) (v:=v) + have hDiff : + quadraticForm M (Function.update x i v) - quadraticForm M x + = + ∑ j, + ((Function.update x i v) j * (M.mulVec (Function.update x i v)) j + - x j * (M.mulVec x) j) := by + unfold quadraticForm + simp [Finset.sum_sub_distrib] + have hExpand : + (∑ j, + ((Function.update x i v) j * (M.mulVec (Function.update x i v)) j + - x j * (M.mulVec x) j)) + = + ∑ j, (δ * (x j * M j i + if j = i then (M.mulVec x) i else 0) + + (δ * δ) * (if j = i then M j i else 0)) := by + refine Finset.sum_congr rfl ?_ + intro j _ + simp [hPoint, δ] + have hSplit : + (∑ j, (δ * (x j * M j i + if j = i then (M.mulVec x) i else 0) + + (δ * δ) * (if j = i then M j i else 0))) + = + (∑ j, δ * (x j * M j i + if j = i then (M.mulVec x) i else 0)) + + + (∑ j, (δ * δ) * (if j = i then M j i else 0)) := by + simp [Finset.sum_add_distrib] + have hSum_if1 : + (∑ j : ι, (if j = i then (M.mulVec x) i else 0)) + = (M.mulVec x) i := by + classical + have hfilter : (Finset.univ.filter (fun j : ι => j = i)) = {i} := by + ext j; by_cases hji : j = i <;> simp [hji] + calc + (∑ j : ι, (if j = i then (M.mulVec x) i else 0)) + = ∑ j ∈ Finset.univ.filter (fun j => j = i), (M.mulVec x) i := by + aesop + _ = (M.mulVec x) i := by + simp [hfilter] + have hSum_if2 : + (∑ j : ι, (if j = i then M j i else 0)) = M i i := by + classical + have hfilter : (Finset.univ.filter (fun j : ι => j = i)) = {i} := by + ext j; by_cases hji : j = i <;> simp [hji] + calc + (∑ j : ι, (if j = i then M j i else 0)) + = ∑ j ∈ Finset.univ.filter (fun j => j = i), M j i := by + aesop + _ = M i i := by + simp [hfilter] + have hPart1 : + (∑ j, δ * (x j * M j i + if j = i then (M.mulVec x) i else 0)) + = + δ * ((∑ j, x j * M j i) + (M.mulVec x) i) := by + have : + (∑ j, δ * (x j * M j i + if j = i then (M.mulVec x) i else 0)) + = δ * ∑ j, (x j * M j i + if j = i then (M.mulVec x) i else 0) := by + simp [Finset.mul_sum] + simp [this, Finset.sum_add_distrib, hSum_if1, add_comm, add_left_comm, add_assoc] + have hPart2 : + (∑ j, (δ * δ) * (if j = i then M j i else 0)) + = (δ * δ) * M i i := by + have : + (∑ j, (δ * δ) * (if j = i then M j i else 0)) + = (δ * δ) * ∑ j, (if j = i then M j i else 0) := by + simp [Finset.mul_sum] + simp [this, hSum_if2] + calc + quadraticForm M (Function.update x i v) - quadraticForm M x + = _ := hDiff + _ = _ := hExpand + _ = _ := hSplit + _ = δ * ((∑ j, x j * M j i) + (M.mulVec x) i) + + (δ * δ) * M i i := by + aesop + _ = (v - x i) * ( (∑ j, x j * M j i) + (M.mulVec x) i ) + + (v - x i) * (v - x i) * M i i := by + simp [δ, mul_comm, mul_left_comm, mul_assoc] + + +/-- Raw single–site quadratic form update (no diagonal assumption). +Old name kept; proof now delegates to `quadraticForm_update_sum`. -/ +lemma quadraticForm_update_raw + (M : Matrix ι ι R) (x : ι → R) (i : ι) (v : R) : + quadraticForm M (Function.update x i v) - quadraticForm M x + = + (v - x i) * ( (∑ j, x j * M j i) + (M.mulVec x) i ) + + (v - x i) * (v - x i) * M i i := + quadraticForm_update_sum (M:=M) (x:=x) (i:=i) (v:=v) + +/-- Version with only the i-th diagonal entry zero. -/ +lemma quadraticForm_update_single_index + {M : Matrix ι ι R} (i : ι) (hii : M i i = 0) + (x : ι → R) (v : R) : + quadraticForm M (Function.update x i v) - quadraticForm M x + = + (v - x i) * + ( (M.mulVec x) i + + ∑ j ∈ (Finset.univ.erase i), x j * M j i ) := by + classical + have hRaw := quadraticForm_update_raw (M:=M) (x:=x) (i:=i) (v:=v) + have hDiag0 : (v - x i) * (v - x i) * M i i = 0 := by simp [hii] + have h1 : + quadraticForm M (Function.update x i v) - quadraticForm M x + = + (v - x i) * ((∑ j, x j * M j i) + (M.mulVec x) i) := by + simpa [hDiag0, add_comm, add_left_comm, add_assoc] using hRaw + have hSplit : + (∑ j, x j * M j i) + = x i * M i i + ∑ j ∈ (Finset.univ.erase i), x j * M j i := by + have := sum_split_at (f:=fun j => x j * M j i) i + simp [add_comm, add_left_comm, add_assoc] + have hErase : + (∑ j, x j * M j i) + = ∑ j ∈ (Finset.univ.erase i), x j * M j i := by + simp_rw [hSplit, hii]; ring_nf + simp_rw [h1, hErase, add_comm] + + +/-- Original (stronger) version assuming all diagonal entries vanish (kept for backwards compatibility). -/ +lemma quadraticForm_update_single + {M : Matrix ι ι R} (hDiag : ∀ j, M j j = 0) + (x : ι → R) (i : ι) (v : R) : + quadraticForm M (Function.update x i v) - quadraticForm M x + = + (v - x i) * + ( (M.mulVec x) i + + ∑ j ∈ (Finset.univ.erase i), x j * M j i ) := + quadraticForm_update_single_index (M:=M) (x:=x) (i:=i) (v:=v) (hii:=hDiag i) +/-- +Optimized symmetric / zero–diagonal update for the quadratic form. +This is the version used in the Hopfield flip energy relation. +-/ +lemma quadratic_form_update_diag_zero + {M : Matrix ι ι R} (hSymm : M.IsSymm) (hDiag : ∀ j, M j j = 0) + (x : ι → R) (i : ι) (v : R) : + quadraticForm M (Function.update x i v) - quadraticForm M x + = (v - x i) * 2 * (M.mulVec x) i := by + classical + have hBase := quadraticForm_update_single (R:=R) (M:=M) hDiag x i v + have hSwap : + ∑ j ∈ (Finset.univ.erase i), x j * M j i + = ∑ j ∈ (Finset.univ.erase i), M i j * x j := by + refine Finset.sum_congr rfl ?_ + intro j hj + simp [Matrix.IsSymm.apply hSymm, mul_comm] + have hMulVec : + (M.mulVec x) i = ∑ j ∈ (Finset.univ.erase i), M i j * x j := by + unfold Matrix.mulVec dotProduct + classical + have : (Finset.univ : Finset ι) = {i} ∪ Finset.univ.erase i := by + ext j; by_cases hj : j = i <;> simp [hj] + rw [this, Finset.sum_union]; simp [Finset.disjoint_singleton_left, hDiag] + rw [← this] + simp + simp_rw [hBase, hSwap, hMulVec]; simp [two_mul, add_comm, add_left_comm, add_assoc, mul_add, + mul_comm, mul_left_comm, mul_assoc] + +end Matrix + +open Finset Matrix NeuralNetwork State TwoState +--variable [Fintype ι] [DecidableEq ι] [CommRing R] + + +variable {R U σ : Type} +variable [Field R] [LinearOrder R] [IsStrictOrderedRing R] +-- We need these helper lemmas about updPos/updNeg which were not in the prompt's snippet but are essential. +namespace TwoState + +variable {R U σ : Type} [Field R] [LinearOrder R] [IsStrictOrderedRing R] [DecidableEq U] +variable {NN : NeuralNetwork R U σ} [TwoStateNeuralNetwork NN] + +@[simp] +lemma updPos_act_at_u (s : NN.State) (u : U) : + (updPos (NN:=NN) s u).act u = TwoStateNeuralNetwork.σ_pos (NN:=NN) := by + simp [updPos, Function.update_self] + +lemma updPos_act_noteq (s : NN.State) (u v : U) (h : v ≠ u) : + (updPos (NN:=NN) s u).act v = s.act v := by + simp [updPos, Function.update_self h] + aesop + +@[simp] +lemma updNeg_act_at_u (s : NN.State) (u : U) : + (updNeg (NN:=NN) s u).act u = TwoStateNeuralNetwork.σ_neg (NN:=NN) := by + simp [updNeg, Function.update_self] + +lemma updNeg_act_noteq (s : NN.State) (u v : U) (h : v ≠ u) : + (updNeg (NN:=NN) s u).act v = s.act v := by + simp [updNeg, Function.update_self h] + aesop + +-- Also need strict inequalities for logisticProb for detailed balance ratios. +lemma logisticProb_pos (x : ℝ) : 0 < logisticProb x := by + unfold logisticProb + have hden : 0 < 1 + Real.exp (-x) := by + have hx : 0 < Real.exp (-x) := Real.exp_pos _ + linarith + simpa using (one_div_pos.mpr hden) + +lemma logisticProb_lt_one (x : ℝ) : logisticProb x < 1 := by + unfold logisticProb + apply (div_lt_one (add_pos_of_pos_of_nonneg zero_lt_one (le_of_lt (Real.exp_pos _)))).mpr + simp; exact Real.exp_pos _ + +/-- Symmetry: logisticProb (-x) = 1 - logisticProb x. -/ +lemma logisticProb_neg (x : ℝ) : logisticProb (-x) = 1 - logisticProb x := by + unfold logisticProb + have h1 : 1 / (1 + Real.exp x) = Real.exp (-x) / (1 + Real.exp (-x)) := by + have hden : (1 + Real.exp x) ≠ 0 := + (add_pos_of_pos_of_nonneg zero_lt_one (le_of_lt (Real.exp_pos _))).ne' + calc + 1 / (1 + Real.exp x) + = (1 * Real.exp (-x)) / ((1 + Real.exp x) * Real.exp (-x)) := by + field_simp [hden] + _ = Real.exp (-x) / (Real.exp (-x) + 1) := by + simp [mul_add, add_comm, add_left_comm, add_assoc, Real.exp_neg, mul_comm] + ring_nf; rw [mul_eq_mul_left_iff]; simp + _ = Real.exp (-x) / (1 + Real.exp (-x)) := by simp [add_comm] + have h2 : Real.exp (-x) / (1 + Real.exp (-x)) = 1 - 1 / (1 + Real.exp (-x)) := by + have hden : (1 + Real.exp (-x)) ≠ 0 := + (add_pos_of_pos_of_nonneg zero_lt_one (le_of_lt (Real.exp_pos _))).ne' + field_simp [hden] + aesop + +end TwoState + +/-! +# Concrete Energy Specification for Hopfield Networks (SymmetricBinary) + +This section defines the standard Hopfield energy function and proves it satisfies +the `EnergySpec'` requirements for the `SymmetricBinary` architecture. +We leverage `Matrix.quadraticForm` for an elegant definition and proof. +-/ + +namespace HopfieldEnergy + +open Finset Matrix NeuralNetwork TwoState +open scoped Classical + +variable {R U : Type} [Field R] [LinearOrder R] [IsStrictOrderedRing R] +variable [Fintype U] [DecidableEq U] [Nonempty U] + +/-- +The standard Hopfield energy function (Hamiltonian) for SymmetricBinary networks. +E(s) = -1/2 * sᵀ W s + θᵀ s +-/ +noncomputable def hamiltonian + (p : Params (SymmetricBinary R U)) (s : (SymmetricBinary R U).State) : R := + let quad : R := ∑ i : U, s.act i * (p.w.mulVec s.act i) + let θ_vec := fun i : U => (p.θ i).get fin0 + (- (1/2 : R) * quad) + ∑ i : U, θ_vec i * s.act i + +/-- +Proof of the fundamental Flip Energy Relation for the SymmetricBinary network. +ΔE = E(s⁺) - E(s⁻) = -2 * Lᵤ. +This leverages Mathlib's `Matrix.quadratic_form_update_diag_zero`. +-/ +lemma hamiltonian_flip_relation (p : Params (SymmetricBinary R U)) (s : (SymmetricBinary R U).State) (u : U) : + let sPos := updPos (NN:=SymmetricBinary R U) s u + let sNeg := updNeg (NN:=SymmetricBinary R U) s u + let L := s.net p u - (p.θ u).get fin0 + (hamiltonian p sPos - hamiltonian p sNeg) = - (2 : R) * L := by + intro sPos sNeg L + unfold hamiltonian + let θ_vec := fun i => (p.θ i).get fin0 + + -- 1. Analyze the Quadratic Term Difference (ΔE_quad). + have h_quad_diff : + (- (1/2 : R) * Matrix.quadraticForm p.w sPos.act) - (- (1/2 : R) * Matrix.quadraticForm p.w sNeg.act) = + - (2 : R) * (p.w.mulVec s.act u) := by + + rw [← mul_sub] + -- We analyze Q(sPos) - Q(sNeg). sPos has 1 at u, sNeg has -1 at u. + + -- Express sPos as an update of sNeg. + have h_sPos_from_sNeg : sPos.act = Function.update sNeg.act u 1 := by + ext i + by_cases hi : i = u + · subst hi + -- At site u, sPos.act u is σ_pos which is definitionally 1 for SymmetricBinary. + simp_rw [sPos, sNeg, updPos, updNeg, Function.update] + aesop + · simp [sPos, sNeg, updPos, updNeg, Function.update, hi] + rw [h_sPos_from_sNeg] + -- Apply the identity for updating a quadratic form with W symmetric and W_uu=0. + -- Q(update(x, k, v)) - Q(x) = (v - x_k) * 2 * (W x)_k. + rw [Matrix.quadratic_form_update_diag_zero (p.hw'.1) (p.hw'.2)] + -- Here v=1, x=sNeg.act, k=u. sNeg.act u = -1. + have h_sNeg_u : sNeg.act u = -1 := updNeg_act_at_u s u + rw [h_sNeg_u] + -- (1 - (-1)) * 2 * (W sNeg.act)_u = 4 * (W sNeg.act)_u. + simp only [sub_neg_eq_add, one_add_one_eq_two] + ring_nf + -- Relate (W sNeg.act)_u back to s. Since W_uu=0, the activation at u doesn't matter. + have h_W_sNeg_eq_W_s : p.w.mulVec sNeg.act u = p.w.mulVec s.act u := by + unfold Matrix.mulVec dotProduct + apply Finset.sum_congr rfl + intro j _ + by_cases h_eq : j = u + · simp [h_eq, p.hw'.2 u] -- W_uu = 0 + · rw [updNeg_act_noteq s u j h_eq] + + rw [h_W_sNeg_eq_W_s] + -- 2. Linear term difference + have h_linear_diff : + dotProduct θ_vec sPos.act - dotProduct θ_vec sNeg.act + = (2 : R) * θ_vec u := by + rw [← dotProduct_sub] + -- Only coordinate u differs (-1 → 1), so the difference vector is 2·e_u. + have h_diff_vec : + sPos.act - sNeg.act = Pi.single u (2 : R) := by + ext v + by_cases hv : v = u + · subst hv + -- At site u: 1 - (-1) = 2 (for SymmetricBinary) + simp [sPos, sNeg, updPos, updNeg, + TwoState.SymmetricBinary, instTwoStateSymmetricBinary, + Pi.single, sub_eq_add_neg, one_add_one_eq_two] + · -- Off site: unchanged, difference 0 + simp [sPos, sNeg, updPos, updNeg, Pi.single, hv, sub_eq_add_neg] + rw [h_diff_vec, dotProduct_single] + simp [mul_comm] + + -- 3. Combine the terms. + erw [add_sub_add_comm, h_quad_diff, h_linear_diff] + -- Relate (W s.act)_u to L = net(s) - θ_u. We need to show net(s) = (W s.act)_u. + have h_net_eq_W_s : s.net p u = p.w.mulVec s.act u := by + unfold State.net SymmetricBinary fnet Matrix.mulVec dotProduct + apply Finset.sum_congr rfl + intro v _ + split_ifs with h_ne + · aesop + · -- Case v = u (since ¬ (v ≠ u)): the net integrand is 0; the mulVec term is W u u * s.act u = 0. + have hv : v = u := by + classical + by_contra hvne + exact h_ne hvne + subst hv + have hdiag : p.w v v = 0 := p.hw'.2 v + simp [hdiag] + + rw [← h_net_eq_W_s] + -- Goal: -2 * net + 2 * θ = -2 * (net - θ). + ring + +/-- The concrete Energy Specification for the SymmetricBinary Hopfield Network. -/ +noncomputable def symmetricBinaryEnergySpec : EnergySpec' (SymmetricBinary R U) where + E := hamiltonian + localField := fun p s u => s.net p u - (p.θ u).get fin0 + localField_spec := by intros; rfl + flip_energy_relation := by + intro f p s u + have h_rel := hamiltonian_flip_relation p s u + have h_scale : scale (NN:=SymmetricBinary R U) f = f 2 := scale_binary f + simp_rw [h_rel, map_mul, map_neg] + rw [h_scale] + +end HopfieldEnergy + +/-! +# Fintype Instance for Real-valued Binary States + +The bridge to `CanonicalEnsemble` requires `[Fintype NN.State]`. For `SymmetricBinary ℝ U`, +we must formally establish that the subtype restricted to {-1, 1} activations is finite. +-/ + +namespace SymmetricBinaryFintype +variable {U : Type} [Fintype U] [DecidableEq U] [Nonempty U] + +/-- Helper type representing the finite set {-1, 1} in ℝ. -/ +def BinarySetReal := {x : ℝ // x = 1 ∨ x = -1} + +/-- Decidable equality inherited from `ℝ` (classical). -/ +noncomputable instance : DecidableEq BinarySetReal := by + classical + infer_instance + +noncomputable instance : Fintype BinarySetReal := + Fintype.ofList + [⟨1, Or.inl rfl⟩, ⟨-1, Or.inr rfl⟩] + (by + intro x + classical + rcases x.property with h | h + · simp_rw [← h]; exact List.mem_cons_self + · simp_rw [← h]; exact List.mem_of_getLast? rfl) + +/-- Equivalence between the State space of SymmetricBinary ℝ U and (U → BinarySetReal). -/ +noncomputable def stateEquivBinarySet : (TwoState.SymmetricBinary ℝ U).State ≃ (U → BinarySetReal) where + toFun s := fun u => ⟨s.act u, s.hp u⟩ + invFun f := { + act := fun u => (f u).val, + hp := fun u => (f u).property + } + left_inv s := by ext u; simp + right_inv f := by ext u; simp + +-- The required Fintype instance. +noncomputable instance : Fintype (TwoState.SymmetricBinary ℝ U).State := + Fintype.ofEquiv (U → BinarySetReal) stateEquivBinarySet.symm + +end SymmetricBinaryFintype + +/-! +# Detailed Balance and the Boltzmann Distribution + +This section establishes that the Gibbs update kernel is reversible with respect to the +Boltzmann distribution derived from the associated Canonical Ensemble. This holds generically +for any exclusive two-state network with an EnergySpec'. +-/ + +namespace HopfieldBoltzmann + +open CanonicalEnsemble ProbabilityTheory TwoState PMF +open scoped Classical + +variable {U σ : Type} [Fintype U] [DecidableEq U] [Nonempty U] +variable (NN : NeuralNetwork ℝ U σ) [Fintype NN.State] [Nonempty NN.State] +variable [TwoStateNeuralNetwork NN] [TwoStateExclusive NN] +variable (spec : TwoState.EnergySpec' (NN:=NN)) + +variable (p : Params NN) (T : Temperature) + +/-- The Canonical Ensemble obtained from params `p` (builds a local Hamiltonian instance from `spec`). -/ +noncomputable def CEparams (p : Params NN) : CanonicalEnsemble NN.State := + let _ : IsHamiltonian (U:=U) (σ:=σ) NN := + IsHamiltonian_of_EnergySpec' (NN:=NN) (spec:=spec) + hopfieldCE (U:=U) (σ:=σ) NN p + +/-- Boltzmann probability of state `s` at temperature `T`. -/ +noncomputable def P (p : Params NN) (T : Temperature) (s : NN.State) : ℝ := + (CEparams (NN:=NN) (spec:=spec) p).probability T s + +omit [Nonempty U] [Nonempty NN.State] in +@[simp] lemma energy_eq_spec (p : Params NN) (s : NN.State) : + let _ : IsHamiltonian (U:=U) (σ:=σ) NN := + IsHamiltonian_of_EnergySpec' (NN:=NN) (spec:=spec) + IsHamiltonian.energy (NN:=NN) p s = spec.E p s := by + rfl + +open scoped ENNReal Temperature Constants CanonicalEnsemble + +omit [Fintype U] [DecidableEq U] [Nonempty U] [TwoStateNeuralNetwork NN] [TwoStateExclusive NN] in +/-- General canonical-ensemble probability ratio identity (finite state case). -/ +lemma CE_probability_ratio + (𝓒 : CanonicalEnsemble NN.State) [𝓒.IsFinite] (T : Temperature) + (s s' : NN.State) : + 𝓒.probability T s' / 𝓒.probability T s = + Real.exp (-(T.β : ℝ) * (𝓒.energy s' - 𝓒.energy s)) := by + classical + unfold CanonicalEnsemble.probability + set Z := 𝓒.mathematicalPartitionFunction T + have hZpos := mathematicalPartitionFunction_pos_finite (𝓒:=𝓒) (T:=T) + have hZne : Z ≠ 0 := hZpos.ne' + have hcancel : + (Real.exp (-(T.β : ℝ) * 𝓒.energy s') / Z) / + (Real.exp (-(T.β : ℝ) * 𝓒.energy s) / Z) + = + Real.exp (-(T.β : ℝ) * 𝓒.energy s') / + Real.exp (-(T.β : ℝ) * 𝓒.energy s) := by + have hc : + (Real.exp (-(T.β : ℝ) * 𝓒.energy s') * Z⁻¹) / + (Real.exp (-(T.β : ℝ) * 𝓒.energy s) * Z⁻¹) + = + Real.exp (-(T.β : ℝ) * 𝓒.energy s') / + Real.exp (-(T.β : ℝ) * 𝓒.energy s) := by + have hZinv_ne : Z⁻¹ ≠ 0 := inv_ne_zero hZne + simp; ring_nf; rw [mul_inv_cancel_right₀ hZinv_ne (Real.exp (-(↑T.β * 𝓒.energy s')))] + simpa [div_eq_mul_inv] using hc + simp [Z, hcancel] + have hratio : + Real.exp (-(T.β : ℝ) * 𝓒.energy s') / + Real.exp (-(T.β : ℝ) * 𝓒.energy s) + = + Real.exp (-(T.β : ℝ) * 𝓒.energy s' - (-(T.β : ℝ) * 𝓒.energy s)) := by + simpa using + (Real.exp_sub (-(T.β : ℝ) * 𝓒.energy s') + (-(T.β : ℝ) * 𝓒.energy s)).symm + have hexp : + -(T.β : ℝ) * 𝓒.energy s' - (-(T.β : ℝ) * 𝓒.energy s) + = -(T.β : ℝ) * (𝓒.energy s' - 𝓒.energy s) := by + ring + aesop + +omit [Nonempty U] in +/-- Ratio of Boltzmann probabilities P(s')/P(s) = exp(-β(E(s')-E(s))). -/ +lemma boltzmann_ratio (s s' : NN.State) : + P (NN:=NN) (spec:=spec) p T s' / P (NN:=NN) (spec:=spec) p T s = + Real.exp (-(T.β : ℝ) * (spec.E p s' - spec.E p s)) := by + classical + have _ : IsHamiltonian (U:=U) (σ:=σ) NN := + IsHamiltonian_of_EnergySpec' (NN:=NN) (spec:=spec) + set 𝓒 := CEparams (NN:=NN) (spec:=spec) p + have instFin : 𝓒.IsFinite := by + dsimp [𝓒, CEparams] -- unfolds to `hopfieldCE` + infer_instance + have h := CE_probability_ratio (NN:=NN) (𝓒:=𝓒) (T:=T) s s' + simpa [P, 𝓒, + energy_eq_spec (NN:=NN) (spec:=spec) (p:=p) (s:=s), + energy_eq_spec (NN:=NN) (spec:=spec) (p:=p) (s:=s'), + sub_eq_add_neg, add_comm, add_left_comm, add_assoc] using h + +-- Define the transition probability K(s→s') in ℝ. +noncomputable def Kbm (u : U) (s s' : NN.State) : ℝ := + ((TwoState.gibbsUpdate (NN:=NN) (RingHom.id ℝ) p T s u) s').toReal + +-- Helper lemmas to evaluate K explicitly. + +omit [Fintype U] [Nonempty U] [Fintype NN.State] [Nonempty NN.State] [TwoStateExclusive NN] in +/-- Pointwise evaluation at `updPos` (ENNReal helper). -/ +private lemma gibbsUpdate_apply_updPos + (f : ℝ →+* ℝ) (p : Params NN) (T : Temperature) (s : NN.State) (u : U) : + (gibbsUpdate (NN:=NN) f p T s u) (updPos (s:=s) (u:=u)) + = ENNReal.ofReal (probPos (NN:=NN) f p T s u) := by + classical + unfold gibbsUpdate + set pPos : ℝ := probPos (NN:=NN) f p T s u + set pPosE : ENNReal := ENNReal.ofReal pPos + have h_le_real : pPos ≤ 1 := probPos_le_one (NN:=NN) f p T s u + have h_le : pPosE ≤ 1 := by + aesop + have hne : updPos (NN:=NN) s u ≠ updNeg (NN:=NN) s u := by + intro h + have h' := congrArg (fun st => st.act u) h + simp [updPos, updNeg] at h' + exact TwoStateNeuralNetwork.h_pos_ne_neg (NN:=NN) h' + simp [pPos, pPosE, + PMF.bernoulli_bind_pure_apply_left_of_ne (α:=NN.State) h_le hne] + +omit [Fintype U] [Nonempty U] [Fintype NN.State] [Nonempty NN.State] [TwoStateExclusive NN] in +/-- Pointwise evaluation at `updNeg` (ENNReal helper). -/ +private lemma gibbsUpdate_apply_updNeg + (f : ℝ →+* ℝ) (p : Params NN) (T : Temperature) (s : NN.State) (u : U) : + (gibbsUpdate (NN:=NN) f p T s u) (updNeg (s:=s) (u:=u)) + = ENNReal.ofReal (1 - probPos (NN:=NN) f p T s u) := by + classical + unfold gibbsUpdate + set pPos : ℝ := probPos (NN:=NN) f p T s u + set pPosE : ENNReal := ENNReal.ofReal pPos + have h_le_real : pPos ≤ 1 := probPos_le_one (NN:=NN) f p T s u + have h_le : pPosE ≤ 1 := by + simpa [pPosE, ENNReal.ofReal_one] using h_le_real + have h_nonneg : 0 ≤ pPos := probPos_nonneg (NN:=NN) f p T s u + have hne : updPos (NN:=NN) s u ≠ updNeg (NN:=NN) s u := by + intro h + have h' := congrArg (fun st => st.act u) h + simp [updPos, updNeg] at h' + exact TwoStateNeuralNetwork.h_pos_ne_neg (NN:=NN) h' + have h_eval := + PMF.bernoulli_bind_pure_apply_right_of_ne (α:=NN.State) h_le hne + have hsub : ENNReal.ofReal (1 - pPos) = 1 - pPosE := by + simp [pPosE, ENNReal.ofReal_one] + ring_nf + rw [ENNReal.sub_eq_sInf, ENNReal.ofReal_sub 1 h_nonneg, + ENNReal.sub_eq_sInf, ENNReal.ofReal_one] + simp [pPos, pPosE, h_eval, hsub] + +omit [Fintype U] [Nonempty U] [Fintype NN.State] [Nonempty NN.State] [TwoStateExclusive NN] in +lemma Kbm_apply_updPos (u : U) (s : NN.State) : + Kbm NN p T u s (updPos (NN:=NN) s u) = TwoState.probPos (NN:=NN) (RingHom.id ℝ) p T s u := by + let f := RingHom.id ℝ + unfold Kbm; rw [gibbsUpdate_apply_updPos NN f] + exact ENNReal.toReal_ofReal (probPos_nonneg f p T s u) + +omit [Fintype U] [Nonempty U] [Fintype NN.State] [Nonempty NN.State] [TwoStateExclusive NN] in +lemma Kbm_apply_updNeg (u : U) (s : NN.State) : + Kbm NN p T u s (updNeg (NN:=NN) s u) = 1 - TwoState.probPos (NN:=NN) (RingHom.id ℝ) p T s u := by + let f := RingHom.id ℝ + unfold Kbm; rw [gibbsUpdate_apply_updNeg NN f] + have h_nonneg : 0 ≤ 1 - probPos (NN:=NN) f p T s u := sub_nonneg.mpr (probPos_le_one f p T s u) + exact ENNReal.toReal_ofReal h_nonneg + +omit [Fintype U] [Nonempty U] [Fintype NN.State] [Nonempty NN.State] [TwoStateExclusive NN] in +lemma Kbm_apply_other (u : U) (s s' : NN.State) + (h_pos : s' ≠ updPos (NN:=NN) s u) (h_neg : s' ≠ updNeg (NN:=NN) s u) : + Kbm NN p T u s s' = 0 := by + unfold Kbm gibbsUpdate + let f := RingHom.id ℝ + let pPosE := ENNReal.ofReal (TwoState.probPos (NN:=NN) f p T s u) + have h_le : pPosE ≤ 1 := by simp [pPosE, TwoState.probPos_le_one] + have h_K := PMF.bernoulli_bind_pure_apply_other h_le h_pos h_neg + simp [h_K] + aesop + +/-- Helper: (1 - logistic(x)) / logistic(x) = exp(-x). -/ +lemma one_sub_logistic_div_logistic (x : ℝ) : + (1 - logisticProb x) / logisticProb x = Real.exp (-x) := by + have h_pos := logisticProb_pos x + rw [div_eq_iff h_pos.ne'] + unfold logisticProb + have h_den_pos : 0 < 1 + Real.exp (-x) := by apply add_pos_of_pos_of_nonneg zero_lt_one; exact (Real.exp_pos _).le + field_simp [h_den_pos.ne'] + +end HopfieldBoltzmann diff --git a/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Core.lean b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Core.lean index b62411241..1c24084ea 100644 --- a/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Core.lean +++ b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Core.lean @@ -180,11 +180,11 @@ lemma act_up_def : (s.Up wθ u).act u = @[simp] lemma act_of_non_up (huv : v2 ≠ u) : (s.Up wθ u).act v2 = s.act v2 := by - simp only [Up, if_neg huv] + simp [Up, huv] @[simp] lemma act_new_neg_one_if_net_lt_th (hn : s.net wθ u < θ' (wθ.θ u)) : (s.Up wθ u).act u = -1 := by - rw [act_up_def]; exact ite_eq_right_iff.mpr fun hyp => (hn.not_le hyp).elim + rw [act_up_def]; exact ite_eq_right_iff.mpr fun hyp => (hn.not_ge hyp).elim @[simp] lemma actnew_neg_one_if_net_lt_th (hn : s.net wθ u < θ' (wθ.θ u)) : (s.Up wθ u).act u = -1 := @@ -327,6 +327,7 @@ lemma Ew_diff' : (s.Up wθ u).Ew wθ - s.Ew wθ = rw [mul_sum, mul_sum, ← sum_neg_distrib, ← sum_add_distrib, sum_eq_zero] simp only [mem_filter, mem_univ, true_and, and_imp]; intro v2 _ hv1 hvneg2 simp_all only [Wact, Up, mul_ite, ite_mul, reduceIte, add_neg_cancel] + simp only [↓reduceDIte, add_neg_cancel] simp only [sub_neg_eq_add] @[simp] @@ -514,10 +515,10 @@ def stateLt (s1 s2 : State' wθ) : Prop := s1.E wθ < s2.E wθ ∨ s1.E wθ = s2 @[simp] lemma stateLt_antisym (s1 s2 : State' wθ) : stateLt s1 s2 → ¬stateLt s2 s1 := by rintro (h1 | ⟨_, h3⟩) (h2 | ⟨_, h4⟩) - · exact h1.not_lt h2 + · exact h1.not_gt h2 · simp_all only [lt_self_iff_false] · simp_all only [lt_self_iff_false] - · exact h3.not_lt h4 + · exact h3.not_gt h4 /-- Defines a partial order on states. The relation `stateOrd` holds between two states `s1` and `s2` @@ -560,8 +561,8 @@ lemma stateLt_lt (s1 s2 : State' wθ) : s1 < s2 ↔ stateLt s1 s2 := by constructor · intro hs; subst hs; have : ¬stateLt s2 s2:= fun - | Or.inl h1 => h1.not_lt h1 - | Or.inr ⟨_, h3⟩ => h3.not_lt h3 + | Or.inl h1 => h1.not_gt h1 + | Or.inr ⟨_, h3⟩ => h3.not_gt h3 exact this hs2 · intro hs; apply stateLt_antisym s1 s2 hs2 hs @@ -573,7 +574,7 @@ lemma state_act_eq (s1 s2 : State' wθ) : s1.act = s2.act → s1 = s2 := by @[simp] lemma state_Up_act (s : State' wθ) : (Up' s u).act u = s.act u → Up' s u = s := by intro h; cases' s with act hact; apply state_act_eq; ext v - by_cases huv : v = u; simp only [huv, h]; simp only [Up', Up, huv, reduceIte] + by_cases huv : v = u; simp only [huv, h]; simp [Up', Up, huv, reduceIte] @[simp] lemma up_act_eq_act_of_up_eq (s : State' wθ) : Up' s u = s → (Up' s u).act u = s.act u := fun hs => @@ -656,7 +657,7 @@ lemma num_of_states_decreases (hs : s < s') : apply Finset.card_lt_card rw [Finset.ssubset_iff_of_subset] simp only [mem_filter, mem_univ, true_and, not_lt] - use s; exact ⟨hs, gt_irrefl s⟩ + use s; exact ⟨hs, lt_irrefl s⟩ simp only [Finset.subset_iff, mem_filter, mem_univ, true_and] exact fun _ hx => hx.trans hs @@ -909,7 +910,7 @@ lemma stateisStablecondition {m : ℕ} (s : (HopfieldNetwork R U).State) (c : R) (hc : 0 < c) (hw : ∀ u, ((Hebbian ps).w).mulVec s.act u = c * s.act u) : s.isStable (Hebbian ps) := by intros u - unfold Up out + unfold State.Up simp only [reduceIte, Fin.isValue] rw [HNfnet_eq] simp_rw [mulVec, dotProduct] at hw u diff --git a/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/DetailedBalanceBM.lean b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/DetailedBalanceBM.lean new file mode 100644 index 000000000..418b5bae4 --- /dev/null +++ b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/DetailedBalanceBM.lean @@ -0,0 +1,1145 @@ +import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.BoltzmannMachine +import Mathlib +import PhysLean.StatisticalMechanics.CanonicalEnsemble.Finite +import PhysLean.StatisticalMechanics.CanonicalEnsemble.Lemmas + +-- Provide a finite canonical ensemble instance for the Hopfield Boltzmann construction. +instance + {U σ : Type} [Fintype U] [DecidableEq U] + (NN : NeuralNetwork ℝ U σ) [Fintype NN.State] [Nonempty NN.State] + [TwoStateNeuralNetwork NN] [TwoState.TwoStateExclusive NN] + (spec : TwoState.EnergySpec' (NN:=NN)) (p : Params NN) : + CanonicalEnsemble.IsFinite (HopfieldBoltzmann.CEparams (NN:=NN) (spec:=spec) p) := by + classical + have _ : IsHamiltonian (U:=U) (σ:=σ) NN := IsHamiltonian_of_EnergySpec' spec + dsimp [HopfieldBoltzmann.CEparams] + infer_instance + +variable [Fintype ι] [DecidableEq ι] [Ring R] --[CommRing R] +open CanonicalEnsemble Constants + +section DetailedBalance +open scoped Classical +open TwoState HopfieldBoltzmann + +variable {U σ : Type} [Fintype U] [DecidableEq U] [Nonempty U] +variable (NN : NeuralNetwork ℝ U σ) [Fintype NN.State] [Nonempty NN.State] +variable [TwoStateNeuralNetwork NN] [TwoStateExclusive NN] +variable (spec : TwoState.EnergySpec' (NN:=NN)) +variable (p : Params NN) (T : Temperature) + +local notation "P" => P (NN:=NN) (spec:=spec) p T +local notation "K" => Kbm (NN:=NN) p T + +/-- Helper: states differ away from `u` (∃ other coordinate with different activation). -/ +def DiffAway (u : U) (s s' : NN.State) : Prop := + ∃ v, v ≠ u ∧ s.act v ≠ s'.act v + +omit [Fintype U] [Nonempty U] [Fintype NN.State] [Nonempty NN.State] [TwoStateExclusive NN] in +/-- If the states differ away from the update site, both transition probabilities vanish. -/ +lemma Kbm_zero_of_diffAway + {u : U} {s s' : NN.State} + (h : DiffAway (NN:=NN) u s s') : + K (u:=u) s s' = 0 ∧ K (u:=u) s' s = 0 := by + classical + rcases h with ⟨v, hv_ne, hv_diff⟩ + have h_ne_pos : s' ≠ updPos (NN:=NN) s u := by + intro h_eq + have hc := congrArg (fun st : NN.State => st.act v) h_eq + have hupd : (updPos (NN:=NN) s u).act v = s.act v := by + simp [updPos_act_noteq (NN:=NN) s u v hv_ne] + have : s'.act v = s.act v := by simpa [hupd] using hc + exact hv_diff (id (Eq.symm this)) + have h_ne_neg : s' ≠ updNeg (NN:=NN) s u := by + intro h_eq + have hc := congrArg (fun st : NN.State => st.act v) h_eq + have hupd : (updNeg (NN:=NN) s u).act v = s.act v := by + simp [updNeg_act_noteq (NN:=NN) s u v hv_ne] + have : s'.act v = s.act v := by simpa [hupd] using hc + exact hv_diff (id (Eq.symm this)) + have h_forward : Kbm (NN:=NN) p T u s s' = 0 := + Kbm_apply_other (NN:=NN) (p:=p) (T:=T) u s s' h_ne_pos h_ne_neg + have h_ne_pos' : s ≠ updPos (NN:=NN) s' u := by + intro h_eq + have hc := congrArg (fun st : NN.State => st.act v) h_eq + have hupd : (updPos (NN:=NN) s' u).act v = s'.act v := by + simp [updPos_act_noteq (NN:=NN) s' u v hv_ne] + have : s.act v = s'.act v := by simpa [hupd] using hc + exact hv_diff this + have h_ne_neg' : s ≠ updNeg (NN:=NN) s' u := by + intro h_eq + have hc := congrArg (fun st : NN.State => st.act v) h_eq + have hupd : (updNeg (NN:=NN) s' u).act v = s'.act v := by + simp [updNeg_act_noteq (NN:=NN) s' u v hv_ne] + have : s.act v = s'.act v := by simpa [hupd] using hc + exact hv_diff this + have h_backward : Kbm (NN:=NN) p T u s' s = 0 := + Kbm_apply_other (NN:=NN) (p:=p) (T:=T) u s' s h_ne_pos' h_ne_neg' + exact ⟨h_forward, h_backward⟩ + +omit [Nonempty U] [Nonempty NN.State] in +/-- Detailed balance holds trivially in the “diff-away” case (both transition probabilities are 0). -/ +lemma detailed_balance_diffAway + {u : U} {s s' : NN.State} + (h : DiffAway (NN:=NN) u s s') : + P s * K (u:=u) s s' = P s' * K (u:=u) s' s := by + classical + rcases Kbm_zero_of_diffAway (NN:=NN) (p:=p) (T:=T) h with ⟨h1, h2⟩ + simp [h1, h2] + +omit [Fintype U] [Nonempty U] [Fintype NN.State] [Nonempty NN.State] in +/-- Classification of the single-site difference at `u` (exclusive two-state case). -/ +lemma single_site_cases + {u : U} {s s' : NN.State} + (h_off : ∀ v ≠ u, s.act v = s'.act v) + (h_ne : s ≠ s') : + (s.act u = TwoStateNeuralNetwork.σ_pos (NN:=NN) ∧ + s'.act u = TwoStateNeuralNetwork.σ_neg (NN:=NN)) + ∨ (s.act u = TwoStateNeuralNetwork.σ_neg (NN:=NN) ∧ + s'.act u = TwoStateNeuralNetwork.σ_pos (NN:=NN)) := by + classical + have hx : s.act u ≠ s'.act u := by + intro hcontra + apply h_ne + ext v + by_cases hv : v = u + · simp [hv, hcontra] + · simpa [hv] using h_off v hv + rcases (TwoStateExclusive.pact_iff (NN:=NN) (a:=s.act u)).1 (s.hp u) with hs_pos | hs_neg + · rcases (TwoStateExclusive.pact_iff (NN:=NN) (a:=s'.act u)).1 (s'.hp u) with hs'_pos | hs'_neg + · exact False.elim (hx (hs_pos.trans hs'_pos.symm)) + · exact Or.inl ⟨hs_pos, hs'_neg⟩ + · rcases (TwoStateExclusive.pact_iff (NN:=NN) (a:=s'.act u)).1 (s'.hp u) with hs'_pos | hs'_neg + · exact Or.inr ⟨hs_neg, hs'_pos⟩ + · exact False.elim (hx (hs_neg.trans hs'_neg.symm)) + +/-- Convenience: `logisticProb (-x) = 1 - logisticProb x` (already available above as +`TwoState.logisticProb_neg`, re-exposed here in the local namespace for algebra lemmas). -/ +lemma logistic_neg (x : ℝ) : + logisticProb (-x) = 1 - logisticProb x := + TwoState.logisticProb_neg x + +/-- Algebraic lemma: `(1 - logisticProb x) = logisticProb (-x)`. -/ +lemma one_sub_logistic_eq_logistic_neg (x : ℝ) : + 1 - logisticProb x = logisticProb (-x) := by + simp [logistic_neg x] + +/-- Denominator non‑zero: `1 - logisticProb x ≠ 0`. -/ +lemma one_sub_logistic_ne_zero (x : ℝ) : + 1 - logisticProb x ≠ 0 := by + have hxlt : logisticProb x < 1 := TwoState.logisticProb_lt_one x + exact sub_ne_zero_of_ne (ne_of_lt hxlt).symm + +/-- Positivity: `logisticProb x > 0`. -/ +lemma logisticProb_pos' (x : ℝ) : 0 < logisticProb x := + TwoState.logisticProb_pos x + +/-- Nonnegativity of the complement: `0 < 1 - logisticProb x`. -/ +lemma one_sub_logistic_pos (x : ℝ) : 0 < 1 - logisticProb x := by + have hxlt : logisticProb x < 1 := TwoState.logisticProb_lt_one x + linarith + +lemma logistic_div_one_sub_logistic (x : ℝ) : + logisticProb x / (1 - logisticProb x) = Real.exp x := by + have hbase : (1 - logisticProb x) / logisticProb x = Real.exp (-x) := + one_sub_logistic_div_logistic x + have hpos : logisticProb x ≠ 0 := (ne_of_gt (logisticProb_pos' x)) + have hden : 1 - logisticProb x ≠ 0 := one_sub_logistic_ne_zero x + have hrecip : + logisticProb x / (1 - logisticProb x) + = ((1 - logisticProb x) / logisticProb x)⁻¹ := by + field_simp [hpos, hden] + calc + logisticProb x / (1 - logisticProb x) + = ((1 - logisticProb x) / logisticProb x)⁻¹ := hrecip + _ = (Real.exp (-x))⁻¹ := by simp [hbase] + _ = Real.exp x := by simp [Real.exp_neg] + +/-- Ratio identity: `logisticProb x / logisticProb (-x) = exp x`. -/ +lemma logistic_ratio_neg (x : ℝ) : + logisticProb x / logisticProb (-x) = Real.exp x := by + have hden_ne : logisticProb (-x) ≠ 0 := (ne_of_gt (logisticProb_pos' (-x))) + have hden : logisticProb (-x) = 1 - logisticProb x := logistic_neg x + calc + logisticProb x / logisticProb (-x) + = logisticProb x / (1 - logisticProb x) := by simp [hden] + _ = Real.exp x := logistic_div_one_sub_logistic x + +/-- Ratio identity in inverted orientation: `logisticProb (-x) / logisticProb x = exp (-x)`. -/ +lemma logistic_ratio (x : ℝ) : + logisticProb (-x) / logisticProb x = Real.exp (-x) := by + simpa [neg_neg] using logistic_ratio_neg (-x) + +/-- Paired flip probabilities (general `EnergySpec'`). For a site `u`, set +`sPos := updPos s u`, `sNeg := updNeg s u` and +`Δ := f (spec.E p sPos - spec.E p sNeg)`. +Then +``` +probPos f p T s u = logisticProb (-Δ * β) +probPos f p T sNeg u = logisticProb (-Δ * β) +``` -/ +lemma TwoState.EnergySpec'.probPos_flip_pair + {U σ} [Fintype U] [DecidableEq U] + {NN : NeuralNetwork ℝ U σ} [TwoStateNeuralNetwork NN] + (spec : TwoState.EnergySpec' (NN:=NN)) + (p : Params NN) (T : Temperature) (s : NN.State) (u : U) : + let f := (RingHom.id ℝ) + let sPos := updPos (NN:=NN) s u + let sNeg := updNeg (NN:=NN) s u + let Δ := f (spec.E p sPos - spec.E p sNeg) + probPos (NN:=NN) f p T s u = logisticProb (-Δ * (T.β : ℝ)) ∧ + probPos (NN:=NN) f p T sNeg u = logisticProb (-Δ * (T.β : ℝ)) := by + classical + intro f sPos sNeg Δ + let ES : TwoState.EnergySpec (NN:=NN) := + { E := spec.E + , localField := spec.localField + , localField_spec := spec.localField_spec + , flip_energy_relation := by + intro f' p' s' u' + simpa using spec.flip_energy_relation f' p' s' u' } + have hPos : updPos (NN:=NN) s u = sPos := rfl + have hNeg : updNeg (NN:=NN) s u = sNeg := rfl + have h₁ : probPos (NN:=NN) f p T s u = + logisticProb (-Δ * (T.β : ℝ)) := by + have h := ES.probPos_eq_of_energy f p T s u + dsimp [ES] at h + simpa [hPos, hNeg, Δ, sub_eq_add_neg, + mul_comm, mul_left_comm, mul_assoc] using h + have hPos' : updPos (NN:=NN) sNeg u = sPos := by + ext v + by_cases hv : v = u + · subst hv; simp [sPos, sNeg, updPos, updNeg] + · simp [sPos, sNeg, updPos, updNeg, hv] + have hNeg' : updNeg (NN:=NN) sNeg u = sNeg := by + ext v + by_cases hv : v = u + · subst hv; simp [sNeg, updNeg] + · simp [sNeg, updNeg, hv] + have h₂ : probPos (NN:=NN) f p T sNeg u = + logisticProb (-Δ * (T.β : ℝ)) := by + have h := ES.probPos_eq_of_energy f p T sNeg u + dsimp [ES] at h + simpa [hPos', hNeg', Δ, sub_eq_add_neg, + mul_comm, mul_left_comm, mul_assoc] using h + exact ⟨h₁, h₂⟩ + +/-- Specialization of the previous pair lemma to the “neg→pos” orientation used +in `detailed_balance_neg_pos`. Here `ΔE = E s' - E s` with `s' = updPos s u` +and `s = updNeg s' u` (i.e. `s` carries σ_neg at `u`, `s'` carries σ_pos). -/ +lemma flip_prob_neg_pos + {U σ} [Fintype U] [DecidableEq U] + {NN : NeuralNetwork ℝ U σ} [TwoStateNeuralNetwork NN] + (spec : TwoState.EnergySpec' (NN:=NN)) + (p : Params NN) (T : Temperature) + {s s' : NN.State} {u : U} + (h_off : ∀ v ≠ u, s.act v = s'.act v) + (h_neg : s.act u = TwoStateNeuralNetwork.σ_neg (NN:=NN)) + (h_pos : s'.act u = TwoStateNeuralNetwork.σ_pos (NN:=NN)) : + let ΔE := spec.E p s' - spec.E p s + probPos (NN:=NN) (RingHom.id ℝ) p T s u = logisticProb (-(T.β : ℝ) * ΔE) ∧ + probPos (NN:=NN) (RingHom.id ℝ) p T s' u = logisticProb (-(T.β : ℝ) * ΔE) := by + classical + intro ΔE + have h_sPos : updPos (NN:=NN) s u = s' := by + ext v; by_cases hv : v = u + · subst hv; simp [updPos_act_at_u, h_pos] + · simp [updPos_act_noteq (NN:=NN) s u v hv, h_off v hv] + have h_sNeg : updNeg (NN:=NN) s u = s := by + ext v; by_cases hv : v = u + · subst hv; simp [updNeg_act_at_u, h_neg] + · simp [updNeg_act_noteq (NN:=NN) s u v hv] + obtain ⟨h_prob_s, _⟩ := + (TwoState.EnergySpec'.probPos_flip_pair (NN:=NN) spec p T s u) + have hΔ₁ : + (RingHom.id ℝ) + (spec.E p (updPos (NN:=NN) s u) - spec.E p (updNeg (NN:=NN) s u)) + = ΔE := by + simp [ΔE, h_sPos, h_sNeg] + have h1 : + probPos (NN:=NN) (RingHom.id ℝ) p T s u + = logisticProb (-(T.β : ℝ) * ΔE) := by + rw [h_prob_s, hΔ₁] + ring_nf + have h_s'Pos : updPos (NN:=NN) s' u = s' := by + ext v; by_cases hv : v = u + · subst hv; simp [updPos_act_at_u, h_pos] + · simp [updPos_act_noteq (NN:=NN) s' u v hv] + have h_s'Neg : updNeg (NN:=NN) s' u = s := by + ext v; by_cases hv : v = u + · subst hv; simp [updNeg_act_at_u, h_neg] + · simp [updNeg_act_noteq (NN:=NN) s' u v hv, (h_off v hv).symm] + obtain ⟨_, h_prob_s'⟩ := + (TwoState.EnergySpec'.probPos_flip_pair (NN:=NN) spec p T s' u) + have hΔ₂ : + (RingHom.id ℝ) + (spec.E p (updPos (NN:=NN) s' u) - spec.E p (updNeg (NN:=NN) s' u)) + = ΔE := by + simp [ΔE, h_s'Pos, h_s'Neg, neg_sub] + have h2 : + probPos (NN:=NN) (RingHom.id ℝ) p T s' u + = logisticProb (-(T.β : ℝ) * ΔE) := by + simp [h_prob_s', hΔ₂] + ring_nf; aesop + exact ⟨h1, h2⟩ + +/-- Clean algebraic lemma: + if + • `Pfun s' / Pfun s = exp (-β ΔE)` and + • `Kfun s s' / Kfun s' s = exp ( β ΔE)` + then detailed balance holds: `Pfun s * Kfun s s' = Pfun s' * Kfun s' s`. -/ +lemma detailed_balance_from_opposite_ratios + {α : Type} + {Pfun : α → ℝ} -- one-argument (probability) function + {Kfun : α → α → ℝ} -- two-argument (kernel) function + {s s' : α} {β ΔE : ℝ} + (hP : Pfun s' / Pfun s = Real.exp (-β * ΔE)) + (hK : Kfun s s' / Kfun s' s = Real.exp (-β * ΔE)) + (hPpos : 0 < Pfun s) (hKpos : 0 < Kfun s' s) : + Pfun s * Kfun s s' = Pfun s' * Kfun s' s := by + have hPne : Pfun s ≠ 0 := (ne_of_gt hPpos) + have hKne : Kfun s' s ≠ 0 := (ne_of_gt hKpos) + have hEqRat : Pfun s' / Pfun s = Kfun s s' / Kfun s' s := by + simp [hP, hK] + have hP' : Pfun s' = (Kfun s s' / Kfun s' s) * Pfun s := by + have := hEqRat + exact (div_eq_iff hPne).1 this + have hFinal : + Pfun s' * Kfun s' s = Kfun s s' * Pfun s := by + have hK' : Kfun s' s ≠ 0 := hKne + calc + Pfun s' * Kfun s' s + = (Kfun s s' / Kfun s' s * Pfun s) * Kfun s' s := by simp [hP'] + _ = (Kfun s s' * (Kfun s' s)⁻¹ * Pfun s) * Kfun s' s := by + simp [div_eq_mul_inv, mul_comm, mul_left_comm, mul_assoc] + _ = Kfun s s' * Pfun s * ((Kfun s' s)⁻¹ * Kfun s' s) := by + ring_nf + _ = Kfun s s' * Pfun s := by + simp [hK'] + simpa [mul_comm, mul_left_comm, mul_assoc] using hFinal.symm + +omit [Nonempty U] in +lemma detailed_balance_neg_pos + {u : U} {s s' : NN.State} + (h_off : ∀ v ≠ u, s.act v = s'.act v) + (h_neg : s.act u = TwoStateNeuralNetwork.σ_neg (NN:=NN)) + (h_pos : s'.act u = TwoStateNeuralNetwork.σ_pos (NN:=NN)) : + P s * K (u:=u) s s' = P s' * K (u:=u) s' s := by + classical + have h_updPos : s' = updPos (NN:=NN) s u := by + ext v; by_cases hv : v = u + · subst hv; simp [updPos_act_at_u, h_pos] + · simp [updPos_act_noteq (NN:=NN) s u v hv, h_off v hv] + have h_updNeg : s = updNeg (NN:=NN) s' u := by + ext v; by_cases hv : v = u + · subst hv; simp [updNeg_act_at_u, h_neg] + · have := h_off v hv; simp [updNeg_act_noteq (NN:=NN) s' u v hv, this.symm] + have hK_fwd : + K (u:=u) s s' = probPos (RingHom.id ℝ) p T s u := by + subst h_updPos + simpa [Kbm] using + (Kbm_apply_updPos (NN:=NN) (p:=p) (T:=T) u s) + have hK_bwd : + K (u:=u) s' s = 1 - probPos (RingHom.id ℝ) p T s' u := by + subst h_updNeg + simpa [Kbm] using + (Kbm_apply_updNeg (NN:=NN) (p:=p) (T:=T) u s') + let ΔE := spec.E p s' - spec.E p s + obtain ⟨hProb_fwd, hProb_bwd⟩ := + flip_prob_neg_pos (NN:=NN) (spec:=spec) p T + (s:=s) (s':=s') (u:=u) h_off h_neg h_pos + have hKf : + K (u:=u) s s' = logisticProb (-(T.β : ℝ) * ΔE) := by + simp [hK_fwd, hProb_fwd, ΔE, mul_comm, mul_left_comm, mul_assoc] + have hKb : + K (u:=u) s' s = logisticProb ((T.β : ℝ) * ΔE) := by + have hbwdprob : + probPos (RingHom.id ℝ) p T s' u + = logisticProb (-(T.β : ℝ) * ΔE) := by + simpa [ΔE, mul_comm, mul_left_comm, mul_assoc] using hProb_bwd + have hneg := logistic_neg ((T.β : ℝ) * ΔE) + have : 1 - probPos (RingHom.id ℝ) p T s' u + = logisticProb ((T.β : ℝ) * ΔE) := by + have hx : + probPos (RingHom.id ℝ) p T s' u + = 1 - logisticProb ((T.β : ℝ) * ΔE) := by + simpa [hneg] using hbwdprob + simp [hx] + simp [hK_bwd, this] + have hPratio := + boltzmann_ratio (NN:=NN) (spec:=spec) (p:=p) (T:=T) s s' + have hPratio' : + P s' / P s = Real.exp (-(T.β : ℝ) * ΔE) := by + simpa [ΔE, sub_eq_add_neg, add_comm, add_left_comm, add_assoc] + using hPratio + have hKratio : + K (u:=u) s s' / K (u:=u) s' s + = Real.exp (-(T.β : ℝ) * ΔE) := by + have := logistic_ratio ((T.β : ℝ) * ΔE) + simpa [hKf, hKb] using this + have hKpos : 0 < K (u:=u) s' s := by + simp [hKb, logisticProb_pos'] + have hPpos : 0 < P s := by + classical + have _ : IsHamiltonian (U:=U) (σ:=σ) NN := + IsHamiltonian_of_EnergySpec' (NN:=NN) (spec:=spec) + set 𝓒 := CEparams (NN:=NN) (spec:=spec) p + have instFin : 𝓒.IsFinite := by + dsimp [𝓒, CEparams]; infer_instance + unfold HopfieldBoltzmann.P + simp only [HopfieldBoltzmann.CEparams] + unfold CanonicalEnsemble.probability + set Z := 𝓒.mathematicalPartitionFunction T + have hZpos := mathematicalPartitionFunction_pos_finite (𝓒:=𝓒) (T:=T) + have hExpPos : 0 < Real.exp (-(T.β : ℝ) * 𝓒.energy s) := Real.exp_pos _ + have : 0 < Real.exp (-(T.β : ℝ) * 𝓒.energy s) / Z := by + exact div_pos hExpPos hZpos + simpa [Z] using this + exact detailed_balance_from_opposite_ratios + (Pfun:=P) (Kfun:=fun a b => K (u:=u) a b) + (s:=s) (s':=s') (β:=T.β) (ΔE:=ΔE) + hPratio' hKratio hPpos hKpos + +omit [Nonempty U] in +/-- Symmetric orientation (pos→neg) obtained from `detailed_balance_neg_pos` by swapping `s,s'`. -/ +lemma detailed_balance_pos_neg + {u : U} {s s' : NN.State} + (h_off : ∀ v ≠ u, s.act v = s'.act v) + (h_pos : s.act u = TwoStateNeuralNetwork.σ_pos (NN:=NN)) + (h_neg : s'.act u = TwoStateNeuralNetwork.σ_neg (NN:=NN)) : + P s * K (u:=u) s s' = P s' * K (u:=u) s' s := by + classical + have hswap := + detailed_balance_neg_pos (NN:=NN) (spec:=spec) (p:=p) (T:=T) + (u:=u) (s:=s') (s':=s) + (h_off:=by + intro v hv + simp [h_off v hv]) + (h_neg:=h_neg) (h_pos:=h_pos) + simpa [mul_comm, mul_left_comm, mul_assoc] using hswap.symm + +omit [Nonempty U] in +/-- +**Theorem: Detailed Balance Condition (Reversibility)**. +The Gibbs update kernel satisfies the detailed balance condition with respect to the Boltzmann distribution. +P(s) K(s→s') = P(s') K(s'→s). +-/ +theorem detailed_balance + (u : U) (s s' : NN.State) : + P s * K (u:=u) s s' + = P s' * K (u:=u) s' s := by + classical + by_cases hDiff : DiffAway (NN:=NN) u s s' + · exact detailed_balance_diffAway (NN:=NN) (spec:=spec) (p:=p) (T:=T) hDiff + have h_off : ∀ v ≠ u, s.act v = s'.act v := by + intro v hv + by_contra H + exact hDiff ⟨v, hv, H⟩ + by_cases hEq : s = s' + · subst hEq; simp + have hClass := + single_site_cases (NN:=NN) (u:=u) (s:=s) (s':=s') h_off hEq + rcases hClass with hCase | hCase + · rcases hCase with ⟨hpos, hneg⟩ + exact detailed_balance_pos_neg (NN:=NN) (spec:=spec) (p:=p) (T:=T) + (u:=u) (s:=s) (s':=s') h_off hpos hneg + · rcases hCase with ⟨hneg, hpos⟩ + exact detailed_balance_neg_pos (NN:=NN) (spec:=spec) (p:=p) (T:=T) + (u:=u) (s:=s) (s':=s') h_off hneg hpos + + +end DetailedBalance + +variable [Fintype ι] [DecidableEq ι] [Ring R] --[CommRing R] +open CanonicalEnsemble Constants + +section DetailedBalance +open scoped Classical ENNReal Temperature Constants +open TwoState Temperature HopfieldBoltzmann ProbabilityTheory + +variable {U σ : Type} [Fintype U] [DecidableEq U] [Nonempty U] +variable (NN : NeuralNetwork ℝ U σ) [Fintype NN.State] [Nonempty NN.State] +variable [TwoStateNeuralNetwork NN] [TwoStateExclusive NN] +variable (spec : TwoState.EnergySpec' (NN:=NN)) +variable (p : Params NN) (T : Temperature) + +/-- Lift a family of PMFs to a Markov kernel on a finite (hence countable) state space. +We reuse `Kernel.ofFunOfCountable`, which supplies the measurability proof. -/ +noncomputable def pmfToKernel + {α : Type*} [Fintype α] [DecidableEq α] + [MeasurableSpace α] [MeasurableSingletonClass α] (K : α → PMF α) : + Kernel α α := + Kernel.ofFunOfCountable (fun a => (K a).toMeasure) + +/-- Single–site Gibbs kernel at site `u` as a Kernel (uses existing `gibbsUpdate`). +`spec` is not needed here, so we underscore it to silence the unused-variable linter. -/ +noncomputable def singleSiteKernel + (NN : NeuralNetwork ℝ U σ) [Fintype NN.State] [DecidableEq U] + [MeasurableSpace NN.State] [MeasurableSingletonClass NN.State] + [TwoStateNeuralNetwork NN] + (_spec : TwoState.EnergySpec' (NN:=NN)) (p : Params NN) (T : Temperature) (u : U) : + Kernel NN.State NN.State := + pmfToKernel (fun s => TwoState.gibbsUpdate (NN:=NN) (RingHom.id ℝ) p T s u) + +/-- Random–scan Gibbs kernel as uniform mixture over sites. +`spec` is likewise unused in the construction of the kernel itself. -/ +noncomputable def randomScanKernel + (NN : NeuralNetwork ℝ U σ) [Fintype U] [DecidableEq U] [Nonempty U] + [Fintype NN.State] [DecidableEq NN.State] [MeasurableSpace NN.State] [MeasurableSingletonClass NN.State] + [TwoStateNeuralNetwork NN] [TwoStateExclusive NN] + (_spec : TwoState.EnergySpec' (NN:=NN)) (p : Params NN) (T : Temperature) : + Kernel NN.State NN.State := + let sitePMF : PMF U := PMF.uniformOfFintype _ + pmfToKernel (fun s => + sitePMF.bind (fun u => + TwoState.gibbsUpdate (NN:=NN) (RingHom.id ℝ) p T s u)) + +section FiniteMeasureAPI +open scoped Classical +open MeasureTheory + +variable {α : Type*} + +/-- On a finite discrete measurable space (⊤ σ–algebra), every set is measurable. -/ +@[simp] lemma measurableSet_univ_of_fintype + [Fintype α] [MeasurableSpace α] (hσ : ‹MeasurableSpace α› = ⊤) + (s : Set α) : MeasurableSet s := by + subst hσ; trivial + +/-- For a finite type with counting measure, the (lower) integral +is the finite sum (specialization of the `tsum` version). + +FIX: Added `[MeasurableSingletonClass α]` which is required by `MeasureTheory.lintegral_count`. +Removed the auxiliary restricted / probability-specialized lemmas that caused build errors +(`lintegral_count_restrict`, `lintegral_fintype_prob_restrict`, `lintegral_restrict_as_sum_if`) +since they were unused and referenced a non‑existent lemma. -/ +lemma lintegral_count_fintype + [MeasurableSpace α] [MeasurableSingletonClass α] + [Fintype α] [DecidableEq α] + (f : α → ℝ≥0∞) : + ∫⁻ x, f x ∂(Measure.count : Measure α) = ∑ x : α, f x := by + classical + simpa [tsum_fintype] using (MeasureTheory.lintegral_count f) + +-- Finite-type restricted lintegral as a weighted finite sum (separated lemma). +lemma lintegral_fintype_measure_restrict + {α : Type*} + [Fintype α] [DecidableEq α] + [MeasurableSpace α] [MeasurableSingletonClass α] + (μ : Measure α) (A : Set α) --(hA : MeasurableSet A) + (f : α → ℝ≥0∞) : + ∫⁻ x in A, f x ∂μ + = ∑ x : α, (if x ∈ A then μ {x} * f x else 0) := by + classical + have hRestr : + ∫⁻ x in A, f x ∂μ + = ∑ x : α, f x * (μ.restrict A) {x} := by + simpa using (lintegral_fintype (μ:=μ.restrict A) (f:=f)) + have hSingleton : + ∀ x : α, (μ.restrict A) {x} = (if x ∈ A then μ {x} else 0) := by + intro x + by_cases hx : x ∈ A + · have hInter : ({x} : Set α) ∩ A = {x} := by + ext y; constructor + · intro hy; rcases hy with ⟨hy1, hy2⟩ + simp at hy1 + subst hy1 + simp [hx] + · intro hy + simp [hy, hx] + aesop + simp [Measure.restrict_apply, hx, hInter] + · have hInter : ({x} : Set α) ∩ A = (∅ : Set α) := by + apply Set.eq_empty_iff_forall_notMem.2 + intro y hy + rcases hy with ⟨hy1, hy2⟩ + have : y = x := by simpa [Set.mem_singleton_iff] using hy1 + subst this + exact hx hy2 + simp [Measure.restrict_apply, hx, hInter] + calc + ∫⁻ x in A, f x ∂μ + = ∑ x : α, f x * (μ.restrict A) {x} := hRestr + _ = ∑ x : α, f x * (if x ∈ A then μ {x} else 0) := by + simp [hSingleton] + _ = ∑ x : α, (if x ∈ A then μ {x} * f x else 0) := by + refine Finset.sum_congr rfl ?_ + intro x _ + by_cases hx : x ∈ A + · simp [hx, mul_comm] + · simp [hx] + +/-- Probability measure style formula for a finite type: +turn a restricted integral into a finite sum with point masses. -/ +lemma lintegral_fintype_prob_restrict + [Fintype α] [DecidableEq α] [MeasurableSpace α] [MeasurableSingletonClass α] + (μ : Measure α) [IsFiniteMeasure μ] + (A : Set α) (f : α → ℝ≥0∞) : + ∫⁻ x in A, f x ∂μ + = ∑ x : α, (if x ∈ A then μ {x} * f x else 0) := by + simpa using lintegral_fintype_measure_restrict μ A f + +/-- Restricted version over the counting measure (finite type). +Uses the probability-style formula specialized to `Measure.count`. -/ +lemma lintegral_count_restrict + [MeasurableSpace α] [MeasurableSingletonClass α] [Fintype α] [DecidableEq α] + (A : Set α) (f : α → ℝ≥0∞) : + ∫⁻ x in A, f x ∂(Measure.count : Measure α) + = ∑ x : α, (if x ∈ A then f x else 0) := by + classical + have h := + lintegral_fintype_prob_restrict (μ:=(Measure.count : Measure α)) A f + have hμ : ∀ x : α, (Measure.count : Measure α) {x} = 1 := by + intro x; simp + simpa [hμ, one_mul] using h + +/-- Convenience rewriting for the specific pattern used in detailed balance proofs: +move `μ {x}` factor to the left of function argument. -/ +lemma lintegral_restrict_as_sum_if + [Fintype α] [DecidableEq α] [MeasurableSpace α] [MeasurableSingletonClass α] + (μ : Measure α) (A : Set α) + (g : α → ℝ≥0∞) : + ∫⁻ x in A, g x ∂μ + = ∑ x : α, (if x ∈ A then μ {x} * g x else 0) := + lintegral_fintype_measure_restrict μ A g + +end FiniteMeasureAPI + +open MeasureTheory Set Finset Kernel TwoState HopfieldBoltzmann + +variable {α β : Type*} [MeasurableSpace α] [MeasurableSpace β] + +/-- (Helper) Every subset of a finite type is finite. -/ +lemma Set.finite_of_subsingleton_fintype + {γ : Type*} [Fintype γ] (S : Set γ) : S.Finite := + (Set.toFinite _) + +namespace ProbabilityTheory +namespace Kernel + +/-- Evaluation lemma for `Kernel.ofFunOfCountable`. Added for convenient rewriting/simp. -/ +@[simp] +lemma ofFunOfCountable_apply + {α β : Type*} [MeasurableSpace α] [MeasurableSpace β] + [Countable α] [MeasurableSingletonClass α] + (f : α → Measure β) (a : α) : + (Kernel.ofFunOfCountable f) a = f a := rfl + +end Kernel +end ProbabilityTheory + +namespace ProbabilityTheory + +open scoped Classical +open MeasureTheory + +variable {U σ : Type} +variable {NN : NeuralNetwork ℝ U σ} [TwoStateNeuralNetwork NN] [TwoStateExclusive NN] +variable {spec : TwoState.EnergySpec' (NN:=NN)} +variable {p : Params NN} {T : Temperature} + +section AuxFiniteSum + +/-- General finite-type identity: +a sum over the whole type with an `if … ∈ S` guard can be rewritten +as a sum over the `Finset` of the elements that satisfy the guard. -/ +lemma Finset.sum_if_mem_eq_sum_filter + {α β : Type*} [Fintype α] [DecidableEq α] [AddCommMonoid β] + (S : Set α) (f : α → β) : + (∑ x : α, (if x ∈ S then f x else 0)) + = ∑ x ∈ S.toFinset, f x := by + classical + have h_univ : + (∑ x : α, (if x ∈ S then f x else 0)) + = ∑ x ∈ (Finset.univ : Finset α), (if x ∈ S then f x else 0) := by + simp + have h_filter : + (∑ x ∈ (Finset.univ : Finset α), (if x ∈ S then f x else 0)) + = ∑ x ∈ (Finset.univ.filter fun x : α => x ∈ S), f x := by + simpa using + (Finset.sum_filter + (s := (Finset.univ : Finset α)) + (p := fun x : α => x ∈ S) + (f := f)).symm + have h_ident : + (Finset.univ.filter fun x : α => x ∈ S) = S.toFinset := by + ext x + by_cases hx : x ∈ S + · simp [hx, Finset.mem_filter, Set.mem_toFinset] + · simp [hx, Finset.mem_filter, Set.mem_toFinset] + simp [h_univ, h_filter, h_ident] + +lemma Finset.sum_subset_of_subset + {α β : Type*} [Fintype α] [DecidableEq α] [AddCommMonoid β] + (S : Set α) (f : α → β) + (_h₁ : ∀ x, x ∈ S.toFinset → True) + (_h₂ : ∀ x, x ∈ S.toFinset → False → False) + (_h₃ : ∀ x, x ∈ S.toFinset → True) : + (∑ x : α, (if x ∈ S then f x else 0)) + = ∑ x ∈ S.toFinset, f x := + Finset.sum_if_mem_eq_sum_filter S f + +end AuxFiniteSum + +/-- Uniform random-scan kernel evaluation: +the kernel probability of a measurable set `B` equals the arithmetic +average of the single-site kernel probabilities. -/ +lemma randomScanKernel_eval_uniform + {U σ : Type} [Fintype U] [DecidableEq U] [Nonempty U] + {NN : NeuralNetwork ℝ U σ} + [Fintype NN.State] [DecidableEq NN.State] + [MeasurableSpace NN.State] [MeasurableSingletonClass NN.State] + [TwoStateNeuralNetwork NN] [TwoStateExclusive NN] + (spec : TwoState.EnergySpec' (NN:=NN)) + (p : Params NN) (T : Temperature) + (x : NN.State) (B : Set NN.State) (_ : MeasurableSet B) : + (randomScanKernel (NN:=NN) spec p T) x B + = + (∑ u : U, (singleSiteKernel (NN:=NN) spec p T u) x B) +/ (Fintype.card U : ℝ≥0∞) := by + classical + unfold randomScanKernel singleSiteKernel pmfToKernel + simp [Kernel.ofFunOfCountable, Kernel.ofFunOfCountable_apply] + let sitePMF := PMF.uniformOfFintype U + let g : U → PMF NN.State := + fun u ↦ TwoState.gibbsUpdate (NN:=NN) (RingHom.id ℝ) p T x u + have hConst : + (sitePMF.bind g).toMeasure B + = + (∑ u : U, (g u).toMeasure B) +/ (Fintype.card U : ℝ≥0∞) := by + classical + have hμ : ∀ u : U, sitePMF u = (Fintype.card U : ℝ≥0∞)⁻¹ := by + intro u; simp [sitePMF, PMF.uniformOfFintype_apply] + simp [PMF.toMeasure_apply, tsum_fintype, PMF.bind_apply, hμ, + Finset.mul_sum, Finset.sum_mul, Finset.sum_comm, + ENNReal.div_eq_inv_mul, Set.indicator_apply, + mul_comm, mul_left_comm, mul_assoc] at * + aesop + +end ProbabilityTheory + +/-- On a finite (any finite subset) space with measurable singletons, the measure of a finite +set under a kernel is the finite sum of the singleton masses. (Refactored: Finset induction; +avoids problematic `hB.induction_on` elaboration.) -/ +lemma Kernel.measure_eq_sum_finset + [DecidableEq α] [MeasurableSingletonClass α] + (κ : Kernel β α) (x : β) {B : Set α} (hB : B.Finite) : + κ x B = ∑ y ∈ hB.toFinset, κ x {y} := by + classical + have hBset : B = (hB.toFinset : Finset α).toSet := by + ext a; aesop + set s : Finset α := hB.toFinset + suffices H : κ x s.toSet = ∑ y ∈ s, κ x {y} by aesop + refine s.induction_on ?h0 ?hstep + · simp + · intro a s ha_notin hIH + have hDisj : Disjoint ({a} : Set α) s.toSet := by + refine disjoint_left.mpr ?_ + intro y hy_in hy_in_s + have : y = a := by simpa using hy_in + subst this + aesop + have hMeas_s : MeasurableSet s.toSet := by + refine s.induction_on ?m0 ?mstep + · simp + · intro b t hb_notin ht + simpa [Finset.coe_insert, Set.image_eq_range, Set.union_comm, Set.union_left_comm, + Set.union_assoc] using (ht.union (measurableSet_singleton b)) + have hMeas_a : MeasurableSet ({a} : Set α) := measurableSet_singleton a + have hUnion : + (insert a s).toSet + = ({a} : Set α) ∪ s.toSet := by + ext y; by_cases hy : y = a + · subst hy; simp + · simp [hy] + have hAdd : + κ x ((insert a s).toSet) + = κ x ({a} : Set α) + κ x s.toSet := by + rw [← measure_union_add_inter {a} hMeas_s] + simp_rw [hUnion, measure_union_add_inter {a} hMeas_s] + exact measure_union hDisj hMeas_s + have hSum : + ∑ y ∈ insert a s, κ x {y} + = κ x ({a} : Set α) + ∑ y ∈ s, κ x {y} := by + simp [Finset.sum_insert, ha_notin] + calc + κ x ((insert a s).toSet) + = κ x ({a} : Set α) + κ x s.toSet := hAdd + _ = κ x ({a} : Set α) + ∑ y ∈ s, κ x {y} := by rw [hIH] + _ = ∑ y ∈ insert a s, κ x {y} := by simp [hSum] + +omit [Fintype U] [DecidableEq U] [Nonempty U] in +lemma lintegral_randomScanKernel_as_sum_div + (NN : NeuralNetwork ℝ U σ) [Fintype U] [DecidableEq U] [Nonempty U] + [Fintype NN.State] [DecidableEq NN.State] + [TwoStateNeuralNetwork NN] [TwoStateExclusive NN] + (spec : TwoState.EnergySpec' (NN:=NN)) + (p : Params NN) (T : Temperature) + (π : Measure (NN.State)) + (A B : Set NN.State) (hA : MeasurableSet A) (hB : MeasurableSet B) : + ∫⁻ x in A, (randomScanKernel (NN:=NN) spec p T) x B ∂π + = + (∑ u : U, + ∫⁻ x in A, (singleSiteKernel (NN:=NN) spec p T u) x B ∂π) +/ (Fintype.card U : ℝ≥0∞) := by + classical + letI : MeasurableSpace NN.State := ⊤ + letI : MeasurableSingletonClass NN.State := ⟨fun _ => trivial⟩ + set κ := randomScanKernel (NN:=NN) spec p T + set κu := fun u : U => singleSiteKernel (NN:=NN) spec p T u + set c : ℝ≥0∞ := (Fintype.card U : ℝ≥0∞)⁻¹ with hc + have h_div : (↑(Fintype.card U) : ℝ≥0∞) ≠ 0 := by + exact_mod_cast (Fintype.card_ne_zero : Fintype.card U ≠ 0) + have hEval : + ∀ x, κ x B = c * ∑ u : U, (κu u) x B := by + intro x + have hx := + randomScanKernel_eval_uniform (NN:=NN) (spec:=spec) p T x B hB + simp [κ, κu, c, ENNReal.div_eq_inv_mul, hx, mul_comm, mul_left_comm, mul_assoc] + have hLHS : + ∫⁻ x in A, κ x B ∂π + = c * ∑ u : U, ∫⁻ x in A, (κu u) x B ∂π := by + have hEval' : + (fun x => κ x B) = + fun x => c * ∑ u : U, (κu u) x B := by + funext x; simp [hEval x] + calc + ∫⁻ x in A, κ x B ∂π + = ∫⁻ x in A, c * (∑ u : U, (κu u) x B) ∂π := by + simp [hEval', mul_comm, mul_left_comm, mul_assoc] + _ = c * ∫⁻ x in A, (∑ u : U, (κu u) x B) ∂π := by + erw [← lintegral_const_mul c fun ⦃t⦄ a => _] + exact fun ⦃t⦄ a => hA + _ = c * ∑ u : U, ∫⁻ x in A, (κu u) x B ∂π := by + have : + ∫⁻ x in A, (∑ u : U, (κu u) x B) ∂π + = ∑ u : U, ∫⁻ x in A, (κu u) x B ∂π := by + erw [lintegral_finset_sum Finset.univ fun b a ⦃t⦄ a => _] + exact fun b a ⦃t⦄ a => hA + simpa using congrArg (fun z => c * z) this + have hRHS : + (∑ u : U, ∫⁻ x in A, (κu u) x B ∂π) +/ (Fintype.card U : ℝ≥0∞) + = c * ∑ u : U, ∫⁻ x in A, (κu u) x B ∂π := by + rw [ENNReal.div_eq_inv_mul] + aesop + +omit [Fintype U] [DecidableEq U] [Nonempty U] in +/-- Averaging lemma: uniform average of reversible single–site kernels is reversible. -/ +lemma randomScanKernel_reversible_of_sites + (NN : NeuralNetwork ℝ U σ) [Fintype U] [DecidableEq U] [Nonempty U] + [Fintype NN.State] [DecidableEq NN.State] + [TwoStateNeuralNetwork NN] [TwoStateExclusive NN] + (spec : TwoState.EnergySpec' (NN:=NN)) + (p : Params NN) (T : Temperature) + (π : Measure (NN.State)) + (hSite : + ∀ u, ProbabilityTheory.Kernel.IsReversible + (singleSiteKernel (NN:=NN) spec p T u) π) : + ProbabilityTheory.Kernel.IsReversible + (randomScanKernel (NN:=NN) spec p T) π := by + classical + letI : MeasurableSpace NN.State := ⊤ + letI : MeasurableSingletonClass NN.State := ⟨fun _ => trivial⟩ + intro A B hA hB + have hSum : + (∑ u : U, + ∫⁻ x in A, (singleSiteKernel (NN:=NN) spec p T u) x B ∂π) + = + (∑ u : U, + ∫⁻ x in B, (singleSiteKernel (NN:=NN) spec p T u) x A ∂π) := by + refine Finset.sum_congr rfl ?_ + intro u _; exact hSite u hA hA + have hAexpr := + lintegral_randomScanKernel_as_sum_div (NN:=NN) (spec:=spec) p T π A B hA hB + have hBexpr := + lintegral_randomScanKernel_as_sum_div (NN:=NN) (spec:=spec) p T π B A hB hA + simp [hAexpr, hBexpr, hSum] + +section ReversibilityFinite + +open scoped Classical +open MeasureTheory + +variable {α : Type*} +variable [Fintype α] [DecidableEq α] +variable [MeasurableSpace α] [MeasurableSingletonClass α] +variable (π : Measure α) (κ : Kernel α α) + +/-- Finite discrete expansion of a restricted lintegral of a kernel (measurable singletons). -/ +lemma lintegral_kernel_restrict_fintype + (A : Set α) : + ∫⁻ x in A, κ x A ∂π + = + ∑ x : α, (if x ∈ A then π {x} * κ x A else 0) := by + classical + simpa using + (lintegral_restrict_as_sum_if (μ:=π) (A:=A) (g:=fun x => κ x A)) + +/-- Finite discrete reversibility from pointwise detailed balance. -/ +lemma Kernel.isReversible_of_pointwise_fintype + (hPoint : + ∀ ⦃x y⦄, π {x} * κ x {y} = π {y} * κ y {x}) + : ProbabilityTheory.Kernel.IsReversible κ π := by + classical + intro A B hA hB + have hFinA : A.Finite := Set.finite_of_subsingleton_fintype A + have hFinB : B.Finite := Set.finite_of_subsingleton_fintype B + have hAexp : + ∫⁻ x in A, κ x B ∂π + = ∑ x ∈ hFinA.toFinset, π {x} * κ x B := by + have h1 : + ∫⁻ x in A, κ x B ∂π + = ∑ x : α, + (if x ∈ A then π {x} * κ x B else 0) := by + simpa using + (lintegral_restrict_as_sum_if (μ:=π) (A:=A) (g:=fun x => κ x B)) + have : + (∑ x : α, (if x ∈ A then π {x} * κ x B else 0)) + = + ∑ x ∈ hFinA.toFinset, π {x} * κ x B := by + classical + simp_rw + [(ProbabilityTheory.Finset.sum_if_mem_eq_sum_filter + (S:=A) (f:=fun x => π {x} * κ x B))] + rw [@toFinite_toFinset] + simp [h1, this] + have hBexp : + ∫⁻ x in B, κ x A ∂π + = ∑ x ∈ hFinB.toFinset, π {x} * κ x A := by + have h1 : + ∫⁻ x in B, κ x A ∂π + = ∑ x : α, + (if x ∈ B then π {x} * κ x A else 0) := by + simpa using + (lintegral_restrict_as_sum_if (μ:=π) (A:=B) (g:=fun x => κ x A)) + have : + (∑ x : α, (if x ∈ B then π {x} * κ x A else 0)) + = + ∑ x ∈ hFinB.toFinset, π {x} * κ x A := by + classical + simp_rw + [(ProbabilityTheory.Finset.sum_if_mem_eq_sum_filter + (S:=B) (f:=fun x => π {x} * κ x A))] + rw [@toFinite_toFinset] + simp [h1, this] + have hκB : + ∀ x, κ x B = ∑ y ∈ hFinB.toFinset, κ x {y} := by + intro x; simpa using + (Kernel.measure_eq_sum_finset (κ:=κ) x hFinB) + have hκA : + ∀ x, κ x A = ∑ y ∈ hFinA.toFinset, κ x {y} := by + intro x; simpa using + (Kernel.measure_eq_sum_finset (κ:=κ) x hFinA) + have hL : + ∑ x ∈ hFinA.toFinset, π {x} * κ x B + = + ∑ x ∈ hFinA.toFinset, ∑ y ∈ hFinB.toFinset, π {x} * κ x {y} := by + refine Finset.sum_congr rfl ?_ + intro x hx + simp_rw [hκB x, Finset.mul_sum] + have hR : + ∑ x ∈ hFinB.toFinset, π {x} * κ x A + = + ∑ x ∈ hFinB.toFinset, ∑ y ∈ hFinA.toFinset, π {x} * κ x {y} := by + refine Finset.sum_congr rfl ?_ + intro x hx + simp_rw [hκA x, Finset.mul_sum] + erw [hAexp, hBexp, hL, hR] + have hRew : + ∑ x ∈ hFinA.toFinset, ∑ y ∈ hFinB.toFinset, π {x} * κ x {y} + = + ∑ x ∈ hFinA.toFinset, ∑ y ∈ hFinB.toFinset, π {y} * κ y {x} := by + refine Finset.sum_congr rfl ?_ + intro x hx + refine Finset.sum_congr rfl ?_ + intro y hy + exact hPoint (x:=x) (y:=y) + calc + ∑ x ∈ hFinA.toFinset, ∑ y ∈ hFinB.toFinset, π {x} * κ x {y} + = ∑ x ∈ hFinA.toFinset, ∑ y ∈ hFinB.toFinset, π {y} * κ y {x} := hRew + _ = ∑ y ∈ hFinB.toFinset, ∑ x ∈ hFinA.toFinset, π {y} * κ y {x} := by + simpa using + (Finset.sum_comm : + (∑ x ∈ hFinA.toFinset, ∑ y ∈ hFinB.toFinset, + π {y} * κ y {x}) + = + ∑ y ∈ hFinB.toFinset, ∑ x ∈ hFinA.toFinset, + π {y} * κ y {x}) + _ = ∑ x ∈ hFinB.toFinset, ∑ y ∈ hFinA.toFinset, π {x} * κ x {y} := rfl + +end ReversibilityFinite + +/-- Singleton evaluation of a PMF turned into a measure. -/ +@[simp] +lemma PMF.toMeasure_singleton + {α : Type*} [MeasurableSpace α] [MeasurableSingletonClass α] + (p : PMF α) (a : α) : + p.toMeasure {a} = p a := by + rw [toMeasure_apply_eq_toOuterMeasure, toOuterMeasure_apply_singleton] + +-- ## Single–site pointwise detailed balance (finite two–state Hopfield) + +section SingleSitePointwise + +open scoped Classical ENNReal +open MeasureTheory TwoState HopfieldBoltzmann ProbabilityTheory + +variable {U σ : Type} [Fintype U] [DecidableEq U] [Nonempty U] +variable (NN : NeuralNetwork ℝ U σ) +variable [Fintype NN.State] [DecidableEq NN.State] [Nonempty NN.State] +variable [TwoStateNeuralNetwork NN] [TwoStateExclusive NN] +variable (spec : TwoState.EnergySpec' (NN:=NN)) +variable (p : Params NN) (T : Temperature) + +/-- Helper: canonical Boltzmann measure we use below. -/ +private noncomputable abbrev πBoltz : Measure NN.State := + (HopfieldBoltzmann.CEparams (NN:=NN) (spec:=spec) p).μProd T + +omit [Fintype U] [Nonempty U] [DecidableEq NN.State] [Nonempty NN.State] [TwoStateExclusive NN] in +/-- Evaluation of the single–site Gibbs kernel on a singleton. -/ +lemma singleSiteKernel_singleton_eval + (u : U) (s t : NN.State) : + (singleSiteKernel (NN:=NN) spec p T u) s {t} + = ENNReal.ofReal (HopfieldBoltzmann.Kbm (NN:=NN) p T u s t) := by + classical + letI : MeasurableSpace NN.State := ⊤ + letI : MeasurableSingletonClass NN.State := ⟨fun _ => trivial⟩ + have hPMF : + (singleSiteKernel (NN:=NN) spec p T u) s {t} + = + (TwoState.gibbsUpdate (NN:=NN) (RingHom.id ℝ) p T s u) t := by + unfold singleSiteKernel pmfToKernel + simp_rw [Kernel.ofFunOfCountable_apply, PMF.toMeasure_singleton] + have hfin : + (TwoState.gibbsUpdate (NN:=NN) (RingHom.id ℝ) p T s u) t ≠ (⊤ : ℝ≥0∞) := by + have hle : + (TwoState.gibbsUpdate (NN:=NN) (RingHom.id ℝ) p T s u) t ≤ 1 := by + simpa using + (TwoState.gibbsUpdate (NN:=NN) (RingHom.id ℝ) p T s u).coe_le_one t + have hlt : (TwoState.gibbsUpdate (NN:=NN) (RingHom.id ℝ) p T s u) t + < (⊤ : ℝ≥0∞) := + lt_of_le_of_lt hle (by simp) + exact (ne_of_lt hlt) + calc + (singleSiteKernel (NN:=NN) spec p T u) s {t} + = (TwoState.gibbsUpdate (NN:=NN) (RingHom.id ℝ) p T s u) t := hPMF + _ = ENNReal.ofReal ((TwoState.gibbsUpdate (NN:=NN) (RingHom.id ℝ) p T s u) t).toReal := by + simp [ENNReal.ofReal_toReal, hfin] + _ = ENNReal.ofReal (HopfieldBoltzmann.Kbm (NN:=NN) p T u s t) := rfl + +omit [Nonempty U] [DecidableEq NN.State] in +/-- Evaluation of the Boltzmann measure on a singleton as `ofReal` of the Boltzmann probability. -/ +lemma boltzmann_singleton_eval + (s : NN.State) : + (πBoltz (NN:=NN) (spec:=spec) (p:=p) (T:=T)) {s} + = + ENNReal.ofReal (HopfieldBoltzmann.P (NN:=NN) (spec:=spec) p T s) := by + classical + have _ : IsHamiltonian (U:=U) (σ:=σ) NN := + IsHamiltonian_of_EnergySpec' (NN:=NN) (spec:=spec) + have : (HopfieldBoltzmann.CEparams (NN:=NN) (spec:=spec) p).μProd T {s} + = + ENNReal.ofReal + ((HopfieldBoltzmann.CEparams (NN:=NN) (spec:=spec) p).probability T s) := by + simp + simp [πBoltz, HopfieldBoltzmann.P, HopfieldBoltzmann.CEparams] + +omit [Nonempty U] [DecidableEq NN.State] in +lemma singleSite_pointwise_detailed_balance + (u : U) : + ∀ s t : NN.State, + (πBoltz (NN:=NN) (spec:=spec) (p:=p) (T:=T)) {s} + * (singleSiteKernel (NN:=NN) spec p T u) s {t} + = + (πBoltz (NN:=NN) (spec:=spec) (p:=p) (T:=T)) {t} + * (singleSiteKernel (NN:=NN) spec p T u) t {s} := by + classical + intro s t + have hReal := + detailed_balance (NN:=NN) (spec:=spec) (p:=p) (T:=T) (u:=u) s t + have hπs := boltzmann_singleton_eval (NN:=NN) (spec:=spec) (p:=p) (T:=T) s + have hπt := boltzmann_singleton_eval (NN:=NN) (spec:=spec) (p:=p) (T:=T) t + have hκst := + singleSiteKernel_singleton_eval (NN:=NN) (spec:=spec) (p:=p) (T:=T) u s t + have hκts := + singleSiteKernel_singleton_eval (NN:=NN) (spec:=spec) (p:=p) (T:=T) u t s + have hPs_nonneg : + 0 ≤ HopfieldBoltzmann.P (NN:=NN) (spec:=spec) p T s := by + have _ : IsHamiltonian (U:=U) (σ:=σ) NN := + IsHamiltonian_of_EnergySpec' (NN:=NN) (spec:=spec) + exact probability_nonneg_finite + (𝓒:=HopfieldBoltzmann.CEparams (NN:=NN) (spec:=spec) p) (T:=T) (i:=s) + have hPt_nonneg : + 0 ≤ HopfieldBoltzmann.P (NN:=NN) (spec:=spec) p T t := by + have _ : IsHamiltonian (U:=U) (σ:=σ) NN := + IsHamiltonian_of_EnergySpec' (NN:=NN) (spec:=spec) + exact probability_nonneg_finite + (𝓒:=HopfieldBoltzmann.CEparams (NN:=NN) (spec:=spec) p) (T:=T) (i:=t) + have hKst_nonneg : + 0 ≤ HopfieldBoltzmann.Kbm (NN:=NN) p T u s t := by + unfold HopfieldBoltzmann.Kbm; exact ENNReal.toReal_nonneg + have hKts_nonneg : + 0 ≤ HopfieldBoltzmann.Kbm (NN:=NN) p T u t s := by + unfold HopfieldBoltzmann.Kbm; exact ENNReal.toReal_nonneg + rw [hπs, hπt, hκst, hκts, + ← ENNReal.ofReal_mul, ← ENNReal.ofReal_mul, hReal] + all_goals + first + | exact mul_nonneg hPs_nonneg hKst_nonneg + | simp_all only [μProd_singleton_of_fintype] + +omit [Nonempty U] in +/-- Reversibility of the single–site kernel w.r.t. the Boltzmann measure (patched). -/ +lemma singleSiteKernel_reversible + (u : U) : + ProbabilityTheory.Kernel.IsReversible + (singleSiteKernel (NN:=NN) spec p T u) + (πBoltz (NN:=NN) (spec:=spec) (p:=p) (T:=T)) := by + classical + letI : MeasurableSpace NN.State := ⊤ + letI : MeasurableSingletonClass NN.State := ⟨fun _ => trivial⟩ + refine Kernel.isReversible_of_pointwise_fintype + (π:=πBoltz (NN:=NN) (spec:=spec) (p:=p) (T:=T)) + (κ:=singleSiteKernel (NN:=NN) spec p T u) ?_ + intro x y + simpa using + singleSite_pointwise_detailed_balance (NN:=NN) (spec:=spec) (p:=p) (T:=T) u x y + +end SingleSitePointwise + +section RandomScan + +open scoped Classical +open MeasureTheory +open TwoState HopfieldBoltzmann ProbabilityTheory + +variable {U σ : Type} [Fintype U] [DecidableEq U] [Nonempty U] +variable (NN : NeuralNetwork ℝ U σ) [Fintype NN.State] [DecidableEq NN.State] [Nonempty NN.State] +variable [TwoStateNeuralNetwork NN] [TwoStateExclusive NN] +variable (spec : TwoState.EnergySpec' (NN:=NN)) +variable (p : Params NN) (T : Temperature) + +/-- Reversibility of the random–scan Gibbs kernel (uniform site choice) w.r.t. the Boltzmann measure. -/ +theorem randomScanKernel_reversible : + ProbabilityTheory.Kernel.IsReversible + (randomScanKernel (NN:=NN) spec p T) + ((HopfieldBoltzmann.CEparams (NN:=NN) (spec:=spec) p).μProd T) := by + classical + have hSite : + ∀ u : U, + ProbabilityTheory.Kernel.IsReversible + (singleSiteKernel (NN:=NN) spec p T u) + ((HopfieldBoltzmann.CEparams (NN:=NN) (spec:=spec) p).μProd T) := by + intro u + simpa [πBoltz, + HopfieldBoltzmann.CEparams] using + (singleSiteKernel_reversible (NN:=NN) (spec:=spec) (p:=p) (T:=T) u) + exact + randomScanKernel_reversible_of_sites + (NN:=NN) (spec:=spec) (p:=p) (T:=T) + ((HopfieldBoltzmann.CEparams (NN:=NN) (spec:=spec) p).μProd T) + hSite + +end RandomScan diff --git a/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/NeuralNetwork.lean b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/NeuralNetwork.lean index c2262553d..10ebcd32b 100644 --- a/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/NeuralNetwork.lean +++ b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/NeuralNetwork.lean @@ -9,6 +9,8 @@ import Mathlib.Data.Vector.Basic open Mathlib Finset +universe uR uU uσ + /- A `NeuralNetwork` models a neural network with: @@ -19,7 +21,7 @@ A `NeuralNetwork` models a neural network with: It extends `Digraph U` and includes the network's architecture, activation functions, and constraints. -/ -structure NeuralNetwork (R U : Type) (σ : Type) [Zero R] extends Digraph U where +structure NeuralNetwork (R : Type uR) (U : Type uU) (σ : Type uσ) [Zero R] extends Digraph U where /-- Input neurons. -/ (Ui Uo Uh : Set U) /-- There is at least one input neuron. -/ @@ -55,7 +57,9 @@ structure NeuralNetwork (R U : Type) (σ : Type) [Zero R] extends Digraph U wher (fnet u_target (w u_target) (fun v => fout v (current v)) (σv u_target)) (θ u_target))) -variable {R U σ : Type} [Zero R] +variable {R : Type uR} {U : Type uU} {σ : Type uσ} [Zero R] + +--variable {R U σ : Type} [Zero R] /-- `Params` is a structure that holds the parameters for a neural network `NN`. -/ structure Params (NN : NeuralNetwork R U σ) where @@ -71,49 +75,53 @@ structure State (NN : NeuralNetwork R U σ) where act : U → σ hp : ∀ u : U, NN.pact (act u) -/-- Extensionality lemma for neural network states -/ -@[ext] -lemma ext {R U σ : Type} [Zero R] {NN : NeuralNetwork R U σ} - {s₁ s₂ : NN.State} : (∀ u, s₁.act u = s₂.act u) → s₁ = s₂ := by - intro h - cases s₁ - cases s₂ - simp only [NeuralNetwork.State.mk.injEq] - apply funext - exact h +@[ext] lemma ext {R U σ : Type} [Zero R] {NN : NeuralNetwork R U σ} + {s₁ s₂ : NN.State} : + (∀ u, s₁.act u = s₂.act u) → s₁ = s₂ := by + intro h; cases s₁; cases s₂; simp [State.mk.injEq, funext h]; aesop namespace State - -variable {NN : NeuralNetwork R U σ} (wσθ : Params NN) (s : NN.State) +variable {NN : NeuralNetwork R U σ} +variable (p : Params NN) (s : NN.State) def out (u : U) : R := NN.fout u (s.act u) -def net (u : U) : R := NN.fnet u (wσθ.w u) (fun v => s.out v) (wσθ.σ u) -def onlyUi : Prop := ∃ σ0 : σ, ∀ u : U, u ∉ NN.Ui → s.act u = σ0 +def net (u : U) : R := NN.fnet u (p.w u) (fun v => s.out v) (p.σ u) +def onlyUi : Prop := ∃ σ0 : σ, ∀ u : U, u ∉ NN.Ui → s.act u = σ0 variable [DecidableEq U] -def Up {NN_local : NeuralNetwork R U σ} (s : NN_local.State) (wσθ : Params NN_local) (u_upd : U) : NN_local.State := - { act := fun v => if v = u_upd then - NN_local.fact u_upd (s.act u_upd) - (NN_local.fnet u_upd (wσθ.w u_upd) (fun n => s.out n) (wσθ.σ u_upd)) - (wσθ.θ u_upd) - else - s.act v, - hp := by - intro v_target - rw [ite_eq_dite] - split_ifs with h_eq_upd_neuron - · exact NN_local.hpact wσθ.w wσθ.hw wσθ.hw' wσθ.σ wσθ.θ s.act s.hp u_upd - · exact s.hp v_target - } - -def workPhase (extu : NN.State) (_ : extu.onlyUi) (uOrder : List U) : NN.State := - uOrder.foldl (fun s_iter u_iter => s_iter.Up wσθ u_iter) extu - -def seqStates (useq : ℕ → U) : ℕ → NeuralNetwork.State NN - | 0 => s - | n + 1 => .Up (seqStates useq n) wσθ (useq n) - -def isStable : Prop := ∀ (u : U), (s.Up wσθ u).act u = s.act u +/-- Single–site asynchronous update: recompute neuron `u` using current state `s`. + -/ +def Up (s : NN.State) (p : Params NN) (u : U) : NN.State := +{ act := fun v => + if hv : v = u then + NN.fact u (s.act u) + (NN.fnet u (p.w u) (fun n => NN.fout n (s.act n)) (p.σ u)) + (p.θ u) + else + s.act v +, hp := by + intro v + by_cases hv : v = u + · subst hv + have hclosure_all := + NN.hpact p.w p.hw p.hw' p.σ p.θ s.act s.hp + have hclosure := hclosure_all v + simp only [dif_pos rfl] + exact hclosure + · simp only [dif_neg hv] + exact s.hp v } + +/-- Fold a list of update sites left-to-right starting from an extended state. -/ +def workPhase (ext : NN.State) (_ : ext.onlyUi) (uOrder : List U) : NN.State := + uOrder.foldl (fun st u => Up st p u) ext + +/-- Iterated sequence of single–site updates following a site stream `useq`. -/ +def seqStates (useq : ℕ → U) : ℕ → NN.State + | 0 => s + | n + 1 => Up (seqStates useq n) p (useq n) + +/-- A state is stable if every single–site update leaves the site unchanged. -/ +def isStable : Prop := ∀ u : U, (Up s p u).act u = s.act u end State end NeuralNetwork diff --git a/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/TwoState.lean b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/TwoState.lean index a4caa8dad..6659aec77 100644 --- a/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/TwoState.lean +++ b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/TwoState.lean @@ -179,9 +179,11 @@ open NeuralNetwork --variable {R U σ : Type*} universe uR uU uσ --- (Optional) you can also parametrize earlier variables with these universes if desired: +-- We can also parametrize earlier variables with these universes if desired: variable {R : Type uR} {U : Type uU} {σ : Type uσ} +--variable {R U σ : Type} + /-- A minimal two-point activation alphabet. This class specifies: @@ -189,7 +191,7 @@ This class specifies: - `σ_neg`: the distinguished “negative” activation, - `embed`: a numeric embedding `σ → R` used to interpret activations in the ambient ring `R`. -/ -class TwoPointActivation (R : Type uR) (σ : Type uσ) where +class TwoPointActivation (R : Type) (σ : Type) where /-- Distinguished “positive” activation state. -/ σ_pos : σ /-- Distinguished “negative” activation state. -/ @@ -625,11 +627,10 @@ structure EnergySpec' `f (E p (updPos s u) - E p (updNeg s u)) = - scale f * f (localField p s u)`. -/ flip_energy_relation : - ∀ (f : R →+* ℝ) - (p : Params NN) (s : NN.State) (u : U), - f (E p (updPos (NN:=NN) s u) - E p (updNeg (NN:=NN) s u)) = - - (scale (R:=R) (U:=U) (σ:=σ) (NN:=NN) (f:=f)) * - f (localField p s u) + ∀ (f : R →+* ℝ) (p : Params NN) (s : NN.State) (u : U), + f (E p (updPos (R:=R) (U:=U) (σ:=σ) (NN:=NN) (s:=s) (u:=u)) + - E p (updNeg (R:=R) (U:=U) (σ:=σ) (NN:=NN) (s:=s) (u:=u))) = + - (scale (R:=R) (U:=U) (σ:=σ) (NN:=NN) (f:=f)) * f (localField p s u) namespace EnergySpec variable {NN : NeuralNetwork R U σ} @@ -1417,3 +1418,424 @@ theorem gibbs_update_tends_to_zero_temp_limit end Convergence end TwoState + +open Finset Matrix NeuralNetwork State Constants Temperature Filter Topology +open scoped ENNReal NNReal BigOperators +open NeuralNetwork +namespace TwoState + +--variable {R U σ : Type*} [Field R] [LinearOrder R] [IsStrictOrderedRing R] +/-! +Fix: strengthen the typeclass assumptions in the final section (they were only `[Zero R]`) +and specialize to `R = ℝ` so that `EnergySpec` (which requires ordered field structure) +is usable. Then provide a complete implementation of +`EnergySpec.energy_order_from_flip_id` (previous version failed due to missing instances). +-/ + + +lemma updPos_eq_self_of_act_pos + {R U σ} [Field R] [LinearOrder R] [IsStrictOrderedRing R] + [Fintype U] [DecidableEq U] + {NN : NeuralNetwork R U σ} [TwoStateNeuralNetwork NN] + (s : NN.State) (u : U) + (h : s.act u = TwoStateNeuralNetwork.σ_pos (NN:=NN)) : + updPos (R:=R) (U:=U) (σ:=σ) (NN:=NN) s u = s := by + classical + ext v + by_cases hv : v = u + · subst hv; simp [updPos, Function.update, h] + · simp [updPos, Function.update, hv] + +/-- Helper: if the current activation at `u` is already `σ_neg`, `updNeg` is identity. + +(Version with fully explicit implicit parameters to avoid universe inference issues.) -/ +lemma updNeg_eq_self_of_act_neg + {R U σ} [Field R] [LinearOrder R] [IsStrictOrderedRing R] + [Fintype U] [DecidableEq U] + {NN : NeuralNetwork R U σ} [TwoStateNeuralNetwork NN] + (s : NN.State) (u : U) + (h : s.act u = TwoStateNeuralNetwork.σ_neg (NN:=NN)) : + updNeg (R:=R) (U:=U) (σ:=σ) (NN:=NN) s u = s := by + classical + ext v + by_cases hv : v = u + · subst hv; simp [updNeg, Function.update, h] + · simp [updNeg, Function.update, hv] + +/-- Classification of a single asynchronous update at site `u`: +`Up` equals `updPos` if `θ ≤ net`, else `updNeg`. (Explicit parameters to +stabilize elaboration.) -/ +lemma Up_eq_updPos_or_updNeg + {R U σ} [Field R] [LinearOrder R] [IsStrictOrderedRing R] + [Fintype U] [DecidableEq U] + {NN : NeuralNetwork R U σ} [TwoStateNeuralNetwork NN] + (p : Params NN) (s : NN.State) (u : U) : + let net := s.net p u + let θ := (p.θ u).get (TwoStateNeuralNetwork.θ0 (NN:=NN) u) + s.Up p u = + (if θ ≤ net then updPos (R:=R) (U:=U) (σ:=σ) (NN:=NN) s u + else updNeg (R:=R) (U:=U) (σ:=σ) (NN:=NN) s u) := by + classical + intro net θ + ext v + by_cases hv : v = u + · subst hv + unfold NeuralNetwork.State.Up + simp only + by_cases hθle : θ ≤ net + · + have hpos := + TwoStateNeuralNetwork.h_fact_pos (NN:=NN) v (s.act v) net (p.θ v) hθle + have : NN.fact v (s.act v) + (NN.fnet v (p.w v) (fun w => s.out w) (p.σ v)) + (p.θ v) = TwoStateNeuralNetwork.σ_pos (NN:=NN) := by + simpa [NeuralNetwork.State.net] using hpos + simp [updPos, Function.update, this, hθle] + aesop + · + have hlt : net < θ := lt_of_not_ge hθle + have hneg := + TwoStateNeuralNetwork.h_fact_neg (NN:=NN) v (s.act v) net (p.θ v) hlt + have : NN.fact v (s.act v) + (NN.fnet v (p.w v) (fun w => s.out w) (p.σ v)) + (p.θ v) = TwoStateNeuralNetwork.σ_neg (NN:=NN) := by + simpa [NeuralNetwork.State.net] using hneg + simp [updNeg, Function.update, this, hθle] + aesop + · + unfold NeuralNetwork.State.Up + simp_rw [hv, updPos, updNeg]; simp [Function.update] + aesop +end TwoState +namespace TwoState.EnergySpec' + +/-- Given an energy difference identity + E sPos - E sNeg = - κ * L with κ ≥ 0, +we deduce the two directional inequalities depending on the sign of L. -/ +lemma energy_order_from_flip_id + {U σ} [Fintype U] [DecidableEq U] + (NN : NeuralNetwork ℝ U σ) [TwoStateNeuralNetwork NN] + (spec : EnergySpec' (NN:=NN)) + {p : Params NN} + {κ L : ℝ} {sPos sNeg : NN.State} + (hdiff : spec.E p sPos - spec.E p sNeg = - κ * L) + (hκ : 0 ≤ κ) : + (0 ≤ L → spec.E p sPos ≤ spec.E p sNeg) ∧ + (L ≤ 0 → spec.E p sNeg ≤ spec.E p sPos) := by + constructor + · intro hL + have hκL : 0 ≤ κ * L := mul_nonneg hκ hL + have hDiffLe : spec.E p sPos - spec.E p sNeg ≤ 0 := by + rw [hdiff]; simp only [neg_mul, Left.neg_nonpos_iff]; exact hκL + exact sub_nonpos.mp hDiffLe + · intro hL + have hrev : spec.E p sNeg - spec.E p sPos = κ * L := by + have := congrArg Neg.neg hdiff + simpa [neg_sub, neg_mul, neg_neg] using this + have hκL : κ * L ≤ 0 := mul_nonpos_of_nonneg_of_nonpos hκ hL + have : spec.E p sNeg - spec.E p sPos ≤ 0 := by simpa [hrev] using hκL + exact sub_nonpos.mp this + +end EnergySpec' +end TwoState +namespace TwoState + +section ConcreteLyapunov +variable {U : Type} [Fintype U] [DecidableEq U] [Nonempty U] + +-- Note: The following lemmas are specialized for `SymmetricBinary ℝ U` +-- to simplify the proof development by avoiding universe polymorphism issues. +-- The original polymorphic versions can be restored later from this template. + +lemma updPos_eq_self_of_act_pos_binary + (s : (SymmetricBinary ℝ U).State) (u : U) + (h : s.act u = 1) : + updPos (NN:=SymmetricBinary ℝ U) s u = s := by + ext v + by_cases hv : v = u + · subst hv; simp [updPos, Function.update, h, instTwoStateSymmetricBinary] + · simp [updPos, Function.update, hv] + +lemma updNeg_eq_self_of_act_neg_binary + (s : (SymmetricBinary ℝ U).State) (u : U) + (h : s.act u = -1) : + updNeg (NN:=SymmetricBinary ℝ U) s u = s := by + ext v + by_cases hv : v = u + · subst hv; simp [updNeg, Function.update, h, instTwoStateSymmetricBinary] + · simp [updNeg, Function.update, hv] + +lemma Up_eq_updPos_or_updNeg_binary + (p : Params (SymmetricBinary ℝ U)) (s : (SymmetricBinary ℝ U).State) (u : U) : + let net := s.net p u + let θ := (p.θ u).get fin0 + s.Up p u = + (if θ ≤ net + then updPos (NN:=SymmetricBinary ℝ U) s u + else updNeg (NN:=SymmetricBinary ℝ U) s u) := by + intro net θ + classical + -- Reuse the general classification lemma and unfold the lets. + simpa [net, θ] using + (TwoState.Up_eq_updPos_or_updNeg + (R:=ℝ) (U:=U) (σ:=ℝ) + (NN:=SymmetricBinary ℝ U) p s u) + +lemma energy_order_from_flip_id_binary + (spec : EnergySpec' (NN:=SymmetricBinary ℝ U)) + {p : Params (SymmetricBinary ℝ U)} + {κ L : ℝ} {sPos sNeg : (SymmetricBinary ℝ U).State} + (hdiff : spec.E p sPos - spec.E p sNeg = - κ * L) + (hκ : 0 ≤ κ) : + (0 ≤ L → spec.E p sPos ≤ spec.E p sNeg) ∧ + (L ≤ 0 → spec.E p sNeg ≤ spec.E p sPos) := by + constructor + · intro hL + have hκL : 0 ≤ κ * L := mul_nonneg hκ hL + -- Convert desired inequality to `sub ≤ 0`. + have hsub : spec.E p sPos - spec.E p sNeg ≤ 0 := by + rw [hdiff]; aesop + exact sub_nonpos.mp hsub + · intro hL + have hκL : κ * L ≤ 0 := mul_nonpos_of_nonneg_of_nonpos hκ hL + -- Derive the reversed difference. + have hrev : spec.E p sNeg - spec.E p sPos = κ * L := by + have := congrArg Neg.neg hdiff + -- -(E sPos - E sNeg) = κ * L + simpa [neg_sub, neg_mul, neg_neg] using this + have hsub : spec.E p sNeg - spec.E p sPos ≤ 0 := by + rw [hrev]; exact hκL + exact sub_nonpos.mp hsub + +/-- Lyapunov (energy non‑increase) at a single site for `SymmetricBinary` networks. -/ +lemma energy_is_lyapunov_at_site_binary + (spec : EnergySpec' (NN:=SymmetricBinary ℝ U)) + (p : Params (SymmetricBinary ℝ U)) (s : (SymmetricBinary ℝ U).State) (u : U) + (hcur : s.act u = 1 ∨ s.act u = -1) : + spec.E p (s.Up p u) ≤ spec.E p s := by + -- Shorthand states and values + set sPos := updPos (NN:=SymmetricBinary ℝ U) s u + set sNeg := updNeg (NN:=SymmetricBinary ℝ U) s u + set net := s.net p u + set θ := (p.θ u).get fin0 + set L : ℝ := net - θ + let fid : ℝ →+* ℝ := RingHom.id _ + -- Get the energy difference relation from the spec + have hflip := spec.flip_energy_relation fid p s u + have hlocal : spec.localField p s u = L := by simp [L, net, θ, spec.localField_spec]; aesop + have hdiff : spec.E p sPos - spec.E p sNeg = - (scale (NN:=SymmetricBinary ℝ U) (f:=fid)) * L := by + simpa [sPos, sNeg, hlocal] using hflip + -- Define κ and prove it's non-negative + set κ := scale (NN:=SymmetricBinary ℝ U) (f:=fid) + have hκ_pos : 0 < κ := by aesop--simp [scale_binary, map_ofNat] + have hκ_nonneg : 0 ≤ κ := hκ_pos.le + -- Get the ordering implications from the energy difference + have hOrder := energy_order_from_flip_id_binary spec hdiff hκ_nonneg + -- Relate the abstract update `Up` to concrete `updPos`/`updNeg` + have hUp_cases := Up_eq_updPos_or_updNeg_binary p s u + -- Case split on the threshold condition + by_cases hθle : θ ≤ net + · -- Case 1: net ≥ θ, so update goes to sPos + have hUp_cases_eval := hUp_cases + simp [net, θ, hθle, sPos, sNeg] at hUp_cases_eval + rw [hUp_cases_eval] + cases hcur with + | inl h_is_pos => -- s is already sPos + rw [updPos_eq_self_of_act_pos_binary s u h_is_pos] + | inr h_is_neg => -- s = sNeg, need E(sPos) ≤ E(sNeg) + have hL_nonneg : 0 ≤ L := by simpa [L] using sub_nonneg.mpr hθle + -- Rewrite only the right-hand occurrence of s using hs + have hs : s = sNeg := (updNeg_eq_self_of_act_neg_binary s u h_is_neg).symm + have hLE : spec.E p sPos ≤ spec.E p sNeg := hOrder.left hL_nonneg + simp_rw [hs] + exact + le_of_eq_of_le (congrArg (spec.E p) (congrFun (congrArg updPos (id (Eq.symm hs))) u)) hLE + · -- Case 2: net < θ, so update goes to sNeg + have hUp_cases_eval := hUp_cases + simp [net, θ, hθle, sPos, sNeg] at hUp_cases_eval + rw [hUp_cases_eval] + cases hcur with + | inl h_is_pos => -- s = sPos, need E(sNeg) ≤ E(sPos) + have hL_nonpos : L ≤ 0 := by simpa [L] using (lt_of_not_ge hθle).le + have hs : s = sPos := (updPos_eq_self_of_act_pos_binary s u h_is_pos).symm + have hLE : spec.E p sNeg ≤ spec.E p sPos := hOrder.2 hL_nonpos + simp_rw [hs] + exact + le_of_eq_of_le (congrArg (spec.E p) (congrFun (congrArg updNeg (id (Eq.symm hs))) u)) hLE + | inr h_is_neg => -- s already sNeg + rw [updNeg_eq_self_of_act_neg_binary s u h_is_neg] + +end ConcreteLyapunov +namespace EnergySpec' +open TwoState + +/-- General (non-binary–specialized) Lyapunov single–site energy descent lemma. -/ +lemma energy_is_lyapunov_at_site + {U σ} [Fintype U] [DecidableEq U] + (NN : NeuralNetwork ℝ U σ) [TwoStateNeuralNetwork NN] + (spec : EnergySpec' (R:=ℝ) (NN:=NN)) + (p : Params NN) (s : NN.State) (u : U) + (hcur : s.act u = TwoStateNeuralNetwork.σ_pos (NN:=NN) ∨ + s.act u = TwoStateNeuralNetwork.σ_neg (NN:=NN)) : + spec.E p (s.Up p u) ≤ spec.E p s := by + classical + set sPos := updPos (NN:=NN) s u + set sNeg := updNeg (NN:=NN) s u + set net := s.net p u + set θ := (p.θ u).get (TwoStateNeuralNetwork.θ0 (NN:=NN) u) + set L : ℝ := net - θ + let fid : ℝ →+* ℝ := RingHom.id _ + have hflip := spec.flip_energy_relation fid p s u + have hlocal : spec.localField p s u = L := by + simp [L, net, θ, spec.localField_spec] + have hdiff : + spec.E p sPos - spec.E p sNeg + = - (scale (R:=ℝ) (U:=U) (σ:=σ) (NN:=NN) (f:=fid)) * L := by + simpa [sPos, sNeg, hlocal] using hflip + set κ := scale (R:=ℝ) (U:=U) (σ:=σ) (NN:=NN) (f:=fid) + have hm := TwoStateNeuralNetwork.m_order (NN:=NN) + have hκpos : 0 < κ := by + simp [κ, scale] + aesop + have hκ : 0 ≤ κ := hκpos.le + have hOrder := + TwoState.EnergySpec'.energy_order_from_flip_id + (NN:=NN) (spec:=spec) (p:=p) (κ:=κ) (L:=L) + (sPos:=sPos) (sNeg:=sNeg) hdiff hκ + have hup := TwoState.Up_eq_updPos_or_updNeg (NN:=NN) p s u + by_cases hθle : θ ≤ net + · have hUpPos : s.Up p u = sPos := by + simpa [net, θ, hθle] using hup + cases hcur with + | inl hpos => + have hFixed : sPos = s := by + have h := TwoState.updPos_eq_self_of_act_pos (NN:=NN) s u hpos + simpa [sPos] using h + simp [hUpPos, hFixed] + | inr hneg => + have hEqNeg : sNeg = s := by + have h := TwoState.updNeg_eq_self_of_act_neg (NN:=NN) s u hneg + simpa [sNeg] using h + have hL : 0 ≤ L := by + have : θ ≤ net := hθle + simpa [L] using sub_nonneg.mpr this + have hLE := hOrder.left hL + simp [hUpPos, hEqNeg] + aesop + · have hnetlt : net < θ := lt_of_not_ge hθle + have hLle : L ≤ 0 := by + have : net - θ < 0 := sub_lt_zero.mpr hnetlt + exact this.le + have hUpNeg : s.Up p u = sNeg := by + have hnot : ¬ θ ≤ net := hθle + simpa [net, θ, hnot] using hup + cases hcur with + | inl hpos => + have hEqPos : sPos = s := by + have h := TwoState.updPos_eq_self_of_act_pos (NN:=NN) s u hpos + simpa [sPos] using h + have hLE := hOrder.right hLle + simp [hUpNeg, hEqPos] + aesop + | inr hneg => + have hFixed : sNeg = s := by + have h := TwoState.updNeg_eq_self_of_act_neg (NN:=NN) s u hneg + simpa [sNeg] using h + simp [hUpNeg, hFixed] + +/-- Wrapper (argument order variant). -/ +lemma energy_is_lyapunov_at_site' + {U σ} [Fintype U] [DecidableEq U] + {NN : NeuralNetwork ℝ U σ} [TwoStateNeuralNetwork NN] + (spec : EnergySpec' (R:=ℝ) (NN:=NN)) + (p : Params NN) (s : NN.State) (u : U) + (hcur : s.act u = TwoStateNeuralNetwork.σ_pos (NN:=NN) ∨ + s.act u = TwoStateNeuralNetwork.σ_neg (NN:=NN)) : + spec.E p (s.Up p u) ≤ spec.E p s := + energy_is_lyapunov_at_site (NN:=NN) (spec:=spec) p s u hcur + +/-- Lyapunov (energy non‑increase) at a single site (completed proof). +Uses `energy_order_from_flip_id` and the flip relation with the identity hom. -/ +lemma energy_is_lyapunov_at_site'' + {U σ} [Fintype U] [DecidableEq U] + (NN : NeuralNetwork ℝ U σ) [TwoStateNeuralNetwork NN] + (spec : EnergySpec' (R:=ℝ) (NN:=NN)) + (p : Params NN) (s : NN.State) (u : U) + (hcur : s.act u = TwoStateNeuralNetwork.σ_pos (NN:=NN) ∨ + s.act u = TwoStateNeuralNetwork.σ_neg (NN:=NN)) : + spec.E p (NeuralNetwork.State.Up s p u) ≤ spec.E p s := by + classical + set sPos := updPos (NN:=NN) s u + set sNeg := updNeg (NN:=NN) s u + set net := s.net p u + set θ := (p.θ u).get (TwoStateNeuralNetwork.θ0 (NN:=NN) u) + set L : ℝ := net - θ + let fid : ℝ →+* ℝ := RingHom.id _ + have hflip := spec.flip_energy_relation fid p s u + have hlocal : spec.localField p s u = L := by + simp [L, net, θ, spec.localField_spec] + have hdiff : + spec.E p sPos - spec.E p sNeg = + - (scale (R:=ℝ) (U:=U) (σ:=σ) (NN:=NN) (f:=fid)) * L := by + simpa [sPos, sNeg, hlocal] using hflip + set κ := scale (R:=ℝ) (U:=U) (σ:=σ) (NN:=NN) (f:=fid) + have hκpos : 0 < κ := by + have hmo := TwoStateNeuralNetwork.m_order (NN:=NN) + have : 0 < (NN.m (TwoStateNeuralNetwork.σ_pos (NN:=NN)) + - NN.m (TwoStateNeuralNetwork.σ_neg (NN:=NN))) := sub_pos.mpr hmo + simpa [κ, scale, fid, RingHom.id_apply] + have hκ : 0 ≤ κ := hκpos.le + have hOrder := + energy_order_from_flip_id + (NN:=NN) (spec:=spec) (p:=p) (κ:=κ) (L:=L) + (sPos:=sPos) (sNeg:=sNeg) hdiff hκ + have hup := TwoState.Up_eq_updPos_or_updNeg (NN:=NN) p s u + by_cases hθle : θ ≤ net + · have hUpPos : s.Up p u = sPos := by + simpa [net, θ, hθle] using hup + cases hcur with + | inl hPos => + have hs : sPos = s := + TwoState.updPos_eq_self_of_act_pos (NN:=NN) s u hPos + have htriv : spec.E p sPos ≤ spec.E p sPos := le_rfl + simp [hUpPos, hs] + | inr hNeg => + have hs : sNeg = s := + TwoState.updNeg_eq_self_of_act_neg (NN:=NN) s u hNeg + have hL : 0 ≤ L := by + have : θ ≤ net := hθle + simpa [L] using sub_nonneg.mpr this + have hLE : spec.E p sPos ≤ spec.E p sNeg := hOrder.left hL + simp [hUpPos, hs] + aesop + · have hLt : net < θ := lt_of_not_ge hθle + have hLle : L ≤ 0 := (sub_lt_zero.mpr hLt).le + have hUpNeg : s.Up p u = sNeg := by + have hnot : ¬ θ ≤ net := hθle + simpa [net, θ, hnot] using hup + cases hcur with + | inl hPos => + have hs : sPos = s := + TwoState.updPos_eq_self_of_act_pos (NN:=NN) s u hPos + have hLE : spec.E p sNeg ≤ spec.E p sPos := hOrder.right hLle + simp [hUpNeg, hs] + aesop + | inr hNeg => + have hs : sNeg = s := + TwoState.updNeg_eq_self_of_act_neg (NN:=NN) s u hNeg + have htriv : spec.E p sNeg ≤ spec.E p sNeg := le_rfl + simp [hUpNeg, hs] + +/-- Restated helper with identical conclusion (wrapper). -/ +lemma energy_is_lyapunov_at_site''' + {U σ} [Fintype U] [DecidableEq U] + {NN : NeuralNetwork ℝ U σ} [TwoStateNeuralNetwork NN] + (spec : EnergySpec' (R:=ℝ) (NN:=NN)) + (p : Params NN) (s : NN.State) (u : U) + (hcur : s.act u = TwoStateNeuralNetwork.σ_pos (NN:=NN) ∨ + s.act u = TwoStateNeuralNetwork.σ_neg (NN:=NN)) : + spec.E p (NeuralNetwork.State.Up s p u) ≤ spec.E p s := + energy_is_lyapunov_at_site (NN:=NN) (spec:=spec) p s u hcur + +end EnergySpec' +end TwoState diff --git a/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/aux.lean b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/aux.lean index ed8c21ae6..14559f2e1 100644 --- a/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/aux.lean +++ b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/aux.lean @@ -220,7 +220,7 @@ lemma filter_sum_pos_iff_exists_pos {α β : Type} [Fintype α] exact h a ha · exact zero_le (p a) have sum_zero := Finset.sum_eq_zero all_zero - exact not_lt_of_le (by exact nonpos_iff_eq_zero.mpr sum_zero) h_pos + exact not_lt_of_ge (by exact nonpos_iff_eq_zero.mpr sum_zero) h_pos rcases exists_pos with ⟨x, hx_mem, hx_pos⟩ exact ⟨x, filter_mem_iff.mp hx_mem, hx_pos⟩ · rintro ⟨x, hx_mem, hx_pos⟩ diff --git a/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/toCanonicalEnsemble.lean b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/toCanonicalEnsemble.lean new file mode 100644 index 000000000..7ca85c463 --- /dev/null +++ b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/toCanonicalEnsemble.lean @@ -0,0 +1,346 @@ +import PhysLean.StatisticalMechanics.CanonicalEnsemble.TwoState +import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.TwoState + +open MeasureTheory + +/-! +# The Bridge from Neural Networks to Statistical Mechanics + +This file defines the `IsHamiltonian` typeclass, which provides the formal bridge +between the constructive, algorithmic definition of a `NeuralNetwork` (Layer 4) +and the physical, probabilistic framework of a `CanonicalEnsemble` (Layer 2). +-/ +variable {U σ : Type} [Fintype U] [DecidableEq U] + +/-- For any finite-state neural network we use the trivial (⊤) measurable space. -/ +instance (NN : NeuralNetwork ℝ U σ) [Fintype NN.State] : MeasurableSpace NN.State := ⊤ + +omit [Fintype U] [DecidableEq U] in +@[simp] lemma measurable_of_fintype_state + (NN : NeuralNetwork ℝ U σ) [Fintype NN.State] (f : NN.State → ℝ) : + Measurable f := by + classical + unfold Measurable; intro s _; simp + +/-- +A typeclass asserting that a `NeuralNetwork`'s dynamics are governed by an energy function. +This is the formal statement that the network is a physical system with a well-defined +Hamiltonian, and that its deterministic dynamics are a form of energy minimization. + +- `NN`: The concrete `NeuralNetwork` structure. +-/ +class IsHamiltonian (NN : NeuralNetwork ℝ U σ) [MeasurableSpace NN.State] where + /-- The energy function (Hamiltonian) of a given state of the network. -/ + energy : Params NN → NN.State → ℝ + /-- A proof that the energy function is measurable (trivial for discrete (⊤) state spaces). -/ + energy_measurable : ∀ p, Measurable (energy p) + /-- The core axiom: The energy is a Lyapunov function for the network's asynchronous + Glauber dynamics. Any single update step does not increase the energy. -/ + energy_is_lyapunov : + ∀ (p : Params NN) (s : NN.State) (u : U), + energy p ((NeuralNetwork.State.Up s p u)) ≤ energy p s + +/-- +A formal constructor that lifts any `NeuralNetwork` proven to be `IsHamiltonian` +into the `CanonicalEnsemble` framework. + +This function is the bridge that allows us to apply the full power of statistical +mechanics (free energy, entropy, etc.) to a structurally-defined neural network. +-/ +@[simps!] +noncomputable def toCanonicalEnsemble + (NN : NeuralNetwork ℝ U σ) [Fintype NN.State] [IsHamiltonian NN] + (p : Params NN) : + CanonicalEnsemble NN.State where + energy := IsHamiltonian.energy p + dof := 0 -- For discrete spin systems, there are no continuous degrees of freedom. + phase_space_unit := 1 -- For counting measures, the unit is 1. + energy_measurable := IsHamiltonian.energy_measurable p + μ := Measure.count -- The natural base measure for a discrete state space. + μ_sigmaFinite := by infer_instance + +variable {U σ : Type} [Fintype U] [DecidableEq U] +variable {NN : NeuralNetwork ℝ U σ} [TwoStateNeuralNetwork NN] + +/- +This instance is the formal bridge. It is a theorem stating that any `NeuralNetwork` +for which we can provide an `EnergySpec` is guaranteed to be an `IsHamiltonian` system. + +Lean's typeclass system will use this instance automatically. If you define an `EnergySpec` +for a network, Lean will now know that it is also `IsHamiltonian`. +-/ + +/-! ## Generic Hamiltonian bridge (refactored) + +We generalize the previous `IsHamiltonian_of_EnergySpecSymmetricBinary` to any +two–state neural network for which the activation predicate `pact` is *exactly* +the two distinguished states `σ_pos` and `σ_neg`. This is captured by the +`TwoStateExclusive` predicate below. For such networks, an `EnergySpec'` +immediately yields an `IsHamiltonian` instance, reusing the generic +`energy_is_lyapunov_at_site''` lemma (no binary–specialized reproofs). -/ + +namespace TwoState + +/-- Exclusivity predicate: the allowed activations are precisely `σ_pos` or `σ_neg`. -/ +class TwoStateExclusive + {U σ} (NN : NeuralNetwork ℝ U σ) + [TwoStateNeuralNetwork NN] : Prop where + (pact_iff : ∀ a, NN.pact a ↔ + a = TwoStateNeuralNetwork.σ_pos (NN:=NN) ∨ + a = TwoStateNeuralNetwork.σ_neg (NN:=NN)) + +attribute [simp] TwoStateExclusive.pact_iff + +/-- Instance: `SymmetricBinary` activations are exactly `{1,-1}`. -/ +instance (U) [Fintype U] [DecidableEq U] [Nonempty U] : + TwoStateExclusive (TwoState.SymmetricBinary ℝ U) where + pact_iff a := by + -- pact definition: a = 1 ∨ a = -1 + simp [TwoState.SymmetricBinary] + aesop + +/-- Instance: `ZeroOne` activations are exactly `{0,1}`. -/ +instance zeroOneExclusive (U) [Fintype U] [DecidableEq U] [Nonempty U] : + TwoStateExclusive (TwoState.ZeroOne ℝ U) where + pact_iff a := by + -- pact definition: a = 0 ∨ a = 1 + simp [TwoState.ZeroOne, TwoState.SymmetricBinary] + aesop + +/-- Instance: `SymmetricSignum` activations (two-point type) are exactly the two constructors. -/ +instance signumExclusive (U) [Fintype U] [DecidableEq U] [Nonempty U] : + TwoStateExclusive (TwoState.SymmetricSignum ℝ U) where + pact_iff a := by + -- pact is `True`; every `a` is either `pos` or `neg` by exhaustive cases. + cases a <;> simp [TwoState.SymmetricSignum] + all_goals aesop + +end TwoState + +open TwoState + +variable {U σ : Type} [Fintype U] [DecidableEq U] +variable {NN : NeuralNetwork ℝ U σ} [TwoStateNeuralNetwork NN] + +/-- Generic bridge: any exclusive two–state NN with an `EnergySpec'` is Hamiltonian. +Added `[Fintype NN.State]` so the `(⊤)` measurable space instance is available, +fixing the missing `MeasurableSpace NN.State` error. -/ +noncomputable instance IsHamiltonian_of_EnergySpec' + (spec : TwoState.EnergySpec' (NN:=NN)) + [Fintype NN.State] -- NEW: ensures MeasurableSpace instance via the earlier `[Fintype]` → `⊤` + [TwoStateExclusive (NN:=NN)] : + IsHamiltonian (U:=U) (σ:=σ) NN where + energy := spec.E + energy_measurable := by + intro p + -- finite state space ⇒ every function measurable + have : Measurable (spec.E p) := + measurable_of_fintype_state (NN:=NN) (f:=spec.E p) + simp + energy_is_lyapunov := by + intro p s u + classical + have hcur := + (TwoStateExclusive.pact_iff (NN:=NN) (a:=s.act u)).1 (s.hp u) + exact TwoState.EnergySpec'.energy_is_lyapunov_at_site'' + (NN:=NN) spec p s u hcur + +/-- +Backward compatibility: the old SymmetricBinary-specific instance is now +substantiated by the generic one. (Kept for clarity; can be removed safely.) +-/ +@[deprecated IsHamiltonian_of_EnergySpec' (since := "2025-08-24")] +noncomputable def IsHamiltonian_of_EnergySpecSymmetricBinary + {U : Type} [Fintype U] [DecidableEq U] [Nonempty U] + [Fintype (TwoState.SymmetricBinary ℝ U).State] + (spec : TwoState.EnergySpec' (NN:=TwoState.SymmetricBinary ℝ U)) : + IsHamiltonian (TwoState.SymmetricBinary ℝ U) := + (IsHamiltonian_of_EnergySpec' (NN:=TwoState.SymmetricBinary ℝ U) (spec:=spec)) + +open CanonicalEnsemble +open scoped BigOperators + +variable {U : Type} [Fintype U] [DecidableEq U] + +/-- Abbreviation: the canonical ensemble associated to a Hamiltonian neural network. -/ +noncomputable abbrev hopfieldCE + (NN : NeuralNetwork ℝ U σ) [Fintype NN.State] [IsHamiltonian (U:=U) (σ:=σ) NN] + (p : Params NN) : + CanonicalEnsemble NN.State := + toCanonicalEnsemble (U:=U) (σ:=σ) NN p + +/-- The induced finite–ensemble structure (counting measure, dof = 0, unit = 1). -/ +instance + (NN : NeuralNetwork ℝ U σ) [Fintype NN.State] [IsHamiltonian (U:=U) (σ:=σ) NN] + (p : Params NN) : + CanonicalEnsemble.IsFinite (hopfieldCE (U:=U) (σ:=σ) NN p) where + μ_eq_count := rfl + dof_eq_zero := rfl + phase_space_unit_eq_one := rfl + +omit [Fintype U] in +@[simp] +lemma hopfieldCE_dof + (NN : NeuralNetwork ℝ U σ) [Fintype NN.State] [IsHamiltonian (U:=U) (σ:=σ) NN] + (p : Params NN) : + (hopfieldCE (U:=U) (σ:=σ) NN p).dof = 0 := rfl + +omit [Fintype U] in +@[simp] +lemma hopfieldCE_phase_space_unit + (NN : NeuralNetwork ℝ U σ) [Fintype NN.State] [IsHamiltonian (U:=U) (σ:=σ) NN] + (p : Params NN) : + (hopfieldCE (U:=U) (σ:=σ) NN p).phase_space_unit = 1 := rfl + +omit [Fintype U] in +/-- Uniform probability for a constant-energy Hamiltonian (sanity test of the bridge). -/ +lemma hopfieldCE_probability_const_energy + (NN : NeuralNetwork ℝ U σ) [Fintype NN.State] [IsHamiltonian (U:=U) (σ:=σ) NN] + (p : Params NN) (c : ℝ) + (hE : ∀ s, IsHamiltonian.energy (U:=U) (σ:=σ) (NN:=NN) p s = c) + (T : Temperature) (s : NN.State) : + (hopfieldCE (U:=U) (σ:=σ) NN p).probability T s + = (1 : ℝ) / (Fintype.card NN.State) := by + classical + set 𝓒 := hopfieldCE (U:=U) (σ:=σ) NN p + have hZ := + (mathematicalPartitionFunction_of_fintype (𝓒:=𝓒) T) + have hZconst : + 𝓒.mathematicalPartitionFunction T + = (Fintype.card NN.State : ℝ) * Real.exp (-(T.β : ℝ) * c) := by + have hsum : + (∑ s : NN.State, Real.exp (-(T.β : ℝ) * 𝓒.energy s)) + = (∑ _ : NN.State, Real.exp (-(T.β : ℝ) * c)) := by + refine Finset.sum_congr rfl ?_ + intro s _; simp [𝓒, toCanonicalEnsemble, hE] + have hsumConst : + (∑ _ : NN.State, Real.exp (-(T.β : ℝ) * c)) + = (Fintype.card NN.State : ℕ) • Real.exp (-(T.β : ℝ) * c) := by + simp [Finset.sum_const, Finset.card_univ] + have hnsmul : + ((Fintype.card NN.State : ℕ) • Real.exp (-(T.β : ℝ) * c)) + = (Fintype.card NN.State : ℝ) * Real.exp (-(T.β : ℝ) * c) := by + simp + simp [hZ, hsum, hsumConst, hnsmul] + aesop + unfold CanonicalEnsemble.probability + have hexp_ne : Real.exp (-(T.β : ℝ) * c) ≠ 0 := (Real.exp_pos _).ne' + simp_rw [𝓒, toCanonicalEnsemble, hE] + erw [hZconst] + simp [one_div, div_eq_mul_inv, mul_comm, mul_left_comm, mul_assoc, hexp_ne] + +omit [Fintype U] in +/-- Corollary: mean energy = constant `c` under the induced canonical ensemble, + for a constant-energy network. -/ +lemma hopfieldCE_meanEnergy_const + (NN : NeuralNetwork ℝ U σ) [Fintype NN.State] [Nonempty NN.State] + [IsHamiltonian (U:=U) (σ:=σ) NN] + (p : Params NN) (c : ℝ) + (hE : ∀ s, IsHamiltonian.energy (U:=U) (σ:=σ) (NN:=NN) p s = c) + (T : Temperature) : + (hopfieldCE (U:=U) (σ:=σ) NN p).meanEnergy T = c := by + classical + set 𝓒 := hopfieldCE (U:=U) (σ:=σ) NN p + have hZeq : + 𝓒.mathematicalPartitionFunction T + = (Fintype.card NN.State : ℝ) * Real.exp (-(T.β : ℝ) * c) := by + have hZform := (mathematicalPartitionFunction_of_fintype (𝓒:=𝓒) T) + have : + (∑ s : NN.State, Real.exp (-(T.β : ℝ) * 𝓒.energy s)) + = (Fintype.card NN.State : ℝ) * Real.exp (-(T.β : ℝ) * c) := by + have hconst : + (∑ _ : NN.State, Real.exp (-(T.β : ℝ) * c)) + = (Fintype.card NN.State : ℕ) • Real.exp (-(T.β : ℝ) * c) := by + simp [Finset.sum_const, Finset.card_univ] + simp [𝓒, toCanonicalEnsemble, hE, hconst, nsmul_eq_mul, + mul_comm, mul_left_comm, mul_assoc] + simpa [hZform] + have hNum' : + (∑ s : NN.State, + IsHamiltonian.energy (U:=U) (σ:=σ) (NN:=NN) p s * + Real.exp (-(T.β : ℝ) * + IsHamiltonian.energy (U:=U) (σ:=σ) (NN:=NN) p s)) + = c * 𝓒.mathematicalPartitionFunction T := by + have hNumEq : + (∑ s : NN.State, + IsHamiltonian.energy (U:=U) (σ:=σ) (NN:=NN) p s * + Real.exp (-(T.β : ℝ) * + IsHamiltonian.energy (U:=U) (σ:=σ) (NN:=NN) p s)) + = c * ((Fintype.card NN.State : ℝ) * + Real.exp (-(T.β : ℝ) * c)) := by + have hconst : + (∑ _ : NN.State, + c * Real.exp (-(T.β : ℝ) * c)) + = (Fintype.card NN.State : ℕ) • (c * Real.exp (-(T.β : ℝ) * c)) := by + simp [Finset.sum_const, Finset.card_univ] + simp [𝓒, toCanonicalEnsemble, hE, hconst, nsmul_eq_mul, + mul_comm, mul_left_comm, mul_assoc] + simpa [hZeq, mul_comm, mul_left_comm, mul_assoc] using hNumEq + unfold CanonicalEnsemble.meanEnergy + have hZne : 𝓒.mathematicalPartitionFunction T ≠ 0 := by + have hcard : 0 < (Fintype.card NN.State : ℝ) := by + have : 0 < Fintype.card NN.State := Fintype.card_pos_iff.mpr inferInstance + exact_mod_cast this + have hpos : + 0 < (Fintype.card NN.State : ℝ) * Real.exp (-(T.β : ℝ) * c) := + mul_pos hcard (Real.exp_pos _) + simp [hZeq] + aesop + +-- Inheritance showcase: canonical–ensemble facts usable for Hopfield networks. +section CanonicalEnsembleInheritanceExamples +variable {U σ : Type} [Fintype U] [DecidableEq U] +variable (NN : NeuralNetwork ℝ U σ) [Fintype NN.State] [Nonempty NN.State] +variable [IsHamiltonian (U:=U) (σ:=σ) NN] +variable (p : Params NN) (T : Temperature) +variable (s : NN.State) + +-- Basic objects +#check (hopfieldCE (U:=U) (σ:=σ) NN p).partitionFunction +#check (hopfieldCE (U:=U) (σ:=σ) NN p).mathematicalPartitionFunction + +-- Positivity (finite specialization) +#check (mathematicalPartitionFunction_pos_finite + (𝓒:=hopfieldCE (U:=U) (σ:=σ) NN p) (T:=T)) +#check (partitionFunction_pos_finite + (𝓒:=hopfieldCE (U:=U) (σ:=σ) NN p) (T:=T)) + +-- Probability normalization & basic bounds +#check (sum_probability_eq_one + (𝓒:=hopfieldCE (U:=U) (σ:=σ) NN p) (T:=T)) +#check (probability_nonneg_finite + (𝓒:=hopfieldCE (U:=U) (σ:=σ) NN p) (T:=T) (i:=s)) + +-- Entropy identifications in finite case +#check (shannonEntropy_eq_differentialEntropy + (𝓒:=hopfieldCE (U:=U) (σ:=σ) NN p) (T:=T)) +#check (thermodynamicEntropy_eq_shannonEntropy + (𝓒:=hopfieldCE (U:=U) (σ:=σ) NN p) (T:=T)) + +-- Additivity for two independent Hopfield ensembles (same phase_space_unit = 1) +variable (NN₁ NN₂ : NeuralNetwork ℝ U σ) +variable [Fintype NN₁.State] [Nonempty NN₁.State] [IsHamiltonian (U:=U) (σ:=σ) NN₁] +variable [Fintype NN₂.State] [Nonempty NN₂.State] [IsHamiltonian (U:=U) (σ:=σ) NN₂] +variable (p₁ : Params NN₁) (p₂ : Params NN₂) + +#check partitionFunction_add + (𝓒:=hopfieldCE (U:=U) (σ:=σ) NN₁ p₁) + (𝓒1:=hopfieldCE (U:=U) (σ:=σ) NN₂ p₂) + (T:=T) (by simp) +#check helmholtzFreeEnergy_add + (𝓒:=hopfieldCE (U:=U) (σ:=σ) NN₁ p₁) + (𝓒1:=hopfieldCE (U:=U) (σ:=σ) NN₂ p₂) + (T:=T) (by simp) +#check meanEnergy_add + (𝓒:=hopfieldCE (U:=U) (σ:=σ) NN₁ p₁) + (𝓒1:=hopfieldCE (U:=U) (σ:=σ) NN₂ p₂) + +-- n independent copies (scaling laws) +#check partitionFunction_nsmul + (𝓒:=hopfieldCE (U:=U) (σ:=σ) NN p) (n:=3) (T:=T) +#check helmholtzFreeEnergy_nsmul + (𝓒:=hopfieldCE (U:=U) (σ:=σ) NN p) (n:=3) (T:=T) +#check meanEnergy_nsmul + (𝓒:=hopfieldCE (U:=U) (σ:=σ) NN p) (n:=3) (T:=T) + +end CanonicalEnsembleInheritanceExamples From 4a88bc7fd1eeb0272d8b751b100e7200d4b849d3 Mon Sep 17 00:00:00 2001 From: Matteo Cipollina Date: Tue, 26 Aug 2025 12:39:02 +0200 Subject: [PATCH 13/15] minor fixes --- PhysLean.lean | 17 ++---- .../HopfieldNetwork/BoltzmannMachine.lean | 58 +++++-------------- 2 files changed, 19 insertions(+), 56 deletions(-) diff --git a/PhysLean.lean b/PhysLean.lean index 3aee74608..72bed8683 100644 --- a/PhysLean.lean +++ b/PhysLean.lean @@ -347,25 +347,18 @@ import PhysLean.Units.Pressure import PhysLean.Units.Speed import PhysLean.Units.Velocity import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.Asym -import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.BoltzmannMachine.Core -import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.BoltzmannMachine.Markov -import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.Core -import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.Markov -import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.NNStochastic -import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.NeuralNetwork -import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.Stochastic -import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.StochasticAux import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.aux -import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.test -import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.Asym +import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.BoltzmannMachine import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.BoltzmannMachine.Core import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.BoltzmannMachine.Markov import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.Core +import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.DetailedBalanceBM +import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.DetailedBalancegen import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.Markov import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.NNStochastic import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.NeuralNetwork import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.Stochastic import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.StochasticAux -import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.aux import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.test - +import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.toCanonicalEnsemble +import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.TwoState diff --git a/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/BoltzmannMachine.lean b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/BoltzmannMachine.lean index de7a63363..04e3a4137 100644 --- a/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/BoltzmannMachine.lean +++ b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/BoltzmannMachine.lean @@ -12,8 +12,6 @@ import Mathlib.Probability.Kernel.Composition.Prod /-! ### Concrete Hopfield Energy and Fintype Instances -/ - - /-! Reintroduce (and simplify) a `Matrix.quadraticForm` helper and the update lemma used later in the Hopfield energy flip relation proof (removed upstream). @@ -109,7 +107,6 @@ lemma mulVec_update_single /- Raw single–site quadratic form update (no diagonal assumption). Produces a δ-linear part plus a δ² * M i i remainder term. - Q(update x i v) - Q x = (v - x i) * ((∑ j, x j * M j i) + (M.mulVec x) i) + (v - x i)^2 * M i i @@ -156,7 +153,7 @@ lemma quadraticForm_update_point ring simpa [hMv, hUpd_off, hIf1, hIf2, δ] using hOffSite -/-- Core raw single–site quadratic form update (separated into a standalone lemma). +/-- Core raw single–site quadratic form update Produces a δ-linear part plus a δ² * M i i remainder term. -/ lemma quadraticForm_update_sum (M : Matrix ι ι R) (x : ι → R) (i : ι) (v : R) : @@ -247,8 +244,7 @@ lemma quadraticForm_update_sum simp [δ, mul_comm, mul_left_comm, mul_assoc] -/-- Raw single–site quadratic form update (no diagonal assumption). -Old name kept; proof now delegates to `quadraticForm_update_sum`. -/ +/-- Raw single–site quadratic form update (no diagonal assumption). -/ lemma quadraticForm_update_raw (M : Matrix ι ι R) (x : ι → R) (i : ι) (v : R) : quadraticForm M (Function.update x i v) - quadraticForm M x @@ -286,7 +282,7 @@ lemma quadraticForm_update_single_index simp_rw [h1, hErase, add_comm] -/-- Original (stronger) version assuming all diagonal entries vanish (kept for backwards compatibility). -/ +/-- Stronger version assuming all diagonal entries vanish -/ lemma quadraticForm_update_single {M : Matrix ι ι R} (hDiag : ∀ j, M j j = 0) (x : ι → R) (i : ι) (v : R) : @@ -296,6 +292,7 @@ lemma quadraticForm_update_single ( (M.mulVec x) i + ∑ j ∈ (Finset.univ.erase i), x j * M j i ) := quadraticForm_update_single_index (M:=M) (x:=x) (i:=i) (v:=v) (hii:=hDiag i) + /-- Optimized symmetric / zero–diagonal update for the quadratic form. This is the version used in the Hopfield flip energy relation. @@ -333,7 +330,6 @@ open Finset Matrix NeuralNetwork State TwoState variable {R U σ : Type} variable [Field R] [LinearOrder R] [IsStrictOrderedRing R] --- We need these helper lemmas about updPos/updNeg which were not in the prompt's snippet but are essential. namespace TwoState variable {R U σ : Type} [Field R] [LinearOrder R] [IsStrictOrderedRing R] [DecidableEq U] @@ -420,11 +416,8 @@ noncomputable def hamiltonian let θ_vec := fun i : U => (p.θ i).get fin0 (- (1/2 : R) * quad) + ∑ i : U, θ_vec i * s.act i -/-- -Proof of the fundamental Flip Energy Relation for the SymmetricBinary network. -ΔE = E(s⁺) - E(s⁻) = -2 * Lᵤ. -This leverages Mathlib's `Matrix.quadratic_form_update_diag_zero`. --/ +/-- Proof of the fundamental Flip Energy Relation for the SymmetricBinary network. +ΔE = E(s⁺) - E(s⁻) = -2 * Lᵤ. -/ lemma hamiltonian_flip_relation (p : Params (SymmetricBinary R U)) (s : (SymmetricBinary R U).State) (u : U) : let sPos := updPos (NN:=SymmetricBinary R U) s u let sNeg := updNeg (NN:=SymmetricBinary R U) s u @@ -433,75 +426,54 @@ lemma hamiltonian_flip_relation (p : Params (SymmetricBinary R U)) (s : (Symmetr intro sPos sNeg L unfold hamiltonian let θ_vec := fun i => (p.θ i).get fin0 - - -- 1. Analyze the Quadratic Term Difference (ΔE_quad). have h_quad_diff : (- (1/2 : R) * Matrix.quadraticForm p.w sPos.act) - (- (1/2 : R) * Matrix.quadraticForm p.w sNeg.act) = - (2 : R) * (p.w.mulVec s.act u) := by - rw [← mul_sub] - -- We analyze Q(sPos) - Q(sNeg). sPos has 1 at u, sNeg has -1 at u. - - -- Express sPos as an update of sNeg. have h_sPos_from_sNeg : sPos.act = Function.update sNeg.act u 1 := by ext i by_cases hi : i = u · subst hi - -- At site u, sPos.act u is σ_pos which is definitionally 1 for SymmetricBinary. simp_rw [sPos, sNeg, updPos, updNeg, Function.update] aesop · simp [sPos, sNeg, updPos, updNeg, Function.update, hi] rw [h_sPos_from_sNeg] - -- Apply the identity for updating a quadratic form with W symmetric and W_uu=0. - -- Q(update(x, k, v)) - Q(x) = (v - x_k) * 2 * (W x)_k. rw [Matrix.quadratic_form_update_diag_zero (p.hw'.1) (p.hw'.2)] - -- Here v=1, x=sNeg.act, k=u. sNeg.act u = -1. have h_sNeg_u : sNeg.act u = -1 := updNeg_act_at_u s u rw [h_sNeg_u] - -- (1 - (-1)) * 2 * (W sNeg.act)_u = 4 * (W sNeg.act)_u. simp only [sub_neg_eq_add, one_add_one_eq_two] ring_nf - -- Relate (W sNeg.act)_u back to s. Since W_uu=0, the activation at u doesn't matter. have h_W_sNeg_eq_W_s : p.w.mulVec sNeg.act u = p.w.mulVec s.act u := by unfold Matrix.mulVec dotProduct apply Finset.sum_congr rfl intro j _ by_cases h_eq : j = u - · simp [h_eq, p.hw'.2 u] -- W_uu = 0 + · simp [h_eq, p.hw'.2 u] · rw [updNeg_act_noteq s u j h_eq] - rw [h_W_sNeg_eq_W_s] - -- 2. Linear term difference have h_linear_diff : dotProduct θ_vec sPos.act - dotProduct θ_vec sNeg.act = (2 : R) * θ_vec u := by rw [← dotProduct_sub] - -- Only coordinate u differs (-1 → 1), so the difference vector is 2·e_u. have h_diff_vec : sPos.act - sNeg.act = Pi.single u (2 : R) := by ext v by_cases hv : v = u · subst hv - -- At site u: 1 - (-1) = 2 (for SymmetricBinary) simp [sPos, sNeg, updPos, updNeg, TwoState.SymmetricBinary, instTwoStateSymmetricBinary, Pi.single, sub_eq_add_neg, one_add_one_eq_two] - · -- Off site: unchanged, difference 0 - simp [sPos, sNeg, updPos, updNeg, Pi.single, hv, sub_eq_add_neg] + · simp [sPos, sNeg, updPos, updNeg, Pi.single, hv, sub_eq_add_neg] rw [h_diff_vec, dotProduct_single] simp [mul_comm] - - -- 3. Combine the terms. erw [add_sub_add_comm, h_quad_diff, h_linear_diff] - -- Relate (W s.act)_u to L = net(s) - θ_u. We need to show net(s) = (W s.act)_u. have h_net_eq_W_s : s.net p u = p.w.mulVec s.act u := by unfold State.net SymmetricBinary fnet Matrix.mulVec dotProduct apply Finset.sum_congr rfl intro v _ split_ifs with h_ne · aesop - · -- Case v = u (since ¬ (v ≠ u)): the net integrand is 0; the mulVec term is W u u * s.act u = 0. - have hv : v = u := by + · have hv : v = u := by classical by_contra hvne exact h_ne hvne @@ -510,7 +482,6 @@ lemma hamiltonian_flip_relation (p : Params (SymmetricBinary R U)) (s : (Symmetr simp [hdiag] rw [← h_net_eq_W_s] - -- Goal: -2 * net + 2 * θ = -2 * (net - θ). ring /-- The concrete Energy Specification for the SymmetricBinary Hopfield Network. -/ @@ -574,11 +545,10 @@ end SymmetricBinaryFintype /-! # Detailed Balance and the Boltzmann Distribution -This section establishes that the Gibbs update kernel is reversible with respect to the -Boltzmann distribution derived from the associated Canonical Ensemble. This holds generically -for any exclusive two-state network with an EnergySpec'. +This section and the DetailedBalanceBM file establish that the Gibbs update kernel is reversible +with respect to the Boltzmann distribution derived from the associated Canonical Ensemble. +This holds generically for any exclusive two-state network with an EnergySpec'. -/ - namespace HopfieldBoltzmann open CanonicalEnsemble ProbabilityTheory TwoState PMF @@ -662,7 +632,7 @@ lemma boltzmann_ratio (s s' : NN.State) : IsHamiltonian_of_EnergySpec' (NN:=NN) (spec:=spec) set 𝓒 := CEparams (NN:=NN) (spec:=spec) p have instFin : 𝓒.IsFinite := by - dsimp [𝓒, CEparams] -- unfolds to `hopfieldCE` + dsimp [𝓒, CEparams] infer_instance have h := CE_probability_ratio (NN:=NN) (𝓒:=𝓒) (T:=T) s s' simpa [P, 𝓒, @@ -752,7 +722,7 @@ lemma Kbm_apply_other (u : U) (s s' : NN.State) simp [h_K] aesop -/-- Helper: (1 - logistic(x)) / logistic(x) = exp(-x). -/ +/-- (1 - logistic(x)) / logistic(x) = exp(-x). -/ lemma one_sub_logistic_div_logistic (x : ℝ) : (1 - logisticProb x) / logisticProb x = Real.exp (-x) := by have h_pos := logisticProb_pos x From 272eb00b50de3a4db986abdc009aefdec92fec48 Mon Sep 17 00:00:00 2001 From: Matteo Cipollina Date: Tue, 26 Aug 2025 12:53:17 +0200 Subject: [PATCH 14/15] remove old files --- PhysLean.lean | 3 +- .../HopfieldNetwork/DetailedBalanceBM.lean | 63 +- .../SpinGlasses/HopfieldNetwork/Markov.lean | 316 ------- .../HopfieldNetwork/Stochastic.lean | 814 ------------------ .../HopfieldNetwork/StochasticAux.lean | 471 ---------- 5 files changed, 27 insertions(+), 1640 deletions(-) delete mode 100644 PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Markov.lean delete mode 100644 PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Stochastic.lean delete mode 100644 PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/StochasticAux.lean diff --git a/PhysLean.lean b/PhysLean.lean index 72bed8683..4dd46cae0 100644 --- a/PhysLean.lean +++ b/PhysLean.lean @@ -353,12 +353,11 @@ import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.BoltzmannMachin import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.BoltzmannMachine.Markov import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.Core import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.DetailedBalanceBM -import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.DetailedBalancegen +import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.DetailedBalanceGen import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.Markov import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.NNStochastic import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.NeuralNetwork import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.Stochastic -import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.StochasticAux import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.test import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.toCanonicalEnsemble import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.TwoState diff --git a/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/DetailedBalanceBM.lean b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/DetailedBalanceBM.lean index 418b5bae4..082fe7e73 100644 --- a/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/DetailedBalanceBM.lean +++ b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/DetailedBalanceBM.lean @@ -1,9 +1,8 @@ import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.BoltzmannMachine -import Mathlib import PhysLean.StatisticalMechanics.CanonicalEnsemble.Finite import PhysLean.StatisticalMechanics.CanonicalEnsemble.Lemmas --- Provide a finite canonical ensemble instance for the Hopfield Boltzmann construction. +-- We provide a finite canonical ensemble instance for the Hopfield Boltzmann construction. instance {U σ : Type} [Fintype U] [DecidableEq U] (NN : NeuralNetwork ℝ U σ) [Fintype NN.State] [Nonempty NN.State] @@ -15,7 +14,7 @@ instance dsimp [HopfieldBoltzmann.CEparams] infer_instance -variable [Fintype ι] [DecidableEq ι] [Ring R] --[CommRing R] +variable [Fintype ι] [DecidableEq ι] [Ring R] open CanonicalEnsemble Constants section DetailedBalance @@ -31,7 +30,7 @@ variable (p : Params NN) (T : Temperature) local notation "P" => P (NN:=NN) (spec:=spec) p T local notation "K" => Kbm (NN:=NN) p T -/-- Helper: states differ away from `u` (∃ other coordinate with different activation). -/ +/-- States differ away from `u` (∃ other coordinate with different activation). -/ def DiffAway (u : U) (s s' : NN.State) : Prop := ∃ v, v ≠ u ∧ s.act v ≠ s'.act v @@ -78,7 +77,8 @@ lemma Kbm_zero_of_diffAway exact ⟨h_forward, h_backward⟩ omit [Nonempty U] [Nonempty NN.State] in -/-- Detailed balance holds trivially in the “diff-away” case (both transition probabilities are 0). -/ +/-- Detailed balance holds trivially in the “diff-away” case (both transition probabilities +are 0). -/ lemma detailed_balance_diffAway {u : U} {s s' : NN.State} (h : DiffAway (NN:=NN) u s s') : @@ -113,8 +113,7 @@ lemma single_site_cases · exact Or.inr ⟨hs_neg, hs'_pos⟩ · exact False.elim (hx (hs_neg.trans hs'_neg.symm)) -/-- Convenience: `logisticProb (-x) = 1 - logisticProb x` (already available above as -`TwoState.logisticProb_neg`, re-exposed here in the local namespace for algebra lemmas). -/ +/-- Convenience: `logisticProb (-x) = 1 - logisticProb x` -/ lemma logistic_neg (x : ℝ) : logisticProb (-x) = 1 - logisticProb x := TwoState.logisticProb_neg x @@ -224,9 +223,9 @@ lemma TwoState.EnergySpec'.probPos_flip_pair mul_comm, mul_left_comm, mul_assoc] using h exact ⟨h₁, h₂⟩ -/-- Specialization of the previous pair lemma to the “neg→pos” orientation used -in `detailed_balance_neg_pos`. Here `ΔE = E s' - E s` with `s' = updPos s u` -and `s = updNeg s' u` (i.e. `s` carries σ_neg at `u`, `s'` carries σ_pos). -/ +/-- Specialization to the “neg→pos” orientation used in `detailed_balance_neg_pos`. +Here `ΔE = E s' - E s` with `s' = updPos s u`and `s = updNeg s' u` (i.e. `s` carries σ_neg at `u`, +`s'` carries σ_pos). -/ lemma flip_prob_neg_pos {U σ} [Fintype U] [DecidableEq U] {NN : NeuralNetwork ℝ U σ} [TwoStateNeuralNetwork NN] @@ -283,8 +282,7 @@ lemma flip_prob_neg_pos ring_nf; aesop exact ⟨h1, h2⟩ -/-- Clean algebraic lemma: - if +/-- if • `Pfun s' / Pfun s = exp (-β ΔE)` and • `Kfun s s' / Kfun s' s = exp ( β ΔE)` then detailed balance holds: `Pfun s * Kfun s s' = Pfun s' * Kfun s' s`. -/ @@ -421,7 +419,8 @@ lemma detailed_balance_pos_neg omit [Nonempty U] in /-- **Theorem: Detailed Balance Condition (Reversibility)**. -The Gibbs update kernel satisfies the detailed balance condition with respect to the Boltzmann distribution. +The Gibbs update kernel satisfies the detailed balance condition with respect to the +Boltzmann distribution. P(s) K(s→s') = P(s') K(s'→s). -/ theorem detailed_balance @@ -447,10 +446,9 @@ theorem detailed_balance exact detailed_balance_neg_pos (NN:=NN) (spec:=spec) (p:=p) (T:=T) (u:=u) (s:=s) (s':=s') h_off hneg hpos - end DetailedBalance -variable [Fintype ι] [DecidableEq ι] [Ring R] --[CommRing R] +variable [Fintype ι] [DecidableEq ι] [Ring R] open CanonicalEnsemble Constants section DetailedBalance @@ -463,16 +461,14 @@ variable [TwoStateNeuralNetwork NN] [TwoStateExclusive NN] variable (spec : TwoState.EnergySpec' (NN:=NN)) variable (p : Params NN) (T : Temperature) -/-- Lift a family of PMFs to a Markov kernel on a finite (hence countable) state space. -We reuse `Kernel.ofFunOfCountable`, which supplies the measurability proof. -/ +/-- Lift a family of PMFs to a Markov kernel on a finite (hence countable) state space. -/ noncomputable def pmfToKernel {α : Type*} [Fintype α] [DecidableEq α] [MeasurableSpace α] [MeasurableSingletonClass α] (K : α → PMF α) : Kernel α α := Kernel.ofFunOfCountable (fun a => (K a).toMeasure) -/-- Single–site Gibbs kernel at site `u` as a Kernel (uses existing `gibbsUpdate`). -`spec` is not needed here, so we underscore it to silence the unused-variable linter. -/ +/-- Single–site Gibbs kernel at site `u` as a Kernel (uses existing `gibbsUpdate`). -/ noncomputable def singleSiteKernel (NN : NeuralNetwork ℝ U σ) [Fintype NN.State] [DecidableEq U] [MeasurableSpace NN.State] [MeasurableSingletonClass NN.State] @@ -481,8 +477,7 @@ noncomputable def singleSiteKernel Kernel NN.State NN.State := pmfToKernel (fun s => TwoState.gibbsUpdate (NN:=NN) (RingHom.id ℝ) p T s u) -/-- Random–scan Gibbs kernel as uniform mixture over sites. -`spec` is likewise unused in the construction of the kernel itself. -/ +/-- Random–scan Gibbs kernel as uniform mixture over sites. -/ noncomputable def randomScanKernel (NN : NeuralNetwork ℝ U σ) [Fintype U] [DecidableEq U] [Nonempty U] [Fintype NN.State] [DecidableEq NN.State] [MeasurableSpace NN.State] [MeasurableSingletonClass NN.State] @@ -507,12 +502,7 @@ variable {α : Type*} subst hσ; trivial /-- For a finite type with counting measure, the (lower) integral -is the finite sum (specialization of the `tsum` version). - -FIX: Added `[MeasurableSingletonClass α]` which is required by `MeasureTheory.lintegral_count`. -Removed the auxiliary restricted / probability-specialized lemmas that caused build errors -(`lintegral_count_restrict`, `lintegral_fintype_prob_restrict`, `lintegral_restrict_as_sum_if`) -since they were unused and referenced a non‑existent lemma. -/ +is the finite sum (specialization of the `tsum` version). -/ lemma lintegral_count_fintype [MeasurableSpace α] [MeasurableSingletonClass α] [Fintype α] [DecidableEq α] @@ -526,7 +516,7 @@ lemma lintegral_fintype_measure_restrict {α : Type*} [Fintype α] [DecidableEq α] [MeasurableSpace α] [MeasurableSingletonClass α] - (μ : Measure α) (A : Set α) --(hA : MeasurableSet A) + (μ : Measure α) (A : Set α) (f : α → ℝ≥0∞) : ∫⁻ x in A, f x ∂μ = ∑ x : α, (if x ∈ A then μ {x} * f x else 0) := by @@ -579,8 +569,7 @@ lemma lintegral_fintype_prob_restrict = ∑ x : α, (if x ∈ A then μ {x} * f x else 0) := by simpa using lintegral_fintype_measure_restrict μ A f -/-- Restricted version over the counting measure (finite type). -Uses the probability-style formula specialized to `Measure.count`. -/ +/-- Restricted version over the counting measure (finite type). -/ lemma lintegral_count_restrict [MeasurableSpace α] [MeasurableSingletonClass α] [Fintype α] [DecidableEq α] (A : Set α) (f : α → ℝ≥0∞) : @@ -609,7 +598,7 @@ open MeasureTheory Set Finset Kernel TwoState HopfieldBoltzmann variable {α β : Type*} [MeasurableSpace α] [MeasurableSpace β] -/-- (Helper) Every subset of a finite type is finite. -/ +/-- Every subset of a finite type is finite. -/ lemma Set.finite_of_subsingleton_fintype {γ : Type*} [Fintype γ] (S : Set γ) : S.Finite := (Set.toFinite _) @@ -720,8 +709,7 @@ lemma randomScanKernel_eval_uniform end ProbabilityTheory /-- On a finite (any finite subset) space with measurable singletons, the measure of a finite -set under a kernel is the finite sum of the singleton masses. (Refactored: Finset induction; -avoids problematic `hB.induction_on` elaboration.) -/ +set under a kernel is the finite sum of the singleton masses. -/ lemma Kernel.measure_eq_sum_finset [DecidableEq α] [MeasurableSingletonClass α] (κ : Kernel β α) (x : β) {B : Set α} (hB : B.Finite) : @@ -826,7 +814,7 @@ lemma lintegral_randomScanKernel_as_sum_div aesop omit [Fintype U] [DecidableEq U] [Nonempty U] in -/-- Averaging lemma: uniform average of reversible single–site kernels is reversible. -/ +/-- Uniform average of reversible single–site kernels is reversible. -/ lemma randomScanKernel_reversible_of_sites (NN : NeuralNetwork ℝ U σ) [Fintype U] [DecidableEq U] [Nonempty U] [Fintype NN.State] [DecidableEq NN.State] @@ -993,7 +981,7 @@ variable [TwoStateNeuralNetwork NN] [TwoStateExclusive NN] variable (spec : TwoState.EnergySpec' (NN:=NN)) variable (p : Params NN) (T : Temperature) -/-- Helper: canonical Boltzmann measure we use below. -/ +/-- Canonical Boltzmann measure from `CanonicalEnsemble.Basic` -/ private noncomputable abbrev πBoltz : Measure NN.State := (HopfieldBoltzmann.CEparams (NN:=NN) (spec:=spec) p).μProd T @@ -1091,7 +1079,7 @@ lemma singleSite_pointwise_detailed_balance | simp_all only [μProd_singleton_of_fintype] omit [Nonempty U] in -/-- Reversibility of the single–site kernel w.r.t. the Boltzmann measure (patched). -/ +/-- Reversibility of the single–site kernel w.r.t. the Boltzmann measure. -/ lemma singleSiteKernel_reversible (u : U) : ProbabilityTheory.Kernel.IsReversible @@ -1121,7 +1109,8 @@ variable [TwoStateNeuralNetwork NN] [TwoStateExclusive NN] variable (spec : TwoState.EnergySpec' (NN:=NN)) variable (p : Params NN) (T : Temperature) -/-- Reversibility of the random–scan Gibbs kernel (uniform site choice) w.r.t. the Boltzmann measure. -/ +/-- Reversibility of the random–scan Gibbs kernel (uniform site choice) w.r.t. +the Boltzmann measure. -/ theorem randomScanKernel_reversible : ProbabilityTheory.Kernel.IsReversible (randomScanKernel (NN:=NN) spec p T) diff --git a/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Markov.lean b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Markov.lean deleted file mode 100644 index 33cf96039..000000000 --- a/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Markov.lean +++ /dev/null @@ -1,316 +0,0 @@ -import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.Stochastic -import Mathlib.MeasureTheory.Measure.WithDensity -import Mathlib.Probability.Kernel.Invariance -import Mathlib.Probability.Kernel.Basic -import Mathlib.Probability.Kernel.Composition.MeasureComp -import Mathlib.Analysis.BoundedVariation - -open ProbabilityTheory.Kernel - -namespace ProbabilityTheory.Kernel - -/-- `Kernel.pow κ n` is the `n`-fold composition of the kernel `κ`, with `pow κ 0 = id`. -/ -noncomputable def pow {α : Type*} [MeasurableSpace α] (κ : Kernel α α) : ℕ → Kernel α α -| 0 => Kernel.id -| n + 1 => κ ∘ₖ (pow κ n) - -end ProbabilityTheory.Kernel - -/-! -# Markov Chain Framework - -## Main definitions - -* `stochasticHopfieldMarkovProcess`: A Markov process on Hopfield network states -* `gibbsTransitionKernel`: The transition kernel for Gibbs sampling -* `DetailedBalance`: The detailed balance condition for reversible Markov chains -* `mixingTime`: The time needed to approach the stationary distribution (TODO) - --/ - -open MeasureTheory ProbabilityTheory ENNReal Finset Function ProbabilityTheory.Kernel Set - -namespace MarkovChain - --- Using the discrete sigma-algebra implicitly for the finite state space -instance (R U : Type) [Field R] [LinearOrder R] [IsStrictOrderedRing R] - [DecidableEq U] [Fintype U] [Nonempty U] : - MeasurableSpace ((HopfieldNetwork R U).State) := ⊤ - --- Prove all sets are measurable in the discrete sigma-algebra -lemma measurableSet_discrete {α : Type*} [MeasurableSpace α] (h : ‹_› = ⊤) - (s : Set α) : MeasurableSet s := by - rw [h] - trivial - -instance (R U : Type) [Field R] [LinearOrder R] [IsStrictOrderedRing R] - [DecidableEq U] [Fintype U] [Nonempty U] : - DiscreteMeasurableSpace ((HopfieldNetwork R U).State) where - forall_measurableSet := fun s => measurableSet_discrete rfl s - -/-! -### Core Markov Chain Definitions --/ - -/-- -A `StationaryDistribution` for a transition kernel is a measure that remains -invariant under the action of the kernel. --/ -structure StationaryDistribution {α : Type*} [MeasurableSpace α] (K : Kernel α α) where - /-- The probability measure that is stationary with respect to the kernel K. -/ - measure : Measure α - /-- Proof that the measure is a probability measure (sums to 1). -/ - isProbability : IsProbabilityMeasure measure - /-- Proof that the measure is invariant under the kernel K. -/ - isStationary : ∀ s, MeasurableSet s → (Measure.bind measure K) s = measure s - -/-- -The detailed balance condition for a Markov kernel with respect to a measure. -`μ(dx) K(x,dy) = μ(dy) K(y,dx)` for all measurable sets --/ -def DetailedBalance {α : Type*} [MeasurableSpace α] (μ : Measure α) (K : Kernel α α) : Prop := - ∀ A B : Set α, MeasurableSet A → MeasurableSet B → - ∫⁻ x in A, (K x B) ∂μ = ∫⁻ y in B, (K y A) ∂μ - -/-- When detailed balance holds, the measure is stationary -/ -def stationaryOfDetailedBalance {α : Type*} [MeasurableSpace α] {μ : Measure α} - [IsProbabilityMeasure μ] {K : Kernel α α} [IsMarkovKernel K] - (h : DetailedBalance μ K) : StationaryDistribution K where - measure := μ - isProbability := inferInstance - isStationary := by - intro s hs - have bind_def : (μ.bind K) s = ∫⁻ x, (K x s) ∂μ := by - apply Measure.bind_apply hs (Kernel.aemeasurable K) - have h_balance := h Set.univ s MeasurableSet.univ hs - rw [bind_def] - have h_univ : ∫⁻ x, K x s ∂μ = ∫⁻ x in Set.univ, K x s ∂μ := by - simp only [Measure.restrict_univ] - rw [h_univ, h_balance] - have univ_one : ∀ y, K y Set.univ = 1 := by - intro y - exact measure_univ - have h_one : ∫⁻ y in s, K y Set.univ ∂μ = ∫⁻ y in s, 1 ∂μ := by - apply lintegral_congr_ae - exact ae_of_all (μ.restrict s) univ_one - rw [h_one, MeasureTheory.lintegral_const, Measure.restrict_apply MeasurableSet.univ, - Set.univ_inter, one_mul] - -/-! -### Markov Chain on Hopfield Networks --/ - -section HopfieldMarkovChain - -variable {R U : Type} [Field R] [LinearOrder R] [IsStrictOrderedRing R] [DecidableEq U] - [Fintype U] [Nonempty U] [Coe R ℝ] - -instance : Nonempty ((HopfieldNetwork R U).State) := by - let defaultState : (HopfieldNetwork R U).State := { - act := fun _ => -1, - hp := fun _ => Or.inr rfl - } - exact ⟨defaultState⟩ - --- Fintype instance for the state space -noncomputable instance : Fintype ((HopfieldNetwork R U).State) := by - let f : ((HopfieldNetwork R U).State) → (U → {r : R | r = 1 ∨ r = -1}) := - fun s u => ⟨s.act u, s.hp u⟩ - have h_inj : Function.Injective f := by - intro s1 s2 h - cases s1 with | mk act1 hp1 => - cases s2 with | mk act2 hp2 => - simp at * - ext u - have h_u := congr_fun h u - simp [f] at h_u - exact h_u - have h_surj : Function.Surjective f := by - intro g - let act := fun u => (g u).val - have hp : ∀ u, act u = 1 ∨ act u = -1 := fun u => (g u).property - exists ⟨act, hp⟩ - exact _root_.instFintypeStateHopfieldNetwork - -noncomputable def gibbsTransitionKernel (wθ : Params (HopfieldNetwork R U)) (T : ℝ) : - Kernel ((HopfieldNetwork R U).State) ((HopfieldNetwork R U).State) where - toFun := fun state => (NN.State.gibbsSamplingStep wθ state T).toMeasure - measurable' := Measurable.of_discrete - --- Mark the kernel as a Markov kernel (preserves probability) -instance gibbsIsMarkovKernel (wθ : Params (HopfieldNetwork R U)) (T : ℝ) : - IsMarkovKernel (gibbsTransitionKernel wθ T) where - isProbabilityMeasure := by - intro s - simp [gibbsTransitionKernel] - exact PMF.toMeasure.isProbabilityMeasure (NN.State.gibbsSamplingStep wθ s T) - -/-- -The stochastic Hopfield Markov process, which models the evolution of Hopfield network states -over discrete time steps using Gibbs sampling at fixed temperature. -In this simplified model, the transition kernel is time-homogeneous (same for all steps). --/ -noncomputable def stochasticHopfieldMarkovProcess (wθ : Params (HopfieldNetwork R U)) (T : ℝ) : - ℕ → Kernel ((HopfieldNetwork R U).State) ((HopfieldNetwork R U).State) := - fun _ => gibbsTransitionKernel wθ T - -/-- -The n-step transition probability, which gives the probability of moving from -state x to state y in exactly n steps. --/ -noncomputable def nStepTransition (wθ : Params (HopfieldNetwork R U)) (T : ℝ) (n : ℕ) : - ((HopfieldNetwork R U).State) → ((HopfieldNetwork R U).State) → ENNReal := - fun x y => (Kernel.pow (gibbsTransitionKernel wθ T) n x) {y} -- Correct application of Kernel.pow - -/-- -The total variation distance between two probability measures on Hopfield network states. -Defined as supremum of |μ(A) - ν(A)| over all measurable sets A. --/ -noncomputable def totalVariation (μ ν : Measure ((HopfieldNetwork R U).State)) : ENNReal := - ⨆ (A : Set ((HopfieldNetwork R U).State)) (_ : MeasurableSet A), - ENNReal.ofReal (abs ((μ A).toReal - (ν A).toReal)) - -/-- -A state is aperiodic if there's a positive probability of returning to it in a single step. --/ -def IsAperiodic (wθ : Params (HopfieldNetwork R U)) (T : ℝ) - (s : (HopfieldNetwork R U).State) : Prop := (gibbsTransitionKernel wθ T s) {s} > 0 - -/-- -A Markov chain is irreducible if it's possible to get from any state to any other state -with positive probability in some finite number of steps. --/ -def IsIrreducible (wθ : Params (HopfieldNetwork R U)) (T : ℝ) : Prop := - ∀ x y, ∃ n, (Kernel.pow (gibbsTransitionKernel wθ T) n x) {y} > 0 -- Use Kernel.pow correctly - -/-- The unnormalized Boltzmann density function -/ -noncomputable def boltzmannDensityFn (wθ : Params (HopfieldNetwork R U)) (T : ℝ) - (s : (HopfieldNetwork R U).State) : ENNReal := - ENNReal.ofReal (Real.exp (-(Coe.coe (NeuralNetwork.State.E wθ s) : ℝ) / T)) - -/-- The Boltzmann partition function (normalizing constant) -/ -noncomputable def boltzmannPartitionFn (wθ : Params (HopfieldNetwork R U)) (T : ℝ) : ENNReal := - ∑ s ∈ Finset.univ, boltzmannDensityFn wθ T s - -/-- Helper lemma: A finite sum of ENNReal values is positive if the set is - nonempty and all terms are positive -/ -lemma ENNReal.sum_pos {α : Type*} (s : Finset α) (f : α → ENNReal) - (h_nonempty : s.Nonempty) (h_pos : ∀ i ∈ s, 0 < f i) : 0 < ∑ i ∈ s, f i := by - rcases h_nonempty with ⟨i, hi⟩ - have h_pos_i : 0 < f i := h_pos i hi - have h_le : f i ≤ ∑ j ∈ s, f j := Finset.single_le_sum (fun j _ => zero_le (f j)) hi - exact lt_of_lt_of_le h_pos_i h_le - -/-- The Boltzmann partition function is positive and finite -/ -lemma boltzmannPartitionFn_pos_finite (wθ : Params (HopfieldNetwork R U)) (T : ℝ) (_hT : T ≠ 0) : - 0 < boltzmannPartitionFn wθ T ∧ boltzmannPartitionFn wθ T < ⊤ := by - simp only [boltzmannPartitionFn] - have h_pos_term : ∀ s : (HopfieldNetwork R U).State, 0 < boltzmannDensityFn wθ T s := by - intro s - simp only [boltzmannDensityFn] - exact ENNReal.ofReal_pos.mpr (Real.exp_pos _) - have h_finite_term : ∀ s : (HopfieldNetwork R U).State, boltzmannDensityFn wθ T s ≠ ⊤ := by - intro s - simp only [boltzmannDensityFn] - exact ENNReal.ofReal_ne_top - constructor - · -- Proves positivity: sum of positive terms is positive - apply ENNReal.sum_pos - · exact Finset.univ_nonempty - · intro s _hs_in_univ - exact h_pos_term s - · -- Proves finiteness: sum is finite if all terms are finite - rw [sum_lt_top] - intro s _hs_in_univ - rw [lt_top_iff_ne_top] - exact h_finite_term s -/-- -The Boltzmann distribution over Hopfield network states at temperature T. --/ -noncomputable def boltzmannDistribution (wθ : Params (HopfieldNetwork R U)) (T : ℝ) (hT : T ≠ 0) : - Measure ((HopfieldNetwork R U).State) := - let densityFn := boltzmannDensityFn wθ T - let partitionFn := boltzmannPartitionFn wθ T - let _h_part_pos_finite := boltzmannPartitionFn_pos_finite wθ T hT - let countMeasure : Measure ((HopfieldNetwork R U).State) := MeasureTheory.Measure.count - if h_part : partitionFn = 0 ∨ partitionFn = ⊤ then - 0 - else - let partitionFn_ne_zero : partitionFn ≠ 0 := by - intro h_zero - exact h_part (Or.inl h_zero) - let partitionFn_ne_top : partitionFn ≠ ⊤ := by - intro h_top - exact h_part (Or.inr h_top) - Measure.withDensity countMeasure (fun s => densityFn s / partitionFn) - --- Helper lemma to handle the 'if' in boltzmannDistribution definition -lemma boltzmannDistribution_def_of_pos_finite (wθ : Params (HopfieldNetwork R U)) - (T : ℝ) (hT : T ≠ 0) : - boltzmannDistribution wθ T hT = - let densityFn := boltzmannDensityFn wθ T - let partitionFn := boltzmannPartitionFn wθ T - let countMeasure : Measure ((HopfieldNetwork R U).State) := MeasureTheory.Measure.count - Measure.withDensity countMeasure (fun s => densityFn s / partitionFn) := by - let h_part := boltzmannPartitionFn_pos_finite wθ T hT - simp [boltzmannDistribution, h_part.1.ne', h_part.2.ne] - -- Use the fact that partitionFn is > 0 and < ⊤ - -/-- The Boltzmann distribution measure of the universe equals the integral of density/partition -/ -lemma boltzmannDistribution_measure_univ (wθ : Params (HopfieldNetwork R U)) (T : ℝ) (hT : T ≠ 0) : - boltzmannDistribution wθ T hT Set.univ = - ∫⁻ s in Set.univ, (boltzmannDensityFn wθ T s) / (boltzmannPartitionFn wθ T) ∂Measure.count := by - rw [boltzmannDistribution_def_of_pos_finite wθ T hT] - simp only [withDensity_apply _ MeasurableSet.univ] - -/-- The integral over the universe equals the sum over all states -/ -lemma boltzmannDistribution_integral_eq_sum (wθ : Params (HopfieldNetwork R U)) - (T : ℝ) (_hT : T ≠ 0) : - ∫⁻ s in Set.univ, (boltzmannDensityFn wθ T s) / (boltzmannPartitionFn wθ T) ∂Measure.count = - ∑ s ∈ Finset.univ, (boltzmannDensityFn wθ T s) / (boltzmannPartitionFn wθ T) := by - rw [Measure.restrict_univ] - trans ∑' (s : (HopfieldNetwork R U).State), - (boltzmannDensityFn wθ T s) / (boltzmannPartitionFn wθ T) - · exact MeasureTheory.lintegral_count - (fun s => (boltzmannDensityFn wθ T s) / (boltzmannPartitionFn wθ T)) - · exact tsum_fintype fun b ↦ boltzmannDensityFn wθ T b / boltzmannPartitionFn wθ T - -/-- Division can be distributed over the sum in the Boltzmann distribution -/ -lemma boltzmannDistribution_div_sum (wθ : Params (HopfieldNetwork R U)) (T : ℝ) (hT : T ≠ 0) : - ∑ s ∈ Finset.univ, (boltzmannDensityFn wθ T s) / (boltzmannPartitionFn wθ T) = - (∑ s ∈ Finset.univ, boltzmannDensityFn wθ T s) / (boltzmannPartitionFn wθ T) := by - let Z := boltzmannPartitionFn wθ T - let h_part := boltzmannPartitionFn_pos_finite wθ T hT - have h_Z_pos : Z > 0 := h_part.1 - have h_Z_lt_top : Z < ⊤ := h_part.2 - have h_div_def : ∀ (a b : ENNReal), a / b = a * b⁻¹ := fun a b => by - rw [ENNReal.div_eq_inv_mul] - rw [mul_comm b⁻¹ a] - simp only [h_div_def] - rw [Finset.sum_mul] - - -/-- The sum of Boltzmann probabilities equals 1 -/ -lemma boltzmannDistribution_sum_one (wθ : Params (HopfieldNetwork R U)) (T : ℝ) (hT : T ≠ 0) : - (∑ s ∈ Finset.univ, boltzmannDensityFn wθ T s) / (boltzmannPartitionFn wθ T) = 1 := by - simp only [boltzmannPartitionFn] - let h_part := boltzmannPartitionFn_pos_finite wθ T hT - exact ENNReal.div_self h_part.1.ne' h_part.2.ne - -/-- -Proves that the Boltzmann distribution for a Hopfield network forms a valid probability measure. --/ -theorem boltzmannDistribution_isProbability {R U : Type} - [Field R] [LinearOrder R] [IsStrictOrderedRing R] [DecidableEq U] - [Fintype U] [Nonempty U] [Coe R ℝ] - (wθ : Params (HopfieldNetwork R U)) (T : ℝ) (hT : T ≠ 0) : - IsProbabilityMeasure (boltzmannDistribution wθ T hT) := by - constructor - rw [boltzmannDistribution_measure_univ wθ T hT] - rw [boltzmannDistribution_integral_eq_sum wθ T hT] - rw [boltzmannDistribution_div_sum wθ T hT] - exact boltzmannDistribution_sum_one wθ T hT - -end HopfieldMarkovChain - -end MarkovChain diff --git a/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Stochastic.lean b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Stochastic.lean deleted file mode 100644 index 3b2d71744..000000000 --- a/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/Stochastic.lean +++ /dev/null @@ -1,814 +0,0 @@ -/- -Copyright (c) 2025 Matteo Cipollina. All rights reserved. -Released under Apache 2.0 license as described in the file LICENSE. -Authors: Matteo Cipollina --/ - -import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.NNStochastic -import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.StochasticAux -import PhysLean.Thermodynamics.Temperature.Basic -import Mathlib.Analysis.RCLike.Basic -import Mathlib.LinearAlgebra.AffineSpace.AffineMap -import Mathlib.LinearAlgebra.Dual.Lemmas - -set_option linter.unusedSectionVars false -set_option linter.unusedVariables false - -/- -# Stochastic Hopfield Network Implementation - -This file defines and proves properties related to a stochastic Hopfield network. -It includes definitions for states, neural network parameters, energy computations, -and stochastic updates using both Gibbs sampling and Metropolis-Hastings algorithms. -- Functions (`StatePMF`, `StochasticDynamics`) representing probability measures over states. -- Key stochastic update operations, including a single-neuron Gibbs update - (`gibbsUpdateNeuron`, `gibbsUpdateSingleNeuron`) and full-network sampling steps - (`gibbsSamplingStep`, `gibbsSamplingSteps`) that iterate these updates. -- Definitions (`metropolisDecision`, `metropolisHastingsStep`, `metropolisHastingsSteps`) for - implementing a Metropolis-Hastings update rule in a Hopfield network. -- A simulated annealing procedure (`simulatedAnnealing`) that adaptively lowers the temperature - to guide the network into a low-energy configuration. -- Various lemmas (such as `single_site_difference`, `updateNeuron_preserves`, and - `gibbs_probs_sum_one`) ensuring correctness and consistency of the update schemes. -- Utility definitions and proofs, including creation of valid parameters - (`mkArray_creates_valid_hopfield_params`), - verification of adjacency (`all_nodes_adjacent`), total variation distance - (`total_variation_distance`), partition function (`partitionFunction`), and more. --/ -open Finset Matrix NeuralNetwork State ENNReal Real - -variable {R U : Type} [Field R] [LinearOrder R] [IsStrictOrderedRing R] - [DecidableEq U] [Fintype U] [Nonempty U] (wθ : Params (HopfieldNetwork R U)) - (s : (HopfieldNetwork R U).State) [Coe R ℝ] (T : ℝ) - -/-- Performs a Gibbs update on a single neuron `u` of the state `s`. - The update probability depends on the energy change associated with flipping the neuron's state, - parameterized by the temperature `T`. -/ -noncomputable def NN.State.gibbsUpdateNeuron (u : U) : - PMF ((HopfieldNetwork R U).State) := - let hᵤ := s.net wθ u - let ΔE := 2 * hᵤ * s.act u - let p_flip := ENNReal.ofReal (exp (-(↑ΔE) / T)) / (1 + ENNReal.ofReal (exp (-(↑ΔE) / T))) - let p_flip_le_one : p_flip ≤ 1 := by - let a := ENNReal.ofReal (exp (-(↑ΔE) / T)) - have h_sum_ne_top : (1 + a) ≠ ⊤ := add_ne_top.2 ⟨one_ne_top, ofReal_ne_top⟩ - rw [ENNReal.div_le_iff _ h_sum_ne_top, one_mul] - · exact le_add_self - · intro H; rw [add_eq_zero] at H; simp only [one_ne_zero] at H; exact H.1 - PMF.bind (PMF.bernoulli p_flip p_flip_le_one) $ fun should_flip => - PMF.pure $ if should_flip then s.Up wθ u else s - - -- Calculate probabilities based on Boltzmann distribution -noncomputable def probs (u : U) (local_field : R) : Bool → ENNReal := fun b => - let new_act_val := if b then 1 else -1 - ENNReal.ofReal (exp (local_field * new_act_val / T)) - -noncomputable def total (u : U) (local_field : R) : ENNReal := - probs T u local_field true + probs T u local_field false - -noncomputable def norm_probs (u : U) (local_field : R) : Bool → ENNReal := fun b => - probs T u local_field b / total T u local_field - -noncomputable def Z (local_field : R) := - ENNReal.ofReal (exp (local_field / T)) + ENNReal.ofReal (exp (-local_field / T)) - -omit [Field R] [LinearOrder R] [IsStrictOrderedRing R] in -lemma h_total_eq_Z (local_field : R) : total T u local_field = (Z T local_field) := by - simp only [mul_ite, mul_one, mul_neg, ↓reduceIte, Bool.false_eq_true, total, - probs, Z] - -omit [Field R] [LinearOrder R] [IsStrictOrderedRing R] [DecidableEq U] [Fintype U] [Nonempty U] in -lemma h_total_ne_zero (u : U) (local_field : R) : total T u local_field ≠ 0 := by - simp only [total, probs, ne_eq, add_eq_zero] - intro h - have h_exp_pos : ENNReal.ofReal (exp (local_field * 1 / T)) > 0 := by - apply ENNReal.ofReal_pos.mpr; apply exp_pos - exact (not_and_or.mpr (Or.inl h_exp_pos.ne')) h - -lemma h_sum (u : U) (local_field : R) : ∑ b : Bool, (norm_probs T u local_field b ) = - (probs T u local_field true + probs T u local_field false) / total T u local_field := by - simp only [Fintype.univ_bool, mem_singleton, Bool.true_eq_false, - not_false_eq_true,sum_insert, sum_singleton, total, probs, Z] - exact ENNReal.div_add_div_same - -lemma h_total_ne_top (u : U) (local_field : R) : total T u local_field ≠ ⊤ := by simp [total, probs] - -/-- Update a single neuron according to Gibbs sampling rule -/ -noncomputable def NN.State.gibbsUpdateSingleNeuron (u : U) : PMF ((HopfieldNetwork R U).State) := - -- Calculate local field for the neuron - let local_field := s.net wθ u - -- Convert Bool to State - (PMF.map (fun b => if b then - NN.State.updateNeuron s u 1 (mul_self_eq_mul_self_iff.mp rfl) - else - NN.State.updateNeuron s u (-1) (AffineMap.lineMap_eq_lineMap_iff.mp rfl)) - (PMF.ofFintype (norm_probs T u local_field) (by - have h_sum : Finset.sum Finset.univ (norm_probs T u local_field) = 1 := by - calc Finset.sum Finset.univ (norm_probs T u local_field) - = (probs T u local_field true)/total T u local_field + - (probs T u local_field false)/total T u local_field := - Fintype.sum_bool fun b ↦ probs T u local_field b / total T u local_field - _ = (probs T u local_field true + probs T u local_field false)/total T u local_field := - ENNReal.div_add_div_same - _ = total T u local_field /total T u local_field := by rfl - _ = 1 := ENNReal.div_self (h_total_ne_zero T u local_field) (h_total_ne_top T u local_field) - exact h_sum))) - -@[inherit_doc] -scoped[ENNReal] notation "ℝ≥0∞" => ENNReal - -open Fintype - -theorem NN.State.gibbsSamplingStep.extracted_1 {U : Type} [inst : Fintype U] [inst_1 : Nonempty U] : - ∑ a : U, (fun _ => 1 / ((Fintype.card U) : ENNReal)) a = 1 := by { - exact uniform_neuron_selection_prob_valid - } - -/-- Given a Hopfield Network's parameters, temperature, and current state, performs a single step -of Gibbs sampling by: -1. Uniformly selecting a random neuron -2. Updating that neuron's state according to the Gibbs distribution --/ -noncomputable def NN.State.gibbsSamplingStep : PMF ((HopfieldNetwork R U).State) := - -- Uniform random selection of neuron - let neuron_pmf : PMF U := - PMF.ofFintype (fun _ => (1 : ENNReal) / (card U : ENNReal)) - (NN.State.gibbsSamplingStep.extracted_1 (U:=U)) - -- Bind neuron selection with conditional update - PMF.bind neuron_pmf $ fun u => NN.State.gibbsUpdateSingleNeuron wθ s T u - -instance : Coe ℝ ℝ := ⟨id⟩ - -/-- Perform a stochastic update on a Pattern representation -/ -noncomputable def patternStochasticUpdate {n : ℕ} [Nonempty (Fin n)] (weights : Fin n → Fin n → ℝ) - (h_diag_zero : ∀ i : Fin n, weights i i = 0) (h_sym : ∀ i j : Fin n, weights i j = weights j i) - (pattern : State (HopfieldNetwork ℝ (Fin n))) (i : Fin n) : - PMF (State (HopfieldNetwork ℝ (Fin n))) := - let wθ : Params (HopfieldNetwork ℝ (Fin n)) := { - w := weights, - hw := fun u v h => by - if h_eq : u = v then - rw [h_eq] - exact h_diag_zero v - else - contradiction - hw' := by - exact IsSymm.ext_iff.mpr fun i j ↦ h_sym j i - σ := fun u => Vector.mk (Array.replicate - ((HopfieldNetwork ℝ (Fin n)).κ1 u) (0 : ℝ)) rfl, - θ := fun u => Vector.mk (Array.replicate - ((HopfieldNetwork ℝ (Fin n)).κ2 u) (0 : ℝ)) rfl - } - NN.State.gibbsUpdateSingleNeuron wθ pattern T i - -/-- Performs multiple steps of Gibbs sampling in a Hopfield network, starting from - an initial state. Each step involves: - 1. First recursively applying previous steps (if any) - 2. Then performing a single Gibbs sampling step on the resulting state - The temperature parameter T controls the randomness of the updates. -/ -noncomputable def NN.State.gibbsSamplingSteps (steps : ℕ) : PMF ((HopfieldNetwork R U).State) := - match steps with - | 0 => PMF.pure s - | steps + 1 => PMF.bind (gibbsSamplingSteps steps) $ fun s' => - NN.State.gibbsSamplingStep wθ s' T - -/-- Temperature schedule for simulated annealing that decreases exponentially with each step. -/ -noncomputable def temperatureSchedule (initial_temp : ℝ) (cooling_rate : ℝ) (step : ℕ) : ℝ := - initial_temp * exp (-cooling_rate * step) - - --initial_temp * exp (-cooling_rate * step) - -/-- Recursively applies Gibbs sampling steps with decreasing temperature according to - the cooling schedule, terminating when the step count reaches the target number of steps. -/ -noncomputable def applyAnnealingSteps (temp_schedule : ℕ → ℝ) (steps : ℕ) - (step : ℕ) (state : (HopfieldNetwork R U).State) : PMF ((HopfieldNetwork R U).State) := - if h : step ≥ steps then - PMF.pure state - else - PMF.bind (NN.State.gibbsSamplingStep wθ state (temp_schedule step)) - (applyAnnealingSteps temp_schedule steps (step + 1)) -termination_by steps - step -decreasing_by { - rw [Nat.sub_succ] - simp only [Nat.pred_eq_sub_one, tsub_lt_self_iff, tsub_pos_iff_lt, Nat.lt_one_iff] - rw [and_true] - exact not_le.mp h} - -/-- `NN.State.simulatedAnnealing` implements the simulated annealing optimization -algorithm for a Hopfield Network. This function performs simulated annealing by starting -from an initial state and gradually reducing the temperature according to an exponential -cooling schedule, allowing the system to explore the state space and eventually settle into a -low-energy configuration. --/ -noncomputable def NN.State.simulatedAnnealing (initial_temp : ℝ) (cooling_rate : ℝ) (steps : ℕ) - (initial_state : (HopfieldNetwork R U).State) : PMF ((HopfieldNetwork R U).State) := - let temp_schedule := temperatureSchedule initial_temp cooling_rate - applyAnnealingSteps wθ temp_schedule steps 0 initial_state - -/-- Given a HopfieldNetwork with parameters `wθ` and temperature `T`, computes the -acceptance probability for transitioning from a `current` state to a `proposed` state according -to the Metropolis-Hastings algorithm. - -* If the energy difference (ΔE) is negative or zero, returns 1.0 (always accepts the transition) -* If the energy difference is positive, returns exp(-ΔE/T) following the Boltzmann distribution --/ -noncomputable def NN.State.acceptanceProbability - (current : (HopfieldNetwork R U).State) (proposed : (HopfieldNetwork R U).State) : ℝ := - let energy_diff := proposed.E wθ - current.E wθ - if energy_diff ≤ 0 then - 1.0 -- Always accept if energy decreases - else - exp (-energy_diff / T) -- Accept with probability e^(-ΔE/T) if energy increases - -/-- The partition function for a Hopfield network, defined as the sum over all possible states -of the Boltzmann factor `exp(-E/T)`. --/ -noncomputable def NN.State.partitionFunction : ℝ := - ∑ s : (HopfieldNetwork R U).State, exp (-s.E wθ / T) - -/-- Metropolis-Hastings single step for Hopfield networks -/ -noncomputable def NN.State.metropolisHastingsStep : PMF ((HopfieldNetwork R U).State) := - -- Uniform random selection of neuron - let neuron_pmf : PMF U := - PMF.ofFintype (fun _ => (1 : ENNReal) / (Fintype.card U : ENNReal)) - (gibbsSamplingStep.extracted_1) - -- Create proposed state by flipping a randomly selected neuron - let propose : U → PMF ((HopfieldNetwork R U).State) := fun u => - let flipped_state := - if s.act u = 1 then -- Assuming 1 and -1 as valid activation values - NN.State.updateNeuron s u (-1) (Or.inr rfl) - else - NN.State.updateNeuron s u 1 (Or.inl rfl) - let p := NN.State.acceptanceProbability wθ T s flipped_state - -- Make acceptance decision - PMF.bind (NN.State.metropolisDecision p) (fun (accept : Bool) => - if accept then PMF.pure flipped_state else PMF.pure s) - -- Combine neuron selection with state proposal - PMF.bind neuron_pmf propose - -/-- Multiple steps of Metropolis-Hastings algorithm for Hopfield networks -/ -noncomputable def NN.State.metropolisHastingsSteps (steps : ℕ) - : PMF ((HopfieldNetwork R U).State) := - match steps with - | 0 => PMF.pure s - | steps+1 => PMF.bind (metropolisHastingsSteps steps) $ fun s' => - NN.State.metropolisHastingsStep wθ s' T - -/-- The Boltzmann (Gibbs) distribution over neural network states -/ -noncomputable def boltzmannDistribution : ((HopfieldNetwork R U).State → ℝ) := - fun s => exp (-s.E wθ / T) / NN.State.partitionFunction wθ T - -/-- The transition probability matrix for Gibbs sampling -/ -noncomputable def gibbsTransitionProb (s s' : (HopfieldNetwork R U).State) : ℝ := - ENNReal.toReal ((NN.State.gibbsSamplingStep wθ s) T s') - -/-- The transition probability matrix for Metropolis-Hastings -/ -noncomputable def metropolisTransitionProb (s s' : (HopfieldNetwork R U).State) : ℝ := - ENNReal.toReal ((NN.State.metropolisHastingsStep wθ s) T s') - -/-- Total variation distance between probability distributions -/ -noncomputable def total_variation_distance - (μ ν : (HopfieldNetwork R U).State → ℝ) : ℝ := - (1/2) * ∑ s : (HopfieldNetwork R U).State, |μ s - ν s| - -/-- For Gibbs updates, given the normalization and probabilities, the sum of - normalized probabilities equals 1 -/ -lemma gibbs_probs_sum_one (v : U) : - let local_field := s.net wθ v - let norm_probs := fun b => probs T v local_field b / total T v local_field - ∑ b : Bool, norm_probs b = 1 := by - intro local_field norm_probs - have h_sum : ∑ b : Bool, probs T v local_field b / total T v local_field = - (probs T v local_field true + probs T v local_field false) / total T v local_field := by - rw [Fintype.sum_bool, ENNReal.div_add_div_same] - rw [h_sum] - have h_total_eq : probs T v local_field true + - probs T v local_field false = total T v local_field := by rfl - rw [h_total_eq] - exact ENNReal.div_self (h_total_ne_zero T v local_field) (h_total_ne_top T v local_field) - -/-- The function that maps boolean values to states in Gibbs sampling -/ -def gibbs_bool_to_state_map - (s : (HopfieldNetwork R U).State) (v : U) : Bool → (HopfieldNetwork R U).State := - fun b => if b then - NN.State.updateNeuron s v 1 (mul_self_eq_mul_self_iff.mp rfl) - else - NN.State.updateNeuron s v (-1) (AffineMap.lineMap_eq_lineMap_iff.mp rfl) - -/-- The total normalization constant for Gibbs sampling is positive -/ -lemma gibbs_total_positive (local_field : ℝ) (T : ℝ) : - let probs : Bool → ENNReal := fun b => - let new_act_val := if b then 1 else -1 - ENNReal.ofReal (exp (local_field * new_act_val / T)) - probs true + probs false ≠ 0 := by - intro probs h_zero - have h1 : ENNReal.ofReal (exp (local_field * 1 / T)) > 0 := by - apply ENNReal.ofReal_pos.mpr - apply exp_pos - have h_sum_zero : ENNReal.ofReal (exp (local_field * 1 / T)) + - ENNReal.ofReal (exp (local_field * (-1) / T)) = 0 := by { - exact h_zero - } - have h_both_zero : ENNReal.ofReal (exp (local_field * 1 / T)) = 0 ∧ - ENNReal.ofReal (exp (local_field * (-1) / T)) = 0 := - add_eq_zero.mp h_sum_zero - exact h1.ne' h_both_zero.1 - -/-- The total normalization constant for Gibbs sampling is not infinity -/ -lemma gibbs_total_not_top (local_field : ℝ) (T : ℝ) : - probs T u local_field true + probs T u local_field false ≠ ⊤ := by - simp only [mul_ite, mul_one, mul_neg, ↓reduceIte, Bool.false_eq_true, ne_eq, ENNReal.add_eq_top, - ENNReal.ofReal_ne_top, or_self, not_false_eq_true, probs] - -/-- For a positive PMF.map application, there exists a preimage with positive probability -/ -lemma pmf_map_pos_implies_preimage {α β : Type} [Fintype α] [DecidableEq β] - {p : α → ENNReal} (h_pmf : ∑ a, p a = 1) (f : α → β) (y : β) : - (PMF.map f (PMF.ofFintype p h_pmf)) y > 0 → ∃ x : α, p x > 0 ∧ f x = y := by - intro h_pos - simp only [PMF.map_apply] at h_pos - simp_all only [PMF.ofFintype_apply, tsum_eq_filter_sum, gt_iff_lt, filter_sum_pos_iff_exists_pos, - pmf_map_pos_iff_exists_pos] - -/-- For states with positive Gibbs update probability, there exists a boolean variable that - determines whether the state has activation 1 or -1 at the updated neuron -/ -lemma gibbsUpdate_exists_bool (v : U) (s_next : (HopfieldNetwork R U).State) : - (NN.State.gibbsUpdateSingleNeuron wθ s T v) s_next > 0 → - ∃ b : Bool, s_next = gibbs_bool_to_state_map s v b := by - intro h_prob_pos - unfold NN.State.gibbsUpdateSingleNeuron at h_prob_pos - let local_field_R := s.net wθ v - let local_field : ℝ := ↑local_field_R - --let total := probs T v local_field true + probs T v local_field false - let norm_probs : Bool → ENNReal := fun b => probs T v local_field b / total T v local_field - let map_fn : Bool → (HopfieldNetwork R U).State := gibbs_bool_to_state_map s v - have h_sum_eq_1 : ∑ b : Bool, norm_probs b = 1 := by - have h_total_ne_zero : total T v local_field ≠ 0 := gibbs_total_positive local_field T - have h_total_ne_top : total T v local_field ≠ ⊤ := gibbs_total_not_top local_field T - calc Finset.sum Finset.univ norm_probs - = (probs T v local_field true) /total T v local_field + - (probs T v local_field false)/total T v local_field := - Fintype.sum_bool fun b ↦ probs T v local_field b / total T v local_field - _ = (probs T v local_field true + - probs T v local_field false)/total T v local_field:= ENNReal.div_add_div_same - _ = total T v local_field /total T v local_field := by rfl - _ = 1 := ENNReal.div_self h_total_ne_zero h_total_ne_top - let base_pmf := PMF.ofFintype norm_probs h_sum_eq_1 - have ⟨b, _, h_map_eq⟩ := pmf_map_pos_implies_preimage h_sum_eq_1 map_fn s_next h_prob_pos - use b - exact Eq.symm h_map_eq - -/-- For states with positive probability under gibbsUpdateSingleNeuron, - they must be one of exactly two possible states (with neuron v set to 1 or -1) -/ -@[simp] -lemma gibbsUpdate_possible_states (v : U) (s_next : (HopfieldNetwork R U).State) : - (NN.State.gibbsUpdateSingleNeuron wθ s T v) s_next > 0 → - s_next = NN.State.updateNeuron s v 1 (mul_self_eq_mul_self_iff.mp rfl) ∨ - s_next = NN.State.updateNeuron s v (-1) (AffineMap.lineMap_eq_lineMap_iff.mp rfl) := by - intro h_prob_pos - obtain ⟨b, h_eq⟩ := gibbsUpdate_exists_bool wθ s T v s_next h_prob_pos - cases b with - | false => - right - unfold gibbs_bool_to_state_map at h_eq - rw [@Std.Tactic.BVDecide.Normalize.if_eq_cond] at h_eq - exact h_eq - | true => - left - unfold gibbs_bool_to_state_map at h_eq - rw [@Std.Tactic.BVDecide.Normalize.if_eq_cond] at h_eq - exact h_eq - -/-- Gibbs update preserves states at non-updated sites -/ -@[simp] -lemma gibbsUpdate_preserves_other_neurons - (v w : U) (h_neq : w ≠ v) : - ∀ s_next, (NN.State.gibbsUpdateSingleNeuron wθ s T v) s_next > 0 → - s_next.act w = s.act w := by - intro s_next h_prob_pos - have h_structure := gibbsUpdate_possible_states wθ s T v s_next h_prob_pos - cases h_structure with - | inl h_pos => - rw [h_pos] - exact updateNeuron_preserves s v w 1 (mul_self_eq_mul_self_iff.mp rfl) h_neq - | inr h_neg => - rw [h_neg] - exact updateNeuron_preserves s v w (-1) - (AffineMap.lineMap_eq_lineMap_iff.mp rfl) h_neq - -/-- The probability mass function for a binary choice (true/false) - has sum 1 when properly normalized -/ -lemma pmf_binary_norm_sum_one (local_field : ℝ) (T : ℝ) : - let total := probs T u local_field true + probs T u local_field false - let norm_probs := fun b => probs T u local_field b / total - ∑ b : Bool, norm_probs b = 1 := by - intro total norm_probs - have h_sum : ∑ b : Bool, probs T u local_field b / total = - (probs T u local_field true + probs T u local_field false) / total := by - simp only [Fintype.sum_bool] - exact ENNReal.div_add_div_same - rw [h_sum] - have h_total_ne_zero : total ≠ 0 := by - simp only [total, probs, ne_eq] - intro h_zero - have h1 : ENNReal.ofReal (exp (local_field * 1 / T)) > 0 := by - apply ENNReal.ofReal_pos.mpr - apply exp_pos - have h_sum_zero : ENNReal.ofReal (exp (local_field * 1 / T)) + - ENNReal.ofReal (exp (local_field * (-1) / T)) = 0 := h_zero - have h_both_zero : ENNReal.ofReal (exp (local_field * 1 / T)) = 0 ∧ - ENNReal.ofReal (exp (local_field * (-1) / T)) = 0 := by - exact add_eq_zero.mp h_sum_zero - exact h1.ne' h_both_zero.1 - have h_total_ne_top : total ≠ ⊤ := by - simp [total, probs] - exact ENNReal.div_self h_total_ne_zero h_total_ne_top - -/-- The normalization factor in Gibbs sampling is the sum of Boltzmann - factors for both possible states -/ -lemma gibbs_normalization_factor (local_field : ℝ) (T : ℝ) : - let total := probs T u local_field true + probs T u local_field false - total = ENNReal.ofReal (exp (local_field / T)) + ENNReal.ofReal - (exp (-local_field / T)) := by - intro total - simp only [probs, total] - simp only [↓reduceIte, mul_one, Bool.false_eq_true, mul_neg, total, probs] - rfl - -/-- The probability mass assigned to true when using Gibbs sampling -/ -lemma gibbs_prob_true (local_field : ℝ) (T : ℝ) : - norm_probs T u local_field true = ENNReal.ofReal (exp (local_field / T)) / - (ENNReal.ofReal (exp (local_field / T)) + ENNReal.ofReal - (exp (-local_field / T))) := by - --intro total --norm_probs - simp only [norm_probs, probs] - have h_total : total T u local_field = ENNReal.ofReal (exp (local_field / T)) + - ENNReal.ofReal (exp (-local_field / T)) := by - simp only [mul_ite, mul_one, mul_neg, ↓reduceIte, Bool.false_eq_true, total, probs, norm_probs] - rfl - rw [h_total] - congr - simp only [↓reduceIte, mul_one, total, norm_probs, probs] - rfl - -/-- The probability mass assigned to false when using Gibbs sampling -/ -lemma gibbs_prob_false (local_field : ℝ) (T : ℝ) : - norm_probs T u local_field false = ENNReal.ofReal (exp (-local_field / T)) / - (ENNReal.ofReal (exp (local_field / T)) + ENNReal.ofReal (exp (-local_field / T))) := by - simp only [norm_probs, probs] - have h_total : total T u local_field = ENNReal.ofReal (exp (local_field / T)) + - ENNReal.ofReal (exp (-local_field / T)) := by - simp [total, probs] - rfl - rw [h_total] - congr - simp only [Bool.false_eq_true, ↓reduceIte, mul_neg, mul_one, norm_probs, probs, total] - rfl - -/-- Converts the ratio of Boltzmann factors to ENNReal sigmoid form. -/ -@[simp] -lemma ENNReal_exp_ratio_to_sigmoid (x : ℝ) : - ENNReal.ofReal (exp x) / (ENNReal.ofReal (exp x) + ENNReal.ofReal (exp (-x))) = - ENNReal.ofReal (1 / (1 + exp (-2 * x))) := by - have num_pos : 0 ≤ exp x := le_of_lt (exp_pos x) - have denom_pos : 0 < exp x + exp (-x) := by - apply add_pos - · exact exp_pos x - · exact exp_pos (-x) - have h1 : ENNReal.ofReal (exp x) / - (ENNReal.ofReal (exp x) + ENNReal.ofReal (exp (-x))) = - ENNReal.ofReal (exp x / (exp x + exp (-x))) := by - have h_sum : ENNReal.ofReal (exp x) + ENNReal.ofReal (exp (-x)) = - ENNReal.ofReal (exp x + exp (-x)) := by - have exp_neg_pos : 0 ≤ exp (-x) := le_of_lt (exp_pos (-x)) - exact Eq.symm (ofReal_add num_pos exp_neg_pos) - rw [h_sum] - exact Eq.symm (ofReal_div_of_pos denom_pos) - have h2 : exp x / (exp x + exp (-x)) = 1 / (1 + exp (-2 * x)) := by - have h_denom : exp x + exp (-x) = exp x * (1 + exp (-2 * x)) := by - have h_exp_diff : exp (-x) = exp x * exp (-2 * x) := by - rw [← exp_add]; congr; ring - calc exp x + exp (-x) - = exp x + exp x * exp (-2 * x) := by rw [h_exp_diff] - _ = exp x * (1 + exp (-2 * x)) := by rw [mul_add, mul_one] - rw [h_denom, div_mul_eq_div_div] - have h_exp_ne_zero : exp x ≠ 0 := ne_of_gt (exp_pos x) - field_simp - rw [h1, h2] - -@[simp] -lemma ENNReal.div_ne_top' {a b : ENNReal} (ha : a ≠ ⊤) (hb : b ≠ 0) : a / b ≠ ⊤ := by - intro h_top - rw [div_eq_top] at h_top - rcases h_top with (⟨_, h_right⟩ | ⟨h_left, _⟩); - exact hb h_right; exact ha h_left - -lemma gibbs_prob_positive (local_field : ℝ) (T : ℝ) : - let total := probs T u local_field true + probs T u local_field false - ENNReal.ofReal (exp (local_field / T)) / total = - ENNReal.ofReal (1 / (1 + exp (-2 * local_field / T))) := by - intro total - have h_total : total = ENNReal.ofReal (exp (local_field / T)) + - ENNReal.ofReal (exp (-local_field / T)) := by - simp only [mul_ite, mul_one, mul_neg, ↓reduceIte, Bool.false_eq_true, total, probs] - rfl - rw [h_total] - have h_temp : ∀ x, exp (x / T) = exp (x * (1/T)) := by - intro x; congr; field_simp - rw [h_temp local_field, h_temp (-local_field)] - have h_direct : - ENNReal.ofReal (exp (local_field * (1 / T))) / - (ENNReal.ofReal (exp (local_field * (1 / T))) + - ENNReal.ofReal (exp (-local_field * (1 / T)))) = - ENNReal.ofReal (1 / (1 + exp (-2 * local_field / T))) := by - have h := ENNReal_exp_ratio_to_sigmoid (local_field * (1 / T)) - have h_rhs : -2 * (local_field * (1 / T)) = -2 * local_field / T := by - field_simp - rw [h_rhs] at h - have neg_equiv : ENNReal.ofReal (exp (-(local_field * (1 / T)))) = - ENNReal.ofReal (exp (-local_field * (1 / T))) := by - congr; ring - rw [neg_equiv] at h - exact h - exact h_direct - -/-- The probability of setting a neuron to -1 under Gibbs sampling -/ -lemma gibbs_prob_negative (local_field : ℝ) (T : ℝ) : - let total := probs T u local_field true + probs T u local_field false - ENNReal.ofReal (exp (-local_field / T)) / total = - ENNReal.ofReal (1 / (1 + exp (2 * local_field / T))) := by - intro total - have h_total : total = ENNReal.ofReal (exp (local_field / T)) + - ENNReal.ofReal (exp (-local_field / T)) := by - simp only [mul_ite, mul_one, mul_neg, ↓reduceIte, Bool.false_eq_true, total, probs] - rfl - rw [h_total] - have h_neg2_neg : -2 * (-local_field / T) = 2 * local_field / T := by ring - have h_neg_neg : -(-local_field / T) = local_field / T := by ring - have h_ratio_final : ENNReal.ofReal (exp (-local_field / T)) / - (ENNReal.ofReal (exp (local_field / T)) + - ENNReal.ofReal (exp (-local_field / T))) = - ENNReal.ofReal (1 / (1 + exp (2 * local_field / T))) := by - have h := ENNReal_exp_ratio_to_sigmoid (-local_field / T) - have h_exp_neg_neg : ENNReal.ofReal (exp (-(-local_field / T))) = - ENNReal.ofReal (exp (local_field / T)) := by congr - rw [h_exp_neg_neg] at h - have h_comm : ENNReal.ofReal (exp (-local_field / T)) + - ENNReal.ofReal (exp (local_field / T)) = - ENNReal.ofReal (exp (local_field / T)) + - ENNReal.ofReal (exp (-local_field / T)) := by - rw [add_comm] - rw [h_neg2_neg, h_comm] at h - exact h - exact h_ratio_final - --- Lemma for the probability calculation in the positive case -lemma gibbs_prob_positive_case - (u : U) : - let local_field := s.net wθ u - let Z := ENNReal.ofReal (exp (local_field / T)) + ENNReal.ofReal (exp (-local_field / T)) - let norm_probs := fun b => if b then - ENNReal.ofReal (exp (local_field / T)) / Z - else - ENNReal.ofReal (exp (-local_field / T)) / Z - (PMF.map (gibbs_bool_to_state_map s u) (PMF.ofFintype norm_probs (by - have h_sum : ∑ b : Bool, norm_probs b = norm_probs true + norm_probs false := by - exact Fintype.sum_bool (fun b => norm_probs b) - rw [h_sum] - simp only [norm_probs] - have h_ratio_sum : ENNReal.ofReal (exp (local_field / T)) / Z + - ENNReal.ofReal (exp (-local_field / T)) / Z = - (ENNReal.ofReal (exp (local_field / T)) + - ENNReal.ofReal (exp (-local_field / T))) / Z := by - exact ENNReal.div_add_div_same - simp only [Bool.false_eq_true] - have h_if_true : (if True then ENNReal.ofReal (exp (local_field / T)) / Z - else ENNReal.ofReal (exp (-local_field / T)) / Z) = - ENNReal.ofReal (exp (local_field / T)) / Z := by simp - - have h_if_false : (if False then ENNReal.ofReal (exp (local_field / T)) / Z - else ENNReal.ofReal (exp (-local_field / T)) / Z) = - ENNReal.ofReal (exp (-local_field / T)) / Z := by simp - rw [h_if_true, h_if_false] - rw [h_ratio_sum] - have h_Z_ne_zero : Z ≠ 0 := by - simp only [ne_eq, add_eq_zero, ENNReal.ofReal_eq_zero, not_and, not_le, Z, norm_probs] - intros - exact exp_pos (-Coe.coe local_field / T) - have h_Z_ne_top : Z ≠ ⊤ := by simp [Z] - exact ENNReal.div_self h_Z_ne_zero h_Z_ne_top - ))) (NN.State.updateNeuron s u 1 (Or.inl rfl)) = norm_probs true := by - intro - apply pmf_map_update_one - --- Lemma for the probability calculation in the negative case -lemma gibbs_prob_negative_case (u : U) : - let local_field := s.net wθ u - let Z := ENNReal.ofReal (exp (local_field / T)) + - ENNReal.ofReal (exp (-local_field / T)) - let norm_probs := fun b => if b then - ENNReal.ofReal (exp (local_field / T)) / Z - else - ENNReal.ofReal (exp (-local_field / T)) / Z - (PMF.map (gibbs_bool_to_state_map s u) (PMF.ofFintype norm_probs (by - have h_sum : ∑ b : Bool, norm_probs b = norm_probs true + norm_probs false := by - exact Fintype.sum_bool (fun b => norm_probs b) - rw [h_sum] - simp only [norm_probs] - have h_ratio_sum : ENNReal.ofReal (exp (local_field / T)) / Z + - ENNReal.ofReal (exp (-local_field / T)) / Z = - (ENNReal.ofReal (exp (local_field / T)) + - ENNReal.ofReal (exp (-local_field / T))) / Z := by - exact ENNReal.div_add_div_same - simp only [Bool.false_eq_true] - simp only [↓reduceIte, norm_probs] - rw [h_ratio_sum] - have h_Z_ne_zero : Z ≠ 0 := by - simp only [Z, ne_eq, add_eq_zero] - intro h - have h_exp_pos : ENNReal.ofReal (exp (local_field / T)) > 0 := by - apply ENNReal.ofReal_pos.mpr - apply exp_pos - exact (not_and_or.mpr (Or.inl h_exp_pos.ne')) h - have h_Z_ne_top : Z ≠ ⊤ := by - simp only [ne_eq, ENNReal.add_eq_top, ENNReal.ofReal_ne_top, or_self, not_false_eq_true, Z, - norm_probs] - exact ENNReal.div_self h_Z_ne_zero h_Z_ne_top))) - (NN.State.updateNeuron s u (-1) (Or.inr rfl)) = norm_probs false := by - intro - apply pmf_map_update_neg_one - -/-- PMF map from boolean values to updated states preserves probability structure -/ -lemma gibbsUpdate_pmf_structure - (u : U) : - let local_field := s.net wθ u - let total := probs T u local_field true + probs T u local_field false - let norm_probs := fun b => probs T u local_field b / total - ∀ b : Bool, (PMF.map (gibbs_bool_to_state_map s u) (PMF.ofFintype norm_probs (by - have h_sum : ∑ b : Bool, norm_probs b = norm_probs true + norm_probs false := by - exact Fintype.sum_bool (fun b => norm_probs b) - rw [h_sum] - have h_ratio_sum : probs T u local_field true / total + probs T u local_field false / total = - (probs T u local_field true + probs T u local_field false) / total := by - exact ENNReal.div_add_div_same - rw [h_ratio_sum] - exact ENNReal.div_self (h_total_ne_zero T u local_field) (h_total_ne_top T u local_field) - ))) (gibbs_bool_to_state_map s u b) = norm_probs b := by - intro local_field total norm_probs b_bool - exact pmf_map_binary_state s u b_bool (fun b => norm_probs b) (by - have h_sum : ∑ b : Bool, norm_probs b = norm_probs true + norm_probs false := by - exact Fintype.sum_bool (fun b => norm_probs b) - rw [h_sum] - have h_ratio_sum : probs T u local_field true / total + probs T u local_field false / total = - (probs T u local_field true + probs T u local_field false) / total := by - exact ENNReal.div_add_div_same - rw [h_ratio_sum] - exact ENNReal.div_self (h_total_ne_zero T u local_field) (h_total_ne_top T u local_field)) - -def h_result_update_one (u : U) (local_field : R) := - pmf_map_update_one s u (norm_probs T u local_field ) (by - rw [h_sum] - exact ENNReal.div_self (h_total_ne_zero T u local_field) (h_total_ne_top T u local_field)) - -def h_result_neg_one (u : U) (local_field : R) := - pmf_map_update_neg_one s u (norm_probs T u local_field ) (by - rw [h_sum] - exact ENNReal.div_self (h_total_ne_zero T u local_field) (h_total_ne_top T u local_field)) - -/-- The probability of updating a neuron to 1 using Gibbs sampling -/ -lemma gibbsUpdate_prob_positive (u : U) : - let local_field := s.net wθ u - --let Z := ENNReal.ofReal (exp (local_field / T)) + ENNReal.ofReal (exp (-local_field / T)) - (NN.State.gibbsUpdateSingleNeuron wθ s T u) (NN.State.updateNeuron s u 1 (Or.inl rfl)) = - ENNReal.ofReal (exp (local_field / T)) / (Z T local_field) := by - intro local_field --Z - unfold NN.State.gibbsUpdateSingleNeuron - rw [h_result_update_one] - simp only [probs, mul_one_div, norm_probs] - rw [h_total_eq_Z] - simp only [if_true, mul_one,local_field] - -/-- The probability of updating a neuron to -1 using Gibbs sampling -/ -lemma gibbsUpdate_prob_negative (u : U) : - let local_field := s.net wθ u - --let Z := ENNReal.ofReal (exp (local_field / T)) + ENNReal.ofReal (exp (-local_field / T)) - (NN.State.gibbsUpdateSingleNeuron wθ s T u) (NN.State.updateNeuron s u (-1) (Or.inr rfl)) = - ENNReal.ofReal (exp (-local_field / T)) / (Z T local_field) := by - intro local_field - unfold NN.State.gibbsUpdateSingleNeuron - rw [h_result_neg_one] - simp only [probs, one_div_neg_one_eq_neg_one, one_div_neg_one_eq_neg_one, norm_probs] - rw [h_total_eq_Z] - simp only [Bool.false_eq_true, ↓reduceIte, mul_neg, mul_one, probs, Z, total, local_field] - -/-- Computes the probability of updating a neuron to a specific value using Gibbs sampling. -- If new_val = 1: probability = exp(local_field/T)/Z -- If new_val = -1: probability = exp(-local_field/T)/Z -where Z is the normalization constant (partition function). --/ -@[simp] -lemma gibbs_update_single_neuron_prob (u : U) (new_val : R) - (hval : (HopfieldNetwork R U).pact new_val) : - let local_field := s.net wθ u - let Z := ENNReal.ofReal (exp (local_field / T)) + - ENNReal.ofReal (exp (-local_field / T)) - (NN.State.gibbsUpdateSingleNeuron wθ s T u) (NN.State.updateNeuron s u new_val hval) = - if new_val = 1 then - ENNReal.ofReal (exp (local_field / T)) / Z - else - ENNReal.ofReal (exp (-local_field / T)) / Z := by - intro local_field Z - by_cases h_val : new_val = 1 - · rw [if_pos h_val] - have h_update_equiv := gibbs_bool_to_state_map_positive s u new_val hval h_val - rw [h_update_equiv] - exact gibbsUpdate_prob_positive wθ s T u - · rw [if_neg h_val] - have h_neg_val : new_val = -1 := hopfield_value_dichotomy new_val hval h_val - have h_update_equiv := gibbs_bool_to_state_map_negative s u new_val hval h_neg_val - rw [h_update_equiv] - exact gibbsUpdate_prob_negative wθ s T u - -/-- When states differ at site u, the probability of transitioning to s' by updating - any other site v is zero -/ -lemma gibbs_update_zero_other_sites (s s' : (HopfieldNetwork R U).State) - (u v : U) (h : ∀ w : U, w ≠ u → s.act w = s'.act w) (h_diff : s.act u ≠ s'.act u) : - v ≠ u → (NN.State.gibbsUpdateSingleNeuron wθ s T v) s' = 0 := by - intro hv - have h_act_diff : s'.act u ≠ s.act u := by - exact Ne.symm h_diff - have h_s'_diff_update : ∀ new_val hval, - s' ≠ NN.State.updateNeuron s v new_val hval := by - intro new_val hval - by_contra h_eq - have h_u_eq : s'.act u = (NN.State.updateNeuron s v new_val hval).act u := by - rw [←h_eq] - have h_u_preserved : (NN.State.updateNeuron s v new_val hval).act u = s.act u := by - exact updateNeuron_preserves s v u new_val hval (id (Ne.symm hv)) - rw [h_u_preserved] at h_u_eq - -- Use h to show contradiction - have h_s'_neq_s : s' ≠ s := by - by_contra h_s_eq - rw [h_s_eq] at h_diff - exact h_diff rfl - have h_same_elsewhere := h v hv - -- Now we have a contradiction: s' differs from s at u but also equals s.act u there - exact h_act_diff h_u_eq - by_contra h_pmf_nonzero - have h_pos_gt_zero : (NN.State.gibbsUpdateSingleNeuron wθ s T v) s' > 0 := by - exact (PMF.apply_pos_iff (NN.State.gibbsUpdateSingleNeuron wθ s T v) s').mpr h_pmf_nonzero - have h_structure := gibbsUpdate_possible_states wθ s T v s' h_pos_gt_zero - cases h_structure with - | inl h_pos_case => - apply h_s'_diff_update 1 (mul_self_eq_mul_self_iff.mp rfl) - exact h_pos_case - | inr h_neg_case => - apply h_s'_diff_update (-1) (AffineMap.lineMap_eq_lineMap_iff.mp rfl) - exact h_neg_case - -/-- When calculating the transition probability sum, only the term for the - differing site contributes -/ -lemma gibbs_transition_sum_simplification (s s' : (HopfieldNetwork R U).State) - (u : U) (h : ∀ v : U, v ≠ u → s.act v = s'.act v) (h_diff : s.act u ≠ s'.act u) : - let neuron_pmf : PMF U := PMF.ofFintype - (fun _ => (1 : ENNReal) / (Fintype.card U : ENNReal)) - (NN.State.gibbsSamplingStep.extracted_1) - let update_prob (v : U) : ENNReal := (NN.State.gibbsUpdateSingleNeuron wθ s T v) s' - ∑ v ∈ Finset.univ, neuron_pmf v * update_prob v = neuron_pmf u * update_prob u := by - intro neuron_pmf update_prob - have h_zero : ∀ v ∈ Finset.univ, v ≠ u → update_prob v = 0 := by - intro v _ hv - exact gibbs_update_zero_other_sites wθ T s s' u v h h_diff hv - apply Finset.sum_eq_single u - · intro v hv hvu - rw [h_zero v hv hvu] - simp only [mul_zero] - · intro hu - exfalso - apply hu - simp only [mem_univ] - -@[simp] -lemma gibbs_update_preserves_other_sites (v u : U) (hvu : v ≠ u) : - ∀ s_next, (NN.State.gibbsUpdateSingleNeuron wθ s T v) s_next > 0 → s_next.act u = s.act u := by - intro s_next h_pos - have h_supp : s_next ∈ PMF.support (NN.State.gibbsUpdateSingleNeuron wθ s T v) := by - exact (PMF.apply_pos_iff (NN.State.gibbsUpdateSingleNeuron wθ s T v) s_next).mp h_pos - have h_structure := gibbsUpdate_possible_states wθ s T v s_next h_pos - cases h_structure with - | inl h_pos => - -- Case s_next = updateNeuron s v 1 - rw [h_pos] - exact updateNeuron_preserves s v u 1 (mul_self_eq_mul_self_iff.mp rfl) (id (Ne.symm hvu)) - | inr h_neg => - -- Case s_next = updateNeuron s v (-1) - rw [h_neg] - exact - updateNeuron_preserves s v u (-1) (AffineMap.lineMap_eq_lineMap_iff.mp rfl) (id (Ne.symm hvu)) - -@[simp] -lemma uniform_neuron_prob {U : Type} [Fintype U] [Nonempty U] (u : U) : - (1 : ENNReal) / (Fintype.card U : ENNReal) = - PMF.ofFintype (fun _ : U => (1 : ENNReal) / (Fintype.card U : ENNReal)) - (by exact NN.State.gibbsSamplingStep.extracted_1 - ) u := by - simp only [one_div, PMF.ofFintype_apply] diff --git a/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/StochasticAux.lean b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/StochasticAux.lean deleted file mode 100644 index 57f800743..000000000 --- a/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/StochasticAux.lean +++ /dev/null @@ -1,471 +0,0 @@ -import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.Core -import Mathlib.Analysis.Normed.Ring.Basic -import Mathlib.Data.Complex.Exponential - -/- Several helper lemmas to support proofs of correctness, such as: -Lemmas (`energy_decomposition`, `weight_symmetry`, `energy_sum_split`) connecting the local -parameters (weights, biases) to the global energy function. -/ - -open Finset Matrix NeuralNetwork State - -variable {R U : Type} [Field R] [LinearOrder R] [IsStrictOrderedRing R] - [Fintype U] [Nonempty U] - -/-- The probability of selecting a specific neuron in the uniform distribution is 1/|U| -/ -lemma uniform_neuron_selection_prob (u : U) : - let p := λ _ => (1 : ENNReal) / (Fintype.card U : ENNReal) - let neuron_pmf := PMF.ofFintype p (by - rw [Finset.sum_const, Finset.card_univ] - rw [ENNReal.div_eq_inv_mul] - simp only [mul_one] - have h_card_ne_zero : (Fintype.card U : ENNReal) ≠ 0 := by - simp only [ne_eq, Nat.cast_eq_zero] - exact Fintype.card_ne_zero - have h_card_ne_top : (Fintype.card U : ENNReal) ≠ ⊤ := ENNReal.natCast_ne_top (Fintype.card U) - rw [← ENNReal.mul_inv_cancel h_card_ne_zero h_card_ne_top] - simp only [nsmul_eq_mul]) - neuron_pmf u = (1 : ENNReal) / (Fintype.card U : ENNReal) := by - intro p neuron_pmf - simp only [PMF.ofFintype_apply, one_div, neuron_pmf, p] - -/-- Uniform neuron selection gives a valid PMF -/ -lemma uniform_neuron_selection_prob_valid : - let p := λ (_ : U) => (1 : ENNReal) / (Fintype.card U : ENNReal) - ∑ a ∈ Finset.univ, p a = 1 := by - intro p - rw [Finset.sum_const, Finset.card_univ] - have h_card_pos : 0 < Fintype.card U := Fintype.card_pos_iff.mpr inferInstance - have h_card_ne_zero : (Fintype.card U : ENNReal) ≠ 0 := by - simp only [ne_eq, Nat.cast_eq_zero] - exact ne_of_gt h_card_pos - have h_card_top : (Fintype.card U : ENNReal) ≠ ⊤ := ENNReal.natCast_ne_top (Fintype.card U) - rw [ENNReal.div_eq_inv_mul, nsmul_eq_mul] - simp only [mul_one] - rw [ENNReal.mul_inv_cancel h_card_ne_zero h_card_top] - -variable [DecidableEq U] (wθ : Params (HopfieldNetwork R U)) - (s : (HopfieldNetwork R U).State) -/-- Decompose energy into weight component and bias component -/ -@[simp] -lemma energy_decomposition : - s.E wθ = s.Ew wθ + s.Eθ wθ := by - rw [NeuralNetwork.State.E] - --rw [← @add_neg_eq_iff_eq_add]; exact add_neg_eq_of_eq_add rfl - -/-- Weight matrix is symmetric in a Hopfield network -/ -lemma weight_symmetry (v1 v2 : U) : - wθ.w v1 v2 = wθ.w v2 v1 := (congrFun (congrFun (id (wθ.hw').symm) v1) v2) - -/-- Energy sum can be split into terms with u and terms without u -/ -lemma energy_sum_split (u : U): - ∑ v : U, ∑ v2 ∈ {v2 | v2 ≠ v}, wθ.w v v2 * s.act v * s.act v2 = - (∑ v2 ∈ {v2 | v2 ≠ u}, wθ.w u v2 * s.act u * s.act v2) + - (∑ v ∈ univ.erase u, ∑ v2 ∈ {v2 | v2 ≠ v}, wθ.w v v2 * s.act v * s.act v2) := by - rw [← sum_erase_add _ _ (mem_univ u)] - simp only [ne_eq, mem_univ, sum_erase_eq_sub, sub_add_cancel, add_sub_cancel] - -/-- When states differ at exactly one site, we can identify that site -/ -@[simp] -lemma single_site_difference (s s' : (HopfieldNetwork R U).State) - (h : s ≠ s' ∧ ∃ u : U, ∀ v : U, v ≠ u → s.act v = s'.act v) : - ∃! u : U, s.act u ≠ s'.act u := by - obtain ⟨s_neq, hu_all⟩ := h - obtain ⟨u, hu⟩ := hu_all - have diff_at_u : s.act u ≠ s'.act u := by { - by_contra h_eq - have all_same : ∀ v, s.act v = s'.act v := by { - intro v - by_cases hv : v = u - { rw [hv, h_eq]} - { exact hu v hv }} - have s_eq : s = s' := ext all_same - exact s_neq s_eq} - use u - constructor - { exact diff_at_u } - { intros v h_diff - by_contra h_v - have eq_v : s.act v = s'.act v := by { - by_cases hv : v = u - { rw [hv]; exact hu u fun a ↦ h_diff (hu v h_v) } - { exact hu v hv }} - exact h_diff eq_v } - -/-- States that are equal at all sites are equal -/ -lemma state_equality_from_sites - (s s' : (HopfieldNetwork R U).State) - (h : ∀ u : U, s.act u = s'.act u) : s = s' := by - apply NeuralNetwork.ext - exact h - -/-- Function to set a specific neuron state -/ -def NN.State.updateNeuron (u : U) (val : R) - (hval : (HopfieldNetwork R U).pact val) : (HopfieldNetwork R U).State := -{ act := fun u' => if u' = u then val else s.act u', - hp := by - intro u' - by_cases h : u' = u - · simp only [h, ↓reduceIte] - exact hval - · simp only [h, ↓reduceIte] - exact s.hp u' } - -/-- The updateNeuron function only changes the specified neuron, leaving others unchanged -/ -@[simp] -lemma updateNeuron_preserves - (s : (HopfieldNetwork R U).State) (v w : U) (val : R) (hval : (HopfieldNetwork R U).pact val) : - w ≠ v → (NN.State.updateNeuron s v val hval).act w = s.act w := by - intro h_neq - unfold NN.State.updateNeuron - exact if_neg h_neq - -/-- For states differing at only one site, that site must be u -/ -@[simp] -lemma single_site_difference_unique - (s s' : (HopfieldNetwork R U).State) - (u : U) (h : ∀ v : U, v ≠ u → s.act v = s'.act v) (h_diff : s ≠ s') : - ∃! v : U, s.act v ≠ s'.act v := by - use u - constructor - · by_contra h_eq - have all_equal : ∀ v, s.act v = s'.act v := by - intro v - by_cases hv : v = u - · rw [hv] - exact h_eq - · exact h v hv - exact h_diff (ext all_equal) - · intro v hv - by_contra h_neq - have v_diff_u : v ≠ u := by - by_contra h_eq - rw [h_eq] at hv - contradiction - exact hv (h v v_diff_u) - -/-- Given a single-site difference, the destination state is - an update of the source state -/ -lemma single_site_is_update - (s s' : (HopfieldNetwork R U).State) (u : U) - (h : ∀ v : U, v ≠ u → s.act v = s'.act v) : - s' = NN.State.updateNeuron s u (s'.act u) (s'.hp u) := by - apply NeuralNetwork.ext - intro v - by_cases hv : v = u - · rw [hv] - exact Eq.symm (if_pos rfl) - · rw [← h v hv] - exact Eq.symm (if_neg hv) - -/-- When updating a neuron with a value that equals one of the - standard values (1 or -1), the result equals the standard update -/ -@[simp] -lemma update_neuron_equiv - (s : (HopfieldNetwork R U).State) (u : U) (val : R) - (hval : (HopfieldNetwork R U).pact val) : - val = 1 → NN.State.updateNeuron s u val hval = - NN.State.updateNeuron s u 1 (Or.inl rfl) := by - intro h_val - apply NeuralNetwork.ext - intro v - unfold NN.State.updateNeuron - by_cases h_v : v = u - · exact ite_congr rfl (fun a ↦ h_val) (congrFun rfl) - · exact ite_congr rfl (fun a ↦ h_val) (congrFun rfl) - -/-- Updates with different activation values produce different states -/ -@[simp] -lemma different_activation_different_state - (s : (HopfieldNetwork R U).State) (u : U) : - NN.State.updateNeuron s u 1 (Or.inl rfl) ≠ - NN.State.updateNeuron s u (-1) (Or.inr rfl) := by - intro h_contra - have h_values : - (NN.State.updateNeuron s u 1 (Or.inl rfl)).act u = - (NN.State.updateNeuron s u (-1) (Or.inr rfl)).act u := by - congr - unfold NN.State.updateNeuron at h_values - simp at h_values - have : (1 : R) ≠ -1 := by - simp only [ne_eq] - norm_num - exact this h_values - -/-- Two neuron updates at the same site are equal if and only if - their new values are equal -/ -lemma update_neuron_eq_iff - (s : (HopfieldNetwork R U).State) (u : U) (val₁ val₂ : R) - (hval₁ : (HopfieldNetwork R U).pact val₁) (hval₂ : (HopfieldNetwork R U).pact val₂) : - NN.State.updateNeuron s u val₁ hval₁ = NN.State.updateNeuron s u val₂ hval₂ ↔ val₁ = val₂ := by - constructor - · intro h - have h_act : (NN.State.updateNeuron s u val₁ hval₁).act u = (NN.State.updateNeuron s u val₂ hval₂).act u := by - rw [h] - unfold NN.State.updateNeuron at h_act - simp at h_act - exact h_act - · intro h_val - subst h_val - apply NeuralNetwork.ext - intro v - by_cases hv : v = u - · subst hv; unfold NN.State.updateNeuron; exact rfl - · unfold NN.State.updateNeuron; exact rfl - -/-- Determines when a boolean-indexed update equals a specific update -/ -@[simp] -lemma bool_update_eq_iff - (s : (HopfieldNetwork R U).State) (u : U) (b : Bool) (val : R) - (hval : (HopfieldNetwork R U).pact val) : - (if b then NN.State.updateNeuron s u 1 (Or.inl rfl) - else NN.State.updateNeuron s u (-1) (Or.inr rfl)) = - NN.State.updateNeuron s u val hval ↔ - (b = true ∧ val = 1) ∨ (b = false ∧ val = -1) := by - cases b - · simp only [Bool.false_eq_true, ↓reduceIte, update_neuron_eq_iff, - false_and, true_and, false_or] - constructor - · intro h - exact id (Eq.symm h) - · intro h_cases - cases h_cases - trivial - · simp only [↓reduceIte, update_neuron_eq_iff, true_and, Bool.true_eq_false, - false_and, or_false] - constructor - · intro h - exact id (Eq.symm h) - · intro h_cases - cases h_cases - ·exact rfl - -/-- When filtering a PMF with binary support to states matching a given state's update, - the result reduces to a singleton if the update site matches -/ -lemma pmf_filter_update_neuron - (s : (HopfieldNetwork R U).State) (u : U) (val : R) - (hval : (HopfieldNetwork R U).pact val) : - let f : Bool → (HopfieldNetwork R U).State := λ b => - if b then NN.State.updateNeuron s u 1 (Or.inl rfl) - else NN.State.updateNeuron s u (-1) (Or.inr rfl) - filter (fun b => f b = NN.State.updateNeuron s u val hval) univ = - if val = 1 then {true} else - if val = -1 then {false} else ∅ := by - intro f - by_cases h1 : val = 1 - · simp only [h1] - ext b - simp only [mem_filter, mem_univ, true_and, mem_singleton] - rw [@bool_update_eq_iff] - simp only [and_true, ↓reduceIte, mem_singleton, or_iff_left_iff_imp, and_imp] - cases b - · simp only [Bool.false_eq_true, imp_false, forall_const] - norm_num - · simp only [Bool.true_eq_false, implies_true] - · by_cases h2 : val = -1 - · simp only [h1, h2] - ext b - simp only [mem_filter, mem_univ, true_and, mem_singleton] - rw [@bool_update_eq_iff] - simp only [and_true, ↓reduceIte] - cases b - · simp only [Bool.false_eq_true, false_and, or_true, true_iff] - norm_num - · simp only [true_and, Bool.true_eq_false, or_false] - norm_num - · simp only [h1, h2] - ext b - simp only [mem_filter, mem_univ, true_and] - rw [@bool_update_eq_iff] - simp only [h1, and_false, h2, or_self, ↓reduceIte, Finset.notMem_empty] - -/-- For a PMF over binary values mapped to states, the probability of a specific state - equals the probability of its corresponding binary value -/ -lemma pmf_map_binary_state - (s : (HopfieldNetwork R U).State) (u : U) (b : Bool) (p : Bool → ENNReal) (h_sum : ∑ b, p b = 1) : - let f : Bool → (HopfieldNetwork R U).State := λ b => - if b then NN.State.updateNeuron s u 1 (Or.inl rfl) - else NN.State.updateNeuron s u (-1) (Or.inr rfl) - PMF.map f (PMF.ofFintype p h_sum) (f b) = p b := by - intro f - simp only [PMF.map_apply] - have h_inj : ∀ b₁ b₂ : Bool, b₁ ≠ b₂ → f b₁ ≠ f b₂ := by - intro b₁ b₂ hneq - unfold f - cases b₁ <;> cases b₂ - · contradiction - · simp only [Bool.false_eq_true, ↓reduceIte, ne_eq] - apply Ne.symm - exact different_activation_different_state s u - · dsimp only [↓dreduceIte, Bool.false_eq_true, ne_eq] - have h_values_diff : (1 : R) ≠ (-1 : R) := by - simp only [ne_eq] - norm_num - exact (update_neuron_eq_iff s u 1 (-1) - (Or.inl rfl) (Or.inr rfl)).not.mpr h_values_diff - · contradiction - have h_unique : ∀ b' : Bool, f b' = f b ↔ b' = b := by - intro b' - by_cases h : b' = b - · constructor - · intro _ - exact h - · intro _ - rw [h] - · have : f b' ≠ f b := h_inj b' b h - constructor - · intro h_eq - contradiction - · intro h_eq - contradiction - have h_filter : (∑' (b' : Bool), if f b = f b' then (PMF.ofFintype p h_sum) b' else 0) = - (PMF.ofFintype p h_sum) b := by - rw [tsum_fintype] - have h_iff : ∀ b' : Bool, f b = f b' ↔ b = b' := by - intro b' - constructor - · intro h_eq - by_contra h_neq - have h_different : f b ≠ f b' := by - apply h_inj - exact h_neq - contradiction - · intro h_eq - rw [h_eq] - have h_eq : ∑ b' : Bool, ite (f b = f b') ((PMF.ofFintype p h_sum) b') 0 = - ∑ b' : Bool, ite (b = b') ((PMF.ofFintype p h_sum) b') 0 := by - apply Finset.sum_congr rfl - intro b' _ - have hcond : (f b = f b') ↔ (b = b') := h_iff b' - simp only [hcond] - rw [h_eq] - simp [h_eq, Finset.sum_ite_eq] - rw [@tsum_bool] - simp only [PMF.ofFintype_apply] - cases b - · have h_true_neq : f false ≠ f true := by - apply h_inj - simp only [ne_eq, Bool.false_eq_true, not_false_eq_true] - simp only [h_true_neq, if_true, if_false, add_zero] - · have h_false_neq : f true ≠ f false := by - apply h_inj - simp only [ne_eq, Bool.true_eq_false, not_false_eq_true] - simp only [h_false_neq, if_true, if_false, zero_add] - -/-- A specialized version of the previous lemma for the case where the state - is an update with new_val = 1 -/ -lemma pmf_map_update_one (s : (HopfieldNetwork R U).State) (u : U) - (p : Bool → ENNReal) (h_sum : ∑ b, p b = 1) : - let f : Bool → (HopfieldNetwork R U).State := λ b => - if b then NN.State.updateNeuron s u 1 (Or.inl rfl) - else NN.State.updateNeuron s u (-1) (Or.inr rfl) - PMF.map f (PMF.ofFintype p h_sum) (NN.State.updateNeuron s u 1 (Or.inl rfl)) = p true := by - intro f - apply pmf_map_binary_state s u true p h_sum - -/-- A specialized version for the case where the state is an update with new_val = -1 -/ -lemma pmf_map_update_neg_one - (s : (HopfieldNetwork R U).State) (u : U) (p : Bool → ENNReal) (h_sum : ∑ b, p b = 1) : - let f : Bool → (HopfieldNetwork R U).State := λ b => - if b then NN.State.updateNeuron s u 1 (Or.inl rfl) - else NN.State.updateNeuron s u (-1) (Or.inr rfl) - PMF.map f (PMF.ofFintype p h_sum) (NN.State.updateNeuron s u (-1) (Or.inr rfl)) = p false := by - intro f - apply pmf_map_binary_state s u false p h_sum - -/-- Expresses a ratio of exponentials in terms of the sigmoid function format. --/ -@[simp] -lemma exp_ratio_to_sigmoid (x : ℝ) : - Real.exp x / (Real.exp x + Real.exp (-x)) = 1 / (1 + Real.exp (-2 * x)) := by - have h_denom : Real.exp x + Real.exp (-x) = Real.exp x * (1 + Real.exp (-2 * x)) := by - have rhs_expanded : Real.exp x * (1 + Real.exp (-2 * x)) = - Real.exp x + Real.exp x * Real.exp (-2 * x) := by - rw [mul_add, mul_one] - have exp_identity : Real.exp x * Real.exp (-2 * x) = Real.exp (-x) := by - rw [← Real.exp_add] - congr - ring - rw [rhs_expanded, exp_identity] - rw [h_denom, div_mul_eq_div_div] - have h_exp_ne_zero : Real.exp x ≠ 0 := ne_of_gt (Real.exp_pos x) - field_simp - -/-- Local field is the weighted sum of incoming activations -/ -lemma local_field_eq_weighted_sum - (wθ : Params (HopfieldNetwork R U)) (s : (HopfieldNetwork R U).State) (u : U) : - s.net wθ u = ∑ v ∈ univ.erase u, wθ.w u v * s.act v := by - unfold NeuralNetwork.State.net - unfold NeuralNetwork.fnet HopfieldNetwork - simp only [ne_eq] - have sum_filter_eq : ∑ v ∈ filter (fun v => v ≠ u) univ, wθ.w u v * s.act v = - ∑ v ∈ univ.erase u, wθ.w u v * s.act v := by - apply Finset.sum_congr - · ext v - simp only [mem_filter, mem_erase, mem_univ, true_and, and_true] - · intro v _ - simp_all only [mem_erase, ne_eq, mem_univ, and_true] - --rw [@OrderedCommSemiring.mul_comm] - exact sum_filter_eq - -@[simp] -lemma gibbs_bool_to_state_map_positive - (s : (HopfieldNetwork R U).State) (u : U) (val : R) (hval : (HopfieldNetwork R U).pact val) : - val = 1 → NN.State.updateNeuron s u val hval = - NN.State.updateNeuron s u 1 (Or.inl rfl) := by - intro h_val - apply NeuralNetwork.ext - intro v - unfold NN.State.updateNeuron - by_cases h_v : v = u - · rw [h_v] - exact ite_congr rfl (fun a ↦ h_val) (congrFun rfl) - · simp only [h_v, if_neg] - exact rfl - -@[simp] -lemma gibbs_bool_to_state_map_negative - (s : (HopfieldNetwork R U).State) (u : U) (val : R) (hval : (HopfieldNetwork R U).pact val) : - val = -1 → NN.State.updateNeuron s u val hval = - NN.State.updateNeuron s u (-1) (Or.inr rfl) := by - intro h_val - apply NeuralNetwork.ext - intro v - unfold NN.State.updateNeuron - by_cases h_v : v = u - · rw [h_v] - dsimp only; exact congrFun (congrArg (ite (u = u)) h_val) (s.act u) - · dsimp only [h_v]; exact congrFun (congrArg (ite (v = u)) h_val) (s.act v) - -/-- When states differ at exactly one site, the later state can be expressed as - an update of the first state at that site -/ -lemma single_site_transition_as_update - (s s' : (HopfieldNetwork R U).State) (u : U) - (h : ∀ v : U, v ≠ u → s.act v = s'.act v) : - s' = NN.State.updateNeuron s u (s'.act u) (s'.hp u) := by - apply NeuralNetwork.ext - intro v - by_cases hv : v = u - · rw [hv] - unfold NN.State.updateNeuron - exact Eq.symm (if_pos rfl) - · unfold NN.State.updateNeuron - rw [← h v hv] - exact Eq.symm (if_neg hv) - -/-- When states differ at exactly one site, the later state can be expressed as - an update of the first state at that site -/ -@[simp] -lemma single_site_difference_as_update (s s' : (HopfieldNetwork R U).State) (u : U) - (h_diff_at_u : s.act u ≠ s'.act u) (h_same_elsewhere : ∀ v : U, v ≠ u → s.act v = s'.act v) : - s' = NN.State.updateNeuron s u (s'.act u) (s'.hp u) := by - apply NeuralNetwork.ext - intro v - by_cases hv : v = u - · rw [hv] - unfold NN.State.updateNeuron - simp only [if_pos rfl] - have _ := h_diff_at_u - exact rfl - · unfold NN.State.updateNeuron - simp only [if_neg hv] - exact Eq.symm (h_same_elsewhere v hv) From b067452b0ce23c43f1e34f4c346d4d2add87ea2a Mon Sep 17 00:00:00 2001 From: Matteo Cipollina Date: Wed, 27 Aug 2025 02:05:15 +0200 Subject: [PATCH 15/15] refactor NNStochastic --- .../HopfieldNetwork/NNStochastic.lean | 24 ++++++++++++------- .../SpinGlasses/HopfieldNetwork/TwoState.lean | 24 ++++++------------- 2 files changed, 23 insertions(+), 25 deletions(-) diff --git a/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/NNStochastic.lean b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/NNStochastic.lean index bcff7235c..34ff84b23 100644 --- a/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/NNStochastic.lean +++ b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/NNStochastic.lean @@ -1,16 +1,24 @@ import PhysLean.StatisticalMechanics.SpinGlasses.HopfieldNetwork.NeuralNetwork import Mathlib.Probability.ProbabilityMassFunction.Constructions +universe uR uU uσ + +open NeuralNetwork State + /-- Probability Mass Function over Neural Network States -/ -def NeuralNetwork.StatePMF {R U : Type} [Zero R] - (NN : NeuralNetwork R U) := PMF (NN.State) +def NeuralNetwork.StatePMF + {R : Type uR} {U : Type uU} {σ : Type uσ} [Zero R] + (NN : NeuralNetwork R U σ) : Type _ := + PMF NN.State /-- Temperature-parameterized stochastic dynamics for neural networks -/ -def NeuralNetwork.StochasticDynamics {R U : Type} [Zero R] - (NN : NeuralNetwork R U) := - ∀ (_ : ℝ), NN.State → NeuralNetwork.StatePMF NN +def NeuralNetwork.StochasticDynamics + {R : Type uR} {U : Type uU} {σ : Type uσ} [Zero R] + (NN : NeuralNetwork R U σ) := + ℝ → NN.State → NeuralNetwork.StatePMF NN /-- Metropolis acceptance decision as a probability mass function over Boolean outcomes -/ -def NN.State.metropolisDecision (p : ℝ) : PMF Bool := - PMF.bernoulli (ENNReal.ofReal (min p 1)) - (mod_cast min_le_right p 1) +def NeuralNetwork.State.metropolisDecision (p : ℝ) : PMF Bool := + PMF.bernoulli (ENNReal.ofReal (min p 1)) (by + have : min p 1 ≤ 1 := min_le_right _ _ + simp) diff --git a/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/TwoState.lean b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/TwoState.lean index 6659aec77..80985fe69 100644 --- a/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/TwoState.lean +++ b/PhysLean/StatisticalMechanics/SpinGlasses/HopfieldNetwork/TwoState.lean @@ -1446,9 +1446,7 @@ lemma updPos_eq_self_of_act_pos · subst hv; simp [updPos, Function.update, h] · simp [updPos, Function.update, hv] -/-- Helper: if the current activation at `u` is already `σ_neg`, `updNeg` is identity. - -(Version with fully explicit implicit parameters to avoid universe inference issues.) -/ +/-- Helper: if the current activation at `u` is already `σ_neg`, `updNeg` is identity. -/ lemma updNeg_eq_self_of_act_neg {R U σ} [Field R] [LinearOrder R] [IsStrictOrderedRing R] [Fintype U] [DecidableEq U] @@ -1492,8 +1490,7 @@ lemma Up_eq_updPos_or_updNeg simpa [NeuralNetwork.State.net] using hpos simp [updPos, Function.update, this, hθle] aesop - · - have hlt : net < θ := lt_of_not_ge hθle + · have hlt : net < θ := lt_of_not_ge hθle have hneg := TwoStateNeuralNetwork.h_fact_neg (NN:=NN) v (s.act v) net (p.θ v) hlt have : NN.fact v (s.act v) @@ -1502,8 +1499,7 @@ lemma Up_eq_updPos_or_updNeg simpa [NeuralNetwork.State.net] using hneg simp [updNeg, Function.update, this, hθle] aesop - · - unfold NeuralNetwork.State.Up + · unfold NeuralNetwork.State.Up simp_rw [hv, updPos, updNeg]; simp [Function.update] aesop end TwoState @@ -1545,7 +1541,6 @@ variable {U : Type} [Fintype U] [DecidableEq U] [Nonempty U] -- Note: The following lemmas are specialized for `SymmetricBinary ℝ U` -- to simplify the proof development by avoiding universe polymorphism issues. --- The original polymorphic versions can be restored later from this template. lemma updPos_eq_self_of_act_pos_binary (s : (SymmetricBinary ℝ U).State) (u : U) @@ -1592,16 +1587,13 @@ lemma energy_order_from_flip_id_binary constructor · intro hL have hκL : 0 ≤ κ * L := mul_nonneg hκ hL - -- Convert desired inequality to `sub ≤ 0`. have hsub : spec.E p sPos - spec.E p sNeg ≤ 0 := by rw [hdiff]; aesop exact sub_nonpos.mp hsub · intro hL have hκL : κ * L ≤ 0 := mul_nonpos_of_nonneg_of_nonpos hκ hL - -- Derive the reversed difference. have hrev : spec.E p sNeg - spec.E p sPos = κ * L := by have := congrArg Neg.neg hdiff - -- -(E sPos - E sNeg) = κ * L simpa [neg_sub, neg_mul, neg_neg] using this have hsub : spec.E p sNeg - spec.E p sPos ≤ 0 := by rw [hrev]; exact hκL @@ -1613,7 +1605,6 @@ lemma energy_is_lyapunov_at_site_binary (p : Params (SymmetricBinary ℝ U)) (s : (SymmetricBinary ℝ U).State) (u : U) (hcur : s.act u = 1 ∨ s.act u = -1) : spec.E p (s.Up p u) ≤ spec.E p s := by - -- Shorthand states and values set sPos := updPos (NN:=SymmetricBinary ℝ U) s u set sNeg := updNeg (NN:=SymmetricBinary ℝ U) s u set net := s.net p u @@ -1640,11 +1631,10 @@ lemma energy_is_lyapunov_at_site_binary simp [net, θ, hθle, sPos, sNeg] at hUp_cases_eval rw [hUp_cases_eval] cases hcur with - | inl h_is_pos => -- s is already sPos + | inl h_is_pos => rw [updPos_eq_self_of_act_pos_binary s u h_is_pos] - | inr h_is_neg => -- s = sNeg, need E(sPos) ≤ E(sNeg) + | inr h_is_neg => have hL_nonneg : 0 ≤ L := by simpa [L] using sub_nonneg.mpr hθle - -- Rewrite only the right-hand occurrence of s using hs have hs : s = sNeg := (updNeg_eq_self_of_act_neg_binary s u h_is_neg).symm have hLE : spec.E p sPos ≤ spec.E p sNeg := hOrder.left hL_nonneg simp_rw [hs] @@ -1655,14 +1645,14 @@ lemma energy_is_lyapunov_at_site_binary simp [net, θ, hθle, sPos, sNeg] at hUp_cases_eval rw [hUp_cases_eval] cases hcur with - | inl h_is_pos => -- s = sPos, need E(sNeg) ≤ E(sPos) + | inl h_is_pos => have hL_nonpos : L ≤ 0 := by simpa [L] using (lt_of_not_ge hθle).le have hs : s = sPos := (updPos_eq_self_of_act_pos_binary s u h_is_pos).symm have hLE : spec.E p sNeg ≤ spec.E p sPos := hOrder.2 hL_nonpos simp_rw [hs] exact le_of_eq_of_le (congrArg (spec.E p) (congrFun (congrArg updNeg (id (Eq.symm hs))) u)) hLE - | inr h_is_neg => -- s already sNeg + | inr h_is_neg => rw [updNeg_eq_self_of_act_neg_binary s u h_is_neg] end ConcreteLyapunov