diff options
Diffstat (limited to 'src/armnnTfParser/TfParser.cpp')
-rw-r--r-- | src/armnnTfParser/TfParser.cpp | 106 |
1 files changed, 106 insertions, 0 deletions
diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp index 4ddcdce1c7..53cdfa37a2 100644 --- a/src/armnnTfParser/TfParser.cpp +++ b/src/armnnTfParser/TfParser.cpp @@ -159,6 +159,17 @@ float ReadMandatoryNodeFloatAttribute(const tensorflow::NodeDef& nodeDef, const return attribValue; } +int32_t ReadMandatoryNodeInt32Attribute(const tensorflow::NodeDef& nodeDef, const std::string& name) +{ + int32_t attribValue = 0u; + ReadMandatoryNodeAttributeImpl(nodeDef, name, tensorflow::AttrValue::kI, + [&attribValue](const tensorflow::AttrValue& attrValue) + { + attribValue = static_cast<int32_t>(attrValue.i()); + }); + return attribValue; +} + uint32_t ReadMandatoryNodeUint32Attribute(const tensorflow::NodeDef& nodeDef, const std::string& name) { uint32_t attribValue = 0u; @@ -349,6 +360,7 @@ const std::map<std::string, TfParser::OperationParsingFunction> TfParser::ms_Ope { "Identity", &TfParser::ParseIdentity }, { "Conv2D", &TfParser::ParseConv2D }, { "DepthwiseConv2dNative", &TfParser::ParseDepthwiseConv2D }, + { "ExpandDims", &TfParser::ParseExpandDims }, { "FusedBatchNorm", &TfParser::ParseFusedBatchNorm }, { "ConcatV2", &TfParser::ParseConcat }, { "LRN", &TfParser::ParseLrn }, @@ -1224,6 +1236,100 @@ ParsedTfOperationPtr TfParser::ParseDepthwiseConv2D(const tensorflow::NodeDef& n return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer); } +TensorInfo OutputShapeOfExpandDims(const tensorflow::NodeDef& nodeDef, TensorInfo inputTensorInfo) +{ + BOOST_ASSERT(nodeDef.op() == "ExpandDims"); + + if (inputTensorInfo.GetNumDimensions() > 4) { + throw ParseException( + boost::str( + boost::format( + "Unsupported number of dimensions: %1% for input shape for ExpandDims %2% %3%") + % inputTensorInfo.GetNumDimensions() + % nodeDef.name() + % CHECK_LOCATION().AsString())); + } + + std::int32_t expandDim = ReadMandatoryNodeInt32Attribute(nodeDef, "Tdim"); + + std::int32_t inputDimSize = boost::numeric_cast<int32_t>(inputTensorInfo.GetNumDimensions()); + std::vector<uint32_t> outputDims; + + // expandDim operation requires: -1-input.dims() <= dim <= input.dims() + if (expandDim >= -1 - inputDimSize && expandDim <= inputDimSize) + { + // add current input shape to outputDims + for (unsigned int i = 0; i < inputTensorInfo.GetNumDimensions(); ++i) { + auto currentDimension = inputTensorInfo.GetShape()[i]; + outputDims.push_back(currentDimension); + } + + // insert a dimension of 1 at index 'expandDim' of inputs shape + if (expandDim >= 0) + { + auto getPosition = std::next(outputDims.begin() + 0, expandDim); + outputDims.insert(getPosition, 1); + } + + // if negative number for 'expandDim' then count backwards from the last element + // and insert 1 dimension at index 'expandDim' + if (expandDim < 0) + { + auto outputDimSize = boost::numeric_cast<uint32_t>(outputDims.size() + 1); + auto getPosition = std::next(outputDims.begin() + outputDimSize, expandDim); + outputDims.insert(getPosition, 1); + } + } + else + { + throw InvalidArgumentException( + boost::str( + boost::format( + "Cannot expand dimension %1% in input tensor with %2% dimension %3%") + % expandDim + % inputDimSize + % CHECK_LOCATION().AsString())); + } + + if (outputDims.size() > 4) + { + throw ParseException( + boost::str( + boost::format( + "Unsupported number of dimensions: %1% for output shape for ExpandDims %2% %3%") + % outputDims.size() + % nodeDef.name() + % CHECK_LOCATION().AsString())); + } + + TensorShape outShape = TensorShape(static_cast<unsigned int>(outputDims.size()), + outputDims.data()); + + TensorInfo outTensorInfo = inputTensorInfo; + outTensorInfo.SetShape(outShape); + + return outTensorInfo; +} + +ParsedTfOperationPtr TfParser::ParseExpandDims(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef) +{ + std::vector<OutputOfParsedTfOperation> inputs = GetInputParsedTfOperationsChecked(nodeDef, 1); + + IOutputSlot& prevLayerOutputSlot = inputs[0].m_IndexedValue->ResolveArmnnOutputSlot(inputs[0].m_Index); + TensorInfo inputTensorInfo = prevLayerOutputSlot.GetTensorInfo(); + + TensorInfo outputInfo; + outputInfo = OutputShapeOfExpandDims(nodeDef, inputTensorInfo); + + ReshapeDescriptor reshapeDesc; + reshapeDesc.m_TargetShape = outputInfo.GetShape(); + IConnectableLayer* layer = m_Network->AddReshapeLayer(reshapeDesc, nodeDef.name().c_str()); + prevLayerOutputSlot.Connect(layer->GetInputSlot(0)); + layer->GetOutputSlot(0).SetTensorInfo(outputInfo); + + return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer); +} + ParsedTfOperationPtr TfParser::ParseFusedBatchNorm(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef) { |