aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJan Eilers <jan.eilers@arm.com>2020-09-08 08:57:40 +0100
committerKeithARM <keith.davis@arm.com>2020-09-10 09:23:30 +0000
commit1f3b49be73d1fadf06f20c912aa160a5ab53a6a8 (patch)
tree8bd9d08025f23c2054c21f1d9d76a29435fb6d74
parent54940191dfe3a405dcc0fdf6516849082ae62cc7 (diff)
downloadarmnn-1f3b49be73d1fadf06f20c912aa160a5ab53a6a8.tar.gz
IVGCVSW-5197 Add support for 2nd input to ExpandDims of TfParser
* ParseExpandDims did not support to pass the axis parameter as a second input tensor * Added related unit tests Signed-off-by: Jan Eilers <jan.eilers@arm.com> Change-Id: I8217950f0b42beaf5b9eaebdcad04267e4443ba3
-rwxr-xr-xsrc/armnnTfParser/TfParser.cpp77
-rw-r--r--src/armnnTfParser/test/ExpandDims.cpp201
2 files changed, 272 insertions, 6 deletions
diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp
index 38202fcf94..0d7c371eae 100755
--- a/src/armnnTfParser/TfParser.cpp
+++ b/src/armnnTfParser/TfParser.cpp
@@ -24,7 +24,7 @@
#include <boost/format.hpp>
#include <boost/numeric/conversion/cast.hpp>
-#include <armnn/utility/PolymorphicDowncast.hpp>
+#include <fmt/core.h>
#include <numeric>
using namespace armnnUtils;
@@ -1464,7 +1464,9 @@ ParsedTfOperationPtr TfParser::ParseDepthwiseConv2D(const tensorflow::NodeDef& n
return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer);
}
-TensorInfo OutputShapeOfExpandDims(const tensorflow::NodeDef& nodeDef, TensorInfo inputTensorInfo)
+TensorInfo OutputShapeOfExpandDims(const tensorflow::NodeDef& nodeDef,
+ TensorInfo inputTensorInfo,
+ std::int32_t expandDim)
{
ARMNN_ASSERT(nodeDef.op() == "ExpandDims");
@@ -1478,8 +1480,6 @@ TensorInfo OutputShapeOfExpandDims(const tensorflow::NodeDef& nodeDef, TensorInf
% CHECK_LOCATION().AsString()));
}
- std::int32_t expandDim = ReadMandatoryNodeInt32Attribute(nodeDef, "Tdim");
-
std::int32_t inputDimSize = boost::numeric_cast<int32_t>(inputTensorInfo.GetNumDimensions());
std::vector<uint32_t> outputDims;
@@ -1542,13 +1542,78 @@ TensorInfo OutputShapeOfExpandDims(const tensorflow::NodeDef& nodeDef, TensorInf
ParsedTfOperationPtr TfParser::ParseExpandDims(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef)
{
IgnoreUnused(graphDef);
- std::vector<OutputOfParsedTfOperation> inputs = GetInputParsedTfOperationsChecked(nodeDef, 1);
+ // Number of inputs can either
+ // be 1 - that indicates that the axis parameter is passed as an attribute of the operation
+ // or 2 - which means that the axis parameter is passed as a second input
+ std::vector<OutputOfConstNodeDef> nodes = GetTfInputNodes(nodeDef);
+ const std::size_t numInputs = nodes.size();
+ std::vector<OutputOfParsedTfOperation> inputs;
+ std::int32_t expandDim; // axis or dim parameter. Describes which dimension to expand.
+ if (numInputs == 1)
+ {
+ inputs = GetInputParsedTfOperationsChecked(nodeDef, 1);
+ expandDim = ReadMandatoryNodeInt32Attribute(nodeDef, "Tdim");
+ }
+ else
+ {
+ inputs = GetInputParsedTfOperationsChecked(nodeDef, 2);
+
+ // make sure data type is int32
+ IOutputSlot& prevLayerOutputSlot = inputs[1].m_IndexedValue->ResolveArmnnOutputSlot(inputs[1].m_Index);
+ TensorInfo inputTensorInfo = prevLayerOutputSlot.GetTensorInfo();
+
+ if (inputTensorInfo.GetDataType()!=armnn::DataType::Signed32)
+ {
+ throw ParseException(
+ fmt::format(
+ "The axis parameter of ExpandDims operation given as second input is not of type int32. "
+ "Input {0} Node {1} {2}",
+ inputs[1].m_IndexedValue->GetNode().name(),
+ nodeDef.name(),
+ CHECK_LOCATION().AsString()));
+ }
+
+ // ensure the second input is a constant value
+ if (!HasParsedConstTensor<int32_t>(inputs[1].m_IndexedValue->GetNode().name()))
+ {
+ throw ParseException(
+ fmt::format(
+ "ArmNN only supports ExpandDims layers with constant axis/dim parameter. "
+ "Input {0} Node {1} {2}",
+ inputs[1].m_IndexedValue->GetNode().name(),
+ nodeDef.name(),
+ CHECK_LOCATION().AsString()));
+ }
+
+ // make sure the second input is scalar or contains only a single value
+ // (we don't support expand dims for multiple axis but we don't care what shape the
+ // given tensor has as long as there is only a single value in it
+ // e.g. a tensor like this [[[1]]] is completely fine)
+ if (inputTensorInfo.GetNumElements() != 1)
+ {
+ throw ParseException(
+ fmt::format(
+ "The axis parameter of ExpandDims operation given as second input is not "
+ "allowed to hold more than one value. "
+ "Input {0} Node {1} {2}",
+ inputs[1].m_IndexedValue->GetNode().name(),
+ nodeDef.name(),
+ CHECK_LOCATION().AsString()));
+ }
+
+ ParsedConstTfOperation<int32_t>* expandDimsNode =
+ PolymorphicDowncast<ParsedConstTfOperation<int32_t>*>(inputs[1].m_IndexedValue);
+
+ memcpy(&expandDim, expandDimsNode->GetStorage(), sizeof(expandDim));
+ }
+
+ // First input is the vector that should be expanded by another dimension
IOutputSlot& prevLayerOutputSlot = inputs[0].m_IndexedValue->ResolveArmnnOutputSlot(inputs[0].m_Index);
TensorInfo inputTensorInfo = prevLayerOutputSlot.GetTensorInfo();
TensorInfo outputInfo;
- outputInfo = OutputShapeOfExpandDims(nodeDef, inputTensorInfo);
+ outputInfo = OutputShapeOfExpandDims(nodeDef, inputTensorInfo, expandDim);
ReshapeDescriptor reshapeDesc;
reshapeDesc.m_TargetShape = outputInfo.GetShape();
diff --git a/src/armnnTfParser/test/ExpandDims.cpp b/src/armnnTfParser/test/ExpandDims.cpp
index 57d472d41d..ad95641cd1 100644
--- a/src/armnnTfParser/test/ExpandDims.cpp
+++ b/src/armnnTfParser/test/ExpandDims.cpp
@@ -109,4 +109,205 @@ BOOST_FIXTURE_TEST_CASE(ParseExpandMinusThreeDim, ExpandMinusThreeDim)
armnn::TensorShape({2, 1, 3, 5})));
}
+struct ExpandDimsAsInputFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
+{
+ ExpandDimsAsInputFixture(const std::string& expandDim,
+ const bool wrongDataType = false,
+ const std::string& numElements = "1")
+ {
+ std::string dataType = (wrongDataType) ? "DT_FLOAT" : "DT_INT32";
+ std::string val = (wrongDataType) ? ("float_val: " + expandDim + ".0") : ("int_val: "+ expandDim);
+
+ m_Prototext = R"(
+ node {
+ name: "a"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 4
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "b"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: )" + dataType + R"(
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: )" + dataType + R"(
+ tensor_shape {
+ dim {
+ size: )" + numElements + R"(
+ }
+ }
+ )" + val + R"(
+ }
+ }
+ }
+ }
+ node {
+ name: "ExpandDims"
+ op: "ExpandDims"
+ input: "a"
+ input: "b"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tdim"
+ value {
+ type: DT_INT32
+ }
+ }
+ }
+ versions {
+ producer: 134
+ })";
+ }
+};
+
+struct ExpandDimAsInput : ExpandDimsAsInputFixture
+{
+ ExpandDimAsInput() : ExpandDimsAsInputFixture("0")
+ {
+ Setup({{"a", {1,4}} ,{"b",{1,1}}}, { "ExpandDims" });
+ }
+};
+
+
+BOOST_FIXTURE_TEST_CASE(ParseExpandDimAsInput, ExpandDimAsInput)
+{
+ // Axis parameter that describes which axis/dim should be expanded is passed as a second input
+ BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("ExpandDims").second.GetShape() ==
+ armnn::TensorShape({1, 1, 4})));
+}
+
+struct ExpandDimAsInputWrongDataType : ExpandDimsAsInputFixture
+{
+ ExpandDimAsInputWrongDataType() : ExpandDimsAsInputFixture("0", true, "1") {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseExpandDimAsInputWrongDataType, ExpandDimAsInputWrongDataType)
+{
+ // Axis parameter that describes which axis/dim should be expanded is passed as a second input
+ // Axis parameter is of wrong data type (float instead of int32)
+ BOOST_REQUIRE_THROW(Setup({{"a", {1,4}} ,{"b",{1,1}}}, { "ExpandDims" }), armnn::ParseException);
+}
+
+struct ExpandDimAsInputWrongShape : ExpandDimsAsInputFixture
+{
+ ExpandDimAsInputWrongShape() : ExpandDimsAsInputFixture("0", false, "2") {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseExpandDimAsInputWrongShape, ExpandDimAsInputWrongShape)
+{
+ // Axis parameter that describes which axis/dim should be expanded is passed as a second input
+ // Axis parameter is of wrong shape
+ BOOST_REQUIRE_THROW(Setup({{"a", {1,4}} ,{"b",{1,1}}}, { "ExpandDims" }), armnn::ParseException);
+}
+
+struct ExpandDimsAsNotConstInputFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
+{
+ ExpandDimsAsNotConstInputFixture()
+ {
+ m_Prototext = R"(
+ node {
+ name: "a"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 4
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "b"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "ExpandDims"
+ op: "ExpandDims"
+ input: "a"
+ input: "b"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tdim"
+ value {
+ type: DT_INT32
+ }
+ }
+ }
+ versions {
+ producer: 134
+ })";
+ }
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseExpandDimAsNotConstInput, ExpandDimsAsNotConstInputFixture)
+{
+ // Axis parameter that describes which axis/dim should be expanded is passed as a second input.
+ // But is not a constant tensor --> not supported
+ BOOST_REQUIRE_THROW(Setup({{"a", {1,4}} ,{"b",{1,1}}}, { "ExpandDims" }),
+ armnn::ParseException);
+}
+
BOOST_AUTO_TEST_SUITE_END()