Skip to content

Commit 99bc1e2

Browse files
Merge pull request #3126 from xlsynth:meheff/2025-09-26-residual-data-pipeline
PiperOrigin-RevId: 813416766
2 parents ad2979f + 68fa805 commit 99bc1e2

File tree

2 files changed

+85
-22
lines changed

2 files changed

+85
-22
lines changed

xls/codegen/block_generator.cc

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -395,11 +395,33 @@ class BlockGenerator {
395395
// possible.
396396
XLS_RETURN_IF_ERROR(DeclareInstantiationOutputs());
397397
if (options_.emit_as_pipeline()) {
398+
// If residual data is provided use it to determine a global order of all
399+
// nodes in the block. This is then used to determine the emission order
400+
// of nodes within any particular stage.
401+
bool has_reference_order = false;
402+
std::vector<Node*> global_order;
398403
if (options_.residual_data().has_value()) {
399-
return absl::UnimplementedError(
400-
"Reference residual data is not supported when generating "
401-
"pipelines");
404+
std::vector<int64_t> ref_ids = NodeIdOrderFromResidualData(
405+
block_->name(), *options_.residual_data());
406+
if (!ref_ids.empty()) {
407+
has_reference_order = true;
408+
global_order = StableTopoSort(block_, ref_ids);
409+
}
402410
}
411+
auto get_stage_node_order = [&](absl::Span<Node* const> stage_nodes) {
412+
CHECK(has_reference_order);
413+
absl::flat_hash_set<Node*> stage_nodes_set(stage_nodes.begin(),
414+
stage_nodes.end());
415+
std::vector<Node*> stage_order;
416+
for (Node* node : global_order) {
417+
if (stage_nodes_set.contains(node)) {
418+
stage_order.push_back(node);
419+
}
420+
}
421+
CHECK_EQ(stage_order.size(), stage_nodes.size());
422+
return stage_order;
423+
};
424+
403425
// Emits the block as a sequence of pipeline stages. First reconstruct the
404426
// stages and emit the stages one-by-one. Emitting as a pipeline is purely
405427
// cosmetic relative to the emit_as_pipeline=false option as the Verilog
@@ -415,8 +437,14 @@ class BlockGenerator {
415437
mb_.declaration_section()->Add<BlankLine>(SourceInfo());
416438
mb_.declaration_section()->Add<Comment>(
417439
SourceInfo(), absl::StrFormat("===== Pipe stage %d:", stage_num));
418-
XLS_RETURN_IF_ERROR(EmitLogic(stage.combinational_nodes, stage_num));
419440

441+
if (has_reference_order) {
442+
XLS_RETURN_IF_ERROR(EmitLogic(
443+
get_stage_node_order(stage.combinational_nodes), stage_num));
444+
} else {
445+
XLS_RETURN_IF_ERROR(
446+
EmitLogic(stage.combinational_nodes, stage_num));
447+
}
420448
if (!stage.registers.empty()) {
421449
mb_.NewDeclarationAndAssignmentSections();
422450
mb_.declaration_section()->Add<BlankLine>(SourceInfo());

xls/tools/block_to_verilog_main_test.py

Lines changed: 53 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,34 @@ class NodeMetadata:
5353
}
5454
'''
5555

56+
TWO_STAGE_BLOCK_IR = '''package add
57+
58+
#[signature("""module_name: "my_function" data_ports { direction: PORT_DIRECTION_INPUT name: "a" width: 32 type { type_enum: BITS bit_count: 32 } }
59+
data_ports { direction: PORT_DIRECTION_INPUT name: "b" width: 32 type { type_enum: BITS bit_count: 32 } }
60+
data_ports { direction: PORT_DIRECTION_OUTPUT name: "out" width: 32 type { type_enum: BITS bit_count: 32 } }
61+
fixed_latency { latency: 1 } """)]
62+
63+
top block my_function(clk: clock, a: bits[32], b: bits[32], out: bits[32]) {
64+
reg not_a_reg(bits[32])
65+
reg not_b_reg(bits[32])
66+
67+
a: bits[32] = input_port(name=a, id=6)
68+
b: bits[32] = input_port(name=b, id=7)
69+
not_a: bits[32] = not(a, id=8)
70+
not_b: bits[32] = not(b, id=9)
71+
72+
not_a_write: () = register_write(not_a, register=not_a_reg, id=10)
73+
not_b_write: () = register_write(not_b, register=not_b_reg, id=11)
74+
not_a_reg: bits[32] = register_read(register=not_a_reg, id=12)
75+
not_b_reg: bits[32] = register_read(register=not_b_reg, id=13)
76+
77+
id_not_a: bits[32] = identity(not_a_reg, id=14)
78+
id_not_b: bits[32] = identity(not_b_reg, id=15)
79+
sum: bits[32] = add(id_not_a, id_not_b, id=16)
80+
out: () = output_port(sum, name=out, id=17)
81+
}
82+
'''
83+
5684
INLINE_OR_IR = '''package inline_or
5785
5886
#[signature("""module_name: "inline_or" data_ports { direction: PORT_DIRECTION_INPUT name: "a" width: 32 type { type_enum: BITS bit_count: 32 } }
@@ -303,34 +331,41 @@ def test_block_ir_residual_roundtrip(self):
303331
with open(verilog1_path, 'r') as f1, open(verilog2_path, 'r') as f2:
304332
self.assertEqual(f1.read(), f2.read())
305333

306-
def test_pipeline_generator_rejects_reference_residual(self):
334+
def test_two_stage_pipeline(self):
307335
# Using reference residual data with pipeline generator should fail.
308-
block_ir_file = self.create_tempfile(content=BLOCK_IR)
336+
block_ir_file = self.create_tempfile(content=TWO_STAGE_BLOCK_IR)
309337

310338
ref_path = self.write_residual_data(
311339
'my_function',
312340
[
313-
NodeMetadata('neg_a', 8),
341+
NodeMetadata('not_a', 8),
314342
NodeMetadata('not_b', 9),
315-
NodeMetadata('id_not_b', 10),
343+
NodeMetadata('id_not_b', 15),
344+
NodeMetadata('id_not_a', 14),
316345
],
317346
'ref_pipeline_order.textproto',
318347
)
348+
verilog_path = test_base.create_named_output_text_file('output.v')
349+
subprocess.check_call([
350+
BLOCK_TO_VERILOG_MAIN_PATH,
351+
'--alsologtostderr',
352+
'--generator=pipeline',
353+
'--reference_residual_data_path=' + ref_path,
354+
'--output_verilog_path=' + verilog_path,
355+
block_ir_file.full_path,
356+
])
319357

320-
proc = subprocess.run(
321-
[
322-
BLOCK_TO_VERILOG_MAIN_PATH,
323-
'--generator=pipeline',
324-
'--reference_residual_data_path=' + ref_path,
325-
block_ir_file.full_path,
326-
],
327-
capture_output=True,
328-
text=True,
329-
check=False,
330-
)
331-
332-
self.assertNotEqual(proc.returncode, 0)
333-
self.assertIn('not supported when generating pipelines', proc.stderr)
358+
with open(verilog_path, 'r') as f:
359+
v = f.read()
360+
print(v)
361+
self.assertLess(
362+
self.find_line_number(v, 'not_a_comb ='),
363+
self.find_line_number(v, 'not_b_comb ='),
364+
)
365+
self.assertLess(
366+
self.find_line_number(v, 'id_not_b_comb ='),
367+
self.find_line_number(v, 'id_not_a_comb ='),
368+
)
334369

335370
def test_inline_all_expressions(self):
336371
ir_file = self.create_tempfile(content=INLINE_OR_IR)

0 commit comments

Comments
 (0)