Skip to content

Commit

Permalink
[ESI][BSP] Gearboxing the hostmem write path (#8136)
Browse files Browse the repository at this point in the history
Take client data and gearbox it up or down to the underlying host
connection width. In the case where we must gearbox down, creates
multiple hostmem write transactions.
  • Loading branch information
teqdruid authored Jan 28, 2025
1 parent ce67b00 commit ebfb49b
Show file tree
Hide file tree
Showing 5 changed files with 202 additions and 75 deletions.
95 changes: 50 additions & 45 deletions frontends/PyCDE/integration_test/esitester.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import pycde.esi as esi
from pycde.types import Bits, Channel, UInt

import typing
import sys

# CHECK: [INFO] [CONNECT] connecting to backend
Expand Down Expand Up @@ -125,54 +126,58 @@ def construct(ports):
return ReadMem


class WriteMem(Module):
"""Writes a cycle count to host memory at address 0 in MMIO upon each MMIO
transaction."""
clk = Clock()
rst = Reset()
def WriteMem(width: int) -> typing.Type['WriteMem']:

@generator
def construct(ports):
cmd_chan_wire = Wire(Channel(esi.MMIOReadWriteCmdType))
resp_ready_wire = Wire(Bits(1))
cmd, cmd_valid = cmd_chan_wire.unwrap(resp_ready_wire)
mmio_xact = cmd_valid & resp_ready_wire

write_loc_ce = mmio_xact & cmd.write & (cmd.offset == UInt(32)(0))
write_loc = Reg(UInt(64),
clk=ports.clk,
rst=ports.rst,
rst_value=0,
ce=write_loc_ce)
write_loc.assign(cmd.data.as_uint())

response_data = write_loc.as_bits()
response_chan, response_ready = Channel(Bits(64)).wrap(
response_data, cmd_valid)
resp_ready_wire.assign(response_ready)

mmio_rw = esi.MMIO.read_write(appid=AppID("WriteMem"))
mmio_rw_cmd_chan = mmio_rw.unpack(data=response_chan)['cmd']
cmd_chan_wire.assign(mmio_rw_cmd_chan)

tag = Counter(8)(clk=ports.clk,
rst=ports.rst,
clear=Bits(1)(0),
increment=mmio_xact)
class WriteMem(Module):
"""Writes a cycle count to host memory at address 0 in MMIO upon each MMIO
transaction."""
clk = Clock()
rst = Reset()

@generator
def construct(ports):
cmd_chan_wire = Wire(Channel(esi.MMIOReadWriteCmdType))
resp_ready_wire = Wire(Bits(1))
cmd, cmd_valid = cmd_chan_wire.unwrap(resp_ready_wire)
mmio_xact = cmd_valid & resp_ready_wire

write_loc_ce = mmio_xact & cmd.write & (cmd.offset == UInt(32)(0))
write_loc = Reg(UInt(64),
clk=ports.clk,
rst=ports.rst,
rst_value=0,
ce=write_loc_ce)
write_loc.assign(cmd.data.as_uint())

response_data = write_loc.as_bits()
response_chan, response_ready = Channel(response_data.type).wrap(
response_data, cmd_valid)
resp_ready_wire.assign(response_ready)

mmio_rw = esi.MMIO.read_write(appid=AppID("WriteMem"))
mmio_rw_cmd_chan = mmio_rw.unpack(data=response_chan)['cmd']
cmd_chan_wire.assign(mmio_rw_cmd_chan)

tag = Counter(8)(clk=ports.clk,
rst=ports.rst,
clear=Bits(1)(0),
increment=mmio_xact)

cycle_counter = Counter(width)(clk=ports.clk,
rst=ports.rst,
clear=Bits(1)(0),
increment=Bits(1)(1))

cycle_counter = Counter(64)(clk=ports.clk,
rst=ports.rst,
clear=Bits(1)(0),
increment=Bits(1)(1))
hostmem_write_req, _ = esi.HostMem.wrap_write_req(
write_loc,
cycle_counter.out.as_bits(),
tag.out,
valid=mmio_xact.reg(ports.clk, ports.rst))

hostmem_write_req, _ = esi.HostMem.wrap_write_req(
write_loc,
cycle_counter.out.as_bits(),
tag.out,
valid=mmio_xact.reg(ports.clk, ports.rst))
hostmem_write_resp = esi.HostMem.write(appid=AppID("WriteMem_hostwrite"),
req=hostmem_write_req)

hostmem_write_resp = esi.HostMem.write(appid=AppID("WriteMem_hostwrite"),
req=hostmem_write_req)
return WriteMem


class EsiTesterTop(Module):
Expand All @@ -185,7 +190,7 @@ def construct(ports):
ReadMem(32)(appid=esi.AppID("readmem", 32), clk=ports.clk, rst=ports.rst)
ReadMem(64)(appid=esi.AppID("readmem", 64), clk=ports.clk, rst=ports.rst)
ReadMem(96)(appid=esi.AppID("readmem", 96), clk=ports.clk, rst=ports.rst)
WriteMem(clk=ports.clk, rst=ports.rst)
WriteMem(96)(clk=ports.clk, rst=ports.rst)


if __name__ == "__main__":
Expand Down
151 changes: 127 additions & 24 deletions frontends/PyCDE/src/pycde/bsp/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,13 +258,13 @@ def build_addr_read(


@modparams
def TaggedGearbox(input_bitwidth: int,
output_bitwidth: int) -> type["TaggedGearboxImpl"]:
def TaggedReadGearbox(input_bitwidth: int,
output_bitwidth: int) -> type["TaggedReadGearboxImpl"]:
"""Build a gearbox to convert the upstream data to the client data
type. Assumes a struct {tag, data} and only gearboxes the data. Tag is stored
separately and the struct is re-assembled later on."""

class TaggedGearboxImpl(Module):
class TaggedReadGearboxImpl(Module):
clk = Clock()
rst = Reset()
in_ = InputChannel(
Expand Down Expand Up @@ -331,15 +331,16 @@ def build(ports):
ports.rst,
ce=upstream_xact,
name="tag_reg")
client_channel, client_ready = TaggedGearboxImpl.out.type.wrap(

client_channel, client_ready = TaggedReadGearboxImpl.out.type.wrap(
{
"tag": tag_reg,
"data": client_data_bits,
}, client_valid)
ready_for_upstream.assign(client_ready)
ports.out = client_channel

return TaggedGearboxImpl
return TaggedReadGearboxImpl


def HostmemReadProcessor(read_width: int, hostmem_module,
Expand Down Expand Up @@ -415,7 +416,7 @@ def build(ports):

# Gearbox the data to the client's data type.
client_type = resp_type.inner_type
gearbox = TaggedGearbox(read_width, client_type.data.bitwidth)(
gearbox = TaggedReadGearbox(read_width, client_type.data.bitwidth)(
clk=ports.clk, rst=ports.rst, in_=demuxed_upstream_channel)
client_resp_channel = gearbox.out.transform(lambda m: client_type({
"tag": m.tag,
Expand Down Expand Up @@ -448,6 +449,101 @@ def build(ports):
return HostmemReadProcessorImpl


@modparams
def TaggedWriteGearbox(input_bitwidth: int,
output_bitwidth: int) -> type["TaggedWriteGearboxImpl"]:
"""Build a gearbox to convert the client data to upstream write chunks.
Assumes a struct {address, tag, data} and only gearboxes the data. Tag is
stored separately and the struct is re-assembled later on."""

if output_bitwidth % 8 != 0:
raise ValueError("Output bitwidth must be a multiple of 8.")

class TaggedWriteGearboxImpl(Module):
clk = Clock()
rst = Reset()
in_ = InputChannel(
StructType([
("address", UInt(64)),
("tag", esi.HostMem.TagType),
("data", Bits(input_bitwidth)),
]))
out = OutputChannel(
StructType([
("address", UInt(64)),
("tag", esi.HostMem.TagType),
("data", Bits(output_bitwidth)),
]))

@generator
def build(ports):
upstream_ready = Wire(Bits(1))
ready_for_client = Wire(Bits(1))
client_tag_and_data, client_valid = ports.in_.unwrap(ready_for_client)
client_data = client_tag_and_data.data
client_xact = ready_for_client & client_valid

# Determine if gearboxing is necessary and whether it needs to be
# gearboxed up or just sliced down.
if output_bitwidth == input_bitwidth:
upstream_data_bits = client_data
upstream_valid = client_valid
ready_for_client.assign(upstream_ready)
elif output_bitwidth > input_bitwidth:
upstream_data_bits = client_data.as_bits(output_bitwidth)
upstream_valid = client_valid
ready_for_client.assign(upstream_ready)
else:
# Create registers equal to the number of upstream transactions needed
# to complete the transmission.
num_chunks = ceil(input_bitwidth / output_bitwidth)
num_chunks_idx_bitwidth = clog2(num_chunks)
padding = Bits(output_bitwidth - (input_bitwidth % output_bitwidth))(0)
client_data_padded = BitsSignal.concat([padding, client_data])
chunks = [
client_data_padded[i * output_bitwidth:(i + 1) * output_bitwidth]
for i in range(num_chunks)
]
chunk_regs = Array(Bits(output_bitwidth), num_chunks)([
c.reg(ports.clk, ce=client_xact, name=f"chunk_{idx}")
for idx, c in enumerate(chunks)
])
increment = Wire(Bits(1))
clear = Wire(Bits(1))
counter = Counter(num_chunks_idx_bitwidth)(clk=ports.clk,
rst=ports.rst,
increment=increment,
clear=clear)
upstream_data_bits = chunk_regs[counter.out]
upstream_valid = ControlReg(ports.clk, ports.rst, [client_xact],
[clear])
upstream_xact = upstream_valid & upstream_ready
clear.assign(upstream_xact & (counter.out == (num_chunks - 1)))
increment.assign(upstream_xact)
ready_for_client.assign(~upstream_valid)

# Construct the output channel. Shared logic across all three cases.
tag_reg = client_tag_and_data.tag.reg(ports.clk,
ce=client_xact,
name="tag_reg")
addr_reg = client_tag_and_data.address.reg(ports.clk,
ce=client_xact,
name="address_reg")
counter_bytes = BitsSignal.concat([counter.out.as_bits(),
Bits(3)(0)]).as_uint()
addr_incremented = (addr_reg + counter_bytes).as_uint(64)
upstream_channel, upstrm_ready_sig = TaggedWriteGearboxImpl.out.type.wrap(
{
"address": addr_incremented,
"tag": tag_reg,
"data": upstream_data_bits,
}, upstream_valid)
upstream_ready.assign(upstrm_ready_sig)
ports.out = upstream_channel

return TaggedWriteGearboxImpl


def HostMemWriteProcessor(
write_width: int, hostmem_module,
reqs: List[esi._OutputBundleSetter]) -> type["HostMemWriteProcessorImpl"]:
Expand Down Expand Up @@ -490,40 +586,47 @@ def build(ports):

# TODO: mux together multiple write clients.
assert len(reqs) == 1, "Only one write client supported for now."
upstream_req_channel = Wire(Channel(hostmem_module.UpstreamWriteReq))
upstream_write_bundle, froms = hostmem_module.write.type.pack(
req=upstream_req_channel)
ports.upstream = upstream_write_bundle
upstream_ack_tag = froms["ackTag"]
# TODO: re-write the tags and store the client and client tag.

# Build the write request channels and ack wires.
write_channels: List[ChannelSignal] = []
write_acks = []
for req in reqs:
# Get the request channel and its data type.
reqch = [c.channel for c in req.type.channels if c.name == 'req'][0]
data_type = reqch.inner_type.data
assert data_type == Bits(
write_width
), f"Gearboxing not yet supported. Client {req.client_name}"
client_type = reqch.inner_type

# Write acks to be filled in later.
write_ack = Wire(Channel(UInt(8)))
write_acks.append(write_ack)
write_ack = upstream_ack_tag

# Pack up the bundle and assign the request channel.
write_req_bundle_type = esi.HostMem.write_req_bundle_type(data_type)
write_req_bundle_type = esi.HostMem.write_req_bundle_type(
client_type.data)
bundle_sig, froms = write_req_bundle_type.pack(ackTag=write_ack)

tagged_client_req = froms["req"]
bitcast_client_req = tagged_client_req.transform(lambda m: client_type({
"tag": m.tag,
"address": m.address,
"data": m.data.bitcast(client_type.data)
}))

# Gearbox the data to the client's data type.
gearbox = TaggedWriteGearbox(client_type.data.bitwidth,
write_width)(clk=ports.clk,
rst=ports.rst,
in_=bitcast_client_req)
write_channels.append(gearbox.out)
# Set the port for the client request.
setattr(ports, HostMemWriteProcessorImpl.reqPortMap[req], bundle_sig)
write_channels.append(tagged_client_req)

# TODO: re-write the tags and store the client and client tag.

# Build a channel mux for the write requests.
tagged_write_channel = esi.ChannelMux(write_channels)
upstream_write_bundle, froms = hostmem_module.write.type.pack(
req=tagged_write_channel)
ack_tag = froms["ackTag"]
# TODO: decode the ack tag and assign it to the correct client.
write_acks[0].assign(ack_tag)
ports.upstream = upstream_write_bundle
muxed_write_channel = esi.ChannelMux(write_channels)
upstream_req_channel.assign(muxed_write_channel)

return HostMemWriteProcessorImpl

Expand Down
4 changes: 4 additions & 0 deletions frontends/PyCDE/src/pycde/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,8 @@ def __getitem__(self, idx: Union[int, BitVectorSignal]) -> Signal:
_validate_idx(self.type.size, idx)
from .dialects import hw
with get_user_loc():
if isinstance(idx, UIntSignal):
idx = idx.as_bits()
v = hw.ArrayGetOp(self.value, idx)
if self.name and isinstance(idx, int):
v.name = self.name + f"__{idx}"
Expand All @@ -560,6 +562,8 @@ def __get_item__slice(self, s: slice):
idxs = s.indices(len(self))
if idxs[2] != 1:
raise ValueError("Array slices do not support steps")
if not isinstance(idxs[0], int) or not isinstance(idxs[1], int):
raise ValueError("Array slices must be constant ints")

from .types import types
from .dialects import hw
Expand Down
4 changes: 3 additions & 1 deletion lib/Dialect/ESI/runtime/cosim_dpi_server/esi-cosim.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,9 @@ def compile_commands(self) -> List[List[str]]:
"-DTOP_MODULE=" + self.sources.top,
]
if self.debug:
cmd += ["--trace", "--trace-params", "--trace-structs"]
cmd += [
"--trace", "--trace-params", "--trace-structs", "--trace-underscore"
]
cflags.append("-DTRACE")
if len(cflags) > 0:
cmd += ["-CFLAGS", " ".join(cflags)]
Expand Down
23 changes: 18 additions & 5 deletions lib/Dialect/ESI/runtime/cpp/tools/esitester.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,19 +149,32 @@ void dmaTest(Accelerator *acc, esi::services::HostMem::HostMemRegion &region,
}

// Initiate a test write.
if (write) {
// TODO: remove the width == 96 once multiplexing support is added.
if (write && width == 96) {
auto check = [&]() {
region.flush();
for (size_t i = 0, e = (width + 63) / 64; i < e; ++i)
if (dataPtr[i] == 0xFFFFFFFFFFFFFFFFull)
return false;
return true;
};

auto *writeMem = acc->getPorts()
.at(AppID("WriteMem"))
.getAs<services::MMIO::MMIORegion>();
*dataPtr = 0;
writeMem->write(0, (uint64_t)dataPtr);
for (size_t i = 0, e = (width + 63) / 64; i < e; ++i)
dataPtr[i] = 0xFFFFFFFFFFFFFFFFull;
region.flush();
// Command the accelerator to write to 'devicePtr', the pointer which the
// device should use for 'dataPtr'.
writeMem->write(0, reinterpret_cast<uint64_t>(devicePtr));
// Wait for the accelerator to write. Timeout and fail after 10ms.
for (int i = 0; i < 100; ++i) {
if (*dataPtr != 0)
if (check())
break;
std::this_thread::sleep_for(std::chrono::microseconds(100));
}
if (*dataPtr == 0)
if (!check())
throw std::runtime_error("DMA write test failed");
}
}
Expand Down

0 comments on commit ebfb49b

Please sign in to comment.