diff options
Diffstat (limited to 'src/armnnTfParser/TfParser.cpp')
-rwxr-xr-x | src/armnnTfParser/TfParser.cpp | 81 |
1 files changed, 67 insertions, 14 deletions
diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp index 90bd992a2b..0087ef83bf 100755 --- a/src/armnnTfParser/TfParser.cpp +++ b/src/armnnTfParser/TfParser.cpp @@ -2,40 +2,27 @@ // Copyright © 2017 Arm Ltd. All rights reserved. // SPDX-License-Identifier: MIT // + #include "TfParser.hpp" -#include <armnn/INetwork.hpp> -#include <armnn/Utils.hpp> #include <armnn/TypesUtils.hpp> -#include <armnn/Exceptions.hpp> #include <armnn/Descriptors.hpp> #include <GraphTopologicalSort.hpp> #include <ParserHelper.hpp> #include <Permute.hpp> -#include <VerificationHelpers.hpp> #include <DataLayoutIndexed.hpp> #include <google/protobuf/io/zero_copy_stream_impl.h> #include <google/protobuf/text_format.h> #include "tensorflow/core/framework/graph.pb.h" -#include "tensorflow/core/framework/node_def.pb.h" -#include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/framework/tensor.pb.h" -#include "tensorflow/core/framework/tensor_shape.pb.h" -#include <boost/assert.hpp> #include <boost/format.hpp> #include <boost/core/ignore_unused.hpp> -#include <boost/log/trivial.hpp> -#include <boost/numeric/conversion/cast.hpp> #include <boost/polymorphic_cast.hpp> -#include <memory> -#include <sstream> #include <numeric> -#include <functional> using namespace armnnUtils; using namespace armnn; @@ -141,6 +128,17 @@ int32_t ReadMandatoryNodeInt32Attribute(const tensorflow::NodeDef& nodeDef, cons return attribValue; } +bool ReadMandatoryNodeBoolAttribute(const tensorflow::NodeDef& nodeDef, const std::string& name) +{ + bool attribValue = false; + ReadMandatoryNodeAttributeImpl(nodeDef, name, tensorflow::AttrValue::kB, + [&attribValue](const tensorflow::AttrValue& attrValue) + { + attribValue = static_cast<bool>(attrValue.b()); + }); + return attribValue; +} + uint32_t ReadMandatoryNodeUint32Attribute(const tensorflow::NodeDef& nodeDef, const std::string& name) { uint32_t attribValue = 0u; @@ -338,6 +336,7 @@ const std::map<std::string, TfParser::OperationParsingFunction> TfParser::ms_Ope { "ConcatV2", &TfParser::ParseConcat }, { "LRN", &TfParser::ParseLrn }, { "MatMul", &TfParser::ParseMatMul }, + { "Mean", &TfParser::ParseMean }, { "Mul", &TfParser::ParseMul }, { "Placeholder", &TfParser::ParsePlaceholder }, { "RealDiv", &TfParser::ParseRealDiv }, @@ -2349,6 +2348,60 @@ ParsedTfOperationPtr TfParser::ParseMatMul(const tensorflow::NodeDef& nodeDef, c return std::make_unique<ParsedMatMulTfOperation>(this, nodeDef); } +ParsedTfOperationPtr TfParser::ParseMean(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef) +{ + std::vector<OutputOfParsedTfOperation> inputs = GetInputParsedTfOperationsChecked(nodeDef, 2); + IOutputSlot& inputSlot = inputs[0].m_IndexedValue->ResolveArmnnOutputSlot(inputs[0].m_Index); + TensorInfo inputTensorInfo = inputSlot.GetTensorInfo(); + + if (inputs.size() != 2) + { + throw ParseException( + boost::str(boost::format("Mean expects two inputs!. Got %1% for Node %2% %3%") + % inputs.size() + % nodeDef.name() + % CHECK_LOCATION().AsString())); + } + + bool keepDims = ReadMandatoryNodeBoolAttribute(nodeDef, "keep_dims"); + + ParsedConstTfOperation<int32_t>* axisNode = + boost::polymorphic_downcast<ParsedConstTfOperation<int32_t>*>(inputs[1].m_IndexedValue); + + const TensorInfo& axisTensorInfo = axisNode->GetTensorInfo(); + + ConstTensor axisTensor(axisTensorInfo, axisNode->GetStorage()); + const int* axisData = static_cast<const int*>(axisTensor.GetMemoryArea()); + + TensorInfo outputTensorInfo; + MeanDescriptor meanDescriptor; + meanDescriptor.m_KeepDims = keepDims; + + // Negative axis values are supported so that the process requires + // to convert them into the corresponding positive ones. + // Duplicate values are also removed. + std::vector<int> rawAxisVector(axisData, axisData + axisTensorInfo.GetNumElements()); + std::set<unsigned int> positiveAxisSet; + int rank = static_cast<int>(inputTensorInfo.GetNumDimensions()); + + std::transform(rawAxisVector.begin(), rawAxisVector.end(), + std::inserter(positiveAxisSet, positiveAxisSet.begin()), + [rank](int i) -> unsigned int { return static_cast<unsigned int>((i + rank) % rank); }); + + CalculateReducedOutputTensoInfo(inputTensorInfo, axisTensorInfo, positiveAxisSet, keepDims, outputTensorInfo); + + if (inputTensorInfo.GetNumDimensions() > positiveAxisSet.size()) + { + meanDescriptor.m_Axis.assign(positiveAxisSet.begin(), positiveAxisSet.end()); + } + + IConnectableLayer* layer = m_Network->AddMeanLayer(meanDescriptor, nodeDef.name().c_str()); + layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); + inputSlot.Connect(layer->GetInputSlot(0)); + + return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer); +} + /// An ParsedTfOperation for a Mul node. /// Creation of the armnn Mul layer is deferred until it is actually needed, because Mul nodes /// are also used for the first part of a leaky relu activation function (Mul followed by Maximum) |