Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions compiler/frontend/pycircuit/hw.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,12 +746,31 @@ def scope(self, name: str) -> Iterator[None]:
def domain(self, name: str) -> ClockDomain:
return ClockDomain(clk=self.clock(f"{name}_clk"), rst=self.reset(f"{name}_rst"))

def create_domain(self, name: str, *, frequency_desc: str = "", reset_active_high: bool = False) -> Any:
def create_domain(
self,
name: str,
*,
frequency_desc: str = "",
reset_name: str | None = None,
reset_polarity: str = "active_high",
reset_active_high: bool | None = None,
) -> Any:
"""V5 cycle-aware domain (next/prev/push/pop); see `pycircuit.v5.CycleAwareDomain`."""
from .v5 import CycleAwareDomain

_ = (frequency_desc, reset_active_high)
return CycleAwareDomain(self, str(name))
_ = frequency_desc
return CycleAwareDomain(
self,
str(name),
reset_name=reset_name,
reset_polarity=(
"active_high"
if reset_active_high is True
else "active_low"
if reset_active_high is False
else reset_polarity
),
)

def input(self, name: str, *, width: int, signed: bool = False) -> Wire: # type: ignore[override]
"""Declare a module input port and return it as a `Wire`."""
Expand Down
123 changes: 111 additions & 12 deletions compiler/frontend/pycircuit/v5.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,21 @@ def emit_mlir(self) -> str:
return super().emit_mlir()

def create_domain(
self, name: str, *, frequency_desc: str = "", reset_active_high: bool = False
self,
name: str,
*,
frequency_desc: str = "",
reset_name: str | None = None,
reset_polarity: str = "active_high",
reset_active_high: bool | None = None,
) -> "CycleAwareDomain":
_ = (frequency_desc, reset_active_high)
return CycleAwareDomain(self, str(name))
_ = frequency_desc
polarity = _normalize_reset_polarity(
reset_polarity, reset_active_high=reset_active_high
)
return CycleAwareDomain(
self, str(name), reset_name=reset_name, reset_polarity=polarity
)

def const_signal(self, value: int, width: int, domain: "CycleAwareDomain") -> Wire:
return domain.create_const(int(value), width=int(width))
Expand All @@ -70,10 +81,24 @@ def input_signal(self, name: str, width: int, domain: "CycleAwareDomain") -> Wir
class CycleAwareDomain:
"""Clock domain with logical occurrence index (tutorial: next/prev/push/pop/cycle)."""

def __init__(self, circuit: Circuit, domain_name: str) -> None:
def __init__(
self,
circuit: Circuit,
domain_name: str,
*,
reset_name: str | None = None,
reset_polarity: str = "active_high",
) -> None:
self._m = circuit
self._name = str(domain_name)
self._cd = _clock_domain_ports(circuit, self._name)
self._reset_name = None if reset_name is None else str(reset_name)
self._reset_polarity = _normalize_reset_polarity(reset_polarity)
self._cd = _clock_domain_ports(
circuit,
self._name,
reset_name=self._reset_name,
reset_polarity=self._reset_polarity,
)
self._occurrence = 0
self._stack: list[int] = []
self._delay_serial = 0
Expand Down Expand Up @@ -461,10 +486,55 @@ def _make_compiled_module(fn: Any, circuit: CycleAwareCircuit, sym_name: str) ->
)


def _clock_domain_ports(m: Circuit, name: str) -> ClockDomain:
def _normalize_reset_polarity(
reset_polarity: str, *, reset_active_high: bool | None = None
) -> str:
if reset_active_high is not None:
return "active_high" if bool(reset_active_high) else "active_low"
p = str(reset_polarity).strip().lower().replace("-", "_")
if p in {"active_high", "high", "1", "true"}:
return "active_high"
if p in {"active_low", "low", "0", "false"}:
return "active_low"
raise ValueError(
"reset_polarity must be 'active_high' or 'active_low' "
f"(got {reset_polarity!r})"
)


def _record_reset_polarity(m: Circuit, reset_port: str, polarity: str) -> None:
by_port = getattr(m, "_v5_reset_polarities", None)
if not isinstance(by_port, dict):
by_port = {}
setattr(m, "_v5_reset_polarities", by_port)
by_port[str(reset_port)] = str(polarity)

values: list[str] = []
for arg_name, sig in getattr(m, "_args", []):
if getattr(sig, "ty", "") == "!pyc.reset":
values.append(str(by_port.get(arg_name, "active_high")))
else:
values.append("")
m.set_func_attr_json("pyc.reset_polarities", values)


def _clock_domain_ports(
m: Circuit,
name: str,
*,
reset_name: str | None = None,
reset_polarity: str = "active_high",
) -> ClockDomain:
if name == "clk":
return ClockDomain(clk=m.clock("clk"), rst=m.reset("rst"))
return m.domain(name)
clk = m.clock("clk")
rst_port = "rst" if reset_name is None else str(reset_name)
rst = m.reset(rst_port)
else:
clk = m.clock(f"{name}_clk")
rst_port = f"{name}_rst" if reset_name is None else str(reset_name)
rst = m.reset(rst_port)
Comment on lines 528 to +535
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for determining the reset port name is duplicated between the default 'clk' domain and other domains. This can be simplified to improve maintainability.

    is_default = (name == "clk")
    clk_port = "clk" if is_default else f"{name}_clk"
    rst_default = "rst" if is_default else f"{name}_rst"

    clk = m.clock(clk_port)
    rst_port = rst_default if reset_name is None else str(reset_name)
    rst = m.reset(rst_port)

_record_reset_polarity(m, rst_port, _normalize_reset_polarity(reset_polarity))
return ClockDomain(clk=clk, rst=rst)


def _as_wire(
Expand Down Expand Up @@ -1126,7 +1196,11 @@ def cas(


def _strip_domain_for_jit(
fn: Callable[..., Any], *, domain_name: str
fn: Callable[..., Any],
*,
domain_name: str,
reset_name: str | None,
reset_polarity: str,
) -> Callable[..., Any]:
"""Drop the ``domain`` parameter for JIT and prepend ``domain = m.create_domain(...)``."""
try:
Expand Down Expand Up @@ -1168,7 +1242,16 @@ def _strip_domain_for_jit(
ctx=ast.Load(),
),
args=[ast.Constant(value=str(domain_name))],
keywords=[],
keywords=[
ast.keyword(
arg="reset_name",
value=ast.Constant(value=None if reset_name is None else str(reset_name)),
),
ast.keyword(
arg="reset_polarity",
value=ast.Constant(value=str(reset_polarity)),
),
],
),
)
fdef.body.insert(0, prelude)
Expand All @@ -1193,6 +1276,9 @@ def compile_cycle_aware(
*,
name: str | None = None,
domain_name: str = "clk",
reset_name: str | None = None,
reset_polarity: str = "active_high",
reset_active_high: bool | None = None,
eager: bool = False,
hierarchical: bool = False,
structural: bool | None = None,
Expand All @@ -1212,14 +1298,22 @@ def compile_cycle_aware(
MLIR ops and instantiated via ``pyc.instance``. The returned circuit's
``emit_mlir()`` emits a multi-module ``Design``.
"""
normalized_reset_polarity = _normalize_reset_polarity(
reset_polarity, reset_active_high=reset_active_high
)

if eager:
circuit_name = (
name
if isinstance(name, str) and name.strip()
else getattr(fn, "__name__", "design") or "design"
)
m = CycleAwareCircuit(str(circuit_name), design_ctx=design_ctx)
dom = m.create_domain(str(domain_name))
dom = m.create_domain(
str(domain_name),
reset_name=reset_name,
reset_polarity=normalized_reset_polarity,
)

if hierarchical:
from .design import Design
Expand Down Expand Up @@ -1265,7 +1359,12 @@ def compile_cycle_aware(

domain_n = str(domain_name)

_jit_fn = _strip_domain_for_jit(fn, domain_name=domain_n)
_jit_fn = _strip_domain_for_jit(
fn,
domain_name=domain_n,
reset_name=reset_name,
reset_polarity=normalized_reset_polarity,
)
setattr(_jit_fn, "__pycircuit_module_name__", sym)
setattr(_jit_fn, "__pycircuit_kind__", "module")
setattr(_jit_fn, "__pycircuit_inline__", False)
Expand Down
37 changes: 36 additions & 1 deletion compiler/mlir/lib/Emit/CppEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,18 @@ static std::string getPortCanonicalFieldPath(func::FuncOp f, unsigned idx, bool
return "out" + std::to_string(idx);
}

static bool isActiveLowResetPort(func::FuncOp f, unsigned idx) {
if (idx >= f.getNumArguments())
return false;
if (!isa<pyc::ResetType>(f.getArgument(idx).getType()))
return false;
auto polarities = f->getAttrOfType<ArrayAttr>("pyc.reset_polarities");
if (!polarities || idx >= polarities.size())
return false;
auto polarity = dyn_cast<StringAttr>(polarities[idx]);
return polarity && polarity.getValue() == "active_low";
}

struct ProbeAliasEntry {
std::string canonicalPath;
std::string sourcePath;
Expand Down Expand Up @@ -634,12 +646,24 @@ static LogicalResult emitFunc(func::FuncOp f, llvm::raw_ostream &os, const CppEm
outNames.reserve(f.getNumResults());
std::vector<std::string> outCanon;
outCanon.reserve(f.getNumResults());
struct ResetAlias {
std::string externalName;
std::string internalName;
};
std::vector<ResetAlias> resetAliases;
for (auto [i, arg] : llvm::enumerate(f.getArguments())) {
inCanon.push_back(getPortCanonicalFieldPath(f, i, /*isResult=*/false));
std::string name = nt.unique(getPortName(f, i, /*isResult=*/false));
inNames.push_back(name);
nt.names.try_emplace(arg, name);
os << " " << cppType(arg.getType()) << " " << name << "{};\n";
if (isActiveLowResetPort(f, static_cast<unsigned>(i))) {
std::string internalName = nt.unique(name + "__active_high");
resetAliases.push_back(ResetAlias{name, internalName});
nt.names.try_emplace(arg, internalName);
os << " pyc::cpp::Wire<1> " << internalName << "{};\n";
} else {
nt.names.try_emplace(arg, name);
}
}
for (unsigned i = 0; i < f.getNumResults(); ++i) {
outCanon.push_back(getPortCanonicalFieldPath(f, i, /*isResult=*/true));
Expand Down Expand Up @@ -1341,6 +1365,13 @@ static LogicalResult emitFunc(func::FuncOp f, llvm::raw_ostream &os, const CppEm
os << " #endif\n";
os << " }\n\n";

if (!resetAliases.empty()) {
os << " inline void _pyc_sync_reset_polarity() {\n";
for (const ResetAlias &alias : resetAliases)
os << " " << alias.internalName << " = ~" << alias.externalName << ";\n";
os << " }\n\n";
}

// Emit fused comb helpers.
for (auto [i, comb] : llvm::enumerate(combs)) {
if (failed(emitCombMethod(comb, os, nt, static_cast<unsigned>(i), opts)))
Expand Down Expand Up @@ -2293,6 +2324,8 @@ static LogicalResult emitFunc(func::FuncOp f, llvm::raw_ostream &os, const CppEm
}

os << " void eval() {\n";
if (!resetAliases.empty())
os << " _pyc_sync_reset_polarity();\n";
if (hasFullTopo) {
if (!topoEvalMethods.empty()) {
for (const std::string &methodName : topoEvalMethods)
Expand Down Expand Up @@ -2368,6 +2401,8 @@ static LogicalResult emitFunc(func::FuncOp f, llvm::raw_ostream &os, const CppEm
}

os << " void tick_compute() {\n";
if (!resetAliases.empty())
os << " _pyc_sync_reset_polarity();\n";
if (!instInfos.empty()) {
os << " // Sub-modules.\n";
for (unsigned i = 0; i < subParts; ++i)
Expand Down
32 changes: 31 additions & 1 deletion compiler/mlir/lib/Emit/VerilogEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,18 @@ static std::string getPortName(func::FuncOp f, unsigned idx, bool isResult) {
return sanitizeId(raw);
}

static bool isActiveLowResetPort(func::FuncOp f, unsigned idx) {
if (idx >= f.getNumArguments())
return false;
if (!isa<pyc::ResetType>(f.getArgument(idx).getType()))
return false;
auto polarities = f->getAttrOfType<ArrayAttr>("pyc.reset_polarities");
if (!polarities || idx >= polarities.size())
return false;
auto polarity = dyn_cast<StringAttr>(polarities[idx]);
return polarity && polarity.getValue() == "active_low";
}

static void computeUniquePortNames(func::FuncOp f, std::vector<std::string> &inNames, std::vector<std::string> &outNames) {
NameTable nt;
inNames.clear();
Expand Down Expand Up @@ -505,6 +517,11 @@ static LogicalResult emitFunc(func::FuncOp f, raw_ostream &os, const VerilogEmit
NameTable nt;
std::vector<std::string> outNames;
outNames.reserve(f.getNumResults());
struct ResetAlias {
std::string externalName;
std::string internalName;
};
std::vector<ResetAlias> resetAliases;

os << "// Generated by pycc (pyCircuit)\n";
os << "// Module: " << f.getSymName() << "\n\n";
Expand All @@ -519,7 +536,13 @@ static LogicalResult emitFunc(func::FuncOp f, raw_ostream &os, const VerilogEmit
os << range << " ";
os << portName;
os << ((i + 1 == f.getNumArguments() && f.getNumResults() == 0) ? "\n" : ",\n");
nt.names.try_emplace(arg, portName);
if (isActiveLowResetPort(f, static_cast<unsigned>(i))) {
std::string internalName = nt.unique(portName + "__active_high");
resetAliases.push_back(ResetAlias{portName, internalName});
nt.names.try_emplace(arg, internalName);
} else {
nt.names.try_emplace(arg, portName);
}
}
for (unsigned i = 0; i < f.getNumResults(); ++i) {
std::string portName = nt.unique(getPortName(f, i, /*isResult=*/true));
Expand All @@ -533,6 +556,13 @@ static LogicalResult emitFunc(func::FuncOp f, raw_ostream &os, const VerilogEmit
}
os << ");\n\n";

for (const ResetAlias &alias : resetAliases) {
os << "wire " << alias.internalName << ";\n";
os << "assign " << alias.internalName << " = ~" << alias.externalName << ";\n";
}
if (!resetAliases.empty())
os << "\n";

// Declare internal nets for op results (including results inside pyc.comb regions).
std::vector<NetDecl> decls;
decls.reserve(256);
Expand Down
Loading
Loading