aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfParser/TfParser.cpp
diff options
context:
space:
mode:
authorFerran Balaguer <ferran.balaguer@arm.com>2019-01-11 19:29:18 +0000
committerFerran Balaguer Arm <ferran.balaguer@arm.com>2019-01-14 09:45:15 +0000
commit51dd62f5725e8a97f3f6957fbc2b899493eb7bb3 (patch)
treef8cce612850d49d798686cce5ad2ab7545b6e0b7 /src/armnnTfParser/TfParser.cpp
parent992d6dc57d8463729910b688f0fb5825d0d3ccf2 (diff)
downloadarmnn-51dd62f5725e8a97f3f6957fbc2b899493eb7bb3.tar.gz
IVGCVSW-1656 Add Mean support to Tf Parser
Change-Id: I3d31d6b72be1984acdb51fd9e7b5488a7aa5d832
Diffstat (limited to 'src/armnnTfParser/TfParser.cpp')
-rwxr-xr-xsrc/armnnTfParser/TfParser.cpp81
1 files changed, 67 insertions, 14 deletions
diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp
index 90bd992a2b..0087ef83bf 100755
--- a/src/armnnTfParser/TfParser.cpp
+++ b/src/armnnTfParser/TfParser.cpp
@@ -2,40 +2,27 @@
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
+
#include "TfParser.hpp"
-#include <armnn/INetwork.hpp>
-#include <armnn/Utils.hpp>
#include <armnn/TypesUtils.hpp>
-#include <armnn/Exceptions.hpp>
#include <armnn/Descriptors.hpp>
#include <GraphTopologicalSort.hpp>
#include <ParserHelper.hpp>
#include <Permute.hpp>
-#include <VerificationHelpers.hpp>
#include <DataLayoutIndexed.hpp>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <google/protobuf/text_format.h>
#include "tensorflow/core/framework/graph.pb.h"
-#include "tensorflow/core/framework/node_def.pb.h"
-#include "tensorflow/core/framework/types.pb.h"
-#include "tensorflow/core/framework/tensor.pb.h"
-#include "tensorflow/core/framework/tensor_shape.pb.h"
-#include <boost/assert.hpp>
#include <boost/format.hpp>
#include <boost/core/ignore_unused.hpp>
-#include <boost/log/trivial.hpp>
-#include <boost/numeric/conversion/cast.hpp>
#include <boost/polymorphic_cast.hpp>
-#include <memory>
-#include <sstream>
#include <numeric>
-#include <functional>
using namespace armnnUtils;
using namespace armnn;
@@ -141,6 +128,17 @@ int32_t ReadMandatoryNodeInt32Attribute(const tensorflow::NodeDef& nodeDef, cons
return attribValue;
}
+bool ReadMandatoryNodeBoolAttribute(const tensorflow::NodeDef& nodeDef, const std::string& name)
+{
+ bool attribValue = false;
+ ReadMandatoryNodeAttributeImpl(nodeDef, name, tensorflow::AttrValue::kB,
+ [&attribValue](const tensorflow::AttrValue& attrValue)
+ {
+ attribValue = static_cast<bool>(attrValue.b());
+ });
+ return attribValue;
+}
+
uint32_t ReadMandatoryNodeUint32Attribute(const tensorflow::NodeDef& nodeDef, const std::string& name)
{
uint32_t attribValue = 0u;
@@ -338,6 +336,7 @@ const std::map<std::string, TfParser::OperationParsingFunction> TfParser::ms_Ope
{ "ConcatV2", &TfParser::ParseConcat },
{ "LRN", &TfParser::ParseLrn },
{ "MatMul", &TfParser::ParseMatMul },
+ { "Mean", &TfParser::ParseMean },
{ "Mul", &TfParser::ParseMul },
{ "Placeholder", &TfParser::ParsePlaceholder },
{ "RealDiv", &TfParser::ParseRealDiv },
@@ -2349,6 +2348,60 @@ ParsedTfOperationPtr TfParser::ParseMatMul(const tensorflow::NodeDef& nodeDef, c
return std::make_unique<ParsedMatMulTfOperation>(this, nodeDef);
}
+ParsedTfOperationPtr TfParser::ParseMean(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef)
+{
+ std::vector<OutputOfParsedTfOperation> inputs = GetInputParsedTfOperationsChecked(nodeDef, 2);
+ IOutputSlot& inputSlot = inputs[0].m_IndexedValue->ResolveArmnnOutputSlot(inputs[0].m_Index);
+ TensorInfo inputTensorInfo = inputSlot.GetTensorInfo();
+
+ if (inputs.size() != 2)
+ {
+ throw ParseException(
+ boost::str(boost::format("Mean expects two inputs!. Got %1% for Node %2% %3%")
+ % inputs.size()
+ % nodeDef.name()
+ % CHECK_LOCATION().AsString()));
+ }
+
+ bool keepDims = ReadMandatoryNodeBoolAttribute(nodeDef, "keep_dims");
+
+ ParsedConstTfOperation<int32_t>* axisNode =
+ boost::polymorphic_downcast<ParsedConstTfOperation<int32_t>*>(inputs[1].m_IndexedValue);
+
+ const TensorInfo& axisTensorInfo = axisNode->GetTensorInfo();
+
+ ConstTensor axisTensor(axisTensorInfo, axisNode->GetStorage());
+ const int* axisData = static_cast<const int*>(axisTensor.GetMemoryArea());
+
+ TensorInfo outputTensorInfo;
+ MeanDescriptor meanDescriptor;
+ meanDescriptor.m_KeepDims = keepDims;
+
+ // Negative axis values are supported so that the process requires
+ // to convert them into the corresponding positive ones.
+ // Duplicate values are also removed.
+ std::vector<int> rawAxisVector(axisData, axisData + axisTensorInfo.GetNumElements());
+ std::set<unsigned int> positiveAxisSet;
+ int rank = static_cast<int>(inputTensorInfo.GetNumDimensions());
+
+ std::transform(rawAxisVector.begin(), rawAxisVector.end(),
+ std::inserter(positiveAxisSet, positiveAxisSet.begin()),
+ [rank](int i) -> unsigned int { return static_cast<unsigned int>((i + rank) % rank); });
+
+ CalculateReducedOutputTensoInfo(inputTensorInfo, axisTensorInfo, positiveAxisSet, keepDims, outputTensorInfo);
+
+ if (inputTensorInfo.GetNumDimensions() > positiveAxisSet.size())
+ {
+ meanDescriptor.m_Axis.assign(positiveAxisSet.begin(), positiveAxisSet.end());
+ }
+
+ IConnectableLayer* layer = m_Network->AddMeanLayer(meanDescriptor, nodeDef.name().c_str());
+ layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
+ inputSlot.Connect(layer->GetInputSlot(0));
+
+ return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer);
+}
+
/// An ParsedTfOperation for a Mul node.
/// Creation of the armnn Mul layer is deferred until it is actually needed, because Mul nodes
/// are also used for the first part of a leaky relu activation function (Mul followed by Maximum)