Skip to content

Commit 34717c3

Browse files
Merge pull request #101 from NLnetLabs/testing-in-roto-scripts
Add some basic testing infrastructure for Roto scripts
2 parents 52257cf + f810919 commit 34717c3

17 files changed

+358
-217
lines changed

examples/simple.roto

+21-8
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,25 @@
1+
function is_zero(x: IpAddr) -> bool {
2+
x == 0.0.0.0
3+
}
4+
15
filter-map main(x: IpAddr) {
2-
define {
3-
y = 0.0.0.0;
6+
if is_zero(x) {
7+
accept
8+
} else {
9+
reject
10+
}
11+
}
12+
13+
test is_zero_true {
14+
if is_zero(1.1.1.1) {
15+
reject
416
}
5-
apply {
6-
if x == y {
7-
accept
8-
} else {
9-
reject
10-
}
17+
accept
18+
}
19+
20+
test is_zero_false {
21+
if not is_zero(0.0.0.0) {
22+
reject
1123
}
24+
accept
1225
}

examples/simple.rs

+10
Original file line numberDiff line numberDiff line change
@@ -40,5 +40,15 @@ fn main() -> Result<(), roto::RotoReport> {
4040
let res = func.call(&mut (), "1.1.1.1".parse().unwrap());
4141
println!("main(1.1.1.1) = {res:?}");
4242

43+
let is_zero = compiled
44+
.get_function::<(), (IpAddr,), bool>("is_zero")
45+
.unwrap();
46+
47+
let res = is_zero.call(&mut (), "0.0.0.0".parse().unwrap());
48+
println!("is_zero(0.0.0.0) = {res:?}");
49+
50+
println!();
51+
let _ = compiled.run_tests(());
52+
4353
Ok(())
4454
}

src/ast.rs

+8-1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ pub enum Declaration {
2323
OutputStream(OutputStream),
2424
Record(RecordTypeDeclaration),
2525
Function(FunctionDeclaration),
26+
Test(Test),
2627
}
2728

2829
#[derive(Clone, Debug)]
@@ -46,7 +47,7 @@ pub struct FilterMap {
4647
pub filter_type: FilterType,
4748
pub ident: Meta<Identifier>,
4849
pub params: Meta<Params>,
49-
pub block: Meta<Block>,
50+
pub body: Meta<Block>,
5051
}
5152

5253
/// A function declaration, including the [`Block`] forming its definition
@@ -58,6 +59,12 @@ pub struct FunctionDeclaration {
5859
pub body: Meta<Block>,
5960
}
6061

62+
#[derive(Clone, Debug)]
63+
pub struct Test {
64+
pub ident: Meta<Identifier>,
65+
pub body: Meta<Block>,
66+
}
67+
6168
/// A block of multiple statements
6269
#[derive(Clone, Debug)]
6370
pub struct Block {

src/codegen/check.rs

-12
Original file line numberDiff line numberDiff line change
@@ -162,18 +162,6 @@ fn check_roto_type(
162162
}
163163
}
164164

165-
pub fn return_type_by_ref(registry: &TypeRegistry, rust_ty: TypeId) -> bool {
166-
let Some(rust_ty) = registry.get(rust_ty) else {
167-
return false;
168-
};
169-
170-
#[allow(clippy::match_like_matches_macro)]
171-
match rust_ty.description {
172-
TypeDescription::Verdict(_, _) => true,
173-
_ => todo!(),
174-
}
175-
}
176-
177165
/// Parameters of a Roto function
178166
///
179167
/// This trait allows for checking the types against Roto types and converting

src/codegen/mod.rs

+53-13
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,12 @@ use crate::{
1616
value::IrType,
1717
IrFunction,
1818
},
19-
runtime::{
20-
context::ContextDescription,
21-
ty::{Reflect, GLOBAL_TYPE_REGISTRY},
22-
RuntimeConstant,
23-
},
19+
runtime::{context::ContextDescription, ty::Reflect, RuntimeConstant},
2420
typechecker::{info::TypeInfo, scope::ScopeRef, types},
25-
IrValue,
21+
IrValue, Verdict,
2622
};
2723
use check::{
28-
check_roto_type_reflect, return_type_by_ref, FunctionRetrievalError,
29-
RotoParams, TypeMismatch,
24+
check_roto_type_reflect, FunctionRetrievalError, RotoParams, TypeMismatch,
3025
};
3126
use cranelift::{
3227
codegen::{
@@ -177,6 +172,7 @@ call_impl!(A, B, C, D, E, F, G);
177172
pub struct FunctionInfo {
178173
id: FuncId,
179174
signature: types::Signature,
175+
return_by_ref: bool,
180176
}
181177

182178
struct ModuleBuilder {
@@ -377,6 +373,7 @@ impl ModuleBuilder {
377373

378374
let mut sig = self.inner.make_signature();
379375

376+
// This is the parameter for the context
380377
sig.params
381378
.push(AbiParam::new(self.cranelift_type(&IrType::Pointer)));
382379

@@ -406,6 +403,7 @@ impl ModuleBuilder {
406403
name.to_string(),
407404
FunctionInfo {
408405
id: func_id,
406+
return_by_ref: ir_signature.return_ptr,
409407
signature: signature.clone(),
410408
},
411409
);
@@ -1072,6 +1070,52 @@ impl<'c> FuncGen<'c> {
10721070
}
10731071

10741072
impl Module {
1073+
pub fn run_tests<Ctx: 'static>(
1074+
&mut self,
1075+
mut ctx: Ctx,
1076+
) -> Result<(), ()> {
1077+
let tests: Vec<_> = self
1078+
.functions
1079+
.keys()
1080+
.filter(|x| x.starts_with("test#"))
1081+
.map(Clone::clone)
1082+
.collect();
1083+
1084+
let total = tests.len();
1085+
let total_width = total.to_string().len();
1086+
let mut successes = 0;
1087+
let mut failures = 0;
1088+
1089+
for (n, test) in tests.into_iter().enumerate() {
1090+
let n = n + 1;
1091+
let test_display = test.strip_prefix("test#").unwrap();
1092+
print!("Test {n:>total_width$} / {total}: {test_display}... ");
1093+
let test_fn = self
1094+
.get_function::<Ctx, (), Verdict<(), ()>>(&test)
1095+
.unwrap();
1096+
1097+
match test_fn.call(&mut ctx) {
1098+
Verdict::Accept(()) => {
1099+
successes += 1;
1100+
println!("\x1B[92mok\x1B[m");
1101+
}
1102+
Verdict::Reject(()) => {
1103+
failures += 1;
1104+
println!("\x1B[91mfail\x1B[m");
1105+
}
1106+
}
1107+
}
1108+
println!(
1109+
"Ran {total} tests, {successes} succeeded, {failures} failed"
1110+
);
1111+
1112+
if failures == 0 {
1113+
Result::Ok(())
1114+
} else {
1115+
Result::Err(())
1116+
}
1117+
}
1118+
10751119
pub fn get_function<Ctx: 'static, Params: RotoParams, Return: Reflect>(
10761120
&mut self,
10771121
name: &str,
@@ -1109,14 +1153,10 @@ impl Module {
11091153
)
11101154
})?;
11111155

1112-
let registry = GLOBAL_TYPE_REGISTRY.lock().unwrap();
1113-
let return_by_ref =
1114-
return_type_by_ref(&registry, TypeId::of::<Return>());
1115-
11161156
let func_ptr = self.inner.0.get_finalized_function(id);
11171157
Ok(TypedFunc {
11181158
func: func_ptr,
1119-
return_by_ref,
1159+
return_by_ref: function_info.return_by_ref,
11201160
_module: self.inner.clone(),
11211161
_ty: PhantomData,
11221162
})

src/codegen/tests.rs

+63
Original file line numberDiff line numberDiff line change
@@ -905,6 +905,69 @@ fn use_context() {
905905
assert_eq!(output, Verdict::Accept(11));
906906
}
907907

908+
#[test]
909+
fn use_a_roto_function() {
910+
let s = src!(
911+
"
912+
function double(x: i32) -> i32 {
913+
2 * x
914+
}"
915+
);
916+
917+
let mut p = compile(s);
918+
let f = p.get_function::<(), (i32,), i32>("double").unwrap();
919+
let output = f.call(&mut (), 2);
920+
assert_eq!(output, 4);
921+
922+
let output = f.call(&mut (), 16);
923+
assert_eq!(output, 32);
924+
}
925+
926+
#[test]
927+
fn use_a_test() {
928+
let s = src!(
929+
"
930+
function double(x: i32) -> i32 {
931+
x # oops! not correct
932+
}
933+
934+
test check_double {
935+
if double(4) != 8 {
936+
reject;
937+
}
938+
if double(16) != 32 {
939+
reject;
940+
}
941+
accept
942+
}
943+
"
944+
);
945+
946+
let mut p = compile(s);
947+
p.run_tests(()).unwrap_err();
948+
949+
let s = src!(
950+
"
951+
function double(x: i32) -> i32 {
952+
2 * x
953+
}
954+
955+
test check_double {
956+
if double(4) != 8 {
957+
reject;
958+
}
959+
if double(16) != 32 {
960+
reject;
961+
}
962+
accept
963+
}
964+
"
965+
);
966+
967+
let mut p = compile(s);
968+
p.run_tests(()).unwrap();
969+
}
970+
908971
#[test]
909972
fn string() {
910973
let s = src!(

src/lower/ir.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -367,8 +367,8 @@ impl<'a> IrPrinter<'a> {
367367
} => format!(
368368
"{}: {ty} = {}({}, {})",
369369
self.var(to),
370-
self.operand(ctx),
371370
self.ident(func),
371+
self.operand(ctx),
372372
args.iter()
373373
.map(|a| format!(
374374
"{} = {}",

src/lower/match_expr.rs

+18-15
Original file line numberDiff line numberDiff line change
@@ -213,12 +213,13 @@ impl Lowerer<'_> {
213213
let ty = self.type_info.type_of(&arm.body);
214214
let val = self.block(&arm.body);
215215
if let Some(val) = val {
216-
let ty = self.lower_type(&ty);
217-
self.add(Instruction::Assign {
218-
to: out.clone(),
219-
val,
220-
ty,
221-
});
216+
if let Some(ty) = self.lower_type(&ty) {
217+
self.add(Instruction::Assign {
218+
to: out.clone(),
219+
val,
220+
ty,
221+
});
222+
}
222223
any_assigned = true;
223224
}
224225
self.add(Instruction::Jump(continue_lbl));
@@ -267,15 +268,17 @@ impl Lowerer<'_> {
267268
1 + self.type_info.padding_of(&ty, 1, self.runtime);
268269
let val =
269270
self.read_field(examinee.clone().into(), offset, &ty);
270-
let ty = self.lower_type(&ty);
271-
self.add(Instruction::Assign {
272-
to: Var {
273-
scope,
274-
kind: VarKind::Explicit(ident),
275-
},
276-
val,
277-
ty,
278-
});
271+
if let Some(val) = val {
272+
let ty = self.lower_type(&ty).unwrap();
273+
self.add(Instruction::Assign {
274+
to: Var {
275+
scope,
276+
kind: VarKind::Explicit(ident),
277+
},
278+
val,
279+
ty,
280+
});
281+
}
279282
}
280283

281284
let ident = Identifier::from(format!("guard_{}", i + 1));

0 commit comments

Comments
 (0)