Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 82 additions & 86 deletions crates/lean_compiler/src/a_simplify_lang.rs

Large diffs are not rendered by default.

34 changes: 19 additions & 15 deletions crates/lean_compiler/src/b_compile_intermediate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,33 +63,37 @@ impl Compiler {
impl SimpleExpr {
fn to_mem_after_fp_or_constant(&self, compiler: &Compiler) -> IntermediateValue {
match self {
Self::Var(var) => IntermediateValue::MemoryAfterFp {
Self::Memory(VarOrConstMallocAccess::Var(var)) => IntermediateValue::MemoryAfterFp {
offset: compiler.get_offset(&var.clone().into()),
},
Self::Memory(VarOrConstMallocAccess::ConstMallocAccess { malloc_label, offset }) => {
IntermediateValue::MemoryAfterFp {
offset: compiler.get_offset(&VarOrConstMallocAccess::ConstMallocAccess {
malloc_label: *malloc_label,
offset: offset.clone(),
}),
}
}
Self::Constant(c) => IntermediateValue::Constant(c.clone()),
Self::ConstMallocAccess { malloc_label, offset } => IntermediateValue::MemoryAfterFp {
offset: compiler.get_offset(&VarOrConstMallocAccess::ConstMallocAccess {
malloc_label: *malloc_label,
offset: offset.clone(),
}),
},
}
}
}

impl IntermediateValue {
fn from_simple_expr(expr: &SimpleExpr, compiler: &Compiler) -> Self {
match expr {
SimpleExpr::Var(var) => Self::MemoryAfterFp {
SimpleExpr::Memory(VarOrConstMallocAccess::Var(var)) => Self::MemoryAfterFp {
offset: compiler.get_offset(&var.clone().into()),
},
SimpleExpr::Memory(VarOrConstMallocAccess::ConstMallocAccess { malloc_label, offset }) => {
Self::MemoryAfterFp {
offset: compiler.get_offset(&VarOrConstMallocAccess::ConstMallocAccess {
malloc_label: *malloc_label,
offset: offset.clone(),
}),
}
}
SimpleExpr::Constant(c) => Self::Constant(c.clone()),
SimpleExpr::ConstMallocAccess { malloc_label, offset } => Self::MemoryAfterFp {
offset: compiler.get_offset(&VarOrConstMallocAccess::ConstMallocAccess {
malloc_label: *malloc_label,
offset: offset.clone(),
}),
},
}
}

Expand Down Expand Up @@ -373,7 +377,7 @@ fn compile_lines(
}

SimpleLine::RawAccess { res, index, shift } => {
if let SimpleExpr::Var(var) = res
if let SimpleExpr::Memory(VarOrConstMallocAccess::Var(var)) = res
&& !compiler.is_in_scope(var)
{
let current_scope_layout = compiler.stack_frame_layout.scopes.last_mut().unwrap();
Expand Down
35 changes: 9 additions & 26 deletions crates/lean_compiler/src/lang.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::collections::{BTreeMap, BTreeSet};
use std::fmt::{Display, Formatter};
use utils::ToUsize;

use crate::a_simplify_lang::VarOrConstMallocAccess;
use crate::{F, parser::ConstArrayValue};
pub use lean_vm::{FileId, FunctionName, SourceLocation};

Expand Down Expand Up @@ -39,12 +40,8 @@ pub type ConstMallocLabel = usize;

#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum SimpleExpr {
Var(Var),
Memory(VarOrConstMallocAccess),
Constant(ConstExpression),
ConstMallocAccess {
malloc_label: ConstMallocLabel,
offset: ConstExpression,
},
}

impl SimpleExpr {
Expand All @@ -63,13 +60,6 @@ impl SimpleExpr {
pub const fn is_constant(&self) -> bool {
matches!(self, Self::Constant(_))
}

pub fn simplify_if_const(&self) -> Self {
if let Self::Constant(constant) = self {
return constant.clone().into();
}
self.clone()
}
}

impl From<ConstantValue> for SimpleExpr {
Expand All @@ -86,16 +76,15 @@ impl From<ConstExpression> for SimpleExpr {

impl From<Var> for SimpleExpr {
fn from(var: Var) -> Self {
Self::Var(var)
VarOrConstMallocAccess::Var(var).into()
}
}

impl SimpleExpr {
pub fn as_constant(&self) -> Option<ConstExpression> {
match self {
Self::Var(_) => None,
Self::Constant(constant) => Some(constant.clone()),
Self::ConstMallocAccess { .. } => None,
Self::Memory(_) => None,
}
}

Expand Down Expand Up @@ -226,7 +215,7 @@ impl Display for Condition {
pub enum Expression {
Value(SimpleExpr),
ArrayAccess {
array: SimpleExpr,
array: Var,
index: Vec<Self>, // multi-dimensional array access
},
MathExpr(MathOperation, Vec<Self>),
Expand Down Expand Up @@ -356,10 +345,7 @@ impl Expression {
self.eval_with(
&|value: &SimpleExpr| value.as_constant()?.naive_eval(),
&|arr, indexes| {
let SimpleExpr::Var(name) = arr else {
return None;
};
let array = const_arrays.get(name)?;
let array = const_arrays.get(arr)?;
assert_eq!(indexes.len(), array.depth());
let idx = indexes.iter().map(|e| e.to_usize()).collect::<Vec<_>>();
array.navigate(&idx)?.as_scalar().map(F::from_usize)
Expand All @@ -370,7 +356,7 @@ impl Expression {
pub fn eval_with<ValueFn, ArrayFn>(&self, value_fn: &ValueFn, array_fn: &ArrayFn) -> Option<F>
where
ValueFn: Fn(&SimpleExpr) -> Option<F>,
ArrayFn: Fn(&SimpleExpr, Vec<F>) -> Option<F>,
ArrayFn: Fn(&Var, Vec<F>) -> Option<F>,
{
match self {
Self::Value(value) => value_fn(value),
Expand Down Expand Up @@ -405,7 +391,7 @@ impl Expression {
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum AssignmentTarget {
Var(Var),
ArrayAccess { array: SimpleExpr, index: Box<Expression> },
ArrayAccess { array: Var, index: Box<Expression> },
}

impl Display for AssignmentTarget {
Expand Down Expand Up @@ -713,11 +699,8 @@ impl Display for ConstantValue {
impl Display for SimpleExpr {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Self::Var(var) => write!(f, "{var}"),
Self::Constant(constant) => write!(f, "{constant}"),
Self::ConstMallocAccess { malloc_label, offset } => {
write!(f, "malloc_access({malloc_label}, {offset})")
}
Self::Memory(var_or_const_malloc_access) => write!(f, "{var_or_const_malloc_access}"),
}
}
}
Expand Down
7 changes: 2 additions & 5 deletions crates/lean_compiler/src/parser/parsers/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,16 +80,13 @@ pub struct ArrayAccessParser;
impl Parse<Expression> for ArrayAccessParser {
fn parse(&self, pair: ParsePair<'_>, ctx: &mut ParseContext) -> ParseResult<Expression> {
let mut inner = pair.into_inner();
let array_name = next_inner_pair(&mut inner, "array name")?.as_str().to_string();
let array = next_inner_pair(&mut inner, "array name")?.as_str().to_string();

let index: Vec<Expression> = inner
.map(|idx_pair| ExpressionParser.parse(idx_pair, ctx))
.collect::<Result<Vec<_>, _>>()?;

Ok(Expression::ArrayAccess {
array: SimpleExpr::Var(array_name),
index,
})
Ok(Expression::ArrayAccess { array, index })
}
}

Expand Down
4 changes: 2 additions & 2 deletions crates/lean_compiler/src/parser/parsers/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use super::statement::StatementParser;
use super::{Parse, ParseContext, next_inner_pair};
use crate::{
SourceLineNumber,
lang::{AssignmentTarget, Expression, Function, Line, SimpleExpr, SourceLocation},
lang::{AssignmentTarget, Expression, Function, Line, SourceLocation},
parser::{
error::{ParseResult, SemanticError},
grammar::{ParsePair, Rule},
Expand Down Expand Up @@ -129,7 +129,7 @@ impl Parse<AssignmentTarget> for AssignmentTargetParser {
let array = next_inner_pair(&mut inner_pairs, "array name")?.as_str().to_string();
let index = ExpressionParser.parse(next_inner_pair(&mut inner_pairs, "array index")?, ctx)?;
Ok(AssignmentTarget::ArrayAccess {
array: SimpleExpr::Var(array),
array,
index: Box::new(index),
})
}
Expand Down
12 changes: 5 additions & 7 deletions crates/lean_compiler/src/parser/parsers/literal.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use super::expression::ExpressionParser;
use super::{ConstArrayValue, Parse, ParseContext, ParsedConstant, next_inner_pair};
use crate::a_simplify_lang::VarOrConstMallocAccess;
use crate::{
F,
lang::{ConstExpression, ConstantValue, SimpleExpr},
Expand Down Expand Up @@ -90,16 +91,13 @@ pub fn evaluate_const_expr(expr: &crate::lang::Expression, ctx: &ParseContext) -
expr.eval_with(
&|simple_expr| match simple_expr {
SimpleExpr::Constant(cst) => cst.naive_eval(),
SimpleExpr::Var(var) => ctx.get_constant(var).map(F::from_usize),
SimpleExpr::ConstMallocAccess { .. } => None,
SimpleExpr::Memory(VarOrConstMallocAccess::Var(var)) => ctx.get_constant(var).map(F::from_usize),
SimpleExpr::Memory(VarOrConstMallocAccess::ConstMallocAccess { .. }) => None,
},
&|arr, index| {
// Support const array access in expressions
let SimpleExpr::Var(name) = arr else {
return None;
};
let idx = index.iter().map(|e| e.to_usize()).collect::<Vec<_>>();
let array = ctx.get_const_array(name)?;
let array = ctx.get_const_array(arr)?;
array.navigate(&idx)?.as_scalar().map(F::from_usize)
},
)
Expand Down Expand Up @@ -161,7 +159,7 @@ impl VarOrConstantParser {
}
// Otherwise treat as variable reference
else {
Ok(SimpleExpr::Var(text.to_string()))
Ok(VarOrConstMallocAccess::Var(text.to_string()).into())
}
}
}
Expand Down
30 changes: 30 additions & 0 deletions crates/lean_compiler/tests/test_compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,36 @@ fn test_mini_program_4() {
dbg!(&poseidon24_permute(public_input)[16..]);
}

#[test]
fn test_mini_program_5() {
let program = r#"
fn main() {
arr = malloc(10);
arr[6] = 42;
arr[8] = 11;
sum_1 = func_1(arr[6], arr[8]);
assert sum_1 == 53;
return;
}

fn func_1(i, j) inline -> 1 {
for k in 0..i {
for u in 0..j {
assert k + u != 1000000;
}
}
return i + j;
}

"#;
compile_and_run(
&ProgramSource::Raw(program.to_string()),
(&[], &[]),
DEFAULT_NO_VEC_RUNTIME_MEMORY,
false,
);
}

#[test]
fn test_inlined() {
let program = r#"
Expand Down