Skip to content

Commit 5f8a863

Browse files
authored
[PyCDE][Python] Fixing leaks reported by nanobind (#8029)
The ESI and MSFT Python modules had object memory leaks which upstream nanobind detected. Fixing them.
1 parent 7fd5ed0 commit 5f8a863

File tree

5 files changed

+52
-26
lines changed

5 files changed

+52
-26
lines changed

Diff for: frontends/PyCDE/src/pycde/esi.py

+8
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,16 @@
1919
from pathlib import Path
2020
from typing import Dict, List, Optional, Tuple, cast
2121

22+
import atexit
23+
2224
__dir__ = Path(__file__).parent
2325

26+
27+
@atexit.register
28+
def _cleanup():
29+
raw_esi.cleanup()
30+
31+
2432
FlattenStructPorts = "esi.portFlattenStructs"
2533
PortInSuffix = "esi.portInSuffix"
2634
PortOutSuffix = "esi.portOutSuffix"

Diff for: frontends/PyCDE/test/test_compreg.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
# RUN: FileCheck %s --input-file %t/hw/CompReg.tcl --check-prefix TCL
55

66
import pycde
7-
from pycde import types, Module, AppID, Clock, Input, Output
7+
from pycde import Module, AppID, Clock, Input, Output
88
from pycde.devicedb import LocationVector
9+
from pycde.types import Bits
910

1011
from pycde.module import generator
1112

@@ -27,9 +28,9 @@
2728

2829
class CompReg(Module):
2930
clk = Clock()
30-
rst = Input(types.i1)
31-
input = Input(types.i8)
32-
output = Output(types.i8)
31+
rst = Input(Bits(1))
32+
input = Input(Bits(8))
33+
output = Output(Bits(8))
3334

3435
@generator
3536
def build(ports):
@@ -40,7 +41,6 @@ def build(ports):
4041

4142

4243
mod = pycde.System([CompReg], name="CompReg", output_directory=sys.argv[1])
43-
mod.print()
4444
mod.generate()
4545
top_inst = mod.get_instance(CompReg)
4646
mod.createdb()

Diff for: frontends/PyCDE/test/test_xrt.py

+5-8
Original file line numberDiff line numberDiff line change
@@ -14,27 +14,24 @@
1414
# RUN: FileCheck %s --input-file %t/hw/XrtTop.sv --check-prefix=TOP
1515

1616
import pycde
17-
from pycde import Clock, Input, Module, generator, types
17+
from pycde import Clock, Input, Module, generator
18+
from pycde.types import Bits
1819
from pycde.bsp import XrtBSP
1920

2021
import sys
2122

2223

2324
class Main(Module):
24-
clk = Clock(types.i1)
25-
rst = Input(types.i1)
25+
clk = Clock()
26+
rst = Input(Bits(1))
2627

2728
@generator
2829
def construct(ports):
2930
pass
3031

3132

3233
gendir = sys.argv[1]
33-
s = pycde.System(XrtBSP(Main),
34-
name="ESILoopback",
35-
output_directory=gendir,
36-
sw_api_langs=["python"])
37-
s.run_passes(debug=True)
34+
s = pycde.System(XrtBSP(Main), name="ESILoopback", output_directory=gendir)
3835
s.compile()
3936
s.package()
4037

Diff for: lib/Bindings/Python/ESIModule.cpp

+29-8
Original file line numberDiff line numberDiff line change
@@ -31,24 +31,39 @@ using namespace circt::esi;
3131
// The main entry point into the ESI Assembly API.
3232
//===----------------------------------------------------------------------===//
3333

34+
/// Container for a Python function that will be called to generate a service.
35+
class ServiceGenFunc {
36+
public:
37+
ServiceGenFunc(py::object genFunc) : genFunc(std::move(genFunc)) {}
38+
39+
MlirLogicalResult run(MlirOperation reqOp, MlirOperation declOp,
40+
MlirOperation recOp) {
41+
py::gil_scoped_acquire acquire;
42+
py::object rc = genFunc(reqOp, declOp, recOp);
43+
return rc.cast<bool>() ? mlirLogicalResultSuccess()
44+
: mlirLogicalResultFailure();
45+
}
46+
47+
private:
48+
py::object genFunc;
49+
};
50+
3451
// Mapping from unique identifier to python callback. We use std::string
3552
// pointers since we also need to allocate memory for the string.
36-
llvm::DenseMap<std::string *, PyObject *> serviceGenFuncLookup;
53+
llvm::DenseMap<std::string *, ServiceGenFunc> serviceGenFuncLookup;
3754
static MlirLogicalResult serviceGenFunc(MlirOperation reqOp,
3855
MlirOperation declOp,
3956
MlirOperation recOp, void *userData) {
4057
std::string *name = static_cast<std::string *>(userData);
41-
py::handle genFunc(serviceGenFuncLookup[name]);
42-
py::gil_scoped_acquire();
43-
py::object rc = genFunc(reqOp, declOp, recOp);
44-
return rc.cast<bool>() ? mlirLogicalResultSuccess()
45-
: mlirLogicalResultFailure();
58+
auto iter = serviceGenFuncLookup.find(name);
59+
if (iter == serviceGenFuncLookup.end())
60+
return mlirLogicalResultFailure();
61+
return iter->getSecond().run(reqOp, declOp, recOp);
4662
}
4763

4864
void registerServiceGenerator(std::string name, py::object genFunc) {
4965
std::string *n = new std::string(name);
50-
genFunc.inc_ref();
51-
serviceGenFuncLookup[n] = genFunc.ptr();
66+
serviceGenFuncLookup.try_emplace(n, ServiceGenFunc(genFunc));
5267
circtESIRegisterGlobalServiceGenerator(wrap(*n), serviceGenFunc, n);
5368
}
5469

@@ -81,6 +96,12 @@ void circt::python::populateDialectESISubmodule(py::module &m) {
8196
m.doc() = "ESI Python Native Extension";
8297
::registerESIPasses();
8398

99+
// Clean up references when the module is unloaded.
100+
auto cleanup = []() { serviceGenFuncLookup.clear(); };
101+
m.def("cleanup", cleanup,
102+
"Cleanup various references. Must be called before the module is "
103+
"unloaded in order to not leak.");
104+
84105
m.def("registerServiceGenerator", registerServiceGenerator,
85106
"Register a service generator for a given service name.",
86107
py::arg("impl_type"), py::arg("generator"));

Diff for: lib/Bindings/Python/MSFTModule.cpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -117,18 +117,18 @@ class PlacementDB {
117117
class PyLocationVecIterator {
118118
public:
119119
/// Get item at the specified position, translating a nullptr to None.
120-
static py::handle getItem(MlirAttribute locVec, intptr_t pos) {
120+
static std::optional<MlirAttribute> getItem(MlirAttribute locVec,
121+
intptr_t pos) {
121122
MlirAttribute loc = circtMSFTLocationVectorAttrGetElement(locVec, pos);
122123
if (loc.ptr == nullptr)
123-
return py::none();
124-
return py::detail::type_caster<MlirAttribute>().cast(
125-
loc, py::return_value_policy::automatic, py::handle());
124+
return std::nullopt;
125+
return loc;
126126
}
127127

128128
PyLocationVecIterator(MlirAttribute attr) : attr(attr) {}
129129
PyLocationVecIterator &dunderIter() { return *this; }
130130

131-
py::handle dunderNext() {
131+
std::optional<MlirAttribute> dunderNext() {
132132
if (nextIndex >= circtMSFTLocationVectorAttrGetNumElements(attr)) {
133133
throw py::stop_iteration();
134134
}

0 commit comments

Comments
 (0)