Skip to content

Commit af162b0

Browse files
CR fix: instance ProxyMemManager for output edge which NeedAllocation
1 parent 84d33e6 commit af162b0

File tree

2 files changed

+31
-39
lines changed

2 files changed

+31
-39
lines changed

src/plugins/intel_cpu/src/graph.cpp

Lines changed: 26 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -830,6 +830,31 @@ void Graph::AllocateWithReuse() {
830830
}
831831

832832
if (!undefinedBoxes.empty()) {
833+
// Use proxy memory manager for output edges
834+
for (auto& box : undefinedBoxes) {
835+
for (auto& edge : edge_clusters[box.id]) {
836+
const auto child = edge->getChild();
837+
if (edge->getStatus() == Edge::Status::NeedAllocation &&
838+
child->getType() == Type::Output) {
839+
auto proxyMemMngr =
840+
std::make_shared<ProxyMemoryMngr>(std::make_shared<DnnlMemoryMngr>(make_unique<MemoryMngrWithReuse>()));
841+
DEBUG_LOG(proxyMemMngr, " ", this);
842+
843+
// Store the output memory managers.
844+
// So that, the infer requests can be able to access them.
845+
int count = 0;
846+
for (auto &output : outputNodesMap) {
847+
if (output.second == child) {
848+
outputNodesMemMngrMap[output.first] = proxyMemMngr;
849+
count++;
850+
}
851+
}
852+
IE_ASSERT(count == 1);
853+
}
854+
}
855+
}
856+
IE_ASSERT(outputNodesMemMngrMap.size() <= outputNodesMap.size());
857+
833858
if (!syncNodesInds.empty()) {
834859
//We have to extend the lifespan of thensors that are crossing a sync point border in order to save
835860
//the intermediate computation results from possible loss due to the tensor resize
@@ -882,47 +907,13 @@ void Graph::AllocateWithReuse() {
882907
}
883908
}
884909
for (auto& group : groups) {
885-
MemoryMngrPtr grpMemMngr;
886-
grpMemMngr =
910+
auto grpMemMngr =
887911
std::make_shared<DnnlMemoryMngr>(make_unique<MemoryMngrWithReuse>());
888-
// deternmine a group with outputs.
889-
size_t isOutGrp = 0;
890-
int64_t outBoxId = -1;
891-
for (auto& box : group) {
892-
if (std::any_of(
893-
edge_clusters[box.id].begin(),
894-
edge_clusters[box.id].end(),
895-
[box](const ov::intel_cpu::EdgePtr edge) {
896-
return edge->getChild()->getType() == Type::Output;
897-
})) {
898-
isOutGrp++;
899-
outBoxId = box.id;
900-
}
901-
}
902-
if (isOutGrp) {
903-
IE_ASSERT(isOutGrp==1); // reuse_io_tensors false
904-
grpMemMngr =
905-
std::make_shared<ProxyMemoryMngr>(grpMemMngr);
906-
DEBUG_LOG(grpMemMngr, " ", this);
907-
908-
// Store the output memory managers.
909-
// So that, the infer requests can be able to access them.
910-
for (auto& edge : edge_clusters[outBoxId]) {
911-
const auto child = edge->getChild();
912-
if (child->getType() == Type::Output) {
913-
for (auto &output : outputNodesMap) {
914-
if (output.second == child) outputNodesMemMngrMap[output.first] = std::static_pointer_cast<ProxyMemoryMngr>(grpMemMngr);
915-
}
916-
}
917-
}
918-
}
919912
for (auto& box : group) {
920913
for (auto& edge : edge_clusters[box.id]) {
921914
if (edge->getStatus() == Edge::Status::NeedAllocation) {
922915
edge->allocate(grpMemMngr);
923916
}
924-
if (isOutGrp && "Parameter" != edge->getParent()->getTypeStr())
925-
edge->getParent()->forceUpdateShape = true; // force recheck shape updates for nodes in the output groups.
926917
}
927918
}
928919
}

src/plugins/intel_cpu/src/infer_request.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -275,19 +275,20 @@ void InferRequestBase::changeDefaultPtr() {
275275
auto itr = outMemMngrMap.find(it.first);
276276
if (itr != outMemMngrMap.end()) {
277277
outputMemMngr = itr->second;
278-
OPENVINO_ASSERT(outputMemMngr, "output memmanager should not be empty.");
278+
OPENVINO_ASSERT(outputMemMngr, "proxy mem manager for output ", it.first, " is empty.");
279279
} else {
280-
OPENVINO_THROW("Cannot find output memmanager for output " + it.first + " !");
280+
canBeInPlace = false;
281+
DEBUG_LOG("no proxy mem manager for output ", it.first, " !");
281282
}
282283

283284
if (canBeInPlace) {
284285
auto tt = std::get<0>(outputsTensor2BlobMap[it.first]);
285286
auto memptr = tt->get_memory();
286287
outputMemMngr->setManager(memptr->getMemoryMngr());
287-
DEBUG_LOG("setTensor ", tt, " graph ", graph, " inferrequest ", this);
288+
DEBUG_LOG("setManager ", memptr->getMemoryMngr(), " graph ", graph, " inferrequest ", this);
288289
} else {
289290
outputMemMngr->setManager(nullptr);
290-
DEBUG_LOG("setTensor nullptr", " graph ", graph, " inferrequest ", this);
291+
DEBUG_LOG("setManager nullptr", " graph ", graph, " inferrequest ", this);
291292
}
292293

293294
continue;

0 commit comments

Comments
 (0)