aboutsummaryrefslogtreecommitdiff
path: root/src/armnnOnnxParser/OnnxParser.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnOnnxParser/OnnxParser.cpp')
-rw-r--r--src/armnnOnnxParser/OnnxParser.cpp165
1 files changed, 87 insertions, 78 deletions
diff --git a/src/armnnOnnxParser/OnnxParser.cpp b/src/armnnOnnxParser/OnnxParser.cpp
index 3fcb7ab603..6caf690935 100644
--- a/src/armnnOnnxParser/OnnxParser.cpp
+++ b/src/armnnOnnxParser/OnnxParser.cpp
@@ -262,38 +262,38 @@ std::string ReadOptionalNodeStringAttribute(const onnx::NodeProto& node, const s
armnn::TensorInfo ToTensorInfo(const std::string& name, std::vector<unsigned int>& shape, int data_type)
{
- DataType type;
- switch(data_type)
- {
- case onnx::TensorProto::FLOAT:
- {
- type = DataType::Float32;
- break;
- }
- case onnx::TensorProto::INT32:
- case onnx::TensorProto::INT64:
- {
- type = DataType::Signed32;
+ DataType type;
+ switch(data_type)
+ {
+ case onnx::TensorProto::FLOAT:
+ {
+ type = DataType::Float32;
break;
- }
- default:
- {
- throw ParseException(
- fmt::format("'{}' is not a currently supported datatype for tensor {}."
- " Supported dataTypes are FLOAT, INT32 and INT64. {}",
- onnx::TensorProto::DataType_Name(static_cast<onnx::TensorProto::DataType>(data_type)),
- name,
- CHECK_LOCATION().AsString() ));
- }
- }
+ }
+ case onnx::TensorProto::INT32:
+ case onnx::TensorProto::INT64:
+ {
+ type = DataType::Signed32;
+ break;
+ }
+ default:
+ {
+ throw ParseException(
+ fmt::format("'{}' is not a currently supported datatype for tensor {}."
+ " Supported dataTypes are FLOAT, INT32 and INT64. {}",
+ onnx::TensorProto::DataType_Name(static_cast<onnx::TensorProto::DataType>(data_type)),
+ name,
+ CHECK_LOCATION().AsString() ));
+ }
+ }
- // To avoid crashes by trivial tensors
- if (shape.empty())
- {
- return TensorInfo(TensorShape(), type);
- }
+ // To avoid crashes by trivial tensors
+ if (shape.empty())
+ {
+ return TensorInfo(TensorShape(Dimensionality::Scalar), type);
+ }
- return TensorInfo(TensorShape(static_cast<unsigned int>(shape.size()), shape.data()), type);
+ return TensorInfo(TensorShape(static_cast<unsigned int>(shape.size()), shape.data()), type);
}
armnn::TensorInfo ToTensorInfo(const onnx::ValueInfoProto& info)
@@ -305,11 +305,6 @@ armnn::TensorInfo ToTensorInfo(const onnx::ValueInfoProto& info)
shapeDims.push_back(CHECKED_NON_NEGATIVE(CHECKED_INT32(onnxShape.dim(i).dim_value())));
}
- if (shapeDims.empty())
- {
- shapeDims.push_back(1);
- }
-
return ToTensorInfo(info.name(), shapeDims, info.type().tensor_type().elem_type());
}
@@ -322,11 +317,6 @@ armnn::TensorInfo ToTensorInfo(const onnx::TensorProto& tensor)
shapeDims.push_back(CHECKED_NON_NEGATIVE(CHECKED_INT32(dim)));
}
- if (shapeDims.empty())
- {
- shapeDims.push_back(1);
- }
-
return ToTensorInfo(tensor.name(), shapeDims, tensor.data_type());
}
@@ -376,7 +366,8 @@ void CalcPadding(uint32_t inputSize,
TensorInfo ComputeReshapeInfo(const TensorShape& targetShapeTensor,
const TensorShape& inShape,
- const std::string& outName)
+ const std::string& outName,
+ DataType dataType = DataType::Float32)
{
std::vector<int> targetDims;
for(uint i = 0; i < targetShapeTensor.GetNumDimensions(); ++i)
@@ -420,7 +411,7 @@ TensorInfo ComputeReshapeInfo(const TensorShape& targetShapeTensor,
outDims[stretchIndex] = inShape.GetNumElements() / targetNumElements;
}
TensorShape outShape = TensorShape{static_cast<unsigned int>(outDims.size()), outDims.data()};
- return TensorInfo(outShape, DataType::Float32);
+ return TensorInfo(outShape, dataType);
}
} //namespace
@@ -469,7 +460,8 @@ void OnnxParserImpl::ValidateInputs(const onnx::NodeProto& node,
std::vector<TensorInfo> OnnxParserImpl::ComputeOutputInfo(std::vector<std::string> outNames,
const IConnectableLayer* layer,
- std::vector<TensorShape> inputShapes)
+ std::vector<TensorShape> inputShapes,
+ const onnx::TensorProto::DataType& dataType)
{
ARMNN_ASSERT(! outNames.empty());
bool needCompute = std::any_of(outNames.begin(),
@@ -478,25 +470,45 @@ std::vector<TensorInfo> OnnxParserImpl::ComputeOutputInfo(std::vector<std::strin
{
return (m_TensorsInfo.count(name) == 0 || m_TensorsInfo[name].m_info == nullptr);
});
- std::vector<TensorInfo> outInfo;
- //if the output info(s) are not here, we need to compute them
- std::vector<TensorShape> inferredShapes;
- if(needCompute)
- {
- inferredShapes = layer->InferOutputShapes(inputShapes);
- ARMNN_ASSERT(inferredShapes.size() == outNames.size());
- }
- for (uint i = 0; i < outNames.size(); ++i)
- {
- if(needCompute)
- {
- m_TensorsInfo[outNames[i]] = OnnxTensor();
- m_TensorsInfo[outNames[i]].m_info = std::make_unique<TensorInfo>(
- TensorInfo(inferredShapes[i], DataType::Float32));
- }
+ std::vector<TensorInfo> outInfo;
+ //if the output info(s) are not here, we need to compute them
+ std::vector<TensorShape> inferredShapes;
+ DataType armnnType = DataType::Float32;
+ if(needCompute) {
+ inferredShapes = layer->InferOutputShapes(inputShapes);
+ ARMNN_ASSERT(inferredShapes.size() == outNames.size());
+ switch (dataType) {
+ case onnx::TensorProto::FLOAT: {
+ armnnType = DataType::Float32;
+ break;
+ }
+ case onnx::TensorProto::INT32:
+ case onnx::TensorProto::INT64: {
+ armnnType = DataType::Signed32;
+ break;
+ }
+ default: {
+ throw ParseException(
+ fmt::format("'{}' is not a currently supported datatype for {}."
+ " Supported dataTypes are FLOAT, INT32 and INT64. {}",
+ onnx::TensorProto::DataType_Name(static_cast<onnx::TensorProto::DataType>(dataType)),
+ layer->GetName(),
+ CHECK_LOCATION().AsString()));
+ }
+ }
+ }
+ for (uint i = 0; i < outNames.size(); ++i)
+ {
+ if(needCompute)
+ {
+ m_TensorsInfo[outNames[i]] = OnnxTensor();
+ m_TensorsInfo[outNames[i]].m_info = std::make_unique<TensorInfo>(
+ TensorInfo(inferredShapes[i], armnnType));
+ m_TensorsInfo[outNames[i]].m_dtype = dataType;
+ }
outInfo.push_back(*m_TensorsInfo[outNames[i]].m_info);
- }
- return outInfo;
+ }
+ return outInfo;
}
OnnxParserImpl::OnnxParserImpl()
@@ -1480,7 +1492,8 @@ void OnnxParserImpl::ParseConcat(const onnx::NodeProto& node)
IConnectableLayer* layer = m_Network->AddConcatLayer(concatDescriptor, node.name().c_str());
ARMNN_ASSERT(layer != nullptr);
- auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, inputShapes);
+ auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, inputShapes,
+ m_TensorsInfo[node.input(0)].m_dtype);
layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
@@ -1774,9 +1787,10 @@ void OnnxParserImpl::ParseGather(const onnx::NodeProto& node)
IConnectableLayer* layer = m_Network->AddGatherLayer(gatherDescriptor, node.name().c_str());
ARMNN_ASSERT(layer != nullptr);
- TensorShape inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape();
- TensorShape indicesShape = m_TensorsInfo[node.input(1)].m_info->GetShape();
- auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, { inputShape, indicesShape });
+ const TensorShape& inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape();
+ const TensorShape& indicesShape = m_TensorsInfo[node.input(1)].m_info->GetShape();
+ auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, { inputShape, indicesShape },
+ m_TensorsInfo[node.input(0)].m_dtype);
layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
// register the input connection slots for the layer, connections are made after all layers have been created
@@ -1823,16 +1837,11 @@ void OnnxParserImpl::ParseShape(const onnx::NodeProto& node)
CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 1);
CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
- // Output must be INT64
- CHECK_VALID_DATATYPE(node.name(), node.output(0),
- m_TensorsInfo[node.output(0)].m_dtype,
- onnx::TensorProto::INT64);
-
IConnectableLayer* layer = m_Network->AddShapeLayer(node.name().c_str());
ARMNN_ASSERT(layer != nullptr);
TensorShape inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape();
- auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, {inputShape});
+ auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, {inputShape}, onnx::TensorProto::INT64);
layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
// register the input connection slots for the layer, connections are made after all layers have been created
@@ -1900,10 +1909,6 @@ void OnnxParserImpl::ParseUnsqueeze(const onnx::NodeProto& node)
CHECK_VALID_SIZE(armnn::numeric_cast<size_t>(node.input_size()), 1, 2);
CHECK_VALID_SIZE(armnn::numeric_cast<size_t>(node.output_size()), 1);
- CHECK_VALID_DATATYPE(node.name(), node.input(0),
- m_TensorsInfo[node.input(0)].m_dtype,
- onnx::TensorProto::FLOAT); //input
-
TensorShape inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape();
std::vector<uint32_t> dims;
if (node.input_size() == 1 && node.attribute_size() > 0)
@@ -1931,9 +1936,12 @@ void OnnxParserImpl::ParseUnsqueeze(const onnx::NodeProto& node)
std::vector<unsigned int> targetShape;
- for(uint i = 0; i < inputShape.GetNumDimensions(); i++)
+ if (inputShape.GetDimensionality() != Dimensionality::Scalar)
{
- targetShape.push_back(inputShape[i]);
+ for(uint i = 0; i < inputShape.GetNumDimensions(); i++)
+ {
+ targetShape.push_back(inputShape[i]);
+ }
}
for(uint i = 0; i < dims.size(); i++)
@@ -1941,9 +1949,10 @@ void OnnxParserImpl::ParseUnsqueeze(const onnx::NodeProto& node)
targetShape.insert(targetShape.begin() + armnn::numeric_cast<int>(dims[i]), 1);
}
- auto outInfo = ComputeReshapeInfo(TensorShape(armnn::numeric_cast<unsigned int>(targetShape.size()),
- targetShape.data()), inputShape, node.output(0));
+ auto outInfo = ComputeReshapeInfo(TensorShape(static_cast<unsigned int>(targetShape.size()), targetShape.data()),
+ inputShape, node.output(0), m_TensorsInfo[node.input(0)].m_info->GetDataType());
m_TensorsInfo[node.output(0)].m_info = std::make_unique<TensorInfo>(outInfo);
+ m_TensorsInfo[node.output(0)].m_dtype = m_TensorsInfo[node.input(0)].m_dtype;
CreateReshapeLayer(node.input(0), node.output(0), node.name());
}