Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 27 additions & 14 deletions Lean4Lean/TypeChecker.lean
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ namespace Lean
abbrev InferCache := ExprMap Expr

structure TypeChecker.State where
ngen : NameGenerator := { namePrefix := `_kernel_fresh, idx := 0 }
nid : Nat := 0
fvarTypeToReusedNamePrefix : Std.HashMap Expr Name := {}
inferTypeI : InferCache := {}
inferTypeC : InferCache := {}
whnfCoreCache : ExprMap Expr := {}
Expand Down Expand Up @@ -40,13 +41,25 @@ instance : MonadEnv M where
instance : MonadLCtx M where
getLCtx := return (← read).lctx

instance [Monad m] : MonadNameGenerator (StateT State m) where
getNGen := return (← get).ngen
setNGen ngen := modify fun s => { s with ngen }

instance (priority := low) : MonadWithReaderOf LocalContext M where
withReader f := withReader fun s => { s with lctx := f s.lctx }

def mkNewId : M Name := do
let nid := (← get).nid
modify fun st => { st with nid := st.nid + 1 }
pure $ .mkNum `_kernel_fresh nid

def mkId (dom : Expr) : M Name := do
if let some np := (← get).fvarTypeToReusedNamePrefix[dom]? then
let mut count := 0
while (← getLCtx).findFVar? (Expr.fvar $ .mk (Name.mkNum np count)) |>.isSome do
count := count + 1
pure $ Name.mkNum np count
else
let np ← mkNewId
modify fun st => { st with fvarTypeToReusedNamePrefix := st.fvarTypeToReusedNamePrefix.insert dom np }
pure $ Name.mkNum np 0

structure Methods where
isDefEqCore : Expr → Expr → M Bool
whnfCore (e : Expr) (cheapRec := false) (cheapProj := false) : M Expr
Expand Down Expand Up @@ -113,7 +126,7 @@ def inferLambda (e : Expr) (inferOnly : Bool) : RecM Expr := loop #[] e where
loop fvars : Expr → RecM Expr
| .lam name dom body bi => do
let d := dom.instantiateRev fvars
let id := ⟨← mkFreshId
let id := ⟨← mkId d
withLCtx ((← getLCtx).mkLocalDecl id name d bi) do
let fvars := fvars.push (.fvar id)
if !inferOnly then
Expand All @@ -130,7 +143,7 @@ def inferForall (e : Expr) (inferOnly : Bool) : RecM Expr := loop #[] #[] e wher
let d := dom.instantiateRev fvars
let t1 ← ensureSortCore (← inferType d inferOnly) d
let us := us.push t1.sortLevel!
let id := ⟨← mkFreshId
let id := ⟨← mkId d
withLCtx ((← getLCtx).mkLocalDecl id name d bi) do
let fvars := fvars.push (.fvar id)
loop fvars us body
Expand Down Expand Up @@ -178,7 +191,7 @@ def inferLet (e : Expr) (inferOnly : Bool) : RecM Expr := loop #[] #[] e where
| .letE name type val body _ => do
let type := type.instantiateRev fvars
let val := val.instantiateRev fvars
let id := ⟨← mkFreshId
let id := ⟨← mkNewId
withLCtx ((← getLCtx).mkLetDecl id name type val) do
let fvars := fvars.push (.fvar id)
let vals := vals.push val
Expand Down Expand Up @@ -455,7 +468,7 @@ def isDefEqLambda (t s : Expr) (subst : Array Expr := #[]) : RecM Bool :=
else pure none
if tBody.hasLooseBVars || sBody.hasLooseBVars then
let sType := sType.getD (sDom.instantiateRev subst)
let id := ⟨← mkFreshId
let id := ⟨← mkId sType
withLCtx ((← getLCtx).mkLocalDecl id name sType bi) do
isDefEqLambda tBody sBody (subst.push (.fvar id))
else
Expand All @@ -473,17 +486,17 @@ def isDefEqForall (t s : Expr) (subst : Array Expr := #[]) : RecM Bool :=
else pure none
if tBody.hasLooseBVars || sBody.hasLooseBVars then
let sType := sType.getD (sDom.instantiateRev subst)
let id := ⟨← mkFreshId
let id := ⟨← mkId sType
withLCtx ((← getLCtx).mkLocalDecl id name sType bi) do
isDefEqForall tBody sBody (subst.push (.fvar id))
else
isDefEqForall tBody sBody (subst.push default)
| t, s => isDefEq (t.instantiateRev subst) (s.instantiateRev subst)

def quickIsDefEq (t s : Expr) (useHash := false) : RecM LBool := do
if ← modifyGet fun (.mk a1 a2 a3 a4 a5 a6 (eqvManager := m)) =>
if ← modifyGet fun (.mk a1 a2 a3 a4 a5 a6 a7 (eqvManager := m)) =>
let (b, m) := m.isEquiv useHash t s
(b, .mk a1 a2 a3 a4 a5 a6 (eqvManager := m))
(b, .mk a1 a2 a3 a4 a5 a6 a7 (eqvManager := m))
then return .true
match t, s with
| .lam .., .lam .. => toLBoolM <| isDefEqLambda t s
Expand Down Expand Up @@ -733,7 +746,7 @@ def etaExpand (e : Expr) : M Expr :=
let rec loop fvars
| .lam name dom body bi => do
let d := dom.instantiateRev fvars
let id := ⟨← mkFreshId
let id := ⟨← mkId d
withLCtx ((← getLCtx).mkLocalDecl id name d bi) do
let fvars := fvars.push (.fvar id)
loop fvars body
Expand All @@ -744,7 +757,7 @@ def etaExpand (e : Expr) : M Expr :=
| 0, _ => throw .deepRecursion
| fuel + 1, .forallE name dom body bi => do
let d := dom.instantiateRev fvars
let id := ⟨← mkFreshId
let id := ⟨← mkId d
withLCtx ((← getLCtx).mkLocalDecl id name d bi) do
let arg := .fvar id
let fvars := fvars.push arg
Expand Down