Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 18 additions & 40 deletions crates/lean_compiler/src/a_simplify_lang.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ use crate::{
Counter, F,
ir::HighLevelOperation,
lang::{
AssignmentTarget, AssumeBoolean, Condition, ConstExpression, ConstMallocLabel, ConstantValue, Context,
Expression, Function, Line, Program, Scope, SimpleExpr, Var,
AssignmentTarget, Condition, ConstExpression, ConstMallocLabel, Context, Expression, Function, Line, MathExpr,
Program, Scope, SimpleExpr, Var,
},
parser::ConstArrayValue,
};
Expand Down Expand Up @@ -529,7 +529,7 @@ fn check_boolean_scoping(boolean: &BooleanExpr<Expression>, ctx: &Context) {

fn check_condition_scoping(condition: &Condition, ctx: &Context) {
match condition {
Condition::Expression(expr, _) => {
Condition::AssumeBoolean(expr) => {
check_expr_scoping(expr, ctx);
}
Condition::Comparison(boolean) => {
Expand Down Expand Up @@ -723,11 +723,10 @@ fn simplify_lines(
if let (SimpleExpr::Constant(left_cst), SimpleExpr::Constant(right_cst)) =
(&left, &right)
{
let result = ConstExpression::Binary {
left: Box::new(left_cst.clone()),
operation: *operation,
right: Box::new(right_cst.clone()),
}
let result = ConstExpression::MathExpr(
MathExpr::Binary(*operation),
vec![left_cst.clone(), right_cst.clone()],
)
.try_naive_simplification();
res.push(SimpleLine::equality(var.clone(), SimpleExpr::Constant(result)));
} else {
Expand Down Expand Up @@ -855,28 +854,8 @@ fn simplify_lines(
});
(diff_var.into(), then_branch, else_branch)
}
Condition::Expression(condition, assume_boolean) => {
Condition::AssumeBoolean(condition) => {
let condition_simplified = simplify_expr(ctx, state, const_malloc, condition, &mut res);

match assume_boolean {
AssumeBoolean::AssumeBoolean => {}
AssumeBoolean::DoNotAssumeBoolean => {
// Check condition_simplified is boolean
let one_minus_condition_var = state.counters.aux_var();
res.push(SimpleLine::Assignment {
var: one_minus_condition_var.clone().into(),
operation: HighLevelOperation::Sub,
arg0: SimpleExpr::Constant(ConstExpression::Value(ConstantValue::Scalar(1))),
arg1: condition_simplified.clone(),
});
res.push(SimpleLine::AssertZero {
operation: HighLevelOperation::Mul,
arg0: condition_simplified.clone(),
arg1: one_minus_condition_var.into(),
});
}
}

(condition_simplified, then_branch, else_branch)
}
};
Expand Down Expand Up @@ -1194,11 +1173,10 @@ fn simplify_expr(
let right_var = simplify_expr(ctx, state, const_malloc, right, lines);

if let (SimpleExpr::Constant(left_cst), SimpleExpr::Constant(right_cst)) = (&left_var, &right_var) {
return SimpleExpr::Constant(ConstExpression::Binary {
left: Box::new(left_cst.clone()),
operation: *operation,
right: Box::new(right_cst.clone()),
});
return SimpleExpr::Constant(ConstExpression::MathExpr(
MathExpr::Binary(*operation),
vec![left_cst.clone(), right_cst.clone()],
));
}

let aux_var = state.counters.aux_var();
Expand Down Expand Up @@ -1275,7 +1253,7 @@ pub fn find_variable_usage(
on_new_expr(&comp.left, internal_vars, external_vars);
on_new_expr(&comp.right, internal_vars, external_vars);
}
Condition::Expression(expr, _assume_boolean) => {
Condition::AssumeBoolean(expr) => {
on_new_expr(expr, internal_vars, external_vars);
}
};
Expand Down Expand Up @@ -1444,7 +1422,7 @@ fn inline_lines(

let inline_condition = |condition: &mut Condition| match condition {
Condition::Comparison(comparison) => inline_comparison(comparison),
Condition::Expression(expr, _assume_boolean) => inline_expr(expr, args, inlining_count),
Condition::AssumeBoolean(expr) => inline_expr(expr, args, inlining_count),
};

let inline_internal_var = |var: &mut Var| {
Expand Down Expand Up @@ -1883,7 +1861,7 @@ fn replace_vars_for_unroll(
internal_vars,
);
}
Condition::Expression(expr, _assume_bool) => {
Condition::AssumeBoolean(expr) => {
replace_vars_for_unroll_in_expr(expr, iterator, unroll_index, iterator_value, internal_vars);
}
}
Expand Down Expand Up @@ -2142,9 +2120,9 @@ fn extract_inlined_calls_from_condition(
inlined_var_counter: &mut Counter,
) -> (Condition, Vec<Line>) {
match condition {
Condition::Expression(expr, assume_boolean) => {
Condition::AssumeBoolean(expr) => {
let (expr, expr_lines) = extract_inlined_calls_from_expr(expr, inlined_functions, inlined_var_counter);
(Condition::Expression(expr, *assume_boolean), expr_lines)
(Condition::AssumeBoolean(expr), expr_lines)
}
Condition::Comparison(boolean) => {
let (boolean, boolean_lines) =
Expand Down Expand Up @@ -2719,7 +2697,7 @@ fn replace_vars_by_const_in_lines(lines: &mut [Line], map: &BTreeMap<Var, F>) {
replace_vars_by_const_in_expr(&mut cond.left, map);
replace_vars_by_const_in_expr(&mut cond.right, map);
}
Condition::Expression(expr, _assume_boolean) => {
Condition::AssumeBoolean(expr) => {
replace_vars_by_const_in_expr(expr, map);
}
}
Expand Down
11 changes: 5 additions & 6 deletions crates/lean_compiler/src/b_compile_intermediate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,10 @@ impl Compiler {
VarOrConstMallocAccess::ConstMallocAccess { malloc_label, offset } => {
for scope in self.stack_frame_layout.scopes.iter().rev() {
if let Some(base) = scope.const_mallocs.get(malloc_label) {
return ConstExpression::Binary {
left: Box::new((*base).into()),
operation: HighLevelOperation::Add,
right: Box::new((*offset).clone()),
};
return ConstExpression::MathExpr(
MathExpr::Binary(HighLevelOperation::Add),
vec![(*base).into(), offset.clone()],
);
}
}
panic!("Const malloc {malloc_label} not in scope");
Expand Down Expand Up @@ -518,7 +517,7 @@ fn compile_lines(
});
}
SimpleLine::ConstMalloc { var, size, label } => {
let size = size.naive_eval().unwrap().to_usize(); // TODO not very good;
let size = size.naive_eval().unwrap().to_usize();
if !compiler.is_in_scope(var) {
let current_scope_layout = compiler.stack_frame_layout.scopes.last_mut().unwrap();
current_scope_layout
Expand Down
25 changes: 10 additions & 15 deletions crates/lean_compiler/src/grammar.pest
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,8 @@ statement = {
return_statement |
break_statement |
continue_statement |
assert_eq_statement |
assert_not_eq_statement |
debug_assert_eq_statement |
debug_assert_not_eq_statement |
debug_assert_lt_statement |
assert_statement |
debug_assert_statement |
assignment
}

Expand All @@ -51,10 +48,14 @@ assignment_target = { array_access_expr | identifier }

if_statement = { "if" ~ condition ~ "{" ~ statement* ~ "}" ~ else_if_clause* ~ else_clause? }

condition = { expression | assumed_bool_expr }
condition = { assumed_bool_expr | comparison }

assumed_bool_expr = { "!!assume_bool" ~ "(" ~ expression ~ ")" }

// Comparisons (shared between conditions and assertions)
comparison = { add_expr ~ comparison_op ~ add_expr }
comparison_op = { "==" | "!=" | "<" }

else_if_clause = { "else" ~ "if" ~ condition ~ "{" ~ statement* ~ "}" }

else_clause = { "else" ~ "{" ~ statement* ~ "}" }
Expand All @@ -67,18 +68,12 @@ match_statement = { "match" ~ expression ~ "{" ~ match_arm* ~ "}" }
match_arm = { pattern ~ "=>" ~ "{" ~ statement* ~ "}" }
pattern = { constant_value }

assert_eq_statement = { "assert" ~ add_expr ~ "==" ~ add_expr ~ ";" }
assert_not_eq_statement = { "assert" ~ add_expr ~ "!=" ~ add_expr ~ ";" }

debug_assert_eq_statement = { "debug_assert" ~ add_expr ~ "==" ~ add_expr ~ ";" }
debug_assert_not_eq_statement = { "debug_assert" ~ add_expr ~ "!=" ~ add_expr ~ ";" }
debug_assert_lt_statement = { "debug_assert" ~ add_expr ~ "<" ~ add_expr ~ ";" }
assert_statement = { "assert" ~ comparison ~ ";" }
debug_assert_statement = { "debug_assert" ~ comparison ~ ";" }

// Expressions
tuple_expression = { expression ~ ("," ~ expression)* }
expression = { neq_expr }
neq_expr = { eq_expr ~ ("!=" ~ eq_expr)* }
eq_expr = { add_expr ~ ("==" ~ add_expr)* }
expression = { add_expr }
add_expr = { sub_expr ~ ("+" ~ sub_expr)* }
sub_expr = { mul_expr ~ ("-" ~ mul_expr)* }
mul_expr = { mod_expr ~ ("*" ~ mod_expr)* }
Expand Down
5 changes: 1 addition & 4 deletions crates/lean_compiler/src/ir/instruction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,7 @@ impl IntermediateInstruction {
arg_c,
res: arg_a,
},
HighLevelOperation::Exp
| HighLevelOperation::Mod
| HighLevelOperation::Equal
| HighLevelOperation::NotEqual => unreachable!(),
HighLevelOperation::Exp | HighLevelOperation::Mod => unreachable!(),
}
}

Expand Down
24 changes: 0 additions & 24 deletions crates/lean_compiler/src/ir/operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,6 @@ use multilinear_toolkit::prelude::*;
use std::fmt::{Display, Formatter};
use utils::ToUsize;

/// High-level operations that can be performed in the IR.
///
/// These operations represent the semantic intent of computations
/// and may be lowered to different VM operations depending on context.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum HighLevelOperation {
/// Addition operation.
Expand All @@ -22,29 +18,11 @@ pub enum HighLevelOperation {
Exp,
/// Modulo operation (only for constant expressions).
Mod,
/// Equality comparison
Equal,
/// Non-equality comparison
NotEqual,
}

impl HighLevelOperation {
pub fn eval(&self, a: F, b: F) -> F {
match self {
Self::Equal => {
if a == b {
F::ONE
} else {
F::ZERO
}
}
Self::NotEqual => {
if a != b {
F::ONE
} else {
F::ZERO
}
}
Self::Add => a + b,
Self::Mul => a * b,
Self::Sub => a - b,
Expand All @@ -58,8 +36,6 @@ impl HighLevelOperation {
impl Display for HighLevelOperation {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Self::Equal => write!(f, "=="),
Self::NotEqual => write!(f, "!="),
Self::Add => write!(f, "+"),
Self::Mul => write!(f, "*"),
Self::Sub => write!(f, "-"),
Expand Down
37 changes: 10 additions & 27 deletions crates/lean_compiler/src/lang.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,6 @@ pub enum ConstantValue {
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum ConstExpression {
Value(ConstantValue),
Binary {
left: Box<Self>,
operation: HighLevelOperation,
right: Box<Self>,
},
MathExpr(MathExpr, Vec<Self>),
}

Expand All @@ -140,11 +135,7 @@ impl TryFrom<Expression> for ConstExpression {
Expression::Binary { left, operation, right } => {
let left_expr = Self::try_from(*left)?;
let right_expr = Self::try_from(*right)?;
Ok(Self::Binary {
left: Box::new(left_expr),
operation,
right: Box::new(right_expr),
})
Ok(Self::MathExpr(MathExpr::Binary(operation), vec![left_expr, right_expr]))
}
Expression::MathExpr(math_expr, args) => {
let mut const_args = Vec::new();
Expand Down Expand Up @@ -185,9 +176,6 @@ impl ConstExpression {
{
match self {
Self::Value(value) => func(value),
Self::Binary { left, operation, right } => {
Some(operation.eval(left.eval_with(func)?, right.eval_with(func)?))
}
Self::MathExpr(math_expr, args) => {
let mut eval_args = Vec::new();
for arg in args {
Expand Down Expand Up @@ -220,25 +208,16 @@ impl From<ConstantValue> for ConstExpression {
}
}

#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum AssumeBoolean {
AssumeBoolean,
DoNotAssumeBoolean,
}

#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum Condition {
Expression(Expression, AssumeBoolean),
AssumeBoolean(Expression),
Comparison(BooleanExpr<Expression>),
}

impl Display for Condition {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Self::Expression(expr, AssumeBoolean::AssumeBoolean) => {
write!(f, "!assume_bool({expr})")
}
Self::Expression(expr, AssumeBoolean::DoNotAssumeBoolean) => write!(f, "{expr}"),
Self::AssumeBoolean(expr) => write!(f, "{expr}"),
Self::Comparison(cmp) => write!(f, "{cmp}"),
}
}
Expand Down Expand Up @@ -270,6 +249,7 @@ pub enum Expression {
/// For arbitrary compile-time computations
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum MathExpr {
Binary(HighLevelOperation),
Log2Ceil,
NextMultipleOf,
SaturatingSub,
Expand All @@ -278,6 +258,7 @@ pub enum MathExpr {
impl Display for MathExpr {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Self::Binary(op) => write!(f, "{op}"),
Self::Log2Ceil => write!(f, "log2_ceil"),
Self::NextMultipleOf => write!(f, "next_multiple_of"),
Self::SaturatingSub => write!(f, "saturating_sub"),
Expand All @@ -288,13 +269,18 @@ impl Display for MathExpr {
impl MathExpr {
pub fn num_args(&self) -> usize {
match self {
Self::Binary(_) => 2,
Self::Log2Ceil => 1,
Self::NextMultipleOf => 2,
Self::SaturatingSub => 2,
}
}
pub fn eval(&self, args: &[F]) -> F {
match self {
Self::Binary(op) => {
assert_eq!(args.len(), 2);
op.eval(args[0], args[1])
}
Self::Log2Ceil => {
assert_eq!(args.len(), 1);
let value = args[0];
Expand Down Expand Up @@ -722,9 +708,6 @@ impl Display for ConstExpression {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Self::Value(value) => write!(f, "{value}"),
Self::Binary { left, operation, right } => {
write!(f, "({left} {operation} {right})")
}
Self::MathExpr(math_expr, args) => {
let args_str = args.iter().map(|arg| format!("{arg}")).collect::<Vec<_>>().join(", ");
write!(f, "{math_expr}({args_str})")
Expand Down
Loading