Skip to content

Commit

Permalink
[ESI][BSP] Adding byte enables to cosim hostmem
Browse files Browse the repository at this point in the history
Writes can happen at less than the granularity of the width of the
upstream port, so we must support writing part of the data width.
  • Loading branch information
teqdruid committed Jan 28, 2025
1 parent 17c7fc8 commit 880bdcd
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 15 deletions.
33 changes: 27 additions & 6 deletions frontends/PyCDE/src/pycde/bsp/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from math import ceil

from ..common import Clock, Input, InputChannel, Output, OutputChannel, Reset
from ..constructs import (AssignableSignal, ControlReg, Counter, NamedWire, Reg,
Wire)
from ..constructs import (AssignableSignal, ControlReg, Counter, Mux, NamedWire,
Reg, Wire)
from .. import esi
from ..module import Module, generator, modparams
from ..signals import BitsSignal, BundleSignal, ChannelSignal
Expand Down Expand Up @@ -458,6 +458,8 @@ def TaggedWriteGearbox(input_bitwidth: int,

if output_bitwidth % 8 != 0:
raise ValueError("Output bitwidth must be a multiple of 8.")
if input_bitwidth % 8 != 0:
raise ValueError("Input bitwidth must be a multiple of 8.")

class TaggedWriteGearboxImpl(Module):
clk = Clock()
Expand All @@ -473,15 +475,20 @@ class TaggedWriteGearboxImpl(Module):
("address", UInt(64)),
("tag", esi.HostMem.TagType),
("data", Bits(output_bitwidth)),
("valid_bytes", Bits(8)),
]))

num_chunks = ceil(input_bitwidth / output_bitwidth)

@generator
def build(ports):
upstream_ready = Wire(Bits(1))
ready_for_client = Wire(Bits(1))
client_tag_and_data, client_valid = ports.in_.unwrap(ready_for_client)
client_data = client_tag_and_data.data
client_xact = ready_for_client & client_valid
input_bitwidth_bytes = input_bitwidth // 8
output_bitwidth_bytes = output_bitwidth // 8

# Determine if gearboxing is necessary and whether it needs to be
# gearboxed up or just sliced down.
Expand All @@ -491,19 +498,23 @@ def build(ports):
ready_for_client.assign(upstream_ready)
tag = client_tag_and_data.tag
address = client_tag_and_data.address
valid_bytes = Bits(8)(input_bitwidth_bytes)
elif output_bitwidth > input_bitwidth:
upstream_data_bits = client_data.as_bits(output_bitwidth)
upstream_valid = client_valid
ready_for_client.assign(upstream_ready)
tag = client_tag_and_data.tag
address = client_tag_and_data.address
valid_bytes = Bits(8)(input_bitwidth_bytes)
else:
# Create registers equal to the number of upstream transactions needed
# to complete the transmission.
num_chunks = ceil(input_bitwidth / output_bitwidth)
num_chunks = TaggedWriteGearboxImpl.num_chunks
num_chunks_idx_bitwidth = clog2(num_chunks)
padding = Bits(output_bitwidth - (input_bitwidth % output_bitwidth))(0)
client_data_padded = BitsSignal.concat([padding, client_data])
padding_numbits = output_bitwidth - (input_bitwidth % output_bitwidth)
assert padding_numbits % 8 == 0, "Padding must be a multiple of 8."
client_data_padded = BitsSignal.concat(
[Bits(padding_numbits)(0), client_data])
chunks = [
client_data_padded[i * output_bitwidth:(i + 1) * output_bitwidth]
for i in range(num_chunks)
Expand Down Expand Up @@ -537,12 +548,16 @@ def build(ports):
name="address_reg")
address = (addr_reg + counter_bytes).as_uint(64)
tag = tag_reg
valid_bytes = Mux(counter.out == (num_chunks - 1),
Bits(8)(output_bitwidth_bytes),
Bits(8)(padding_numbits // 8))

upstream_channel, upstrm_ready_sig = TaggedWriteGearboxImpl.out.type.wrap(
{
"address": address,
"tag": tag,
"data": upstream_data_bits,
"valid_bytes": valid_bytes
}, upstream_valid)
upstream_ready.assign(upstrm_ready_sig)
ports.out = upstream_channel
Expand Down Expand Up @@ -584,7 +599,8 @@ def build(ports):
{
"address": 0,
"tag": 0,
"data": 0
"data": 0,
"valid_bytes": 0,
}, 0)
write_bundle, _ = hostmem_module.write.type.pack(req=req)
ports.upstream = write_bundle
Expand Down Expand Up @@ -633,6 +649,7 @@ def build(ports):
"address": m.address,
"tag": idx,
"data": m.data,
"valid_bytes": m.valid_bytes
})))
# Set the port for the client request.
setattr(ports, HostMemWriteProcessorImpl.reqPortMap[req], bundle_sig)
Expand Down Expand Up @@ -670,10 +687,14 @@ class ChannelHostMemImpl(esi.ServiceImplementation):
("data", Bits(read_width)),
])),
]))

if write_width % 8 != 0:
raise ValueError("Write width must be a multiple of 8.")
UpstreamWriteReq = StructType([
("address", UInt(64)),
("tag", UInt(8)),
("data", Bits(write_width)),
("valid_bytes", Bits(8)),
])
write = Output(
Bundle([
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/ESI/runtime/cpp/include/esi/Common.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ std::ostream &operator<<(std::ostream &, const esi::AppID &);
//===----------------------------------------------------------------------===//

namespace esi {
std::string toHex(uint32_t val);
std::string toHex(uint64_t val);
} // namespace esi

#endif // ESI_COMMON_H
2 changes: 1 addition & 1 deletion lib/Dialect/ESI/runtime/cpp/lib/Common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ std::string MessageData::toHex() const {
return ss.str();
}

std::string esi::toHex(uint32_t val) {
std::string esi::toHex(uint64_t val) {
std::ostringstream ss;
ss << std::hex << val;
return ss.str();
Expand Down
10 changes: 7 additions & 3 deletions lib/Dialect/ESI/runtime/cpp/lib/backends/Cosim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,7 @@ struct HostMemReadResp {
};

struct HostMemWriteReq {
uint8_t valid_bytes;
uint64_t data;
uint8_t tag;
uint64_t address;
Expand Down Expand Up @@ -540,10 +541,13 @@ class CosimHostMem : public HostMem {
std::unique_ptr<std::map<std::string, std::any>> &details) {
subsystem = "HostMem";
msg = "Write request: addr=0x" + toHex(req->address) + " data=0x" +
toHex(req->data) + " tag=" + std::to_string(req->tag);
toHex(req->data) +
" valid_bytes=" + std::to_string(req->valid_bytes) +
" tag=" + std::to_string(req->tag);
});
uint64_t *dataPtr = reinterpret_cast<uint64_t *>(req->address);
*dataPtr = req->data;
uint8_t *dataPtr = reinterpret_cast<uint8_t *>(req->address);
for (uint8_t i = 0; i < req->valid_bytes; ++i)
dataPtr[i] = (req->data >> (i * 8)) & 0xFF;
HostMemWriteResp resp = req->tag;
return MessageData::from(resp);
}
Expand Down
20 changes: 16 additions & 4 deletions lib/Dialect/ESI/runtime/cpp/tools/esitester.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,16 @@ void dmaTest(Accelerator *acc, esi::services::HostMem::HostMemRegion &region,
// Initiate a test write.
// TODO: remove the width == 96 once multiplexing support is added.
if (write) {
auto check = [&]() {
assert(width % 8 == 0);
auto check = [&](bool print) {
region.flush();
for (size_t i = 0, e = (width + 63) / 64; i < e; ++i)
for (size_t i = 0, e = (width + 63) / 64; i < e; ++i) {
if (print)
std::cout << "dataPtr[" << i << "] = 0x" << esi::toHex(dataPtr[i])
<< std::endl;
if (dataPtr[i] == 0xFFFFFFFFFFFFFFFFull)
return false;
}
return true;
};

Expand All @@ -179,12 +184,19 @@ void dmaTest(Accelerator *acc, esi::services::HostMem::HostMemRegion &region,
writeMem->write(0, reinterpret_cast<uint64_t>(devicePtr));
// Wait for the accelerator to write. Timeout and fail after 10ms.
for (int i = 0; i < 100; ++i) {
if (check())
if (check(false))
break;
std::this_thread::sleep_for(std::chrono::microseconds(100));
}
if (!check())
if (!check(true))
throw std::runtime_error("DMA write test failed");

// Check that the accelerator didn't write too far.
size_t widthInBytes = width / 8;
uint8_t *dataPtr8 = reinterpret_cast<uint8_t *>(region.getPtr());
for (size_t i = widthInBytes, e = (widthInBytes + 7) / 8; i < e; ++i)
if (dataPtr8[i] != 0xFF)
throw std::runtime_error("DMA write test failed -- write went too far");
}
}

Expand Down

0 comments on commit 880bdcd

Please sign in to comment.