Skip to content

Commit 89be56f

Browse files
authored
fix: Refactor read_model() to accept XML path (#646)
1 parent 1104926 commit 89be56f

File tree

1 file changed

+7
-29
lines changed

1 file changed

+7
-29
lines changed

onnxruntime/core/providers/openvino/ov_interface.cc

Lines changed: 7 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -198,40 +198,18 @@ OVExeNetwork OVCore::ImportModel(std::istream& model_stream,
198198
auto obj = core.import_model(model_stream, hw_target, device_config);
199199
exe = OVExeNetwork(obj, hw_target);
200200
} else {
201-
// Get path to bin file
202-
std::string bin_file;
201+
// If the model is XML, we need to load it with the XML content in read_model()
202+
// where weights from bin file is directly consumed
203+
std::string xml_file_name = name;
203204
if (name.size() >= 5 && name.substr(name.size() - 5) == ".onnx") {
204-
bin_file = name;
205-
bin_file.replace(name.size() - 5, 5, ".bin");
205+
xml_file_name = name;
206+
xml_file_name.replace(name.size() - 5, 5, ".xml");
206207
} else {
207208
throw std::runtime_error("Invalid model name. Make sure *.onnx, *.xml, and *.bin carry the same name.");
208209
}
209210

210-
// Read the model XML into a string
211-
std::stringstream xml_stream;
212-
xml_stream << model_stream.rdbuf();
213-
std::string xml_content = xml_stream.str();
214-
215-
// Read model.bin into a vector
216-
std::ifstream bin_stream;
217-
bin_stream.open(bin_file, std::ios::binary);
218-
if (!bin_stream.is_open()) {
219-
throw std::runtime_error("Failed to open " + bin_file);
220-
}
221-
222-
bin_stream.seekg(0, std::ios::end);
223-
std::streamsize size = bin_stream.tellg();
224-
bin_stream.seekg(0, std::ios::beg);
225-
std::vector<uint8_t> bin_data(size);
226-
if (!bin_stream.read(reinterpret_cast<char*>(bin_data.data()), size)) {
227-
throw std::runtime_error("Failed to read binary data from " + bin_file);
228-
}
229-
230-
// Create an ov::Tensor for weights
231-
ov::Tensor weights_tensor(ov::element::u8, {bin_data.size()}, bin_data.data());
232-
233-
// Load the model explicitly with XML content and weights
234-
std::shared_ptr<ov::Model> model = core.read_model(xml_content, weights_tensor);
211+
// Load the model explicitly with XML contents
212+
std::shared_ptr<ov::Model> model = core.read_model(xml_file_name);
235213

236214
if (enable_causallm) {
237215
exe = OVCore::Get()->StatefulCompileModel(model, hw_target, device_config);

0 commit comments

Comments
 (0)