diff --git a/src/interpreter.rs b/src/interpreter.rs index 15ebcef..209abd0 100644 --- a/src/interpreter.rs +++ b/src/interpreter.rs @@ -13,6 +13,7 @@ mod tests; pub struct Interpreter { pub environment: Rc>, + globals: Rc>, } fn clock_impl( @@ -40,6 +41,7 @@ impl Interpreter { ); Self { + globals: Rc::new(RefCell::new(Environment::new())), environment: Rc::new(RefCell::new(environment)), } } @@ -48,7 +50,10 @@ impl Interpreter { let environment = Rc::new(RefCell::new(Environment::new())); environment.borrow_mut().enclosing = Some(parent.clone()); - Self { environment } + Self { + globals: Rc::new(RefCell::new(Environment::new())), + environment, + } } pub fn interpret_statements(&mut self, statements: Vec) -> Result<(), String> { @@ -123,12 +128,8 @@ impl Interpreter { for i in 0..(body.len()) { closure_interpreter.interpret_statements(vec![body[i].clone()])?; - if let Statement::Return { - keyword: _, - value: _, - } = &body[i] - { - return closure_interpreter.environment.borrow_mut().get("return"); + if let Ok(value) = closure_interpreter.globals.borrow().get("return") { + return Ok(value); } } @@ -150,7 +151,7 @@ impl Interpreter { _ => LiteralValue::Nil, }; - self.environment + self.globals .borrow_mut() .define(String::from("return"), response); } diff --git a/src/tests/cases/fun_conditional_return.lox b/src/tests/cases/fun_conditional_return.lox new file mode 100644 index 0000000..a7d6c12 --- /dev/null +++ b/src/tests/cases/fun_conditional_return.lox @@ -0,0 +1,13 @@ +fun condreturn(a) { + if (a <= 0) + { + return 0; + } + + return a - 1; +} + +print condreturn(4); +print condreturn(3); +print condreturn(2); +print condreturn(-1); diff --git a/src/tests/cases/fun_very_nested.lox b/src/tests/cases/fun_very_nested.lox new file mode 100644 index 0000000..3ef5fa7 --- /dev/null +++ b/src/tests/cases/fun_very_nested.lox @@ -0,0 +1,15 @@ +fun nested(a) { + if (a < 3) { + if (a > 1) { + return a; + } + } + { + a = a + 2; + return a; + } + return -1; +} + +print nested(2); +print nested(1); diff --git a/src/tests/interpreter_tests.rs b/src/tests/interpreter_tests.rs index ed3f76d..92070bd 100644 --- a/src/tests/interpreter_tests.rs +++ b/src/tests/interpreter_tests.rs @@ -412,4 +412,90 @@ mod tests { Ok(LiteralValue::Nil) ); } + + #[test] + fn test_function_with_conditional_return() { + let source = " + fun condreturn(a) { + if (a <= 0) + { + return 0; + } + + return a - 1; + } + + var a = condreturn(4); + var b = condreturn(-1); + "; + let mut scanner = Scanner::new(source); + let tokens = scanner.scan_tokens().unwrap(); + + let mut parser = Parser::new(tokens); + let statements = parser.parse().unwrap(); + + let mut interpreter: Interpreter = Interpreter::new(); + let variable_count = interpreter.environment.borrow().values.len(); + let result = interpreter.interpret_statements(statements); + + assert!(result.is_ok()); + + assert_eq!( + interpreter.environment.borrow().values.len(), + variable_count + 3 + ); + assert_eq!( + interpreter.environment.borrow().get("a"), + Ok(LiteralValue::IntValue(3)) + ); + assert_eq!( + interpreter.environment.borrow().get("b"), + Ok(LiteralValue::IntValue(0)) + ); + } + + #[test] + fn test_function_with_nested_blocks() { + let source = " + fun nested(a) { + if (a < 3) { + if (a > 1) { + return a; + } + } + { + a = a + 2; + return a; + } + return -1; + } + + var a = nested(2); + var b = nested(1); + "; + let mut scanner = Scanner::new(source); + let tokens = scanner.scan_tokens().unwrap(); + + let mut parser = Parser::new(tokens); + let statements = parser.parse().unwrap(); + + let mut interpreter: Interpreter = Interpreter::new(); + let variable_count = interpreter.environment.borrow().values.len(); + let result = interpreter.interpret_statements(statements); + + assert!(result.is_ok()); + + assert_eq!( + interpreter.environment.borrow().values.len(), + variable_count + 3 + ); + assert_eq!( + interpreter.environment.borrow().get("a"), + Ok(LiteralValue::IntValue(2)) + ); + assert_eq!( + interpreter.environment.borrow().get("b"), + Ok(LiteralValue::IntValue(3)) + ); + } } diff --git a/src/tests/main_tests.rs b/src/tests/main_tests.rs index 2a751a0..345b8f1 100644 --- a/src/tests/main_tests.rs +++ b/src/tests/main_tests.rs @@ -105,6 +105,28 @@ mod tests { assert_eq!(lines[1], "nil"); } + #[test] + fn function_conditional_return() { + let lines = test_file("./src/tests/cases/fun_conditional_return.lox"); + + assert_eq!(lines.len(), 4); + + assert_eq!(lines[0], "3"); + assert_eq!(lines[1], "2"); + assert_eq!(lines[2], "1"); + assert_eq!(lines[3], "0"); + } + + #[test] + fn function_nested() { + let lines = test_file("./src/tests/cases/fun_very_nested.lox"); + + assert_eq!(lines.len(), 2); + + assert_eq!(lines[0], "2"); + assert_eq!(lines[1], "3"); + } + fn test_file(file_path: &str) -> Vec { let output = Command::new("cargo") .args(["run", file_path])