From c7e384c32cababb4bc32bc42e455523c7d6c6719 Mon Sep 17 00:00:00 2001 From: Gabriel Barreto Date: Fri, 9 Jan 2026 20:12:55 -0300 Subject: [PATCH 1/5] Simplified types No more `ContextualType`. Instead, `TypedInner` carries an `escape` field --- Ix/Aiur/Check.lean | 168 ++++++++++++++++++++++--------------------- Ix/Aiur/Compile.lean | 24 +++---- Ix/Aiur/Term.lean | 16 +---- 3 files changed, 100 insertions(+), 108 deletions(-) diff --git a/Ix/Aiur/Check.lean b/Ix/Aiur/Check.lean index 53b59321..56562b31 100644 --- a/Ix/Aiur/Check.lean +++ b/Ix/Aiur/Check.lean @@ -80,37 +80,37 @@ def bindIdents (bindings : List (Local × Typ)) (ctx : CheckContext) : CheckCont mutual partial def inferTerm : Term → CheckM TypedTerm - | .unit => pure $ .mk (.evaluates .unit) .unit + | .unit => pure ⟨.unit, .unit, false⟩ | .var x => do -- Retrieves and returns the variable type from the context. let ctx ← read match ctx.varTypes[x]? with - | some t => pure $ .mk (.evaluates t) (.var x) + | some t => pure ⟨t, .var x, false⟩ | none => let Local.str localName := x | unreachable! - let typ := .evaluates (← refLookup (Global.init localName)) - pure $ .mk typ (.var x) + let typ ← refLookup (Global.init localName) + pure ⟨typ, .var x, false⟩ | .ref x => do - let typ := .evaluates (← refLookup x) - pure $ .mk typ (.ref x) + let typ ← refLookup x + pure ⟨typ, .ref x, false⟩ | .ret term => do -- Ensures that the type of the returned term matches the expected return type. -- The term is not allowed to have a (nested) return. -- Returning the type of the term is not necessary because it's already in the context. let ctx ← read let inner ← checkNoEscape term ctx.returnType - pure $ .mk .escapes inner + pure ⟨ctx.returnType, inner, true⟩ | .data data => do let (typ, inner) ← inferData data - pure $ .mk (.evaluates typ) inner + pure ⟨typ, inner, false⟩ | .let pat expr body => do -- Returns the type of the body, inferred in the context extended with the bound variable type. -- The bound variable is ensured not to escape. let (exprTyp, exprInner) ← inferNoEscape expr - let expr' := .mk (.evaluates exprTyp) exprInner + let expr' := ⟨exprTyp, exprInner, false⟩ let bindings ← checkPattern pat exprTyp let body' ← withReader (bindIdents bindings) (inferTerm body) - pure $ .mk body'.typ (.let pat expr' body') + pure ⟨body'.typ, .let pat expr' body', body'.escapes⟩ | .match term branches => inferMatch term branches | .app func@(⟨.str .anonymous unqualifiedFunc⟩) args => do -- Ensures the function exists in the context and that the arguments, which aren't allowed to @@ -119,15 +119,15 @@ partial def inferTerm : Term → CheckM TypedTerm match ctx.varTypes[Local.str unqualifiedFunc]? with | some (.function inputs output) => do let args ← checkArgsAndInputs func args inputs - pure $ .mk (.evaluates output) (.app func args) + pure ⟨output, .app func args, false⟩ | some _ => throw $ .notAFunction func | none => match ctx.decls.getByKey func with | some (.function function) => do let args ← checkArgsAndInputs func args (function.inputs.map Prod.snd) - pure $ .mk (.evaluates function.output) (.app func args) + pure ⟨function.output, .app func args, false⟩ | some (.constructor dataType constr) => do let args ← checkArgsAndInputs func args constr.argTypes - pure $ .mk (.evaluates (.dataType dataType.name)) (.app func args) + pure ⟨.dataType dataType.name, .app func args, false⟩ | _ => throw $ .cannotApply func | .app func args => do -- Only checks global map if it is not unqualified @@ -135,29 +135,29 @@ partial def inferTerm : Term → CheckM TypedTerm match ctx.decls.getByKey func with | some (.function function) => let args ← checkArgsAndInputs func args (function.inputs.map Prod.snd) - pure $ .mk (.evaluates function.output) (.app func args) + pure ⟨function.output, .app func args, false⟩ | some (.constructor dataType constr) => let args ← checkArgsAndInputs func args constr.argTypes - pure $ .mk (.evaluates (.dataType dataType.name)) (.app func args) + pure ⟨.dataType dataType.name, .app func args, false⟩ | _ => throw $ .cannotApply func | .add a b => do - let (ctxTyp, a, b) ← checkArith a b - pure $ .mk ctxTyp (.add a b) + let (a, b) ← checkArith a b + pure ⟨.field, .add a b, false⟩ | .sub a b => do - let (ctxTyp, a, b) ← checkArith a b - pure $ .mk ctxTyp (.sub a b) + let (a, b) ← checkArith a b + pure ⟨.field, .sub a b, false⟩ | .mul a b => do - let (ctxTyp, a, b) ← checkArith a b - pure $ .mk ctxTyp (.mul a b) + let (a, b) ← checkArith a b + pure ⟨.field, .mul a b, false⟩ | .eqZero a => do let a ← fieldTerm <$> checkNoEscape a .field - pure $ .mk (.evaluates .field) (.eqZero a) + pure ⟨.field, .eqZero a, false⟩ | .proj tup i => do let (typs, tupInner) ← inferTuple tup if h : i < typs.size then let typ := typs[i] - let tup := .mk (.evaluates (.tuple typs)) tupInner - pure $ .mk (.evaluates typ) (.proj tup i) + let tup := ⟨.tuple typs, tupInner, false⟩ + pure ⟨typ, .proj tup i, false⟩ else throw $ .indexOoB i | .get arr i => do @@ -165,13 +165,13 @@ partial def inferTerm : Term → CheckM TypedTerm if i ≥ n then throw $ .indexOoB i else - let arr := .mk (.evaluates (.array typ n)) inner - pure $ .mk (.evaluates typ) (.get arr i) + let arr := ⟨.array typ n, inner, false⟩ + pure ⟨typ, .get arr i, false⟩ | .slice arr i j => if j < i then throw $ .negativeRange i j else do let (typ, n, inner) ← inferArray arr if j ≤ n then - let arr := .mk (.evaluates (.array typ n)) inner - pure $ .mk (.evaluates (.array typ (j - i))) (.slice arr i j) + let arr := ⟨.array typ n, inner, false⟩ + pure ⟨.array typ (j - i), .slice arr i j, false⟩ else throw $ .rangeOoB i j | .set arr i val => do @@ -181,50 +181,51 @@ partial def inferTerm : Term → CheckM TypedTerm else let val ← checkNoEscape val typ let arrTyp := .array typ n - let arr := .mk (.evaluates arrTyp) inner - let val := .mk (.evaluates typ) val - pure $ .mk (.evaluates arrTyp) (.set arr i val) + let arr := ⟨arrTyp, inner, false⟩ + let val := ⟨typ, val, false⟩ + pure ⟨arrTyp, .set arr i val, false⟩ | .store term => do -- Infers the type of the term and returns it, wrapped by a pointer type. -- The term is not allowed to early return. let (typ, inner) ← inferNoEscape term - let store := .store (.mk (.evaluates typ) inner) - pure $ .mk (.evaluates (.pointer typ)) store + let store := .store ⟨typ, inner, false⟩ + pure ⟨.pointer typ, store, false⟩ | .load term => do -- Ensures that the the type of the term is a pointer type and returns the unwrapped type. -- The term is not allowed to early return. let (typ, inner) ← inferNoEscape term match typ with | .pointer innerTyp => - let load := .load (.mk (.evaluates typ) inner) - pure $ .mk (.evaluates innerTyp) load + let load := .load ⟨typ, inner, false⟩ + pure ⟨innerTyp, load, false⟩ | _ => throw $ .notAPointer typ | .ptrVal term => do -- Infers the type of the term, which must be a pointer, but returns `.u64`, as in a cast. -- The term is not allowed to early return. let (typ, inner) ← inferNoEscape term match typ with - | .pointer _ => pure $ fieldTerm (.ptrVal (.mk (.evaluates typ) inner)) + | .pointer _ => pure $ fieldTerm (.ptrVal ⟨typ, inner, false⟩) | _ => throw $ .notAPointer typ | .ann typ term => do let inner ← checkNoEscape term typ - pure $ .mk (.evaluates typ) inner + pure ⟨typ, inner, false⟩ -- | .unsafeCast term castTyp => do -- let (typ, inner) ← inferNoEscape term - -- pure $ .mk (.evaluates typ) (.unsafeCast inner castTyp) + -- pure ⟨typ, .unsafeCast inner castTyp, false⟩ | .assertEq a b ret => do -- `a` and `b` must have the same type. let (typ, a) ← inferNoEscape a let b ← checkNoEscape b typ let ret ← inferTerm ret - let assertEq := .assertEq (.mk (.evaluates typ) a) (.mk (.evaluates typ) b) ret - pure $ .mk (.evaluates ret.typ.unwrap) assertEq + let assertEq := .assertEq ⟨typ, a, false⟩ ⟨typ, b, false⟩ ret + assert! (ret.escapes == false) + pure ⟨ret.typ, assertEq, false⟩ | .ioGetInfo key => do let (typ, keyInner) ← inferNoEscape key match typ with | .array .. => - let ioGetInfo := .ioGetInfo (.mk (.evaluates typ) keyInner) - pure $ .mk (.evaluates (.tuple #[.field, .field])) ioGetInfo + let ioGetInfo := .ioGetInfo ⟨typ, keyInner, false⟩ + pure ⟨.tuple #[.field, .field], ioGetInfo, false⟩ | _ => throw $ .notAnArray typ | .ioSetInfo key idx len ret => do let (keyTyp, keyInner) ← inferNoEscape key @@ -234,50 +235,53 @@ partial def inferTerm : Term → CheckM TypedTerm let idx ← fieldTerm <$> checkNoEscape idx .field let len ← fieldTerm <$> checkNoEscape len .field let ret ← inferTerm ret - let ioSetInfo := .ioSetInfo (.mk (.evaluates keyTyp) keyInner) idx len ret - pure $ .mk (.evaluates ret.typ.unwrap) ioSetInfo + let ioSetInfo := .ioSetInfo ⟨keyTyp, keyInner, false⟩ idx len ret + assert! (ret.escapes == false) + pure ⟨ret.typ, ioSetInfo, false⟩ | _ => throw $ .notAnArray keyTyp | .ioRead idx len => do if len = 0 then throw .emptyArray let idx ← fieldTerm <$> checkNoEscape idx .field let ioRead := .ioRead idx len - pure $ .mk (.evaluates (.array .field len)) ioRead + pure ⟨.array .field len, ioRead, false⟩ | .ioWrite data ret => do let (dataTyp, dataInner) ← inferNoEscape data match dataTyp with | .array dataEltTyp _ => if dataEltTyp != .field then throw $ .typeMismatch .field dataEltTyp let ret ← inferTerm ret - let ioWrite := .ioWrite (.mk (.evaluates dataTyp) dataInner) ret - pure $ .mk (.evaluates ret.typ.unwrap) ioWrite + let ioWrite := .ioWrite ⟨dataTyp, dataInner, false⟩ ret + assert! (ret.escapes == false) + pure ⟨ret.typ, ioWrite, false⟩ | _ => throw $ .notAnArray dataTyp | .u8BitDecomposition byte => do let byte ← fieldTerm <$> checkNoEscape byte .field let u8BitDecomposition := .u8BitDecomposition byte - pure $ .mk (.evaluates (.array .field 8)) u8BitDecomposition + pure ⟨.array .field 8, u8BitDecomposition, false⟩ | .u8ShiftLeft byte => do let byte ← fieldTerm <$> checkNoEscape byte .field let u8ShiftLeft := .u8ShiftLeft byte - pure $ .mk (.evaluates .field) u8ShiftLeft + pure ⟨.field, u8ShiftLeft, false⟩ | .u8ShiftRight byte => do let byte ← fieldTerm <$> checkNoEscape byte .field let u8ShiftRight := .u8ShiftRight byte - pure $ .mk (.evaluates .field) u8ShiftRight + pure ⟨.field, u8ShiftRight, false⟩ | .u8Xor i j => do let i ← fieldTerm <$> checkNoEscape i .field let j ← fieldTerm <$> checkNoEscape j .field let u8Xor := .u8Xor i j - pure $ .mk (.evaluates .field) u8Xor + pure ⟨.field, u8Xor, false⟩ | .u8Add i j => do let i ← fieldTerm <$> checkNoEscape i .field let j ← fieldTerm <$> checkNoEscape j .field let u8Add := .u8Add i j - pure $ .mk (.evaluates (.tuple #[.field, .field])) u8Add + pure ⟨.tuple #[.field, .field], u8Add, false⟩ | .debug label term ret => do let term ← term.mapM inferTerm let ret ← inferTerm ret let debug := .debug label term ret - pure $ .mk (.evaluates ret.typ.unwrap) debug + assert! (ret.escapes == false) + pure ⟨ret.typ, debug, false⟩ where /-- Ensures that there are as many arguments and as expected types and that @@ -289,13 +293,13 @@ where unless lenArgs == lenInputs do throw $ .wrongNumArgs func lenArgs lenInputs let pass := fun (arg, input) => do let inner ← checkNoEscape arg input - pure $ .mk (.evaluates input) inner + pure ⟨input, inner, false⟩ args.zip inputs |>.mapM pass checkArith a b := do let aInner ← checkNoEscape a .field let bInner ← checkNoEscape b .field - pure (.evaluates .field, fieldTerm aInner, fieldTerm bInner) - fieldTerm := (.mk (.evaluates .field) ·) + pure (fieldTerm aInner, fieldTerm bInner) + fieldTerm inner := ⟨.field, inner, false⟩ partial def checkNoEscape (term : Term) (typ : Typ) : CheckM TypedTermInner := do let (typ', inner) ← inferNoEscape term @@ -304,56 +308,56 @@ partial def checkNoEscape (term : Term) (typ : Typ) : CheckM TypedTermInner := d partial def inferNoEscape (term : Term) : CheckM (Typ × TypedTermInner) := do let typedTerm ← inferTerm term - match typedTerm.typ with - | .escapes => throw .illegalReturn - | .evaluates type => pure (type, typedTerm.inner) + if typedTerm.escapes then throw .illegalReturn + pure (typedTerm.typ, typedTerm.inner) partial def inferData : Data → CheckM (Typ × TypedTermInner) | .field g => pure (.field, .data (.field g)) | .tuple terms => do let typsAndInners ← terms.mapM inferNoEscape let typs := typsAndInners.map Prod.fst - let terms := typsAndInners.map fun (typ, inner) => .mk (.evaluates typ) inner + let terms := typsAndInners.map fun (typ, inner) => ⟨typ, inner, false⟩ pure (.tuple typs, .data (.tuple terms)) | .array terms => do if h : terms.size > 0 then let (typ, firstInner) ← inferNoEscape terms[0] let mut typedTerms := Array.mkEmpty terms.size - |>.push (.mk (.evaluates typ) firstInner) + |>.push ⟨typ, firstInner, false⟩ for term in terms[1:] do let inner ← checkNoEscape term typ - typedTerms := typedTerms.push (.mk (.evaluates typ) inner) + typedTerms := typedTerms.push ⟨typ, inner, false⟩ pure (.array typ terms.size, .data (.array typedTerms)) else throw .emptyArray /-- Infers the type of a 'match' expression and ensures its patterns and branches are valid. -/ partial def inferMatch (term : Term) (branches : List (Pattern × Term)) : CheckM TypedTerm := do - if branches.isEmpty then throw .emptyMatch let (termTyp, termInner) ← inferNoEscape term - let term := .mk (.evaluates termTyp) termInner - let init := ([], .escapes) - let (branches, typ) ← branches.foldrM (init := init) (checkBranch termTyp) - pure $ .mk typ (.match term branches) + let term := ⟨termTyp, termInner, false⟩ + let init := ([], none) + let (branches, typOpt) ← branches.foldrM (init := init) (checkBranch termTyp) + match typOpt with + | some (typ, escapes) => pure ⟨typ, .match term branches, escapes⟩ + | none => throw .emptyMatch where checkBranch patTyp branchData acc := do let (pat, branch) := branchData - let (typedBranches, currentTyp) := acc + let (typedBranches, currentTypOpt) := acc let bindings ← checkPattern pat patTyp - withReader (bindIdents bindings) (match currentTyp with - | .escapes => do + withReader (bindIdents bindings) (match currentTypOpt with + | none => do let typedBranch ← inferTerm branch - pure (typedBranches.cons (pat, typedBranch), typedBranch.typ) - | .evaluates matchTyp => do - -- Some branch didn't escape, so if this branch doesn't escape it must have the same type - -- as the previous non-escaping branch. + pure (typedBranches.cons (pat, typedBranch), some (typedBranch.typ, typedBranch.escapes)) + | some (matchTyp, matchEscapes) => do let typedBranch ← inferTerm branch let typedBranches := typedBranches.cons (pat, typedBranch) - match typedBranch.typ with - | .escapes => pure (typedBranches, currentTyp) - | .evaluates branchTyp => - -- This branch doesn't escape so its type must match the type of the previous non-escaping branch. - unless (matchTyp == branchTyp) do throw $ .branchMismatch matchTyp branchTyp - pure (typedBranches, currentTyp)) + if typedBranch.escapes then + pure (typedBranches, currentTypOpt) + else if matchEscapes then + pure (typedBranches, some (typedBranch.typ, false)) + else + -- Neither branch escapes so their types must match + unless (matchTyp == typedBranch.typ) do throw $ .branchMismatch matchTyp typedBranch.typ + pure (typedBranches, currentTypOpt)) /-- Checks that a pattern matches a given type and collects its bindings. -/ partial def checkPattern (pat : Pattern) (typ : Typ) : CheckM $ List (Local × Typ) := do @@ -451,8 +455,8 @@ where /-- Checks a function to ensure its body's type matches its declared output type. -/ def checkFunction (function : Function) : CheckM TypedFunction := do let body ← inferTerm function.body - if let .evaluates typ := body.typ then - unless typ == function.output do throw $ .typeMismatch typ function.output + unless body.escapes do + unless body.typ == function.output do throw $ .typeMismatch body.typ function.output pure ⟨function.name, function.inputs, function.output, body, function.unconstrained⟩ end Aiur diff --git a/Ix/Aiur/Compile.lean b/Ix/Aiur/Compile.lean index ef62d0e0..1f8a3054 100644 --- a/Ix/Aiur/Compile.lean +++ b/Ix/Aiur/Compile.lean @@ -246,7 +246,7 @@ partial def toIndex (term : TypedTerm) : StateM CompilerState (Array Bytecode.ValIdx) := match term.inner with -- | .unsafeCast inner castTyp => - -- if typSize layoutMap castTyp != typSize layoutMap term.typ.unwrap then + -- if typSize layoutMap castTyp != typSize layoutMap term.typ then -- panic! "Impossible cast" -- else -- toIndex layoutMap bindings (.mk term.typ inner) @@ -335,8 +335,8 @@ partial def toIndex -- pushOp (Bytecode.Op.preimg layout.index out layout.inputSize) layout.inputSize -- | _ => panic! "should not happen after typechecking" | .proj arg i => do - let typs := match arg.typ with - | .evaluates (.tuple typs) => typs + let typs := match (arg.typ, arg.escapes) with + | (.tuple typs, false) => typs | _ => panic! "Should not happen after typechecking" let offset := (typs.extract 0 i).foldl (init := 0) fun acc typ => typSize layoutMap typ + acc @@ -344,23 +344,23 @@ partial def toIndex let length := typSize layoutMap typs[i]! pure $ arg.extract offset (offset + length) | .get arr i => do - let eltTyp := match arr.typ with - | .evaluates (.array typ _) => typ + let eltTyp := match (arr.typ, arr.escapes) with + | (.array typ _, false) => typ | _ => panic! "Should not happen after typechecking" let eltSize := typSize layoutMap eltTyp let offset := i * eltSize let arr ← toIndex layoutMap bindings arr pure $ arr.extract offset (offset + eltSize) | .slice arr i j => do - let eltTyp := match arr.typ with - | .evaluates (.array typ _) => typ + let eltTyp := match (arr.typ, arr.escapes) with + | (.array typ _, false) => typ | _ => panic! "Should not happen after typechecking" let eltSize := typSize layoutMap eltTyp let arr ← toIndex layoutMap bindings arr pure $ arr.extract (i * eltSize) (j * eltSize) | .set arr i val => do - let eltTyp := match arr.typ with - | .evaluates (.array typ _) => typ + let eltTyp := match (arr.typ, arr.escapes) with + | (.array typ _, false) => typ | _ => panic! "Should not happen after typechecking" let eltSize := typSize layoutMap eltTyp let arr ← toIndex layoutMap bindings arr @@ -372,8 +372,8 @@ partial def toIndex let args ← toIndex layoutMap bindings arg pushOp (.store args) | .load ptr => do - let size := match ptr.typ.unwrap with - | .pointer typ => typSize layoutMap typ + let size := match (ptr.typ, ptr.escapes) with + | (.pointer typ, false) => typSize layoutMap typ | _ => unreachable! let ptr ← expectIdx ptr pushOp (.load size ptr) size @@ -454,7 +454,7 @@ partial def TypedTerm.compile modify fun stt => { stt with ops := stt.ops.push (.debug label term) } ret.compile returnTyp layoutMap bindings | .match term cases => - match term.typ.unwrapOr returnTyp with + match term.typ with -- Also do this for tuple-like and array-like (one constructor only) datatypes | .tuple typs => match cases with | [(.tuple vars, branch)] => do diff --git a/Ix/Aiur/Term.lean b/Ix/Aiur/Term.lean index a279f86b..ac5670b1 100644 --- a/Ix/Aiur/Term.lean +++ b/Ix/Aiur/Term.lean @@ -104,19 +104,6 @@ inductive Data end -inductive ContextualType - | evaluates : Typ → ContextualType - | escapes : ContextualType - deriving Repr, BEq, Inhabited - -def ContextualType.unwrap : ContextualType → Typ -| .escapes => panic! "term should not escape" -| .evaluates typ => typ - -def ContextualType.unwrapOr : ContextualType → Typ → Typ -| .escapes => fun typ => typ -| .evaluates typ => fun _ => typ - mutual inductive TypedTermInner | unit @@ -153,8 +140,9 @@ inductive TypedTermInner deriving Repr, Inhabited structure TypedTerm where - typ : ContextualType + typ : Typ inner : TypedTermInner + escapes : Bool deriving Repr, Inhabited inductive TypedData From 317ea26b2537814037c4df30d68c786cbb8c7658 Mon Sep 17 00:00:00 2001 From: Gabriel Barreto Date: Fri, 9 Jan 2026 22:38:42 -0300 Subject: [PATCH 2/5] Infer term refactor into many smaller functions --- Ix/Aiur/Check.lean | 435 +++++++++++++++++++++++---------------------- 1 file changed, 222 insertions(+), 213 deletions(-) diff --git a/Ix/Aiur/Check.lean b/Ix/Aiur/Check.lean index 56562b31..6be432fa 100644 --- a/Ix/Aiur/Check.lean +++ b/Ix/Aiur/Check.lean @@ -78,215 +78,201 @@ def refLookup (global : Global) : CheckM Typ := do def bindIdents (bindings : List (Local × Typ)) (ctx : CheckContext) : CheckContext := { ctx with varTypes := ctx.varTypes.insertMany bindings } +def fieldTerm (t : TypedTermInner) : TypedTerm := ⟨.field, t, false⟩ + mutual -partial def inferTerm : Term → CheckM TypedTerm +partial def inferTerm (t : Term) : CheckM TypedTerm := match t with | .unit => pure ⟨.unit, .unit, false⟩ - | .var x => do - -- Retrieves and returns the variable type from the context. - let ctx ← read - match ctx.varTypes[x]? with - | some t => pure ⟨t, .var x, false⟩ - | none => - let Local.str localName := x | unreachable! - let typ ← refLookup (Global.init localName) - pure ⟨typ, .var x, false⟩ + | .var x => inferVariable x | .ref x => do - let typ ← refLookup x - pure ⟨typ, .ref x, false⟩ - | .ret term => do - -- Ensures that the type of the returned term matches the expected return type. - -- The term is not allowed to have a (nested) return. - -- Returning the type of the term is not necessary because it's already in the context. - let ctx ← read - let inner ← checkNoEscape term ctx.returnType - pure ⟨ctx.returnType, inner, true⟩ - | .data data => do - let (typ, inner) ← inferData data - pure ⟨typ, inner, false⟩ - | .let pat expr body => do - -- Returns the type of the body, inferred in the context extended with the bound variable type. - -- The bound variable is ensured not to escape. - let (exprTyp, exprInner) ← inferNoEscape expr - let expr' := ⟨exprTyp, exprInner, false⟩ - let bindings ← checkPattern pat exprTyp - let body' ← withReader (bindIdents bindings) (inferTerm body) - pure ⟨body'.typ, .let pat expr' body', body'.escapes⟩ + pure ⟨← refLookup x, .ref x, false⟩ + | .ret term => inferReturn term + | .data data => inferData data + | .let pat expr body => inferLet pat expr body | .match term branches => inferMatch term branches - | .app func@(⟨.str .anonymous unqualifiedFunc⟩) args => do - -- Ensures the function exists in the context and that the arguments, which aren't allowed to - -- escape, match the function's input types. Returns the function's output type. - let ctx ← read - match ctx.varTypes[Local.str unqualifiedFunc]? with - | some (.function inputs output) => do - let args ← checkArgsAndInputs func args inputs - pure ⟨output, .app func args, false⟩ - | some _ => throw $ .notAFunction func - | none => match ctx.decls.getByKey func with - | some (.function function) => do - let args ← checkArgsAndInputs func args (function.inputs.map Prod.snd) - pure ⟨function.output, .app func args, false⟩ - | some (.constructor dataType constr) => do - let args ← checkArgsAndInputs func args constr.argTypes - pure ⟨.dataType dataType.name, .app func args, false⟩ - | _ => throw $ .cannotApply func - | .app func args => do - -- Only checks global map if it is not unqualified - let ctx ← read - match ctx.decls.getByKey func with - | some (.function function) => - let args ← checkArgsAndInputs func args (function.inputs.map Prod.snd) - pure ⟨function.output, .app func args, false⟩ - | some (.constructor dataType constr) => - let args ← checkArgsAndInputs func args constr.argTypes - pure ⟨.dataType dataType.name, .app func args, false⟩ - | _ => throw $ .cannotApply func - | .add a b => do - let (a, b) ← checkArith a b - pure ⟨.field, .add a b, false⟩ - | .sub a b => do - let (a, b) ← checkArith a b - pure ⟨.field, .sub a b, false⟩ - | .mul a b => do - let (a, b) ← checkArith a b - pure ⟨.field, .mul a b, false⟩ - | .eqZero a => do - let a ← fieldTerm <$> checkNoEscape a .field - pure ⟨.field, .eqZero a, false⟩ - | .proj tup i => do - let (typs, tupInner) ← inferTuple tup - if h : i < typs.size then - let typ := typs[i] - let tup := ⟨.tuple typs, tupInner, false⟩ - pure ⟨typ, .proj tup i, false⟩ - else - throw $ .indexOoB i - | .get arr i => do - let (typ, n, inner) ← inferArray arr - if i ≥ n then - throw $ .indexOoB i - else - let arr := ⟨.array typ n, inner, false⟩ - pure ⟨typ, .get arr i, false⟩ - | .slice arr i j => if j < i then throw $ .negativeRange i j else do + | .app func args => inferApplication func args + | .ann typ term => do + pure ⟨typ, ← checkNoEscape term typ, false⟩ + | .proj tup i => inferProj tup i + | .get arr i => inferGet arr i + | .slice arr i j => inferSlice arr i j + | .set arr i val => inferSet arr i val + | .store term => inferStore term + | .load term => inferLoad term + | .ptrVal term => inferPtrVal term + | .eqZero a => inferUnop a .eqZero .field + | .add a b => inferBinop a b .add .field + | .sub a b => inferBinop a b .sub .field + | .mul a b => inferBinop a b .mul .field + | .u8ShiftLeft a => inferUnop a .u8ShiftLeft .field + | .u8ShiftRight a => inferUnop a .u8ShiftRight .field + | .u8BitDecomposition a => inferUnop a .u8BitDecomposition (.array .field 8) + | .u8Xor a b => inferBinop a b .u8Xor .field + | .u8Add a b => inferBinop a b .u8Add (.tuple #[.field, .field]) + | .ioGetInfo key => inferIoGetInfo key + | .ioSetInfo key idx len ret => inferIoSetInfo key idx len ret + | .ioRead idx len => inferIoRead idx len + | .ioWrite data ret => inferIoWrite data ret + | .assertEq a b ret => inferAssertEq a b ret + | .debug label term ret => inferDebug label term ret + +partial def checkNoEscape (term : Term) (typ : Typ) : CheckM TypedTermInner := do + let (typ', inner) ← inferNoEscape term + unless typ == typ' do throw $ .typeMismatch typ typ' + pure inner + +partial def inferNoEscape (term : Term) : CheckM (Typ × TypedTermInner) := do + let typedTerm ← inferTerm term + if typedTerm.escapes then throw .illegalReturn + pure (typedTerm.typ, typedTerm.inner) + +partial def inferUnop + (a : Term) + (op : TypedTerm → TypedTermInner) + (typ : Typ) : + CheckM TypedTerm := do + let a ← fieldTerm <$> checkNoEscape a .field + pure ⟨typ, op a, false⟩ + +partial def inferBinop + (a : Term) + (b : Term) + (op : TypedTerm → TypedTerm → TypedTermInner) + (typ : Typ) : + CheckM TypedTerm := do + let a ← fieldTerm <$> checkNoEscape a .field + let b ← fieldTerm <$> checkNoEscape b .field + pure ⟨typ, op a b, false⟩ + +partial def inferProj (tup : Term) (i : Nat) : CheckM TypedTerm := do + let (typs, tupInner) ← inferTuple tup + if h : i < typs.size then + let typ := typs[i] + let tup := ⟨.tuple typs, tupInner, false⟩ + pure ⟨typ, .proj tup i, false⟩ + else + throw $ .indexOoB i + +partial def inferGet (arr : Term) (i : Nat) : CheckM TypedTerm := do + let (typ, n, inner) ← inferArray arr + if i ≥ n then + throw $ .indexOoB i + else + let arr := ⟨.array typ n, inner, false⟩ + pure ⟨typ, .get arr i, false⟩ + +partial def inferSlice (arr : Term) (i j : Nat) : CheckM TypedTerm := + if j < i then throw $ .negativeRange i j else do let (typ, n, inner) ← inferArray arr if j ≤ n then let arr := ⟨.array typ n, inner, false⟩ pure ⟨.array typ (j - i), .slice arr i j, false⟩ else throw $ .rangeOoB i j - | .set arr i val => do - let (typ, n, inner) ← inferArray arr - if i ≥ n then - throw $ .indexOoB i - else - let val ← checkNoEscape val typ - let arrTyp := .array typ n - let arr := ⟨arrTyp, inner, false⟩ - let val := ⟨typ, val, false⟩ - pure ⟨arrTyp, .set arr i val, false⟩ - | .store term => do - -- Infers the type of the term and returns it, wrapped by a pointer type. - -- The term is not allowed to early return. - let (typ, inner) ← inferNoEscape term - let store := .store ⟨typ, inner, false⟩ - pure ⟨.pointer typ, store, false⟩ - | .load term => do - -- Ensures that the the type of the term is a pointer type and returns the unwrapped type. - -- The term is not allowed to early return. - let (typ, inner) ← inferNoEscape term - match typ with - | .pointer innerTyp => - let load := .load ⟨typ, inner, false⟩ - pure ⟨innerTyp, load, false⟩ - | _ => throw $ .notAPointer typ - | .ptrVal term => do - -- Infers the type of the term, which must be a pointer, but returns `.u64`, as in a cast. - -- The term is not allowed to early return. - let (typ, inner) ← inferNoEscape term - match typ with - | .pointer _ => pure $ fieldTerm (.ptrVal ⟨typ, inner, false⟩) - | _ => throw $ .notAPointer typ - | .ann typ term => do - let inner ← checkNoEscape term typ - pure ⟨typ, inner, false⟩ - -- | .unsafeCast term castTyp => do - -- let (typ, inner) ← inferNoEscape term - -- pure ⟨typ, .unsafeCast inner castTyp, false⟩ - | .assertEq a b ret => do - -- `a` and `b` must have the same type. - let (typ, a) ← inferNoEscape a - let b ← checkNoEscape b typ - let ret ← inferTerm ret - let assertEq := .assertEq ⟨typ, a, false⟩ ⟨typ, b, false⟩ ret - assert! (ret.escapes == false) - pure ⟨ret.typ, assertEq, false⟩ - | .ioGetInfo key => do - let (typ, keyInner) ← inferNoEscape key - match typ with - | .array .. => - let ioGetInfo := .ioGetInfo ⟨typ, keyInner, false⟩ - pure ⟨.tuple #[.field, .field], ioGetInfo, false⟩ - | _ => throw $ .notAnArray typ - | .ioSetInfo key idx len ret => do - let (keyTyp, keyInner) ← inferNoEscape key - match keyTyp with - | .array keyEltTyp _ => - if keyEltTyp != .field then throw $ .typeMismatch .field keyEltTyp - let idx ← fieldTerm <$> checkNoEscape idx .field - let len ← fieldTerm <$> checkNoEscape len .field - let ret ← inferTerm ret - let ioSetInfo := .ioSetInfo ⟨keyTyp, keyInner, false⟩ idx len ret - assert! (ret.escapes == false) - pure ⟨ret.typ, ioSetInfo, false⟩ - | _ => throw $ .notAnArray keyTyp - | .ioRead idx len => do - if len = 0 then throw .emptyArray + +partial def inferSet (arr : Term) (i : Nat) (val : Term) : CheckM TypedTerm := do + let (typ, n, inner) ← inferArray arr + if i ≥ n then + throw $ .indexOoB i + else + let val ← checkNoEscape val typ + let arrTyp := .array typ n + let arr := ⟨arrTyp, inner, false⟩ + let val := ⟨typ, val, false⟩ + pure ⟨arrTyp, .set arr i val, false⟩ + +partial def inferStore (term : Term) : CheckM TypedTerm := do + let (typ, inner) ← inferNoEscape term + let store := .store ⟨typ, inner, false⟩ + pure ⟨.pointer typ, store, false⟩ + +partial def inferLoad (term : Term) : CheckM TypedTerm := do + let (typ, inner) ← inferNoEscape term + match typ with + | .pointer innerTyp => + let load := .load ⟨typ, inner, false⟩ + pure ⟨innerTyp, load, false⟩ + | _ => throw $ .notAPointer typ + +partial def inferPtrVal (term : Term) : CheckM TypedTerm := do + let (typ, inner) ← inferNoEscape term + match typ with + | .pointer _ => pure $ fieldTerm (.ptrVal ⟨typ, inner, false⟩) + | _ => throw $ .notAPointer typ + +partial def inferIoGetInfo (key : Term) : CheckM TypedTerm := do + let (typ, keyInner) ← inferNoEscape key + match typ with + | .array .. => + let ioGetInfo := .ioGetInfo ⟨typ, keyInner, false⟩ + pure ⟨.tuple #[.field, .field], ioGetInfo, false⟩ + | _ => throw $ .notAnArray typ + +partial def inferIoSetInfo (key idx len ret : Term) : CheckM TypedTerm := do + let (keyTyp, keyInner) ← inferNoEscape key + match keyTyp with + | .array keyEltTyp _ => + if keyEltTyp != .field then throw $ .typeMismatch .field keyEltTyp let idx ← fieldTerm <$> checkNoEscape idx .field - let ioRead := .ioRead idx len - pure ⟨.array .field len, ioRead, false⟩ - | .ioWrite data ret => do - let (dataTyp, dataInner) ← inferNoEscape data - match dataTyp with - | .array dataEltTyp _ => - if dataEltTyp != .field then throw $ .typeMismatch .field dataEltTyp - let ret ← inferTerm ret - let ioWrite := .ioWrite ⟨dataTyp, dataInner, false⟩ ret - assert! (ret.escapes == false) - pure ⟨ret.typ, ioWrite, false⟩ - | _ => throw $ .notAnArray dataTyp - | .u8BitDecomposition byte => do - let byte ← fieldTerm <$> checkNoEscape byte .field - let u8BitDecomposition := .u8BitDecomposition byte - pure ⟨.array .field 8, u8BitDecomposition, false⟩ - | .u8ShiftLeft byte => do - let byte ← fieldTerm <$> checkNoEscape byte .field - let u8ShiftLeft := .u8ShiftLeft byte - pure ⟨.field, u8ShiftLeft, false⟩ - | .u8ShiftRight byte => do - let byte ← fieldTerm <$> checkNoEscape byte .field - let u8ShiftRight := .u8ShiftRight byte - pure ⟨.field, u8ShiftRight, false⟩ - | .u8Xor i j => do - let i ← fieldTerm <$> checkNoEscape i .field - let j ← fieldTerm <$> checkNoEscape j .field - let u8Xor := .u8Xor i j - pure ⟨.field, u8Xor, false⟩ - | .u8Add i j => do - let i ← fieldTerm <$> checkNoEscape i .field - let j ← fieldTerm <$> checkNoEscape j .field - let u8Add := .u8Add i j - pure ⟨.tuple #[.field, .field], u8Add, false⟩ - | .debug label term ret => do - let term ← term.mapM inferTerm + let len ← fieldTerm <$> checkNoEscape len .field let ret ← inferTerm ret - let debug := .debug label term ret - assert! (ret.escapes == false) - pure ⟨ret.typ, debug, false⟩ + let ioSetInfo := .ioSetInfo ⟨keyTyp, keyInner, false⟩ idx len ret + pure ⟨ret.typ, ioSetInfo, ret.escapes⟩ + | _ => throw $ .notAnArray keyTyp + +partial def inferIoRead (idx : Term) (len : Nat) : CheckM TypedTerm := do + if len = 0 then throw .emptyArray + let idx ← fieldTerm <$> checkNoEscape idx .field + let ioRead := .ioRead idx len + pure ⟨.array .field len, ioRead, false⟩ + +partial def inferIoWrite (data ret : Term) : CheckM TypedTerm := do + let (dataTyp, dataInner) ← inferNoEscape data + match dataTyp with + | .array dataEltTyp _ => + if dataEltTyp != .field then throw $ .typeMismatch .field dataEltTyp + let ret ← inferTerm ret + let ioWrite := .ioWrite ⟨dataTyp, dataInner, false⟩ ret + pure ⟨ret.typ, ioWrite, ret.escapes⟩ + | _ => throw $ .notAnArray dataTyp + +partial def inferVariable (x : Local) : CheckM TypedTerm := do + let ctx ← read + match ctx.varTypes[x]? with + | some t => pure ⟨t, .var x, false⟩ + | none => + let Local.str localName := x | unreachable! + let typ ← refLookup (Global.init localName) + pure ⟨typ, .var x, false⟩ + +partial def inferReturn (term : Term) : CheckM TypedTerm := do + let ctx ← read + let inner ← checkNoEscape term ctx.returnType + pure ⟨ctx.returnType, inner, true⟩ + +partial def inferLet (pat : Pattern) (expr : Term) (body : Term) : CheckM TypedTerm := do + let (exprTyp, exprInner) ← inferNoEscape expr + let expr' := ⟨exprTyp, exprInner, false⟩ + let bindings ← checkPattern pat exprTyp + let body' ← withReader (bindIdents bindings) (inferTerm body) + pure ⟨body'.typ, .let pat expr' body', body'.escapes⟩ + +partial def inferUnqualifiedApp (func : Global) (unqualifiedFunc : String) (args : List Term) : CheckM TypedTerm := do + let ctx ← read + match ctx.varTypes[Local.str unqualifiedFunc]? with + | some (.function inputs output) => do + let args ← checkArgsAndInputs func args inputs + pure ⟨output, .app func args, false⟩ + | some _ => throw $ .notAFunction func + | none => match ctx.decls.getByKey func with + | some (.function function) => do + let args ← checkArgsAndInputs func args (function.inputs.map Prod.snd) + pure ⟨function.output, .app func args, false⟩ + | some (.constructor dataType constr) => do + let args ← checkArgsAndInputs func args constr.argTypes + pure ⟨.dataType dataType.name, .app func args, false⟩ + | _ => throw $ .cannotApply func where - /-- - Ensures that there are as many arguments and as expected types and that - the types of the arguments are precisely those expected. - -/ checkArgsAndInputs func args inputs : CheckM (List TypedTerm) := do let lenArgs := args.length let lenInputs := inputs.length @@ -295,29 +281,52 @@ where let inner ← checkNoEscape arg input pure ⟨input, inner, false⟩ args.zip inputs |>.mapM pass - checkArith a b := do - let aInner ← checkNoEscape a .field - let bInner ← checkNoEscape b .field - pure (fieldTerm aInner, fieldTerm bInner) - fieldTerm inner := ⟨.field, inner, false⟩ -partial def checkNoEscape (term : Term) (typ : Typ) : CheckM TypedTermInner := do - let (typ', inner) ← inferNoEscape term - unless typ == typ' do throw $ .typeMismatch typ typ' - pure inner +partial def inferQualifiedApp (func : Global) (args : List Term) : CheckM TypedTerm := do + let ctx ← read + match ctx.decls.getByKey func with + | some (.function function) => + let args ← checkArgsAndInputs func args (function.inputs.map Prod.snd) + pure ⟨function.output, .app func args, false⟩ + | some (.constructor dataType constr) => + let args ← checkArgsAndInputs func args constr.argTypes + pure ⟨.dataType dataType.name, .app func args, false⟩ + | _ => throw $ .cannotApply func +where + checkArgsAndInputs func args inputs : CheckM (List TypedTerm) := do + let lenArgs := args.length + let lenInputs := inputs.length + unless lenArgs == lenInputs do throw $ .wrongNumArgs func lenArgs lenInputs + let pass := fun (arg, input) => do + let inner ← checkNoEscape arg input + pure ⟨input, inner, false⟩ + args.zip inputs |>.mapM pass -partial def inferNoEscape (term : Term) : CheckM (Typ × TypedTermInner) := do - let typedTerm ← inferTerm term - if typedTerm.escapes then throw .illegalReturn - pure (typedTerm.typ, typedTerm.inner) +partial def inferApplication (func : Global) (args : List Term) : CheckM TypedTerm := + match func.toName with + | .str .anonymous unqualifiedFunc => inferUnqualifiedApp func unqualifiedFunc args + | _ => inferQualifiedApp func args + +partial def inferAssertEq (a b ret : Term) : CheckM TypedTerm := do + let (typ, a) ← inferNoEscape a + let b ← checkNoEscape b typ + let ret ← inferTerm ret + let assertEq := .assertEq ⟨typ, a, false⟩ ⟨typ, b, false⟩ ret + pure ⟨ret.typ, assertEq, ret.escapes⟩ + +partial def inferDebug (label : String) (term : Option Term) (ret : Term) : CheckM TypedTerm := do + let term ← term.mapM inferTerm + let ret ← inferTerm ret + let debug := .debug label term ret + pure ⟨ret.typ, debug, ret.escapes⟩ -partial def inferData : Data → CheckM (Typ × TypedTermInner) - | .field g => pure (.field, .data (.field g)) +partial def inferData : Data → CheckM TypedTerm + | .field g => pure ⟨.field, .data (.field g), false⟩ | .tuple terms => do let typsAndInners ← terms.mapM inferNoEscape let typs := typsAndInners.map Prod.fst let terms := typsAndInners.map fun (typ, inner) => ⟨typ, inner, false⟩ - pure (.tuple typs, .data (.tuple terms)) + pure ⟨.tuple typs, .data (.tuple terms), false⟩ | .array terms => do if h : terms.size > 0 then let (typ, firstInner) ← inferNoEscape terms[0] @@ -326,7 +335,7 @@ partial def inferData : Data → CheckM (Typ × TypedTermInner) for term in terms[1:] do let inner ← checkNoEscape term typ typedTerms := typedTerms.push ⟨typ, inner, false⟩ - pure (.array typ terms.size, .data (.array typedTerms)) + pure ⟨.array typ terms.size, .data (.array typedTerms), false⟩ else throw .emptyArray /-- Infers the type of a 'match' expression and ensures its patterns and branches are valid. -/ From 24d083949bead96703cacd4cf29170ddd5288003 Mon Sep 17 00:00:00 2001 From: Gabriel Barreto Date: Tue, 13 Jan 2026 13:44:35 -0300 Subject: [PATCH 3/5] And and Or U8 operations --- Ix/Aiur/Bytecode.lean | 2 + Ix/Aiur/Check.lean | 2 + Ix/Aiur/Compile.lean | 10 +++- Ix/Aiur/Meta.lean | 14 ++++++ Ix/Aiur/Term.lean | 4 ++ Tests/Aiur.lean | 10 ++++ src/aiur/bytecode.rs | 2 + src/aiur/constraints.rs | 21 ++++++++- src/aiur/execute.rs | 14 ++++++ src/aiur/gadgets/bytes2.rs | 87 ++++++++++++++++++++++++++++++++--- src/aiur/mod.rs | 10 ++++ src/aiur/trace.rs | 23 ++++++++- src/lean/ffi/aiur/toplevel.rs | 8 ++++ 13 files changed, 195 insertions(+), 12 deletions(-) diff --git a/Ix/Aiur/Bytecode.lean b/Ix/Aiur/Bytecode.lean index 8051a9e7..69abff49 100644 --- a/Ix/Aiur/Bytecode.lean +++ b/Ix/Aiur/Bytecode.lean @@ -27,6 +27,8 @@ inductive Op | u8ShiftRight : ValIdx → Op | u8Xor : ValIdx → ValIdx → Op | u8Add : ValIdx → ValIdx → Op + | u8And : ValIdx → ValIdx → Op + | u8Or : ValIdx → ValIdx → Op | debug : String → Option (Array ValIdx) → Op deriving Repr diff --git a/Ix/Aiur/Check.lean b/Ix/Aiur/Check.lean index 6be432fa..83e33a51 100644 --- a/Ix/Aiur/Check.lean +++ b/Ix/Aiur/Check.lean @@ -108,6 +108,8 @@ partial def inferTerm (t : Term) : CheckM TypedTerm := match t with | .u8ShiftRight a => inferUnop a .u8ShiftRight .field | .u8BitDecomposition a => inferUnop a .u8BitDecomposition (.array .field 8) | .u8Xor a b => inferBinop a b .u8Xor .field + | .u8And a b => inferBinop a b .u8And .field + | .u8Or a b => inferBinop a b .u8Or .field | .u8Add a b => inferBinop a b .u8Add (.tuple #[.field, .field]) | .ioGetInfo key => inferIoGetInfo key | .ioSetInfo key idx len ret => inferIoSetInfo key idx len ret diff --git a/Ix/Aiur/Compile.lean b/Ix/Aiur/Compile.lean index 1f8a3054..ffe6a1c4 100644 --- a/Ix/Aiur/Compile.lean +++ b/Ix/Aiur/Compile.lean @@ -120,7 +120,7 @@ def opLayout : Bytecode.Op → LayoutM Unit pushDegrees $ .replicate 8 1 bumpAuxiliaries 8 bumpLookups - | .u8ShiftLeft _ | .u8ShiftRight _ | .u8Xor .. => do + | .u8ShiftLeft _ | .u8ShiftRight _ | .u8Xor .. | .u8And .. | .u8Or .. => do pushDegree 1 bumpAuxiliaries 1 bumpLookups @@ -421,6 +421,14 @@ partial def toIndex let i ← expectIdx i let j ← expectIdx j pushOp (.u8Add i j) 2 + | .u8And i j => do + let i ← expectIdx i + let j ← expectIdx j + pushOp (.u8And i j) + | .u8Or i j => do + let i ← expectIdx i + let j ← expectIdx j + pushOp (.u8Or i j) | .debug label term ret => do let term ← term.mapM (toIndex layoutMap bindings) modify fun stt => { stt with ops := stt.ops.push (.debug label term) } diff --git a/Ix/Aiur/Meta.lean b/Ix/Aiur/Meta.lean index 84c77a1c..48ad07c5 100644 --- a/Ix/Aiur/Meta.lean +++ b/Ix/Aiur/Meta.lean @@ -123,6 +123,8 @@ syntax "u8_shift_left" "(" trm ")" : trm syntax "u8_shift_right" "(" trm ")" : trm syntax "u8_xor" "(" trm ", " trm ")" : trm syntax "u8_add" "(" trm ", " trm ")" : trm +syntax "u8_and" "(" trm ", " trm ")" : trm +syntax "u8_or" "(" trm ", " trm ")" : trm syntax "dbg!" "(" str (", " trm)? ")" ";" (trm)? : trm syntax trm "[" "@" noWs ident "]" : trm @@ -219,6 +221,10 @@ partial def elabTrm : ElabStxCat `trm mkAppM ``Term.u8Xor #[← elabTrm i, ← elabTrm j] | `(trm| u8_add($i:trm, $j:trm)) => do mkAppM ``Term.u8Add #[← elabTrm i, ← elabTrm j] + | `(trm| u8_and($i:trm, $j:trm)) => do + mkAppM ``Term.u8And #[← elabTrm i, ← elabTrm j] + | `(trm| u8_or($i:trm, $j:trm)) => do + mkAppM ``Term.u8Or #[← elabTrm i, ← elabTrm j] | `(trm| dbg!($label:str $[, $t:trm]?); $[$ret:trm]?) => do let t ← match t with | none => mkAppOptM ``Option.none #[some (mkConst ``Term)] @@ -360,6 +366,14 @@ where let i ← replaceToken old new i let j ← replaceToken old new j `(trm| u8_add($i, $j)) + | `(trm| u8_and($i:trm, $j:trm)) => do + let i ← replaceToken old new i + let j ← replaceToken old new j + `(trm| u8_and($i, $j)) + | `(trm| u8_or($i:trm, $j:trm)) => do + let i ← replaceToken old new i + let j ← replaceToken old new j + `(trm| u8_or($i, $j)) | `(trm| dbg!($label:str $[, $t:trm]?); $[$ret:trm]?) => do let t' ← t.mapM $ replaceToken old new let ret' ← ret.mapM $ replaceToken old new diff --git a/Ix/Aiur/Term.lean b/Ix/Aiur/Term.lean index ac5670b1..83302c79 100644 --- a/Ix/Aiur/Term.lean +++ b/Ix/Aiur/Term.lean @@ -93,6 +93,8 @@ inductive Term | u8ShiftRight : Term → Term | u8Xor : Term → Term → Term | u8Add : Term → Term → Term + | u8And : Term → Term → Term + | u8Or : Term → Term → Term | debug : String → Option Term → Term → Term deriving Repr, BEq, Hashable, Inhabited @@ -136,6 +138,8 @@ inductive TypedTermInner | u8ShiftRight : TypedTerm → TypedTermInner | u8Xor : TypedTerm → TypedTerm → TypedTermInner | u8Add : TypedTerm → TypedTerm → TypedTermInner + | u8And : TypedTerm → TypedTerm → TypedTermInner + | u8Or : TypedTerm → TypedTerm → TypedTermInner | debug : String → Option TypedTerm → TypedTerm → TypedTermInner deriving Repr, Inhabited diff --git a/Tests/Aiur.lean b/Tests/Aiur.lean index a8b647ba..7a1aa741 100644 --- a/Tests/Aiur.lean +++ b/Tests/Aiur.lean @@ -199,6 +199,14 @@ def toplevel := ⟦ (u8_add(i_xor_j, i), u8_add(i_xor_j, j)) } + fn u8_add_and(i: G, j: G) -> G { + u8_and(i, j) + } + + fn u8_add_or(i: G, j: G) -> G { + u8_or(i, j) + } + fn fold_matrix_sum(m: [[G; 2]; 2]) -> G { fold(0 .. 2, 0, |acc_outer, @i| fold(0 .. 2, acc_outer, |acc_inner, @j| @@ -245,6 +253,8 @@ def aiurTestCases : List AiurTestCase := [ ⟨#[1, 2, 3, 4, 1, 2, 3, 4], .ofList [(#[0], ⟨0, 4⟩), (#[1], ⟨0, 8⟩)]⟩⟩, .noIO `shr_shr_shl_decompose #[87] #[0, 1, 0, 1, 0, 1, 0, 0], .noIO `u8_add_xor #[45, 131] #[219, 0, 49, 1], + .noIO `u8_add_and #[45, 131] #[1], + .noIO `u8_add_or #[45, 131] #[175], .noIO `fold_matrix_sum #[1, 2, 3, 4] #[10], ] diff --git a/src/aiur/bytecode.rs b/src/aiur/bytecode.rs index 51c3f8cb..43331947 100644 --- a/src/aiur/bytecode.rs +++ b/src/aiur/bytecode.rs @@ -53,6 +53,8 @@ pub enum Op { U8ShiftRight(ValIdx), U8Xor(ValIdx, ValIdx), U8Add(ValIdx, ValIdx), + U8And(ValIdx, ValIdx), + U8Or(ValIdx, ValIdx), Debug(String, Option>), } diff --git a/src/aiur/constraints.rs b/src/aiur/constraints.rs index 48993716..d15ae933 100644 --- a/src/aiur/constraints.rs +++ b/src/aiur/constraints.rs @@ -16,8 +16,9 @@ use crate::aiur::{ bytes1::{Bytes1, Bytes1Op}, bytes2::{Bytes2, Bytes2Op}, }, - memory_channel, u8_add_channel, u8_bit_decomposition_channel, - u8_shift_left_channel, u8_shift_right_channel, u8_xor_channel, + memory_channel, u8_add_channel, u8_and_channel, + u8_bit_decomposition_channel, u8_or_channel, u8_shift_left_channel, + u8_shift_right_channel, u8_xor_channel, }; type Expr = SymbolicExpression; @@ -415,6 +416,22 @@ impl Op { sel.clone(), state, ), + Op::U8And(i, j) => bytes2_constraints( + *i, + *j, + &Bytes2Op::And, + u8_and_channel(), + sel.clone(), + state, + ), + Op::U8Or(i, j) => bytes2_constraints( + *i, + *j, + &Bytes2Op::Or, + u8_or_channel(), + sel.clone(), + state, + ), Op::IOSetInfo(..) | Op::IOWrite(_) | Op::Debug(..) => (), } } diff --git a/src/aiur/execute.rs b/src/aiur/execute.rs index 6d25c99d..8030f5b0 100644 --- a/src/aiur/execute.rs +++ b/src/aiur/execute.rs @@ -271,6 +271,20 @@ impl Function { bytes2_execute(*i, *j, &Bytes2Op::Add, &mut map, record); } }, + ExecEntry::Op(Op::U8And(i, j)) => { + if unconstrained { + map.push(Bytes2::and(&map[*i], &map[*j])); + } else { + bytes2_execute(*i, *j, &Bytes2Op::And, &mut map, record); + } + }, + ExecEntry::Op(Op::U8Or(i, j)) => { + if unconstrained { + map.push(Bytes2::or(&map[*i], &map[*j])); + } else { + bytes2_execute(*i, *j, &Bytes2Op::Or, &mut map, record); + } + }, ExecEntry::Op(Op::Debug(label, idxs)) => match idxs { None => println!("{label}"), Some(idxs) => { diff --git a/src/aiur/gadgets/bytes2.rs b/src/aiur/gadgets/bytes2.rs index fc7e8044..43dbca1c 100644 --- a/src/aiur/gadgets/bytes2.rs +++ b/src/aiur/gadgets/bytes2.rs @@ -7,13 +7,16 @@ use multi_stark::{ }; use crate::aiur::{ - G, execute::QueryRecord, gadgets::AiurGadget, u8_add_channel, u8_xor_channel, + G, execute::QueryRecord, gadgets::AiurGadget, u8_add_channel, u8_and_channel, + u8_or_channel, u8_xor_channel, }; /// Number of columns in the trace with multiplicities for /// - xor /// - overflowing add -const TRACE_WIDTH: usize = 2; +/// - and +/// - or +const TRACE_WIDTH: usize = 4; /// Number of columns in the preprocessed trace: /// - first raw byte value @@ -21,7 +24,9 @@ const TRACE_WIDTH: usize = 2; /// - xor result /// - add result /// - add overflow -const PREPROCESSED_TRACE_WIDTH: usize = 5; +/// - and result +/// - or result +const PREPROCESSED_TRACE_WIDTH: usize = 7; /// AIR implementer for arity 2 byte-related lookups. pub(crate) struct Bytes2; @@ -29,6 +34,8 @@ pub(crate) struct Bytes2; pub(crate) enum Bytes2Op { Xor, Add, + And, + Or, } impl BaseAir for Bytes2 { @@ -53,6 +60,12 @@ impl BaseAir for Bytes2 { let (r, o) = i.overflowing_add(j); trace_values.push(G::from_u8(r)); trace_values.push(G::from_bool(o)); + + // And + trace_values.push(G::from_u8(i & j)); + + // Or + trace_values.push(G::from_u8(i | j)); } } Some(RowMajorMatrix::new(trace_values, PREPROCESSED_TRACE_WIDTH)) @@ -71,6 +84,8 @@ impl AiurGadget for Bytes2 { match op { Bytes2Op::Xor => 1, Bytes2Op::Add => 2, + Bytes2Op::And => 1, + Bytes2Op::Or => 1, } } @@ -92,6 +107,14 @@ impl AiurGadget for Bytes2 { let (r, o) = Self::add(i, j); vec![r, o] }, + Bytes2Op::And => { + record.bytes2_queries.bump_and(i, j); + vec![Self::and(i, j)] + }, + Bytes2Op::Or => { + record.bytes2_queries.bump_or(i, j); + vec![Self::or(i, j)] + }, } } @@ -99,10 +122,14 @@ impl AiurGadget for Bytes2 { // Channels let xor_channel = u8_xor_channel().into(); let add_channel = u8_add_channel().into(); + let and_channel = u8_and_channel().into(); + let or_channel = u8_or_channel().into(); // Multiplicity columns let xor_multiplicity = var(0); let add_multiplicity = var(1); + let and_multiplicity = var(2); + let or_multiplicity = var(3); // Preprocessed columns let i = preprocessed_var(0); @@ -110,16 +137,28 @@ impl AiurGadget for Bytes2 { let xor = preprocessed_var(2); let add_r = preprocessed_var(3); let add_o = preprocessed_var(4); + let and = preprocessed_var(5); + let or = preprocessed_var(6); let pull_xor = Lookup::pull( xor_multiplicity, vec![xor_channel, i.clone(), j.clone(), xor], ); - let pull_add = - Lookup::pull(add_multiplicity, vec![add_channel, i, j, add_r, add_o]); + let pull_add = Lookup::pull( + add_multiplicity, + vec![add_channel, i.clone(), j.clone(), add_r, add_o], + ); + + let pull_and = Lookup::pull( + and_multiplicity, + vec![and_channel, i.clone(), j.clone(), and], + ); + + let pull_or = + Lookup::pull(or_multiplicity, vec![or_channel, i, j, or]); - vec![pull_xor, pull_add] + vec![pull_xor, pull_add, pull_and, pull_or] } fn witness_data( @@ -133,18 +172,22 @@ impl AiurGadget for Bytes2 { let xor_channel = u8_xor_channel(); let add_channel = u8_add_channel(); + let and_channel = u8_and_channel(); + let or_channel = u8_or_channel(); rows .chunks_exact_mut(TRACE_WIDTH) .enumerate() .zip(&record.bytes2_queries.0) .zip(&mut lookups) - .for_each(|(((row_idx, row), &[xor, add]), row_lookups)| { + .for_each(|(((row_idx, row), &[xor, add, and, or]), row_lookups)| { let i = G::from_usize(row_idx / 256); let j = G::from_usize(row_idx % 256); row[0] = xor; row[1] = add; + row[2] = and; + row[3] = or; // Pull xor. row_lookups[0] = @@ -153,6 +196,14 @@ impl AiurGadget for Bytes2 { // Pull add. let (r, o) = Self::add(&i, &j); row_lookups[1] = Lookup::pull(add, vec![add_channel, i, j, r, o]); + + // Pull and. + row_lookups[2] = + Lookup::pull(and, vec![and_channel, i, j, Self::and(&i, &j)]); + + // Pull or. + row_lookups[3] = + Lookup::pull(or, vec![or_channel, i, j, Self::or(&i, &j)]); }); (RowMajorMatrix::new(rows, TRACE_WIDTH), lookups) } @@ -175,6 +226,14 @@ impl Bytes2Queries { self.bump_multiplicity_for(i, j, 1) } + fn bump_and(&mut self, i: &G, j: &G) { + self.bump_multiplicity_for(i, j, 2) + } + + fn bump_or(&mut self, i: &G, j: &G) { + self.bump_multiplicity_for(i, j, 3) + } + fn bump_multiplicity_for(&mut self, i: &G, j: &G, col: usize) { let i = usize::try_from(i.as_canonical_u64()).unwrap(); let j = usize::try_from(j.as_canonical_u64()).unwrap(); @@ -198,4 +257,18 @@ impl Bytes2 { let (r, o) = i.overflowing_add(j); (G::from_u8(r), G::from_bool(o)) } + + #[inline] + pub(crate) fn and(i: &G, j: &G) -> G { + let i: u8 = i.as_canonical_u64().try_into().unwrap(); + let j: u8 = j.as_canonical_u64().try_into().unwrap(); + G::from_u8(i & j) + } + + #[inline] + pub(crate) fn or(i: &G, j: &G) -> G { + let i: u8 = i.as_canonical_u64().try_into().unwrap(); + let j: u8 = j.as_canonical_u64().try_into().unwrap(); + G::from_u8(i | j) + } } diff --git a/src/aiur/mod.rs b/src/aiur/mod.rs index df09e145..26e49bb9 100644 --- a/src/aiur/mod.rs +++ b/src/aiur/mod.rs @@ -44,3 +44,13 @@ pub fn u8_xor_channel() -> G { pub fn u8_add_channel() -> G { G::from_u8(6) } + +#[inline] +pub fn u8_and_channel() -> G { + G::from_u8(7) +} + +#[inline] +pub fn u8_or_channel() -> G { + G::from_u8(8) +} diff --git a/src/aiur/trace.rs b/src/aiur/trace.rs index 4b0e17c7..48bb4c40 100644 --- a/src/aiur/trace.rs +++ b/src/aiur/trace.rs @@ -17,8 +17,9 @@ use crate::aiur::{ function_channel, gadgets::{bytes1::Bytes1, bytes2::Bytes2}, memory::Memory, - u8_add_channel, u8_bit_decomposition_channel, u8_shift_left_channel, - u8_shift_right_channel, u8_xor_channel, + u8_add_channel, u8_and_channel, u8_bit_decomposition_channel, + u8_or_channel, u8_shift_left_channel, u8_shift_right_channel, + u8_xor_channel, }; struct ColumnIndex { @@ -374,6 +375,24 @@ impl Op { let lookup_args = vec![u8_add_channel(), i, j, r, o]; slice.push_lookup(index, Lookup::push(G::ONE, lookup_args)); }, + Op::U8And(i, j) => { + let (i, _) = map[*i]; + let (j, _) = map[*j]; + let and = Bytes2::and(&i, &j); + map.push((and, 1)); + slice.push_auxiliary(index, and); + let lookup_args = vec![u8_and_channel(), i, j, and]; + slice.push_lookup(index, Lookup::push(G::ONE, lookup_args)); + }, + Op::U8Or(i, j) => { + let (i, _) = map[*i]; + let (j, _) = map[*j]; + let or = Bytes2::or(&i, &j); + map.push((or, 1)); + slice.push_auxiliary(index, or); + let lookup_args = vec![u8_or_channel(), i, j, or]; + slice.push_lookup(index, Lookup::push(G::ONE, lookup_args)); + }, Op::AssertEq(..) | Op::IOSetInfo(..) | Op::IOWrite(_) | Op::Debug(..) => { }, } diff --git a/src/lean/ffi/aiur/toplevel.rs b/src/lean/ffi/aiur/toplevel.rs index 1b4ebd84..8789665e 100644 --- a/src/lean/ffi/aiur/toplevel.rs +++ b/src/lean/ffi/aiur/toplevel.rs @@ -117,6 +117,14 @@ fn lean_ptr_to_op(ptr: *const c_void) -> Op { Op::U8Add(i, j) }, 18 => { + let [i, j] = ctor.objs().map(lean_unbox_nat_as_usize); + Op::U8And(i, j) + }, + 19 => { + let [i, j] = ctor.objs().map(lean_unbox_nat_as_usize); + Op::U8Or(i, j) + }, + 20 => { let [label_ptr, idxs_ptr] = ctor.objs(); let label_str: &LeanStringObject = as_ref_unsafe(label_ptr.cast()); let label = label_str.as_string(); From 7042fae35e6356a5c1fffbdaaf520a2612197f4e Mon Sep 17 00:00:00 2001 From: Gabriel Barreto Date: Thu, 15 Jan 2026 10:17:30 -0300 Subject: [PATCH 4/5] fmt and clippy --- src/aiur/constraints.rs | 5 ++--- src/aiur/gadgets/bytes2.rs | 7 ++----- src/aiur/trace.rs | 5 ++--- 3 files changed, 6 insertions(+), 11 deletions(-) diff --git a/src/aiur/constraints.rs b/src/aiur/constraints.rs index d15ae933..9c1252dd 100644 --- a/src/aiur/constraints.rs +++ b/src/aiur/constraints.rs @@ -16,9 +16,8 @@ use crate::aiur::{ bytes1::{Bytes1, Bytes1Op}, bytes2::{Bytes2, Bytes2Op}, }, - memory_channel, u8_add_channel, u8_and_channel, - u8_bit_decomposition_channel, u8_or_channel, u8_shift_left_channel, - u8_shift_right_channel, u8_xor_channel, + memory_channel, u8_add_channel, u8_and_channel, u8_bit_decomposition_channel, + u8_or_channel, u8_shift_left_channel, u8_shift_right_channel, u8_xor_channel, }; type Expr = SymbolicExpression; diff --git a/src/aiur/gadgets/bytes2.rs b/src/aiur/gadgets/bytes2.rs index 43dbca1c..f928af62 100644 --- a/src/aiur/gadgets/bytes2.rs +++ b/src/aiur/gadgets/bytes2.rs @@ -82,10 +82,8 @@ impl AiurGadget for Bytes2 { fn output_size(&self, op: &Bytes2Op) -> usize { match op { - Bytes2Op::Xor => 1, + Bytes2Op::Xor | Bytes2Op::And | Bytes2Op::Or => 1, Bytes2Op::Add => 2, - Bytes2Op::And => 1, - Bytes2Op::Or => 1, } } @@ -155,8 +153,7 @@ impl AiurGadget for Bytes2 { vec![and_channel, i.clone(), j.clone(), and], ); - let pull_or = - Lookup::pull(or_multiplicity, vec![or_channel, i, j, or]); + let pull_or = Lookup::pull(or_multiplicity, vec![or_channel, i, j, or]); vec![pull_xor, pull_add, pull_and, pull_or] } diff --git a/src/aiur/trace.rs b/src/aiur/trace.rs index 48bb4c40..a87f4c5f 100644 --- a/src/aiur/trace.rs +++ b/src/aiur/trace.rs @@ -17,9 +17,8 @@ use crate::aiur::{ function_channel, gadgets::{bytes1::Bytes1, bytes2::Bytes2}, memory::Memory, - u8_add_channel, u8_and_channel, u8_bit_decomposition_channel, - u8_or_channel, u8_shift_left_channel, u8_shift_right_channel, - u8_xor_channel, + u8_add_channel, u8_and_channel, u8_bit_decomposition_channel, u8_or_channel, + u8_shift_left_channel, u8_shift_right_channel, u8_xor_channel, }; struct ColumnIndex { From 98dc1238b20fd5d27ce9f36302261e32dac9913a Mon Sep 17 00:00:00 2001 From: Gabriel Barreto Date: Thu, 15 Jan 2026 12:58:50 -0300 Subject: [PATCH 5/5] bincode unmaintained --- deny.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/deny.toml b/deny.toml index 1cb67022..6519dae3 100644 --- a/deny.toml +++ b/deny.toml @@ -74,6 +74,8 @@ ignore = [ "RUSTSEC-2024-0384", # `instant` crate is unmaintained "RUSTSEC-2024-0370", # `proc-macro-error` crate is unmaintained "RUSTSEC-2023-0089", # `atomic-polyfill` crate is unmaintained + "RUSTSEC-2025-0141", # `bincode` crate is unmaintained + #{ id = "RUSTSEC-0000-0000", reason = "you can specify a reason the advisory is ignored" }, #"a-crate-that-is-yanked@0.1.1", # you can also ignore yanked crate versions if you wish #{ crate = "a-crate-that-is-yanked@0.1.1", reason = "you can specify why you are ignoring the yanked crate" },