aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNarumol Prangnawarat <narumol.prangnawarat@arm.com>2021-09-23 16:12:19 +0100
committerNarumol Prangnawarat <narumol.prangnawarat@arm.com>2021-10-07 14:43:09 +0000
commit452274c86245082ce20563ede12b92af81dba38a (patch)
tree79718c6cf86acbb21138068c17aae15c4b172306
parent4d217c02fe2c0a32ff9da69d8fe375a75173c0f3 (diff)
downloadarmnn-452274c86245082ce20563ede12b92af81dba38a.tar.gz
IVGCVSW-6459 Add support of scalar and flexible output datatypes to ONNX parser
Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com> Change-Id: Id1e933f6ae55ddc1a57c80c9f6a5757ccb61f018
-rw-r--r--src/armnn/layers/GatherLayer.cpp8
-rw-r--r--src/armnnOnnxParser/OnnxParser.cpp165
-rw-r--r--src/armnnOnnxParser/OnnxParser.hpp8
-rw-r--r--src/armnnOnnxParser/test/Gather.cpp22
-rw-r--r--src/armnnOnnxParser/test/Shape.cpp56
-rw-r--r--src/armnnOnnxParser/test/Unsqueeze.cpp14
6 files changed, 153 insertions, 120 deletions
diff --git a/src/armnn/layers/GatherLayer.cpp b/src/armnn/layers/GatherLayer.cpp
index e8b67b8348..a808c42384 100644
--- a/src/armnn/layers/GatherLayer.cpp
+++ b/src/armnn/layers/GatherLayer.cpp
@@ -37,6 +37,11 @@ std::vector<TensorShape> GatherLayer::InferOutputShapes(const std::vector<Tensor
const TensorShape& params = inputShapes[0];
const TensorShape& indices = inputShapes[1];
+ if (indices.GetDimensionality() == Dimensionality::Scalar && indices.GetNumDimensions() == 1)
+ {
+ return std::vector<TensorShape>({ TensorShape(Dimensionality::Scalar)});
+ }
+
const unsigned int paramsDim = params.GetNumDimensions();
const unsigned int indicesDim = indices.GetNumDimensions();
const unsigned int outputDim = paramsDim - 1 + indicesDim;
@@ -78,7 +83,8 @@ void GatherLayer::ValidateTensorShapesFromInputs()
{GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(),
GetInputSlot(1).GetConnection()->GetTensorInfo().GetShape()});
ARMNN_ASSERT(inferredShapes.size() == 1);
- ARMNN_ASSERT(inferredShapes[0].GetDimensionality() == Dimensionality::Specified);
+ ARMNN_ASSERT(inferredShapes[0].GetDimensionality() == Dimensionality::Specified ||
+ inferredShapes[0].GetDimensionality() == Dimensionality::Scalar);
ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "GatherLayer");
}
diff --git a/src/armnnOnnxParser/OnnxParser.cpp b/src/armnnOnnxParser/OnnxParser.cpp
index 3fcb7ab603..6caf690935 100644
--- a/src/armnnOnnxParser/OnnxParser.cpp
+++ b/src/armnnOnnxParser/OnnxParser.cpp
@@ -262,38 +262,38 @@ std::string ReadOptionalNodeStringAttribute(const onnx::NodeProto& node, const s
armnn::TensorInfo ToTensorInfo(const std::string& name, std::vector<unsigned int>& shape, int data_type)
{
- DataType type;
- switch(data_type)
- {
- case onnx::TensorProto::FLOAT:
- {
- type = DataType::Float32;
- break;
- }
- case onnx::TensorProto::INT32:
- case onnx::TensorProto::INT64:
- {
- type = DataType::Signed32;
+ DataType type;
+ switch(data_type)
+ {
+ case onnx::TensorProto::FLOAT:
+ {
+ type = DataType::Float32;
break;
- }
- default:
- {
- throw ParseException(
- fmt::format("'{}' is not a currently supported datatype for tensor {}."
- " Supported dataTypes are FLOAT, INT32 and INT64. {}",
- onnx::TensorProto::DataType_Name(static_cast<onnx::TensorProto::DataType>(data_type)),
- name,
- CHECK_LOCATION().AsString() ));
- }
- }
+ }
+ case onnx::TensorProto::INT32:
+ case onnx::TensorProto::INT64:
+ {
+ type = DataType::Signed32;
+ break;
+ }
+ default:
+ {
+ throw ParseException(
+ fmt::format("'{}' is not a currently supported datatype for tensor {}."
+ " Supported dataTypes are FLOAT, INT32 and INT64. {}",
+ onnx::TensorProto::DataType_Name(static_cast<onnx::TensorProto::DataType>(data_type)),
+ name,
+ CHECK_LOCATION().AsString() ));
+ }
+ }
- // To avoid crashes by trivial tensors
- if (shape.empty())
- {
- return TensorInfo(TensorShape(), type);
- }
+ // To avoid crashes by trivial tensors
+ if (shape.empty())
+ {
+ return TensorInfo(TensorShape(Dimensionality::Scalar), type);
+ }
- return TensorInfo(TensorShape(static_cast<unsigned int>(shape.size()), shape.data()), type);
+ return TensorInfo(TensorShape(static_cast<unsigned int>(shape.size()), shape.data()), type);
}
armnn::TensorInfo ToTensorInfo(const onnx::ValueInfoProto& info)
@@ -305,11 +305,6 @@ armnn::TensorInfo ToTensorInfo(const onnx::ValueInfoProto& info)
shapeDims.push_back(CHECKED_NON_NEGATIVE(CHECKED_INT32(onnxShape.dim(i).dim_value())));
}
- if (shapeDims.empty())
- {
- shapeDims.push_back(1);
- }
-
return ToTensorInfo(info.name(), shapeDims, info.type().tensor_type().elem_type());
}
@@ -322,11 +317,6 @@ armnn::TensorInfo ToTensorInfo(const onnx::TensorProto& tensor)
shapeDims.push_back(CHECKED_NON_NEGATIVE(CHECKED_INT32(dim)));
}
- if (shapeDims.empty())
- {
- shapeDims.push_back(1);
- }
-
return ToTensorInfo(tensor.name(), shapeDims, tensor.data_type());
}
@@ -376,7 +366,8 @@ void CalcPadding(uint32_t inputSize,
TensorInfo ComputeReshapeInfo(const TensorShape& targetShapeTensor,
const TensorShape& inShape,
- const std::string& outName)
+ const std::string& outName,
+ DataType dataType = DataType::Float32)
{
std::vector<int> targetDims;
for(uint i = 0; i < targetShapeTensor.GetNumDimensions(); ++i)
@@ -420,7 +411,7 @@ TensorInfo ComputeReshapeInfo(const TensorShape& targetShapeTensor,
outDims[stretchIndex] = inShape.GetNumElements() / targetNumElements;
}
TensorShape outShape = TensorShape{static_cast<unsigned int>(outDims.size()), outDims.data()};
- return TensorInfo(outShape, DataType::Float32);
+ return TensorInfo(outShape, dataType);
}
} //namespace
@@ -469,7 +460,8 @@ void OnnxParserImpl::ValidateInputs(const onnx::NodeProto& node,
std::vector<TensorInfo> OnnxParserImpl::ComputeOutputInfo(std::vector<std::string> outNames,
const IConnectableLayer* layer,
- std::vector<TensorShape> inputShapes)
+ std::vector<TensorShape> inputShapes,
+ const onnx::TensorProto::DataType& dataType)
{
ARMNN_ASSERT(! outNames.empty());
bool needCompute = std::any_of(outNames.begin(),
@@ -478,25 +470,45 @@ std::vector<TensorInfo> OnnxParserImpl::ComputeOutputInfo(std::vector<std::strin
{
return (m_TensorsInfo.count(name) == 0 || m_TensorsInfo[name].m_info == nullptr);
});
- std::vector<TensorInfo> outInfo;
- //if the output info(s) are not here, we need to compute them
- std::vector<TensorShape> inferredShapes;
- if(needCompute)
- {
- inferredShapes = layer->InferOutputShapes(inputShapes);
- ARMNN_ASSERT(inferredShapes.size() == outNames.size());
- }
- for (uint i = 0; i < outNames.size(); ++i)
- {
- if(needCompute)
- {
- m_TensorsInfo[outNames[i]] = OnnxTensor();
- m_TensorsInfo[outNames[i]].m_info = std::make_unique<TensorInfo>(
- TensorInfo(inferredShapes[i], DataType::Float32));
- }
+ std::vector<TensorInfo> outInfo;
+ //if the output info(s) are not here, we need to compute them
+ std::vector<TensorShape> inferredShapes;
+ DataType armnnType = DataType::Float32;
+ if(needCompute) {
+ inferredShapes = layer->InferOutputShapes(inputShapes);
+ ARMNN_ASSERT(inferredShapes.size() == outNames.size());
+ switch (dataType) {
+ case onnx::TensorProto::FLOAT: {
+ armnnType = DataType::Float32;
+ break;
+ }
+ case onnx::TensorProto::INT32:
+ case onnx::TensorProto::INT64: {
+ armnnType = DataType::Signed32;
+ break;
+ }
+ default: {
+ throw ParseException(
+ fmt::format("'{}' is not a currently supported datatype for {}."
+ " Supported dataTypes are FLOAT, INT32 and INT64. {}",
+ onnx::TensorProto::DataType_Name(static_cast<onnx::TensorProto::DataType>(dataType)),
+ layer->GetName(),
+ CHECK_LOCATION().AsString()));
+ }
+ }
+ }
+ for (uint i = 0; i < outNames.size(); ++i)
+ {
+ if(needCompute)
+ {
+ m_TensorsInfo[outNames[i]] = OnnxTensor();
+ m_TensorsInfo[outNames[i]].m_info = std::make_unique<TensorInfo>(
+ TensorInfo(inferredShapes[i], armnnType));
+ m_TensorsInfo[outNames[i]].m_dtype = dataType;
+ }
outInfo.push_back(*m_TensorsInfo[outNames[i]].m_info);
- }
- return outInfo;
+ }
+ return outInfo;
}
OnnxParserImpl::OnnxParserImpl()
@@ -1480,7 +1492,8 @@ void OnnxParserImpl::ParseConcat(const onnx::NodeProto& node)
IConnectableLayer* layer = m_Network->AddConcatLayer(concatDescriptor, node.name().c_str());
ARMNN_ASSERT(layer != nullptr);
- auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, inputShapes);
+ auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, inputShapes,
+ m_TensorsInfo[node.input(0)].m_dtype);
layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
@@ -1774,9 +1787,10 @@ void OnnxParserImpl::ParseGather(const onnx::NodeProto& node)
IConnectableLayer* layer = m_Network->AddGatherLayer(gatherDescriptor, node.name().c_str());
ARMNN_ASSERT(layer != nullptr);
- TensorShape inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape();
- TensorShape indicesShape = m_TensorsInfo[node.input(1)].m_info->GetShape();
- auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, { inputShape, indicesShape });
+ const TensorShape& inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape();
+ const TensorShape& indicesShape = m_TensorsInfo[node.input(1)].m_info->GetShape();
+ auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, { inputShape, indicesShape },
+ m_TensorsInfo[node.input(0)].m_dtype);
layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
// register the input connection slots for the layer, connections are made after all layers have been created
@@ -1823,16 +1837,11 @@ void OnnxParserImpl::ParseShape(const onnx::NodeProto& node)
CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 1);
CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
- // Output must be INT64
- CHECK_VALID_DATATYPE(node.name(), node.output(0),
- m_TensorsInfo[node.output(0)].m_dtype,
- onnx::TensorProto::INT64);
-
IConnectableLayer* layer = m_Network->AddShapeLayer(node.name().c_str());
ARMNN_ASSERT(layer != nullptr);
TensorShape inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape();
- auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, {inputShape});
+ auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, {inputShape}, onnx::TensorProto::INT64);
layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
// register the input connection slots for the layer, connections are made after all layers have been created
@@ -1900,10 +1909,6 @@ void OnnxParserImpl::ParseUnsqueeze(const onnx::NodeProto& node)
CHECK_VALID_SIZE(armnn::numeric_cast<size_t>(node.input_size()), 1, 2);
CHECK_VALID_SIZE(armnn::numeric_cast<size_t>(node.output_size()), 1);
- CHECK_VALID_DATATYPE(node.name(), node.input(0),
- m_TensorsInfo[node.input(0)].m_dtype,
- onnx::TensorProto::FLOAT); //input
-
TensorShape inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape();
std::vector<uint32_t> dims;
if (node.input_size() == 1 && node.attribute_size() > 0)
@@ -1931,9 +1936,12 @@ void OnnxParserImpl::ParseUnsqueeze(const onnx::NodeProto& node)
std::vector<unsigned int> targetShape;
- for(uint i = 0; i < inputShape.GetNumDimensions(); i++)
+ if (inputShape.GetDimensionality() != Dimensionality::Scalar)
{
- targetShape.push_back(inputShape[i]);
+ for(uint i = 0; i < inputShape.GetNumDimensions(); i++)
+ {
+ targetShape.push_back(inputShape[i]);
+ }
}
for(uint i = 0; i < dims.size(); i++)
@@ -1941,9 +1949,10 @@ void OnnxParserImpl::ParseUnsqueeze(const onnx::NodeProto& node)
targetShape.insert(targetShape.begin() + armnn::numeric_cast<int>(dims[i]), 1);
}
- auto outInfo = ComputeReshapeInfo(TensorShape(armnn::numeric_cast<unsigned int>(targetShape.size()),
- targetShape.data()), inputShape, node.output(0));
+ auto outInfo = ComputeReshapeInfo(TensorShape(static_cast<unsigned int>(targetShape.size()), targetShape.data()),
+ inputShape, node.output(0), m_TensorsInfo[node.input(0)].m_info->GetDataType());
m_TensorsInfo[node.output(0)].m_info = std::make_unique<TensorInfo>(outInfo);
+ m_TensorsInfo[node.output(0)].m_dtype = m_TensorsInfo[node.input(0)].m_dtype;
CreateReshapeLayer(node.input(0), node.output(0), node.name());
}
diff --git a/src/armnnOnnxParser/OnnxParser.hpp b/src/armnnOnnxParser/OnnxParser.hpp
index 6a0fad0ec2..d388f501d4 100644
--- a/src/armnnOnnxParser/OnnxParser.hpp
+++ b/src/armnnOnnxParser/OnnxParser.hpp
@@ -74,9 +74,11 @@ private:
void SetupInfo(const google::protobuf::RepeatedPtrField<onnx::ValueInfoProto >* list);
- std::vector<armnn::TensorInfo> ComputeOutputInfo(std::vector<std::string> outNames,
- const armnn::IConnectableLayer* layer,
- std::vector<armnn::TensorShape> inputShapes);
+ std::vector<armnn::TensorInfo> ComputeOutputInfo(
+ std::vector<std::string> outNames,
+ const armnn::IConnectableLayer* layer,
+ std::vector<armnn::TensorShape> inputShapes,
+ const onnx::TensorProto::DataType& type = onnx::TensorProto::FLOAT);
void DetectFullyConnected();
diff --git a/src/armnnOnnxParser/test/Gather.cpp b/src/armnnOnnxParser/test/Gather.cpp
index 1d214419c4..8fd9021ebc 100644
--- a/src/armnnOnnxParser/test/Gather.cpp
+++ b/src/armnnOnnxParser/test/Gather.cpp
@@ -85,6 +85,14 @@ struct GatherMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxPar
}
};
+struct GatherScalarFixture : GatherMainFixture
+{
+ GatherScalarFixture() : GatherMainFixture({ }, { 0 }, { 8 }, { })
+ {
+ Setup();
+ }
+};
+
struct Gather1dFixture : GatherMainFixture
{
Gather1dFixture() : GatherMainFixture({ 4 }, { 0, 2, 1, 5 }, { 8 }, { 4 })
@@ -117,16 +125,22 @@ struct Gather4dFixture : GatherMainFixture
}
};
+TEST_CASE_FIXTURE(GatherScalarFixture, "GatherScalarTest")
+{
+ RunTest<1, float>({{"input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f }}},
+ {{"output", { 1.0f }}});
+}
+
TEST_CASE_FIXTURE(Gather1dFixture, "Gather1dTest")
{
- RunTest<1, float>({{"input", {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f }}},
- {{"output", {1.0f, 3.0f, 2.0f, 6.0f}}});
+ RunTest<1, float>({{"input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f }}},
+ {{"output", { 1.0f, 3.0f, 2.0f, 6.0f }}});
}
TEST_CASE_FIXTURE(Gather2dFixture, "Gather2dTest")
{
- RunTest<2, float>({{"input", {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f }}},
- {{"output", {3.0f, 4.0f, 7.0f, 8.0f, 9.0f, 10.0f }}});
+ RunTest<2, float>({{"input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f }}},
+ {{"output", { 3.0f, 4.0f, 7.0f, 8.0f, 9.0f, 10.0f }}});
}
TEST_CASE_FIXTURE(Gather3dMultiIndicesFixture, "Gather3dMultiIndicesTest")
diff --git a/src/armnnOnnxParser/test/Shape.cpp b/src/armnnOnnxParser/test/Shape.cpp
index b033b2d8bf..e01c4b8b55 100644
--- a/src/armnnOnnxParser/test/Shape.cpp
+++ b/src/armnnOnnxParser/test/Shape.cpp
@@ -5,9 +5,11 @@
#include "armnnOnnxParser/IOnnxParser.hpp"
#include "ParserPrototxtFixture.hpp"
+#include "OnnxParserTestUtils.hpp"
TEST_SUITE("OnnxParser_Shape")
{
+
struct ShapeMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
{
ShapeMainFixture(const std::string& inputType,
@@ -31,7 +33,7 @@ struct ShapeMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxPars
tensor_type {
elem_type: )" + inputType + R"(
shape {
- )" + ConstructShapeString(inputShape) + R"(
+ )" + armnnUtils::ConstructTensorShapeString(inputShape) + R"(
}
}
}
@@ -54,91 +56,77 @@ struct ShapeMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxPars
version: 10
})";
}
- std::string ConstructShapeString(const std::vector<int>& shape)
- {
- std::string shapeStr;
- for (int i : shape)
- {
- shapeStr = fmt::format("{} dim {{ dim_value: {} }}", shapeStr, i);
- }
- return shapeStr;
- }
};
-struct ShapeValidFloatFixture : ShapeMainFixture
+struct ShapeFloatFixture : ShapeMainFixture
{
- ShapeValidFloatFixture() : ShapeMainFixture("1", "7", "4", { 1, 3, 1, 5 }) {
+ ShapeFloatFixture() : ShapeMainFixture("1", "7", "4", { 1, 3, 1, 5 })
+ {
Setup();
}
};
-struct ShapeValidIntFixture : ShapeMainFixture
+struct ShapeIntFixture : ShapeMainFixture
{
- ShapeValidIntFixture() : ShapeMainFixture("7", "7", "4", { 1, 3, 1, 5 }) {
+ ShapeIntFixture() : ShapeMainFixture("7", "7", "4", { 1, 3, 1, 5 })
+ {
Setup();
}
};
struct Shape3DFixture : ShapeMainFixture
{
- Shape3DFixture() : ShapeMainFixture("1", "7", "3", { 3, 2, 3 }) {
+ Shape3DFixture() : ShapeMainFixture("1", "7", "3", { 3, 2, 3 })
+ {
Setup();
}
};
struct Shape2DFixture : ShapeMainFixture
{
- Shape2DFixture() : ShapeMainFixture("1", "7", "2", { 2, 3 }) {
+ Shape2DFixture() : ShapeMainFixture("1", "7", "2", { 2, 3 })
+ {
Setup();
}
};
struct Shape1DFixture : ShapeMainFixture
{
- Shape1DFixture() : ShapeMainFixture("1", "7", "1", { 5 }) {
+ Shape1DFixture() : ShapeMainFixture("1", "7", "1", { 5 })
+ {
Setup();
}
};
-struct ShapeInvalidFixture : ShapeMainFixture
-{
- ShapeInvalidFixture() : ShapeMainFixture("1", "1", "4", { 1, 3, 1, 5 }) {}
-};
-
-TEST_CASE_FIXTURE(ShapeValidFloatFixture, "FloatValidShapeTest")
+TEST_CASE_FIXTURE(ShapeFloatFixture, "FloatValidShapeTest")
{
- RunTest<2, int>({{"Input", { 0.0f, 1.0f, 2.0f, 3.0f, 4.0f,
+ RunTest<1, int>({{"Input", { 0.0f, 1.0f, 2.0f, 3.0f, 4.0f,
4.0f, 3.0f, 2.0f, 1.0f, 0.0f,
0.0f, 1.0f, 2.0f, 3.0f, 4.0f }}}, {{"Output", { 1, 3, 1, 5 }}});
}
-TEST_CASE_FIXTURE(ShapeValidIntFixture, "IntValidShapeTest")
+TEST_CASE_FIXTURE(ShapeIntFixture, "IntValidShapeTest")
{
- RunTest<2, int>({{"Input", { 0, 1, 2, 3, 4,
+ RunTest<1, int>({{"Input", { 0, 1, 2, 3, 4,
4, 3, 2, 1, 0,
0, 1, 2, 3, 4 }}}, {{"Output", { 1, 3, 1, 5 }}});
}
TEST_CASE_FIXTURE(Shape3DFixture, "Shape3DTest")
{
- RunTest<2, int>({{"Input", { 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
+ RunTest<1, int>({{"Input", { 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
5.0f, 4.0f, 3.0f, 2.0f, 1.0f, 0.0f,
0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f }}}, {{"Output", { 3, 2, 3 }}});
}
TEST_CASE_FIXTURE(Shape2DFixture, "Shape2DTest")
{
- RunTest<2, int>({{"Input", { 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f }}}, {{"Output", { 2, 3 }}});
+ RunTest<1, int>({{"Input", { 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f }}}, {{"Output", { 2, 3 }}});
}
TEST_CASE_FIXTURE(Shape1DFixture, "Shape1DTest")
{
- RunTest<2, int>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f }}}, {{"Output", { 5 }}});
-}
-
-TEST_CASE_FIXTURE(ShapeInvalidFixture, "IncorrectOutputDataShapeTest")
-{
- CHECK_THROWS_AS(Setup(), armnn::ParseException);
+ RunTest<1, int>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f }}}, {{"Output", { 5 }}});
}
}
diff --git a/src/armnnOnnxParser/test/Unsqueeze.cpp b/src/armnnOnnxParser/test/Unsqueeze.cpp
index 95a191e46b..7ba87bc680 100644
--- a/src/armnnOnnxParser/test/Unsqueeze.cpp
+++ b/src/armnnOnnxParser/test/Unsqueeze.cpp
@@ -77,6 +77,14 @@ struct UnsqueezeUnsortedAxesFixture : UnsqueezeFixture
}
};
+struct UnsqueezeScalarFixture : UnsqueezeFixture
+{
+ UnsqueezeScalarFixture() : UnsqueezeFixture({ 0 }, { }, { 1 })
+ {
+ Setup();
+ }
+};
+
TEST_CASE_FIXTURE(UnsqueezeSingleAxesFixture, "UnsqueezeSingleAxesTest")
{
RunTest<3, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }}},
@@ -107,6 +115,12 @@ TEST_CASE_FIXTURE(UnsqueezeUnsortedAxesFixture, "UnsqueezeUnsortedAxesTest")
6.0f, 7.0f, 8.0f, 9.0f, 10.0f }}});
}
+TEST_CASE_FIXTURE(UnsqueezeScalarFixture, "UnsqueezeScalarTest")
+{
+ RunTest<1, float>({{"Input", { 1.0f }}},
+ {{"Output", { 1.0f }}});
+}
+
struct UnsqueezeInputAxesFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
{
UnsqueezeInputAxesFixture()