aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTee Jung <tee.ty.jung@openedges.com>2019-11-01 05:27:28 +0000
committerMatteo Martincigh <matteo.martincigh@arm.com>2019-11-04 09:11:08 +0000
commitfcf6fd562f87595c814d8acbec04194421018c32 (patch)
tree0d719705d6d05388025457ee9592ed35aa6309c0
parentd94efa8a7c6476fb50b3434723fda22859c236ad (diff)
downloadarmnn-fcf6fd562f87595c814d8acbec04194421018c32.tar.gz
Fix crash issue
* armnnOnnxParser makes tensorInfo from graph->value_info but PyTorch does not add weight/bias tensor information to graph->value_info so tensorInfo of const tensor should be extracted from graph->initializer Signed-off-by: Jung Tae-young tee.ty.jung@openedges.com Change-Id: Ib2656dd25abc522012cf413e843fe03949cb2eb0
-rw-r--r--src/armnnOnnxParser/OnnxParser.cpp41
1 files changed, 30 insertions, 11 deletions
diff --git a/src/armnnOnnxParser/OnnxParser.cpp b/src/armnnOnnxParser/OnnxParser.cpp
index 6f83dc5b35..9d374aed71 100644
--- a/src/armnnOnnxParser/OnnxParser.cpp
+++ b/src/armnnOnnxParser/OnnxParser.cpp
@@ -187,16 +187,10 @@ std::string ReadOptionalNodeStringAttribute(const onnx::NodeProto& node, const s
return attribValue;
}
-armnn::TensorInfo ToTensorInfo(const onnx::ValueInfoProto& info)
+armnn::TensorInfo ToTensorInfo(const std::string& name, std::vector<unsigned int>& shape, int data_type)
{
- const onnx::TensorShapeProto onnxShape = info.type().tensor_type().shape();
- std::vector<unsigned int> shapeDims;
- for (int i = 0; i < onnxShape.dim_size(); ++i)
- {
- shapeDims.push_back(CHECKED_NON_NEGATIVE(CHECKED_INT32(onnxShape.dim(i).dim_value())));
- }
DataType type;
- switch(info.type().tensor_type().elem_type())
+ switch(data_type)
{
case onnx::TensorProto::FLOAT:
{
@@ -216,13 +210,35 @@ armnn::TensorInfo ToTensorInfo(const onnx::ValueInfoProto& info)
boost::format("'%1%' is not a currently supported datatype for tensor %2%."
" Supported dataTypes are FLOAT, INT32 and INT64. %3%") %
onnx::TensorProto::DataType_Name(
- static_cast<onnx::TensorProto::DataType>(info.type().tensor_type().elem_type())) %
- info.name() %
+ static_cast<onnx::TensorProto::DataType>(data_type)) %
+ name %
CHECK_LOCATION().AsString() ));
}
+ }
+ return TensorInfo(TensorShape(static_cast<unsigned int>(shape.size()), shape.data()), type);
+}
+
+armnn::TensorInfo ToTensorInfo(const onnx::ValueInfoProto& info)
+{
+ const onnx::TensorShapeProto onnxShape = info.type().tensor_type().shape();
+ std::vector<unsigned int> shapeDims;
+ for (int i = 0; i < onnxShape.dim_size(); ++i)
+ {
+ shapeDims.push_back(CHECKED_NON_NEGATIVE(CHECKED_INT32(onnxShape.dim(i).dim_value())));
+ }
+
+ return ToTensorInfo(info.name(), shapeDims, info.type().tensor_type().elem_type());
+}
+armnn::TensorInfo ToTensorInfo(const onnx::TensorProto& tensor)
+{
+ std::vector<unsigned int> shapeDims;
+ for (auto dim: tensor.dims())
+ {
+ shapeDims.push_back(CHECKED_NON_NEGATIVE(CHECKED_INT32(dim)));
}
- return TensorInfo(TensorShape(static_cast<unsigned int>(shapeDims.size()), shapeDims.data()), type);
+
+ return ToTensorInfo(tensor.name(), shapeDims, tensor.data_type());
}
std::string TensorInfoAsString(const TensorInfo& info,
@@ -580,6 +596,9 @@ void OnnxParser::LoadGraph()
for (auto tensor : m_Graph->initializer())
{
m_TensorsInfo[tensor.name()].m_tensor = std::make_unique<const onnx::TensorProto>(tensor);
+ m_TensorsInfo[tensor.name()].m_info = std::make_unique<TensorInfo>(ToTensorInfo(tensor));
+ m_TensorsInfo[tensor.name()].m_dtype =
+ static_cast<onnx::TensorProto::DataType>(tensor.data_type());
}
SetupInputLayers();