aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfParser
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnTfParser')
-rw-r--r--src/armnnTfParser/README.md5
-rw-r--r--src/armnnTfParser/TensorFlowSupport.md111
-rw-r--r--src/armnnTfParser/TfParser.cpp2200
-rw-r--r--src/armnnTfParser/TfParser.hpp199
-rw-r--r--src/armnnTfParser/test/Activations.cpp113
-rw-r--r--src/armnnTfParser/test/Addition.cpp78
-rw-r--r--src/armnnTfParser/test/BiasAdd.cpp104
-rw-r--r--src/armnnTfParser/test/BroadcastForAdd.cpp149
-rw-r--r--src/armnnTfParser/test/Concat.cpp183
-rw-r--r--src/armnnTfParser/test/ConcatOfConcats.cpp316
-rw-r--r--src/armnnTfParser/test/Constant.cpp321
-rw-r--r--src/armnnTfParser/test/Convolution2d.cpp322
-rw-r--r--src/armnnTfParser/test/DepthwiseConvolution2d.cpp166
-rw-r--r--src/armnnTfParser/test/FullyConnected.cpp579
-rw-r--r--src/armnnTfParser/test/FusedBatchNorm.cpp175
-rw-r--r--src/armnnTfParser/test/Identity.cpp161
-rw-r--r--src/armnnTfParser/test/LocalResponseNormalization.cpp121
-rw-r--r--src/armnnTfParser/test/MultiOutput.cpp144
-rw-r--r--src/armnnTfParser/test/Multiplication.cpp172
-rw-r--r--src/armnnTfParser/test/PassThru.cpp52
-rw-r--r--src/armnnTfParser/test/Pooling.cpp112
-rw-r--r--src/armnnTfParser/test/Reshape.cpp86
-rw-r--r--src/armnnTfParser/test/ResizeBilinear.cpp114
-rw-r--r--src/armnnTfParser/test/Shape.cpp94
-rw-r--r--src/armnnTfParser/test/Softmax.cpp55
-rw-r--r--src/armnnTfParser/test/Squeeze.cpp108
-rw-r--r--src/armnnTfParser/test/TestDependencies.cpp296
-rw-r--r--src/armnnTfParser/test/TestMultiInputsOutputs.cpp92
28 files changed, 6628 insertions, 0 deletions
diff --git a/src/armnnTfParser/README.md b/src/armnnTfParser/README.md
new file mode 100644
index 0000000000..fe3f2b8950
--- /dev/null
+++ b/src/armnnTfParser/README.md
@@ -0,0 +1,5 @@
+#The Arm NN TensorFlow parser
+
+`armnnTfParser` is a library for loading Neural Networks defined by TensorFlow protobuf files into the Arm NN runtime.
+
+For more information about the TensorFlow operators that are supported, and the networks that have been tested, see [TensorFlowSupport.md](./TensorFlowSupport.md) \ No newline at end of file
diff --git a/src/armnnTfParser/TensorFlowSupport.md b/src/armnnTfParser/TensorFlowSupport.md
new file mode 100644
index 0000000000..d052a70d49
--- /dev/null
+++ b/src/armnnTfParser/TensorFlowSupport.md
@@ -0,0 +1,111 @@
+#TensorFlow operators that the Arm NN SDK supports
+
+This reference guide provides a list of TensorFlow operators the Arm NN SDK currently supports.
+
+The Arm NN SDK TensorFlow parser currently only supports fp32 operators.
+
+These are the TensorFlow operators that the Arm NN SDK currently supports:
+
+**avg_pool**
+
+See the TensorFlow [avg_pool documentation](https://www.tensorflow.org/api_docs/python/tf/nn/avg_pool) for more information.
+
+**bias_add**
+
+ See the TensorFlow [bias_add documentation](https://www.tensorflow.org/api_docs/python/tf/nn/bias_add) for more information.
+
+**conv2d**
+
+ See the TensorFlow [conv2d documentation](https://www.tensorflow.org/api_docs/python/tf/nn/conv2d) for more information.
+
+**identity**
+
+See the TensorFlow [identity documentation](https://www.tensorflow.org/api_docs/python/tf/identity) for more information.
+
+**local_response_normalization**
+
+See the TensorFlow [local_response_normalization documentation](https://www.tensorflow.org/api_docs/python/tf/nn/local_response_normalization) for more information.
+
+**max_pool**
+
+See the TensorFlow [max_pool documentation](https://www.tensorflow.org/api_docs/python/tf/nn/max_pool) for more information.
+
+**relu**
+
+ See the TensorFlow [relu documentation](https://www.tensorflow.org/api_docs/python/tf/nn/relu) for more information.
+
+**relu6**
+
+ See the TensorFlow [relu6 documentation](https://www.tensorflow.org/api_docs/python/tf/nn/relu6) for more information.
+
+**shape**
+
+ See the TensorFlow [shape documentation](https://www.tensorflow.org/api_docs/python/tf/shape) for more information.
+
+**sigmoid**
+
+ See the TensorFlow [sigmoid documentation](https://www.tensorflow.org/api_docs/python/tf/sigmoid) for more information.
+
+**softplus**
+
+See the TensorFlow [softplus documentation](https://www.tensorflow.org/api_docs/python/tf/nn/softplus) for more information.
+
+**squeeze**
+
+See the TensorFlow [squeeze documentation](https://www.tensorflow.org/api_docs/python/tf/squeeze) for more information.
+
+**tanh**
+
+See the TensorFlow [tanh documentation](https://www.tensorflow.org/api_docs/python/tf/tanh) for more information.
+
+The Arm NN SDK TensorFlow parser currently partially supports:
+
+**add**
+
+The parser does not support all forms of [broadcast composition](https://www.tensorflow.org/performance/xla/broadcasting), only broadcasting of scalars and 1D tensors. See the TensorFlow [add operator documentation](https://www.tensorflow.org/api_docs/python/tf/add) for more information.
+
+**depthwise_conv2D_native**
+
+The parser only supports a dilation rate of (1,1,1,1). See the TensorFlow [depthwise_conv2d_native documentation](https://www.tensorflow.org/api_docs/python/tf/nn/depthwise_conv2d_native) for more information.
+
+**fused_batch_norm**
+
+The parser does not support training outputs. See the TensorFlow [fused_batch_norm documentation](https://www.tensorflow.org/api_docs/python/tf/nn/fused_batch_norm) for more information.
+
+**matmul**
+
+The parser only supports constant weights in a fully connected layer. See the TensorFlow [matmul documentation](https://www.tensorflow.org/api_docs/python/tf/matmul) for more information.
+
+**multiply**
+
+The parser does not support all forms of [broadcast composition](https://www.tensorflow.org/performance/xla/broadcasting), only broadcasting of scalars and 1D tensors. See the TensorFlow [multiply documentation](https://www.tensorflow.org/api_docs/python/tf/multiply) for more information. No broadcasting supported on the NEON backend.
+
+**placeholder**
+
+ The parser only supports the NHWC data format in the input layer. See the TensorFlow [placeholder documentation](https://www.tensorflow.org/api_docs/python/tf/placeholder) for more information.
+
+**reshape**
+
+The parser does not support reshaping to or from 4D. See the TensorFlow [reshape documentation](https://www.tensorflow.org/api_docs/python/tf/reshape) for more information.
+
+**resize_images**
+
+The parser only supports `ResizeMethod.BILINEAR`. See the TensorFlow [resize_images documentation](https://www.tensorflow.org/api_docs/python/tf/image/resize_images) for more information.
+
+**softmax**
+
+The parser only supports 2D inputs and does not support selecting the `softmax` dimension. See the TensorFlow [softmax documentation](https://www.tensorflow.org/api_docs/python/tf/nn/softmax) for more information.
+
+
+
+Arm tests these operators with the following TensorFlow fp32 neural networks:
+
+* Cifar10.
+
+* Lenet.
+
+* mobilenet_v1_1.0_224. The Arm NN SDK only supports the non*_quant version of the network. See the [MobileNet_v1 documentation](https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet_v1.md) for more information on _quant networks.
+
+* inception_v3. The Arm NN SDK only supports the official inception_v3 transformed model using the GPU acceleration only, but NEON acceleration is not supported at the moment. See the TensorFlow documentation on [preparing models for mobile deployment](https://www.tensorflow.org/mobile/prepare_models) for more information on how to transform the inception_v3 network.
+
+More machine learning operators will be supported in future releases.
diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp
new file mode 100644
index 0000000000..7c8e01b112
--- /dev/null
+++ b/src/armnnTfParser/TfParser.cpp
@@ -0,0 +1,2200 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// See LICENSE file in the project root for full license information.
+//
+#include "TfParser.hpp"
+
+#include <armnn/INetwork.hpp>
+#include <armnn/Utils.hpp>
+#include <armnn/TypesUtils.hpp>
+#include <armnn/Exceptions.hpp>
+#include <armnn/Descriptors.hpp>
+
+#include <GraphTopologicalSort.hpp>
+#include <Permute.hpp>
+
+#include <google/protobuf/io/zero_copy_stream_impl.h>
+#include <google/protobuf/text_format.h>
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h"
+
+#include <boost/assert.hpp>
+#include <boost/format.hpp>
+#include <boost/core/ignore_unused.hpp>
+#include <boost/log/trivial.hpp>
+#include <boost/numeric/conversion/cast.hpp>
+#include <boost/polymorphic_cast.hpp>
+
+#include <memory>
+#include <sstream>
+#include <numeric>
+#include <functional>
+
+using namespace armnn;
+
+namespace armnnTfParser
+{
+namespace
+{
+
+const PermutationVector NHWCToArmNN = { 0, 2, 3, 1 };
+const PermutationVector ArmNNToNHWC = { 0, 3, 1, 2 };
+
+IConnectableLayer* AddSwizzleLayer(INetwork& network, IOutputSlot& input, const PermutationVector& mapping,
+ const std::string& name)
+{
+ // Add swizzle layer
+ IConnectableLayer* const layer = network.AddPermuteLayer(mapping, name.c_str());
+
+ // Connect intput to swizzle layer
+ input.Connect(layer->GetInputSlot(0));
+
+ // Setup swizzled output
+ const TensorInfo outInfo = armnnUtils::Permuted(input.GetTensorInfo(), mapping);
+ layer->GetOutputSlot(0).SetTensorInfo(outInfo);
+
+ return layer;
+}
+
+IConnectableLayer* SwizzleInDeswizzleOut(INetwork& network, IOutputSlot& input, IConnectableLayer& layer,
+ const std::string& name)
+{
+ // Add swizzle layer
+ IConnectableLayer* const swizzleLayer = AddSwizzleLayer(network, input, NHWCToArmNN, "swizzle_for-" + name);
+
+ // Connect swizzledInput to layer
+ swizzleLayer->GetOutputSlot(0).Connect(layer.GetInputSlot(0));
+
+ // Add deswizzle layer
+ IConnectableLayer* const deswizzleLayer = AddSwizzleLayer(network, layer.GetOutputSlot(0), ArmNNToNHWC,
+ "deswizzle_for-" + name);
+
+ return deswizzleLayer;
+}
+
+template <typename Callable>
+void ReadMandatoryNodeAttributeImpl(const tensorflow::NodeDef& nodeDef,
+ const std::string& attribName,
+ tensorflow::AttrValue::ValueCase expectedValueCase,
+ Callable callable)
+{
+ auto iter = nodeDef.attr().find(attribName);
+ if (iter != nodeDef.attr().end())
+ {
+ const auto& attrValue = iter->second;
+ if (attrValue.value_case() == expectedValueCase)
+ {
+ callable(attrValue);
+ }
+ else
+ {
+ throw ParseException(boost::str(boost::format(
+ "Attribute %1% of node %2% expected to have %3% as tensorflow::AttrValue::ValueCase, "
+ "but found %4% instead")
+ % attribName
+ % nodeDef.name()
+ % static_cast<int>(expectedValueCase)
+ % static_cast<int>(attrValue.value_case())));
+ }
+ }
+ else
+ {
+ throw ParseException(boost::str(boost::format("Could not find required attribute %1% in node %2%")
+ % attribName % nodeDef.name()));
+ }
+}
+
+template <typename Callable>
+void ReadOptionalNodeAttributeImpl(const tensorflow::NodeDef& nodeDef,
+ const std::string& attribName,
+ tensorflow::AttrValue::ValueCase expectedValueCase,
+ Callable callable)
+{
+ auto iter = nodeDef.attr().find(attribName);
+ if (iter != nodeDef.attr().end())
+ {
+ const auto& attrValue = iter->second;
+ if (attrValue.value_case() == expectedValueCase)
+ {
+ callable(attrValue);
+ }
+ else
+ {
+ throw ParseException(boost::str(boost::format(
+ "Attribute %1% of node %2% expected to have %3% as tensorflow::AttrValue::ValueCase, "
+ "but found %4% instead")
+ % attribName
+ % nodeDef.name()
+ % static_cast<int>(expectedValueCase)
+ % static_cast<int>(attrValue.value_case())));
+ }
+ }
+}
+
+float ReadMandatoryNodeFloatAttribute(const tensorflow::NodeDef& nodeDef, const std::string& name)
+{
+ float attribValue = 0.0f;
+ ReadMandatoryNodeAttributeImpl(nodeDef, name, tensorflow::AttrValue::kF,
+ [&attribValue](const tensorflow::AttrValue& attrValue)
+ {
+ attribValue = attrValue.f();
+ });
+ return attribValue;
+}
+
+uint32_t ReadMandatoryNodeUint32Attribute(const tensorflow::NodeDef& nodeDef, const std::string& name)
+{
+ uint32_t attribValue = 0u;
+ ReadMandatoryNodeAttributeImpl(nodeDef, name, tensorflow::AttrValue::kI,
+ [&attribValue](const tensorflow::AttrValue& attrValue)
+ {
+ attribValue = static_cast<uint32_t>(attrValue.i());
+ });
+ return attribValue;
+}
+
+std::string ReadMandatoryNodeStringAttribute(const tensorflow::NodeDef& nodeDef, const std::string& name)
+{
+ std::string attribValue = "";
+ ReadMandatoryNodeAttributeImpl(nodeDef, name, tensorflow::AttrValue::kS,
+ [&attribValue](const tensorflow::AttrValue& attrValue)
+ {
+ attribValue = attrValue.s();
+ });
+ return attribValue;
+}
+
+std::vector<uint32_t> ReadMandatoryNodeUint32ListAttribute(const tensorflow::NodeDef& nodeDef,
+ const std::string& name)
+{
+ std::vector<uint32_t> attriList;
+ ReadMandatoryNodeAttributeImpl(nodeDef, name, tensorflow::AttrValue::kList,
+ [&attriList](const tensorflow::AttrValue& attrValue)
+ {
+ for (int attriNum = 0; attriNum < attrValue.list().i_size(); ++attriNum)
+ {
+ attriList.push_back(static_cast<uint32_t>(attrValue.list().i(attriNum)));
+ }
+ });
+
+ return attriList;
+}
+
+std::vector<uint32_t> ReadOptionalNodeUint32ListAttribute(const tensorflow::NodeDef& nodeDef,
+ const std::string& name)
+{
+ std::vector<uint32_t> attriList;
+ ReadOptionalNodeAttributeImpl(nodeDef, name, tensorflow::AttrValue::kList,
+ [&attriList](const tensorflow::AttrValue& attrValue)
+ {
+ for (int attriNum = 0; attriNum < attrValue.list().i_size(); ++attriNum)
+ {
+ attriList.push_back(static_cast<uint32_t>(attrValue.list().i(attriNum)));
+ }
+ });
+
+ return attriList;
+}
+
+bool ReadOptionalNodeBoolAttribute(const tensorflow::NodeDef& nodeDef,
+ const std::string& name,
+ bool defaultValue = false)
+{
+ bool attribValue = defaultValue;
+ ReadOptionalNodeAttributeImpl(nodeDef, name, tensorflow::AttrValue::kB,
+ [&attribValue](const tensorflow::AttrValue& attrValue)
+ {
+ attribValue = attrValue.b();
+ });
+ return attribValue;
+}
+
+tensorflow::DataType ReadMandatoryNodeTypeAttribute(const tensorflow::NodeDef& nodeDef, const std::string& name)
+{
+ tensorflow::DataType attribValue = tensorflow::DT_INVALID;
+ ReadMandatoryNodeAttributeImpl(nodeDef, name, tensorflow::AttrValue::kType,
+ [&attribValue](const tensorflow::AttrValue& attrValue)
+ {
+ attribValue = attrValue.type();
+ });
+ return attribValue;
+}
+
+TensorInfo PrepareReshape(const TensorInfo& input, const std::vector<int32_t>& targetDims)
+{
+ std::vector<unsigned int> outDims(targetDims.begin(), targetDims.end());
+ const auto stretchDim = std::find(targetDims.begin(), targetDims.end(), -1);
+
+ if (stretchDim != targetDims.end())
+ {
+ if (std::find(std::next(stretchDim), targetDims.end(), -1) != targetDims.end())
+ {
+ throw ParseException("At most one component of shape can be -1");
+ }
+
+ auto targetNumElements = boost::numeric_cast<unsigned int>(std::accumulate(targetDims.begin(), targetDims.end(),
+ -1, std::multiplies<int32_t>()));
+ auto stretchIndex = static_cast<size_t>(std::distance(targetDims.begin(), stretchDim));
+ outDims[stretchIndex] = input.GetNumElements() / targetNumElements;
+ }
+
+ TensorInfo reshapeInfo = input;
+ reshapeInfo.SetShape(TensorShape{ static_cast<unsigned int>(outDims.size()), outDims.data() });
+
+ return reshapeInfo;
+}
+
+// We need the input0Slot to guide the reshape for input1Slot
+IOutputSlot* BroadcastForAddandMul(IOutputSlot* input0Slot, IOutputSlot* input1Slot, bool isNHWC, INetwork& m_Network,
+ const tensorflow::NodeDef& nodeDef)
+{
+ const TensorInfo& input1Info = input1Slot->GetTensorInfo();
+ const TensorInfo inputTensorInfo = input0Slot->GetTensorInfo();
+ const unsigned int matchDim = inputTensorInfo.GetNumDimensions() - (isNHWC ? 1 : 3);
+ std::array<unsigned int, MaxNumOfTensorDimensions> reshapedDimensions;
+ std::fill_n(reshapedDimensions.begin(), inputTensorInfo.GetNumDimensions(), 1);
+ reshapedDimensions[matchDim] = input1Info.GetShape()[0];
+
+ armnn::TensorInfo reshapedInfo = input1Info;
+ reshapedInfo.SetShape(TensorShape{ inputTensorInfo.GetNumDimensions(), reshapedDimensions.data() });
+
+ const std::string reshapeLayerName = "reshape_for-" + nodeDef.name();
+ ReshapeDescriptor reshapeDesc;
+ reshapeDesc.m_TargetShape = reshapedInfo.GetShape();
+ IConnectableLayer* const reshapeLayer = m_Network.AddReshapeLayer(reshapeDesc, reshapeLayerName.c_str());
+
+ input1Slot->Connect(reshapeLayer->GetInputSlot(0));
+ reshapeLayer->GetOutputSlot(0).SetTensorInfo(reshapedInfo);
+
+ input1Slot = &reshapeLayer->GetOutputSlot(0);
+
+ return input1Slot;
+}
+
+OutputId ParseOutputId(const std::string & name)
+{
+ unsigned int outputNum = 0;
+ size_t colonPos = name.find_last_of(":");
+ if (colonPos != std::string::npos)
+ {
+ int n = std::stoi(name.substr(colonPos+1));
+ if (n<0 || n>100)
+ {
+ throw ParseException("Output tensor id is out of range for "+name);
+ }
+ outputNum = static_cast<unsigned int>(n);
+ }
+ return OutputId(name.substr(0,colonPos),outputNum);
+}
+
+} // namespace
+
+const std::map<std::string, TfParser::OperationParsingFunction> TfParser::ms_OperationNameToParsingFunctions = {
+ { "Const", &TfParser::ParseConst },
+ { "Add", &TfParser::ParseAdd },
+ { "BiasAdd", &TfParser::ParseBiasAdd },
+ { "Identity", &TfParser::ParseIdentity },
+ { "Conv2D", &TfParser::ParseConv2D },
+ { "DepthwiseConv2dNative", &TfParser::ParseDepthwiseConv2D },
+ { "FusedBatchNorm", &TfParser::ParseFusedBatchNorm },
+ { "ConcatV2", &TfParser::ParseConcat },
+ { "LRN", &TfParser::ParseLrn },
+ { "MatMul", &TfParser::ParseMatMul },
+ { "Mul", &TfParser::ParseMul },
+ { "Placeholder", &TfParser::ParsePlaceholder },
+ { "Relu", &TfParser::ParseRelu },
+ { "Relu6", &TfParser::ParseRelu6 },
+ { "Reshape", &TfParser::ParseReshape },
+ { "ResizeBilinear", &TfParser::ParseResizeBilinear },
+ { "Shape", &TfParser::ParseShape },
+ { "Squeeze", &TfParser::ParseSqueeze },
+ { "Sigmoid", &TfParser::ParseSigmoid },
+ { "Softmax", &TfParser::ParseSoftmax },
+ { "Softplus", &TfParser::ParseSoftplus },
+ { "Tanh", &TfParser::ParseTanh },
+ { "MaxPool", &TfParser::ParseMaxPool },
+ { "AvgPool", &TfParser::ParseAvgPool },
+};
+
+ITfParser* ITfParser::CreateRaw()
+{
+ return new TfParser();
+}
+
+ITfParserPtr ITfParser::Create()
+{
+ return ITfParserPtr(CreateRaw(), &ITfParser::Destroy);
+}
+
+void ITfParser::Destroy(ITfParser* parser)
+{
+ delete parser;
+}
+
+inline void CalculateSamePadding(uint32_t inputSize, uint32_t stride,
+ uint32_t filterSize, bool samePadding,
+ uint32_t* paddingFront, uint32_t* paddingBack) {
+ *paddingFront = 0;
+ *paddingBack = 0;
+
+ if (samePadding) {
+ uint32_t outputSize = (inputSize + stride - 1) / stride;
+ uint32_t temp = (outputSize - 1) * stride + filterSize;
+ if (temp > inputSize) {
+ *paddingFront = (temp - inputSize) / 2;
+ *paddingBack = (temp - inputSize) - *paddingFront;
+ }
+ }
+}
+
+void CalcPadding(uint32_t input, uint32_t kernel, uint32_t stride, uint32_t& outPadHead, uint32_t& outPadTail,
+ bool samePadding)
+{
+ CalculateSamePadding(input, stride, kernel, samePadding, &outPadHead, &outPadTail);
+}
+
+/// An Abstract base class which represents a single tensorflow operation (node)
+/// that has been (potentially partially) converted to Armnn.
+/// It may not yet have been fully converted into actual Armnn layers.
+class ParsedTfOperation
+{
+public:
+ ParsedTfOperation(TfParser* parser, const tensorflow::NodeDef& node)
+ : m_Parser(parser)
+ , m_Node(node)
+ {
+ }
+
+ virtual ~ParsedTfOperation() {};
+
+ const tensorflow::NodeDef& GetNode() const { return m_Node; }
+
+ /// Gets the ArmNN IOutputSlot corresponding to the given output index of the Tensorflow operation.
+ /// This may result in the creation of Armnn layers if this was deferred (e.g. see ParsedConstTfOperation).
+ virtual IOutputSlot& ResolveArmnnOutputSlot(unsigned int tfOutputIndex) = 0;
+
+ /// If this operation is an Identity then this will follow return the 'parent' operation (recursively).
+ virtual ParsedTfOperation* ResolveIdentityOperations()
+ {
+ return this;
+ }
+
+protected:
+ TfParser* m_Parser;
+ const tensorflow::NodeDef& m_Node;
+};
+
+/// An ParsedTfOperation where the Armnn equivalent is a single layer,
+/// with output slots that correspond directly to the Tf node outputs.
+class SingleLayerParsedTfOperation : public ParsedTfOperation
+{
+public:
+ SingleLayerParsedTfOperation(TfParser* parser, const tensorflow::NodeDef& node, IConnectableLayer* layer)
+ : ParsedTfOperation(parser, node)
+ , m_Layer(layer)
+ {
+ }
+
+ IOutputSlot& ResolveArmnnOutputSlot(unsigned int tfOutputIndex) override
+ {
+ BOOST_ASSERT(m_Layer);
+ // Assume one-to-one mapping between Tf and armnn output slots.
+ unsigned int armnnOutputSlotIdx = tfOutputIndex;
+ if (armnnOutputSlotIdx >= m_Layer->GetNumOutputSlots())
+ {
+ throw ParseException(
+ boost::str(boost::format("The requested output slot #%1% "
+ "for %2% does not exist") % armnnOutputSlotIdx % m_Layer->GetName()));
+ }
+ return m_Layer->GetOutputSlot(armnnOutputSlotIdx);
+ }
+
+protected:
+ IConnectableLayer* m_Layer;
+};
+
+/// A SingleLayerParsedTfOperation for deferred layer creation
+class DeferredSingleLayerParsedTfOperation : public SingleLayerParsedTfOperation
+{
+public:
+ DeferredSingleLayerParsedTfOperation(TfParser* parser, const tensorflow::NodeDef& node)
+ : SingleLayerParsedTfOperation(parser, node, nullptr)
+ {
+ }
+
+ IOutputSlot& ResolveArmnnOutputSlot(unsigned int tfOutputIndex) override
+ {
+ if (!m_Layer)
+ {
+ CreateLayerDeferred();
+ }
+ return SingleLayerParsedTfOperation::ResolveArmnnOutputSlot(tfOutputIndex);
+ }
+
+private:
+ virtual void CreateLayerDeferred() = 0;
+};
+
+
+TfParser::TfParser()
+ : m_Network(nullptr, nullptr)
+{
+}
+
+
+const tensorflow::NodeDef* TfParser::ResolveIdentityNode(const tensorflow::NodeDef* nodeDef)
+{
+ if (nodeDef->op() != "Identity")
+ {
+ return nodeDef;
+ }
+
+ if (nodeDef->input_size() != 1)
+ {
+ throw ParseException("Identity node does not have correct amount of inputs!");
+ }
+
+ auto it = m_NodesByName.find(nodeDef->input(0));
+ if (it != m_NodesByName.end())
+ {
+ const tensorflow::NodeDef* inputNode = it->second;
+ return ResolveIdentityNode(inputNode);
+ }
+ else
+ {
+ throw ParseException("Cannot find what the Identity node is linked to!");
+ }
+}
+
+std::vector<OutputOfConstNodeDef>
+TfParser::GetTfInputNodes(const tensorflow::NodeDef& nodeDef) const
+{
+ std::vector<OutputOfConstNodeDef> ret;
+
+ ret.reserve(boost::numeric_cast<size_t>(nodeDef.input_size()));
+ for (int j = 0; j < nodeDef.input_size(); ++j)
+ {
+ OutputId outputId = ParseOutputId(nodeDef.input(j));
+ auto inputIt = m_NodesByName.find(outputId.m_IndexedValue);
+ if (inputIt == m_NodesByName.end())
+ {
+ throw ParseException(
+ "Can't find node '" + nodeDef.input(j) +
+ "', which is listed as an input of '" + nodeDef.name() + "'");
+ }
+ ret.push_back(OutputOfConstNodeDef(inputIt->second,outputId.m_Index));
+ }
+
+ return ret;
+}
+
+std::vector<OutputOfParsedTfOperation>
+TfParser::GetInputParsedTfOperationsChecked(const tensorflow::NodeDef& nodeDef,
+ std::size_t expectedNumInputs)
+{
+ // Fetch the tensorflow nodes connected as inputs and validate the size.
+ std::vector<OutputOfConstNodeDef> nodes = GetTfInputNodes(nodeDef);
+ const std::size_t numInputs = nodes.size();
+ if (numInputs != expectedNumInputs)
+ {
+ throw ParseException(boost::str(boost::format("Unexpected number of inputs for node %1%. "
+ "Expected %2%, found %3%") % nodeDef.name() % expectedNumInputs % numInputs));
+ }
+ // Fetch the corresponding ParsedTfOperation operations
+ std::vector<OutputOfParsedTfOperation> result;
+ for (auto&& node : nodes)
+ {
+ auto it = m_ParsedTfOperations.find(node.m_IndexedValue->name());
+ if (it == m_ParsedTfOperations.end())
+ {
+ throw ParseException("Node with name '" + node.m_IndexedValue->name() + "' has not been parsed");
+ }
+ ParsedTfOperation* parsedOp = it->second.get();
+ // Transparently 'skip' any Identity operations. This simplifies the logic inside the ParseXXX() functions.
+ parsedOp = parsedOp->ResolveIdentityOperations();
+ result.push_back(OutputOfParsedTfOperation(parsedOp,node.m_Index));
+ }
+ return result;
+}
+
+ParsedTfOperationPtr TfParser::ParseAdd(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef)
+{
+ std::vector<OutputOfParsedTfOperation> inputs = GetInputParsedTfOperationsChecked(nodeDef, 2);
+
+ // If one of the inputs is a MatMul and the other is a const, then we handle both nodes together as FullyConnected
+ if (inputs[0].m_IndexedValue->GetNode().op() == "MatMul" &&
+ HasParsedConstTensor<float>(inputs[1].m_IndexedValue->GetNode().name()))
+ {
+ IConnectableLayer* layer =
+ AddFullyConnectedLayer(inputs[0].m_IndexedValue->GetNode(),
+ &nodeDef,nodeDef.name().c_str());
+ return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer);
+ }
+ else if (HasParsedConstTensor<float>(inputs[0].m_IndexedValue->GetNode().name()) &&
+ inputs[1].m_IndexedValue->GetNode().op() == "MatMul")
+ {
+ IConnectableLayer* layer =
+ AddFullyConnectedLayer(inputs[1].m_IndexedValue->GetNode(),
+ &nodeDef,nodeDef.name().c_str());
+ return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer);
+ }
+ else
+ {
+ // Otherwise it's just a regular addition
+ return AddAdditionLayer(nodeDef);
+ }
+}
+
+ParsedTfOperationPtr TfParser::ParseBiasAdd(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef)
+{
+ return AddAdditionLayer(nodeDef, true);
+}
+
+/// An ParsedTfOperation which forwards to another (used for Identity nodes).
+class ParsedIdentityTfOperation : public ParsedTfOperation
+{
+public:
+ ParsedIdentityTfOperation(TfParser* parser, const tensorflow::NodeDef& node, ParsedTfOperation* representative)
+ : ParsedTfOperation(parser, node)
+ , m_Representative(representative)
+ {
+ }
+
+ virtual IOutputSlot& ResolveArmnnOutputSlot(unsigned int tfOutputIndex) override
+ {
+ BOOST_ASSERT(m_Representative);
+ return m_Representative->ResolveArmnnOutputSlot(tfOutputIndex);
+ }
+
+ virtual ParsedTfOperation* ResolveIdentityOperations() override
+ {
+ return m_Representative->ResolveIdentityOperations();
+ }
+
+private:
+ ParsedTfOperation* m_Representative;
+};
+
+ParsedTfOperationPtr TfParser::ParseIdentity(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef)
+{
+ std::vector<OutputOfParsedTfOperation> inputs = GetInputParsedTfOperationsChecked(nodeDef, 1);
+ // Any requests for the output slots of this node should be forwarded to the node connected as input.
+ return std::make_unique<ParsedIdentityTfOperation>(this, nodeDef, inputs[0].m_IndexedValue);
+}
+
+/// An ParsedTfOperation for a Const node.
+/// Creation of the armnn ConstLayer is deferred until it is actually needed, because Const nodes are mostly used
+/// for weight inputs to MatMul/Conv2D nodes and in these cases armnn doesn't need a ConstLayer.
+template <typename T>
+class ParsedConstTfOperation : public DeferredSingleLayerParsedTfOperation
+{
+public:
+ ParsedConstTfOperation(TfParser* parser, const tensorflow::NodeDef& node,
+ const T* tensorData, const TensorInfo& tensorInfo)
+ : DeferredSingleLayerParsedTfOperation(parser, node),
+ m_Storage(tensorData, tensorData + tensorInfo.GetNumElements()),
+ m_TensorInfo(tensorInfo)
+ {
+ BOOST_ASSERT(tensorInfo.GetDataType() == GetDataType<T>());
+ }
+
+ void CreateLayerDeferred() override
+ {
+ BOOST_ASSERT(m_Layer == nullptr);
+ m_Layer = m_Parser->m_Network->AddConstantLayer(ConstTensor(m_TensorInfo, m_Storage), m_Node.name().c_str());
+ m_Layer->GetOutputSlot(0).SetTensorInfo(m_TensorInfo);
+ }
+
+ ConstTensor GetConstTensor(bool swizzleForConvolutionWeights, std::vector<T>& outputTensorData) const
+ {
+ // Mappings from TensorFlow filter tensors to the ArmNN filter tensors.
+ // Tensorflow weights are [H, W, In, Out]
+ // ArmNN weights are [Out, In, H, W]
+ static const PermutationVector HWIOToOIHW = {2, 3, 1, 0};
+
+ const TensorInfo outInfo = swizzleForConvolutionWeights
+ ? armnnUtils::Permuted(m_TensorInfo, HWIOToOIHW)
+ : m_TensorInfo;
+
+ outputTensorData.resize(m_TensorInfo.GetNumElements());
+
+ // Copy or swizzle from the permanent storage into the storage the caller provided.
+ if (swizzleForConvolutionWeights)
+ {
+ armnnUtils::Permute(outInfo.GetShape(), HWIOToOIHW, m_Storage.data(), outputTensorData.data());
+ }
+ else
+ {
+ memcpy(outputTensorData.data(), m_Storage.data(), m_TensorInfo.GetNumBytes());
+ }
+ // Update the result to point to the user provided storage
+ ConstTensor constTensor(outInfo, outputTensorData);
+ return constTensor;
+ }
+
+private:
+ ///< Manages the lifetime of the tensor data.
+ std::vector<T> m_Storage;
+ ///< Describes the layout of the tensor and points to the data in m_Storage.
+ TensorInfo m_TensorInfo;
+};
+
+DataType ConvertTfTensorDataType(const tensorflow::DataType tfDataType)
+{
+ switch (tfDataType)
+ {
+ case tensorflow::DT_FLOAT:
+ return DataType::Float32;
+ break;
+ case tensorflow::DT_INT32:
+ return DataType::Signed32;
+ break;
+ default:
+ throw ParseException(boost::str(
+ boost::format("Unknown DataType %1% for node")
+ % tensorflow::DataType_Name(tfDataType)));
+ }
+}
+
+struct ParseTfTensorValueList
+{
+ template<typename DataType>
+ static void Parse(
+ const tensorflow::TensorProto& tfTensor,
+ unsigned int dstElements,
+ std::vector<int8_t>& outputData);
+
+ template <typename DataType>
+ static void ReadData(const void* srcData, unsigned int numSrcElements,
+ std::vector<int8_t>& dstData, unsigned int numDstElements)
+ {
+ // If there are no entries in the list, perform no action
+ if (numSrcElements == 0)
+ {
+ return;
+ }
+
+ // If no size was provided, use the length of the value list
+ if (numDstElements == 0)
+ {
+ numDstElements = numSrcElements;
+ }
+
+ // Allocate memory
+ dstData.resize(std::max(numSrcElements, numDstElements) * sizeof(DataType));
+
+ const DataType* srcTensor = reinterpret_cast<const DataType*>(srcData);
+ DataType* dstTensor = reinterpret_cast<DataType*>(dstData.data());
+
+ // Copy the value list entries into the destination
+ std::copy(srcTensor, srcTensor + numSrcElements, dstTensor);
+
+ if (numDstElements > numSrcElements)
+ {
+ // Use the last element in the list to fill the remaining entries
+ std::fill(dstTensor + numSrcElements, dstTensor + numDstElements, srcTensor[numSrcElements - 1]);
+ }
+ }
+
+};
+
+template <>
+void ParseTfTensorValueList::Parse<float>(const tensorflow::TensorProto& tfTensor,
+ unsigned int dstElements, std::vector<int8_t>& outputData)
+{
+ ReadData<float>(tfTensor.float_val().data(), static_cast<unsigned int>(tfTensor.float_val_size()),
+ outputData, dstElements);
+}
+
+template <>
+void ParseTfTensorValueList::Parse<int32_t>(const tensorflow::TensorProto& tfTensor,
+ unsigned int dstElements, std::vector<int8_t>& outputData)
+{
+ ReadData<int32_t>(tfTensor.int_val().data(), static_cast<unsigned int>(tfTensor.int_val_size()),
+ outputData, dstElements);
+}
+
+template <template<typename> class OperatorType, typename T = int8_t>
+struct MakeTfOperation
+{
+ template<typename DataType, class... Args>
+ inline static std::unique_ptr<OperatorType<DataType>> Parse(TfParser* parser, const tensorflow::NodeDef& node,
+ Args&&... args)
+ {
+ return std::make_unique<OperatorType<DataType>>(parser, node, std::forward<Args>(args)...);
+ }
+};
+
+template <>
+struct MakeTfOperation<ParsedConstTfOperation>
+{
+ template<typename DataType, class... Args>
+ inline static std::unique_ptr<ParsedConstTfOperation<DataType>> Parse(TfParser* parser,
+ const tensorflow::NodeDef& node, const std::vector<int8_t>& tensorData, const TensorInfo& tensorInfo)
+ {
+ return std::make_unique<ParsedConstTfOperation<DataType>>(parser, node,
+ reinterpret_cast<const DataType*>(tensorData.data()), tensorInfo);
+ }
+};
+
+template <class FuncType>
+struct InvokeParseFunction
+{
+ template<class ResType, class... Args>
+ inline static ResType Result(DataType dataType, Args&&... args)
+ {
+ if (dataType == DataType::Float32)
+ {
+ return FuncType::template Parse<float>(std::forward<Args>(args)...);
+ }
+ else if (dataType == DataType::Signed32)
+ {
+ return FuncType::template Parse<int32_t>(std::forward<Args>(args)...);
+ }
+
+ return ResType();
+ }
+
+ template<class... Args>
+ inline static void Result(DataType dataType, Args&&... args)
+ {
+ if (dataType == DataType::Float32)
+ {
+ FuncType::template Parse<float>(std::forward<Args>(args)...);
+ }
+ else if (dataType == DataType::Signed32)
+ {
+ FuncType::template Parse<int32_t>(std::forward<Args>(args)...);
+ }
+ }
+};
+
+ParsedTfOperationPtr TfParser::ParseConst(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef)
+{
+ BOOST_ASSERT(nodeDef.op() == "Const");
+
+ if (nodeDef.attr().count("value") == 0)
+ {
+ throw ParseException(boost::str(
+ boost::format("Value not found for Const node - %1%")
+ % nodeDef.name()));
+ }
+
+ const tensorflow::TensorProto& tfTensor = nodeDef.attr().at("value").tensor();
+ const tensorflow::TensorShapeProto& tfTensorShape = tfTensor.tensor_shape();
+ const tensorflow::DataType tfDataType = ReadMandatoryNodeTypeAttribute(nodeDef, "dtype");
+
+ const auto GetDimensionSize = [](auto& d) { return d.size(); };
+
+ std::vector<unsigned int> dimensionSizes;
+ std::transform(tfTensorShape.dim().begin(), tfTensorShape.dim().end(),
+ std::back_inserter(dimensionSizes), GetDimensionSize);
+
+ // Calculate number of elements
+ const DataType dataType = ConvertTfTensorDataType(tfDataType);
+ unsigned int numElements = 0U;
+
+ if (!dimensionSizes.empty())
+ {
+ numElements = std::accumulate(dimensionSizes.begin(), dimensionSizes.end(),
+ 1U, std::multiplies<unsigned int>());
+ }
+
+ std::vector<int8_t> tensorData;
+
+ // Get tensor data from the list of values attribute
+ if (tfTensor.tensor_content().empty())
+ {
+ InvokeParseFunction<ParseTfTensorValueList>::Result<void>(dataType, tfTensor, numElements, tensorData);
+
+ // If the tensor shape is not defined, but there is a value list, then interpret the data as a 1D
+ // tensor of the provided number of elements
+ if (numElements == 0)
+ {
+ const unsigned int tfNumElements = static_cast<unsigned int>(tensorData.size()) / GetDataTypeSize(dataType);
+ dimensionSizes.push_back(tfNumElements);
+ }
+ }
+ // Get tensor data from tensor content attribute
+ else
+ {
+ tensorData.assign(tfTensor.tensor_content().begin(), tfTensor.tensor_content().end());
+
+ // Check if a tensor shape is defined for the tensor content
+ if (numElements == 0)
+ {
+ throw ParseException(boost::str(
+ boost::format("No tensor shape found for Const node - %1%")
+ % nodeDef.name()));
+ }
+ }
+
+ // Const node requires at least a list of values or a content attribute
+ if (tensorData.empty())
+ {
+ throw ParseException(boost::str(
+ boost::format("No tensor data found for Const node - %1%")
+ % nodeDef.name()));
+ }
+
+ const TensorInfo tensorInfo(static_cast<unsigned int>(dimensionSizes.size()), dimensionSizes.data(), dataType);
+
+ // If we have a list of values, then the length of the list must be
+ // less than or equal to the number of elements implied by the shape argument
+ if (tensorData.size() > tensorInfo.GetNumBytes())
+ {
+ throw ParseException(boost::str(
+ boost::format("Number of elements (%1%) should be less than or equal \
+ to the number of elements implied by the shape argument (%2%) for Const node - %3%")
+ % (tensorData.size() / GetDataTypeSize(dataType))
+ % tensorInfo.GetNumElements()
+ % nodeDef.name()));
+ }
+
+ return InvokeParseFunction<MakeTfOperation<ParsedConstTfOperation>>::Result<ParsedTfOperationPtr>(
+ dataType, this, nodeDef, tensorData, tensorInfo);
+}
+
+template<typename Type>
+bool TfParser::HasParsedConstTensor(const std::string & nodeName) const
+{
+ auto it = m_ParsedTfOperations.find(nodeName);
+ if (it == m_ParsedTfOperations.end() ||
+ dynamic_cast<ParsedConstTfOperation<Type>*>(it->second.get()) == nullptr)
+ {
+ return false;
+ }
+ else
+ {
+ return true;
+ }
+}
+
+ParsedTfOperationPtr TfParser::ParseConv2D(const tensorflow::NodeDef& nodeDef,
+ const tensorflow::GraphDef& graphDef)
+{
+ std::vector<OutputOfParsedTfOperation> inputs = GetInputParsedTfOperationsChecked(nodeDef, 2);
+ IOutputSlot& inputSlot = inputs[0].m_IndexedValue->ResolveArmnnOutputSlot(inputs[0].m_Index);
+ TensorInfo inputTensorInfo = inputSlot.GetTensorInfo();
+
+ if (!HasParsedConstTensor<float>(inputs[1].m_IndexedValue->GetNode().name()))
+ {
+ throw ParseException("ArmNN only supports Convolution layers with constant weights");
+ }
+ ParsedConstTfOperation<float>* weightNode =
+ boost::polymorphic_downcast<ParsedConstTfOperation<float> *>(inputs[1].m_IndexedValue);
+
+ std::string paddingString = ReadMandatoryNodeStringAttribute(nodeDef, "padding");
+ std::string dataFormat = ReadMandatoryNodeStringAttribute(nodeDef, "data_format");
+ std::vector<uint32_t> strides = ReadMandatoryNodeUint32ListAttribute(nodeDef, "strides");
+
+ // read the dilations, if present - only [1,1,1,1] (the default) is supported
+ std::vector<uint32_t> dilations = ReadOptionalNodeUint32ListAttribute(nodeDef, "dilations");
+ if (!dilations.empty())
+ {
+ for (auto dilation : dilations)
+ {
+ if (dilation != 1u)
+ {
+ throw ParseException("ArmNN only supports Convolution layers with dilations [1,1,1,1]");
+ }
+ }
+ }
+
+ Convolution2dDescriptor desc;
+ desc.m_BiasEnabled = false;
+
+ if (dataFormat == "NHWC")
+ {
+ desc.m_StrideX = strides[2];
+ desc.m_StrideY = strides[1];
+ // Swizzle input to supported memory layout
+ inputTensorInfo = armnnUtils::Permuted(inputSlot.GetTensorInfo(), NHWCToArmNN);
+ }
+ else if (dataFormat == "NCHW")
+ {
+ desc.m_StrideX = strides[3];
+ desc.m_StrideY = strides[2];
+ }
+ else
+ {
+ throw ParseException("Unsupported data format passed for Conv2D. Only NHWC and NCHW supported");
+ }
+
+ uint32_t inputHeight = inputTensorInfo.GetShape()[2];
+ uint32_t inputWidth = inputTensorInfo.GetShape()[3];
+
+ std::vector<float> outputTensorData;
+
+ ConstTensor weightTensor = weightNode->GetConstTensor(true, outputTensorData);
+
+ uint32_t weightHeight = weightTensor.GetShape()[2];
+ uint32_t weightWidth = weightTensor.GetShape()[3];
+
+ bool padding = false;
+ TensorInfo outputInfo;
+ if (paddingString == "SAME")
+ {
+ padding = true;
+ outputInfo = TensorInfo({ inputTensorInfo.GetShape()[0],
+ weightTensor.GetShape()[0],
+ static_cast<uint32_t>(ceil(
+ static_cast<float>(inputHeight) /
+ static_cast<float>(desc.m_StrideY))),
+ static_cast<uint32_t>(ceil(
+ static_cast<float>(inputWidth) /
+ static_cast<float>(desc.m_StrideX)))
+ }, DataType::Float32);
+ }
+ else if (paddingString == "VALID")
+ {
+ padding = false;
+ outputInfo = TensorInfo({ inputTensorInfo.GetShape()[0],
+ weightTensor.GetShape()[0],
+ static_cast<uint32_t>(ceil(
+ static_cast<float>(inputHeight - weightHeight + 1) /
+ static_cast<float>(desc.m_StrideY))),
+ static_cast<uint32_t>(ceil(
+ static_cast<float>(inputWidth - weightWidth + 1) /
+ static_cast<float>(desc.m_StrideX)))
+ }, DataType::Float32);
+ }
+ else
+ {
+ throw ParseException("Only 'SAME' and 'VALID' padding supported");
+ }
+
+ CalcPadding(inputHeight, weightHeight, desc.m_StrideY, desc.m_PadTop, desc.m_PadBottom, padding);
+ CalcPadding(inputWidth, weightWidth, desc.m_StrideX, desc.m_PadLeft, desc.m_PadRight, padding);
+
+ IConnectableLayer* layer = m_Network->AddConvolution2dLayer(desc, weightTensor, nodeDef.name().c_str());
+ layer->GetOutputSlot(0).SetTensorInfo(outputInfo);
+
+ if (dataFormat == "NHWC")
+ {
+ layer = SwizzleInDeswizzleOut(*m_Network, inputSlot, *layer, nodeDef.name());
+ }
+ else
+ {
+ inputSlot.Connect(layer->GetInputSlot(0));
+ }
+
+ return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer);
+}
+
+ParsedTfOperationPtr TfParser::ParseDepthwiseConv2D(const tensorflow::NodeDef& nodeDef,
+ const tensorflow::GraphDef& graphDef)
+{
+ std::vector<OutputOfParsedTfOperation> inputs = GetInputParsedTfOperationsChecked(nodeDef, 2);
+ IOutputSlot& inputSlot = inputs[0].m_IndexedValue->ResolveArmnnOutputSlot(inputs[0].m_Index);
+ TensorInfo inputTensorInfo = inputSlot.GetTensorInfo();
+
+ if (!HasParsedConstTensor<float>(inputs[1].m_IndexedValue->GetNode().name()))
+ {
+ throw ParseException("ArmNN only supports Depthwise Convolution layers with constant weights");
+ }
+ ParsedConstTfOperation<float>* weightNode =
+ boost::polymorphic_downcast<ParsedConstTfOperation<float> *>(inputs[1].m_IndexedValue);
+
+
+ std::string paddingString = ReadMandatoryNodeStringAttribute(nodeDef, "padding");
+ std::string dataFormat = ReadMandatoryNodeStringAttribute(nodeDef, "data_format");
+ std::vector<uint32_t> strides = ReadMandatoryNodeUint32ListAttribute(nodeDef, "strides");
+
+ DepthwiseConvolution2dDescriptor desc;
+ desc.m_BiasEnabled = false;
+
+ if (dataFormat == "NHWC")
+ {
+ desc.m_StrideX = strides[2];
+ desc.m_StrideY = strides[1];
+ // Swizzle input to supported memory layout
+ inputTensorInfo = armnnUtils::Permuted(inputSlot.GetTensorInfo(), NHWCToArmNN);
+ }
+ else if (dataFormat == "NCHW")
+ {
+ desc.m_StrideX = strides[3];
+ desc.m_StrideY = strides[2];
+ }
+ else
+ {
+ throw ParseException("Unsupported data format passed for DepthwiseConv2dNative. Only NHWC and NCHW supported");
+ }
+
+ uint32_t inputHeight = inputTensorInfo.GetShape()[2];
+ uint32_t inputWidth = inputTensorInfo.GetShape()[3];
+
+ std::vector<float> outputTensorData;
+
+ ConstTensor weightTensor = weightNode->GetConstTensor(true, outputTensorData);
+
+ uint32_t weightHeight = weightTensor.GetShape()[2];
+ uint32_t weightWidth = weightTensor.GetShape()[3];
+
+ bool padding = false;
+ TensorInfo outputInfo;
+ if (paddingString == "SAME")
+ {
+ padding = true;
+ outputInfo = TensorInfo({ inputTensorInfo.GetShape()[0],
+ weightTensor.GetShape()[0] * weightTensor.GetShape()[1],
+ static_cast<uint32_t>(ceil(
+ static_cast<float>(inputHeight) /
+ static_cast<float>(desc.m_StrideY))),
+ static_cast<uint32_t>(ceil(
+ static_cast<float>(inputWidth) /
+ static_cast<float>(desc.m_StrideX)))
+ }, DataType::Float32);
+ }
+ else if (paddingString == "VALID")
+ {
+ padding = false;
+ outputInfo = TensorInfo({ inputTensorInfo.GetShape()[0],
+ weightTensor.GetShape()[0] * weightTensor.GetShape()[1],
+ static_cast<uint32_t>(ceil(
+ static_cast<float>(inputHeight - weightHeight + 1) /
+ static_cast<float>(desc.m_StrideY))),
+ static_cast<uint32_t>(ceil(
+ static_cast<float>(inputWidth - weightWidth + 1) /
+ static_cast<float>(desc.m_StrideX)))
+ }, DataType::Float32);
+ }
+ else
+ {
+ throw ParseException("Only 'SAME' and 'VALID' padding supported");
+ }
+
+ CalcPadding(inputHeight, weightHeight, desc.m_StrideY, desc.m_PadTop, desc.m_PadBottom, padding);
+ CalcPadding(inputWidth, weightWidth, desc.m_StrideX, desc.m_PadLeft, desc.m_PadRight, padding);
+
+ IConnectableLayer* layer = m_Network->AddDepthwiseConvolution2dLayer(desc, weightTensor, nodeDef.name().c_str());
+ layer->GetOutputSlot(0).SetTensorInfo(outputInfo);
+
+ if (dataFormat == "NHWC")
+ {
+ layer = SwizzleInDeswizzleOut(*m_Network, inputSlot, *layer, nodeDef.name());
+ }
+ else
+ {
+ inputSlot.Connect(layer->GetInputSlot(0));
+ }
+
+ return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer);
+}
+
+ParsedTfOperationPtr TfParser::ParseFusedBatchNorm(const tensorflow::NodeDef& nodeDef,
+ const tensorflow::GraphDef& graphDef)
+{
+ std::vector<OutputOfParsedTfOperation> inputs = GetInputParsedTfOperationsChecked(nodeDef, 5);
+
+ if (!HasParsedConstTensor<float>(inputs[1].m_IndexedValue->GetNode().name()))
+ {
+ throw ParseException("ArmNN only supports FusedBatchNormalization layers with constant scale");
+ }
+ ParsedConstTfOperation<float>* scaleNode =
+ boost::polymorphic_downcast<ParsedConstTfOperation<float> *>(inputs[1].m_IndexedValue);
+
+ if (!HasParsedConstTensor<float>(inputs[2].m_IndexedValue->GetNode().name()))
+ {
+ throw ParseException("ArmNN only supports FusedBatchNormalization layers with constant offset");
+ }
+ ParsedConstTfOperation<float>* offsetNode =
+ boost::polymorphic_downcast<ParsedConstTfOperation<float> *>(inputs[2].m_IndexedValue);
+
+ if (!HasParsedConstTensor<float>(inputs[3].m_IndexedValue->GetNode().name()))
+ {
+ throw ParseException("ArmNN only supports FusedBatchNormalization layers with constant mean");
+ }
+ ParsedConstTfOperation<float>* meanNode =
+ boost::polymorphic_downcast<ParsedConstTfOperation<float> *>(inputs[3].m_IndexedValue);
+
+ if (!HasParsedConstTensor<float>(inputs[4].m_IndexedValue->GetNode().name()))
+ {
+ throw ParseException("ArmNN only supports FusedBatchNormalization layers with constant variance");
+ }
+ ParsedConstTfOperation<float>* varianceNode =
+ boost::polymorphic_downcast<ParsedConstTfOperation<float> *>(inputs[4].m_IndexedValue);
+
+ // The descriptor only has the epsilon attribute
+ BatchNormalizationDescriptor desc;
+ desc.m_Eps = ReadMandatoryNodeFloatAttribute(nodeDef, "epsilon");
+
+ // data for the parsed tensor args (scale, offset, mean, variance) must be stored locally until the layer is added
+ std::vector<float> scaleTensorData;
+ ConstTensor scaleTensor = scaleNode->GetConstTensor(false, scaleTensorData);
+
+ std::vector<float> offsetTensorData;
+ ConstTensor offsetTensor = offsetNode->GetConstTensor(false, offsetTensorData);
+
+ std::vector<float> meanTensorData;
+ ConstTensor meanTensor = meanNode->GetConstTensor(false, meanTensorData);
+
+ std::vector<float> varianceTensorData;
+ ConstTensor varianceTensor = varianceNode->GetConstTensor(false, varianceTensorData);
+
+ IConnectableLayer* layer = m_Network->AddBatchNormalizationLayer(desc,
+ meanTensor,
+ varianceTensor,
+ offsetTensor,
+ scaleTensor,
+ nodeDef.name().c_str());
+
+ IOutputSlot& inputSlot = inputs[0].m_IndexedValue->ResolveArmnnOutputSlot(inputs[0].m_Index);
+
+ const std::string dataFormat = ReadMandatoryNodeStringAttribute(nodeDef, "data_format");
+
+ if (dataFormat == "NHWC")
+ {
+ const TensorInfo outputTensorInfo = armnnUtils::Permuted(inputSlot.GetTensorInfo(), NHWCToArmNN);
+ layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
+ layer = SwizzleInDeswizzleOut(*m_Network, inputSlot, *layer, nodeDef.name());
+ }
+ else
+ {
+ layer->GetOutputSlot(0).SetTensorInfo(inputSlot.GetTensorInfo());
+ inputSlot.Connect(layer->GetInputSlot(0));
+ }
+
+ return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer);
+}
+
+ParsedTfOperationPtr TfParser::ParseConcat(const tensorflow::NodeDef& nodeDef,
+ const tensorflow::GraphDef& graphDef)
+{
+ std::vector<OutputOfConstNodeDef> nodes = GetTfInputNodes(nodeDef);
+ // In tensorflow, we have the last input of the Concat layer as the axis for concatenation
+ unsigned int numInputs = static_cast<unsigned int>(nodes.size());
+ unsigned int numConcatView = numInputs - 1;
+
+ OriginsDescriptor concatDescriptor(static_cast<uint32_t>(numConcatView), MaxNumOfTensorDimensions);
+ std::vector<unsigned int>mergeDimSizes(MaxNumOfTensorDimensions, 0u);
+
+ unsigned int mergeDim = 0;
+ std::vector<OutputOfParsedTfOperation> inputs = GetInputParsedTfOperationsChecked(nodeDef, numInputs);
+
+ // The last input is the axis for concatenation
+ if (!HasParsedConstTensor<int32_t>(inputs[numInputs - 1].m_IndexedValue->GetNode().name()))
+ {
+ throw ParseException("ArmNN only supports Concat with constant axis");
+ }
+ ParsedConstTfOperation<int32_t>* shapeNode =
+ boost::polymorphic_downcast<ParsedConstTfOperation<int32_t>*>(inputs[numInputs - 1].m_IndexedValue);
+
+ std::vector<int32_t> axisTensorData;
+ ConstTensor axisTensor = shapeNode->GetConstTensor(false, axisTensorData);
+
+ // This concatDim indicates the data format: 3 is the NHWC, 1 is the NCHW
+ const unsigned int concatDimInput = static_cast<unsigned int>(axisTensorData[0]);
+
+ // Armnn supports concatenation along the channel dimension for data format NHWC and NCHW
+ if (concatDimInput == 0 || concatDimInput == 2)
+ {
+ throw ParseException("The dimension for concatenation is not supported by Armnn");
+ }
+
+ // This is the only concatDim we support in Armnn
+ const unsigned int concatDim = 1;
+ for (unsigned int viewIndex = 0; viewIndex < numConcatView; ++viewIndex)
+ {
+ // need to double check whether it should be
+ IOutputSlot& inputSlot =
+ inputs[viewIndex].m_IndexedValue->ResolveArmnnOutputSlot(inputs[viewIndex].m_Index);
+ TensorInfo inputTensorInfo = inputSlot.GetTensorInfo();
+
+ if (inputTensorInfo.GetNumDimensions() != MaxNumOfTensorDimensions)
+ {
+ throw ParseException("The number of dimensions for input tensors of the concatenation op should be 4");
+ }
+
+ if (concatDimInput == 3)
+ {
+ inputTensorInfo = armnnUtils::Permuted(inputTensorInfo, NHWCToArmNN);
+ }
+
+ for (unsigned int dim = 0; dim < MaxNumOfTensorDimensions; ++dim)
+ {
+ mergeDimSizes[dim] = inputTensorInfo.GetShape()[dim];
+ }
+
+ for (unsigned int j = 0; j < concatDim; ++j)
+ {
+ concatDescriptor.SetViewOriginCoord(viewIndex, j, 0);
+ }
+
+ concatDescriptor.SetViewOriginCoord(viewIndex, concatDim, mergeDim);
+ mergeDim += mergeDimSizes[concatDim];
+
+ for (unsigned int j = concatDim+1; j < MaxNumOfTensorDimensions; ++j)
+ {
+ concatDescriptor.SetViewOriginCoord(viewIndex, j, 0);
+ }
+ }
+
+ mergeDimSizes[concatDim] = mergeDim;
+ armnn::IConnectableLayer *layer = m_Network->AddMergerLayer(concatDescriptor, nodeDef.name().c_str());
+
+ layer->GetOutputSlot(0).SetTensorInfo(armnn::TensorInfo(MaxNumOfTensorDimensions, mergeDimSizes.data(),
+ DataType::Float32));
+
+ for (unsigned int v = 0; v < numConcatView; ++v)
+ {
+ IOutputSlot& inputSlot = inputs[v].m_IndexedValue->ResolveArmnnOutputSlot(inputs[v].m_Index);
+ if (concatDimInput == 3)
+ {
+ IConnectableLayer* const swizzleLayer = AddSwizzleLayer(*m_Network, inputSlot, NHWCToArmNN,
+ "swizzle_for-" + nodeDef.name());
+ swizzleLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(v));
+ }
+ else
+ {
+ inputSlot.Connect(layer->GetInputSlot(v));
+ }
+ }
+
+ if (concatDimInput == 3)
+ {
+ IConnectableLayer* const deswizzleLayer = AddSwizzleLayer(*m_Network, layer->GetOutputSlot(0), ArmNNToNHWC,
+ "deswizzle_for-" + nodeDef.name());
+ layer = deswizzleLayer;
+ }
+
+ return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer);
+}
+
+ParsedTfOperationPtr TfParser::ParseShape(const tensorflow::NodeDef& nodeDef,
+ const tensorflow::GraphDef& graphDef)
+{
+ // Note: The Shape layer is handled in a special way, because:
+ // 1. ARMNN doesn't support int32 tensors which it outputs
+ // 2. ARMNN works with statically shaped tensors which are known at parse time
+ // 3. because of 1. and 2. we treat the output of Shape as a temporary const int32
+ // tensor which may be used as an input to other ops, most likely a Reshape
+
+ const tensorflow::DataType tfDataType = ReadMandatoryNodeTypeAttribute(nodeDef, "out_type");
+ if (tfDataType != tensorflow::DT_INT32)
+ {
+ throw ParseException("Armnn only supports DT_INT32 as out_type");
+ }
+
+ const std::vector<OutputOfParsedTfOperation> inputs = GetInputParsedTfOperationsChecked(nodeDef, 1);
+ IOutputSlot& prevLayerOutputSlot = inputs[0].m_IndexedValue->ResolveArmnnOutputSlot(inputs[0].m_Index);
+ const TensorInfo& prevLayerTensorInfo = prevLayerOutputSlot.GetTensorInfo();
+ unsigned int prevLayerDimensions = prevLayerTensorInfo.GetNumDimensions();
+
+ std::vector<int32_t> shapeTensorData;
+ shapeTensorData.reserve(prevLayerDimensions);
+
+ for (unsigned int i=0; i<prevLayerDimensions; ++i)
+ {
+ shapeTensorData.push_back(static_cast<int32_t>(prevLayerTensorInfo.GetShape()[i]));
+ }
+
+ TensorInfo shapeTensorInfo(1, &prevLayerDimensions, DataType::Signed32);
+
+ return std::make_unique<ParsedConstTfOperation<int32_t>>(this,
+ nodeDef,
+ &shapeTensorData[0],
+ shapeTensorInfo);
+}
+
+ParsedTfOperationPtr TfParser::ParseReshape(const tensorflow::NodeDef& nodeDef,
+ const tensorflow::GraphDef& graphDef)
+{
+ std::vector<OutputOfParsedTfOperation> inputs = GetInputParsedTfOperationsChecked(nodeDef, 2);
+ ParsedTfOperation* inputNode = inputs[0].m_IndexedValue;
+
+ if (!HasParsedConstTensor<int32_t>(inputs[1].m_IndexedValue->GetNode().name()))
+ {
+ throw ParseException("ArmNN only supports Reshape layers with constant shapes");
+ }
+ ParsedConstTfOperation<int32_t>* shapeNode =
+ boost::polymorphic_downcast<ParsedConstTfOperation<int32_t>*>(inputs[1].m_IndexedValue);
+
+ armnn::IOutputSlot& prevLayerOutputSlot = inputNode->ResolveArmnnOutputSlot(inputs[0].m_Index);
+ TensorInfo inputTensorInfo = prevLayerOutputSlot.GetTensorInfo();
+
+ std::vector<int32_t> shapeTensorData;
+ ConstTensor shapeTensor = shapeNode->GetConstTensor(false, shapeTensorData);
+ const TensorInfo outputTensorInfo = PrepareReshape(inputTensorInfo, shapeTensorData);
+
+ TensorShape targetShape = outputTensorInfo.GetShape();
+ ReshapeDescriptor reshapeDesc;
+ reshapeDesc.m_TargetShape = targetShape;
+
+ IConnectableLayer* layer = m_Network->AddReshapeLayer(reshapeDesc, nodeDef.name().c_str());
+ prevLayerOutputSlot.Connect(layer->GetInputSlot(0));
+ layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
+
+ return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer);
+}
+
+ParsedTfOperationPtr TfParser::ParseResizeBilinear(const tensorflow::NodeDef& nodeDef,
+ const tensorflow::GraphDef& graphDef)
+{
+ std::vector<OutputOfParsedTfOperation> inputs = GetInputParsedTfOperationsChecked(nodeDef, 2);
+
+ if (!HasParsedConstTensor<int32_t>(inputs[1].m_IndexedValue->GetNode().name()))
+ {
+ throw ParseException("ArmNN only supports ResizeBilinear layers with constant sizes");
+ }
+ ParsedConstTfOperation<int32_t>* sizeNode =
+ boost::polymorphic_downcast<ParsedConstTfOperation<int32_t>*>(inputs[1].m_IndexedValue);
+
+ // Check the align_corners attribute is not set
+ if (ReadOptionalNodeBoolAttribute(nodeDef, "align_corners", false))
+ {
+ throw ParseException("ArmNN only supports ResizeBilinear layers with align_corners set to false");
+ }
+
+ // data for the parsed tensor args (size) must be stored locally
+ std::vector<int32_t> sizeTensorData;
+ ConstTensor sizeTensor = sizeNode->GetConstTensor(false, sizeTensorData);
+
+ // The descriptor only has target height and width attributes, which we get from the size tensor
+ ResizeBilinearDescriptor desc;
+ desc.m_TargetHeight = static_cast<uint32_t> (sizeTensorData[0]);
+ desc.m_TargetWidth = static_cast<uint32_t> (sizeTensorData[1]);
+
+ IConnectableLayer* layer = m_Network->AddResizeBilinearLayer(desc, nodeDef.name().c_str());
+
+ IOutputSlot& inputSlot = inputs[0].m_IndexedValue->ResolveArmnnOutputSlot(inputs[0].m_Index);
+ TensorInfo inputTensorInfo = inputSlot.GetTensorInfo();
+ // the input shape is always in BHWC format, this will be swizzled below; for now,
+ // get the batch and channels to make up the ArmNN output shape with the target size
+ unsigned int outBatch = inputTensorInfo.GetShape()[0];
+ unsigned int outChannels = inputTensorInfo.GetShape()[3];
+ unsigned int outHeight = desc.m_TargetHeight;
+ unsigned int outWidth = desc.m_TargetWidth;
+ TensorShape outShape({outBatch, outChannels, outHeight, outWidth});
+ // The output DataType is always Float32, regardless of the input DataType
+ const TensorInfo outputTensorInfo(outShape, armnn::DataType::Float32);
+ layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
+
+ // TensorFlow ResizeBilinear input is always in BHWC format, so add swizzle and deswizzle layers
+ layer = SwizzleInDeswizzleOut(*m_Network, inputSlot, *layer, nodeDef.name());
+
+ return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer);
+}
+
+TensorInfo OutputShapeOfSqueeze(const tensorflow::NodeDef& nodeDef, TensorInfo inputTensorInfo)
+{
+ BOOST_ASSERT(nodeDef.op() == "Squeeze");
+ tensorflow::DataType tfDataType = ReadMandatoryNodeTypeAttribute(nodeDef, "T");
+
+ DataType type;
+ if (tfDataType == tensorflow::DT_FLOAT)
+ {
+ type = DataType::Float32;
+ }
+ else if (tfDataType == tensorflow::DT_INT32)
+ {
+ type = DataType::Signed32;
+ }
+ else
+ {
+ throw ParseException(boost::str(
+ boost::format("Unsupported DataType %1% for Squeeze operation")
+ % tensorflow::DataType_Name(tfDataType)));
+ }
+
+ std::vector<uint32_t> squeezeDims = ReadOptionalNodeUint32ListAttribute(nodeDef, "squeeze_dims");
+ if (squeezeDims.empty())
+ {
+ for(unsigned int i = 0; i < inputTensorInfo.GetNumDimensions(); i++)
+ {
+ if (inputTensorInfo.GetShape()[i] == 1)
+ {
+ squeezeDims.push_back(i);
+ }
+ }
+ }
+
+ std::vector<uint32_t> outputDims;
+ for(unsigned int i = 0; i < inputTensorInfo.GetNumDimensions(); i++)
+ {
+ bool includeDimension = (std::find(squeezeDims.begin(), squeezeDims.end(), i) == squeezeDims.end());
+ if (includeDimension)
+ {
+ outputDims.push_back(inputTensorInfo.GetShape()[i]);
+ }
+ }
+
+ if (outputDims.size() > 4)
+ {
+ throw ParseException("Unsupported shape for Squeeze");
+ }
+
+ TensorInfo outTensorInfo = TensorInfo(boost::numeric_cast<unsigned int>(outputDims.size()),
+ outputDims.data(),
+ type);
+
+ return outTensorInfo;
+}
+
+ParsedTfOperationPtr TfParser::ParseSqueeze(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef)
+{
+ std::vector<OutputOfParsedTfOperation> inputs = GetInputParsedTfOperationsChecked(nodeDef, 1);
+
+ IOutputSlot& prevLayerOutputSlot = inputs[0].m_IndexedValue->ResolveArmnnOutputSlot(inputs[0].m_Index);
+ TensorInfo inputTensorInfo = prevLayerOutputSlot.GetTensorInfo();
+
+ TensorInfo outputInfo;
+ outputInfo = OutputShapeOfSqueeze(nodeDef, inputTensorInfo);
+
+ ReshapeDescriptor reshapeDesc;
+ reshapeDesc.m_TargetShape = outputInfo.GetShape();
+ IConnectableLayer* layer = m_Network->AddReshapeLayer(reshapeDesc, nodeDef.name().c_str());
+ prevLayerOutputSlot.Connect(layer->GetInputSlot(0));
+ layer->GetOutputSlot(0).SetTensorInfo(outputInfo);
+
+ return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer);
+}
+
+ParsedTfOperationPtr TfParser::ParseLrn(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef)
+{
+ std::vector<OutputOfParsedTfOperation> inputs = GetInputParsedTfOperationsChecked(nodeDef, 1);
+
+ NormalizationDescriptor normalizationDescriptor;
+ normalizationDescriptor.m_NormMethodType = NormalizationAlgorithmMethod::LocalBrightness;
+ normalizationDescriptor.m_NormChannelType = NormalizationAlgorithmChannel::Across;
+ normalizationDescriptor.m_Alpha = ReadMandatoryNodeFloatAttribute(nodeDef, "alpha");
+ normalizationDescriptor.m_Beta = ReadMandatoryNodeFloatAttribute(nodeDef, "beta");
+ normalizationDescriptor.m_K = ReadMandatoryNodeFloatAttribute(nodeDef, "bias");
+ normalizationDescriptor.m_NormSize = ReadMandatoryNodeUint32Attribute(nodeDef, "depth_radius");
+
+ // The window size must be an odd value. For a window size of (2 * n + 1), TensorFlow defines depth_radius = n.
+ normalizationDescriptor.m_NormSize = normalizationDescriptor.m_NormSize * 2 + 1;
+
+ IOutputSlot& prevLayerOutputSlot = inputs[0].m_IndexedValue->ResolveArmnnOutputSlot(inputs[0].m_Index);
+
+ IConnectableLayer* layer = m_Network->AddNormalizationLayer(normalizationDescriptor,
+ nodeDef.name().c_str());
+
+ const TensorInfo permutedInfo = armnnUtils::Permuted(prevLayerOutputSlot.GetTensorInfo(), NHWCToArmNN);
+ layer->GetOutputSlot(0).SetTensorInfo(permutedInfo);
+
+ layer = SwizzleInDeswizzleOut(*m_Network, prevLayerOutputSlot, *layer, nodeDef.name());
+
+ return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer);
+}
+
+/// An ParsedTfOperation for a MatMul node.
+/// Creation of the armnn FullyConnected layer is deferred until it is actually needed, because MatMul nodes are
+/// often used for the first part of a biased FullyConnected (MatMul followed by Add) and in these cases armnn doesn't
+/// need a separate layer for the MatMul.
+class ParsedMatMulTfOperation : public DeferredSingleLayerParsedTfOperation
+{
+public:
+ ParsedMatMulTfOperation(TfParser* parser, const tensorflow::NodeDef& node)
+ : DeferredSingleLayerParsedTfOperation(parser, node)
+ {
+ }
+
+ void CreateLayerDeferred() override
+ {
+ BOOST_ASSERT(m_Layer == nullptr);
+ m_Layer = m_Parser->AddFullyConnectedLayer(m_Node, nullptr, m_Node.name().c_str());
+ }
+};
+
+ParsedTfOperationPtr TfParser::ParseMatMul(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef)
+{
+ // Defer the creation of the layer (see ParsedMatMulTfOperation).
+ return std::make_unique<ParsedMatMulTfOperation>(this, nodeDef);
+}
+
+ParsedTfOperationPtr TfParser::ParseMul(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef)
+{
+ boost::ignore_unused(graphDef);
+
+ std::vector<OutputOfParsedTfOperation> inputs = GetInputParsedTfOperationsChecked(nodeDef, 2);
+
+ IConnectableLayer* const layer = m_Network->AddMultiplicationLayer(nodeDef.name().c_str());
+ IOutputSlot* input0Slot = &inputs[0].m_IndexedValue->ResolveArmnnOutputSlot(inputs[0].m_Index);
+ IOutputSlot* input1Slot = &inputs[1].m_IndexedValue->ResolveArmnnOutputSlot(inputs[1].m_Index);
+
+ auto const input0NumDims = input0Slot->GetTensorInfo().GetNumDimensions();
+ auto const input1NumDims = input1Slot->GetTensorInfo().GetNumDimensions();
+
+ if (input0NumDims < input1NumDims)
+ {
+ const bool isNHWC = true;
+ input0Slot = BroadcastForAddandMul(input1Slot, input0Slot, isNHWC, *m_Network, nodeDef);
+ }
+ if (input1NumDims < input0NumDims)
+ {
+ const bool isNHWC = true;
+ input1Slot = BroadcastForAddandMul(input0Slot, input1Slot, isNHWC, *m_Network, nodeDef);
+ }
+
+ input0Slot->Connect(layer->GetInputSlot(0));
+ input1Slot->Connect(layer->GetInputSlot(1));
+
+ if (input0NumDims < input1NumDims)
+ {
+ layer->GetOutputSlot(0).SetTensorInfo(input1Slot->GetTensorInfo());
+ }
+ else
+ {
+ layer->GetOutputSlot(0).SetTensorInfo(input0Slot->GetTensorInfo());
+ }
+ return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer);
+}
+
+ParsedTfOperationPtr TfParser::ParsePlaceholder(const tensorflow::NodeDef& nodeDef,
+ const tensorflow::GraphDef& graphDef)
+{
+ boost::ignore_unused(graphDef);
+
+ std::vector<OutputOfParsedTfOperation> inputs = GetInputParsedTfOperationsChecked(nodeDef, 0);
+
+ const LayerBindingId layerId = boost::numeric_cast<LayerBindingId>(m_NetworkInputsBindingInfo.size());
+
+ auto it = m_InputShapes.find(nodeDef.name());
+ if (it == m_InputShapes.end())
+ {
+ throw ParseException("Missing input shape for Placeholder '" + nodeDef.name() + "'");
+ }
+ TensorInfo tensorInfo(it->second, DataType::Float32);
+
+ IConnectableLayer* const layer = m_Network->AddInputLayer(layerId, nodeDef.name().c_str());
+
+ layer->GetOutputSlot(0).SetTensorInfo(tensorInfo);
+
+ TrackInputBinding(layer, layerId, tensorInfo);
+
+ return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer);
+}
+
+ParsedTfOperationPtr TfParser::ParseRelu(const tensorflow::NodeDef& nodeDef,
+ const tensorflow::GraphDef& graphDef)
+{
+ boost::ignore_unused(graphDef);
+
+ ActivationDescriptor activationDesc;
+ activationDesc.m_Function = ActivationFunction::ReLu;
+ return AddActivationLayer(nodeDef, activationDesc);
+}
+
+ParsedTfOperationPtr TfParser::ParseRelu6(const tensorflow::NodeDef& nodeDef,
+ const tensorflow::GraphDef& graphDef)
+{
+ boost::ignore_unused(graphDef);
+
+ ActivationDescriptor activationDesc;
+ activationDesc.m_Function = ActivationFunction::BoundedReLu;
+ activationDesc.m_A = 6.0f;
+ activationDesc.m_B = 0.0f;
+
+ return AddActivationLayer(nodeDef, activationDesc);
+}
+
+ParsedTfOperationPtr TfParser::ParseSigmoid(const tensorflow::NodeDef& nodeDef,
+ const tensorflow::GraphDef& graphDef)
+{
+ boost::ignore_unused(graphDef);
+
+ ActivationDescriptor activationDesc;
+ activationDesc.m_Function = ActivationFunction::Sigmoid;
+
+ return AddActivationLayer(nodeDef, activationDesc);
+}
+
+ParsedTfOperationPtr TfParser::ParseSoftmax(const tensorflow::NodeDef& nodeDef,
+ const tensorflow::GraphDef& graphDef)
+{
+ boost::ignore_unused(graphDef);
+
+ std::vector<OutputOfParsedTfOperation> inputs = GetInputParsedTfOperationsChecked(nodeDef, 1);
+
+ SoftmaxDescriptor softmaxDescriptor;
+ IConnectableLayer* const layer = m_Network->AddSoftmaxLayer(softmaxDescriptor, nodeDef.name().c_str());
+
+ IOutputSlot& prevLayerSlot = inputs[0].m_IndexedValue->ResolveArmnnOutputSlot(inputs[0].m_Index);
+ prevLayerSlot.Connect(layer->GetInputSlot(0));
+ layer->GetOutputSlot(0).SetTensorInfo(prevLayerSlot.GetTensorInfo());
+
+ return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer);
+}
+
+ParsedTfOperationPtr TfParser::ParseSoftplus(const tensorflow::NodeDef& nodeDef,
+ const tensorflow::GraphDef& graphDef)
+{
+ boost::ignore_unused(graphDef);
+
+ ActivationDescriptor activationDesc;
+ activationDesc.m_Function = ActivationFunction::SoftReLu;
+
+ return AddActivationLayer(nodeDef, activationDesc);
+}
+
+ParsedTfOperationPtr TfParser::ParseTanh(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef)
+{
+ boost::ignore_unused(graphDef);
+
+ ActivationDescriptor activationDesc;
+ activationDesc.m_Function = ActivationFunction::TanH;
+ activationDesc.m_A = 1.0f;
+ activationDesc.m_B = 1.0f;
+
+ return AddActivationLayer(nodeDef, activationDesc);
+}
+
+ParsedTfOperationPtr TfParser::AddActivationLayer(const tensorflow::NodeDef& nodeDef,
+ ActivationDescriptor& activationDesc)
+{
+ std::vector<OutputOfParsedTfOperation> inputs = GetInputParsedTfOperationsChecked(nodeDef, 1);
+
+ IConnectableLayer* const layer = m_Network->AddActivationLayer(activationDesc, nodeDef.name().c_str());
+
+ IOutputSlot& prevLayerOutputSlot = inputs[0].m_IndexedValue->ResolveArmnnOutputSlot(inputs[0].m_Index);
+ prevLayerOutputSlot.Connect(layer->GetInputSlot(0));
+ layer->GetOutputSlot(0).SetTensorInfo(prevLayerOutputSlot.GetTensorInfo());
+ return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer);
+}
+
+ParsedTfOperationPtr TfParser::ParseMaxPool(const tensorflow::NodeDef& nodeDef,
+ const tensorflow::GraphDef& graphDef)
+{
+ return ParsePooling2d(nodeDef, graphDef, PoolingAlgorithm::Max);
+}
+
+ParsedTfOperationPtr TfParser::ParseAvgPool(const tensorflow::NodeDef& nodeDef,
+ const tensorflow::GraphDef& graphDef)
+{
+ return ParsePooling2d(nodeDef, graphDef, PoolingAlgorithm::Average);
+}
+
+ParsedTfOperationPtr TfParser::ParsePooling2d(const tensorflow::NodeDef& nodeDef,
+ const tensorflow::GraphDef& graphDef, PoolingAlgorithm pooltype)
+{
+ std::vector<OutputOfParsedTfOperation> inputs = GetInputParsedTfOperationsChecked(nodeDef, 1);
+ IOutputSlot& inputSlot = inputs[0].m_IndexedValue->ResolveArmnnOutputSlot(inputs[0].m_Index);
+ TensorInfo inputTensorInfo = inputSlot.GetTensorInfo();
+
+ if (inputs.size() != 1)
+ {
+ throw ParseException("2D Pooling expects one input!");
+ }
+
+ std::string paddingString = ReadMandatoryNodeStringAttribute(nodeDef, "padding");
+ std::string dataFormat = ReadMandatoryNodeStringAttribute(nodeDef, "data_format");
+ std::vector<uint32_t> strides = ReadMandatoryNodeUint32ListAttribute(nodeDef, "strides");
+ std::vector<uint32_t> ksize = ReadMandatoryNodeUint32ListAttribute(nodeDef, "ksize"); // size of pool windows
+
+ Pooling2dDescriptor pooling2dDescriptor;
+ pooling2dDescriptor.m_PoolType = pooltype;
+ pooling2dDescriptor.m_PaddingMethod = PaddingMethod::Exclude;
+ pooling2dDescriptor.m_OutputShapeRounding = OutputShapeRounding::Floor;
+
+ if (dataFormat == "NHWC")
+ {
+ pooling2dDescriptor.m_StrideX = strides[2];
+ pooling2dDescriptor.m_StrideY = strides[1];
+ pooling2dDescriptor.m_PoolWidth = ksize[2];
+ pooling2dDescriptor.m_PoolHeight = ksize[1];
+ // Swizzle input to supported memory layout
+ inputTensorInfo = armnnUtils::Permuted(inputSlot.GetTensorInfo(), NHWCToArmNN);
+ }
+ else if (dataFormat == "NCHW")
+ {
+ pooling2dDescriptor.m_StrideX = strides[3];
+ pooling2dDescriptor.m_StrideY = strides[2];
+ pooling2dDescriptor.m_PoolWidth = ksize[3];
+ pooling2dDescriptor.m_PoolHeight = ksize[2];
+ }
+ else
+ {
+ throw ParseException("Only NHWC or NCHW supported for Pooling2d");
+ }
+
+ uint32_t inputHeight = inputTensorInfo.GetShape()[2];
+ uint32_t inputWidth = inputTensorInfo.GetShape()[3];
+
+ bool padding = false;
+ TensorInfo outputInfo;
+ if (paddingString == "SAME")
+ {
+ padding = true;
+ outputInfo = TensorInfo({ inputTensorInfo.GetShape()[0],
+ inputTensorInfo.GetShape()[1],
+ static_cast<uint32_t>(ceil(
+ static_cast<float>(inputHeight) /
+ static_cast<float>(pooling2dDescriptor.m_StrideY))),
+ static_cast<uint32_t>(ceil(
+ static_cast<float>(inputWidth) /
+ static_cast<float>(pooling2dDescriptor.m_StrideX)))
+ }, DataType::Float32);
+ }
+ else if (paddingString == "VALID")
+ {
+ padding = false;
+ outputInfo = TensorInfo({ inputTensorInfo.GetShape()[0],
+ inputTensorInfo.GetShape()[1],
+ static_cast<uint32_t>(ceil(
+ static_cast<float>(inputHeight - pooling2dDescriptor.m_PoolHeight + 1) /
+ static_cast<float>(pooling2dDescriptor.m_StrideY))),
+ static_cast<uint32_t>(ceil(
+ static_cast<float>(inputWidth - pooling2dDescriptor.m_PoolWidth + 1) /
+ static_cast<float>(pooling2dDescriptor.m_StrideX)))
+ }, DataType::Float32);
+ }
+ else
+ {
+ throw ParseException("Only 'SAME' and 'VALID' padding supported");
+ }
+
+ CalcPadding(inputWidth, pooling2dDescriptor.m_PoolWidth, pooling2dDescriptor.m_StrideX,
+ pooling2dDescriptor.m_PadLeft, pooling2dDescriptor.m_PadRight, padding);
+ CalcPadding(inputHeight, pooling2dDescriptor.m_PoolHeight, pooling2dDescriptor.m_StrideY,
+ pooling2dDescriptor.m_PadTop, pooling2dDescriptor.m_PadBottom, padding);
+
+
+ IConnectableLayer* layer = m_Network->AddPooling2dLayer(pooling2dDescriptor, nodeDef.name().c_str());
+ if (layer == nullptr)
+ {
+ throw ParseException("Failed to add pooling2d layer");
+ }
+
+ layer->GetOutputSlot(0).SetTensorInfo(outputInfo);
+
+ if (dataFormat == "NHWC")
+ {
+ layer = SwizzleInDeswizzleOut(*m_Network, inputSlot, *layer, nodeDef.name());
+ }
+ else
+ {
+ inputSlot.Connect(layer->GetInputSlot(0));
+ }
+
+ return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer);
+}
+
+ParsedTfOperationPtr TfParser::AddAdditionLayer(const tensorflow::NodeDef& nodeDef, bool isBiasAdd)
+{
+ std::vector<OutputOfParsedTfOperation> inputs = GetInputParsedTfOperationsChecked(nodeDef, 2);
+
+ IOutputSlot* input0Slot = &inputs[0].m_IndexedValue->ResolveArmnnOutputSlot(inputs[0].m_Index);
+ IOutputSlot* input1Slot = &inputs[1].m_IndexedValue->ResolveArmnnOutputSlot(inputs[1].m_Index);
+
+ const TensorInfo& input0Info = input0Slot->GetTensorInfo();
+ const TensorInfo& input1Info = input1Slot->GetTensorInfo();
+
+ if (isBiasAdd)
+ {
+ // BiasAdd takes bias as a 1D tensor. We need to add a reshape layer to create a 4D tensor
+ // with the same data in the correct dimension for broadcast in addition.
+ if(input1Info.GetNumDimensions() != 1)
+ {
+ throw ParseException("Unsupported bias for BiasAdd. It should be a 1D vector.");
+ }
+
+ const std::string dataFormat = ReadMandatoryNodeStringAttribute(nodeDef, "data_format");
+ const bool isNHWC = (dataFormat == "NHWC");
+ const bool isNCHW = (dataFormat == "NCHW");
+
+ if (!isNHWC && ! isNCHW)
+ {
+ throw ParseException("Only NHWC or NCHW supported for BiasAdd");
+ }
+
+ input1Slot = BroadcastForAddandMul(input0Slot, input1Slot, isNHWC, *m_Network, nodeDef);
+ }
+ else
+ {
+ if (input0Info.GetNumDimensions() == 1)
+ {
+ const bool isNHWC = true;
+ input0Slot = BroadcastForAddandMul(input1Slot, input0Slot, isNHWC, *m_Network, nodeDef);
+ }
+
+ if (input1Info.GetNumDimensions() == 1)
+ {
+ const bool isNHWC = true;
+ input1Slot = BroadcastForAddandMul(input0Slot, input1Slot, isNHWC, *m_Network, nodeDef);
+ }
+ }
+
+ IConnectableLayer* const layer = m_Network->AddAdditionLayer(nodeDef.name().c_str());
+
+ input0Slot->Connect(layer->GetInputSlot(0));
+ input1Slot->Connect(layer->GetInputSlot(1));
+
+ if (input0Info.GetNumDimensions() == 1 && isBiasAdd == false)
+ {
+ layer->GetOutputSlot(0).SetTensorInfo(input1Slot->GetTensorInfo());
+ }
+ else
+ {
+ layer->GetOutputSlot(0).SetTensorInfo(input0Slot->GetTensorInfo());
+ }
+
+ return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer);
+}
+
+IConnectableLayer* TfParser::AddFullyConnectedLayer(const tensorflow::NodeDef& matMulNodeDef,
+ const tensorflow::NodeDef* addNodeDef, const char* armnnLayerName)
+{
+ // find bias const (if applicable)
+ ParsedConstTfOperation<float>* biasNode = nullptr;
+ if (addNodeDef != nullptr)
+ {
+ std::vector<OutputOfParsedTfOperation> addInputs = GetInputParsedTfOperationsChecked(*addNodeDef, 2);
+ // find our inputs
+ if (HasParsedConstTensor<float>(addInputs[0].m_IndexedValue->GetNode().name()))
+ {
+ biasNode = boost::polymorphic_downcast<ParsedConstTfOperation<float>*>(addInputs[0].m_IndexedValue);
+ }
+ else if (HasParsedConstTensor<float>(addInputs[1].m_IndexedValue->GetNode().name()))
+ {
+ biasNode = boost::polymorphic_downcast<ParsedConstTfOperation<float>*>(addInputs[1].m_IndexedValue);
+ }
+ else
+ {
+ throw ParseException("ArmNN only supports fully connected layers with constant bias");
+ }
+ }
+
+ // find matmul inputs
+ ParsedConstTfOperation<float>* weightNode = nullptr;
+ ParsedTfOperation* inputNode = nullptr;
+ unsigned int inputIdx = 0;
+ std::vector<OutputOfParsedTfOperation> mulInputs = GetInputParsedTfOperationsChecked(matMulNodeDef, 2);
+ if (HasParsedConstTensor<float>(mulInputs[0].m_IndexedValue->GetNode().name()))
+ {
+ weightNode = boost::polymorphic_downcast<ParsedConstTfOperation<float>*>(mulInputs[0].m_IndexedValue);
+ inputNode = mulInputs[1].m_IndexedValue;
+ inputIdx = mulInputs[1].m_Index;
+ }
+ else if (HasParsedConstTensor<float>(mulInputs[1].m_IndexedValue->GetNode().name()))
+ {
+ weightNode = boost::polymorphic_downcast<ParsedConstTfOperation<float>*>(mulInputs[1].m_IndexedValue);
+ inputNode = mulInputs[0].m_IndexedValue;
+ inputIdx = mulInputs[0].m_Index;
+ }
+ else
+ {
+ throw ParseException("ArmNN only supports fully connected layers with constant weights");
+ }
+
+ std::vector<float> weightTensorData;
+ // handle weight
+ ConstTensor weights = weightNode->GetConstTensor(false, weightTensorData);
+
+ FullyConnectedDescriptor desc;
+ desc.m_BiasEnabled = addNodeDef != nullptr;
+
+ IConnectableLayer* layer = nullptr;
+ // make the layer
+ if (addNodeDef != nullptr)
+ {
+ std::vector<float> biasTensorData;
+ ConstTensor biases = biasNode->GetConstTensor(false, biasTensorData);
+
+ if (weights.GetShape()[1] != biases.GetShape()[0])
+ {
+ throw ParseException("shape of matmul and bias do not match");
+ }
+
+ layer = m_Network->AddFullyConnectedLayer(desc, weights, biases, armnnLayerName);
+ }
+ else
+ {
+ layer = m_Network->AddFullyConnectedLayer(desc, weights, armnnLayerName);
+ }
+
+ BOOST_ASSERT(layer != nullptr);
+
+ inputNode->ResolveArmnnOutputSlot(inputIdx).Connect(layer->GetInputSlot(0));
+ unsigned int batches = inputNode->ResolveArmnnOutputSlot(inputIdx).GetTensorInfo().GetShape()[0];
+
+ // handle output
+ TensorInfo outputInfo({ batches, weights.GetShape()[1] }, DataType::Float32);
+ layer->GetOutputSlot(0).SetTensorInfo(outputInfo);
+ return layer;
+}
+
+void TfParser::LoadNodeDef(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef)
+{
+ // get the type of the node (assume float)
+ tensorflow::DataType type = tensorflow::DT_FLOAT;
+ if (nodeDef.attr().count("T") != 0)
+ {
+ auto attr = nodeDef.attr().at("T");
+ type = attr.type();
+ }
+ else if (nodeDef.attr().count("dtype") != 0)
+ {
+ auto attr = nodeDef.attr().at("dtype");
+ type = attr.type();
+ }
+
+ if (type != tensorflow::DT_FLOAT && nodeDef.op() != "Const")
+ {
+ throw ParseException("Currently only FLOAT is supported for tensorflow nodes (apart from Const)");
+ }
+
+ const std::string& operation = nodeDef.op();
+ auto it = ms_OperationNameToParsingFunctions.find(operation);
+ if (it != ms_OperationNameToParsingFunctions.end())
+ {
+ auto func = it->second;
+ ParsedTfOperationPtr parsedTfOperation = (this->*func)(nodeDef, graphDef);
+ ParsedTfOperation* parsedTfOperationRaw = parsedTfOperation.get();
+
+ // Store the parsed operation so that dependent layers can connect to it
+ auto it = m_ParsedTfOperations.find(nodeDef.name());
+ if (it != m_ParsedTfOperations.end())
+ {
+ throw ParseException(boost::str(boost::format("Name %1% used by more than one node") % nodeDef.name()));
+ }
+ m_ParsedTfOperations[nodeDef.name()] = std::move(parsedTfOperation);
+
+ // If this node was requested as an output from the network then add an ArmNN output layer
+ if (std::find(m_RequestedOutputs.begin(), m_RequestedOutputs.end(), nodeDef.name()) !=
+ m_RequestedOutputs.end())
+ {
+ auto outId = ParseOutputId(nodeDef.name());
+ const LayerBindingId layerId = boost::numeric_cast<LayerBindingId>(m_NetworkOutputsBindingInfo.size());
+ IOutputSlot& prevSlot = parsedTfOperationRaw->ResolveArmnnOutputSlot(outId.m_Index);
+
+ TensorInfo tensorInfo = prevSlot.GetTensorInfo();
+
+ IConnectableLayer* outputLayer = m_Network->AddOutputLayer(layerId, nodeDef.name().c_str());
+
+ prevSlot.Connect(outputLayer->GetInputSlot(0));
+
+ TrackOutputBinding(outputLayer, layerId, tensorInfo);
+ }
+ }
+ else
+ {
+ throw ParseException(boost::str(
+ boost::format("Unsupported operation %1% in tensorflow::GraphDef") % operation));
+ }
+}
+
+void TfParser::LoadGraphDef(const tensorflow::GraphDef& graphDef)
+{
+ // add all nodes to our map
+ m_NodesByName.clear();
+ m_NetworkInputsBindingInfo.clear();
+ m_NetworkOutputsBindingInfo.clear();
+
+ for (int i = 0; i < graphDef.node_size(); ++i)
+ {
+ const tensorflow::NodeDef& node = graphDef.node(i);
+ m_NodesByName[node.name()] = &node;
+ }
+
+ // Find the output nodes the user requested
+ std::vector<const tensorflow::NodeDef*> targetNodes;
+ for (const std::string& requestedOutputName : m_RequestedOutputs)
+ {
+ auto nodeIt = m_NodesByName.find(requestedOutputName);
+ if (nodeIt == m_NodesByName.end())
+ {
+ throw ParseException("Couldn't find requested output node '" + requestedOutputName + "' in graph");
+ }
+ targetNodes.push_back(nodeIt->second);
+ }
+
+ // Sort them into a linear ordering such that all inputs of a node are before the node itself
+ std::vector<const tensorflow::NodeDef*> sortedNodes;
+ if (!armnnUtils::GraphTopologicalSort<const tensorflow::NodeDef*>(
+ targetNodes,
+ [this](const tensorflow::NodeDef* node)
+ {
+ auto outputs = GetTfInputNodes(*node);
+ std::vector<const tensorflow::NodeDef*> nodesOnly;
+ for (const auto & o : outputs) {
+ nodesOnly.push_back(o.m_IndexedValue);
+ }
+ return nodesOnly;
+ },
+ sortedNodes))
+ {
+ throw ParseException("Cycle detected in graph");
+ }
+
+ // Parse each node in order, knowing that all inputs of a node will be processed before the node itself
+ for (const auto& it : sortedNodes)
+ {
+ const tensorflow::NodeDef& currentNode = *it;
+ LoadNodeDef(currentNode, graphDef);
+ }
+}
+
+INetworkPtr TfParser::CreateNetworkFromTextFile(const char* graphFile,
+ const std::map<std::string, TensorShape>& inputShapes,
+ const std::vector<std::string>& requestedOutputs)
+{
+ FILE* fd = fopen(graphFile, "r");
+
+ if (fd == nullptr)
+ {
+ std::stringstream error;
+ error << "Graph file " << graphFile << " failed to open";
+ throw FileNotFoundException(error.str());
+ }
+
+ // Parse the file into a message
+ tensorflow::GraphDef graphDef;
+ auto input = new google::protobuf::io::FileInputStream(fileno(fd));
+ bool success = google::protobuf::TextFormat::Parse(input, &graphDef);
+ delete input;
+ fclose(fd);
+
+ if (!success)
+ {
+ std::stringstream error;
+ error << "Failed to parse graph file";
+ throw ParseException(error.str());
+ }
+
+ return CreateNetworkFromGraphDef(graphDef, inputShapes, requestedOutputs);
+}
+
+INetworkPtr TfParser::CreateNetworkFromString(const char* protoText,
+ const std::map<std::string, TensorShape>& inputShapes,
+ const std::vector<std::string>& requestedOutputs)
+{
+ // Parse the string into a message
+ tensorflow::GraphDef graphDef;
+ bool success = google::protobuf::TextFormat::ParseFromString(protoText, &graphDef);
+
+ if (!success)
+ {
+ std::stringstream error;
+ error << "Failed to parse graph file";
+ throw ParseException(error.str());
+ }
+
+ return CreateNetworkFromGraphDef(graphDef, inputShapes, requestedOutputs);
+}
+
+INetworkPtr TfParser::CreateNetworkFromBinaryFile(const char* graphFile,
+ const std::map<std::string, TensorShape>& inputShapes,
+ const std::vector<std::string>& requestedOutputs)
+{
+ FILE* fd = fopen(graphFile, "rb");
+
+ if (fd == nullptr)
+ {
+ std::stringstream error;
+ error << "Graph file " << graphFile << " failed to open";
+ throw FileNotFoundException(error.str());
+ }
+
+ // Parse the file into a message
+ tensorflow::GraphDef graphDef;
+
+ google::protobuf::io::FileInputStream inStream(fileno(fd));
+ google::protobuf::io::CodedInputStream codedStream(&inStream);
+ codedStream.SetTotalBytesLimit(INT_MAX, INT_MAX);
+ bool success = graphDef.ParseFromCodedStream(&codedStream);
+ fclose(fd);
+
+ if (!success)
+ {
+ std::stringstream error;
+ error << "Failed to parse protobuf file" << graphFile;
+ throw ParseException(error.str());
+ }
+
+ return CreateNetworkFromGraphDef(graphDef, inputShapes, requestedOutputs);
+}
+
+INetworkPtr TfParser::CreateNetworkFromGraphDef(const tensorflow::GraphDef& graphDef,
+ const std::map<std::string, TensorShape>& inputShapes,
+ const std::vector<std::string>& requestedOutputs)
+{
+ m_Network = INetwork::Create();
+
+ m_InputShapes = inputShapes;
+ if (requestedOutputs.size() == 0)
+ {
+ throw ParseException("requestedOutputs must have at least one entry");
+ }
+ m_RequestedOutputs = requestedOutputs;
+
+ try
+ {
+ LoadGraphDef(graphDef);
+ }
+ catch (const ParseException& e)
+ {
+ Cleanup();
+ throw e;
+ }
+
+ Cleanup();
+
+ return std::move(m_Network);
+}
+
+void TfParser::Cleanup()
+{
+ // cleanup, in case we reuse this parser
+ m_InputShapes.clear();
+ m_RequestedOutputs.clear();
+ m_NodesByName.clear();
+ m_ParsedTfOperations.clear();
+}
+
+BindingPointInfo TfParser::GetNetworkInputBindingInfo(const std::string& name) const
+{
+ return GetBindingInfo(name, "input", m_NetworkInputsBindingInfo);
+}
+
+BindingPointInfo TfParser::GetNetworkOutputBindingInfo(const std::string& name) const
+{
+ return GetBindingInfo(name, "output", m_NetworkOutputsBindingInfo);
+}
+
+std::pair<LayerBindingId, TensorInfo> TfParser::GetBindingInfo(const std::string& layerName,
+ const char* bindingPointDesc,
+ const std::unordered_map<std::string, BindingPointInfo>& nameToBindingInfo)
+{
+ auto it = nameToBindingInfo.find(layerName);
+ if (it == nameToBindingInfo.end())
+ {
+ throw InvalidArgumentException(boost::str(boost::format("Unknown %1% '%2%'") % bindingPointDesc % layerName));
+ }
+ return it->second;
+}
+
+void TfParser::TrackInputBinding(IConnectableLayer* layer, LayerBindingId id, const TensorInfo& tensorInfo)
+{
+ return TrackBindingPoint(layer, id, tensorInfo, "input", m_NetworkInputsBindingInfo);
+}
+
+void TfParser::TrackOutputBinding(IConnectableLayer* layer, LayerBindingId id, const TensorInfo& tensorInfo)
+{
+ return TrackBindingPoint(layer, id, tensorInfo, "output", m_NetworkOutputsBindingInfo);
+}
+
+void TfParser::TrackBindingPoint(IConnectableLayer* layer,
+ LayerBindingId id,
+ const TensorInfo& tensorInfo,
+ const char* bindingPointDesc,
+ std::unordered_map<std::string, BindingPointInfo>& nameToBindingInfo)
+{
+ const std::string layerName = layer->GetName();
+ auto it = nameToBindingInfo.find(layerName);
+ if (it == nameToBindingInfo.end())
+ {
+ nameToBindingInfo[layerName] = std::make_pair(id, tensorInfo);
+ }
+ else
+ {
+ throw ParseException(boost::str(
+ boost::format("Id %1% used by more than one %2% layer") % id % bindingPointDesc));
+ }
+}
+
+} // namespace armnnTfParser
diff --git a/src/armnnTfParser/TfParser.hpp b/src/armnnTfParser/TfParser.hpp
new file mode 100644
index 0000000000..c5b4bce8ac
--- /dev/null
+++ b/src/armnnTfParser/TfParser.hpp
@@ -0,0 +1,199 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// See LICENSE file in the project root for full license information.
+//
+#pragma once
+
+#include "armnnTfParser/ITfParser.hpp"
+
+#include "armnn/Types.hpp"
+#include "armnn/Tensor.hpp"
+#include "armnn/INetwork.hpp"
+
+#include <map>
+#include <memory>
+#include <unordered_map>
+#include <vector>
+
+namespace armnn
+{
+class TensorInfo;
+}
+
+namespace tensorflow
+{
+class GraphDef;
+class NodeDef;
+}
+
+namespace armnnTfParser
+{
+
+using BindingPointInfo = std::pair<armnn::LayerBindingId, armnn::TensorInfo>;
+
+class ParsedTfOperation;
+using ParsedTfOperationPtr = std::unique_ptr<ParsedTfOperation>;
+
+///
+/// WithOutputTensorIndex wraps a value and an index. The purpose of
+/// this template is to signify that in Tensorflow the input name of
+/// a layer has the convention of 'inputTensorName:#index' where the
+/// #index can be omitted and it implicitly means the 0. output of
+/// the referenced layer. By supporting this notation we can handle
+/// layers with multiple outputs, such as Split.
+///
+template <typename T>
+struct WithOutputTensorIndex
+{
+ T m_IndexedValue;
+ unsigned int m_Index;
+
+ WithOutputTensorIndex(const T & value, unsigned int index)
+ : m_IndexedValue{value}
+ , m_Index{index} {}
+
+ WithOutputTensorIndex(T && value, unsigned int index)
+ : m_IndexedValue{value}
+ , m_Index{index} {}
+};
+
+using OutputOfParsedTfOperation = WithOutputTensorIndex<ParsedTfOperation *>;
+using OutputOfConstNodeDef = WithOutputTensorIndex<const tensorflow::NodeDef*>;
+using OutputId = WithOutputTensorIndex<std::string>;
+
+class TfParser : public ITfParser
+{
+public:
+ /// Create the network from a protobuf text file on disk
+ virtual armnn::INetworkPtr CreateNetworkFromTextFile(
+ const char* graphFile,
+ const std::map<std::string, armnn::TensorShape>& inputShapes,
+ const std::vector<std::string>& requestedOutputs) override;
+
+ /// Create the network from a protobuf binary file on disk
+ virtual armnn::INetworkPtr CreateNetworkFromBinaryFile(
+ const char* graphFile,
+ const std::map<std::string, armnn::TensorShape>& inputShapes,
+ const std::vector<std::string>& requestedOutputs) override;
+
+ /// Create the network directly from protobuf text in a string. Useful for debugging/testing
+ virtual armnn::INetworkPtr CreateNetworkFromString(
+ const char* protoText,
+ const std::map<std::string, armnn::TensorShape>& inputShapes,
+ const std::vector<std::string>& requestedOutputs) override;
+
+ /// Retrieve binding info (layer id and tensor info) for the network input identified by the given layer name
+ virtual BindingPointInfo GetNetworkInputBindingInfo(const std::string& name) const override;
+
+ /// Retrieve binding info (layer id and tensor info) for the network output identified by the given layer name
+ virtual BindingPointInfo GetNetworkOutputBindingInfo(const std::string& name) const override;
+
+public:
+ TfParser();
+
+private:
+ template <typename T>
+ friend class ParsedConstTfOperation;
+ friend class ParsedMatMulTfOperation;
+
+ /// Parses a GraphDef loaded into memory from one of the other CreateNetwork*
+ armnn::INetworkPtr CreateNetworkFromGraphDef(const tensorflow::GraphDef& graphDef,
+ const std::map<std::string, armnn::TensorShape>& inputShapes,
+ const std::vector<std::string>& requestedOutputs);
+
+ /// sets up variables and then performs BFS to parse all nodes
+ void LoadGraphDef(const tensorflow::GraphDef& graphDef);
+
+ /// parses a given node, assuming nodes before it in graph have been done
+ void LoadNodeDef(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
+
+ /// Handling identity layers as the input for Conv2D layer
+ const tensorflow::NodeDef* ResolveIdentityNode(const tensorflow::NodeDef* nodeDef);
+ /// Finds the nodes connected as inputs of the given node in the graph.
+ std::vector<OutputOfConstNodeDef> GetTfInputNodes(const tensorflow::NodeDef& nodeDef) const;
+ /// Finds the IParsedTfOperations for the nodes connected as inputs of the given node in the graph,
+ /// and throws an exception if the number of inputs does not match the expected one.
+ /// This will automatically resolve any identity nodes. The result vector contains the parsed operation
+ /// together with the output tensor index to make the connection unambiguous.
+ std::vector<OutputOfParsedTfOperation> GetInputParsedTfOperationsChecked(const tensorflow::NodeDef& nodeDef,
+ std::size_t expectedNumInputs);
+
+ ParsedTfOperationPtr ParseConst(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
+
+ /// Checks if there is a pre-parsed const tensor is available with the given name and Type
+ template<typename Type>
+ bool HasParsedConstTensor(const std::string & nodeName) const;
+
+ ParsedTfOperationPtr ParseAdd(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
+ ParsedTfOperationPtr ParseBiasAdd(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
+ ParsedTfOperationPtr ParseConv2D(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
+ ParsedTfOperationPtr ParseDepthwiseConv2D(const tensorflow::NodeDef& nodeDef,const tensorflow::GraphDef& graphDef);
+ ParsedTfOperationPtr ParseFusedBatchNorm(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
+ ParsedTfOperationPtr ParseConcat(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
+ ParsedTfOperationPtr ParseIdentity(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
+ ParsedTfOperationPtr ParseLrn(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
+ ParsedTfOperationPtr ParseMatMul(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
+ ParsedTfOperationPtr ParseMul(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
+ ParsedTfOperationPtr ParsePlaceholder(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
+ ParsedTfOperationPtr ParseRelu(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
+ ParsedTfOperationPtr ParseRelu6(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
+ ParsedTfOperationPtr ParseReshape(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
+ ParsedTfOperationPtr ParseResizeBilinear(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
+ ParsedTfOperationPtr ParseShape(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
+ ParsedTfOperationPtr ParseSqueeze(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
+ ParsedTfOperationPtr ParseSigmoid(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
+ ParsedTfOperationPtr ParseSoftmax(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
+ ParsedTfOperationPtr ParseSoftplus(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
+ ParsedTfOperationPtr ParseTanh(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
+ ParsedTfOperationPtr ParseMaxPool(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
+ ParsedTfOperationPtr ParseAvgPool(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
+ ParsedTfOperationPtr ParsePooling2d(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef,
+ armnn::PoolingAlgorithm pooltype);
+ ParsedTfOperationPtr AddActivationLayer(const tensorflow::NodeDef& nodeDef, armnn::ActivationDescriptor& desc);
+ ParsedTfOperationPtr AddAdditionLayer(const tensorflow::NodeDef& nodeDef, bool isBiasAdd = false);
+ armnn::IConnectableLayer* AddFullyConnectedLayer(const tensorflow::NodeDef& matMulNodeDef,
+ const tensorflow::NodeDef* addNodeDef, const char* armnnLayerName);
+
+ static std::pair<armnn::LayerBindingId, armnn::TensorInfo> GetBindingInfo(const std::string& layerName,
+ const char* bindingPointDesc,
+ const std::unordered_map<std::string, BindingPointInfo>& nameToBindingInfo);
+
+ void TrackInputBinding(armnn::IConnectableLayer* layer,
+ armnn::LayerBindingId id,
+ const armnn::TensorInfo& tensorInfo);
+
+ void TrackOutputBinding(armnn::IConnectableLayer* layer,
+ armnn::LayerBindingId id,
+ const armnn::TensorInfo& tensorInfo);
+
+ static void TrackBindingPoint(armnn::IConnectableLayer* layer, armnn::LayerBindingId id,
+ const armnn::TensorInfo& tensorInfo,
+ const char* bindingPointDesc,
+ std::unordered_map<std::string, BindingPointInfo>& nameToBindingInfo);
+
+ void Cleanup();
+
+ /// The network we're building. Gets cleared after it is passed to the user
+ armnn::INetworkPtr m_Network;
+
+ using OperationParsingFunction = ParsedTfOperationPtr(TfParser::*)(const tensorflow::NodeDef& nodeDef,
+ const tensorflow::GraphDef& graphDef);
+
+ /// map of TensorFlow operation names to parsing member functions
+ static const std::map<std::string, OperationParsingFunction> ms_OperationNameToParsingFunctions;
+
+ std::map<std::string, armnn::TensorShape> m_InputShapes;
+ std::vector<std::string> m_RequestedOutputs;
+
+ /// map of nodes extracted from the GraphDef to speed up parsing
+ std::unordered_map<std::string, const tensorflow::NodeDef*> m_NodesByName;
+
+ std::unordered_map<std::string, ParsedTfOperationPtr> m_ParsedTfOperations;
+
+ /// maps input layer names to their corresponding ids and tensor infos
+ std::unordered_map<std::string, BindingPointInfo> m_NetworkInputsBindingInfo;
+
+ /// maps output layer names to their corresponding ids and tensor infos
+ std::unordered_map<std::string, BindingPointInfo> m_NetworkOutputsBindingInfo;
+};
+}
diff --git a/src/armnnTfParser/test/Activations.cpp b/src/armnnTfParser/test/Activations.cpp
new file mode 100644
index 0000000000..72ed64d653
--- /dev/null
+++ b/src/armnnTfParser/test/Activations.cpp
@@ -0,0 +1,113 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// See LICENSE file in the project root for full license information.
+//
+
+#include <boost/test/unit_test.hpp>
+#include "armnnTfParser/ITfParser.hpp"
+#include "ParserPrototxtFixture.hpp"
+
+BOOST_AUTO_TEST_SUITE(TensorflowParser)
+
+
+struct ActivationFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser>
+{
+ explicit ActivationFixture(const char* activationFunction)
+ {
+ m_Prototext = "node {\n"
+ " name: \"Placeholder\"\n"
+ " op: \"Placeholder\"\n"
+ " attr {\n"
+ " key: \"dtype\"\n"
+ " value {\n"
+ " type: DT_FLOAT\n"
+ " }\n"
+ " }\n"
+ " attr {\n"
+ " key: \"shape\"\n"
+ " value {\n"
+ " shape {\n"
+ " unknown_rank: true\n"
+ " }\n"
+ " }\n"
+ " }\n"
+ "}\n"
+ "node {\n"
+ " name: \"";
+ m_Prototext.append(activationFunction);
+ m_Prototext.append("\"\n"
+ " op: \"");
+ m_Prototext.append(activationFunction);
+ m_Prototext.append("\"\n"
+ " input: \"Placeholder\"\n"
+ " attr {\n"
+ " key: \"T\"\n"
+ " value {\n"
+ " type: DT_FLOAT\n"
+ " }\n"
+ " }\n"
+ "}\n");
+
+ SetupSingleInputSingleOutput({ 1, 7 }, "Placeholder", activationFunction);
+ }
+};
+
+
+struct ReLuFixture : ActivationFixture
+{
+ ReLuFixture() : ActivationFixture("Relu") {}
+};
+BOOST_FIXTURE_TEST_CASE(ParseReLu, ReLuFixture)
+{
+ RunTest<2>({ -1.0f, -0.5f, 1.25f, -3.0f, 0.0f, 0.5f, -0.75f },
+ { 0.0f, 0.0f, 1.25f, 0.0f, 0.0f, 0.5f, 0.0f });
+}
+
+
+struct ReLu6Fixture : ActivationFixture
+{
+ ReLu6Fixture() : ActivationFixture("Relu6") {}
+};
+BOOST_FIXTURE_TEST_CASE(ParseReLu6, ReLu6Fixture)
+{
+ RunTest<2>({ -1.0f, -0.5f, 7.25f, -3.0f, 0.0f, 0.5f, -0.75f },
+ { 0.0f, 0.0f, 6.0f, 0.0f, 0.0f, 0.5f, 0.0f });
+}
+
+
+struct SigmoidFixture : ActivationFixture
+{
+ SigmoidFixture() : ActivationFixture("Sigmoid") {}
+};
+BOOST_FIXTURE_TEST_CASE(ParseSigmoid, SigmoidFixture)
+{
+ RunTest<2>({ -0.1f, -0.2f, -0.3f, -0.4f, 0.1f, 0.2f, 0.3f },
+ { 0.4750208f, 0.45016602f, 0.42555749f, 0.40131235f, 0.52497917f, 0.54983395f, 0.57444251f });
+}
+
+
+struct SoftplusFixture : ActivationFixture
+{
+ SoftplusFixture() : ActivationFixture("Softplus") {}
+};
+BOOST_FIXTURE_TEST_CASE(ParseSoftplus, SoftplusFixture)
+{
+ RunTest<2>({ -0.1f, -0.2f, -0.3f, -0.4f, 0.1f, 0.2f, 0.3f },
+ { 0.64439666f, 0.59813893f, 0.55435526f, 0.51301527f, 0.74439669f, 0.7981388f, 0.85435522f });
+}
+
+
+struct TanhFixture : ActivationFixture
+{
+ TanhFixture() : ActivationFixture("Tanh") {}
+};
+BOOST_FIXTURE_TEST_CASE(ParseTanh, TanhFixture)
+{
+ RunTest<2>({ -0.1f, -0.2f, -0.3f, -0.4f, 0.1f, 0.2f, 0.3f },
+ { -0.09966799f, -0.19737528f, -0.29131261f, -0.379949f, 0.09966799f, 0.19737528f, 0.29131261f });
+}
+
+
+
+
+BOOST_AUTO_TEST_SUITE_END()
diff --git a/src/armnnTfParser/test/Addition.cpp b/src/armnnTfParser/test/Addition.cpp
new file mode 100644
index 0000000000..c9e69268c6
--- /dev/null
+++ b/src/armnnTfParser/test/Addition.cpp
@@ -0,0 +1,78 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// See LICENSE file in the project root for full license information.
+//
+
+#include <boost/test/unit_test.hpp>
+#include "armnnTfParser/ITfParser.hpp"
+#include "ParserPrototxtFixture.hpp"
+
+BOOST_AUTO_TEST_SUITE(TensorflowParser)
+
+struct AdditionFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser>
+{
+ AdditionFixture()
+ {
+ m_Prototext = "node { \n"
+ " name: \"graphInput\" \n"
+ " op: \"Placeholder\" \n"
+ " attr { \n"
+ " key: \"dtype\" \n"
+ " value { \n"
+ " type: DT_FLOAT \n"
+ " } \n"
+ " } \n"
+ " attr { \n"
+ " key: \"shape\" \n"
+ " value { \n"
+ " shape { \n"
+ " } \n"
+ " } \n"
+ " } \n"
+ " } \n"
+ " node { \n"
+ " name: \"softmax1\" \n"
+ " op: \"Softmax\" \n"
+ " input: \"graphInput\" \n"
+ " attr { \n"
+ " key: \"T\" \n"
+ " value { \n"
+ " type: DT_FLOAT \n"
+ " } \n"
+ " } \n"
+ " }\n"
+ " node {\n"
+ " name: \"softmax2\"\n"
+ " op : \"Softmax\"\n"
+ " input: \"graphInput\"\n"
+ " attr { \n"
+ " key: \"T\" \n"
+ " value { \n"
+ " type: DT_FLOAT \n"
+ " } \n"
+ " } \n"
+ " }\n"
+ " node {\n"
+ " name: \"addition\"\n"
+ " op : \"Add\"\n"
+ " input: \"softmax1\"\n"
+ " input: \"softmax2\"\n"
+ " attr { \n"
+ " key: \"T\" \n"
+ " value { \n"
+ " type: DT_FLOAT \n"
+ " } \n"
+ " } \n"
+ " }\n";
+
+ SetupSingleInputSingleOutput({ 1, 7 }, "graphInput", "addition");
+ }
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseAddition, AdditionFixture)
+{
+ RunTest<2>({ 0, 0, 10000, 0, 0, 0, 0 }, { 0, 0, 2, 0, 0, 0, 0 });
+}
+
+
+BOOST_AUTO_TEST_SUITE_END()
diff --git a/src/armnnTfParser/test/BiasAdd.cpp b/src/armnnTfParser/test/BiasAdd.cpp
new file mode 100644
index 0000000000..e29aeb1057
--- /dev/null
+++ b/src/armnnTfParser/test/BiasAdd.cpp
@@ -0,0 +1,104 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// See LICENSE file in the project root for full license information.
+//
+
+#include <boost/test/unit_test.hpp>
+#include "armnnTfParser/ITfParser.hpp"
+#include "ParserPrototxtFixture.hpp"
+
+BOOST_AUTO_TEST_SUITE(TensorflowParser)
+
+struct BiasAddFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser>
+{
+ explicit BiasAddFixture(const std::string& dataFormat)
+ {
+ m_Prototext = R"(
+node {
+ name: "graphInput"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ }
+ }
+ }
+}
+node {
+ name: "bias"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 3
+ }
+ }
+ float_val: 1
+ float_val: 2
+ float_val: 3
+ }
+ }
+ }
+}
+node {
+ name: "biasAdd"
+ op : "BiasAdd"
+ input: "graphInput"
+ input: "bias"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "data_format"
+ value {
+ s: ")" + dataFormat + R"("
+ }
+ }
+}
+)";
+
+ SetupSingleInputSingleOutput({ 1, 3, 1, 3 }, "graphInput", "biasAdd");
+ }
+};
+
+struct BiasAddFixtureNCHW : BiasAddFixture
+{
+ BiasAddFixtureNCHW() : BiasAddFixture("NCHW") {}
+};
+
+struct BiasAddFixtureNHWC : BiasAddFixture
+{
+ BiasAddFixtureNHWC() : BiasAddFixture("NHWC") {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseBiasAddNCHW, BiasAddFixtureNCHW)
+{
+ RunTest<4>(std::vector<float>(9), { 1, 1, 1, 2, 2, 2, 3, 3, 3 });
+}
+
+BOOST_FIXTURE_TEST_CASE(ParseBiasAddNHWC, BiasAddFixtureNHWC)
+{
+ RunTest<4>(std::vector<float>(9), { 1, 2, 3, 1, 2, 3, 1, 2, 3 });
+}
+
+BOOST_AUTO_TEST_SUITE_END()
diff --git a/src/armnnTfParser/test/BroadcastForAdd.cpp b/src/armnnTfParser/test/BroadcastForAdd.cpp
new file mode 100644
index 0000000000..4c9731d7fc
--- /dev/null
+++ b/src/armnnTfParser/test/BroadcastForAdd.cpp
@@ -0,0 +1,149 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// See LICENSE file in the project root for full license information.
+//
+
+#include <boost/test/unit_test.hpp>
+#include "armnnTfParser/ITfParser.hpp"
+#include "ParserPrototxtFixture.hpp"
+// This is a special case for add, which supports broadcasting
+BOOST_AUTO_TEST_SUITE(TensorflowParser)
+
+struct BroadcastForAddFixtureSlot1 : public ParserPrototxtFixture<armnnTfParser::ITfParser>
+{
+ BroadcastForAddFixtureSlot1()
+ {
+ m_Prototext = R"(
+ node {
+ name: "graphInput"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ }
+ }
+ }
+ }
+ node {
+ name: "Const_1"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 4.0
+ float_val: 5.0
+ }
+ }
+ }
+ }
+ node {
+ name: "Add"
+ op: "Add"
+ input: "graphInput"
+ input: "Const_1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ }
+ )";
+
+ SetupSingleInputSingleOutput({ 1, 2, 2, 2 }, "graphInput", "Add");
+ }
+};
+
+struct BroadcastForAddFixtureSlot0 : public ParserPrototxtFixture<armnnTfParser::ITfParser>
+{
+ BroadcastForAddFixtureSlot0()
+ {
+ m_Prototext = R"(
+ node {
+ name: "graphInput"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ }
+ }
+ }
+ }
+ node {
+ name: "Const_1"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 4.0
+ float_val: 5.0
+ }
+ }
+ }
+ }
+ node {
+ name: "Add"
+ op: "Add"
+ input: "Const_1"
+ input: "graphInput"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ }
+ )";
+
+ SetupSingleInputSingleOutput({ 1, 2, 2, 2 }, "graphInput", "Add");
+ }
+};
+
+
+BOOST_FIXTURE_TEST_CASE(ParseBroadcastForAddition1, BroadcastForAddFixtureSlot1)
+{
+ RunTest<4>({ 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0 }, { 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0 });
+}
+
+BOOST_FIXTURE_TEST_CASE(ParseBroadcastForAddition0, BroadcastForAddFixtureSlot0)
+{
+ RunTest<4>({ 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0 }, { 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0 });
+}
+
+
+
+BOOST_AUTO_TEST_SUITE_END()
diff --git a/src/armnnTfParser/test/Concat.cpp b/src/armnnTfParser/test/Concat.cpp
new file mode 100644
index 0000000000..a7d5ea03af
--- /dev/null
+++ b/src/armnnTfParser/test/Concat.cpp
@@ -0,0 +1,183 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// See LICENSE file in the project root for full license information.
+//
+
+#include <boost/test/unit_test.hpp>
+#include "armnnTfParser/ITfParser.hpp"
+#include "ParserPrototxtFixture.hpp"
+
+BOOST_AUTO_TEST_SUITE(TensorflowParser)
+
+struct ConcatFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser>
+{
+ explicit ConcatFixture(const armnn::TensorShape& inputShape0, const armnn::TensorShape& inputShape1,
+ unsigned int concatDim)
+ {
+ m_Prototext = R"(
+ node {
+ name: "graphInput0"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ }
+ }
+ }
+ }
+ node {
+ name: "graphInput1"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ }
+ }
+ }
+ }
+ node {
+ name: "concat/axis"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: )";
+
+ m_Prototext += std::to_string(concatDim);
+
+ m_Prototext += R"(
+ }
+ }
+ }
+ }
+ node {
+ name: "concat"
+ op: "ConcatV2"
+ input: "graphInput0"
+ input: "graphInput1"
+ input: "concat/axis"
+ attr {
+ key: "N"
+ value {
+ i: 2
+ }
+ }
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tidx"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ }
+ )";
+
+ Setup({{"graphInput0", inputShape0 },
+ {"graphInput1", inputShape1 }}, {"concat"});
+ }
+};
+
+struct ConcatFixtureNCHW : ConcatFixture
+{
+ ConcatFixtureNCHW() : ConcatFixture({ 1, 1, 2, 2 }, { 1, 1, 2, 2 }, 1 ) {}
+};
+
+struct ConcatFixtureNHWC : ConcatFixture
+{
+ ConcatFixtureNHWC() : ConcatFixture({ 1, 1, 2, 2 }, { 1, 1, 2, 2 }, 3 ) {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseConcatNCHW, ConcatFixtureNCHW)
+{
+ RunTest<4>({{"graphInput0", {0.0, 1.0, 2.0, 3.0}},
+ {"graphInput1", {4.0, 5.0, 6.0, 7.0}}},
+ {{"concat", { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0 }}});
+}
+
+BOOST_FIXTURE_TEST_CASE(ParseConcatNHWC, ConcatFixtureNHWC)
+{
+ RunTest<4>({{"graphInput0", {0.0, 1.0, 2.0, 3.0}},
+ {"graphInput1", {4.0, 5.0, 6.0, 7.0}}},
+ {{"concat", { 0.0, 1.0, 4.0, 5.0, 2.0, 3.0, 6.0, 7.0 }}});
+}
+
+struct ConcatFixtureDim1 : ConcatFixture
+{
+ ConcatFixtureDim1() : ConcatFixture({ 1, 2, 3, 4 }, { 1, 2, 3, 4 }, 1) {}
+};
+
+struct ConcatFixtureDim3 : ConcatFixture
+{
+ ConcatFixtureDim3() : ConcatFixture({ 1, 2, 3, 4 }, { 1, 2, 3, 4 }, 3) {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseConcatDim1, ConcatFixtureDim1)
+{
+ RunTest<4>({ { "graphInput0", { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0,
+ 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0 } },
+ { "graphInput1", { 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0, 61.0,
+ 62.0, 63.0, 64.0, 65.0, 66.0, 67.0, 68.0, 69.0, 70.0, 71.0, 72.0, 73.0 } } },
+ { { "concat", { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0,
+ 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0,
+ 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0, 61.0,
+ 62.0, 63.0, 64.0, 65.0, 66.0, 67.0, 68.0, 69.0, 70.0, 71.0, 72.0, 73.0 } } });
+}
+
+BOOST_FIXTURE_TEST_CASE(ParseConcatDim3, ConcatFixtureDim3)
+{
+ RunTest<4>({ { "graphInput0", { 0.0, 1.0, 2.0, 3.0,
+ 4.0, 5.0, 6.0, 7.0,
+ 8.0, 9.0, 10.0, 11.0,
+ 12.0, 13.0, 14.0, 15.0,
+ 16.0, 17.0, 18.0, 19.0,
+ 20.0, 21.0, 22.0, 23.0 } },
+ { "graphInput1", { 50.0, 51.0, 52.0, 53.0,
+ 54.0, 55.0, 56.0, 57.0,
+ 58.0, 59.0, 60.0, 61.0,
+ 62.0, 63.0, 64.0, 65.0,
+ 66.0, 67.0, 68.0, 69.0,
+ 70.0, 71.0, 72.0, 73.0 } } },
+ { { "concat", { 0.0, 1.0, 2.0, 3.0,
+ 50.0, 51.0, 52.0, 53.0,
+ 4.0, 5.0, 6.0, 7.0,
+ 54.0, 55.0, 56.0, 57.0,
+ 8.0, 9.0, 10.0, 11.0,
+ 58.0, 59.0, 60.0, 61.0,
+ 12.0, 13.0, 14.0, 15.0,
+ 62.0, 63.0, 64.0, 65.0,
+ 16.0, 17.0, 18.0, 19.0,
+ 66.0, 67.0, 68.0, 69.0,
+ 20.0, 21.0, 22.0, 23.0,
+ 70.0, 71.0, 72.0, 73.0 } } });
+}
+
+BOOST_AUTO_TEST_SUITE_END() \ No newline at end of file
diff --git a/src/armnnTfParser/test/ConcatOfConcats.cpp b/src/armnnTfParser/test/ConcatOfConcats.cpp
new file mode 100644
index 0000000000..7316b9f1ac
--- /dev/null
+++ b/src/armnnTfParser/test/ConcatOfConcats.cpp
@@ -0,0 +1,316 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// See LICENSE file in the project root for full license information.
+//
+
+#include <boost/test/unit_test.hpp>
+#include "armnnTfParser/ITfParser.hpp"
+#include "ParserPrototxtFixture.hpp"
+
+BOOST_AUTO_TEST_SUITE(TensorflowParser)
+
+struct ConcatOfConcatsFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser>
+{
+ explicit ConcatOfConcatsFixture(const armnn::TensorShape& inputShape0, const armnn::TensorShape& inputShape1,
+ const armnn::TensorShape& inputShape2, const armnn::TensorShape& inputShape3,
+ unsigned int concatDim)
+ {
+ m_Prototext = R"(
+ node {
+ name: "graphInput0"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ }
+ }
+ }
+ }
+ node {
+ name: "graphInput1"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ }
+ }
+ }
+ }
+ node {
+ name: "graphInput2"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ }
+ }
+ }
+ }
+ node {
+ name: "graphInput3"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ }
+ }
+ }
+ }
+ node {
+ name: "Relu"
+ op: "Relu"
+ input: "graphInput0"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ }
+ node {
+ name: "Relu_1"
+ op: "Relu"
+ input: "graphInput1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ }
+ node {
+ name: "Relu_2"
+ op: "Relu"
+ input: "graphInput2"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ }
+ node {
+ name: "Relu_3"
+ op: "Relu"
+ input: "graphInput3"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ }
+ node {
+ name: "concat/axis"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: )";
+ m_Prototext += std::to_string(concatDim);
+ m_Prototext += R"(
+ }
+ }
+ }
+ }
+ node {
+ name: "concat"
+ op: "ConcatV2"
+ input: "Relu"
+ input: "Relu_1"
+ input: "concat/axis"
+ attr {
+ key: "N"
+ value {
+ i: 2
+ }
+ }
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tidx"
+ value {
+ type: DT_INT32
+ }
+ }
+ }
+ node {
+ name: "concat_1/axis"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: )";
+ m_Prototext += std::to_string(concatDim);
+ m_Prototext += R"(
+ }
+ }
+ }
+ }
+ node {
+ name: "concat_1"
+ op: "ConcatV2"
+ input: "Relu_2"
+ input: "Relu_3"
+ input: "concat_1/axis"
+ attr {
+ key: "N"
+ value {
+ i: 2
+ }
+ }
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tidx"
+ value {
+ type: DT_INT32
+ }
+ }
+ }
+ node {
+ name: "concat_2/axis"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: )";
+ m_Prototext += std::to_string(concatDim);
+ m_Prototext += R"(
+ }
+ }
+ }
+ }
+ node {
+ name: "concat_2"
+ op: "ConcatV2"
+ input: "concat"
+ input: "concat_1"
+ input: "concat_2/axis"
+ attr {
+ key: "N"
+ value {
+ i: 2
+ }
+ }
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tidx"
+ value {
+ type: DT_INT32
+ }
+ }
+ }
+ )";
+
+ Setup({{ "graphInput0", inputShape0 },
+ { "graphInput1", inputShape1 },
+ { "graphInput2", inputShape2 },
+ { "graphInput3", inputShape3}}, {"concat_2"});
+ }
+};
+
+struct ConcatOfConcatsFixtureNCHW : ConcatOfConcatsFixture
+{
+ ConcatOfConcatsFixtureNCHW() : ConcatOfConcatsFixture({ 1, 1, 2, 2 }, { 1, 1, 2, 2 }, { 1, 1, 2, 2 },
+ { 1, 1, 2, 2 }, 1 ) {}
+};
+
+struct ConcatOfConcatsFixtureNHWC : ConcatOfConcatsFixture
+{
+ ConcatOfConcatsFixtureNHWC() : ConcatOfConcatsFixture({ 1, 1, 2, 2 }, { 1, 1, 2, 2 }, { 1, 1, 2, 2 },
+ { 1, 1, 2, 2 }, 3 ) {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseConcatOfConcatsNCHW, ConcatOfConcatsFixtureNCHW)
+{
+ RunTest<4>({{"graphInput0", {0.0, 1.0, 2.0, 3.0}},
+ {"graphInput1", {4.0, 5.0, 6.0, 7.0}},
+ {"graphInput2", {8.0, 9.0, 10.0, 11.0}},
+ {"graphInput3", {12.0, 13.0, 14.0, 15.0}}},
+ {{"concat_2", { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0,
+ 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0 }}});
+}
+
+BOOST_FIXTURE_TEST_CASE(ParseConcatOfConcatsNHWC, ConcatOfConcatsFixtureNHWC)
+{
+ RunTest<4>({{"graphInput0", {0.0, 1.0, 2.0, 3.0}},
+ {"graphInput1", {4.0, 5.0, 6.0, 7.0}},
+ {"graphInput2", {8.0, 9.0, 10.0, 11.0}},
+ {"graphInput3", {12.0, 13.0, 14.0, 15.0}}},
+ {{"concat_2", { 0.0, 1.0, 4.0, 5.0, 8.0, 9.0, 12.0, 13.0,
+ 2.0, 3.0, 6.0, 7.0, 10.0, 11.0, 14.0, 15.0 }}});
+}
+
+BOOST_AUTO_TEST_SUITE_END()
diff --git a/src/armnnTfParser/test/Constant.cpp b/src/armnnTfParser/test/Constant.cpp
new file mode 100644
index 0000000000..09587fc3d5
--- /dev/null
+++ b/src/armnnTfParser/test/Constant.cpp
@@ -0,0 +1,321 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// See LICENSE file in the project root for full license information.
+//
+
+#include <boost/test/unit_test.hpp>
+
+#include "armnnTfParser/ITfParser.hpp"
+
+#include "ParserPrototxtFixture.hpp"
+
+BOOST_AUTO_TEST_SUITE(TensorflowParser)
+
+// Tests that a Const node in Tensorflow can be converted to a ConstLayer in armnn (as opposed to most
+// Const nodes which are used as weight inputs for convolutions etc. and are therefore not converted to
+// armnn ConstLayers).
+struct ConstantFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser>
+{
+ ConstantFixture()
+ {
+ // input = tf.placeholder(tf.float32, name = "input")
+ // const = tf.constant([17], tf.float32, [1])
+ // output = tf.add(input, const, name = "output")
+ m_Prototext =
+ R"(
+node {
+ name: "input"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ unknown_rank: true
+ }
+ }
+ }
+}
+node {
+ name: "Const"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ float_val: 17.0
+ }
+ }
+ }
+}
+node {
+ name: "output"
+ op: "Add"
+ input: "input"
+ input: "Const"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+ )";
+ SetupSingleInputSingleOutput({ 1 }, "input", "output");
+ }
+};
+
+BOOST_FIXTURE_TEST_CASE(Constant, ConstantFixture)
+{
+ RunTest<1>({1}, {18});
+}
+
+
+// Tests that a single Const node in Tensorflow can be used twice by a dependant node. This should result in only
+// a single armnn ConstLayer being created.
+struct ConstantReusedFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser>
+{
+ ConstantReusedFixture()
+ {
+ // const = tf.constant([17], tf.float32, [1])
+ // output = tf.add(const, const, name = "output")
+ m_Prototext =
+ R"(
+node {
+ name: "Const"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ float_val: 17.0
+ }
+ }
+ }
+}
+node {
+ name: "output"
+ op: "Add"
+ input: "Const"
+ input: "Const"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+ )";
+ Setup({}, { "output" });
+ }
+};
+
+BOOST_FIXTURE_TEST_CASE(ConstantReused, ConstantReusedFixture)
+{
+ RunTest<1>({}, { { "output", { 34 } } });
+}
+
+template <int ListSize>
+struct ConstantValueListFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser>
+{
+ ConstantValueListFixture()
+ {
+ m_Prototext =
+ R"(
+node {
+ name: "output"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 2
+ }
+ dim {
+ size: 3
+ }
+ })";
+
+ double value = 0.75;
+ for (int i = 0; i < ListSize; i++, value += 0.25)
+ {
+ m_Prototext += std::string("float_val : ") + std::to_string(value) + "\n";
+ }
+
+ m_Prototext +=
+ R"(
+ }
+ }
+ }
+}
+ )";
+ Setup({}, { "output" });
+ }
+};
+
+using ConstantSingleValueListFixture = ConstantValueListFixture<1>;
+using ConstantMultipleValueListFixture = ConstantValueListFixture<4>;
+using ConstantMaxValueListFixture = ConstantValueListFixture<6>;
+
+BOOST_FIXTURE_TEST_CASE(ConstantSingleValueList, ConstantSingleValueListFixture)
+{
+ RunTest<2>({}, { { "output", { 0.75f, 0.75f, 0.75f, 0.75f, 0.75f, 0.75f } } });
+}
+BOOST_FIXTURE_TEST_CASE(ConstantMultipleValueList, ConstantMultipleValueListFixture)
+{
+ RunTest<2>({}, { { "output", { 0.75f, 1.f, 1.25f, 1.5f, 1.5f, 1.5f } } });
+}
+BOOST_FIXTURE_TEST_CASE(ConstantMaxValueList, ConstantMaxValueListFixture)
+{
+ RunTest<2>({}, { { "output", { 0.75f, 1.f, 1.25f, 1.50f, 1.75f, 2.f } } });
+}
+
+template <bool WithShape, bool WithContent, bool WithValueList>
+struct ConstantCreateFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser>
+{
+ ConstantCreateFixture()
+ {
+ m_Prototext =
+ R"(
+node {
+ name: "output"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ )";
+
+ if (WithShape)
+ {
+ m_Prototext +=
+ R"(
+tensor_shape {
+ dim {
+ size: 2
+ }
+ dim {
+ size: 2
+ }
+}
+ )";
+ }
+ else
+ {
+ m_Prototext +=
+ R"(
+tensor_shape {
+}
+ )";
+ }
+
+ if (WithContent)
+ {
+ m_Prototext +=
+ R"(
+tensor_content: "\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?"
+ )";
+ }
+
+ if (WithValueList)
+ {
+ m_Prototext +=
+ R"(
+float_val: 1.0
+float_val: 1.0
+float_val: 1.0
+float_val: 1.0
+float_val: 1.0
+ )";
+ }
+
+ m_Prototext +=
+ R"(
+ }
+ }
+ }
+}
+ )";
+ }
+};
+
+using ConstantCreateNoValueListFixture = ConstantCreateFixture<true, false, true>;
+using ConstantCreateNoValueList2Fixture = ConstantCreateFixture<true, false, false>;
+using ConstantCreateNoContentFixture = ConstantCreateFixture<true, true, false>;
+using ConstantCreateNoContent2Fixture = ConstantCreateFixture<true, false, false>;
+using ConstantCreateNoShapeFixture = ConstantCreateFixture<false, false, false>;
+using ConstantCreateNoShape2Fixture = ConstantCreateFixture<false, true, false>;
+using ConstantCreateNoShape3Fixture = ConstantCreateFixture<false, false, true>;
+
+BOOST_FIXTURE_TEST_CASE(ConstantCreateInvalidValueList, ConstantCreateNoValueListFixture)
+{
+ BOOST_REQUIRE_THROW(Setup({}, { "output" }), armnn::ParseException);
+}
+BOOST_FIXTURE_TEST_CASE(ConstantCreateInvalidValueList2, ConstantCreateNoValueList2Fixture)
+{
+ BOOST_REQUIRE_THROW(Setup({}, { "output" }), armnn::ParseException);
+}
+BOOST_FIXTURE_TEST_CASE(ConstantCreateInvalidContent, ConstantCreateNoContentFixture)
+{
+ BOOST_REQUIRE_THROW(Setup({}, { "output" }), armnn::ParseException);
+}
+BOOST_FIXTURE_TEST_CASE(ConstantCreateInvalidShape, ConstantCreateNoShapeFixture)
+{
+ BOOST_REQUIRE_THROW(Setup({}, { "output" }), armnn::ParseException);
+}
+BOOST_FIXTURE_TEST_CASE(ConstantCreateNoShape2, ConstantCreateNoShape2Fixture)
+{
+ BOOST_REQUIRE_THROW(Setup({}, { "output" }), armnn::ParseException);
+}
+BOOST_FIXTURE_TEST_CASE(ConstantCreateNoShape3, ConstantCreateNoShape3Fixture)
+{
+ Setup({}, { "output" });
+ RunTest<1>({}, { { "output", { 1.f, 1.f, 1.f, 1.f, 1.f } } });
+}
+
+BOOST_AUTO_TEST_SUITE_END()
diff --git a/src/armnnTfParser/test/Convolution2d.cpp b/src/armnnTfParser/test/Convolution2d.cpp
new file mode 100644
index 0000000000..a7c7648b81
--- /dev/null
+++ b/src/armnnTfParser/test/Convolution2d.cpp
@@ -0,0 +1,322 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// See LICENSE file in the project root for full license information.
+//
+
+#include <boost/test/unit_test.hpp>
+#include "armnnTfParser/ITfParser.hpp"
+#include "ParserPrototxtFixture.hpp"
+#include <string>
+#include <iostream>
+
+BOOST_AUTO_TEST_SUITE(TensorflowParser)
+
+struct Convolution2dFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser>
+{
+ explicit Convolution2dFixture(const char* paddingType)
+ : Convolution2dFixture(paddingType, 1)
+ {}
+
+ // dilation: 0 - dilations attribute is not included;
+ // dilation: >0 - dilations attribute set to [1,v,v,1], where v is the value of the dilation arg
+ explicit Convolution2dFixture(const char* paddingType, int stride, int dilation = 0)
+ {
+ std::string strideString = std::to_string(stride);
+ std::string dilationString = std::to_string(dilation);
+ m_Prototext = "node { \n"
+ " name: \"graphInput\" \n"
+ " op: \"Placeholder\" \n"
+ " attr { \n"
+ " key: \"dtype\" \n"
+ " value { \n"
+ " type: DT_FLOAT \n"
+ " } \n"
+ " } \n"
+ " attr { \n"
+ " key: \"shape\" \n"
+ " value { \n"
+ " shape { \n"
+ " } \n"
+ " } \n"
+ " } \n"
+ " } \n"
+ " node { \n"
+ " name: \"Const_1\" \n"
+ " op: \"Const\" \n"
+ " attr { \n"
+ " key: \"dtype\" \n"
+ " value { \n"
+ " type: DT_FLOAT \n"
+ " } \n"
+ " } \n"
+ " attr { \n"
+ " key: \"value\" \n"
+ " value { \n"
+ " tensor { \n"
+ " dtype: DT_FLOAT \n"
+ " tensor_shape { \n"
+ " dim { \n"
+ " size: 1 \n"
+ " } \n"
+ " dim { \n"
+ " size: 3 \n"
+ " } \n"
+ " dim { \n"
+ " size: 1 \n"
+ " } \n"
+ " dim { \n"
+ " size: 1 \n"
+ " } \n"
+ " } \n"
+ " tensor_content: \"\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?\" \n"
+ " } \n"
+ " } \n"
+ " } \n"
+ "} \n"
+ "node { \n"
+ " name: \"potato\" \n"
+ " op: \"Conv2D\" \n"
+ " input: \"graphInput\" \n"
+ " input: \"Const_1\" \n"
+ " attr { \n"
+ " key: \"T\" \n"
+ " value { \n"
+ " type: DT_FLOAT \n"
+ " } \n"
+ " } \n"
+ " attr { \n"
+ " key: \"data_format\" \n"
+ " value { \n"
+ " s: \"NHWC\" \n"
+ " } \n"
+ " } \n"
+ " attr { \n"
+ " key: \"padding\" \n"
+ " value { \n"
+ " s: \"";
+ m_Prototext.append(paddingType);
+ m_Prototext.append("\"\n"
+ " } \n"
+ " } \n"
+ " attr { \n"
+ " key: \"strides\" \n"
+ " value { \n"
+ " list { \n"
+ " i: 1 \n"
+ " i: 1 \n"
+ " i: ");
+ m_Prototext.append(strideString);
+ m_Prototext.append(" \n"
+ " i: 1 \n"
+ " } \n"
+ " } \n"
+ " } \n");
+
+ if (dilation > 0)
+ {
+ m_Prototext.append(" attr { \n"
+ " key: \"dilations\" \n"
+ " value { \n"
+ " list { \n"
+ " i: 1 \n"
+ " i: ");
+ m_Prototext.append(dilationString);
+ m_Prototext.append(" \n"
+ " i: ");
+ m_Prototext.append(dilationString);
+ m_Prototext.append(" \n"
+ " i: 1 \n"
+ " } \n"
+ " } \n"
+ " } \n");
+ }
+ m_Prototext.append(" attr { \n"
+ " key: \"use_cudnn_on_gpu\" \n"
+ " value { \n"
+ " b: false \n"
+ " } \n"
+ " } \n"
+ "} \n");
+
+ // Manual height computation based on stride parameter.
+ BOOST_ASSERT_MSG(stride == 1 || stride==2, "Add support for strides other than 1 or 2.");
+ unsigned int dims[] = {1,2,3,1};
+ if (stride == 2)
+ {
+ dims[1]=3;
+ }
+
+ SetupSingleInputSingleOutput(armnn::TensorShape(4, dims), "graphInput", "potato");
+ }
+};
+
+
+struct Convolution2dSameFixture : Convolution2dFixture
+{
+ Convolution2dSameFixture() : Convolution2dFixture("SAME", 1){}
+};
+BOOST_FIXTURE_TEST_CASE(ParseConv2DSame, Convolution2dSameFixture)
+{
+ RunTest<4>({1, 2, 3, 4, 5, 6}, {2, 4, 4, 6.5f, 10 , 8.5f});
+}
+
+struct Convolution2dValidFixture : Convolution2dFixture
+{
+ Convolution2dValidFixture() : Convolution2dFixture("VALID", 1){}
+};
+BOOST_FIXTURE_TEST_CASE(ParseConv2DValid, Convolution2dValidFixture)
+{
+ RunTest<4>({1, 2, 3, 4, 5, 6}, {4, 10});
+}
+
+
+struct Convolution2dStride2SameFixture : Convolution2dFixture
+{
+ Convolution2dStride2SameFixture() : Convolution2dFixture("SAME", 2){}
+};
+BOOST_FIXTURE_TEST_CASE(ParseConv2DStride2Same, Convolution2dStride2SameFixture)
+{
+ RunTest<4>({1, 2, 3, 4, 5, 6, 7, 8, 9}, {2, 4, 6.5, 8.5, 11, 13});
+}
+
+
+struct Convolution2dStride2ValidFixture : Convolution2dFixture
+{
+ Convolution2dStride2ValidFixture() : Convolution2dFixture("VALID", 2){}
+};
+BOOST_FIXTURE_TEST_CASE(ParseConv2DStride2Valid, Convolution2dStride2ValidFixture)
+{
+ RunTest<4>({1, 2, 3, 4, 5, 6, 7, 8, 9}, {4, 10, 16});
+}
+
+
+struct Convolution2dDilation1Fixture : Convolution2dFixture
+{
+ Convolution2dDilation1Fixture() : Convolution2dFixture("SAME", 1, 1){}
+};
+BOOST_FIXTURE_TEST_CASE(ParseConv2DDilation1, Convolution2dDilation1Fixture)
+{
+ RunTest<4>({1, 2, 3, 4, 5, 6}, {2, 4, 4, 6.5f, 10 , 8.5f});
+}
+
+BOOST_AUTO_TEST_CASE(ParseConv2DDilation2)
+{
+ const char* prototext = ""
+ "node {\n"
+ " name: \"graphInput\"\n"
+ " op: \"Placeholder\"\n"
+ " attr {\n"
+ " key: \"dtype\"\n"
+ " value {\n"
+ " type: DT_FLOAT\n"
+ " }\n"
+ " }\n"
+ " attr {\n"
+ " key: \"shape\"\n"
+ " value {\n"
+ " shape {\n"
+ " }\n"
+ " }\n"
+ " }\n"
+ "}\n"
+ "node {\n"
+ " name: \"Const_1\"\n"
+ " op: \"Const\"\n"
+ " attr {\n"
+ " key: \"dtype\"\n"
+ " value {\n"
+ " type: DT_FLOAT\n"
+ " }\n"
+ " }\n"
+ " attr {\n"
+ " key: \"value\"\n"
+ " value {\n"
+ " tensor {\n"
+ " dtype: DT_FLOAT\n"
+ " tensor_shape {\n"
+ " dim {\n"
+ " size: 1\n"
+ " }\n"
+ " dim {\n"
+ " size: 3\n"
+ " }\n"
+ " dim {\n"
+ " size: 1\n"
+ " }\n"
+ " dim {\n"
+ " size: 1\n"
+ " }\n"
+ " }\n"
+ " tensor_content: \"\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?\"\n"
+ " }\n"
+ " }\n"
+ " }\n"
+ "}\n"
+ "node {\n"
+ " name: \"potato\"\n"
+ " op: \"Conv2D\"\n"
+ " input: \"graphInput\"\n"
+ " input: \"Const_1\"\n"
+ " attr {\n"
+ " key: \"T\"\n"
+ " value {\n"
+ " type: DT_FLOAT\n"
+ " }\n"
+ " }\n"
+ " attr {\n"
+ " key: \"data_format\"\n"
+ " value {\n"
+ " s: \"NHWC\"\n"
+ " }\n"
+ " }\n"
+ " attr {\n"
+ " key: \"padding\"\n"
+ " value {\n"
+ " s: \"SAME\"\n"
+ " }\n"
+ " }\n"
+ " attr {\n"
+ " key: \"strides\"\n"
+ " value {\n"
+ " list {\n"
+ " i: 1\n"
+ " i: 1\n"
+ " i: 1\n"
+ " i: 1\n"
+ " }\n"
+ " }\n"
+ " }\n"
+ " attr {\n"
+ " key: \"dilations\"\n"
+ " value {\n"
+ " list {\n"
+ " i: 1\n"
+ " i: 2\n"
+ " i: 2\n"
+ " i: 1\n"
+ " }\n"
+ " }\n"
+ " }\n"
+ " attr {\n"
+ " key: \"use_cudnn_on_gpu\"\n"
+ " value {\n"
+ " b: false\n"
+ " }\n"
+ " }\n"
+ "}\n";
+
+ std::map<std::string, armnn::TensorShape> inputShapes;
+ armnn::TensorShape tensorShape = { 1, 3, 3, 1 };
+ inputShapes["graphInput"] = tensorShape;
+ armnnTfParser::ITfParserPtr parser = armnnTfParser::ITfParser::Create();
+ BOOST_CHECK_EXCEPTION(parser->CreateNetworkFromString(prototext, inputShapes, { "potato" }),
+ armnn::ParseException,
+ [] (armnn::ParseException const& ex)->bool
+ {
+ return strcmp(ex.what(),
+ "ArmNN only supports Convolution layers with dilations [1,1,1,1]") == 0;
+ });
+}
+
+
+BOOST_AUTO_TEST_SUITE_END()
diff --git a/src/armnnTfParser/test/DepthwiseConvolution2d.cpp b/src/armnnTfParser/test/DepthwiseConvolution2d.cpp
new file mode 100644
index 0000000000..84e7a7e7a9
--- /dev/null
+++ b/src/armnnTfParser/test/DepthwiseConvolution2d.cpp
@@ -0,0 +1,166 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// See LICENSE file in the project root for full license information.
+//
+
+#include <boost/test/unit_test.hpp>
+#include "armnnTfParser/ITfParser.hpp"
+#include "ParserPrototxtFixture.hpp"
+#include <string>
+#include <iostream>
+
+BOOST_AUTO_TEST_SUITE(TensorflowParser)
+
+struct DepthwiseConvolution2dFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser>
+{
+ explicit DepthwiseConvolution2dFixture(const char* paddingType)
+ {
+ m_Prototext = "node { \n"
+ " name: \"graphInput\" \n"
+ " op: \"Placeholder\" \n"
+ " attr { \n"
+ " key: \"dtype\" \n"
+ " value { \n"
+ " type: DT_FLOAT \n"
+ " } \n"
+ " } \n"
+ " attr { \n"
+ " key: \"value\" \n"
+ " value { \n"
+ " tensor { \n"
+ " dtype: DT_FLOAT \n"
+ " tensor_shape { \n"
+ " dim { \n"
+ " size: 1 \n"
+ " } \n"
+ " dim { \n"
+ " size: 1 \n"
+ " } \n"
+ " dim { \n"
+ " size: 3 \n"
+ " } \n"
+ " dim { \n"
+ " size: 3 \n"
+ " } \n"
+ " } \n"
+ " tensor_content: \"\\000\\000\\200?\\000\\000\\000@\\000\\000@@\\000\\000\\200@"
+ "\\000\\000\\240@\\000\\000\\300@\\000\\000\\340@\\000\\000\\000A\\000\\000\\020A\" \n"
+ " } \n"
+ " } \n"
+ " } \n"
+ " } \n"
+ " node { \n"
+ " name: \"Const_1\" \n"
+ " op: \"Const\" \n"
+ " attr { \n"
+ " key: \"dtype\" \n"
+ " value { \n"
+ " type: DT_FLOAT \n"
+ " } \n"
+ " } \n"
+ " attr { \n"
+ " key: \"value\" \n"
+ " value { \n"
+ " tensor { \n"
+ " dtype: DT_FLOAT \n"
+ " tensor_shape { \n"
+ " dim { \n"
+ " size: 1 \n"
+ " } \n"
+ " dim { \n"
+ " size: 3 \n"
+ " } \n"
+ " dim { \n"
+ " size: 3 \n"
+ " } \n"
+ " dim { \n"
+ " size: 3 \n"
+ " } \n"
+ " } \n"
+ " tensor_content: \"\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?"
+ "\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?"
+ "\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?"
+ "\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?"
+ "\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?"
+ "\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?"
+ "\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?"
+ "\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?"
+ "\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?\" \n"
+ " } \n"
+ " } \n"
+ " } \n"
+ "} \n"
+ "node { \n"
+ " name: \"potato\" \n"
+ " op: \"DepthwiseConv2dNative\" \n"
+ " input: \"graphInput\" \n"
+ " input: \"Const_1\" \n"
+ " attr { \n"
+ " key: \"T\" \n"
+ " value { \n"
+ " type: DT_FLOAT \n"
+ " } \n"
+ " } \n"
+ " attr { \n"
+ " key: \"data_format\" \n"
+ " value { \n"
+ " s: \"NHWC\" \n"
+ " } \n"
+ " } \n"
+ " attr { \n"
+ " key: \"padding\" \n"
+ " value { \n"
+ " s: \"";
+ m_Prototext.append(paddingType);
+ m_Prototext.append("\"\n"
+ " } \n"
+ " } \n"
+ " attr { \n"
+ " key: \"strides\" \n"
+ " value { \n"
+ " list { \n"
+ " i: 1 \n"
+ " i: 1 \n"
+ " i: 1 \n"
+ " i: 1 \n"
+ " } \n"
+ " } \n"
+ " } \n"
+ " attr { \n"
+ " key: \"use_cudnn_on_gpu\" \n"
+ " value { \n"
+ " b: false \n"
+ " } \n"
+ " } \n"
+ "} \n");
+
+ SetupSingleInputSingleOutput({ 1, 1, 3, 3 }, "graphInput", "potato");
+ }
+};
+
+struct DepthwiseConvolution2dSameFixture : DepthwiseConvolution2dFixture
+{
+ DepthwiseConvolution2dSameFixture() : DepthwiseConvolution2dFixture("SAME") { }
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseDepthwiseConv2DSame, DepthwiseConvolution2dSameFixture)
+{
+ RunTest<4>({ 1, 2, 3, 4, 5, 6, 7, 8, 9 },
+ { 2.5f, 5.f, 2.5f, 3.5f, 7.f, 3.5f, 4.5f, 9.f, 4.5f,
+ 6.f, 12.f, 6.f, 7.5f, 15.f, 7.5f, 9.f, 18.f, 9.f,
+ 5.5f, 11.f, 5.5f, 6.5f, 13.f, 6.5f, 7.5f, 15.f, 7.5f});
+}
+
+struct DepthwiseConvolution2dValidFixture : DepthwiseConvolution2dFixture
+{
+ DepthwiseConvolution2dValidFixture() : DepthwiseConvolution2dFixture("VALID") { }
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseDepthwiseConv2DValid, DepthwiseConvolution2dValidFixture)
+{
+ RunTest<4>({ 1, 2, 3, 4, 5, 6, 7, 8, 9 }, // input data
+ { 6.f, 12.f, 6.f, 7.5f, 15.f, 7.5f, 9.f, 18.f, 9.f }); // output expected data
+}
+
+
+BOOST_AUTO_TEST_SUITE_END()
diff --git a/src/armnnTfParser/test/FullyConnected.cpp b/src/armnnTfParser/test/FullyConnected.cpp
new file mode 100644
index 0000000000..2a7b4951b7
--- /dev/null
+++ b/src/armnnTfParser/test/FullyConnected.cpp
@@ -0,0 +1,579 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// See LICENSE file in the project root for full license information.
+//
+
+#include <boost/test/unit_test.hpp>
+#include "armnnTfParser/ITfParser.hpp"
+#include "ParserPrototxtFixture.hpp"
+#include "Runtime.hpp"
+#include "Network.hpp"
+#include "Graph.hpp"
+
+BOOST_AUTO_TEST_SUITE(TensorflowParser)
+
+// In Tensorflow fully connected layers are expressed as a MatMul followed by an Add.
+// The TfParser must detect this case and convert them to a FullyConnected layer.
+struct FullyConnectedFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser>
+{
+ FullyConnectedFixture()
+ {
+ // input = tf.placeholder(tf.float32, [1, 1], "input")
+ // weights = tf.constant([2], tf.float32, [1, 1])
+ // matmul = tf.matmul(input, weights)
+ // bias = tf.constant([1], tf.float32)
+ // output = tf.add(matmul, bias, name="output")
+ m_Prototext = R"(
+node {
+ name: "input"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+}
+node {
+ name: "Const"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 1
+ }
+ }
+ float_val: 2.0
+ }
+ }
+ }
+}
+node {
+ name: "MatMul"
+ op: "MatMul"
+ input: "input"
+ input: "Const"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "transpose_a"
+ value {
+ b: false
+ }
+ }
+ attr {
+ key: "transpose_b"
+ value {
+ b: false
+ }
+ }
+}
+node {
+ name: "Const_1"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ float_val: 1.0
+ }
+ }
+ }
+}
+node {
+ name: "output"
+ op: "Add"
+ input: "MatMul"
+ input: "Const_1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+ )";
+ SetupSingleInputSingleOutput({ 1, 1 }, "input", "output");
+ }
+};
+
+BOOST_FIXTURE_TEST_CASE(FullyConnected, FullyConnectedFixture)
+{
+ RunTest<1>({ 3 }, { 7 });
+}
+
+// Similar to FullyConnectedFixture, but this time the MatMul's output is used by two Adds. This should result
+// in two FullyConnected layers being created.
+// I
+// |
+// M -- C
+// / \'
+// C-- A A -- C
+// \ /
+// A
+struct MatMulUsedInTwoFcFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser>
+{
+ MatMulUsedInTwoFcFixture()
+ {
+ m_Prototext = R"(
+node {
+ name: "input"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+}
+node {
+ name: "Const"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 1
+ }
+ }
+ float_val: 2.0
+ }
+ }
+ }
+}
+node {
+ name: "MatMul"
+ op: "MatMul"
+ input: "input"
+ input: "Const"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "transpose_a"
+ value {
+ b: false
+ }
+ }
+ attr {
+ key: "transpose_b"
+ value {
+ b: false
+ }
+ }
+}
+node {
+ name: "Const_1"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ float_val: 5.0
+ }
+ }
+ }
+}
+node {
+ name: "Const_2"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ float_val: 15.0
+ }
+ }
+ }
+}
+node {
+ name: "Add"
+ op: "Add"
+ input: "MatMul"
+ input: "Const_1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "Add_1"
+ op: "Add"
+ input: "MatMul"
+ input: "Const_2"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "output"
+ op: "Add"
+ input: "Add"
+ input: "Add_1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+ )";
+ SetupSingleInputSingleOutput({ 1, 1 }, "input", "output");
+ }
+};
+
+BOOST_FIXTURE_TEST_CASE(MatMulUsedInTwoFc, MatMulUsedInTwoFcFixture)
+{
+ RunTest<1>({ 3 }, { 32 });
+ // Ideally we would check here that the armnn network has 5 layers:
+ // Input, 2 x FullyConnected (biased), Add and Output.
+ // This would make sure the parser hasn't incorrectly added some unconnected layers corresponding to the MatMul
+}
+
+// Similar to MatMulUsedInTwoFc, but this time the Adds are 'staggered' (see diagram), which means that only one
+// FullyConnected layer can be created (the other should just be an Add).
+// I
+// |
+// M -- C1
+// / \'
+// C2 -- A |
+// \ /
+// A
+struct MatMulUsedInTwoFcStaggeredFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser>
+{
+ MatMulUsedInTwoFcStaggeredFixture()
+ {
+ // input = tf.placeholder(tf.float32, shape=[1,1], name = "input")
+ // const1 = tf.constant([17], tf.float32, [1,1])
+ // mul = tf.matmul(input, const1)
+ // const2 = tf.constant([7], tf.float32, [1])
+ // fc = tf.add(mul, const2)
+ // output = tf.add(mul, fc, name="output")
+ m_Prototext = R"(
+node {
+ name: "input"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+}
+node {
+ name: "Const"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 1
+ }
+ }
+ float_val: 17.0
+ }
+ }
+ }
+}
+node {
+ name: "MatMul"
+ op: "MatMul"
+ input: "input"
+ input: "Const"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "transpose_a"
+ value {
+ b: false
+ }
+ }
+ attr {
+ key: "transpose_b"
+ value {
+ b: false
+ }
+ }
+}
+node {
+ name: "Const_1"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ float_val: 7.0
+ }
+ }
+ }
+}
+node {
+ name: "Add"
+ op: "Add"
+ input: "MatMul"
+ input: "Const_1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "output"
+ op: "Add"
+ input: "MatMul"
+ input: "Add"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+ )";
+ SetupSingleInputSingleOutput({ 1, 1 }, "input", "output");
+ }
+};
+
+BOOST_FIXTURE_TEST_CASE(MatMulUsedInTwoFcStaggered, MatMulUsedInTwoFcStaggeredFixture)
+{
+ RunTest<1>({ 2 }, { 75 });
+ // Ideally we would check here that the armnn network has 5 layers:
+ // Input, FullyConnected (biased), FullyConnected (non biased), Add and Output.
+}
+
+// A MatMul in isolation, not connected to an add. Should result in a non-biased FullyConnected layer.
+struct MatMulFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser>
+{
+ MatMulFixture()
+ {
+ // input = tf.placeholder(tf.float32, shape = [1, 1], name = "input")
+ // const = tf.constant([17], tf.float32, [1, 1])
+ // output = tf.matmul(input, const, name = "output")
+ m_Prototext = R"(
+node {
+ name: "input"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+}
+node {
+ name: "Const"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 1
+ }
+ }
+ float_val: 17.0
+ }
+ }
+ }
+}
+node {
+ name: "output"
+ op: "MatMul"
+ input: "input"
+ input: "Const"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "transpose_a"
+ value {
+ b: false
+ }
+ }
+ attr {
+ key: "transpose_b"
+ value {
+ b: false
+ }
+ }
+}
+ )";
+ SetupSingleInputSingleOutput({ 1, 1 }, "input", "output");
+ }
+};
+
+BOOST_FIXTURE_TEST_CASE(MatMul, MatMulFixture)
+{
+ RunTest<1>({ 2 }, { 34 });
+}
+
+BOOST_AUTO_TEST_SUITE_END()
diff --git a/src/armnnTfParser/test/FusedBatchNorm.cpp b/src/armnnTfParser/test/FusedBatchNorm.cpp
new file mode 100644
index 0000000000..632d5f01f9
--- /dev/null
+++ b/src/armnnTfParser/test/FusedBatchNorm.cpp
@@ -0,0 +1,175 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// See LICENSE file in the project root for full license information.
+//
+
+#include <boost/test/unit_test.hpp>
+#include "armnnTfParser/ITfParser.hpp"
+#include "ParserPrototxtFixture.hpp"
+
+BOOST_AUTO_TEST_SUITE(TensorflowParser)
+
+struct FusedBatchNormFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser>
+{
+ FusedBatchNormFixture()
+ {
+ m_Prototext = "node { \n"
+ " name: \"graphInput\" \n"
+ " op: \"Placeholder\" \n"
+ " attr { \n"
+ " key: \"dtype\" \n"
+ " value { \n"
+ " type: DT_FLOAT \n"
+ " } \n"
+ " } \n"
+ " attr { \n"
+ " key: \"shape\" \n"
+ " value { \n"
+ " shape { \n"
+ " } \n"
+ " } \n"
+ " } \n"
+ "} \n"
+ "node { \n"
+ " name: \"Const_1\" \n"
+ " op: \"Const\" \n"
+ " attr { \n"
+ " key: \"dtype\" \n"
+ " value { \n"
+ " type: DT_FLOAT \n"
+ " } \n"
+ " } \n"
+ " attr { \n"
+ " key: \"value\" \n"
+ " value { \n"
+ " tensor { \n"
+ " dtype: DT_FLOAT \n"
+ " tensor_shape { \n"
+ " dim { \n"
+ " size: 1 \n"
+ " } \n"
+ " } \n"
+ " float_val: 1.0 \n"
+ " } \n"
+ " } \n"
+ " } \n"
+ "} \n"
+ "node { \n"
+ " name: \"Const_2\" \n"
+ " op: \"Const\" \n"
+ " attr { \n"
+ " key: \"dtype\" \n"
+ " value { \n"
+ " type: DT_FLOAT \n"
+ " } \n"
+ " } \n"
+ " attr { \n"
+ " key: \"value\" \n"
+ " value { \n"
+ " tensor { \n"
+ " dtype: DT_FLOAT \n"
+ " tensor_shape { \n"
+ " dim { \n"
+ " size: 1 \n"
+ " } \n"
+ " } \n"
+ " float_val: 0.0 \n"
+ " } \n"
+ " } \n"
+ " } \n"
+ "} \n"
+ "node { \n"
+ " name: \"FusedBatchNormLayer/mean\" \n"
+ " op: \"Const\" \n"
+ " attr { \n"
+ " key: \"dtype\" \n"
+ " value { \n"
+ " type: DT_FLOAT \n"
+ " } \n"
+ " } \n"
+ " attr { \n"
+ " key: \"value\" \n"
+ " value { \n"
+ " tensor { \n"
+ " dtype: DT_FLOAT \n"
+ " tensor_shape { \n"
+ " dim { \n"
+ " size: 1 \n"
+ " } \n"
+ " } \n"
+ " float_val: 5.0 \n"
+ " } \n"
+ " } \n"
+ " } \n"
+ "} \n"
+ "node { \n"
+ " name: \"FusedBatchNormLayer/variance\" \n"
+ " op: \"Const\" \n"
+ " attr { \n"
+ " key: \"dtype\" \n"
+ " value { \n"
+ " type: DT_FLOAT \n"
+ " } \n"
+ " } \n"
+ " attr { \n"
+ " key: \"value\" \n"
+ " value { \n"
+ " tensor { \n"
+ " dtype: DT_FLOAT \n"
+ " tensor_shape { \n"
+ " dim { \n"
+ " size: 1 \n"
+ " } \n"
+ " } \n"
+ " float_val: 2.0 \n"
+ " } \n"
+ " } \n"
+ " } \n"
+ "} \n"
+ "node { \n"
+ " name: \"output\" \n"
+ " op: \"FusedBatchNorm\" \n"
+ " input: \"graphInput\" \n"
+ " input: \"Const_1\" \n"
+ " input: \"Const_2\" \n"
+ " input: \"FusedBatchNormLayer/mean\" \n"
+ " input: \"FusedBatchNormLayer/variance\" \n"
+ " attr { \n"
+ " key: \"T\" \n"
+ " value { \n"
+ " type: DT_FLOAT \n"
+ " } \n"
+ " } \n"
+ " attr { \n"
+ " key: \"data_format\" \n"
+ " value { \n"
+ " s: \"NHWC\" \n"
+ " } \n"
+ " } \n"
+ " attr { \n"
+ " key: \"epsilon\" \n"
+ " value { \n"
+ " f: 0.0010000000475 \n"
+ " } \n"
+ " } \n"
+ " attr { \n"
+ " key: \"is_training\" \n"
+ " value { \n"
+ " b: false \n"
+ " } \n"
+ " } \n"
+ "} \n";
+
+ SetupSingleInputSingleOutput({1, 3, 3, 1}, "graphInput", "output");
+ }
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseFusedBatchNorm, FusedBatchNormFixture)
+{
+ RunTest<4>({1, 2, 3, 4, 5, 6, 7, 8, 9}, // input data
+ {-2.8277204f, -2.12079024f, -1.4138602f,
+ -0.7069301f, 0.0f, 0.7069301f,
+ 1.4138602f, 2.12079024f, 2.8277204f}); // expected output data
+}
+
+BOOST_AUTO_TEST_SUITE_END()
diff --git a/src/armnnTfParser/test/Identity.cpp b/src/armnnTfParser/test/Identity.cpp
new file mode 100644
index 0000000000..ca20de5760
--- /dev/null
+++ b/src/armnnTfParser/test/Identity.cpp
@@ -0,0 +1,161 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// See LICENSE file in the project root for full license information.
+//
+
+#include <boost/test/unit_test.hpp>
+#include "armnnTfParser/ITfParser.hpp"
+#include "ParserPrototxtFixture.hpp"
+
+BOOST_AUTO_TEST_SUITE(TensorflowParser)
+
+struct IdentitySimpleFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser>
+{
+ IdentitySimpleFixture()
+ {
+ m_Prototext = "node{ "
+ " name: \"Placeholder\""
+ " op: \"Placeholder\""
+ " attr {"
+ " key: \"dtype\""
+ " value {"
+ " type: DT_FLOAT"
+ " }"
+ " }"
+ " attr {"
+ " key: \"shape\""
+ " value {"
+ " shape {"
+ " unknown_rank: true"
+ " }"
+ " }"
+ " }"
+ "}"
+ "node {"
+ " name: \"Identity\""
+ " op: \"Identity\""
+ " input: \"Placeholder\""
+ " attr {"
+ " key: \"T\""
+ " value {"
+ " type: DT_FLOAT"
+ " }"
+ " }"
+ "}";
+ SetupSingleInputSingleOutput({ 4 }, "Placeholder", "Identity");
+ }
+};
+
+BOOST_FIXTURE_TEST_CASE(IdentitySimple, IdentitySimpleFixture)
+{
+ RunTest<1>({ 1.0f, 2.0f, 3.0f, 4.0f }, { 1.0f, 2.0f, 3.0f, 4.0f });
+}
+
+struct IdentityFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser>
+{
+ IdentityFixture()
+ {
+ m_Prototext = "node{ "
+ " name: \"Placeholder\""
+ " op: \"Placeholder\""
+ " attr {"
+ " key: \"dtype\""
+ " value {"
+ " type: DT_FLOAT"
+ " }"
+ " }"
+ " attr {"
+ " key: \"shape\""
+ " value {"
+ " shape {"
+ " unknown_rank: true"
+ " }"
+ " }"
+ " }"
+ "}"
+ "node {"
+ " name: \"Identity\""
+ " op: \"Identity\""
+ " input: \"Placeholder\""
+ " attr {"
+ " key: \"T\""
+ " value {"
+ " type: DT_FLOAT"
+ " }"
+ " }"
+ "}"
+ "node {"
+ " name: \"Add\""
+ " op: \"Add\""
+ " input: \"Identity\""
+ " input: \"Identity\""
+ " attr {"
+ " key: \"T\""
+ " value {"
+ " type: DT_FLOAT"
+ " }"
+ " }"
+ "}";
+ SetupSingleInputSingleOutput({ 4 }, "Placeholder", "Add");
+ }
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseIdentity, IdentityFixture)
+{
+ RunTest<1>({ 1.0f, 2.0f, 3.0f, 4.0f }, { 2.0f, 4.0f, 6.0f, 8.0f });
+}
+
+struct IdentityChainFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser>
+{
+ IdentityChainFixture()
+ {
+ m_Prototext = "node{ "
+ " name: \"Placeholder\""
+ " op: \"Placeholder\""
+ " attr {"
+ " key: \"dtype\""
+ " value {"
+ " type: DT_FLOAT"
+ " }"
+ " }"
+ " attr {"
+ " key: \"shape\""
+ " value {"
+ " shape {"
+ " unknown_rank: true"
+ " }"
+ " }"
+ " }"
+ "}"
+ "node {"
+ " name: \"Identity\""
+ " op: \"Identity\""
+ " input: \"Placeholder\""
+ " attr {"
+ " key: \"T\""
+ " value {"
+ " type: DT_FLOAT"
+ " }"
+ " }"
+ "}"
+ "node {"
+ " name: \"Identity2\""
+ " op: \"Identity\""
+ " input: \"Identity\""
+ " attr {"
+ " key: \"T\""
+ " value {"
+ " type: DT_FLOAT"
+ " }"
+ " }"
+ "}";
+ SetupSingleInputSingleOutput({ 4 }, "Placeholder", "Identity2");
+ }
+};
+
+BOOST_FIXTURE_TEST_CASE(IdentityChain, IdentityChainFixture)
+{
+ RunTest<1>({ 1.0f, 2.0f, 3.0f, 4.0f }, { 1.0f, 2.0f, 3.0f, 4.0f });
+}
+
+BOOST_AUTO_TEST_SUITE_END()
diff --git a/src/armnnTfParser/test/LocalResponseNormalization.cpp b/src/armnnTfParser/test/LocalResponseNormalization.cpp
new file mode 100644
index 0000000000..a7c2bfe3e1
--- /dev/null
+++ b/src/armnnTfParser/test/LocalResponseNormalization.cpp
@@ -0,0 +1,121 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// See LICENSE file in the project root for full license information.
+//
+
+#include <boost/test/unit_test.hpp>
+#include "armnnTfParser/ITfParser.hpp"
+#include "ParserPrototxtFixture.hpp"
+
+BOOST_AUTO_TEST_SUITE(TensorflowParser)
+
+
+struct LocalResponseNormalizationBaseFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser>
+{
+ explicit LocalResponseNormalizationBaseFixture(float alpha, float beta, float bias)
+ {
+ std::string alphaString = std::to_string(alpha);
+ std::string betaString = std::to_string(beta);
+ std::string biasString = std::to_string(bias);
+
+ m_Prototext = "node {"
+ " name: \"Placeholder\""
+ " op: \"Placeholder\""
+ " attr {"
+ " key: \"dtype\""
+ " value {"
+ " type: DT_FLOAT"
+ " }"
+ " }"
+ " attr {"
+ " key: \"shape\""
+ " value {"
+ " shape {"
+ " unknown_rank: true"
+ " }"
+ " }"
+ " }"
+ "}"
+ "node {"
+ " name: \"LRN\""
+ " op: \"LRN\""
+ " input: \"Placeholder\""
+ " attr {"
+ " key: \"T\""
+ " value {"
+ " type: DT_FLOAT"
+ " }"
+ " }"
+ " attr {"
+ " key: \"alpha\""
+ " value {"
+ " f: ";
+ m_Prototext.append(alphaString);
+ m_Prototext.append("\n"
+ " }"
+ " }"
+ " attr {"
+ " key: \"beta\""
+ " value {"
+ " f: ");
+ m_Prototext.append(betaString);
+ m_Prototext.append("\n"
+ " }"
+ " }"
+ " attr {"
+ " key: \"bias\""
+ " value {"
+ " f: ");
+ m_Prototext.append(biasString);
+ m_Prototext.append("\n"
+ " }"
+ " }"
+ " attr {"
+ " key: \"depth_radius\""
+ " value {"
+ " i: 1"
+ " }"
+ " }"
+ "}");
+ }
+};
+
+
+struct LocalResponseNormalizationFixtureSimple : public LocalResponseNormalizationBaseFixture
+{
+ explicit LocalResponseNormalizationFixtureSimple()
+ : LocalResponseNormalizationBaseFixture(1.0f, 1.0f, 1.0f)
+ {
+ SetupSingleInputSingleOutput({ 2, 2, 2, 1 }, "Placeholder", "LRN");
+ }
+};
+BOOST_FIXTURE_TEST_CASE(ParseSimpleLocalResponseNormalization, LocalResponseNormalizationFixtureSimple)
+{
+ RunTest<4>({ 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f },
+ { 0.5f, 0.4f, 0.3f, 0.23529412f, 0.1923077f, 0.16216217f, 0.14f, 0.12307692f });
+}
+
+
+struct LocalResponseNormalizationFixture : public LocalResponseNormalizationBaseFixture
+{
+ explicit LocalResponseNormalizationFixture()
+ : LocalResponseNormalizationBaseFixture(0.5f, 1.0f, 0.5f)
+ {
+ SetupSingleInputSingleOutput({1, 3, 3, 2}, "Placeholder", "LRN");
+ }
+};
+BOOST_FIXTURE_TEST_CASE(ParseLocalResponseNormalization, LocalResponseNormalizationFixture)
+{
+ RunTest<4>({ 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
+ 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f,
+ 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f},
+
+ {0.333333340f, 0.66666670f, 0.230769250f, 0.307692320f, 0.161290320f, 0.19354838f,
+ 0.122807020f, 0.14035088f, 0.098901100f, 0.109890110f, 0.082706770f, 0.09022556f,
+ 0.071038246f, 0.07650273f, 0.062240668f, 0.066390045f, 0.055374593f, 0.05863192f});
+}
+
+
+
+
+BOOST_AUTO_TEST_SUITE_END()
diff --git a/src/armnnTfParser/test/MultiOutput.cpp b/src/armnnTfParser/test/MultiOutput.cpp
new file mode 100644
index 0000000000..56be33dab7
--- /dev/null
+++ b/src/armnnTfParser/test/MultiOutput.cpp
@@ -0,0 +1,144 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// See LICENSE file in the project root for full license information.
+//
+
+#include <boost/test/unit_test.hpp>
+#include "armnnTfParser/ITfParser.hpp"
+#include "ParserPrototxtFixture.hpp"
+
+BOOST_AUTO_TEST_SUITE(TensorflowParser)
+
+struct MultiOutMatchFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser>
+{
+ MultiOutMatchFixture()
+ {
+ m_Prototext = R"(
+node {
+ name: "input"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ }
+ }
+ }
+}
+node {
+ name: "softmax1"
+ op: "Softmax"
+ input: "input:0"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+ )";
+ SetupSingleInputSingleOutput({ 1, 7 }, "input", "softmax1");
+ }
+};
+
+BOOST_FIXTURE_TEST_CASE(MultiOutMatch, MultiOutMatchFixture)
+{
+ // Note that the point of this test is to verify the parsing went well.
+ // Here we make sure the softmax has really connected to the input layer.
+ RunTest<2>({ 0, 0, 10000, 0, 0, 0, 0 }, { 0, 0, 1, 0, 0, 0, 0 });
+}
+
+struct MultiOutFailFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser>
+{
+ MultiOutFailFixture()
+ {
+ m_Prototext = R"(
+node {
+ name: "input"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ }
+ }
+ }
+}
+node {
+ name: "softmax1"
+ op: "Softmax"
+ input: "input:1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+ )";
+ BOOST_CHECK_THROW(SetupSingleInputSingleOutput({ 1, 7 }, "input", "softmax1"), armnn::ParseException);
+ }
+};
+
+BOOST_FIXTURE_TEST_CASE(MultiOutFail, MultiOutFailFixture)
+{
+ // Not running the graph because this is expected to throw an exception during parsing.
+}
+
+struct MultiOutInvalidFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser>
+{
+ MultiOutInvalidFixture()
+ {
+ m_Prototext = R"(
+node {
+ name: "input"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ }
+ }
+ }
+}
+node {
+ name: "softmax1"
+ op: "Softmax"
+ input: "input:-1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+ )";
+ BOOST_CHECK_THROW(SetupSingleInputSingleOutput({ 1, 7 }, "input", "softmax1"), armnn::ParseException);
+ }
+};
+
+BOOST_FIXTURE_TEST_CASE(MultiOutInvalid, MultiOutInvalidFixture)
+{
+ // Not running the graph because this is expected to throw an exception during parsing.
+}
+
+
+BOOST_AUTO_TEST_SUITE_END() \ No newline at end of file
diff --git a/src/armnnTfParser/test/Multiplication.cpp b/src/armnnTfParser/test/Multiplication.cpp
new file mode 100644
index 0000000000..3a20fd1141
--- /dev/null
+++ b/src/armnnTfParser/test/Multiplication.cpp
@@ -0,0 +1,172 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// See LICENSE file in the project root for full license information.
+//
+
+#include <boost/test/unit_test.hpp>
+#include "armnnTfParser/ITfParser.hpp"
+#include "ParserPrototxtFixture.hpp"
+
+BOOST_AUTO_TEST_SUITE(TensorflowParser)
+
+struct MultiplicationFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser>
+{
+ MultiplicationFixture()
+ {
+ m_Prototext = "node { \n"
+ " name: \"graphInput\" \n"
+ " op: \"Placeholder\" \n"
+ " attr { \n"
+ " key: \"dtype\" \n"
+ " value { \n"
+ " type: DT_FLOAT \n"
+ " } \n"
+ " } \n"
+ " attr { \n"
+ " key: \"shape\" \n"
+ " value { \n"
+ " shape { \n"
+ " } \n"
+ " } \n"
+ " } \n"
+ " } \n"
+ " node { \n"
+ " name: \"softmax1\" \n"
+ " op: \"Softmax\" \n"
+ " input: \"graphInput\" \n"
+ " attr { \n"
+ " key: \"T\" \n"
+ " value { \n"
+ " type: DT_FLOAT \n"
+ " } \n"
+ " } \n"
+ " }\n"
+ " node {\n"
+ " name: \"softmax2\"\n"
+ " op : \"Softmax\"\n"
+ " input: \"graphInput\"\n"
+ " attr { \n"
+ " key: \"T\" \n"
+ " value { \n"
+ " type: DT_FLOAT \n"
+ " } \n"
+ " } \n"
+ " }\n"
+ " node {\n"
+ " name: \"multiplication\"\n"
+ " op : \"Mul\"\n"
+ " input: \"softmax1\"\n"
+ " input: \"softmax2\"\n"
+ " attr { \n"
+ " key: \"T\" \n"
+ " value { \n"
+ " type: DT_FLOAT \n"
+ " } \n"
+ " } \n"
+ " }\n";
+
+ SetupSingleInputSingleOutput({ 1, 7 }, "graphInput", "multiplication");
+ }
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseMultiplication, MultiplicationFixture)
+{
+ RunTest<2>({ 0, 0, 10000, 0, 0, 0, 0 }, { 0, 0, 1, 0, 0, 0, 0 });
+}
+
+struct MultiplicationBroadcastFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser>
+{
+ MultiplicationBroadcastFixture(const armnn::TensorShape& inputShape0, const armnn::TensorShape& inputShape1)
+ {
+ m_Prototext = R"(
+node {
+ name: "input0"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ }
+ }
+ }
+}
+node {
+ name: "input1"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ }
+ }
+ }
+}
+node {
+ name: "output"
+ op: "Mul"
+ input: "input0"
+ input: "input1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+ )";
+
+ Setup({ { "input0", inputShape0 },
+ { "input1", inputShape1 } },
+ { "output" });
+ }
+};
+
+struct MultiplicationBroadcastFixture4D1D : public MultiplicationBroadcastFixture
+{
+ MultiplicationBroadcastFixture4D1D() : MultiplicationBroadcastFixture({ 1, 2, 2, 3 }, { 1 }) {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseMultiplicationBroadcast4D1D, MultiplicationBroadcastFixture4D1D)
+{
+ RunTest<4>({ { "input0", { 0.0f, 1.0f, 2.0f,
+ 3.0f, 4.0f, 5.0f,
+ 6.0f, 7.0f, 8.0f,
+ 9.0f, 10.0f, 11.0f } },
+ { "input1", { 5.0f } } },
+ { { "output", { 0.0f, 5.0f, 10.0f,
+ 15.0f, 20.0f, 25.0f,
+ 30.0f, 35.0f, 40.0f,
+ 45.0f, 50.0f, 55.0f } } });
+}
+
+struct MultiplicationBroadcastFixture1D4D : public MultiplicationBroadcastFixture
+{
+ MultiplicationBroadcastFixture1D4D() : MultiplicationBroadcastFixture({ 1 }, { 1, 2, 2, 3 }) {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseMultiplicationBroadcast1D4D, MultiplicationBroadcastFixture1D4D)
+{
+ RunTest<4>({ { "input0", { 3.0f } },
+ { "input1", { 0.0f, 1.0f, 2.0f,
+ 3.0f, 4.0f, 5.0f,
+ 6.0f, 7.0f, 8.0f,
+ 9.0f, 10.0f, 11.0f } } },
+ { { "output", { 0.0f, 3.0f, 6.0f,
+ 9.0f, 12.0f, 15.0f,
+ 18.0f, 21.0f, 24.0f,
+ 27.0f, 30.0f, 33.0f } } });
+}
+
+BOOST_AUTO_TEST_SUITE_END()
diff --git a/src/armnnTfParser/test/PassThru.cpp b/src/armnnTfParser/test/PassThru.cpp
new file mode 100644
index 0000000000..8462ec27cc
--- /dev/null
+++ b/src/armnnTfParser/test/PassThru.cpp
@@ -0,0 +1,52 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// See LICENSE file in the project root for full license information.
+//
+#include <boost/test/unit_test.hpp>
+#include "armnnTfParser/ITfParser.hpp"
+#include "ParserPrototxtFixture.hpp"
+
+BOOST_AUTO_TEST_SUITE(TensorflowParser)
+
+struct PassThruFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser>
+{
+ PassThruFixture()
+ {
+ m_Prototext = "node {\n"
+ " name: \"Placeholder\"\n"
+ " op: \"Placeholder\"\n"
+ " attr {\n"
+ " key: \"dtype\"\n"
+ " value {\n"
+ " type: DT_FLOAT\n"
+ " }\n"
+ " }\n"
+ " attr {\n"
+ " key: \"shape\"\n"
+ " value {\n"
+ " shape {\n"
+ " }\n"
+ " }\n"
+ " }\n"
+ "}\n";
+ SetupSingleInputSingleOutput({ 1, 7 }, "Placeholder", "Placeholder");
+ }
+};
+
+BOOST_FIXTURE_TEST_CASE(ValidateOutput, PassThruFixture)
+{
+ BOOST_TEST(m_Parser->GetNetworkOutputBindingInfo("Placeholder").second.GetNumDimensions() == 2);
+ BOOST_TEST(m_Parser->GetNetworkOutputBindingInfo("Placeholder").second.GetShape()[0] == 1);
+ BOOST_TEST(m_Parser->GetNetworkOutputBindingInfo("Placeholder").second.GetShape()[1] == 7);
+}
+
+BOOST_FIXTURE_TEST_CASE(RunGraph, PassThruFixture)
+{
+ armnn::TensorInfo inputTensorInfo = m_Parser->GetNetworkInputBindingInfo("Placeholder").second;
+ auto input = MakeRandomTensor<float, 2>(inputTensorInfo, 378346);
+ std::vector<float> inputVec;
+ inputVec.assign(input.data(), input.data() + input.num_elements());
+ RunTest<2>(inputVec, inputVec); // The passthru network should output the same as the input
+}
+
+BOOST_AUTO_TEST_SUITE_END()
diff --git a/src/armnnTfParser/test/Pooling.cpp b/src/armnnTfParser/test/Pooling.cpp
new file mode 100644
index 0000000000..36ffa47def
--- /dev/null
+++ b/src/armnnTfParser/test/Pooling.cpp
@@ -0,0 +1,112 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// See LICENSE file in the project root for full license information.
+//
+
+#include <boost/test/unit_test.hpp>
+#include "armnnTfParser/ITfParser.hpp"
+#include "ParserPrototxtFixture.hpp"
+
+BOOST_AUTO_TEST_SUITE(TensorflowParser)
+
+
+struct Pooling2dFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser>
+{
+ explicit Pooling2dFixture(const char* poolingtype)
+ {
+ m_Prototext = "node {\n"
+ " name: \"Placeholder\"\n"
+ " op: \"Placeholder\"\n"
+ " attr {\n"
+ " key: \"dtype\"\n"
+ " value {\n"
+ " type: DT_FLOAT\n"
+ " }\n"
+ " }\n"
+ " attr {\n"
+ " key: \"value\"\n"
+ " value {\n"
+ " tensor {\n"
+ " dtype: DT_FLOAT\n"
+ " tensor_shape {\n"
+ " }\n"
+ " }\n"
+ " }\n"
+ " }\n"
+ " }\n"
+ "node {\n"
+ " name: \"";
+ m_Prototext.append(poolingtype);
+ m_Prototext.append("\"\n"
+ " op: \"");
+ m_Prototext.append(poolingtype);
+ m_Prototext.append("\"\n"
+ " input: \"Placeholder\"\n"
+ " attr {\n"
+ " key: \"T\"\n"
+ " value {\n"
+ " type: DT_FLOAT\n"
+ " }\n"
+ " }\n"
+ " attr {\n"
+ " key: \"data_format\"\n"
+ " value {\n"
+ " s: \"NHWC\"\n"
+ " }\n"
+ " }\n"
+ " attr {\n"
+ " key: \"ksize\"\n"
+ " value {\n"
+ " list {\n"
+ " i: 1\n"
+ " i: 2\n"
+ " i: 2\n"
+ " i: 1\n"
+ " }\n"
+ " }\n"
+ " }\n"
+ " attr {\n"
+ " key: \"padding\"\n"
+ " value {\n"
+ " s: \"VALID\"\n"
+ " }\n"
+ " }\n"
+ " attr {\n"
+ " key: \"strides\"\n"
+ " value {\n"
+ " list {\n"
+ " i: 1\n"
+ " i: 1\n"
+ " i: 1\n"
+ " i: 1\n"
+ " }\n"
+ " }\n"
+ " }\n"
+ "}\n");
+
+ SetupSingleInputSingleOutput({ 1, 2, 2, 1 }, "Placeholder", poolingtype);
+ }
+};
+
+
+struct MaxPoolFixture : Pooling2dFixture
+{
+ MaxPoolFixture() : Pooling2dFixture("MaxPool") {}
+};
+BOOST_FIXTURE_TEST_CASE(ParseMaxPool, MaxPoolFixture)
+{
+ RunTest<4>({1.0f, 2.0f, 3.0f, -4.0f}, {3.0f});
+}
+
+
+struct AvgPoolFixture : Pooling2dFixture
+{
+ AvgPoolFixture() : Pooling2dFixture("AvgPool") {}
+};
+BOOST_FIXTURE_TEST_CASE(ParseAvgPool, AvgPoolFixture)
+{
+ RunTest<4>({1.0f, 2.0f, 3.0f, 4.0f}, {2.5f});
+}
+
+
+BOOST_AUTO_TEST_SUITE_END()
diff --git a/src/armnnTfParser/test/Reshape.cpp b/src/armnnTfParser/test/Reshape.cpp
new file mode 100644
index 0000000000..4eb6b12467
--- /dev/null
+++ b/src/armnnTfParser/test/Reshape.cpp
@@ -0,0 +1,86 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// See LICENSE file in the project root for full license information.
+//
+
+#include <boost/test/unit_test.hpp>
+#include "armnnTfParser/ITfParser.hpp"
+#include "ParserPrototxtFixture.hpp"
+
+BOOST_AUTO_TEST_SUITE(TensorflowParser)
+
+
+struct ReshapeFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser>
+{
+ ReshapeFixture()
+ {
+ m_Prototext = "node { \n"
+ " name: \"graphInput\" \n"
+ " op: \"Placeholder\" \n"
+ " attr { \n"
+ " key: \"dtype\" \n"
+ " value { \n"
+ " type: DT_FLOAT \n"
+ " } \n"
+ " } \n"
+ " attr { \n"
+ " key: \"shape\" \n"
+ " value { \n"
+ " shape { \n"
+ " } \n"
+ " } \n"
+ " } \n"
+ " } \n"
+ "node { \n"
+ " name: \"Reshape/shape\" \n"
+ " op: \"Const\" \n"
+ " attr { \n"
+ " key: \"dtype\" \n"
+ " value { \n"
+ " type: DT_INT32 \n"
+ " } \n"
+ " } \n"
+ " attr { \n"
+ " key: \"value\" \n"
+ " value { \n"
+ " tensor { \n"
+ " dtype: DT_INT32 \n"
+ " tensor_shape { \n"
+ " dim { \n"
+ " size: 2 \n"
+ " } \n"
+ " } \n"
+ " tensor_content: \"\\002\\000\\000\\000\\002\\000\\000\\000\" \n"
+ " } \n"
+ " } \n"
+ " } \n"
+ "} \n"
+ "node { \n"
+ " name: \"Reshape\" \n"
+ " op: \"Reshape\" \n"
+ " input: \"graphInput\" \n"
+ " input: \"Reshape/shape\" \n"
+ " attr { \n"
+ " key: \"T\" \n"
+ " value { \n"
+ " type: DT_FLOAT \n"
+ " } \n"
+ " } \n"
+ " attr { \n"
+ " key: \"Tshape\" \n"
+ " value { \n"
+ " type: DT_INT32 \n"
+ " } \n"
+ " } \n"
+ "} \n";
+
+ SetupSingleInputSingleOutput({1, 4}, "graphInput", "Reshape");
+ }
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseReshape, ReshapeFixture)
+{
+ RunTest<2>({ 0.0f, 1.0f, 2.0f, 3.0f }, { 0.0f, 1.0f, 2.0f, 3.0f });
+}
+
+BOOST_AUTO_TEST_SUITE_END()
diff --git a/src/armnnTfParser/test/ResizeBilinear.cpp b/src/armnnTfParser/test/ResizeBilinear.cpp
new file mode 100644
index 0000000000..30d898f5bb
--- /dev/null
+++ b/src/armnnTfParser/test/ResizeBilinear.cpp
@@ -0,0 +1,114 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// See LICENSE file in the project root for full license information.
+//
+
+#include <boost/test/unit_test.hpp>
+#include "armnnTfParser/ITfParser.hpp"
+#include "ParserPrototxtFixture.hpp"
+
+BOOST_AUTO_TEST_SUITE(TensorflowParser)
+
+struct ResizeBilinearFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser>
+{
+ ResizeBilinearFixture()
+ {
+ m_Prototext = R"(
+node {
+ name: "graphInput"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 3
+ }
+ dim {
+ size: 3
+ }
+ dim {
+ size: 1
+ }
+ }
+ tensor_content:
+"\000\000\000\000\000\000\200?\000\000\000@\000\000@@\000\000\200@\000\000\240@\000\000\300@\000\000\340@\000\000\000A"
+ }
+ }
+ }
+}
+node {
+ name: "resizeBilinearLayer/size"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 2
+ }
+ }
+ tensor_content: "\005\000\000\000\005\000\000\000"
+ }
+ }
+ }
+}
+node {
+ name: "resizeBilinearLayer"
+ op: "ResizeBilinear"
+ input: "graphInput"
+ input: "resizeBilinearLayer/size"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "align_corners"
+ value {
+ b: false
+ }
+ }
+}
+ )";
+
+ SetupSingleInputSingleOutput({ 1, 3, 3, 1 }, "graphInput", "resizeBilinearLayer");
+ }
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseResizeBilinear, ResizeBilinearFixture)
+{
+ RunTest<4>(// input data
+ { 0.0f, 1.0f, 2.0f,
+ 3.0f, 4.0f, 5.0f,
+ 6.0f, 7.0f, 8.0f },
+ // expected output data
+ { 0.0f, 0.6f, 1.2f, 1.8f, 2.0f,
+ 1.8f, 2.4f, 3.0f, 3.6f, 3.8f,
+ 3.6f, 4.2f, 4.8f, 5.4f, 5.6f,
+ 5.4f, 6.0f, 6.6f, 7.2f, 7.4f,
+ 6.0f, 6.6f, 7.2f, 7.8f, 8.0f });
+
+}
+
+BOOST_AUTO_TEST_SUITE_END()
diff --git a/src/armnnTfParser/test/Shape.cpp b/src/armnnTfParser/test/Shape.cpp
new file mode 100644
index 0000000000..7b414ecfac
--- /dev/null
+++ b/src/armnnTfParser/test/Shape.cpp
@@ -0,0 +1,94 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// See LICENSE file in the project root for full license information.
+//
+
+#include <boost/test/unit_test.hpp>
+#include "armnnTfParser/ITfParser.hpp"
+#include "ParserPrototxtFixture.hpp"
+
+BOOST_AUTO_TEST_SUITE(TensorflowParser)
+
+struct ShapeFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser>
+{
+ ShapeFixture()
+ {
+ m_Prototext =
+ "node { \n"
+ " name: \"Placeholder\" \n"
+ " op: \"Placeholder\" \n"
+ " attr { \n"
+ " key: \"dtype\" \n"
+ " value { \n"
+ " type: DT_FLOAT \n"
+ " } \n"
+ " } \n"
+ " attr { \n"
+ " key: \"shape\" \n"
+ " value { \n"
+ " shape { \n"
+ " dim { \n"
+ " size: 1 \n"
+ " } \n"
+ " dim { \n"
+ " size: 1 \n"
+ " } \n"
+ " dim { \n"
+ " size: 1 \n"
+ " } \n"
+ " dim { \n"
+ " size: 4 \n"
+ " } \n"
+ " } \n"
+ " } \n"
+ " } \n"
+ "} \n"
+ "node { \n"
+ " name: \"shapeTest\" \n"
+ " op: \"Shape\" \n"
+ " input: \"Placeholder\" \n"
+ " attr { \n"
+ " key: \"T\" \n"
+ " value { \n"
+ " type: DT_FLOAT \n"
+ " } \n"
+ " } \n"
+ " attr { \n"
+ " key: \"out_type\" \n"
+ " value { \n"
+ " type: DT_INT32 \n"
+ " } \n"
+ " } \n"
+ "} \n"
+ "node { \n"
+ " name: \"Reshape\" \n"
+ " op: \"Reshape\" \n"
+ " input: \"Placeholder\" \n"
+ " input: \"shapeTest\" \n"
+ " attr { \n"
+ " key: \"T\" \n"
+ " value { \n"
+ " type: DT_FLOAT \n"
+ " } \n"
+ " } \n"
+ " attr { \n"
+ " key: \"Tshape\" \n"
+ " value { \n"
+ " type: DT_INT32 \n"
+ " } \n"
+ " } \n"
+ "} \n";
+
+ SetupSingleInputSingleOutput({1, 4}, "Placeholder", "Reshape");
+ }
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseShape, ShapeFixture)
+{
+ // Note: the test's output cannot be an int32 const layer, because that cannot exist in the
+ // as ARMNN only supports u8 and float layers. For that reason I added a reshape layer
+ // which reshapes the input to its original dimensions, which is not changing it.
+ RunTest<2>({ 0.0f, 1.0f, 2.0f, 3.0f }, { 0.0f, 1.0f, 2.0f, 3.0f });
+}
+
+BOOST_AUTO_TEST_SUITE_END()
diff --git a/src/armnnTfParser/test/Softmax.cpp b/src/armnnTfParser/test/Softmax.cpp
new file mode 100644
index 0000000000..1ab28ea3aa
--- /dev/null
+++ b/src/armnnTfParser/test/Softmax.cpp
@@ -0,0 +1,55 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// See LICENSE file in the project root for full license information.
+//
+
+#include <boost/test/unit_test.hpp>
+#include "armnnTfParser/ITfParser.hpp"
+#include "ParserPrototxtFixture.hpp"
+
+BOOST_AUTO_TEST_SUITE(TensorflowParser)
+
+struct SoftmaxFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser>
+{
+ SoftmaxFixture()
+ {
+ m_Prototext = "node {\n"
+ " name: \"blah\"\n"
+ " op: \"Placeholder\"\n"
+ " attr {\n"
+ " key: \"dtype\"\n"
+ " value {\n"
+ " type: DT_FLOAT\n"
+ " }\n"
+ " }\n"
+ " attr {\n"
+ " key: \"shape\"\n"
+ " value {\n"
+ " shape {\n"
+ " }\n"
+ " }\n"
+ " }\n"
+ "}\n"
+ "node {\n"
+ " name: \"blah2\"\n"
+ " op: \"Softmax\"\n"
+ " input: \"blah\"\n"
+ " attr {\n"
+ " key: \"T\"\n"
+ " value {\n"
+ " type: DT_FLOAT\n"
+ " }\n"
+ " }\n"
+ "}\n";
+
+ SetupSingleInputSingleOutput({ 1, 7 }, "blah", "blah2");
+ }
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseSoftmax, SoftmaxFixture)
+{
+ RunTest<2>({ 0, 0, 10000, 0, 0, 0, 0 }, { 0, 0, 1, 0, 0, 0, 0 });
+}
+
+
+BOOST_AUTO_TEST_SUITE_END()
diff --git a/src/armnnTfParser/test/Squeeze.cpp b/src/armnnTfParser/test/Squeeze.cpp
new file mode 100644
index 0000000000..d2d7d49494
--- /dev/null
+++ b/src/armnnTfParser/test/Squeeze.cpp
@@ -0,0 +1,108 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// See LICENSE file in the project root for full license information.
+//
+
+#include <boost/test/unit_test.hpp>
+#include "armnnTfParser/ITfParser.hpp"
+#include "ParserPrototxtFixture.hpp"
+
+BOOST_AUTO_TEST_SUITE(TensorflowParser)
+
+
+template <bool withDimZero, bool withDimOne>
+struct SqueezeFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser>
+{
+ SqueezeFixture()
+ {
+ m_Prototext =
+ "node { \n"
+ " name: \"graphInput\" \n"
+ " op: \"Placeholder\" \n"
+ " attr { \n"
+ " key: \"dtype\" \n"
+ " value { \n"
+ " type: DT_FLOAT \n"
+ " } \n"
+ " } \n"
+ " attr { \n"
+ " key: \"shape\" \n"
+ " value { \n"
+ " shape { \n"
+ " } \n"
+ " } \n"
+ " } \n"
+ " } \n"
+ "node { \n"
+ " name: \"Squeeze\" \n"
+ " op: \"Squeeze\" \n"
+ " input: \"graphInput\" \n"
+ " attr { \n"
+ " key: \"T\" \n"
+ " value { \n"
+ " type: DT_FLOAT \n"
+ " } \n"
+ " } \n"
+ " attr { \n"
+ " key: \"squeeze_dims\" \n"
+ " value { \n"
+ " list {\n";
+
+ if (withDimZero)
+ {
+ m_Prototext += "i:0\n";
+ }
+
+ if (withDimOne)
+ {
+ m_Prototext += "i:1\n";
+ }
+
+ m_Prototext +=
+ " } \n"
+ " } \n"
+ " } \n"
+ "} \n";
+
+ SetupSingleInputSingleOutput({ 1, 1, 2, 2 }, "graphInput", "Squeeze");
+ }
+};
+
+typedef SqueezeFixture<false, false> ImpliedDimensionsSqueezeFixture;
+typedef SqueezeFixture<true, false> ExplicitDimensionZeroSqueezeFixture;
+typedef SqueezeFixture<false, true> ExplicitDimensionOneSqueezeFixture;
+typedef SqueezeFixture<true, true> ExplicitDimensionsSqueezeFixture;
+
+BOOST_FIXTURE_TEST_CASE(ParseImplicitSqueeze, ImpliedDimensionsSqueezeFixture)
+{
+ BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("Squeeze").second.GetShape() ==
+ armnn::TensorShape({2,2})));
+ RunTest<2>({ 1.0f, 2.0f, 3.0f, 4.0f },
+ { 1.0f, 2.0f, 3.0f, 4.0f });
+}
+
+BOOST_FIXTURE_TEST_CASE(ParseDimensionZeroSqueeze, ExplicitDimensionZeroSqueezeFixture)
+{
+ BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("Squeeze").second.GetShape() ==
+ armnn::TensorShape({1,2,2})));
+ RunTest<3>({ 1.0f, 2.0f, 3.0f, 4.0f },
+ { 1.0f, 2.0f, 3.0f, 4.0f });
+}
+
+BOOST_FIXTURE_TEST_CASE(ParseDimensionOneSqueeze, ExplicitDimensionOneSqueezeFixture)
+{
+ BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("Squeeze").second.GetShape() ==
+ armnn::TensorShape({1,2,2})));
+ RunTest<3>({ 1.0f, 2.0f, 3.0f, 4.0f },
+ { 1.0f, 2.0f, 3.0f, 4.0f });
+}
+
+BOOST_FIXTURE_TEST_CASE(ParseExplicitDimensionsSqueeze, ExplicitDimensionsSqueezeFixture)
+{
+ BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("Squeeze").second.GetShape() ==
+ armnn::TensorShape({2,2})));
+ RunTest<2>({ 1.0f, 2.0f, 3.0f, 4.0f },
+ { 1.0f, 2.0f, 3.0f, 4.0f });
+}
+
+BOOST_AUTO_TEST_SUITE_END()
diff --git a/src/armnnTfParser/test/TestDependencies.cpp b/src/armnnTfParser/test/TestDependencies.cpp
new file mode 100644
index 0000000000..13ab17c5b6
--- /dev/null
+++ b/src/armnnTfParser/test/TestDependencies.cpp
@@ -0,0 +1,296 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// See LICENSE file in the project root for full license information.
+//
+
+#include <boost/test/unit_test.hpp>
+#include "armnnTfParser/ITfParser.hpp"
+#include "ParserPrototxtFixture.hpp"
+
+BOOST_AUTO_TEST_SUITE(TensorflowParser)
+
+// Graph which tests that nodes are re-ordered in the queue when they are encountered a second time.
+// In this case R0 will be encountered first via R1 and then via R2. At that time
+// we need to make sure that R0 (and the I on which it is dependent) is moved to the front again
+// so that it is before both R1 and R2.
+// I
+// |
+// R0
+// / \'
+// R1 R2
+// \ |
+// \ R3
+// \|
+// O
+struct RediscoveredDependenciesFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser>
+{
+ RediscoveredDependenciesFixture()
+ {
+ // input = tf.placeholder(tf.float32, 1, "input")
+ // relu0 = tf.nn.relu(input, "relu0")
+ // relu1 = tf.nn.relu(relu0, "relu1")
+ // relu2 = tf.nn.relu(relu0, "relu2")
+ // relu3 = tf.nn.relu(relu2, "relu3")
+ // output = tf.add(relu1, relu3, "output")
+ m_Prototext = R"(
+ node {
+ name: "input"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "relu0"
+ op: "Relu"
+ input: "input"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ }
+ node {
+ name: "relu1"
+ op: "Relu"
+ input: "relu0"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ }
+ node {
+ name: "relu2"
+ op: "Relu"
+ input: "relu0"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ }
+ node {
+ name: "relu3"
+ op: "Relu"
+ input: "relu2"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ }
+ node {
+ name: "output"
+ op: "Add"
+ input: "relu1"
+ input: "relu3"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ }
+ )";
+ SetupSingleInputSingleOutput({ 1 }, "input", "output");
+ }
+};
+
+BOOST_FIXTURE_TEST_CASE(RediscoveredDependencies, RediscoveredDependenciesFixture)
+{
+ RunTest<1>({1}, {2});
+}
+
+// Tests that a simple cycle in the tensorflow graph will be detected and an exception thrown, rather than the TfParser
+// getting stuck in an infinite loop.
+BOOST_AUTO_TEST_CASE(SimpleCycle)
+{
+ const char* prototext = R"(
+node {
+ name: "r1"
+ op: "Relu"
+ input: "r2"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "r2"
+ op: "Relu"
+ input: "r1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+ )";
+ armnnTfParser::ITfParserPtr parser = armnnTfParser::ITfParser::Create();
+ BOOST_CHECK_THROW(parser->CreateNetworkFromString(prototext, {}, { "r2" }), armnn::ParseException);
+}
+
+// Similar to the above SimpleCycle test, but has a single node which connects to itself.
+BOOST_AUTO_TEST_CASE(SingleNodeCycle)
+{
+ const char* prototext = R"(
+node {
+ name: "r1"
+ op: "Relu"
+ input: "r1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+ )";
+ armnnTfParser::ITfParserPtr parser = armnnTfParser::ITfParser::Create();
+ BOOST_CHECK_THROW(parser->CreateNetworkFromString(prototext, {}, { "r1" }), armnn::ParseException);
+}
+
+// Similar to the above SimpleCycle test, but with a more complicated graph.
+// I
+// |
+// A2---<---<-
+// / \' |
+// R1 R2 |
+// \ | |
+// \ R3 |
+// \| |
+// A1-->--->|
+//
+BOOST_AUTO_TEST_CASE(ComplexCycle)
+{
+ // input = tf.placeholder(tf.float32, 1, "input")
+ // add2 = tf.nn.relu(input, add1, "add2") // This line won't actually run in TF, because add1 is not yet defined
+ // relu1 = tf.nn.relu(relu0, "relu1")
+ // relu2 = tf.nn.relu(relu0, "relu2")
+ // relu3 = tf.nn.relu(relu2, "relu3")
+ // add1 = tf.add(relu1, relu3, "add1")
+ const char* prototext = R"(
+ node {
+ name: "input"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "add2"
+ op: "Add"
+ input: "input"
+ input: "add1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ }
+ node {
+ name: "relu1"
+ op: "Relu"
+ input: "add2"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ }
+ node {
+ name: "relu2"
+ op: "Relu"
+ input: "add2"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ }
+ node {
+ name: "relu3"
+ op: "Relu"
+ input: "relu2"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ }
+ node {
+ name: "add1"
+ op: "Add"
+ input: "relu1"
+ input: "relu3"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ }
+ )";
+ armnnTfParser::ITfParserPtr parser = armnnTfParser::ITfParser::Create();
+ BOOST_CHECK_THROW(parser->CreateNetworkFromString(prototext, {}, { "add1" }), armnn::ParseException);
+}
+
+// Tests that a graph with an input that is not present throws a ParseException.
+BOOST_AUTO_TEST_CASE(InvalidInput)
+{
+ const char* prototext = R"(
+node {
+ name: "r1"
+ op: "Relu"
+ input: "a-node-that-does-not-exist"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+ )";
+ armnnTfParser::ITfParserPtr parser = armnnTfParser::ITfParser::Create();
+ BOOST_CHECK_THROW(parser->CreateNetworkFromString(prototext, {}, { "r1" }), armnn::ParseException);
+}
+
+BOOST_AUTO_TEST_SUITE_END()
diff --git a/src/armnnTfParser/test/TestMultiInputsOutputs.cpp b/src/armnnTfParser/test/TestMultiInputsOutputs.cpp
new file mode 100644
index 0000000000..5eea616ec8
--- /dev/null
+++ b/src/armnnTfParser/test/TestMultiInputsOutputs.cpp
@@ -0,0 +1,92 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// See LICENSE file in the project root for full license information.
+//
+
+#include <boost/test/unit_test.hpp>
+#include "armnnTfParser/ITfParser.hpp"
+#include "ParserPrototxtFixture.hpp"
+
+BOOST_AUTO_TEST_SUITE(TensorflowParser)
+
+struct MultiInputsOutputsFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser>
+{
+ MultiInputsOutputsFixture()
+ {
+ // input1 = tf.placeholder(tf.float32, shape=[], name = "input1")
+ // input2 = tf.placeholder(tf.float32, shape = [], name = "input2")
+ // add1 = tf.add(input1, input2, name = "add1")
+ // add2 = tf.add(input1, input2, name = "add2")
+ m_Prototext = R"(
+node {
+ name: "input1"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ }
+ }
+ }
+}
+node {
+ name: "input2"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ }
+ }
+ }
+}
+node {
+ name: "add1"
+ op: "Add"
+ input: "input1"
+ input: "input2"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "add2"
+ op: "Add"
+ input: "input1"
+ input: "input2"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+ )";
+ Setup({ { "input1", { 1 } },
+ { "input2", { 1 } } },
+ { "add1", "add2" });
+ }
+};
+
+BOOST_FIXTURE_TEST_CASE(MultiInputsOutputs, MultiInputsOutputsFixture)
+{
+ RunTest<1>({ { "input1", {12.0f} }, { "input2", { 13.0f } } },
+ { { "add1", { 25.0f } }, { "add2", { 25.0f } } });
+}
+
+BOOST_AUTO_TEST_SUITE_END()