aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfParser
diff options
context:
space:
mode:
authorConor Kennedy <conor.kennedy@arm.com>2018-12-05 11:05:54 +0000
committerAron Virginas-Tar <aron.virginas-tar@arm.com>2018-12-05 13:36:01 +0000
commitc2130a070e6a9196d193c93a02b5f118810dd59a (patch)
tree45994da8a86619e4ed942a6619df284954329a1f /src/armnnTfParser
parentf6ba747c0802d87ba30aecd598f0603f9bd18576 (diff)
downloadarmnn-c2130a070e6a9196d193c93a02b5f118810dd59a.tar.gz
IVGCVSW-2193 ExpandDims operation implementation
* Add ExpandDims operation to TfParser.cpp Change-Id: Ifa756ae0667c11e3b6daec8f6dd4e54cac88d16a
Diffstat (limited to 'src/armnnTfParser')
-rw-r--r--src/armnnTfParser/TfParser.cpp106
-rw-r--r--src/armnnTfParser/TfParser.hpp1
-rw-r--r--src/armnnTfParser/test/ExpandDims.cpp112
3 files changed, 219 insertions, 0 deletions
diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp
index 4ddcdce1c7..53cdfa37a2 100644
--- a/src/armnnTfParser/TfParser.cpp
+++ b/src/armnnTfParser/TfParser.cpp
@@ -159,6 +159,17 @@ float ReadMandatoryNodeFloatAttribute(const tensorflow::NodeDef& nodeDef, const
return attribValue;
}
+int32_t ReadMandatoryNodeInt32Attribute(const tensorflow::NodeDef& nodeDef, const std::string& name)
+{
+ int32_t attribValue = 0u;
+ ReadMandatoryNodeAttributeImpl(nodeDef, name, tensorflow::AttrValue::kI,
+ [&attribValue](const tensorflow::AttrValue& attrValue)
+ {
+ attribValue = static_cast<int32_t>(attrValue.i());
+ });
+ return attribValue;
+}
+
uint32_t ReadMandatoryNodeUint32Attribute(const tensorflow::NodeDef& nodeDef, const std::string& name)
{
uint32_t attribValue = 0u;
@@ -349,6 +360,7 @@ const std::map<std::string, TfParser::OperationParsingFunction> TfParser::ms_Ope
{ "Identity", &TfParser::ParseIdentity },
{ "Conv2D", &TfParser::ParseConv2D },
{ "DepthwiseConv2dNative", &TfParser::ParseDepthwiseConv2D },
+ { "ExpandDims", &TfParser::ParseExpandDims },
{ "FusedBatchNorm", &TfParser::ParseFusedBatchNorm },
{ "ConcatV2", &TfParser::ParseConcat },
{ "LRN", &TfParser::ParseLrn },
@@ -1224,6 +1236,100 @@ ParsedTfOperationPtr TfParser::ParseDepthwiseConv2D(const tensorflow::NodeDef& n
return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer);
}
+TensorInfo OutputShapeOfExpandDims(const tensorflow::NodeDef& nodeDef, TensorInfo inputTensorInfo)
+{
+ BOOST_ASSERT(nodeDef.op() == "ExpandDims");
+
+ if (inputTensorInfo.GetNumDimensions() > 4) {
+ throw ParseException(
+ boost::str(
+ boost::format(
+ "Unsupported number of dimensions: %1% for input shape for ExpandDims %2% %3%")
+ % inputTensorInfo.GetNumDimensions()
+ % nodeDef.name()
+ % CHECK_LOCATION().AsString()));
+ }
+
+ std::int32_t expandDim = ReadMandatoryNodeInt32Attribute(nodeDef, "Tdim");
+
+ std::int32_t inputDimSize = boost::numeric_cast<int32_t>(inputTensorInfo.GetNumDimensions());
+ std::vector<uint32_t> outputDims;
+
+ // expandDim operation requires: -1-input.dims() <= dim <= input.dims()
+ if (expandDim >= -1 - inputDimSize && expandDim <= inputDimSize)
+ {
+ // add current input shape to outputDims
+ for (unsigned int i = 0; i < inputTensorInfo.GetNumDimensions(); ++i) {
+ auto currentDimension = inputTensorInfo.GetShape()[i];
+ outputDims.push_back(currentDimension);
+ }
+
+ // insert a dimension of 1 at index 'expandDim' of inputs shape
+ if (expandDim >= 0)
+ {
+ auto getPosition = std::next(outputDims.begin() + 0, expandDim);
+ outputDims.insert(getPosition, 1);
+ }
+
+ // if negative number for 'expandDim' then count backwards from the last element
+ // and insert 1 dimension at index 'expandDim'
+ if (expandDim < 0)
+ {
+ auto outputDimSize = boost::numeric_cast<uint32_t>(outputDims.size() + 1);
+ auto getPosition = std::next(outputDims.begin() + outputDimSize, expandDim);
+ outputDims.insert(getPosition, 1);
+ }
+ }
+ else
+ {
+ throw InvalidArgumentException(
+ boost::str(
+ boost::format(
+ "Cannot expand dimension %1% in input tensor with %2% dimension %3%")
+ % expandDim
+ % inputDimSize
+ % CHECK_LOCATION().AsString()));
+ }
+
+ if (outputDims.size() > 4)
+ {
+ throw ParseException(
+ boost::str(
+ boost::format(
+ "Unsupported number of dimensions: %1% for output shape for ExpandDims %2% %3%")
+ % outputDims.size()
+ % nodeDef.name()
+ % CHECK_LOCATION().AsString()));
+ }
+
+ TensorShape outShape = TensorShape(static_cast<unsigned int>(outputDims.size()),
+ outputDims.data());
+
+ TensorInfo outTensorInfo = inputTensorInfo;
+ outTensorInfo.SetShape(outShape);
+
+ return outTensorInfo;
+}
+
+ParsedTfOperationPtr TfParser::ParseExpandDims(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef)
+{
+ std::vector<OutputOfParsedTfOperation> inputs = GetInputParsedTfOperationsChecked(nodeDef, 1);
+
+ IOutputSlot& prevLayerOutputSlot = inputs[0].m_IndexedValue->ResolveArmnnOutputSlot(inputs[0].m_Index);
+ TensorInfo inputTensorInfo = prevLayerOutputSlot.GetTensorInfo();
+
+ TensorInfo outputInfo;
+ outputInfo = OutputShapeOfExpandDims(nodeDef, inputTensorInfo);
+
+ ReshapeDescriptor reshapeDesc;
+ reshapeDesc.m_TargetShape = outputInfo.GetShape();
+ IConnectableLayer* layer = m_Network->AddReshapeLayer(reshapeDesc, nodeDef.name().c_str());
+ prevLayerOutputSlot.Connect(layer->GetInputSlot(0));
+ layer->GetOutputSlot(0).SetTensorInfo(outputInfo);
+
+ return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer);
+}
+
ParsedTfOperationPtr TfParser::ParseFusedBatchNorm(const tensorflow::NodeDef& nodeDef,
const tensorflow::GraphDef& graphDef)
{
diff --git a/src/armnnTfParser/TfParser.hpp b/src/armnnTfParser/TfParser.hpp
index 1c29ce2717..da78f48f54 100644
--- a/src/armnnTfParser/TfParser.hpp
+++ b/src/armnnTfParser/TfParser.hpp
@@ -131,6 +131,7 @@ private:
ParsedTfOperationPtr ParseBiasAdd(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
ParsedTfOperationPtr ParseConv2D(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
ParsedTfOperationPtr ParseDepthwiseConv2D(const tensorflow::NodeDef& nodeDef,const tensorflow::GraphDef& graphDef);
+ ParsedTfOperationPtr ParseExpandDims(const tensorflow::NodeDef& nodeDef,const tensorflow::GraphDef& graphDef);
ParsedTfOperationPtr ParseFusedBatchNorm(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
ParsedTfOperationPtr ParseConcat(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
ParsedTfOperationPtr ParseIdentity(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
diff --git a/src/armnnTfParser/test/ExpandDims.cpp b/src/armnnTfParser/test/ExpandDims.cpp
new file mode 100644
index 0000000000..57d472d41d
--- /dev/null
+++ b/src/armnnTfParser/test/ExpandDims.cpp
@@ -0,0 +1,112 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include <boost/test/unit_test.hpp>
+#include "armnnTfParser/ITfParser.hpp"
+#include "ParserPrototxtFixture.hpp"
+
+BOOST_AUTO_TEST_SUITE(TensorflowParser)
+
+struct ExpandDimsFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
+{
+ ExpandDimsFixture(const std::string& expandDim)
+ {
+ m_Prototext =
+ "node { \n"
+ " name: \"graphInput\" \n"
+ " op: \"Placeholder\" \n"
+ " attr { \n"
+ " key: \"dtype\" \n"
+ " value { \n"
+ " type: DT_FLOAT \n"
+ " } \n"
+ " } \n"
+ " attr { \n"
+ " key: \"shape\" \n"
+ " value { \n"
+ " shape { \n"
+ " } \n"
+ " } \n"
+ " } \n"
+ " } \n"
+ "node { \n"
+ " name: \"ExpandDims\" \n"
+ " op: \"ExpandDims\" \n"
+ " input: \"graphInput\" \n"
+ " attr { \n"
+ " key: \"T\" \n"
+ " value { \n"
+ " type: DT_FLOAT \n"
+ " } \n"
+ " } \n"
+ " attr { \n"
+ " key: \"Tdim\" \n"
+ " value { \n";
+ m_Prototext += "i:" + expandDim;
+ m_Prototext +=
+ " } \n"
+ " } \n"
+ "} \n";
+
+ SetupSingleInputSingleOutput({ 2, 3, 5 }, "graphInput", "ExpandDims");
+ }
+};
+
+struct ExpandZeroDim : ExpandDimsFixture
+{
+ ExpandZeroDim() : ExpandDimsFixture("0") {}
+};
+
+struct ExpandTwoDim : ExpandDimsFixture
+{
+ ExpandTwoDim() : ExpandDimsFixture("2") {}
+};
+
+struct ExpandThreeDim : ExpandDimsFixture
+{
+ ExpandThreeDim() : ExpandDimsFixture("3") {}
+};
+
+struct ExpandMinusOneDim : ExpandDimsFixture
+{
+ ExpandMinusOneDim() : ExpandDimsFixture("-1") {}
+};
+
+struct ExpandMinusThreeDim : ExpandDimsFixture
+{
+ ExpandMinusThreeDim() : ExpandDimsFixture("-3") {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseExpandZeroDim, ExpandZeroDim)
+{
+ BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("ExpandDims").second.GetShape() ==
+ armnn::TensorShape({1, 2, 3, 5})));
+}
+
+BOOST_FIXTURE_TEST_CASE(ParseExpandTwoDim, ExpandTwoDim)
+{
+ BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("ExpandDims").second.GetShape() ==
+ armnn::TensorShape({2, 3, 1, 5})));
+}
+
+BOOST_FIXTURE_TEST_CASE(ParseExpandThreeDim, ExpandThreeDim)
+{
+ BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("ExpandDims").second.GetShape() ==
+ armnn::TensorShape({2, 3, 5, 1})));
+}
+
+BOOST_FIXTURE_TEST_CASE(ParseExpandMinusOneDim, ExpandMinusOneDim)
+{
+ BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("ExpandDims").second.GetShape() ==
+ armnn::TensorShape({2, 3, 5, 1})));
+}
+
+BOOST_FIXTURE_TEST_CASE(ParseExpandMinusThreeDim, ExpandMinusThreeDim)
+{
+ BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("ExpandDims").second.GetShape() ==
+ armnn::TensorShape({2, 1, 3, 5})));
+}
+
+BOOST_AUTO_TEST_SUITE_END()