@@ -15,35 +15,96 @@ AsyncInferRequest::AsyncInferRequest(const std::shared_ptr<SyncInferRequest>& re
15
15
: ov::IAsyncInferRequest(request, nullptr , callback_executor),
16
16
m_sync_request (request),
17
17
m_request_without_batch(request_without_batch) {
18
- // this executor starts the inference while the task (checking the result) is passed to the next stage
19
- struct ThisRequestExecutor : public ov ::threading::ITaskExecutor {
20
- explicit ThisRequestExecutor (AsyncInferRequest* _this_) : _this{_this_} {}
21
- void run (ov::threading::Task task) override {
22
- auto workerInferRequest = _this->m_sync_request ->m_batched_request_wrapper ;
23
- std::pair<AsyncInferRequest*, ov::threading::Task> t;
24
- t.first = _this;
25
- t.second = std::move (task);
26
- workerInferRequest->_tasks .push (t);
27
- // it is ok to call size() here as the queue only grows (and the bulk removal happens under the mutex)
28
- const int sz = static_cast <int >(workerInferRequest->_tasks .size ());
29
- if (sz == workerInferRequest->_batch_size ) {
30
- workerInferRequest->_cond .notify_one ();
18
+ if (m_sync_request && m_sync_request->get_batch_size () == 0 ) {
19
+ // batch not applicable, just a wrapper to hardware infer request
20
+ // share the tensors with hardware infer request
21
+ for (const auto & input : get_inputs ()) {
22
+ auto tensor = m_request_without_batch->get_tensor (input);
23
+ if (!tensor._so ) {
24
+ tensor._so = m_request_without_batch._so ;
31
25
}
26
+ set_tensor (input, tensor);
27
+ }
28
+ for (const auto & output : get_outputs ()) {
29
+ auto tensor = m_request_without_batch->get_tensor (output);
30
+ if (!tensor._so ) {
31
+ tensor._so = m_request_without_batch._so ;
32
+ }
33
+ set_tensor (output, tensor);
34
+ }
35
+ struct RequestExecutor : ov::threading::ITaskExecutor {
36
+ explicit RequestExecutor (const ov::SoPtr<ov::IAsyncInferRequest>& infer_request)
37
+ : m_inferrequest(infer_request) {
38
+ m_inferrequest->set_callback ([this ](std::exception_ptr exceptionPtr) mutable {
39
+ m_exceptionptr = std::move (exceptionPtr);
40
+ auto capturedTask = std::move (m_task);
41
+ capturedTask ();
42
+ });
43
+ }
44
+ void run (ov::threading::Task task) override {
45
+ m_task = std::move (task);
46
+ m_inferrequest->start_async ();
47
+ };
48
+ const ov::SoPtr<ov::IAsyncInferRequest>& m_inferrequest;
49
+ std::exception_ptr m_exceptionptr;
50
+ ov::threading::Task m_task;
32
51
};
33
- AsyncInferRequest* _this = nullptr ;
34
- };
35
- m_pipeline = {{/* TaskExecutor*/ std::make_shared<ThisRequestExecutor>(this ), /* task*/ [this ] {
36
- if (this ->m_sync_request ->m_exception_ptr ) // if the exception happened in the batch1 fallback
37
- std::rethrow_exception (this ->m_sync_request ->m_exception_ptr );
38
- auto batchReq = this ->m_sync_request ->m_batched_request_wrapper ;
39
- if (batchReq->_exception_ptr ) // when the batchN execution failed
40
- std::rethrow_exception (batchReq->_exception_ptr );
41
- // in the case of non-batched execution the tensors were set explicitly
42
- if (SyncInferRequest::eExecutionFlavor::BATCH_EXECUTED ==
43
- this ->m_sync_request ->m_batched_request_status ) {
44
- this ->m_sync_request ->copy_outputs_if_needed ();
45
- }
46
- }}};
52
+ auto requestExecutor = std::make_shared<RequestExecutor>(m_request_without_batch);
53
+ m_pipeline.emplace_back (requestExecutor, [requestExecutor] {
54
+ if (nullptr != requestExecutor->m_exceptionptr ) {
55
+ std::rethrow_exception (requestExecutor->m_exceptionptr );
56
+ }
57
+ });
58
+ } else {
59
+ // batch size > 1, try infer with batched request
60
+ // this executor starts the inference while the task (checking the result) is passed to the next stage
61
+ struct ThisRequestExecutor : public ov ::threading::ITaskExecutor {
62
+ explicit ThisRequestExecutor (AsyncInferRequest* _this_) : _this{_this_} {}
63
+ void run (ov::threading::Task task) override {
64
+ auto workerInferRequest = _this->m_sync_request ->m_batched_request_wrapper ;
65
+ std::pair<AsyncInferRequest*, ov::threading::Task> t;
66
+ t.first = _this;
67
+ t.second = std::move (task);
68
+ workerInferRequest->_tasks .push (t);
69
+ // it is ok to call size() here as the queue only grows (and the bulk removal happens under the mutex)
70
+ const int sz = static_cast <int >(workerInferRequest->_tasks .size ());
71
+ if (sz == workerInferRequest->_batch_size ) {
72
+ workerInferRequest->_cond .notify_one ();
73
+ }
74
+ };
75
+ AsyncInferRequest* _this = nullptr ;
76
+ };
77
+ m_pipeline = {
78
+ {/* TaskExecutor*/ std::make_shared<ThisRequestExecutor>(this ), /* task*/ [this ] {
79
+ if (this ->m_sync_request ->m_exception_ptr ) // if the exception happened in the batch1 fallback
80
+ std::rethrow_exception (this ->m_sync_request ->m_exception_ptr );
81
+ auto batchReq = this ->m_sync_request ->m_batched_request_wrapper ;
82
+ if (batchReq->_exception_ptr ) // when the batchN execution failed
83
+ std::rethrow_exception (batchReq->_exception_ptr );
84
+ // in the case of non-batched execution the tensors were set explicitly
85
+ if (SyncInferRequest::eExecutionFlavor::BATCH_EXECUTED ==
86
+ this ->m_sync_request ->m_batched_request_status ) {
87
+ this ->m_sync_request ->copy_outputs_if_needed ();
88
+ }
89
+ }}};
90
+ }
91
+ }
92
+
93
+ void AsyncInferRequest::set_tensor (const ov::Output<const ov::Node>& port, const ov::SoPtr<ov::ITensor>& tensor) {
94
+ check_state ();
95
+ if (m_sync_request && m_sync_request->get_batch_size () == 0 ) {
96
+ m_request_without_batch->set_tensor (port, tensor);
97
+ }
98
+ ov::IAsyncInferRequest::set_tensor (port, tensor);
99
+ }
100
+
101
+ void AsyncInferRequest::set_tensors (const ov::Output<const ov::Node>& port,
102
+ const std::vector<ov::SoPtr<ov::ITensor>>& tensors) {
103
+ check_state ();
104
+ if (m_sync_request && m_sync_request->get_batch_size () == 0 ) {
105
+ m_request_without_batch->set_tensors (port, tensors);
106
+ }
107
+ ov::IAsyncInferRequest::set_tensors (port, tensors);
47
108
}
48
109
49
110
std::vector<ov::ProfilingInfo> AsyncInferRequest::get_profiling_info () const {
0 commit comments