@@ -316,32 +316,97 @@ local::ProgramInvocation::Future PyFunctionCall(
316
316
return local::ProgramInvocation::Invoke (std::move (inv));
317
317
}
318
318
319
- py::object PyRehydrateRef (local::ProgramInvocation *inv,
320
- iree::vm_opaque_ref ref) {
321
- auto type = ref.get ()->type ;
322
- // Note that these accessors are dangerous as they assert/abort if
323
- // process-wide registration is not done properly. We assume here that
324
- // since we got a ref out that the basics are set up soundly, but if actually
325
- // doing this on user/dynamic types, we would want to be more defensive.
326
- // TODO: Don't just do a linear scan if we have more than a couple.
327
- // TODO: Find a reliable way to statically cache the type id.
328
- if (local::ProgramInvocationMarshalableFactory::invocation_marshalable_type<
329
- array::device_array>() == type) {
330
- // device_array
331
- return py::cast (local::ProgramInvocationMarshalableFactory::
332
- CreateFromInvocationResultRef<array::device_array>(
333
- inv, std::move (ref)));
334
- } else if (local::ProgramInvocationMarshalableFactory::
335
- invocation_marshalable_type<array::storage>() == type) {
336
- // storage
337
- return py::cast (
338
- local::ProgramInvocationMarshalableFactory::
339
- CreateFromInvocationResultRef<array::storage>(inv, std::move (ref)));
319
+ // Wraps a ProgramInvocation::Ptr representing a completed (awaited) invocation.
320
+ // Holds some additional accounting for marshaling results back to Python.
321
+ class PyProgramInvocation {
322
+ public:
323
+ PyProgramInvocation (local::ProgramInvocation::Ptr inv)
324
+ : inv_(std::move(inv)) {}
325
+ PyProgramInvocation (const PyProgramInvocation &) = delete ;
326
+ PyProgramInvocation (PyProgramInvocation &&other)
327
+ : inv_(std::move(other.inv_)),
328
+ cached_results_ (std::move(other.cached_results_)),
329
+ results_failure_(other.results_failure_) {}
330
+
331
+ // Fields that can be bound.
332
+ bool assume_no_alias = true ;
333
+ static std::optional<bool > global_assume_no_alias;
334
+
335
+ void CheckValid () {
336
+ if (!inv_) throw std::invalid_argument (" Deallocated invocation" );
340
337
}
341
- throw std::invalid_argument (
342
- fmt::format (" Cannot marshal ref type {} to Python" ,
343
- to_string_view (iree_vm_ref_type_name (type))));
344
- }
338
+ local::ProgramInvocation::Ptr &inv () { return inv_; }
339
+
340
+ py::object results () {
341
+ if (results_failure_) {
342
+ throw std::logic_error (" Prior attempt to marshal IREE results failed" );
343
+ }
344
+ if (cached_results_) {
345
+ return cached_results_;
346
+ }
347
+
348
+ // Cache results.
349
+ CheckValid ();
350
+ results_failure_ = true ;
351
+
352
+ local::CoarseInvocationTimelineImporter::Options options;
353
+ options.assume_no_alias = assume_no_alias;
354
+ if (global_assume_no_alias) {
355
+ options.assume_no_alias = *global_assume_no_alias;
356
+ }
357
+ local::CoarseInvocationTimelineImporter timeline_importer (inv ().get (),
358
+ options);
359
+ size_t size = inv_->results_size ();
360
+ py::object tp = py::steal (PyTuple_New (size));
361
+ for (size_t i = 0 ; i < size; ++i) {
362
+ iree::vm_opaque_ref ref = inv_->result_ref (i);
363
+ if (!ref) {
364
+ throw new std::logic_error (" Program returned unsupported Python type" );
365
+ }
366
+ py::object item = RehydrateRef (std::move (ref), &timeline_importer);
367
+ PyTuple_SET_ITEM (tp.ptr (), i, item.release ().ptr ());
368
+ }
369
+
370
+ cached_results_ = std::move (tp);
371
+ results_failure_ = false ;
372
+ return cached_results_;
373
+ }
374
+
375
+ private:
376
+ py::object RehydrateRef (
377
+ iree::vm_opaque_ref ref,
378
+ local::CoarseInvocationTimelineImporter *timeline_importer) {
379
+ auto type = ref.get ()->type ;
380
+ // Note that these accessors are dangerous as they assert/abort if
381
+ // process-wide registration is not done properly. We assume here that
382
+ // since we got a ref out that the basics are set up soundly, but if
383
+ // actually doing this on user/dynamic types, we would want to be more
384
+ // defensive.
385
+ // TODO: Don't just do a linear scan if we have more than a couple.
386
+ // TODO: Find a reliable way to statically cache the type id.
387
+ if (local::ProgramInvocationMarshalableFactory::invocation_marshalable_type<
388
+ array::device_array>() == type) {
389
+ // device_array
390
+ return py::cast (local::ProgramInvocationMarshalableFactory::
391
+ CreateFromInvocationResultRef<array::device_array>(
392
+ inv ().get (), timeline_importer, std::move (ref)));
393
+ } else if (local::ProgramInvocationMarshalableFactory::
394
+ invocation_marshalable_type<array::storage>() == type) {
395
+ // storage
396
+ return py::cast (local::ProgramInvocationMarshalableFactory::
397
+ CreateFromInvocationResultRef<array::storage>(
398
+ inv ().get (), timeline_importer, std::move (ref)));
399
+ }
400
+ throw std::invalid_argument (
401
+ fmt::format (" Cannot marshal ref type {} to Python" ,
402
+ to_string_view (iree_vm_ref_type_name (type))));
403
+ }
404
+
405
+ local::ProgramInvocation::Ptr inv_;
406
+ py::object cached_results_;
407
+ bool results_failure_ = false ;
408
+ };
409
+ std::optional<bool > PyProgramInvocation::global_assume_no_alias;
345
410
346
411
py::object RunInForeground (std::shared_ptr<Refs> refs, local::System &self,
347
412
py::object coro) {
@@ -743,56 +808,45 @@ void BindLocal(py::module_ &m) {
743
808
return local::ProgramModule::ParameterProvider (system , c_params);
744
809
},
745
810
py::arg (" system" ), py::arg (" params" ));
746
- py::class_<local::ProgramInvocation::Ptr >(m, " ProgramInvocation" )
811
+ py::class_<PyProgramInvocation>(m, " ProgramInvocation" )
812
+ .def_rw (" assume_no_alias" , &PyProgramInvocation::assume_no_alias,
813
+ " Assumes that no results alias inputs or other buffers" )
814
+ .def_rw_static (
815
+ " global_assume_no_alias" ,
816
+ &PyProgramInvocation::global_assume_no_alias,
817
+ " Globally changes the assume_no_alias flag for all invocations" )
747
818
.def (" invoke" ,
748
- [](local::ProgramInvocation:: Ptr &self) {
749
- if (! self) throw std::invalid_argument ( " Deallocated invocation " );
750
- return local::ProgramInvocation::Invoke (std::move (self));
819
+ [](PyProgramInvocation &self) {
820
+ self. CheckValid ( );
821
+ return local::ProgramInvocation::Invoke (std::move (self. inv () ));
751
822
})
752
823
.def (" add_arg" ,
753
- [](local::ProgramInvocation:: Ptr &self, py::handle arg) {
754
- if (! self) throw std::invalid_argument ( " Deallocated invocation " );
755
- py::capsule inv_capsule (self.get ());
824
+ [](PyProgramInvocation &self, py::handle arg) {
825
+ self. CheckValid ( );
826
+ py::capsule inv_capsule (& self.inv ());
756
827
PyAddProgramInvocationArg (inv_capsule, arg);
757
828
})
758
829
.def (" __iter__" ,
759
- [](local::ProgramInvocation::Ptr &self) {
760
- if (!self) throw std::invalid_argument (" Deallocated invocation" );
761
- size_t size = self->results_size ();
762
- py::object tp = py::steal (PyTuple_New (size));
763
- for (size_t i = 0 ; i < size; ++i) {
764
- iree::vm_opaque_ref ref = self->result_ref (i);
765
- if (!ref) {
766
- throw new std::logic_error (
767
- " Program returned unsupported Python type" );
768
- }
769
- py::object item = PyRehydrateRef (self.get (), std::move (ref));
770
- PyTuple_SET_ITEM (tp.ptr (), i, item.release ().ptr ());
771
- }
772
- return tp.attr (" __iter__" )();
830
+ [](PyProgramInvocation &self) {
831
+ return self.results ().attr (" __iter__" )();
773
832
})
774
833
.def (
775
834
" __len__" ,
776
- [](local::ProgramInvocation:: Ptr &self) {
777
- if (! self) throw std::invalid_argument ( " Deallocated invocation " );
778
- return self->results_size ();
835
+ [](PyProgramInvocation &self) {
836
+ self. CheckValid ( );
837
+ return self. inv () ->results_size ();
779
838
},
780
839
" The number of results in this invocation" )
781
840
.def (
782
841
" __getitem__" ,
783
- [](local::ProgramInvocation::Ptr &self, iree_host_size_t i) {
784
- if (!self) throw std::invalid_argument (" Deallocated invocation" );
785
- iree::vm_opaque_ref ref = self->result_ref (i);
786
- if (!ref) {
787
- throw new std::logic_error (
788
- " Program returned unsupported Python type" );
789
- }
790
- return PyRehydrateRef (self.get (), std::move (ref));
842
+ [](PyProgramInvocation &self, iree_host_size_t i) {
843
+ self.CheckValid ();
844
+ return self.results ().attr (" __getitem__" )(i);
791
845
},
792
846
" Gets the i'th result" )
793
- .def (" __repr__" , [](local::ProgramInvocation:: Ptr &self) {
794
- if (!self) return std::string (" ProgramInvocation(INVALID)" );
795
- return self->to_s ();
847
+ .def (" __repr__" , [](PyProgramInvocation &self) {
848
+ if (!self. inv () ) return std::string (" ProgramInvocation(INVALID)" );
849
+ return self. inv () ->to_s ();
796
850
});
797
851
798
852
py::class_<local::BaseProgramParameters>(m, " BaseProgramParameters" );
@@ -1207,7 +1261,7 @@ void BindLocal(py::module_ &m) {
1207
1261
// expensive in the C++ API: essentially, ProgramInvocations flow
1208
1262
// through the system precisely one way. As a low level facility, this
1209
1263
// is deemed acceptable.
1210
- return py::cast (std::move (result));
1264
+ return py::cast (PyProgramInvocation ( std::move (result) ));
1211
1265
});
1212
1266
py::class_<local::MessageFuture, local::Future>(m, " MessageFuture" )
1213
1267
.def (" result" , [](local::MessageFuture &self) {
0 commit comments