diff --git a/src/frontend/ast.rs b/src/frontend/ast.rs index 040e966..e4facfe 100644 --- a/src/frontend/ast.rs +++ b/src/frontend/ast.rs @@ -152,6 +152,11 @@ pub enum Statement { generics: Vec,// TODO: add generics fields: Vec<(String, Type)>, }, + AliasDecl { + name: String, + generics: Vec, + ty: Type, + }, } impl Statement { @@ -196,6 +201,10 @@ impl Statement { pub fn new_record_decl(name: String, generics: Vec, fields: Vec<(String, Type)>, line: usize) -> Located { Located::new(Self::RecordDecl { name, generics, fields }, line) } + + pub fn new_alias_decl(name: String, generics: Vec, ty: Type, line: usize) -> Located { + Located::new(Self::AliasDecl { name, generics, ty }, line) + } } pub type Block = Vec>; diff --git a/src/frontend/parser.rs b/src/frontend/parser.rs index 410b20a..31c8269 100644 --- a/src/frontend/parser.rs +++ b/src/frontend/parser.rs @@ -490,6 +490,7 @@ fn statement<'a>(parser: &mut Parser<'a>) -> Option> { TokenKind::LeftBrace => block_statement(parser), TokenKind::Return => return_statement(parser), TokenKind::Record => record_declaration(parser), + TokenKind::Type => alias_declaration(parser), _ => expression_statement(parser), } } @@ -644,7 +645,24 @@ fn record_declaration<'a>(parser: &mut Parser<'a>) -> Option> let name = parser.previous.lexeme.to_string(); let fields = parse_record_fields(parser)?; // TODO: add generic handling - Some(Statement::new_record_decl(name, vec![], fields, line)) + let generics = vec![]; + Some(Statement::new_record_decl(name, generics, fields, line)) +} + +fn alias_declaration<'a>(parser: &mut Parser<'a>) -> Option> { + // advance over the alias keyword + advance(parser); + // advance over the identifier + advance(parser); + let line = parser.previous.line; + let name = parser.previous.lexeme.to_string(); + // TODO: handle generics + let generics = vec![]; + // expect an equals sign + consume(parser, TokenKind::Eq, "Expected an equals sign after alias declaration.")?; + // parse the type + let ty = type_expr(parser)?; + Some(Statement::new_alias_decl(name, generics, ty, line)) } // top level parsing ===== @@ -1003,5 +1021,22 @@ mod tests { ], }); } + + #[test] + fn test_alias_declaration() { + let scanner = Scanner::new("type Vec = [f32; 10]"); + let mut parser = Parser::new(scanner); + let stmt = statement(&mut parser); + assert!(stmt.is_some()); + let alias_decl = stmt.unwrap().node; + assert_eq!(alias_decl, Statement::AliasDecl { + name: "Vec".to_string(), + generics: vec![], + ty: Type::Array { + element_type: Box::new(Type::F32), + size: 10, + }, + }); + } } diff --git a/src/frontend/scanner.rs b/src/frontend/scanner.rs index 2e66994..5a9650e 100644 --- a/src/frontend/scanner.rs +++ b/src/frontend/scanner.rs @@ -46,6 +46,7 @@ pub enum TokenKind { Break, Continue, Record, + Type, // Literals Number, String, @@ -103,6 +104,7 @@ impl Display for TokenKind { TokenKind::Break => write!(f, "break"), TokenKind::Continue => write!(f, "continue"), TokenKind::Record => write!(f, "record"), + TokenKind::Type => write!(f, "type"), TokenKind::Number => write!(f, ""), TokenKind::String => write!(f, ""), TokenKind::Char => write!(f, ""), @@ -215,6 +217,7 @@ impl<'a> Scanner<'a> { "break" => TokenKind::Break, "continue" => TokenKind::Continue, "record" => TokenKind::Record, + "type" => TokenKind::Type, _ => TokenKind::Identifier, }; diff --git a/src/frontend/types.rs b/src/frontend/types.rs index a28f046..8769038 100644 --- a/src/frontend/types.rs +++ b/src/frontend/types.rs @@ -32,14 +32,6 @@ pub enum Type { } } -#[derive(Debug, PartialEq, Clone)] -pub struct Record { - pub name: String, - // type vars - pub generics: Vec, - pub fields: Vec<(String, Type)>, -} - impl Display for Type { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { @@ -83,6 +75,37 @@ impl Display for Type { } } +#[derive(Debug, PartialEq, Clone)] +pub enum Udt { + Record { + name: String, + generics: Vec, + fields: Vec<(String, Type)>, + }, + Alias { + name: String, + generics: Vec, + ty: Type, + } +} + +impl Udt { + pub fn new_record(name: String, generics: Vec, fields: Vec<(String, Type)>) -> Self { + Self::Record { name, generics, fields } + } + + pub fn new_alias(name: String, generics: Vec, ty: Type) -> Self { + Self::Alias { name, generics, ty } + } + + pub fn name(&self) -> &str { + match self { + Udt::Record { name, .. } => name, + Udt::Alias { name, .. } => name, + } + } +} + type FnSig = (Vec, Type); #[derive(Clone)] @@ -90,7 +113,7 @@ pub struct TypeEnv { pub fn_sig: Option, parent: Option>>, bindings: HashMap, - record_types: HashMap, + udts: HashMap, } impl TypeEnv { @@ -100,7 +123,7 @@ impl TypeEnv { fn_sig: None, parent: None, bindings: HashMap::new(), - record_types: HashMap::new(), + udts: HashMap::new(), } } @@ -110,7 +133,7 @@ impl TypeEnv { fn_sig: None, parent: Some(Rc::new(RefCell::new(self.clone()))), bindings: HashMap::new(), - record_types: HashMap::new(), + udts: HashMap::new(), } } @@ -120,7 +143,7 @@ impl TypeEnv { fn_sig: Some((params, return_type)), parent: Some(Rc::new(RefCell::new(self.clone()))), bindings: HashMap::new(), - record_types: HashMap::new(), + udts: HashMap::new(), } } @@ -131,8 +154,8 @@ impl TypeEnv { .or_else(|| self.parent.as_ref().and_then(|p| p.borrow().get(ident))) } - pub fn get_record(&self, ident: &str) -> Option { - self.record_types.get(ident).cloned() + pub fn get_udt(&self, ident: &str) -> Option { + self.udts.get(ident).cloned() } /// inserts a type into the top level type environment @@ -144,16 +167,16 @@ impl TypeEnv { } } - pub fn insert_record_top(&mut self, record: Record) -> Result<(), String> { + pub fn insert_udt_top(&mut self, udt: Udt) -> Result<(), String> { if self.parent.is_none() { // ensure that the record name is not already in the type env - if self.record_types.contains_key(&record.name) { - return Err(format!("Record {} already exists in the type environment", record.name)); + if self.udts.contains_key(udt.name()) { + return Err(format!("UDT {} already exists in the type environment", udt.name())); } - self.record_types.insert(record.name.clone(), record); + self.udts.insert(udt.name().to_string(), udt); Ok(()) } else { - self.parent.as_ref().unwrap().borrow_mut().insert_record_top(record) + self.parent.as_ref().unwrap().borrow_mut().insert_udt_top(udt) } } diff --git a/src/passes/type_pass.rs b/src/passes/type_pass.rs index e930298..8b5a877 100644 --- a/src/passes/type_pass.rs +++ b/src/passes/type_pass.rs @@ -1,4 +1,4 @@ -use crate::{Block, Expression, Program, Record, Statement, TokenKind, Type, TypeEnv}; +use crate::{Block, Expression, Program, Statement, TokenKind, Type, TypeEnv, Udt}; // TODO: add handling for type var identifiers in expressions such as let x: TypeVar = 1; then x is a type var and let y = x + 1; will need resolution // TODO: add handling for function calls in which the the callee has type vars in its signature @@ -13,7 +13,10 @@ fn add_udt_pass(program: &Program, type_env: &mut TypeEnv) -> Result { - type_env.insert_record_top(Record { name: name.clone(), generics: generics.clone(), fields: fields.clone() })?; + type_env.insert_udt_top(Udt::Record { name: name.clone(), generics: generics.clone(), fields: fields.clone() })?; + } + Statement::AliasDecl { name, generics, ty } => { + type_env.insert_udt_top(Udt::Alias { name: name.clone(), generics: generics.clone(), ty: ty.clone() })?; } _ => new_program.add_statement(stmt.clone()), } @@ -406,6 +409,22 @@ mod tests { env } + #[test] + fn test_alias_declaration() { + let mut program = Program::new(); + program.add_statement(Statement::new_alias_decl("Vec".to_string(), vec![], Type::Array { element_type: Box::new(Type::I64), size: 10 }, 1)); + let new_env = type_check_program(&program).unwrap(); + assert_eq!(new_env.get_udt("Vec"), Some(Udt::Alias { name: "Vec".to_string(), generics: vec![], ty: Type::Array { element_type: Box::new(Type::I64), size: 10 } })); + } + + #[test] + fn test_record_declaration() { + let mut program = Program::new(); + program.add_statement(Statement::new_record_decl("Point".to_string(), vec![], vec![("x".to_string(), Type::I64), ("y".to_string(), Type::I64)], 1)); + let new_env = type_check_program(&program).unwrap(); + assert_eq!(new_env.get_udt("Point"), Some(Udt::Record { name: "Point".to_string(), generics: vec![], fields: vec![("x".to_string(), Type::I64), ("y".to_string(), Type::I64)] })); + } + #[test] fn test_literal_types() { let mut env = create_test_env();