diff options
author | Tee Jung <tee.ty.jung@openedges.com> | 2019-11-01 05:27:28 +0000 |
---|---|---|
committer | Matteo Martincigh <matteo.martincigh@arm.com> | 2019-11-04 09:11:08 +0000 |
commit | fcf6fd562f87595c814d8acbec04194421018c32 (patch) | |
tree | 0d719705d6d05388025457ee9592ed35aa6309c0 | |
parent | d94efa8a7c6476fb50b3434723fda22859c236ad (diff) | |
download | armnn-fcf6fd562f87595c814d8acbec04194421018c32.tar.gz |
Fix crash issue
* armnnOnnxParser makes tensorInfo from graph->value_info
but PyTorch does not add weight/bias tensor information to graph->value_info
so tensorInfo of const tensor should be extracted from graph->initializer
Signed-off-by: Jung Tae-young tee.ty.jung@openedges.com
Change-Id: Ib2656dd25abc522012cf413e843fe03949cb2eb0
-rw-r--r-- | src/armnnOnnxParser/OnnxParser.cpp | 41 |
1 files changed, 30 insertions, 11 deletions
diff --git a/src/armnnOnnxParser/OnnxParser.cpp b/src/armnnOnnxParser/OnnxParser.cpp index 6f83dc5b35..9d374aed71 100644 --- a/src/armnnOnnxParser/OnnxParser.cpp +++ b/src/armnnOnnxParser/OnnxParser.cpp @@ -187,16 +187,10 @@ std::string ReadOptionalNodeStringAttribute(const onnx::NodeProto& node, const s return attribValue; } -armnn::TensorInfo ToTensorInfo(const onnx::ValueInfoProto& info) +armnn::TensorInfo ToTensorInfo(const std::string& name, std::vector<unsigned int>& shape, int data_type) { - const onnx::TensorShapeProto onnxShape = info.type().tensor_type().shape(); - std::vector<unsigned int> shapeDims; - for (int i = 0; i < onnxShape.dim_size(); ++i) - { - shapeDims.push_back(CHECKED_NON_NEGATIVE(CHECKED_INT32(onnxShape.dim(i).dim_value()))); - } DataType type; - switch(info.type().tensor_type().elem_type()) + switch(data_type) { case onnx::TensorProto::FLOAT: { @@ -216,13 +210,35 @@ armnn::TensorInfo ToTensorInfo(const onnx::ValueInfoProto& info) boost::format("'%1%' is not a currently supported datatype for tensor %2%." " Supported dataTypes are FLOAT, INT32 and INT64. %3%") % onnx::TensorProto::DataType_Name( - static_cast<onnx::TensorProto::DataType>(info.type().tensor_type().elem_type())) % - info.name() % + static_cast<onnx::TensorProto::DataType>(data_type)) % + name % CHECK_LOCATION().AsString() )); } + } + return TensorInfo(TensorShape(static_cast<unsigned int>(shape.size()), shape.data()), type); +} + +armnn::TensorInfo ToTensorInfo(const onnx::ValueInfoProto& info) +{ + const onnx::TensorShapeProto onnxShape = info.type().tensor_type().shape(); + std::vector<unsigned int> shapeDims; + for (int i = 0; i < onnxShape.dim_size(); ++i) + { + shapeDims.push_back(CHECKED_NON_NEGATIVE(CHECKED_INT32(onnxShape.dim(i).dim_value()))); + } + + return ToTensorInfo(info.name(), shapeDims, info.type().tensor_type().elem_type()); +} +armnn::TensorInfo ToTensorInfo(const onnx::TensorProto& tensor) +{ + std::vector<unsigned int> shapeDims; + for (auto dim: tensor.dims()) + { + shapeDims.push_back(CHECKED_NON_NEGATIVE(CHECKED_INT32(dim))); } - return TensorInfo(TensorShape(static_cast<unsigned int>(shapeDims.size()), shapeDims.data()), type); + + return ToTensorInfo(tensor.name(), shapeDims, tensor.data_type()); } std::string TensorInfoAsString(const TensorInfo& info, @@ -580,6 +596,9 @@ void OnnxParser::LoadGraph() for (auto tensor : m_Graph->initializer()) { m_TensorsInfo[tensor.name()].m_tensor = std::make_unique<const onnx::TensorProto>(tensor); + m_TensorsInfo[tensor.name()].m_info = std::make_unique<TensorInfo>(ToTensorInfo(tensor)); + m_TensorsInfo[tensor.name()].m_dtype = + static_cast<onnx::TensorProto::DataType>(tensor.data_type()); } SetupInputLayers(); |