aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPablo Tello <pablo.tello@arm.com>2019-04-24 14:20:21 +0100
committerPablo Tello <pablo.tello@arm.com>2019-04-26 10:24:21 +0100
commit3dcc1c68104a6b79b0c5dd5d5172011aebf3e2e5 (patch)
treede9487ca92757c0e66652c14ddbaf4605ed3f372
parentd49b70fc4b9b6fe23d42399bde23abdf4d2ee9c7 (diff)
downloadarmnn-3dcc1c68104a6b79b0c5dd5d5172011aebf3e2e5.tar.gz
MLCE-111: ONNX parser raw data bug
Fixed bug in ONNX parser: unable to load raw data from the binary models. Change-Id: Iec60d2f90b78ffe6910fdec6e6bd2eb05802ffd0 Signed-off-by: Pablo Tello <pablo.tello@arm.com>
-rw-r--r--src/armnnOnnxParser/OnnxParser.cpp21
1 files changed, 14 insertions, 7 deletions
diff --git a/src/armnnOnnxParser/OnnxParser.cpp b/src/armnnOnnxParser/OnnxParser.cpp
index 6fe7cc6732..a62383b563 100644
--- a/src/armnnOnnxParser/OnnxParser.cpp
+++ b/src/armnnOnnxParser/OnnxParser.cpp
@@ -419,20 +419,27 @@ std::pair<ConstTensor, std::unique_ptr<float[]>> OnnxParser::CreateConstTensor(c
onnx::TensorProto onnxTensor = *m_TensorsInfo[name].m_tensor;
auto srcData = onnxTensor.float_data().data();
- if(tensorInfo.GetNumElements() != static_cast<uint>(onnxTensor.float_data_size()))
+ std::unique_ptr<float[]> tensorData(new float[tensorInfo.GetNumElements()]);
+ const size_t tensorSizeInBytes = tensorInfo.GetNumBytes();
+ // Copy the value list entries into the destination
+ if (!onnxTensor.has_raw_data())
{
- throw ParseException(boost::str(
- boost::format("The number of data provided (%1%) does not match the tensor '%2%' number of elements"
+ if(tensorInfo.GetNumElements() != static_cast<uint>(onnxTensor.float_data_size()))
+ {
+ throw ParseException(boost::str(
+ boost::format("The number of data provided (%1%) does not match the tensor '%2%' number of elements"
" (%3%) %4%")
% onnxTensor.float_data_size()
% name
% tensorInfo.GetNumElements()
% CHECK_LOCATION().AsString()));
+ }
+ ::memcpy(tensorData.get(), srcData, tensorSizeInBytes);
+ }
+ else
+ {
+ ::memcpy(tensorData.get(), onnxTensor.raw_data().c_str(), tensorSizeInBytes);
}
- std::unique_ptr<float[]> tensorData(new float[tensorInfo.GetNumElements()]);
-
- // Copy the value list entries into the destination
- ::memcpy(tensorData.get(),srcData, tensorInfo.GetNumBytes());
// Const tensors requires at least a list of values
if (tensorInfo.GetNumElements() == 0)