aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfParser/TfParser.cpp
diff options
context:
space:
mode:
authorConor Kennedy <conor.kennedy@arm.com>2018-12-05 11:05:54 +0000
committerAron Virginas-Tar <aron.virginas-tar@arm.com>2018-12-05 13:36:01 +0000
commitc2130a070e6a9196d193c93a02b5f118810dd59a (patch)
tree45994da8a86619e4ed942a6619df284954329a1f /src/armnnTfParser/TfParser.cpp
parentf6ba747c0802d87ba30aecd598f0603f9bd18576 (diff)
downloadarmnn-c2130a070e6a9196d193c93a02b5f118810dd59a.tar.gz
IVGCVSW-2193 ExpandDims operation implementation
* Add ExpandDims operation to TfParser.cpp Change-Id: Ifa756ae0667c11e3b6daec8f6dd4e54cac88d16a
Diffstat (limited to 'src/armnnTfParser/TfParser.cpp')
-rw-r--r--src/armnnTfParser/TfParser.cpp106
1 files changed, 106 insertions, 0 deletions
diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp
index 4ddcdce1c7..53cdfa37a2 100644
--- a/src/armnnTfParser/TfParser.cpp
+++ b/src/armnnTfParser/TfParser.cpp
@@ -159,6 +159,17 @@ float ReadMandatoryNodeFloatAttribute(const tensorflow::NodeDef& nodeDef, const
return attribValue;
}
+int32_t ReadMandatoryNodeInt32Attribute(const tensorflow::NodeDef& nodeDef, const std::string& name)
+{
+ int32_t attribValue = 0u;
+ ReadMandatoryNodeAttributeImpl(nodeDef, name, tensorflow::AttrValue::kI,
+ [&attribValue](const tensorflow::AttrValue& attrValue)
+ {
+ attribValue = static_cast<int32_t>(attrValue.i());
+ });
+ return attribValue;
+}
+
uint32_t ReadMandatoryNodeUint32Attribute(const tensorflow::NodeDef& nodeDef, const std::string& name)
{
uint32_t attribValue = 0u;
@@ -349,6 +360,7 @@ const std::map<std::string, TfParser::OperationParsingFunction> TfParser::ms_Ope
{ "Identity", &TfParser::ParseIdentity },
{ "Conv2D", &TfParser::ParseConv2D },
{ "DepthwiseConv2dNative", &TfParser::ParseDepthwiseConv2D },
+ { "ExpandDims", &TfParser::ParseExpandDims },
{ "FusedBatchNorm", &TfParser::ParseFusedBatchNorm },
{ "ConcatV2", &TfParser::ParseConcat },
{ "LRN", &TfParser::ParseLrn },
@@ -1224,6 +1236,100 @@ ParsedTfOperationPtr TfParser::ParseDepthwiseConv2D(const tensorflow::NodeDef& n
return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer);
}
+TensorInfo OutputShapeOfExpandDims(const tensorflow::NodeDef& nodeDef, TensorInfo inputTensorInfo)
+{
+ BOOST_ASSERT(nodeDef.op() == "ExpandDims");
+
+ if (inputTensorInfo.GetNumDimensions() > 4) {
+ throw ParseException(
+ boost::str(
+ boost::format(
+ "Unsupported number of dimensions: %1% for input shape for ExpandDims %2% %3%")
+ % inputTensorInfo.GetNumDimensions()
+ % nodeDef.name()
+ % CHECK_LOCATION().AsString()));
+ }
+
+ std::int32_t expandDim = ReadMandatoryNodeInt32Attribute(nodeDef, "Tdim");
+
+ std::int32_t inputDimSize = boost::numeric_cast<int32_t>(inputTensorInfo.GetNumDimensions());
+ std::vector<uint32_t> outputDims;
+
+ // expandDim operation requires: -1-input.dims() <= dim <= input.dims()
+ if (expandDim >= -1 - inputDimSize && expandDim <= inputDimSize)
+ {
+ // add current input shape to outputDims
+ for (unsigned int i = 0; i < inputTensorInfo.GetNumDimensions(); ++i) {
+ auto currentDimension = inputTensorInfo.GetShape()[i];
+ outputDims.push_back(currentDimension);
+ }
+
+ // insert a dimension of 1 at index 'expandDim' of inputs shape
+ if (expandDim >= 0)
+ {
+ auto getPosition = std::next(outputDims.begin() + 0, expandDim);
+ outputDims.insert(getPosition, 1);
+ }
+
+ // if negative number for 'expandDim' then count backwards from the last element
+ // and insert 1 dimension at index 'expandDim'
+ if (expandDim < 0)
+ {
+ auto outputDimSize = boost::numeric_cast<uint32_t>(outputDims.size() + 1);
+ auto getPosition = std::next(outputDims.begin() + outputDimSize, expandDim);
+ outputDims.insert(getPosition, 1);
+ }
+ }
+ else
+ {
+ throw InvalidArgumentException(
+ boost::str(
+ boost::format(
+ "Cannot expand dimension %1% in input tensor with %2% dimension %3%")
+ % expandDim
+ % inputDimSize
+ % CHECK_LOCATION().AsString()));
+ }
+
+ if (outputDims.size() > 4)
+ {
+ throw ParseException(
+ boost::str(
+ boost::format(
+ "Unsupported number of dimensions: %1% for output shape for ExpandDims %2% %3%")
+ % outputDims.size()
+ % nodeDef.name()
+ % CHECK_LOCATION().AsString()));
+ }
+
+ TensorShape outShape = TensorShape(static_cast<unsigned int>(outputDims.size()),
+ outputDims.data());
+
+ TensorInfo outTensorInfo = inputTensorInfo;
+ outTensorInfo.SetShape(outShape);
+
+ return outTensorInfo;
+}
+
+ParsedTfOperationPtr TfParser::ParseExpandDims(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef)
+{
+ std::vector<OutputOfParsedTfOperation> inputs = GetInputParsedTfOperationsChecked(nodeDef, 1);
+
+ IOutputSlot& prevLayerOutputSlot = inputs[0].m_IndexedValue->ResolveArmnnOutputSlot(inputs[0].m_Index);
+ TensorInfo inputTensorInfo = prevLayerOutputSlot.GetTensorInfo();
+
+ TensorInfo outputInfo;
+ outputInfo = OutputShapeOfExpandDims(nodeDef, inputTensorInfo);
+
+ ReshapeDescriptor reshapeDesc;
+ reshapeDesc.m_TargetShape = outputInfo.GetShape();
+ IConnectableLayer* layer = m_Network->AddReshapeLayer(reshapeDesc, nodeDef.name().c_str());
+ prevLayerOutputSlot.Connect(layer->GetInputSlot(0));
+ layer->GetOutputSlot(0).SetTensorInfo(outputInfo);
+
+ return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer);
+}
+
ParsedTfOperationPtr TfParser::ParseFusedBatchNorm(const tensorflow::NodeDef& nodeDef,
const tensorflow::GraphDef& graphDef)
{