diff --git a/xls/dslx/frontend/parser_test.cc b/xls/dslx/frontend/parser_test.cc index 7c17c2e425..4776be7841 100644 --- a/xls/dslx/frontend/parser_test.cc +++ b/xls/dslx/frontend/parser_test.cc @@ -2677,6 +2677,27 @@ TEST_F(ParserTest, ModuleWithParametricProcAlias) { proc Bar = Foo<3, 4>;)"); } +TEST_F(ParserTest, ModuleWithParametricProcAliasCallingParametricFn) { + RoundTrip(R"(fn bar(i: uN[Y]) -> uN[Y] { + i + i +} +proc Foo { + c: chan out; + config(output_c: chan out) { + (output_c,) + } + init { + uN[N]:1 + } + next(i: uN[N]) { + let result = bar(i); + let tok = send(join(), c, result); + result + uN[N]:1 + } +} +proc Bar = Foo<16>;)"); +} + TEST_F(ParserTest, ModuleWithPublicParametricProcAlias) { RoundTrip(R"(pub proc Foo { config() {} diff --git a/xls/dslx/ir_convert/function_converter.cc b/xls/dslx/ir_convert/function_converter.cc index b498b73626..67ed2479f7 100644 --- a/xls/dslx/ir_convert/function_converter.cc +++ b/xls/dslx/ir_convert/function_converter.cc @@ -3598,14 +3598,18 @@ absl::Status FunctionConverter::HandleProcNextFunction( parametric_type->GetTotalBitCount()); XLS_ASSIGN_OR_RETURN(int64_t bit_count, parametric_width_ctd.GetAsInt64()); Value param_value; - if (parametric_value->IsSigned()) { - XLS_ASSIGN_OR_RETURN(int64_t bit_value, - parametric_value->GetBitValueViaSign()); - param_value = Value(SBits(bit_value, bit_count)); + if (parametric_value->IsBits()) { + if (parametric_value->IsSigned()) { + XLS_ASSIGN_OR_RETURN(int64_t bit_value, + parametric_value->GetBitValueViaSign()); + param_value = Value(SBits(bit_value, bit_count)); + } else { + XLS_ASSIGN_OR_RETURN(uint64_t bit_value, + parametric_value->GetBitValueViaSign()); + param_value = Value(UBits(bit_value, bit_count)); + } } else { - XLS_ASSIGN_OR_RETURN(uint64_t bit_value, - parametric_value->GetBitValueViaSign()); - param_value = Value(UBits(bit_value, bit_count)); + XLS_ASSIGN_OR_RETURN(param_value, InterpValueToValue(*parametric_value)); } DefConst(parametric_binding, param_value); XLS_RETURN_IF_ERROR( diff --git a/xls/dslx/ir_convert/ir_converter_test.cc b/xls/dslx/ir_convert/ir_converter_test.cc index a3e3e169c1..ec271a89c5 100644 --- a/xls/dslx/ir_convert/ir_converter_test.cc +++ b/xls/dslx/ir_convert/ir_converter_test.cc @@ -2694,6 +2694,42 @@ pub proc FooAlias = Foo<16>; ExpectIr(converted); } +TEST_P(IrConverterWithBothTypecheckVersionsTest, + HandlesParametricProcAliasCallingParametricFn) { + if (GetParam() == TypeInferenceVersion::kVersion1) { + // Proc aliases are not supported in TIv1. + return; + } + + constexpr std::string_view program = R"( +fn bar(i: uN[Y]) -> uN[Y] { + i+i +} + +proc Foo { + c: chan out; + init { uN[N]:1 } + config(output_c: chan out) { + (output_c,) + } + next(i: uN[N]) { + let result = bar(i); + let tok = send(join(), c, result); + result + uN[N]:1 + } +} + +pub proc FooAlias = Foo<16>; +)"; + + auto import_data = CreateImportDataForTest(); + XLS_ASSERT_OK_AND_ASSIGN( + std::string converted, + ConvertOneFunctionForTest(program, "FooAlias", import_data, + kNoPosOptions)); + ExpectIr(converted); +} + TEST_P(IrConverterWithBothTypecheckVersionsTest, HandlesProcAliasToImportedProc) { if (GetParam() == TypeInferenceVersion::kVersion1) { @@ -2730,6 +2766,83 @@ pub proc FooAlias = imported::Foo<16>; ExpectIr(converted); } +TEST_P(IrConverterWithBothTypecheckVersionsTest, + HandlesProcAliasToNonBitsParametricProc) { + if (GetParam() == TypeInferenceVersion::kVersion1) { + // Proc aliases are not supported in TIv1. + return; + } + + constexpr std::string_view program = R"( +struct MyStruct { + a: u32, + b: u32, +} + +pub proc Foo { + c: chan out; + init { uN[CONFIG.a]:1 } + config(output_c: chan out) { + (output_c,) + } + next(i: uN[CONFIG.a]) { + let tok = send(join(), c, i); + i + CONFIG.b as uN[CONFIG.a] + } +} + +const CONFIG = MyStruct{a: u32:32, b: u32:2}; +proc FooAlias = Foo; +)"; + + ImportData import_data = CreateImportDataForTest(); + XLS_ASSERT_OK_AND_ASSIGN( + std::string converted, + ConvertOneFunctionForTest(program, "FooAlias", import_data, + kNoPosOptions)); + ExpectIr(converted); +} + +// TODO: google/xls#3353 - enable this test when all the issues have been fixed. +TEST_P(IrConverterWithBothTypecheckVersionsTest, + DISABLED_HandlesSpawnOfNonBitsParametricProc) { + constexpr std::string_view program = R"( +struct MyStruct { + a: u32, + b: u32, +} + +proc Foo { + c: chan out; + init { uN[CONFIG.a]:1 } + config(output_c: chan out) { + (output_c,) + } + next(i: uN[CONFIG.a]) { + let tok = send(join(), c, i); + i + CONFIG.b as uN[CONFIG.a] + } +} + +const CONFIG = MyStruct{a: u32:32, b: u32:2}; +proc Top { + init { () } + config() { + let (p, c) = chan("my_chan"); + spawn Foo(p); + () + } + next(state: ()) { () } +} +)"; + + ImportData import_data = CreateImportDataForTest(); + XLS_ASSERT_OK_AND_ASSIGN( + std::string converted, + ConvertOneFunctionForTest(program, "Top", import_data, kNoPosOptions)); + ExpectIr(converted); +} + TEST_P(IrConverterWithBothTypecheckVersionsTest, HandlesProcWithTypeAlias) { constexpr std::string_view program = R"( proc P { @@ -6368,6 +6481,37 @@ pub proc FooAlias = Foo<16>; ExpectIr(converted); } +TEST_P(ProcScopedChannelsIrConverterTest, + ProcScopedParametricProcAliasCallingParametricFn) { + constexpr std::string_view program = R"( +fn bar(i: uN[Y]) -> uN[Y] { + i+i +} + +proc Foo { + c: chan out; + init { uN[N]:1 } + config(output_c: chan out) { + (output_c,) + } + next(i: uN[N]) { + let result = bar(i); + let tok = send(join(), c, result); + result + uN[N]:1 + } +} + +pub proc FooAlias = Foo<16>; +)"; + + auto import_data = CreateImportDataForTest(); + XLS_ASSERT_OK_AND_ASSIGN( + std::string converted, + ConvertOneFunctionForTest(program, "FooAlias", import_data, + kProcScopedChannelOptions)); + ExpectIr(converted); +} + TEST_P(ProcScopedChannelsIrConverterTest, ProcScopedProcAliasToImportedProc) { ImportData import_data = CreateImportDataForTest(); diff --git a/xls/dslx/ir_convert/testdata/ir_converter_test_HandlesParametricProcAliasCallingParametricFn.ir b/xls/dslx/ir_convert/testdata/ir_converter_test_HandlesParametricProcAliasCallingParametricFn.ir new file mode 100644 index 0000000000..c453cb2674 --- /dev/null +++ b/xls/dslx/ir_convert/testdata/ir_converter_test_HandlesParametricProcAliasCallingParametricFn.ir @@ -0,0 +1,23 @@ +package test_module + +file_number 0 "test_module.x" + +chan test_module__output_c(bits[16], id=0, kind=streaming, ops=send_only, flow_control=ready_valid, strictness=proven_mutually_exclusive) + +fn __test_module__bar__16(i: bits[16] id=1) -> bits[16] { + Y: bits[5] = literal(value=16, id=2) + ret add.3: bits[16] = add(i, i, id=3) +} + +top proc __test_module__FooAlias_next(__state: bits[16], init={1}) { + __state: bits[16] = state_read(state_element=__state, id=5) + result: bits[16] = invoke(__state, to_apply=__test_module__bar__16, id=8) + literal.11: bits[16] = literal(value=1, id=11) + after_all.9: token = after_all(id=9) + literal.6: bits[1] = literal(value=1, id=6) + add.12: bits[16] = add(result, literal.11, id=12) + __token: token = literal(value=token, id=4) + N: bits[32] = literal(value=16, id=7) + tok: token = send(after_all.9, result, predicate=literal.6, channel=test_module__output_c, id=10) + next_value.13: () = next_value(param=__state, value=add.12, id=13) +} diff --git a/xls/dslx/ir_convert/testdata/ir_converter_test_HandlesProcAliasToNonBitsParametricProc.ir b/xls/dslx/ir_convert/testdata/ir_converter_test_HandlesProcAliasToNonBitsParametricProc.ir new file mode 100644 index 0000000000..67da67e2de --- /dev/null +++ b/xls/dslx/ir_convert/testdata/ir_converter_test_HandlesProcAliasToNonBitsParametricProc.ir @@ -0,0 +1,18 @@ +package test_module + +file_number 0 "test_module.x" + +chan test_module__output_c(bits[32], id=0, kind=streaming, ops=send_only, flow_control=ready_valid, strictness=proven_mutually_exclusive) + +top proc __test_module__FooAlias_next(__state: bits[32], init={1}) { + CONFIG: (bits[32], bits[32]) = literal(value=(32, 2), id=4) + CONFIG_b: bits[32] = tuple_index(CONFIG, index=1, id=7) + __state: bits[32] = state_read(state_element=__state, id=2) + zero_ext.8: bits[32] = zero_ext(CONFIG_b, new_bit_count=32, id=8) + after_all.5: token = after_all(id=5) + literal.3: bits[1] = literal(value=1, id=3) + add.9: bits[32] = add(__state, zero_ext.8, id=9) + __token: token = literal(value=token, id=1) + tok: token = send(after_all.5, __state, predicate=literal.3, channel=test_module__output_c, id=6) + next_value.10: () = next_value(param=__state, value=add.9, id=10) +} diff --git a/xls/dslx/ir_convert/testdata/ir_converter_test_HandlesSpawnOfNonBitsParametricProc.ir b/xls/dslx/ir_convert/testdata/ir_converter_test_HandlesSpawnOfNonBitsParametricProc.ir new file mode 100644 index 0000000000..e69de29bb2 diff --git a/xls/dslx/ir_convert/testdata/ir_converter_test_ProcScopedParametricProcAliasCallingParametricFn.ir b/xls/dslx/ir_convert/testdata/ir_converter_test_ProcScopedParametricProcAliasCallingParametricFn.ir new file mode 100644 index 0000000000..5847efd010 --- /dev/null +++ b/xls/dslx/ir_convert/testdata/ir_converter_test_ProcScopedParametricProcAliasCallingParametricFn.ir @@ -0,0 +1,23 @@ +package test_module + +file_number 0 "test_module.x" + +fn __test_module__bar__16(i: bits[16] id=1) -> bits[16] { + Y: bits[5] = literal(value=16, id=2) + ret add.3: bits[16] = add(i, i, id=3) +} + +top proc __test_module__FooAlias_next(__state: bits[16], init={1}) { + chan_interface output_c(direction=send, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none) + __state: bits[16] = state_read(state_element=__state, id=5) + result: bits[16] = invoke(__state, to_apply=__test_module__bar__16, id=9) + literal.12: bits[16] = literal(value=1, id=12) + after_all.10: token = after_all(id=10) + literal.6: bits[1] = literal(value=1, id=6) + add.13: bits[16] = add(result, literal.12, id=13) + __token: token = literal(value=token, id=4) + N: bits[32] = literal(value=16, id=7) + tuple.8: () = tuple(id=8) + tok: token = send(after_all.10, result, predicate=literal.6, channel=output_c, id=11) + next_value.14: () = next_value(param=__state, value=add.13, id=14) +} diff --git a/xls/dslx/type_system_v2/inference_table.cc b/xls/dslx/type_system_v2/inference_table.cc index 3336e49c71..8c64d1b9eb 100644 --- a/xls/dslx/type_system_v2/inference_table.cc +++ b/xls/dslx/type_system_v2/inference_table.cc @@ -395,6 +395,7 @@ class InferenceTableImpl : public InferenceTable { ParametricContext* result = context.get(); parametric_contexts_.push_back(std::move(context)); mutable_parametric_context_data_.emplace(result, std::move(mutable_data)); + SetParametricEnv(result, env); return result; }