Skip to content

Commit 0b3ccda

Browse files
authored
Merge pull request #285 from Pivot-Studio/Chronostasys/issue283
feat: add fntype
2 parents acc8782 + 659ba29 commit 0b3ccda

File tree

16 files changed

+1164
-755
lines changed

16 files changed

+1164
-755
lines changed

src/ast/builder/llvmbuilder.rs

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ use inkwell::{
3131
};
3232
use rustc_hash::FxHashMap;
3333

34-
use crate::ast::{diag::PLDiag, pass::run_immix_pass};
34+
use crate::ast::{diag::PLDiag, pass::run_immix_pass, pltype::ClosureType};
3535

3636
use super::{
3737
super::{
@@ -587,6 +587,22 @@ impl<'a, 'ctx> LLVMBuilder<'a, 'ctx> {
587587
fn_type
588588
})
589589
}
590+
591+
fn get_closure_fn_type(&self, closure: &ClosureType, ctx: &mut Ctx<'a>) -> FunctionType<'ctx> {
592+
let params = closure
593+
.arg_types
594+
.iter()
595+
.map(|pltype| {
596+
let tp = self.get_basic_type_op(&pltype.borrow(), ctx).unwrap();
597+
let tp: BasicMetadataTypeEnum = tp.into();
598+
tp
599+
})
600+
.collect::<Vec<_>>();
601+
let fn_type = self
602+
.get_ret_type(&closure.ret_type.borrow(), ctx)
603+
.fn_type(&params, false);
604+
fn_type
605+
}
590606
/// # get_basic_type_op
591607
/// get the basic type of the type
592608
/// used in code generation
@@ -648,6 +664,19 @@ impl<'a, 'ctx> LLVMBuilder<'a, 'ctx> {
648664
];
649665
Some(self.context.struct_type(&fields, false).into())
650666
}
667+
PLType::Closure(c) => {
668+
// all closures are represented as a struct with a function pointer and an i8ptr(point to closure data)
669+
let fields = vec![
670+
self.get_closure_fn_type(c, ctx)
671+
.ptr_type(AddressSpace::default())
672+
.into(),
673+
self.context
674+
.i8_type()
675+
.ptr_type(AddressSpace::default())
676+
.into(),
677+
];
678+
Some(self.context.struct_type(&fields, false).into())
679+
}
651680
}
652681
}
653682
/// # get_ret_type
@@ -936,6 +965,7 @@ impl<'a, 'ctx> LLVMBuilder<'a, 'ctx> {
936965
);
937966
Some(tp.as_type())
938967
}
968+
PLType::Closure(_) => self.get_ditype(&PLType::Primitive(PriType::I64), ctx), // TODO
939969
}
940970
}
941971

@@ -1122,6 +1152,9 @@ impl<'a, 'ctx> IRBuilder<'a, 'ctx> for LLVMBuilder<'a, 'ctx> {
11221152
| AnyValueEnum::PointerValue(_)
11231153
| AnyValueEnum::StructValue(_)
11241154
| AnyValueEnum::VectorValue(_) => handle,
1155+
AnyValueEnum::FunctionValue(f) => {
1156+
return Ok(self.get_llvm_value_handle(&f.as_global_value().as_any_value_enum()));
1157+
}
11251158
_ => return Err(ctx.add_diag(range.new_err(ErrorCode::EXPECT_VALUE))),
11261159
})
11271160
} else {
@@ -1251,8 +1284,15 @@ impl<'a, 'ctx> IRBuilder<'a, 'ctx> for LLVMBuilder<'a, 'ctx> {
12511284
let value = self.get_llvm_value(value).unwrap();
12521285
let ptr = self.get_llvm_value(ptr).unwrap();
12531286
let ptr = ptr.into_pointer_value();
1254-
self.builder
1255-
.build_store::<BasicValueEnum>(ptr, value.try_into().unwrap());
1287+
let value = if value.is_function_value() {
1288+
value
1289+
.into_function_value()
1290+
.as_global_value()
1291+
.as_basic_value_enum()
1292+
} else {
1293+
value.try_into().unwrap()
1294+
};
1295+
self.builder.build_store(ptr, value);
12561296
}
12571297
fn build_const_in_bounds_gep(
12581298
&self,

src/ast/ctx.rs

Lines changed: 55 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ use super::traits::CustomType;
2323

2424
use crate::ast::builder::BuilderEnum;
2525
use crate::ast::builder::IRBuilder;
26+
use crate::format_label;
2627
use crate::lsp::semantic_tokens::type_index;
2728

2829
use crate::mismatch_err;
@@ -214,14 +215,46 @@ impl<'a, 'ctx> Ctx<'a> {
214215
}
215216
pub fn up_cast<'b>(
216217
&mut self,
217-
trait_pltype: Arc<RefCell<PLType>>,
218-
st_pltype: Arc<RefCell<PLType>>,
219-
trait_range: Range,
220-
st_range: Range,
221-
st_value: usize,
218+
target_pltype: Arc<RefCell<PLType>>,
219+
ori_pltype: Arc<RefCell<PLType>>,
220+
target_range: Range,
221+
ori_range: Range,
222+
ori_value: usize,
222223
builder: &'b BuilderEnum<'a, 'ctx>,
223224
) -> Result<usize, PLDiag> {
224-
if let PLType::Union(u) = &*trait_pltype.borrow() {
225+
if let (PLType::Closure(c), PLType::Fn(f)) =
226+
(&*target_pltype.borrow(), &*ori_pltype.borrow())
227+
{
228+
if f.to_closure_ty(self, builder) != *c {
229+
return Err(ori_range
230+
.new_err(ErrorCode::FUNCTION_TYPE_NOT_MATCH)
231+
.add_label(
232+
target_range,
233+
self.get_file(),
234+
format_label!("expected type `{}`", c.get_name()),
235+
)
236+
.add_label(
237+
ori_range,
238+
self.get_file(),
239+
format_label!("found type `{}`", f.to_closure_ty(self, builder).get_name()),
240+
)
241+
.add_to_ctx(self));
242+
}
243+
if ori_value == usize::MAX {
244+
return Err(ori_range
245+
.new_err(ErrorCode::CANNOT_ASSIGN_INCOMPLETE_GENERICS)
246+
.add_help("try add generic type explicitly to fix this error.")
247+
.add_to_ctx(self));
248+
}
249+
let closure_v = builder.alloc("tmp", &target_pltype.borrow(), self, None);
250+
let closure_f = builder.build_struct_gep(closure_v, 0, "closure_f").unwrap();
251+
let ori_value = builder.try_load2var(ori_range, ori_value, self)?;
252+
// TODO now, we only handle the case that the closure is a pure function.
253+
// TODO the real closure case is leave to the future.
254+
builder.build_store(closure_f, ori_value);
255+
return Ok(closure_v);
256+
}
257+
if let PLType::Union(u) = &*target_pltype.borrow() {
225258
let union_members = self.run_in_type_mod(u, |ctx, u| {
226259
let mut union_members = vec![];
227260
for tp in &u.sum_types {
@@ -231,9 +264,9 @@ impl<'a, 'ctx> Ctx<'a> {
231264
Ok(union_members)
232265
})?;
233266
for (i, tp) in union_members.iter().enumerate() {
234-
if *tp.borrow() == *st_pltype.borrow() {
267+
if *tp.borrow() == *ori_pltype.borrow() {
235268
let union_handle =
236-
builder.alloc("tmp_unionv", &trait_pltype.borrow(), self, None);
269+
builder.alloc("tmp_unionv", &target_pltype.borrow(), self, None);
237270
let union_value = builder
238271
.build_struct_gep(union_handle, 1, "union_value")
239272
.unwrap();
@@ -242,11 +275,11 @@ impl<'a, 'ctx> Ctx<'a> {
242275
.unwrap();
243276
let union_type = builder.int_value(&PriType::U64, i as u64, false);
244277
builder.build_store(union_type_field, union_type);
245-
let mut ptr = st_value;
246-
if !builder.is_ptr(st_value) {
278+
let mut ptr = ori_value;
279+
if !builder.is_ptr(ori_value) {
247280
// mv to heap
248-
ptr = builder.alloc("tmp", &st_pltype.borrow(), self, None);
249-
builder.build_store(ptr, st_value);
281+
ptr = builder.alloc("tmp", &ori_pltype.borrow(), self, None);
282+
builder.build_store(ptr, ori_value);
250283
}
251284
let st_value = builder.bitcast(
252285
self,
@@ -260,20 +293,20 @@ impl<'a, 'ctx> Ctx<'a> {
260293
}
261294
}
262295
}
263-
let (st_pltype, st_value) = self.auto_deref(st_pltype, st_value, builder);
296+
let (st_pltype, st_value) = self.auto_deref(ori_pltype, ori_value, builder);
264297
if let (PLType::Trait(t), PLType::Struct(st)) =
265-
(&*trait_pltype.borrow(), &*st_pltype.borrow())
298+
(&*target_pltype.borrow(), &*st_pltype.borrow())
266299
{
267300
if !st.implements_trait(t, &self.plmod) {
268301
return Err(mismatch_err!(
269302
self,
270-
st_range,
271-
trait_range,
272-
trait_pltype.borrow(),
303+
ori_range,
304+
target_range,
305+
target_pltype.borrow(),
273306
st_pltype.borrow()
274307
));
275308
}
276-
let trait_handle = builder.alloc("tmp_traitv", &trait_pltype.borrow(), self, None);
309+
let trait_handle = builder.alloc("tmp_traitv", &target_pltype.borrow(), self, None);
277310
for f in t.list_trait_fields().iter() {
278311
let mthd = st.find_method(self, &f.name).unwrap();
279312
let fnhandle = builder.get_or_insert_fn_handle(&mthd, self);
@@ -303,9 +336,9 @@ impl<'a, 'ctx> Ctx<'a> {
303336
#[allow(clippy::needless_return)]
304337
return Err(mismatch_err!(
305338
self,
306-
st_range,
307-
trait_range,
308-
trait_pltype.borrow(),
339+
ori_range,
340+
target_range,
341+
target_pltype.borrow(),
309342
st_pltype.borrow()
310343
));
311344
}
@@ -807,6 +840,7 @@ impl<'a, 'ctx> Ctx<'a> {
807840
PLType::Pointer(_) => unreachable!(),
808841
PLType::PlaceHolder(_) => CompletionItemKind::STRUCT,
809842
PLType::Union(_) => CompletionItemKind::ENUM,
843+
PLType::Closure(_) => unreachable!(),
810844
};
811845
if k.starts_with('|') {
812846
// skip method

src/ast/diag.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,9 @@ define_error!(
120120
INVALID_IS_EXPR = "invalid `is` expression",
121121
INVALID_CAST = "invalid cast",
122122
METHOD_NOT_FOUND = "method not found",
123-
DERIVE_TRAIT_NOT_IMPL = "derive trait not impl"
123+
DERIVE_TRAIT_NOT_IMPL = "derive trait not impl",
124+
CANNOT_ASSIGN_INCOMPLETE_GENERICS = "cannot assign incomplete generic function to variable",
125+
FUNCTION_TYPE_NOT_MATCH = "function type not match",
124126
);
125127
macro_rules! define_warn {
126128
($(
@@ -397,6 +399,9 @@ impl PLDiag {
397399
file: String,
398400
txt: Option<(String, Vec<String>)>,
399401
) -> &mut Self {
402+
if range == Default::default() {
403+
return self;
404+
}
400405
self.raw.labels.push(PLLabel { file, txt, range });
401406
self
402407
}

src/ast/fmt.rs

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,9 @@ use super::{
2424
string_literal::StringNode,
2525
tuple::{TupleInitNode, TupleTypeNode},
2626
types::{
27-
ArrayInitNode, ArrayTypeNameNode, GenericDefNode, GenericParamNode, PointerTypeNode,
28-
StructDefNode, StructInitFieldNode, StructInitNode, TypeNameNode, TypedIdentifierNode,
27+
ArrayInitNode, ArrayTypeNameNode, ClosureTypeNode, GenericDefNode, GenericParamNode,
28+
PointerTypeNode, StructDefNode, StructInitFieldNode, StructInitNode, TypeNameNode,
29+
TypedIdentifierNode,
2930
},
3031
union::UnionDefNode,
3132
FmtTrait, NodeEnum, TypeNodeEnum,
@@ -768,4 +769,19 @@ impl FmtBuilder {
768769
}
769770
self.r_paren();
770771
}
772+
pub fn parse_closure_type_node(&mut self, node: &ClosureTypeNode) {
773+
self.l_paren();
774+
for (i, ty) in node.arg_types.iter().enumerate() {
775+
if i > 0 {
776+
self.comma();
777+
self.space();
778+
}
779+
ty.format(self);
780+
}
781+
self.r_paren();
782+
self.space();
783+
self.token("=>");
784+
self.space();
785+
node.ret_type.format(self);
786+
}
771787
}

0 commit comments

Comments
 (0)