Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optional values and big typecheck refactor #118

Merged
merged 3 commits into from
Mar 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 14 additions & 12 deletions src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ pub enum Declaration {
}

#[derive(Clone, Debug)]
pub struct Params(pub Vec<(Meta<Identifier>, Meta<Path>)>);
pub struct Params(pub Vec<(Meta<Identifier>, Meta<TypeExpr>)>);

/// The value of a typed record
#[derive(Clone, Debug)]
Expand Down Expand Up @@ -53,7 +53,7 @@ pub struct FilterMap {
pub struct FunctionDeclaration {
pub ident: Meta<Identifier>,
pub params: Meta<Params>,
pub ret: Option<Meta<Path>>,
pub ret: Option<Meta<TypeExpr>>,
pub body: Meta<Block>,
}

Expand Down Expand Up @@ -83,6 +83,15 @@ pub struct Path {
pub idents: Vec<Meta<Identifier>>,
}

#[derive(Clone, Debug)]
pub enum TypeExpr {
Optional(Box<TypeExpr>),
Path(Meta<Path>, Vec<Meta<TypeExpr>>),
Never,
Unit,
Record(RecordType),
}

/// A Roto expression
#[derive(Clone, Debug)]
pub enum Expr {
Expand Down Expand Up @@ -173,7 +182,7 @@ pub enum Pattern {
Underscore,
EnumVariant {
variant: Meta<Identifier>,
data_field: Option<Meta<Identifier>>,
fields: Option<Meta<Vec<Meta<Identifier>>>>,
},
}

Expand Down Expand Up @@ -216,14 +225,7 @@ impl From<String> for Identifier {

#[derive(Clone, Debug)]
pub struct RecordType {
pub key_values: Meta<Vec<(Meta<Identifier>, RecordFieldType)>>,
}

#[derive(Clone, Debug)]
pub enum RecordFieldType {
Path(Meta<Path>),
Record(Meta<RecordType>),
List(Meta<Box<RecordFieldType>>),
pub fields: Meta<Vec<(Meta<Identifier>, Meta<TypeExpr>)>>,
}

#[derive(Clone, Debug)]
Expand All @@ -236,7 +238,7 @@ pub enum Literal {
Bool(bool),
}

#[derive(Clone, Debug)]
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum BinOp {
/// Logical and (`&&`)
And,
Expand Down
119 changes: 83 additions & 36 deletions src/codegen/check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ use crate::{
},
typechecker::{
info::TypeInfo,
types::{Primitive, Type},
scope::{ResolvedName, ScopeRef},
types::{Type, TypeDefinition},
},
};
use std::{
Expand Down Expand Up @@ -105,36 +106,43 @@ fn check_roto_type(
let mut roto_ty = type_info.resolve(roto_ty);

if let Type::IntVar(_) = roto_ty {
roto_ty = Type::Primitive(Primitive::I32);
roto_ty = Type::named("i32", Vec::new());
}

match rust_ty.description {
TypeDescription::Leaf => {
let expected_roto = match rust_ty.type_id {
x if x == BOOL => Type::Primitive(Primitive::Bool),
x if x == U8 => Type::Primitive(Primitive::U8),
x if x == U16 => Type::Primitive(Primitive::U16),
x if x == U32 => Type::Primitive(Primitive::U32),
x if x == U64 => Type::Primitive(Primitive::U64),
x if x == I8 => Type::Primitive(Primitive::I8),
x if x == I16 => Type::Primitive(Primitive::I16),
x if x == I32 => Type::Primitive(Primitive::I32),
x if x == I64 => Type::Primitive(Primitive::I64),
x if x == UNIT => Type::Primitive(Primitive::Unit),
x if x == ASN => Type::Primitive(Primitive::Asn),
x if x == IPADDR => Type::Primitive(Primitive::IpAddr),
x if x == PREFIX => Type::Primitive(Primitive::Prefix),
x if x == STRING => Type::Primitive(Primitive::String),
let expected_name = match rust_ty.type_id {
x if x == BOOL => "bool",
x if x == U8 => "u8",
x if x == U16 => "u16",
x if x == U32 => "u32",
x if x == U64 => "u64",
x if x == I8 => "i8",
x if x == I16 => "i16",
x if x == I32 => "i32",
x if x == I64 => "i64",
x if x == UNIT => "Unit",
x if x == ASN => "Asn",
x if x == IPADDR => "IpAddr",
x if x == PREFIX => "Prefix",
x if x == STRING => "String",
_ => panic!(),
};
let expected_roto = Type::named(expected_name, Vec::new());
if expected_roto == roto_ty {
Ok(())
} else {
Err(error_message)
}
}
TypeDescription::Val(ty) => {
let Type::BuiltIn(_, id) = roto_ty else {
let Type::Name(type_name) = roto_ty else {
return Err(error_message);
};

let TypeDefinition::Runtime(_, id) =
type_info.resolve_type_name(&type_name)
else {
return Err(error_message);
};

Expand All @@ -147,18 +155,48 @@ fn check_roto_type(
TypeDescription::ConstPtr(_) => Err(error_message),
TypeDescription::MutPtr(_) => Err(error_message), // TODO: actually check this
TypeDescription::Verdict(rust_accept, rust_reject) => {
let Type::Verdict(roto_accept, roto_reject) = &roto_ty else {
let Type::Name(type_name) = &roto_ty else {
return Err(error_message);
};

if type_name.name
!= (ResolvedName {
scope: ScopeRef::GLOBAL,
ident: "Verdict".into(),
})
{
return Err(error_message);
}

let [roto_accept, roto_reject] = &type_name.arguments[..] else {
return Err(error_message);
};

check_roto_type(registry, type_info, rust_accept, roto_accept)?;
check_roto_type(registry, type_info, rust_reject, roto_reject)?;
Ok(())
}
// We don't do options and results, we should hint towards verdict
// when using them.
TypeDescription::Option(_) | TypeDescription::Result(_, _) => {
Err(error_message)
TypeDescription::Option(rust_ty) => {
let Type::Name(type_name) = &roto_ty else {
return Err(error_message);
};

if type_name.name
!= (ResolvedName {
scope: ScopeRef::GLOBAL,
ident: "Optional".into(),
})
{
return Err(error_message);
}

let [roto_ty] = &type_name.arguments[..] else {
return Err(error_message);
};
check_roto_type(registry, type_info, rust_ty, roto_ty)
}
// We don't do results, we should hint towards verdict when using them.
TypeDescription::Result(_, _) => Err(error_message),
}
}

Expand All @@ -172,18 +210,20 @@ fn check_roto_type(
///
/// This trait is implemented on tuples of various sizes.
pub trait RotoParams {
type Transformed;
/// This type but with [`Reflect::AsParam`] applied to each element.
type AsParams;

fn transform(self) -> Self::Transformed;

/// Convert to `Self::AsParams`.
fn as_params(&mut self) -> Self::AsParams;
fn as_params(transformed: &mut Self::Transformed) -> Self::AsParams;

/// Check whether these parameters match a parameter list from Roto.
fn check(
type_info: &mut TypeInfo,
ty: &[Type],
) -> Result<(), FunctionRetrievalError>;

/// Call a function pointer as if it were a function with these parameters.
///
/// This is _extremely_ unsafe, do not pass this arbitrary pointers and
Expand Down Expand Up @@ -214,14 +254,18 @@ macro_rules! params {
#[allow(unused_variables)]
#[allow(unused_mut)]
impl<$($t,)*> RotoParams for ($($t,)*)
where
$($t: Reflect,)*
{
where $($t: Reflect,)* {
type Transformed = ($($t::Transformed,)*);
type AsParams = ($($t::AsParam,)*);

fn as_params(&mut self) -> Self::AsParams {
fn transform(self) -> Self::Transformed {
let ($($t,)*) = self;
return ($($t.as_param(),)*);
return ($($t.transform(),)*);
}

fn as_params(transformed: &mut Self::Transformed) -> Self::AsParams {
let ($($t,)*) = transformed;
return ($($t::as_param($t),)*);
}

fn check(
Expand All @@ -246,21 +290,24 @@ macro_rules! params {
Ok(())
}

unsafe fn invoke<Ctx: 'static, Return: Reflect>(mut self, ctx: &mut Ctx, func_ptr: *const u8, return_by_ref: bool) -> Return {
let ($($t,)*) = self.as_params();
unsafe fn invoke<Ctx: 'static, Return: Reflect>(self, ctx: &mut Ctx, func_ptr: *const u8, return_by_ref: bool) -> Return {
let mut transformed = <Self as RotoParams>::transform(self);
let ($($t,)*) = <Self as RotoParams>::as_params(&mut transformed);

// We forget values that we pass into Roto. The script is responsible
// for cleaning them op. Forgetting copy types does nothing, but that's
// fine.
#[allow(forgetting_copy_types)]
std::mem::forget(self);
std::mem::forget(transformed);
if return_by_ref {
let func_ptr = unsafe {
std::mem::transmute::<*const u8, fn(*mut Return, *mut Ctx, $($t::AsParam),*) -> ()>(func_ptr)
std::mem::transmute::<*const u8, fn(*mut Return::Transformed, *mut Ctx, $($t::AsParam),*) -> ()>(func_ptr)
};
let mut ret = MaybeUninit::<Return>::uninit();
let mut ret = MaybeUninit::<Return::Transformed>::uninit();
func_ptr(ret.as_mut_ptr(), ctx as *mut Ctx, $($t),*);
unsafe { ret.assume_init() }
let transformed_ret = unsafe { ret.assume_init() };
let ret: Return = Return::untransform(transformed_ret);
ret
} else {
let func_ptr = unsafe {
std::mem::transmute::<*const u8, fn(*mut Ctx, $($t::AsParam),*) -> Return>(func_ptr)
Expand Down
33 changes: 10 additions & 23 deletions src/codegen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -768,7 +768,7 @@ impl<'c> FuncGen<'c> {
}
ir::Instruction::Div {
to,
ty,
signed,
left,
right,
} => {
Expand All @@ -777,14 +777,9 @@ impl<'c> FuncGen<'c> {

let var = self.variable(to, left_ty);

let val = match ty {
IrType::I8 | IrType::I16 | IrType::I32 | IrType::I64 => {
self.ins().sdiv(l, r)
}
IrType::U8 | IrType::U16 | IrType::U32 | IrType::U64 => {
self.ins().udiv(l, r)
}
_ => panic!(),
let val = match signed {
true => self.ins().sdiv(l, r),
false => self.ins().udiv(l, r),
};
self.def(var, val)
}
Expand All @@ -797,34 +792,26 @@ impl<'c> FuncGen<'c> {
self.def(var, val)
}
ir::Instruction::Eq { .. } => todo!(),
ir::Instruction::Alloc {
to,
size,
align_shift,
} => {
ir::Instruction::Alloc { to, layout } => {
let slot =
self.builder.create_sized_stack_slot(StackSlotData::new(
StackSlotKind::ExplicitSlot,
*size,
*align_shift,
layout.size() as u32,
layout.align_shift() as u8,
));

let pointer_ty = self.module.isa.pointer_type();
let var = self.variable(to, pointer_ty);
let p = self.ins().stack_addr(pointer_ty, slot, 0);
self.def(var, p);
}
ir::Instruction::Initialize {
to,
bytes,
align_shift,
} => {
ir::Instruction::Initialize { to, bytes, layout } => {
let pointer_ty = self.module.isa.pointer_type();
let slot =
self.builder.create_sized_stack_slot(StackSlotData::new(
StackSlotKind::ExplicitSlot,
bytes.len() as u32,
*align_shift,
layout.size() as u32,
layout.align_shift() as u8,
));

let data_id = self
Expand Down
Loading