Skip to content

Commit b299af3

Browse files
authoredFeb 19, 2025··
[shortfin] Implement async alloc/dealloc of buffers. (#507)
* Device allocations are now async, queue ordered alloc/dealloc. * Program invocations asynchronously deallocate function call results if it can. If it ever cannot, then a small tracy zone `SyncImportTimelineResource` will be emitted per result that cannot be async deallocated. * Adds `ProgramInvocation.assume_no_alias` instance boolean to disable the assumption which allows async deallocation to work. * Adds global `ProgramIncovation.global_no_alias` property to control process-wide. This is a very fiddly optimization which requires (esp in multi-device cases) a number of things to line up. Tested on amdgpu and CPU with a number of sample workloads (with logging enabled and visually confirmed). See #980 for detailed analysis and further work required.
1 parent 888a98a commit b299af3

14 files changed

+569
-134
lines changed
 

‎shortfin/python/lib_ext.cc

+115-61
Original file line numberDiff line numberDiff line change
@@ -316,32 +316,97 @@ local::ProgramInvocation::Future PyFunctionCall(
316316
return local::ProgramInvocation::Invoke(std::move(inv));
317317
}
318318

319-
py::object PyRehydrateRef(local::ProgramInvocation *inv,
320-
iree::vm_opaque_ref ref) {
321-
auto type = ref.get()->type;
322-
// Note that these accessors are dangerous as they assert/abort if
323-
// process-wide registration is not done properly. We assume here that
324-
// since we got a ref out that the basics are set up soundly, but if actually
325-
// doing this on user/dynamic types, we would want to be more defensive.
326-
// TODO: Don't just do a linear scan if we have more than a couple.
327-
// TODO: Find a reliable way to statically cache the type id.
328-
if (local::ProgramInvocationMarshalableFactory::invocation_marshalable_type<
329-
array::device_array>() == type) {
330-
// device_array
331-
return py::cast(local::ProgramInvocationMarshalableFactory::
332-
CreateFromInvocationResultRef<array::device_array>(
333-
inv, std::move(ref)));
334-
} else if (local::ProgramInvocationMarshalableFactory::
335-
invocation_marshalable_type<array::storage>() == type) {
336-
// storage
337-
return py::cast(
338-
local::ProgramInvocationMarshalableFactory::
339-
CreateFromInvocationResultRef<array::storage>(inv, std::move(ref)));
319+
// Wraps a ProgramInvocation::Ptr representing a completed (awaited) invocation.
320+
// Holds some additional accounting for marshaling results back to Python.
321+
class PyProgramInvocation {
322+
public:
323+
PyProgramInvocation(local::ProgramInvocation::Ptr inv)
324+
: inv_(std::move(inv)) {}
325+
PyProgramInvocation(const PyProgramInvocation &) = delete;
326+
PyProgramInvocation(PyProgramInvocation &&other)
327+
: inv_(std::move(other.inv_)),
328+
cached_results_(std::move(other.cached_results_)),
329+
results_failure_(other.results_failure_) {}
330+
331+
// Fields that can be bound.
332+
bool assume_no_alias = true;
333+
static std::optional<bool> global_assume_no_alias;
334+
335+
void CheckValid() {
336+
if (!inv_) throw std::invalid_argument("Deallocated invocation");
340337
}
341-
throw std::invalid_argument(
342-
fmt::format("Cannot marshal ref type {} to Python",
343-
to_string_view(iree_vm_ref_type_name(type))));
344-
}
338+
local::ProgramInvocation::Ptr &inv() { return inv_; }
339+
340+
py::object results() {
341+
if (results_failure_) {
342+
throw std::logic_error("Prior attempt to marshal IREE results failed");
343+
}
344+
if (cached_results_) {
345+
return cached_results_;
346+
}
347+
348+
// Cache results.
349+
CheckValid();
350+
results_failure_ = true;
351+
352+
local::CoarseInvocationTimelineImporter::Options options;
353+
options.assume_no_alias = assume_no_alias;
354+
if (global_assume_no_alias) {
355+
options.assume_no_alias = *global_assume_no_alias;
356+
}
357+
local::CoarseInvocationTimelineImporter timeline_importer(inv().get(),
358+
options);
359+
size_t size = inv_->results_size();
360+
py::object tp = py::steal(PyTuple_New(size));
361+
for (size_t i = 0; i < size; ++i) {
362+
iree::vm_opaque_ref ref = inv_->result_ref(i);
363+
if (!ref) {
364+
throw new std::logic_error("Program returned unsupported Python type");
365+
}
366+
py::object item = RehydrateRef(std::move(ref), &timeline_importer);
367+
PyTuple_SET_ITEM(tp.ptr(), i, item.release().ptr());
368+
}
369+
370+
cached_results_ = std::move(tp);
371+
results_failure_ = false;
372+
return cached_results_;
373+
}
374+
375+
private:
376+
py::object RehydrateRef(
377+
iree::vm_opaque_ref ref,
378+
local::CoarseInvocationTimelineImporter *timeline_importer) {
379+
auto type = ref.get()->type;
380+
// Note that these accessors are dangerous as they assert/abort if
381+
// process-wide registration is not done properly. We assume here that
382+
// since we got a ref out that the basics are set up soundly, but if
383+
// actually doing this on user/dynamic types, we would want to be more
384+
// defensive.
385+
// TODO: Don't just do a linear scan if we have more than a couple.
386+
// TODO: Find a reliable way to statically cache the type id.
387+
if (local::ProgramInvocationMarshalableFactory::invocation_marshalable_type<
388+
array::device_array>() == type) {
389+
// device_array
390+
return py::cast(local::ProgramInvocationMarshalableFactory::
391+
CreateFromInvocationResultRef<array::device_array>(
392+
inv().get(), timeline_importer, std::move(ref)));
393+
} else if (local::ProgramInvocationMarshalableFactory::
394+
invocation_marshalable_type<array::storage>() == type) {
395+
// storage
396+
return py::cast(local::ProgramInvocationMarshalableFactory::
397+
CreateFromInvocationResultRef<array::storage>(
398+
inv().get(), timeline_importer, std::move(ref)));
399+
}
400+
throw std::invalid_argument(
401+
fmt::format("Cannot marshal ref type {} to Python",
402+
to_string_view(iree_vm_ref_type_name(type))));
403+
}
404+
405+
local::ProgramInvocation::Ptr inv_;
406+
py::object cached_results_;
407+
bool results_failure_ = false;
408+
};
409+
std::optional<bool> PyProgramInvocation::global_assume_no_alias;
345410

346411
py::object RunInForeground(std::shared_ptr<Refs> refs, local::System &self,
347412
py::object coro) {
@@ -743,56 +808,45 @@ void BindLocal(py::module_ &m) {
743808
return local::ProgramModule::ParameterProvider(system, c_params);
744809
},
745810
py::arg("system"), py::arg("params"));
746-
py::class_<local::ProgramInvocation::Ptr>(m, "ProgramInvocation")
811+
py::class_<PyProgramInvocation>(m, "ProgramInvocation")
812+
.def_rw("assume_no_alias", &PyProgramInvocation::assume_no_alias,
813+
"Assumes that no results alias inputs or other buffers")
814+
.def_rw_static(
815+
"global_assume_no_alias",
816+
&PyProgramInvocation::global_assume_no_alias,
817+
"Globally changes the assume_no_alias flag for all invocations")
747818
.def("invoke",
748-
[](local::ProgramInvocation::Ptr &self) {
749-
if (!self) throw std::invalid_argument("Deallocated invocation");
750-
return local::ProgramInvocation::Invoke(std::move(self));
819+
[](PyProgramInvocation &self) {
820+
self.CheckValid();
821+
return local::ProgramInvocation::Invoke(std::move(self.inv()));
751822
})
752823
.def("add_arg",
753-
[](local::ProgramInvocation::Ptr &self, py::handle arg) {
754-
if (!self) throw std::invalid_argument("Deallocated invocation");
755-
py::capsule inv_capsule(self.get());
824+
[](PyProgramInvocation &self, py::handle arg) {
825+
self.CheckValid();
826+
py::capsule inv_capsule(&self.inv());
756827
PyAddProgramInvocationArg(inv_capsule, arg);
757828
})
758829
.def("__iter__",
759-
[](local::ProgramInvocation::Ptr &self) {
760-
if (!self) throw std::invalid_argument("Deallocated invocation");
761-
size_t size = self->results_size();
762-
py::object tp = py::steal(PyTuple_New(size));
763-
for (size_t i = 0; i < size; ++i) {
764-
iree::vm_opaque_ref ref = self->result_ref(i);
765-
if (!ref) {
766-
throw new std::logic_error(
767-
"Program returned unsupported Python type");
768-
}
769-
py::object item = PyRehydrateRef(self.get(), std::move(ref));
770-
PyTuple_SET_ITEM(tp.ptr(), i, item.release().ptr());
771-
}
772-
return tp.attr("__iter__")();
830+
[](PyProgramInvocation &self) {
831+
return self.results().attr("__iter__")();
773832
})
774833
.def(
775834
"__len__",
776-
[](local::ProgramInvocation::Ptr &self) {
777-
if (!self) throw std::invalid_argument("Deallocated invocation");
778-
return self->results_size();
835+
[](PyProgramInvocation &self) {
836+
self.CheckValid();
837+
return self.inv()->results_size();
779838
},
780839
"The number of results in this invocation")
781840
.def(
782841
"__getitem__",
783-
[](local::ProgramInvocation::Ptr &self, iree_host_size_t i) {
784-
if (!self) throw std::invalid_argument("Deallocated invocation");
785-
iree::vm_opaque_ref ref = self->result_ref(i);
786-
if (!ref) {
787-
throw new std::logic_error(
788-
"Program returned unsupported Python type");
789-
}
790-
return PyRehydrateRef(self.get(), std::move(ref));
842+
[](PyProgramInvocation &self, iree_host_size_t i) {
843+
self.CheckValid();
844+
return self.results().attr("__getitem__")(i);
791845
},
792846
"Gets the i'th result")
793-
.def("__repr__", [](local::ProgramInvocation::Ptr &self) {
794-
if (!self) return std::string("ProgramInvocation(INVALID)");
795-
return self->to_s();
847+
.def("__repr__", [](PyProgramInvocation &self) {
848+
if (!self.inv()) return std::string("ProgramInvocation(INVALID)");
849+
return self.inv()->to_s();
796850
});
797851

798852
py::class_<local::BaseProgramParameters>(m, "BaseProgramParameters");
@@ -1207,7 +1261,7 @@ void BindLocal(py::module_ &m) {
12071261
// expensive in the C++ API: essentially, ProgramInvocations flow
12081262
// through the system precisely one way. As a low level facility, this
12091263
// is deemed acceptable.
1210-
return py::cast(std::move(result));
1264+
return py::cast(PyProgramInvocation(std::move(result)));
12111265
});
12121266
py::class_<local::MessageFuture, local::Future>(m, "MessageFuture")
12131267
.def("result", [](local::MessageFuture &self) {

‎shortfin/src/shortfin/array/array.cc

+6-4
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ void device_array::AddAsInvocationArgument(
109109

110110
iree::vm_opaque_ref ref;
111111
*(&ref) = iree_hal_buffer_view_move_ref(buffer_view);
112-
inv->AddArg(std::move(ref));
112+
inv->AddArg(std::move(ref), storage().timeline_resource_.get());
113113

114114
storage().AddInvocationArgBarrier(inv, barrier);
115115
}
@@ -119,16 +119,18 @@ iree_vm_ref_type_t device_array::invocation_marshalable_type() {
119119
}
120120

121121
device_array device_array::CreateFromInvocationResultRef(
122-
local::ProgramInvocation *inv, iree::vm_opaque_ref ref) {
122+
local::ProgramInvocation *inv,
123+
local::CoarseInvocationTimelineImporter *timeline_importer,
124+
iree::vm_opaque_ref ref) {
123125
SHORTFIN_TRACE_SCOPE_NAMED("PyDeviceArray::CreateFromInvocationResultRef");
124126
// We don't retain the buffer view in the device array, so just deref it
125127
// vs stealing the ref.
126128
iree_hal_buffer_view_t *bv = iree_hal_buffer_view_deref(*ref.get());
127129
iree::hal_buffer_ptr buffer =
128130
iree::hal_buffer_ptr::borrow_reference(iree_hal_buffer_view_buffer(bv));
129131

130-
auto imported_storage =
131-
storage::ImportInvocationResultStorage(inv, std::move(buffer));
132+
auto imported_storage = storage::ImportInvocationResultStorage(
133+
inv, timeline_importer, std::move(buffer));
132134
std::span<const iree_hal_dim_t> shape(iree_hal_buffer_view_shape_dims(bv),
133135
iree_hal_buffer_view_shape_rank(bv));
134136
return device_array(

‎shortfin/src/shortfin/array/array.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,9 @@ class SHORTFIN_API device_array
216216
void AddAsInvocationArgument(local::ProgramInvocation *inv,
217217
local::ProgramResourceBarrier barrier) override;
218218
static device_array CreateFromInvocationResultRef(
219-
local::ProgramInvocation *inv, iree::vm_opaque_ref ref);
219+
local::ProgramInvocation *inv,
220+
local::CoarseInvocationTimelineImporter *timeline_importer,
221+
iree::vm_opaque_ref ref);
220222
static iree_vm_ref_type_t invocation_marshalable_type();
221223
friend class shortfin::local::ProgramInvocationMarshalableFactory;
222224
};

0 commit comments

Comments
 (0)
Please sign in to comment.