diff --git a/Ix/Aiur/Bytecode.lean b/Ix/Aiur/Bytecode.lean index 8051a9e7..69abff49 100644 --- a/Ix/Aiur/Bytecode.lean +++ b/Ix/Aiur/Bytecode.lean @@ -27,6 +27,8 @@ inductive Op | u8ShiftRight : ValIdx → Op | u8Xor : ValIdx → ValIdx → Op | u8Add : ValIdx → ValIdx → Op + | u8And : ValIdx → ValIdx → Op + | u8Or : ValIdx → ValIdx → Op | debug : String → Option (Array ValIdx) → Op deriving Repr diff --git a/Ix/Aiur/Check.lean b/Ix/Aiur/Check.lean index 53b59321..83e33a51 100644 --- a/Ix/Aiur/Check.lean +++ b/Ix/Aiur/Check.lean @@ -78,282 +78,297 @@ def refLookup (global : Global) : CheckM Typ := do def bindIdents (bindings : List (Local × Typ)) (ctx : CheckContext) : CheckContext := { ctx with varTypes := ctx.varTypes.insertMany bindings } +def fieldTerm (t : TypedTermInner) : TypedTerm := ⟨.field, t, false⟩ + mutual -partial def inferTerm : Term → CheckM TypedTerm - | .unit => pure $ .mk (.evaluates .unit) .unit - | .var x => do - -- Retrieves and returns the variable type from the context. - let ctx ← read - match ctx.varTypes[x]? with - | some t => pure $ .mk (.evaluates t) (.var x) - | none => - let Local.str localName := x | unreachable! - let typ := .evaluates (← refLookup (Global.init localName)) - pure $ .mk typ (.var x) +partial def inferTerm (t : Term) : CheckM TypedTerm := match t with + | .unit => pure ⟨.unit, .unit, false⟩ + | .var x => inferVariable x | .ref x => do - let typ := .evaluates (← refLookup x) - pure $ .mk typ (.ref x) - | .ret term => do - -- Ensures that the type of the returned term matches the expected return type. - -- The term is not allowed to have a (nested) return. - -- Returning the type of the term is not necessary because it's already in the context. - let ctx ← read - let inner ← checkNoEscape term ctx.returnType - pure $ .mk .escapes inner - | .data data => do - let (typ, inner) ← inferData data - pure $ .mk (.evaluates typ) inner - | .let pat expr body => do - -- Returns the type of the body, inferred in the context extended with the bound variable type. - -- The bound variable is ensured not to escape. - let (exprTyp, exprInner) ← inferNoEscape expr - let expr' := .mk (.evaluates exprTyp) exprInner - let bindings ← checkPattern pat exprTyp - let body' ← withReader (bindIdents bindings) (inferTerm body) - pure $ .mk body'.typ (.let pat expr' body') + pure ⟨← refLookup x, .ref x, false⟩ + | .ret term => inferReturn term + | .data data => inferData data + | .let pat expr body => inferLet pat expr body | .match term branches => inferMatch term branches - | .app func@(⟨.str .anonymous unqualifiedFunc⟩) args => do - -- Ensures the function exists in the context and that the arguments, which aren't allowed to - -- escape, match the function's input types. Returns the function's output type. - let ctx ← read - match ctx.varTypes[Local.str unqualifiedFunc]? with - | some (.function inputs output) => do - let args ← checkArgsAndInputs func args inputs - pure $ .mk (.evaluates output) (.app func args) - | some _ => throw $ .notAFunction func - | none => match ctx.decls.getByKey func with - | some (.function function) => do - let args ← checkArgsAndInputs func args (function.inputs.map Prod.snd) - pure $ .mk (.evaluates function.output) (.app func args) - | some (.constructor dataType constr) => do - let args ← checkArgsAndInputs func args constr.argTypes - pure $ .mk (.evaluates (.dataType dataType.name)) (.app func args) - | _ => throw $ .cannotApply func - | .app func args => do - -- Only checks global map if it is not unqualified - let ctx ← read - match ctx.decls.getByKey func with - | some (.function function) => - let args ← checkArgsAndInputs func args (function.inputs.map Prod.snd) - pure $ .mk (.evaluates function.output) (.app func args) - | some (.constructor dataType constr) => - let args ← checkArgsAndInputs func args constr.argTypes - pure $ .mk (.evaluates (.dataType dataType.name)) (.app func args) - | _ => throw $ .cannotApply func - | .add a b => do - let (ctxTyp, a, b) ← checkArith a b - pure $ .mk ctxTyp (.add a b) - | .sub a b => do - let (ctxTyp, a, b) ← checkArith a b - pure $ .mk ctxTyp (.sub a b) - | .mul a b => do - let (ctxTyp, a, b) ← checkArith a b - pure $ .mk ctxTyp (.mul a b) - | .eqZero a => do - let a ← fieldTerm <$> checkNoEscape a .field - pure $ .mk (.evaluates .field) (.eqZero a) - | .proj tup i => do - let (typs, tupInner) ← inferTuple tup - if h : i < typs.size then - let typ := typs[i] - let tup := .mk (.evaluates (.tuple typs)) tupInner - pure $ .mk (.evaluates typ) (.proj tup i) - else - throw $ .indexOoB i - | .get arr i => do - let (typ, n, inner) ← inferArray arr - if i ≥ n then - throw $ .indexOoB i - else - let arr := .mk (.evaluates (.array typ n)) inner - pure $ .mk (.evaluates typ) (.get arr i) - | .slice arr i j => if j < i then throw $ .negativeRange i j else do + | .app func args => inferApplication func args + | .ann typ term => do + pure ⟨typ, ← checkNoEscape term typ, false⟩ + | .proj tup i => inferProj tup i + | .get arr i => inferGet arr i + | .slice arr i j => inferSlice arr i j + | .set arr i val => inferSet arr i val + | .store term => inferStore term + | .load term => inferLoad term + | .ptrVal term => inferPtrVal term + | .eqZero a => inferUnop a .eqZero .field + | .add a b => inferBinop a b .add .field + | .sub a b => inferBinop a b .sub .field + | .mul a b => inferBinop a b .mul .field + | .u8ShiftLeft a => inferUnop a .u8ShiftLeft .field + | .u8ShiftRight a => inferUnop a .u8ShiftRight .field + | .u8BitDecomposition a => inferUnop a .u8BitDecomposition (.array .field 8) + | .u8Xor a b => inferBinop a b .u8Xor .field + | .u8And a b => inferBinop a b .u8And .field + | .u8Or a b => inferBinop a b .u8Or .field + | .u8Add a b => inferBinop a b .u8Add (.tuple #[.field, .field]) + | .ioGetInfo key => inferIoGetInfo key + | .ioSetInfo key idx len ret => inferIoSetInfo key idx len ret + | .ioRead idx len => inferIoRead idx len + | .ioWrite data ret => inferIoWrite data ret + | .assertEq a b ret => inferAssertEq a b ret + | .debug label term ret => inferDebug label term ret + +partial def checkNoEscape (term : Term) (typ : Typ) : CheckM TypedTermInner := do + let (typ', inner) ← inferNoEscape term + unless typ == typ' do throw $ .typeMismatch typ typ' + pure inner + +partial def inferNoEscape (term : Term) : CheckM (Typ × TypedTermInner) := do + let typedTerm ← inferTerm term + if typedTerm.escapes then throw .illegalReturn + pure (typedTerm.typ, typedTerm.inner) + +partial def inferUnop + (a : Term) + (op : TypedTerm → TypedTermInner) + (typ : Typ) : + CheckM TypedTerm := do + let a ← fieldTerm <$> checkNoEscape a .field + pure ⟨typ, op a, false⟩ + +partial def inferBinop + (a : Term) + (b : Term) + (op : TypedTerm → TypedTerm → TypedTermInner) + (typ : Typ) : + CheckM TypedTerm := do + let a ← fieldTerm <$> checkNoEscape a .field + let b ← fieldTerm <$> checkNoEscape b .field + pure ⟨typ, op a b, false⟩ + +partial def inferProj (tup : Term) (i : Nat) : CheckM TypedTerm := do + let (typs, tupInner) ← inferTuple tup + if h : i < typs.size then + let typ := typs[i] + let tup := ⟨.tuple typs, tupInner, false⟩ + pure ⟨typ, .proj tup i, false⟩ + else + throw $ .indexOoB i + +partial def inferGet (arr : Term) (i : Nat) : CheckM TypedTerm := do + let (typ, n, inner) ← inferArray arr + if i ≥ n then + throw $ .indexOoB i + else + let arr := ⟨.array typ n, inner, false⟩ + pure ⟨typ, .get arr i, false⟩ + +partial def inferSlice (arr : Term) (i j : Nat) : CheckM TypedTerm := + if j < i then throw $ .negativeRange i j else do let (typ, n, inner) ← inferArray arr if j ≤ n then - let arr := .mk (.evaluates (.array typ n)) inner - pure $ .mk (.evaluates (.array typ (j - i))) (.slice arr i j) + let arr := ⟨.array typ n, inner, false⟩ + pure ⟨.array typ (j - i), .slice arr i j, false⟩ else throw $ .rangeOoB i j - | .set arr i val => do - let (typ, n, inner) ← inferArray arr - if i ≥ n then - throw $ .indexOoB i - else - let val ← checkNoEscape val typ - let arrTyp := .array typ n - let arr := .mk (.evaluates arrTyp) inner - let val := .mk (.evaluates typ) val - pure $ .mk (.evaluates arrTyp) (.set arr i val) - | .store term => do - -- Infers the type of the term and returns it, wrapped by a pointer type. - -- The term is not allowed to early return. - let (typ, inner) ← inferNoEscape term - let store := .store (.mk (.evaluates typ) inner) - pure $ .mk (.evaluates (.pointer typ)) store - | .load term => do - -- Ensures that the the type of the term is a pointer type and returns the unwrapped type. - -- The term is not allowed to early return. - let (typ, inner) ← inferNoEscape term - match typ with - | .pointer innerTyp => - let load := .load (.mk (.evaluates typ) inner) - pure $ .mk (.evaluates innerTyp) load - | _ => throw $ .notAPointer typ - | .ptrVal term => do - -- Infers the type of the term, which must be a pointer, but returns `.u64`, as in a cast. - -- The term is not allowed to early return. - let (typ, inner) ← inferNoEscape term - match typ with - | .pointer _ => pure $ fieldTerm (.ptrVal (.mk (.evaluates typ) inner)) - | _ => throw $ .notAPointer typ - | .ann typ term => do - let inner ← checkNoEscape term typ - pure $ .mk (.evaluates typ) inner - -- | .unsafeCast term castTyp => do - -- let (typ, inner) ← inferNoEscape term - -- pure $ .mk (.evaluates typ) (.unsafeCast inner castTyp) - | .assertEq a b ret => do - -- `a` and `b` must have the same type. - let (typ, a) ← inferNoEscape a - let b ← checkNoEscape b typ - let ret ← inferTerm ret - let assertEq := .assertEq (.mk (.evaluates typ) a) (.mk (.evaluates typ) b) ret - pure $ .mk (.evaluates ret.typ.unwrap) assertEq - | .ioGetInfo key => do - let (typ, keyInner) ← inferNoEscape key - match typ with - | .array .. => - let ioGetInfo := .ioGetInfo (.mk (.evaluates typ) keyInner) - pure $ .mk (.evaluates (.tuple #[.field, .field])) ioGetInfo - | _ => throw $ .notAnArray typ - | .ioSetInfo key idx len ret => do - let (keyTyp, keyInner) ← inferNoEscape key - match keyTyp with - | .array keyEltTyp _ => - if keyEltTyp != .field then throw $ .typeMismatch .field keyEltTyp - let idx ← fieldTerm <$> checkNoEscape idx .field - let len ← fieldTerm <$> checkNoEscape len .field - let ret ← inferTerm ret - let ioSetInfo := .ioSetInfo (.mk (.evaluates keyTyp) keyInner) idx len ret - pure $ .mk (.evaluates ret.typ.unwrap) ioSetInfo - | _ => throw $ .notAnArray keyTyp - | .ioRead idx len => do - if len = 0 then throw .emptyArray + +partial def inferSet (arr : Term) (i : Nat) (val : Term) : CheckM TypedTerm := do + let (typ, n, inner) ← inferArray arr + if i ≥ n then + throw $ .indexOoB i + else + let val ← checkNoEscape val typ + let arrTyp := .array typ n + let arr := ⟨arrTyp, inner, false⟩ + let val := ⟨typ, val, false⟩ + pure ⟨arrTyp, .set arr i val, false⟩ + +partial def inferStore (term : Term) : CheckM TypedTerm := do + let (typ, inner) ← inferNoEscape term + let store := .store ⟨typ, inner, false⟩ + pure ⟨.pointer typ, store, false⟩ + +partial def inferLoad (term : Term) : CheckM TypedTerm := do + let (typ, inner) ← inferNoEscape term + match typ with + | .pointer innerTyp => + let load := .load ⟨typ, inner, false⟩ + pure ⟨innerTyp, load, false⟩ + | _ => throw $ .notAPointer typ + +partial def inferPtrVal (term : Term) : CheckM TypedTerm := do + let (typ, inner) ← inferNoEscape term + match typ with + | .pointer _ => pure $ fieldTerm (.ptrVal ⟨typ, inner, false⟩) + | _ => throw $ .notAPointer typ + +partial def inferIoGetInfo (key : Term) : CheckM TypedTerm := do + let (typ, keyInner) ← inferNoEscape key + match typ with + | .array .. => + let ioGetInfo := .ioGetInfo ⟨typ, keyInner, false⟩ + pure ⟨.tuple #[.field, .field], ioGetInfo, false⟩ + | _ => throw $ .notAnArray typ + +partial def inferIoSetInfo (key idx len ret : Term) : CheckM TypedTerm := do + let (keyTyp, keyInner) ← inferNoEscape key + match keyTyp with + | .array keyEltTyp _ => + if keyEltTyp != .field then throw $ .typeMismatch .field keyEltTyp let idx ← fieldTerm <$> checkNoEscape idx .field - let ioRead := .ioRead idx len - pure $ .mk (.evaluates (.array .field len)) ioRead - | .ioWrite data ret => do - let (dataTyp, dataInner) ← inferNoEscape data - match dataTyp with - | .array dataEltTyp _ => - if dataEltTyp != .field then throw $ .typeMismatch .field dataEltTyp - let ret ← inferTerm ret - let ioWrite := .ioWrite (.mk (.evaluates dataTyp) dataInner) ret - pure $ .mk (.evaluates ret.typ.unwrap) ioWrite - | _ => throw $ .notAnArray dataTyp - | .u8BitDecomposition byte => do - let byte ← fieldTerm <$> checkNoEscape byte .field - let u8BitDecomposition := .u8BitDecomposition byte - pure $ .mk (.evaluates (.array .field 8)) u8BitDecomposition - | .u8ShiftLeft byte => do - let byte ← fieldTerm <$> checkNoEscape byte .field - let u8ShiftLeft := .u8ShiftLeft byte - pure $ .mk (.evaluates .field) u8ShiftLeft - | .u8ShiftRight byte => do - let byte ← fieldTerm <$> checkNoEscape byte .field - let u8ShiftRight := .u8ShiftRight byte - pure $ .mk (.evaluates .field) u8ShiftRight - | .u8Xor i j => do - let i ← fieldTerm <$> checkNoEscape i .field - let j ← fieldTerm <$> checkNoEscape j .field - let u8Xor := .u8Xor i j - pure $ .mk (.evaluates .field) u8Xor - | .u8Add i j => do - let i ← fieldTerm <$> checkNoEscape i .field - let j ← fieldTerm <$> checkNoEscape j .field - let u8Add := .u8Add i j - pure $ .mk (.evaluates (.tuple #[.field, .field])) u8Add - | .debug label term ret => do - let term ← term.mapM inferTerm + let len ← fieldTerm <$> checkNoEscape len .field + let ret ← inferTerm ret + let ioSetInfo := .ioSetInfo ⟨keyTyp, keyInner, false⟩ idx len ret + pure ⟨ret.typ, ioSetInfo, ret.escapes⟩ + | _ => throw $ .notAnArray keyTyp + +partial def inferIoRead (idx : Term) (len : Nat) : CheckM TypedTerm := do + if len = 0 then throw .emptyArray + let idx ← fieldTerm <$> checkNoEscape idx .field + let ioRead := .ioRead idx len + pure ⟨.array .field len, ioRead, false⟩ + +partial def inferIoWrite (data ret : Term) : CheckM TypedTerm := do + let (dataTyp, dataInner) ← inferNoEscape data + match dataTyp with + | .array dataEltTyp _ => + if dataEltTyp != .field then throw $ .typeMismatch .field dataEltTyp let ret ← inferTerm ret - let debug := .debug label term ret - pure $ .mk (.evaluates ret.typ.unwrap) debug + let ioWrite := .ioWrite ⟨dataTyp, dataInner, false⟩ ret + pure ⟨ret.typ, ioWrite, ret.escapes⟩ + | _ => throw $ .notAnArray dataTyp + +partial def inferVariable (x : Local) : CheckM TypedTerm := do + let ctx ← read + match ctx.varTypes[x]? with + | some t => pure ⟨t, .var x, false⟩ + | none => + let Local.str localName := x | unreachable! + let typ ← refLookup (Global.init localName) + pure ⟨typ, .var x, false⟩ + +partial def inferReturn (term : Term) : CheckM TypedTerm := do + let ctx ← read + let inner ← checkNoEscape term ctx.returnType + pure ⟨ctx.returnType, inner, true⟩ + +partial def inferLet (pat : Pattern) (expr : Term) (body : Term) : CheckM TypedTerm := do + let (exprTyp, exprInner) ← inferNoEscape expr + let expr' := ⟨exprTyp, exprInner, false⟩ + let bindings ← checkPattern pat exprTyp + let body' ← withReader (bindIdents bindings) (inferTerm body) + pure ⟨body'.typ, .let pat expr' body', body'.escapes⟩ + +partial def inferUnqualifiedApp (func : Global) (unqualifiedFunc : String) (args : List Term) : CheckM TypedTerm := do + let ctx ← read + match ctx.varTypes[Local.str unqualifiedFunc]? with + | some (.function inputs output) => do + let args ← checkArgsAndInputs func args inputs + pure ⟨output, .app func args, false⟩ + | some _ => throw $ .notAFunction func + | none => match ctx.decls.getByKey func with + | some (.function function) => do + let args ← checkArgsAndInputs func args (function.inputs.map Prod.snd) + pure ⟨function.output, .app func args, false⟩ + | some (.constructor dataType constr) => do + let args ← checkArgsAndInputs func args constr.argTypes + pure ⟨.dataType dataType.name, .app func args, false⟩ + | _ => throw $ .cannotApply func where - /-- - Ensures that there are as many arguments and as expected types and that - the types of the arguments are precisely those expected. - -/ checkArgsAndInputs func args inputs : CheckM (List TypedTerm) := do let lenArgs := args.length let lenInputs := inputs.length unless lenArgs == lenInputs do throw $ .wrongNumArgs func lenArgs lenInputs let pass := fun (arg, input) => do let inner ← checkNoEscape arg input - pure $ .mk (.evaluates input) inner + pure ⟨input, inner, false⟩ args.zip inputs |>.mapM pass - checkArith a b := do - let aInner ← checkNoEscape a .field - let bInner ← checkNoEscape b .field - pure (.evaluates .field, fieldTerm aInner, fieldTerm bInner) - fieldTerm := (.mk (.evaluates .field) ·) -partial def checkNoEscape (term : Term) (typ : Typ) : CheckM TypedTermInner := do - let (typ', inner) ← inferNoEscape term - unless typ == typ' do throw $ .typeMismatch typ typ' - pure inner +partial def inferQualifiedApp (func : Global) (args : List Term) : CheckM TypedTerm := do + let ctx ← read + match ctx.decls.getByKey func with + | some (.function function) => + let args ← checkArgsAndInputs func args (function.inputs.map Prod.snd) + pure ⟨function.output, .app func args, false⟩ + | some (.constructor dataType constr) => + let args ← checkArgsAndInputs func args constr.argTypes + pure ⟨.dataType dataType.name, .app func args, false⟩ + | _ => throw $ .cannotApply func +where + checkArgsAndInputs func args inputs : CheckM (List TypedTerm) := do + let lenArgs := args.length + let lenInputs := inputs.length + unless lenArgs == lenInputs do throw $ .wrongNumArgs func lenArgs lenInputs + let pass := fun (arg, input) => do + let inner ← checkNoEscape arg input + pure ⟨input, inner, false⟩ + args.zip inputs |>.mapM pass -partial def inferNoEscape (term : Term) : CheckM (Typ × TypedTermInner) := do - let typedTerm ← inferTerm term - match typedTerm.typ with - | .escapes => throw .illegalReturn - | .evaluates type => pure (type, typedTerm.inner) +partial def inferApplication (func : Global) (args : List Term) : CheckM TypedTerm := + match func.toName with + | .str .anonymous unqualifiedFunc => inferUnqualifiedApp func unqualifiedFunc args + | _ => inferQualifiedApp func args + +partial def inferAssertEq (a b ret : Term) : CheckM TypedTerm := do + let (typ, a) ← inferNoEscape a + let b ← checkNoEscape b typ + let ret ← inferTerm ret + let assertEq := .assertEq ⟨typ, a, false⟩ ⟨typ, b, false⟩ ret + pure ⟨ret.typ, assertEq, ret.escapes⟩ + +partial def inferDebug (label : String) (term : Option Term) (ret : Term) : CheckM TypedTerm := do + let term ← term.mapM inferTerm + let ret ← inferTerm ret + let debug := .debug label term ret + pure ⟨ret.typ, debug, ret.escapes⟩ -partial def inferData : Data → CheckM (Typ × TypedTermInner) - | .field g => pure (.field, .data (.field g)) +partial def inferData : Data → CheckM TypedTerm + | .field g => pure ⟨.field, .data (.field g), false⟩ | .tuple terms => do let typsAndInners ← terms.mapM inferNoEscape let typs := typsAndInners.map Prod.fst - let terms := typsAndInners.map fun (typ, inner) => .mk (.evaluates typ) inner - pure (.tuple typs, .data (.tuple terms)) + let terms := typsAndInners.map fun (typ, inner) => ⟨typ, inner, false⟩ + pure ⟨.tuple typs, .data (.tuple terms), false⟩ | .array terms => do if h : terms.size > 0 then let (typ, firstInner) ← inferNoEscape terms[0] let mut typedTerms := Array.mkEmpty terms.size - |>.push (.mk (.evaluates typ) firstInner) + |>.push ⟨typ, firstInner, false⟩ for term in terms[1:] do let inner ← checkNoEscape term typ - typedTerms := typedTerms.push (.mk (.evaluates typ) inner) - pure (.array typ terms.size, .data (.array typedTerms)) + typedTerms := typedTerms.push ⟨typ, inner, false⟩ + pure ⟨.array typ terms.size, .data (.array typedTerms), false⟩ else throw .emptyArray /-- Infers the type of a 'match' expression and ensures its patterns and branches are valid. -/ partial def inferMatch (term : Term) (branches : List (Pattern × Term)) : CheckM TypedTerm := do - if branches.isEmpty then throw .emptyMatch let (termTyp, termInner) ← inferNoEscape term - let term := .mk (.evaluates termTyp) termInner - let init := ([], .escapes) - let (branches, typ) ← branches.foldrM (init := init) (checkBranch termTyp) - pure $ .mk typ (.match term branches) + let term := ⟨termTyp, termInner, false⟩ + let init := ([], none) + let (branches, typOpt) ← branches.foldrM (init := init) (checkBranch termTyp) + match typOpt with + | some (typ, escapes) => pure ⟨typ, .match term branches, escapes⟩ + | none => throw .emptyMatch where checkBranch patTyp branchData acc := do let (pat, branch) := branchData - let (typedBranches, currentTyp) := acc + let (typedBranches, currentTypOpt) := acc let bindings ← checkPattern pat patTyp - withReader (bindIdents bindings) (match currentTyp with - | .escapes => do + withReader (bindIdents bindings) (match currentTypOpt with + | none => do let typedBranch ← inferTerm branch - pure (typedBranches.cons (pat, typedBranch), typedBranch.typ) - | .evaluates matchTyp => do - -- Some branch didn't escape, so if this branch doesn't escape it must have the same type - -- as the previous non-escaping branch. + pure (typedBranches.cons (pat, typedBranch), some (typedBranch.typ, typedBranch.escapes)) + | some (matchTyp, matchEscapes) => do let typedBranch ← inferTerm branch let typedBranches := typedBranches.cons (pat, typedBranch) - match typedBranch.typ with - | .escapes => pure (typedBranches, currentTyp) - | .evaluates branchTyp => - -- This branch doesn't escape so its type must match the type of the previous non-escaping branch. - unless (matchTyp == branchTyp) do throw $ .branchMismatch matchTyp branchTyp - pure (typedBranches, currentTyp)) + if typedBranch.escapes then + pure (typedBranches, currentTypOpt) + else if matchEscapes then + pure (typedBranches, some (typedBranch.typ, false)) + else + -- Neither branch escapes so their types must match + unless (matchTyp == typedBranch.typ) do throw $ .branchMismatch matchTyp typedBranch.typ + pure (typedBranches, currentTypOpt)) /-- Checks that a pattern matches a given type and collects its bindings. -/ partial def checkPattern (pat : Pattern) (typ : Typ) : CheckM $ List (Local × Typ) := do @@ -451,8 +466,8 @@ where /-- Checks a function to ensure its body's type matches its declared output type. -/ def checkFunction (function : Function) : CheckM TypedFunction := do let body ← inferTerm function.body - if let .evaluates typ := body.typ then - unless typ == function.output do throw $ .typeMismatch typ function.output + unless body.escapes do + unless body.typ == function.output do throw $ .typeMismatch body.typ function.output pure ⟨function.name, function.inputs, function.output, body, function.unconstrained⟩ end Aiur diff --git a/Ix/Aiur/Compile.lean b/Ix/Aiur/Compile.lean index ef62d0e0..ffe6a1c4 100644 --- a/Ix/Aiur/Compile.lean +++ b/Ix/Aiur/Compile.lean @@ -120,7 +120,7 @@ def opLayout : Bytecode.Op → LayoutM Unit pushDegrees $ .replicate 8 1 bumpAuxiliaries 8 bumpLookups - | .u8ShiftLeft _ | .u8ShiftRight _ | .u8Xor .. => do + | .u8ShiftLeft _ | .u8ShiftRight _ | .u8Xor .. | .u8And .. | .u8Or .. => do pushDegree 1 bumpAuxiliaries 1 bumpLookups @@ -246,7 +246,7 @@ partial def toIndex (term : TypedTerm) : StateM CompilerState (Array Bytecode.ValIdx) := match term.inner with -- | .unsafeCast inner castTyp => - -- if typSize layoutMap castTyp != typSize layoutMap term.typ.unwrap then + -- if typSize layoutMap castTyp != typSize layoutMap term.typ then -- panic! "Impossible cast" -- else -- toIndex layoutMap bindings (.mk term.typ inner) @@ -335,8 +335,8 @@ partial def toIndex -- pushOp (Bytecode.Op.preimg layout.index out layout.inputSize) layout.inputSize -- | _ => panic! "should not happen after typechecking" | .proj arg i => do - let typs := match arg.typ with - | .evaluates (.tuple typs) => typs + let typs := match (arg.typ, arg.escapes) with + | (.tuple typs, false) => typs | _ => panic! "Should not happen after typechecking" let offset := (typs.extract 0 i).foldl (init := 0) fun acc typ => typSize layoutMap typ + acc @@ -344,23 +344,23 @@ partial def toIndex let length := typSize layoutMap typs[i]! pure $ arg.extract offset (offset + length) | .get arr i => do - let eltTyp := match arr.typ with - | .evaluates (.array typ _) => typ + let eltTyp := match (arr.typ, arr.escapes) with + | (.array typ _, false) => typ | _ => panic! "Should not happen after typechecking" let eltSize := typSize layoutMap eltTyp let offset := i * eltSize let arr ← toIndex layoutMap bindings arr pure $ arr.extract offset (offset + eltSize) | .slice arr i j => do - let eltTyp := match arr.typ with - | .evaluates (.array typ _) => typ + let eltTyp := match (arr.typ, arr.escapes) with + | (.array typ _, false) => typ | _ => panic! "Should not happen after typechecking" let eltSize := typSize layoutMap eltTyp let arr ← toIndex layoutMap bindings arr pure $ arr.extract (i * eltSize) (j * eltSize) | .set arr i val => do - let eltTyp := match arr.typ with - | .evaluates (.array typ _) => typ + let eltTyp := match (arr.typ, arr.escapes) with + | (.array typ _, false) => typ | _ => panic! "Should not happen after typechecking" let eltSize := typSize layoutMap eltTyp let arr ← toIndex layoutMap bindings arr @@ -372,8 +372,8 @@ partial def toIndex let args ← toIndex layoutMap bindings arg pushOp (.store args) | .load ptr => do - let size := match ptr.typ.unwrap with - | .pointer typ => typSize layoutMap typ + let size := match (ptr.typ, ptr.escapes) with + | (.pointer typ, false) => typSize layoutMap typ | _ => unreachable! let ptr ← expectIdx ptr pushOp (.load size ptr) size @@ -421,6 +421,14 @@ partial def toIndex let i ← expectIdx i let j ← expectIdx j pushOp (.u8Add i j) 2 + | .u8And i j => do + let i ← expectIdx i + let j ← expectIdx j + pushOp (.u8And i j) + | .u8Or i j => do + let i ← expectIdx i + let j ← expectIdx j + pushOp (.u8Or i j) | .debug label term ret => do let term ← term.mapM (toIndex layoutMap bindings) modify fun stt => { stt with ops := stt.ops.push (.debug label term) } @@ -454,7 +462,7 @@ partial def TypedTerm.compile modify fun stt => { stt with ops := stt.ops.push (.debug label term) } ret.compile returnTyp layoutMap bindings | .match term cases => - match term.typ.unwrapOr returnTyp with + match term.typ with -- Also do this for tuple-like and array-like (one constructor only) datatypes | .tuple typs => match cases with | [(.tuple vars, branch)] => do diff --git a/Ix/Aiur/Meta.lean b/Ix/Aiur/Meta.lean index 84c77a1c..48ad07c5 100644 --- a/Ix/Aiur/Meta.lean +++ b/Ix/Aiur/Meta.lean @@ -123,6 +123,8 @@ syntax "u8_shift_left" "(" trm ")" : trm syntax "u8_shift_right" "(" trm ")" : trm syntax "u8_xor" "(" trm ", " trm ")" : trm syntax "u8_add" "(" trm ", " trm ")" : trm +syntax "u8_and" "(" trm ", " trm ")" : trm +syntax "u8_or" "(" trm ", " trm ")" : trm syntax "dbg!" "(" str (", " trm)? ")" ";" (trm)? : trm syntax trm "[" "@" noWs ident "]" : trm @@ -219,6 +221,10 @@ partial def elabTrm : ElabStxCat `trm mkAppM ``Term.u8Xor #[← elabTrm i, ← elabTrm j] | `(trm| u8_add($i:trm, $j:trm)) => do mkAppM ``Term.u8Add #[← elabTrm i, ← elabTrm j] + | `(trm| u8_and($i:trm, $j:trm)) => do + mkAppM ``Term.u8And #[← elabTrm i, ← elabTrm j] + | `(trm| u8_or($i:trm, $j:trm)) => do + mkAppM ``Term.u8Or #[← elabTrm i, ← elabTrm j] | `(trm| dbg!($label:str $[, $t:trm]?); $[$ret:trm]?) => do let t ← match t with | none => mkAppOptM ``Option.none #[some (mkConst ``Term)] @@ -360,6 +366,14 @@ where let i ← replaceToken old new i let j ← replaceToken old new j `(trm| u8_add($i, $j)) + | `(trm| u8_and($i:trm, $j:trm)) => do + let i ← replaceToken old new i + let j ← replaceToken old new j + `(trm| u8_and($i, $j)) + | `(trm| u8_or($i:trm, $j:trm)) => do + let i ← replaceToken old new i + let j ← replaceToken old new j + `(trm| u8_or($i, $j)) | `(trm| dbg!($label:str $[, $t:trm]?); $[$ret:trm]?) => do let t' ← t.mapM $ replaceToken old new let ret' ← ret.mapM $ replaceToken old new diff --git a/Ix/Aiur/Term.lean b/Ix/Aiur/Term.lean index a279f86b..83302c79 100644 --- a/Ix/Aiur/Term.lean +++ b/Ix/Aiur/Term.lean @@ -93,6 +93,8 @@ inductive Term | u8ShiftRight : Term → Term | u8Xor : Term → Term → Term | u8Add : Term → Term → Term + | u8And : Term → Term → Term + | u8Or : Term → Term → Term | debug : String → Option Term → Term → Term deriving Repr, BEq, Hashable, Inhabited @@ -104,19 +106,6 @@ inductive Data end -inductive ContextualType - | evaluates : Typ → ContextualType - | escapes : ContextualType - deriving Repr, BEq, Inhabited - -def ContextualType.unwrap : ContextualType → Typ -| .escapes => panic! "term should not escape" -| .evaluates typ => typ - -def ContextualType.unwrapOr : ContextualType → Typ → Typ -| .escapes => fun typ => typ -| .evaluates typ => fun _ => typ - mutual inductive TypedTermInner | unit @@ -149,12 +138,15 @@ inductive TypedTermInner | u8ShiftRight : TypedTerm → TypedTermInner | u8Xor : TypedTerm → TypedTerm → TypedTermInner | u8Add : TypedTerm → TypedTerm → TypedTermInner + | u8And : TypedTerm → TypedTerm → TypedTermInner + | u8Or : TypedTerm → TypedTerm → TypedTermInner | debug : String → Option TypedTerm → TypedTerm → TypedTermInner deriving Repr, Inhabited structure TypedTerm where - typ : ContextualType + typ : Typ inner : TypedTermInner + escapes : Bool deriving Repr, Inhabited inductive TypedData diff --git a/Tests/Aiur.lean b/Tests/Aiur.lean index a8b647ba..7a1aa741 100644 --- a/Tests/Aiur.lean +++ b/Tests/Aiur.lean @@ -199,6 +199,14 @@ def toplevel := ⟦ (u8_add(i_xor_j, i), u8_add(i_xor_j, j)) } + fn u8_add_and(i: G, j: G) -> G { + u8_and(i, j) + } + + fn u8_add_or(i: G, j: G) -> G { + u8_or(i, j) + } + fn fold_matrix_sum(m: [[G; 2]; 2]) -> G { fold(0 .. 2, 0, |acc_outer, @i| fold(0 .. 2, acc_outer, |acc_inner, @j| @@ -245,6 +253,8 @@ def aiurTestCases : List AiurTestCase := [ ⟨#[1, 2, 3, 4, 1, 2, 3, 4], .ofList [(#[0], ⟨0, 4⟩), (#[1], ⟨0, 8⟩)]⟩⟩, .noIO `shr_shr_shl_decompose #[87] #[0, 1, 0, 1, 0, 1, 0, 0], .noIO `u8_add_xor #[45, 131] #[219, 0, 49, 1], + .noIO `u8_add_and #[45, 131] #[1], + .noIO `u8_add_or #[45, 131] #[175], .noIO `fold_matrix_sum #[1, 2, 3, 4] #[10], ] diff --git a/deny.toml b/deny.toml index 1cb67022..6519dae3 100644 --- a/deny.toml +++ b/deny.toml @@ -74,6 +74,8 @@ ignore = [ "RUSTSEC-2024-0384", # `instant` crate is unmaintained "RUSTSEC-2024-0370", # `proc-macro-error` crate is unmaintained "RUSTSEC-2023-0089", # `atomic-polyfill` crate is unmaintained + "RUSTSEC-2025-0141", # `bincode` crate is unmaintained + #{ id = "RUSTSEC-0000-0000", reason = "you can specify a reason the advisory is ignored" }, #"a-crate-that-is-yanked@0.1.1", # you can also ignore yanked crate versions if you wish #{ crate = "a-crate-that-is-yanked@0.1.1", reason = "you can specify why you are ignoring the yanked crate" }, diff --git a/src/aiur/bytecode.rs b/src/aiur/bytecode.rs index 51c3f8cb..43331947 100644 --- a/src/aiur/bytecode.rs +++ b/src/aiur/bytecode.rs @@ -53,6 +53,8 @@ pub enum Op { U8ShiftRight(ValIdx), U8Xor(ValIdx, ValIdx), U8Add(ValIdx, ValIdx), + U8And(ValIdx, ValIdx), + U8Or(ValIdx, ValIdx), Debug(String, Option>), } diff --git a/src/aiur/constraints.rs b/src/aiur/constraints.rs index 48993716..9c1252dd 100644 --- a/src/aiur/constraints.rs +++ b/src/aiur/constraints.rs @@ -16,8 +16,8 @@ use crate::aiur::{ bytes1::{Bytes1, Bytes1Op}, bytes2::{Bytes2, Bytes2Op}, }, - memory_channel, u8_add_channel, u8_bit_decomposition_channel, - u8_shift_left_channel, u8_shift_right_channel, u8_xor_channel, + memory_channel, u8_add_channel, u8_and_channel, u8_bit_decomposition_channel, + u8_or_channel, u8_shift_left_channel, u8_shift_right_channel, u8_xor_channel, }; type Expr = SymbolicExpression; @@ -415,6 +415,22 @@ impl Op { sel.clone(), state, ), + Op::U8And(i, j) => bytes2_constraints( + *i, + *j, + &Bytes2Op::And, + u8_and_channel(), + sel.clone(), + state, + ), + Op::U8Or(i, j) => bytes2_constraints( + *i, + *j, + &Bytes2Op::Or, + u8_or_channel(), + sel.clone(), + state, + ), Op::IOSetInfo(..) | Op::IOWrite(_) | Op::Debug(..) => (), } } diff --git a/src/aiur/execute.rs b/src/aiur/execute.rs index 6d25c99d..8030f5b0 100644 --- a/src/aiur/execute.rs +++ b/src/aiur/execute.rs @@ -271,6 +271,20 @@ impl Function { bytes2_execute(*i, *j, &Bytes2Op::Add, &mut map, record); } }, + ExecEntry::Op(Op::U8And(i, j)) => { + if unconstrained { + map.push(Bytes2::and(&map[*i], &map[*j])); + } else { + bytes2_execute(*i, *j, &Bytes2Op::And, &mut map, record); + } + }, + ExecEntry::Op(Op::U8Or(i, j)) => { + if unconstrained { + map.push(Bytes2::or(&map[*i], &map[*j])); + } else { + bytes2_execute(*i, *j, &Bytes2Op::Or, &mut map, record); + } + }, ExecEntry::Op(Op::Debug(label, idxs)) => match idxs { None => println!("{label}"), Some(idxs) => { diff --git a/src/aiur/gadgets/bytes2.rs b/src/aiur/gadgets/bytes2.rs index fc7e8044..f928af62 100644 --- a/src/aiur/gadgets/bytes2.rs +++ b/src/aiur/gadgets/bytes2.rs @@ -7,13 +7,16 @@ use multi_stark::{ }; use crate::aiur::{ - G, execute::QueryRecord, gadgets::AiurGadget, u8_add_channel, u8_xor_channel, + G, execute::QueryRecord, gadgets::AiurGadget, u8_add_channel, u8_and_channel, + u8_or_channel, u8_xor_channel, }; /// Number of columns in the trace with multiplicities for /// - xor /// - overflowing add -const TRACE_WIDTH: usize = 2; +/// - and +/// - or +const TRACE_WIDTH: usize = 4; /// Number of columns in the preprocessed trace: /// - first raw byte value @@ -21,7 +24,9 @@ const TRACE_WIDTH: usize = 2; /// - xor result /// - add result /// - add overflow -const PREPROCESSED_TRACE_WIDTH: usize = 5; +/// - and result +/// - or result +const PREPROCESSED_TRACE_WIDTH: usize = 7; /// AIR implementer for arity 2 byte-related lookups. pub(crate) struct Bytes2; @@ -29,6 +34,8 @@ pub(crate) struct Bytes2; pub(crate) enum Bytes2Op { Xor, Add, + And, + Or, } impl BaseAir for Bytes2 { @@ -53,6 +60,12 @@ impl BaseAir for Bytes2 { let (r, o) = i.overflowing_add(j); trace_values.push(G::from_u8(r)); trace_values.push(G::from_bool(o)); + + // And + trace_values.push(G::from_u8(i & j)); + + // Or + trace_values.push(G::from_u8(i | j)); } } Some(RowMajorMatrix::new(trace_values, PREPROCESSED_TRACE_WIDTH)) @@ -69,7 +82,7 @@ impl AiurGadget for Bytes2 { fn output_size(&self, op: &Bytes2Op) -> usize { match op { - Bytes2Op::Xor => 1, + Bytes2Op::Xor | Bytes2Op::And | Bytes2Op::Or => 1, Bytes2Op::Add => 2, } } @@ -92,6 +105,14 @@ impl AiurGadget for Bytes2 { let (r, o) = Self::add(i, j); vec![r, o] }, + Bytes2Op::And => { + record.bytes2_queries.bump_and(i, j); + vec![Self::and(i, j)] + }, + Bytes2Op::Or => { + record.bytes2_queries.bump_or(i, j); + vec![Self::or(i, j)] + }, } } @@ -99,10 +120,14 @@ impl AiurGadget for Bytes2 { // Channels let xor_channel = u8_xor_channel().into(); let add_channel = u8_add_channel().into(); + let and_channel = u8_and_channel().into(); + let or_channel = u8_or_channel().into(); // Multiplicity columns let xor_multiplicity = var(0); let add_multiplicity = var(1); + let and_multiplicity = var(2); + let or_multiplicity = var(3); // Preprocessed columns let i = preprocessed_var(0); @@ -110,16 +135,27 @@ impl AiurGadget for Bytes2 { let xor = preprocessed_var(2); let add_r = preprocessed_var(3); let add_o = preprocessed_var(4); + let and = preprocessed_var(5); + let or = preprocessed_var(6); let pull_xor = Lookup::pull( xor_multiplicity, vec![xor_channel, i.clone(), j.clone(), xor], ); - let pull_add = - Lookup::pull(add_multiplicity, vec![add_channel, i, j, add_r, add_o]); + let pull_add = Lookup::pull( + add_multiplicity, + vec![add_channel, i.clone(), j.clone(), add_r, add_o], + ); + + let pull_and = Lookup::pull( + and_multiplicity, + vec![and_channel, i.clone(), j.clone(), and], + ); + + let pull_or = Lookup::pull(or_multiplicity, vec![or_channel, i, j, or]); - vec![pull_xor, pull_add] + vec![pull_xor, pull_add, pull_and, pull_or] } fn witness_data( @@ -133,18 +169,22 @@ impl AiurGadget for Bytes2 { let xor_channel = u8_xor_channel(); let add_channel = u8_add_channel(); + let and_channel = u8_and_channel(); + let or_channel = u8_or_channel(); rows .chunks_exact_mut(TRACE_WIDTH) .enumerate() .zip(&record.bytes2_queries.0) .zip(&mut lookups) - .for_each(|(((row_idx, row), &[xor, add]), row_lookups)| { + .for_each(|(((row_idx, row), &[xor, add, and, or]), row_lookups)| { let i = G::from_usize(row_idx / 256); let j = G::from_usize(row_idx % 256); row[0] = xor; row[1] = add; + row[2] = and; + row[3] = or; // Pull xor. row_lookups[0] = @@ -153,6 +193,14 @@ impl AiurGadget for Bytes2 { // Pull add. let (r, o) = Self::add(&i, &j); row_lookups[1] = Lookup::pull(add, vec![add_channel, i, j, r, o]); + + // Pull and. + row_lookups[2] = + Lookup::pull(and, vec![and_channel, i, j, Self::and(&i, &j)]); + + // Pull or. + row_lookups[3] = + Lookup::pull(or, vec![or_channel, i, j, Self::or(&i, &j)]); }); (RowMajorMatrix::new(rows, TRACE_WIDTH), lookups) } @@ -175,6 +223,14 @@ impl Bytes2Queries { self.bump_multiplicity_for(i, j, 1) } + fn bump_and(&mut self, i: &G, j: &G) { + self.bump_multiplicity_for(i, j, 2) + } + + fn bump_or(&mut self, i: &G, j: &G) { + self.bump_multiplicity_for(i, j, 3) + } + fn bump_multiplicity_for(&mut self, i: &G, j: &G, col: usize) { let i = usize::try_from(i.as_canonical_u64()).unwrap(); let j = usize::try_from(j.as_canonical_u64()).unwrap(); @@ -198,4 +254,18 @@ impl Bytes2 { let (r, o) = i.overflowing_add(j); (G::from_u8(r), G::from_bool(o)) } + + #[inline] + pub(crate) fn and(i: &G, j: &G) -> G { + let i: u8 = i.as_canonical_u64().try_into().unwrap(); + let j: u8 = j.as_canonical_u64().try_into().unwrap(); + G::from_u8(i & j) + } + + #[inline] + pub(crate) fn or(i: &G, j: &G) -> G { + let i: u8 = i.as_canonical_u64().try_into().unwrap(); + let j: u8 = j.as_canonical_u64().try_into().unwrap(); + G::from_u8(i | j) + } } diff --git a/src/aiur/mod.rs b/src/aiur/mod.rs index df09e145..26e49bb9 100644 --- a/src/aiur/mod.rs +++ b/src/aiur/mod.rs @@ -44,3 +44,13 @@ pub fn u8_xor_channel() -> G { pub fn u8_add_channel() -> G { G::from_u8(6) } + +#[inline] +pub fn u8_and_channel() -> G { + G::from_u8(7) +} + +#[inline] +pub fn u8_or_channel() -> G { + G::from_u8(8) +} diff --git a/src/aiur/trace.rs b/src/aiur/trace.rs index 4b0e17c7..a87f4c5f 100644 --- a/src/aiur/trace.rs +++ b/src/aiur/trace.rs @@ -17,8 +17,8 @@ use crate::aiur::{ function_channel, gadgets::{bytes1::Bytes1, bytes2::Bytes2}, memory::Memory, - u8_add_channel, u8_bit_decomposition_channel, u8_shift_left_channel, - u8_shift_right_channel, u8_xor_channel, + u8_add_channel, u8_and_channel, u8_bit_decomposition_channel, u8_or_channel, + u8_shift_left_channel, u8_shift_right_channel, u8_xor_channel, }; struct ColumnIndex { @@ -374,6 +374,24 @@ impl Op { let lookup_args = vec![u8_add_channel(), i, j, r, o]; slice.push_lookup(index, Lookup::push(G::ONE, lookup_args)); }, + Op::U8And(i, j) => { + let (i, _) = map[*i]; + let (j, _) = map[*j]; + let and = Bytes2::and(&i, &j); + map.push((and, 1)); + slice.push_auxiliary(index, and); + let lookup_args = vec![u8_and_channel(), i, j, and]; + slice.push_lookup(index, Lookup::push(G::ONE, lookup_args)); + }, + Op::U8Or(i, j) => { + let (i, _) = map[*i]; + let (j, _) = map[*j]; + let or = Bytes2::or(&i, &j); + map.push((or, 1)); + slice.push_auxiliary(index, or); + let lookup_args = vec![u8_or_channel(), i, j, or]; + slice.push_lookup(index, Lookup::push(G::ONE, lookup_args)); + }, Op::AssertEq(..) | Op::IOSetInfo(..) | Op::IOWrite(_) | Op::Debug(..) => { }, } diff --git a/src/lean/ffi/aiur/toplevel.rs b/src/lean/ffi/aiur/toplevel.rs index 1b4ebd84..8789665e 100644 --- a/src/lean/ffi/aiur/toplevel.rs +++ b/src/lean/ffi/aiur/toplevel.rs @@ -117,6 +117,14 @@ fn lean_ptr_to_op(ptr: *const c_void) -> Op { Op::U8Add(i, j) }, 18 => { + let [i, j] = ctor.objs().map(lean_unbox_nat_as_usize); + Op::U8And(i, j) + }, + 19 => { + let [i, j] = ctor.objs().map(lean_unbox_nat_as_usize); + Op::U8Or(i, j) + }, + 20 => { let [label_ptr, idxs_ptr] = ctor.objs(); let label_str: &LeanStringObject = as_ref_unsafe(label_ptr.cast()); let label = label_str.as_string();