@@ -830,6 +830,31 @@ void Graph::AllocateWithReuse() {
830830 }
831831
832832 if (!undefinedBoxes.empty ()) {
833+ // Use proxy memory manager for output edges
834+ for (auto & box : undefinedBoxes) {
835+ for (auto & edge : edge_clusters[box.id ]) {
836+ const auto child = edge->getChild ();
837+ if (edge->getStatus () == Edge::Status::NeedAllocation &&
838+ child->getType () == Type::Output) {
839+ auto proxyMemMngr =
840+ std::make_shared<ProxyMemoryMngr>(std::make_shared<DnnlMemoryMngr>(make_unique<MemoryMngrWithReuse>()));
841+ DEBUG_LOG (proxyMemMngr, " " , this );
842+
843+ // Store the output memory managers.
844+ // So that, the infer requests can be able to access them.
845+ int count = 0 ;
846+ for (auto &output : outputNodesMap) {
847+ if (output.second == child) {
848+ outputNodesMemMngrMap[output.first ] = proxyMemMngr;
849+ count++;
850+ }
851+ }
852+ IE_ASSERT (count == 1 );
853+ }
854+ }
855+ }
856+ IE_ASSERT (outputNodesMemMngrMap.size () <= outputNodesMap.size ());
857+
833858 if (!syncNodesInds.empty ()) {
834859 // We have to extend the lifespan of thensors that are crossing a sync point border in order to save
835860 // the intermediate computation results from possible loss due to the tensor resize
@@ -882,47 +907,13 @@ void Graph::AllocateWithReuse() {
882907 }
883908 }
884909 for (auto & group : groups) {
885- MemoryMngrPtr grpMemMngr;
886- grpMemMngr =
910+ auto grpMemMngr =
887911 std::make_shared<DnnlMemoryMngr>(make_unique<MemoryMngrWithReuse>());
888- // deternmine a group with outputs.
889- size_t isOutGrp = 0 ;
890- int64_t outBoxId = -1 ;
891- for (auto & box : group) {
892- if (std::any_of (
893- edge_clusters[box.id ].begin (),
894- edge_clusters[box.id ].end (),
895- [box](const ov::intel_cpu::EdgePtr edge) {
896- return edge->getChild ()->getType () == Type::Output;
897- })) {
898- isOutGrp++;
899- outBoxId = box.id ;
900- }
901- }
902- if (isOutGrp) {
903- IE_ASSERT (isOutGrp==1 ); // reuse_io_tensors false
904- grpMemMngr =
905- std::make_shared<ProxyMemoryMngr>(grpMemMngr);
906- DEBUG_LOG (grpMemMngr, " " , this );
907-
908- // Store the output memory managers.
909- // So that, the infer requests can be able to access them.
910- for (auto & edge : edge_clusters[outBoxId]) {
911- const auto child = edge->getChild ();
912- if (child->getType () == Type::Output) {
913- for (auto &output : outputNodesMap) {
914- if (output.second == child) outputNodesMemMngrMap[output.first ] = std::static_pointer_cast<ProxyMemoryMngr>(grpMemMngr);
915- }
916- }
917- }
918- }
919912 for (auto & box : group) {
920913 for (auto & edge : edge_clusters[box.id ]) {
921914 if (edge->getStatus () == Edge::Status::NeedAllocation) {
922915 edge->allocate (grpMemMngr);
923916 }
924- if (isOutGrp && " Parameter" != edge->getParent ()->getTypeStr ())
925- edge->getParent ()->forceUpdateShape = true ; // force recheck shape updates for nodes in the output groups.
926917 }
927918 }
928919 }
0 commit comments