diff --git a/Lean4Lean/TypeChecker.lean b/Lean4Lean/TypeChecker.lean index 5ac0f38..d0d50cc 100644 --- a/Lean4Lean/TypeChecker.lean +++ b/Lean4Lean/TypeChecker.lean @@ -11,7 +11,8 @@ namespace Lean abbrev InferCache := ExprMap Expr structure TypeChecker.State where - ngen : NameGenerator := { namePrefix := `_kernel_fresh, idx := 0 } + nid : Nat := 0 + fvarTypeToReusedNamePrefix : Std.HashMap Expr Name := {} inferTypeI : InferCache := {} inferTypeC : InferCache := {} whnfCoreCache : ExprMap Expr := {} @@ -40,13 +41,25 @@ instance : MonadEnv M where instance : MonadLCtx M where getLCtx := return (← read).lctx -instance [Monad m] : MonadNameGenerator (StateT State m) where - getNGen := return (← get).ngen - setNGen ngen := modify fun s => { s with ngen } - instance (priority := low) : MonadWithReaderOf LocalContext M where withReader f := withReader fun s => { s with lctx := f s.lctx } +def mkNewId : M Name := do + let nid := (← get).nid + modify fun st => { st with nid := st.nid + 1 } + pure $ .mkNum `_kernel_fresh nid + +def mkId (dom : Expr) : M Name := do + if let some np := (← get).fvarTypeToReusedNamePrefix[dom]? then + let mut count := 0 + while (← getLCtx).findFVar? (Expr.fvar $ .mk (Name.mkNum np count)) |>.isSome do + count := count + 1 + pure $ Name.mkNum np count + else + let np ← mkNewId + modify fun st => { st with fvarTypeToReusedNamePrefix := st.fvarTypeToReusedNamePrefix.insert dom np } + pure $ Name.mkNum np 0 + structure Methods where isDefEqCore : Expr → Expr → M Bool whnfCore (e : Expr) (cheapRec := false) (cheapProj := false) : M Expr @@ -113,7 +126,7 @@ def inferLambda (e : Expr) (inferOnly : Bool) : RecM Expr := loop #[] e where loop fvars : Expr → RecM Expr | .lam name dom body bi => do let d := dom.instantiateRev fvars - let id := ⟨← mkFreshId⟩ + let id := ⟨← mkId d⟩ withLCtx ((← getLCtx).mkLocalDecl id name d bi) do let fvars := fvars.push (.fvar id) if !inferOnly then @@ -130,7 +143,7 @@ def inferForall (e : Expr) (inferOnly : Bool) : RecM Expr := loop #[] #[] e wher let d := dom.instantiateRev fvars let t1 ← ensureSortCore (← inferType d inferOnly) d let us := us.push t1.sortLevel! - let id := ⟨← mkFreshId⟩ + let id := ⟨← mkId d⟩ withLCtx ((← getLCtx).mkLocalDecl id name d bi) do let fvars := fvars.push (.fvar id) loop fvars us body @@ -178,7 +191,7 @@ def inferLet (e : Expr) (inferOnly : Bool) : RecM Expr := loop #[] #[] e where | .letE name type val body _ => do let type := type.instantiateRev fvars let val := val.instantiateRev fvars - let id := ⟨← mkFreshId⟩ + let id := ⟨← mkNewId⟩ withLCtx ((← getLCtx).mkLetDecl id name type val) do let fvars := fvars.push (.fvar id) let vals := vals.push val @@ -455,7 +468,7 @@ def isDefEqLambda (t s : Expr) (subst : Array Expr := #[]) : RecM Bool := else pure none if tBody.hasLooseBVars || sBody.hasLooseBVars then let sType := sType.getD (sDom.instantiateRev subst) - let id := ⟨← mkFreshId⟩ + let id := ⟨← mkId sType⟩ withLCtx ((← getLCtx).mkLocalDecl id name sType bi) do isDefEqLambda tBody sBody (subst.push (.fvar id)) else @@ -473,7 +486,7 @@ def isDefEqForall (t s : Expr) (subst : Array Expr := #[]) : RecM Bool := else pure none if tBody.hasLooseBVars || sBody.hasLooseBVars then let sType := sType.getD (sDom.instantiateRev subst) - let id := ⟨← mkFreshId⟩ + let id := ⟨← mkId sType⟩ withLCtx ((← getLCtx).mkLocalDecl id name sType bi) do isDefEqForall tBody sBody (subst.push (.fvar id)) else @@ -481,9 +494,9 @@ def isDefEqForall (t s : Expr) (subst : Array Expr := #[]) : RecM Bool := | t, s => isDefEq (t.instantiateRev subst) (s.instantiateRev subst) def quickIsDefEq (t s : Expr) (useHash := false) : RecM LBool := do - if ← modifyGet fun (.mk a1 a2 a3 a4 a5 a6 (eqvManager := m)) => + if ← modifyGet fun (.mk a1 a2 a3 a4 a5 a6 a7 (eqvManager := m)) => let (b, m) := m.isEquiv useHash t s - (b, .mk a1 a2 a3 a4 a5 a6 (eqvManager := m)) + (b, .mk a1 a2 a3 a4 a5 a6 a7 (eqvManager := m)) then return .true match t, s with | .lam .., .lam .. => toLBoolM <| isDefEqLambda t s @@ -733,7 +746,7 @@ def etaExpand (e : Expr) : M Expr := let rec loop fvars | .lam name dom body bi => do let d := dom.instantiateRev fvars - let id := ⟨← mkFreshId⟩ + let id := ⟨← mkId d⟩ withLCtx ((← getLCtx).mkLocalDecl id name d bi) do let fvars := fvars.push (.fvar id) loop fvars body @@ -744,7 +757,7 @@ def etaExpand (e : Expr) : M Expr := | 0, _ => throw .deepRecursion | fuel + 1, .forallE name dom body bi => do let d := dom.instantiateRev fvars - let id := ⟨← mkFreshId⟩ + let id := ⟨← mkId d⟩ withLCtx ((← getLCtx).mkLocalDecl id name d bi) do let arg := .fvar id let fvars := fvars.push arg