Skip to content

Commit

Permalink
[ESI][BSP] Multiplexing hostmem write clients (#8137)
Browse files Browse the repository at this point in the history
Take multiple hostmem clients and multiplex them post-gearbox. Won't
work long-term since this approach precludes bursting multiple
transactions, which is necessary to get good PCIe perf.
  • Loading branch information
teqdruid authored Jan 28, 2025
1 parent ebfb49b commit 17c7fc8
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 27 deletions.
5 changes: 4 additions & 1 deletion frontends/PyCDE/integration_test/esitester.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def construct(ports):
return ReadMem


@modparams
def WriteMem(width: int) -> typing.Type['WriteMem']:

class WriteMem(Module):
Expand Down Expand Up @@ -190,7 +191,9 @@ def construct(ports):
ReadMem(32)(appid=esi.AppID("readmem", 32), clk=ports.clk, rst=ports.rst)
ReadMem(64)(appid=esi.AppID("readmem", 64), clk=ports.clk, rst=ports.rst)
ReadMem(96)(appid=esi.AppID("readmem", 96), clk=ports.clk, rst=ports.rst)
WriteMem(96)(clk=ports.clk, rst=ports.rst)
WriteMem(32)(appid=esi.AppID("writemem", 32), clk=ports.clk, rst=ports.rst)
WriteMem(64)(appid=esi.AppID("writemem", 64), clk=ports.clk, rst=ports.rst)
WriteMem(96)(appid=esi.AppID("writemem", 96), clk=ports.clk, rst=ports.rst)


if __name__ == "__main__":
Expand Down
53 changes: 33 additions & 20 deletions frontends/PyCDE/src/pycde/bsp/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,10 +489,14 @@ def build(ports):
upstream_data_bits = client_data
upstream_valid = client_valid
ready_for_client.assign(upstream_ready)
tag = client_tag_and_data.tag
address = client_tag_and_data.address
elif output_bitwidth > input_bitwidth:
upstream_data_bits = client_data.as_bits(output_bitwidth)
upstream_valid = client_valid
ready_for_client.assign(upstream_ready)
tag = client_tag_and_data.tag
address = client_tag_and_data.address
else:
# Create registers equal to the number of upstream transactions needed
# to complete the transmission.
Expand Down Expand Up @@ -521,21 +525,23 @@ def build(ports):
clear.assign(upstream_xact & (counter.out == (num_chunks - 1)))
increment.assign(upstream_xact)
ready_for_client.assign(~upstream_valid)
counter_bytes = BitsSignal.concat([counter.out.as_bits(),
Bits(3)(0)]).as_uint()

# Construct the output channel. Shared logic across all three cases.
tag_reg = client_tag_and_data.tag.reg(ports.clk,
ce=client_xact,
name="tag_reg")
addr_reg = client_tag_and_data.address.reg(ports.clk,
ce=client_xact,
name="address_reg")
address = (addr_reg + counter_bytes).as_uint(64)
tag = tag_reg

# Construct the output channel. Shared logic across all three cases.
tag_reg = client_tag_and_data.tag.reg(ports.clk,
ce=client_xact,
name="tag_reg")
addr_reg = client_tag_and_data.address.reg(ports.clk,
ce=client_xact,
name="address_reg")
counter_bytes = BitsSignal.concat([counter.out.as_bits(),
Bits(3)(0)]).as_uint()
addr_incremented = (addr_reg + counter_bytes).as_uint(64)
upstream_channel, upstrm_ready_sig = TaggedWriteGearboxImpl.out.type.wrap(
{
"address": addr_incremented,
"tag": tag_reg,
"address": address,
"tag": tag,
"data": upstream_data_bits,
}, upstream_valid)
upstream_ready.assign(upstrm_ready_sig)
Expand Down Expand Up @@ -584,29 +590,31 @@ def build(ports):
ports.upstream = write_bundle
return

# TODO: mux together multiple write clients.
assert len(reqs) == 1, "Only one write client supported for now."
assert len(reqs) <= 256, "More than 256 write clients not supported."

upstream_req_channel = Wire(Channel(hostmem_module.UpstreamWriteReq))
upstream_write_bundle, froms = hostmem_module.write.type.pack(
req=upstream_req_channel)
ports.upstream = upstream_write_bundle
upstream_ack_tag = froms["ackTag"]

demuxed_acks = esi.TaggedDemux(len(reqs), upstream_ack_tag.type)(
clk=ports.clk, rst=ports.rst, in_=upstream_ack_tag)

# TODO: re-write the tags and store the client and client tag.

# Build the write request channels and ack wires.
write_channels: List[ChannelSignal] = []
for req in reqs:
for idx, req in enumerate(reqs):
# Get the request channel and its data type.
reqch = [c.channel for c in req.type.channels if c.name == 'req'][0]
client_type = reqch.inner_type

# Write acks to be filled in later.
write_ack = upstream_ack_tag

# Pack up the bundle and assign the request channel.
write_req_bundle_type = esi.HostMem.write_req_bundle_type(
client_type.data)
bundle_sig, froms = write_req_bundle_type.pack(ackTag=write_ack)
bundle_sig, froms = write_req_bundle_type.pack(
ackTag=demuxed_acks.get_out(idx))

tagged_client_req = froms["req"]
bitcast_client_req = tagged_client_req.transform(lambda m: client_type({
Expand All @@ -620,7 +628,12 @@ def build(ports):
write_width)(clk=ports.clk,
rst=ports.rst,
in_=bitcast_client_req)
write_channels.append(gearbox.out)
write_channels.append(
gearbox.out.transform(lambda m: m.type({
"address": m.address,
"tag": idx,
"data": m.data,
})))
# Set the port for the client request.
setattr(ports, HostMemWriteProcessorImpl.reqPortMap[req], bundle_sig)

Expand Down
12 changes: 10 additions & 2 deletions frontends/PyCDE/src/pycde/esi.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,7 +793,8 @@ def TaggedDemux(num_clients: int,
channel_type: Channel) -> typing.Type["TaggedDemuxImpl"]:
"""Construct a tagged demultiplexer for a given tagged data type.
'tagged_data_type' is assumed to be a struct with a 'tag' field and a 'data'
field. Demux the data to the appropriate output channel based on the tag."""
field OR a UInt representing the tag itself. Demux the data to the appropriate
output channel based on the tag."""

class TaggedDemuxImpl(Module):
clk = Clock()
Expand All @@ -812,10 +813,17 @@ def get_out(self, idx: int) -> ChannelSignal:
def build(ports) -> None:
upstream_ready_wire = Wire(Bits(1))
upstream_data, upstream_valid = ports.in_.unwrap(upstream_ready_wire)
upstream_data_type = upstream_data.type

upstream_ready = Bits(1)(1)
for idx in range(num_clients):
output_valid = upstream_valid & (upstream_data.tag == UInt(8)(idx))
if isinstance(upstream_data_type, StructType):
tag = upstream_data.tag
elif isinstance(upstream_data_type, UInt):
tag = upstream_data
else:
raise TypeError("TaggedDemux input must be a struct or UInt.")
output_valid = upstream_valid & (tag == UInt(8)(idx))
output_ch, output_ready = channel_type.wrap(upstream_data, output_valid)
setattr(ports, TaggedDemuxImpl.output_names[idx], output_ch)
upstream_ready = upstream_ready & output_ready
Expand Down
17 changes: 13 additions & 4 deletions lib/Dialect/ESI/runtime/cpp/tools/esitester.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ void dmaTest(Accelerator *acc, esi::services::HostMem::HostMemRegion &region,

// Initiate a test write.
// TODO: remove the width == 96 once multiplexing support is added.
if (write && width == 96) {
if (write) {
auto check = [&]() {
region.flush();
for (size_t i = 0, e = (width + 63) / 64; i < e; ++i)
Expand All @@ -159,9 +159,18 @@ void dmaTest(Accelerator *acc, esi::services::HostMem::HostMemRegion &region,
return true;
};

auto *writeMem = acc->getPorts()
.at(AppID("WriteMem"))
.getAs<services::MMIO::MMIORegion>();
auto writeMemChildIter = acc->getChildren().find(AppID("writemem", width));
if (writeMemChildIter == acc->getChildren().end())
throw std::runtime_error("DMA test failed. No writemem child found");
auto &writeMemPorts = writeMemChildIter->second->getPorts();
auto writeMemPortIter = writeMemPorts.find(AppID("WriteMem"));
if (writeMemPortIter == writeMemPorts.end())
throw std::runtime_error("DMA test failed. No WriteMem port found");
auto *writeMem =
writeMemPortIter->second.getAs<services::MMIO::MMIORegion>();
if (!writeMem)
throw std::runtime_error("DMA test failed. WriteMem port is not MMIO");

for (size_t i = 0, e = (width + 63) / 64; i < e; ++i)
dataPtr[i] = 0xFFFFFFFFFFFFFFFFull;
region.flush();
Expand Down

0 comments on commit 17c7fc8

Please sign in to comment.