From bceff2fb3fc68bb0aa88b886900c34b77340c826 Mon Sep 17 00:00:00 2001 From: surmeh01 Date: Thu, 29 Mar 2018 16:29:27 +0100 Subject: Release 18.03 --- src/armnnTfParser/README.md | 5 + src/armnnTfParser/TensorFlowSupport.md | 111 + src/armnnTfParser/TfParser.cpp | 2200 ++++++++++++++++++++ src/armnnTfParser/TfParser.hpp | 199 ++ src/armnnTfParser/test/Activations.cpp | 113 + src/armnnTfParser/test/Addition.cpp | 78 + src/armnnTfParser/test/BiasAdd.cpp | 104 + src/armnnTfParser/test/BroadcastForAdd.cpp | 149 ++ src/armnnTfParser/test/Concat.cpp | 183 ++ src/armnnTfParser/test/ConcatOfConcats.cpp | 316 +++ src/armnnTfParser/test/Constant.cpp | 321 +++ src/armnnTfParser/test/Convolution2d.cpp | 322 +++ src/armnnTfParser/test/DepthwiseConvolution2d.cpp | 166 ++ src/armnnTfParser/test/FullyConnected.cpp | 579 ++++++ src/armnnTfParser/test/FusedBatchNorm.cpp | 175 ++ src/armnnTfParser/test/Identity.cpp | 161 ++ .../test/LocalResponseNormalization.cpp | 121 ++ src/armnnTfParser/test/MultiOutput.cpp | 144 ++ src/armnnTfParser/test/Multiplication.cpp | 172 ++ src/armnnTfParser/test/PassThru.cpp | 52 + src/armnnTfParser/test/Pooling.cpp | 112 + src/armnnTfParser/test/Reshape.cpp | 86 + src/armnnTfParser/test/ResizeBilinear.cpp | 114 + src/armnnTfParser/test/Shape.cpp | 94 + src/armnnTfParser/test/Softmax.cpp | 55 + src/armnnTfParser/test/Squeeze.cpp | 108 + src/armnnTfParser/test/TestDependencies.cpp | 296 +++ src/armnnTfParser/test/TestMultiInputsOutputs.cpp | 92 + 28 files changed, 6628 insertions(+) create mode 100644 src/armnnTfParser/README.md create mode 100644 src/armnnTfParser/TensorFlowSupport.md create mode 100644 src/armnnTfParser/TfParser.cpp create mode 100644 src/armnnTfParser/TfParser.hpp create mode 100644 src/armnnTfParser/test/Activations.cpp create mode 100644 src/armnnTfParser/test/Addition.cpp create mode 100644 src/armnnTfParser/test/BiasAdd.cpp create mode 100644 src/armnnTfParser/test/BroadcastForAdd.cpp create mode 100644 src/armnnTfParser/test/Concat.cpp create mode 100644 src/armnnTfParser/test/ConcatOfConcats.cpp create mode 100644 src/armnnTfParser/test/Constant.cpp create mode 100644 src/armnnTfParser/test/Convolution2d.cpp create mode 100644 src/armnnTfParser/test/DepthwiseConvolution2d.cpp create mode 100644 src/armnnTfParser/test/FullyConnected.cpp create mode 100644 src/armnnTfParser/test/FusedBatchNorm.cpp create mode 100644 src/armnnTfParser/test/Identity.cpp create mode 100644 src/armnnTfParser/test/LocalResponseNormalization.cpp create mode 100644 src/armnnTfParser/test/MultiOutput.cpp create mode 100644 src/armnnTfParser/test/Multiplication.cpp create mode 100644 src/armnnTfParser/test/PassThru.cpp create mode 100644 src/armnnTfParser/test/Pooling.cpp create mode 100644 src/armnnTfParser/test/Reshape.cpp create mode 100644 src/armnnTfParser/test/ResizeBilinear.cpp create mode 100644 src/armnnTfParser/test/Shape.cpp create mode 100644 src/armnnTfParser/test/Softmax.cpp create mode 100644 src/armnnTfParser/test/Squeeze.cpp create mode 100644 src/armnnTfParser/test/TestDependencies.cpp create mode 100644 src/armnnTfParser/test/TestMultiInputsOutputs.cpp (limited to 'src/armnnTfParser') diff --git a/src/armnnTfParser/README.md b/src/armnnTfParser/README.md new file mode 100644 index 0000000000..fe3f2b8950 --- /dev/null +++ b/src/armnnTfParser/README.md @@ -0,0 +1,5 @@ +#The Arm NN TensorFlow parser + +`armnnTfParser` is a library for loading Neural Networks defined by TensorFlow protobuf files into the Arm NN runtime. + +For more information about the TensorFlow operators that are supported, and the networks that have been tested, see [TensorFlowSupport.md](./TensorFlowSupport.md) \ No newline at end of file diff --git a/src/armnnTfParser/TensorFlowSupport.md b/src/armnnTfParser/TensorFlowSupport.md new file mode 100644 index 0000000000..d052a70d49 --- /dev/null +++ b/src/armnnTfParser/TensorFlowSupport.md @@ -0,0 +1,111 @@ +#TensorFlow operators that the Arm NN SDK supports + +This reference guide provides a list of TensorFlow operators the Arm NN SDK currently supports. + +The Arm NN SDK TensorFlow parser currently only supports fp32 operators. + +These are the TensorFlow operators that the Arm NN SDK currently supports: + +**avg_pool** + +See the TensorFlow [avg_pool documentation](https://www.tensorflow.org/api_docs/python/tf/nn/avg_pool) for more information. + +**bias_add** + + See the TensorFlow [bias_add documentation](https://www.tensorflow.org/api_docs/python/tf/nn/bias_add) for more information. + +**conv2d** + + See the TensorFlow [conv2d documentation](https://www.tensorflow.org/api_docs/python/tf/nn/conv2d) for more information. + +**identity** + +See the TensorFlow [identity documentation](https://www.tensorflow.org/api_docs/python/tf/identity) for more information. + +**local_response_normalization** + +See the TensorFlow [local_response_normalization documentation](https://www.tensorflow.org/api_docs/python/tf/nn/local_response_normalization) for more information. + +**max_pool** + +See the TensorFlow [max_pool documentation](https://www.tensorflow.org/api_docs/python/tf/nn/max_pool) for more information. + +**relu** + + See the TensorFlow [relu documentation](https://www.tensorflow.org/api_docs/python/tf/nn/relu) for more information. + +**relu6** + + See the TensorFlow [relu6 documentation](https://www.tensorflow.org/api_docs/python/tf/nn/relu6) for more information. + +**shape** + + See the TensorFlow [shape documentation](https://www.tensorflow.org/api_docs/python/tf/shape) for more information. + +**sigmoid** + + See the TensorFlow [sigmoid documentation](https://www.tensorflow.org/api_docs/python/tf/sigmoid) for more information. + +**softplus** + +See the TensorFlow [softplus documentation](https://www.tensorflow.org/api_docs/python/tf/nn/softplus) for more information. + +**squeeze** + +See the TensorFlow [squeeze documentation](https://www.tensorflow.org/api_docs/python/tf/squeeze) for more information. + +**tanh** + +See the TensorFlow [tanh documentation](https://www.tensorflow.org/api_docs/python/tf/tanh) for more information. + +The Arm NN SDK TensorFlow parser currently partially supports: + +**add** + +The parser does not support all forms of [broadcast composition](https://www.tensorflow.org/performance/xla/broadcasting), only broadcasting of scalars and 1D tensors. See the TensorFlow [add operator documentation](https://www.tensorflow.org/api_docs/python/tf/add) for more information. + +**depthwise_conv2D_native** + +The parser only supports a dilation rate of (1,1,1,1). See the TensorFlow [depthwise_conv2d_native documentation](https://www.tensorflow.org/api_docs/python/tf/nn/depthwise_conv2d_native) for more information. + +**fused_batch_norm** + +The parser does not support training outputs. See the TensorFlow [fused_batch_norm documentation](https://www.tensorflow.org/api_docs/python/tf/nn/fused_batch_norm) for more information. + +**matmul** + +The parser only supports constant weights in a fully connected layer. See the TensorFlow [matmul documentation](https://www.tensorflow.org/api_docs/python/tf/matmul) for more information. + +**multiply** + +The parser does not support all forms of [broadcast composition](https://www.tensorflow.org/performance/xla/broadcasting), only broadcasting of scalars and 1D tensors. See the TensorFlow [multiply documentation](https://www.tensorflow.org/api_docs/python/tf/multiply) for more information. No broadcasting supported on the NEON backend. + +**placeholder** + + The parser only supports the NHWC data format in the input layer. See the TensorFlow [placeholder documentation](https://www.tensorflow.org/api_docs/python/tf/placeholder) for more information. + +**reshape** + +The parser does not support reshaping to or from 4D. See the TensorFlow [reshape documentation](https://www.tensorflow.org/api_docs/python/tf/reshape) for more information. + +**resize_images** + +The parser only supports `ResizeMethod.BILINEAR`. See the TensorFlow [resize_images documentation](https://www.tensorflow.org/api_docs/python/tf/image/resize_images) for more information. + +**softmax** + +The parser only supports 2D inputs and does not support selecting the `softmax` dimension. See the TensorFlow [softmax documentation](https://www.tensorflow.org/api_docs/python/tf/nn/softmax) for more information. + + + +Arm tests these operators with the following TensorFlow fp32 neural networks: + +* Cifar10. + +* Lenet. + +* mobilenet_v1_1.0_224. The Arm NN SDK only supports the non*_quant version of the network. See the [MobileNet_v1 documentation](https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet_v1.md) for more information on _quant networks. + +* inception_v3. The Arm NN SDK only supports the official inception_v3 transformed model using the GPU acceleration only, but NEON acceleration is not supported at the moment. See the TensorFlow documentation on [preparing models for mobile deployment](https://www.tensorflow.org/mobile/prepare_models) for more information on how to transform the inception_v3 network. + +More machine learning operators will be supported in future releases. diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp new file mode 100644 index 0000000000..7c8e01b112 --- /dev/null +++ b/src/armnnTfParser/TfParser.cpp @@ -0,0 +1,2200 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// See LICENSE file in the project root for full license information. +// +#include "TfParser.hpp" + +#include +#include +#include +#include +#include + +#include +#include + +#include +#include + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +using namespace armnn; + +namespace armnnTfParser +{ +namespace +{ + +const PermutationVector NHWCToArmNN = { 0, 2, 3, 1 }; +const PermutationVector ArmNNToNHWC = { 0, 3, 1, 2 }; + +IConnectableLayer* AddSwizzleLayer(INetwork& network, IOutputSlot& input, const PermutationVector& mapping, + const std::string& name) +{ + // Add swizzle layer + IConnectableLayer* const layer = network.AddPermuteLayer(mapping, name.c_str()); + + // Connect intput to swizzle layer + input.Connect(layer->GetInputSlot(0)); + + // Setup swizzled output + const TensorInfo outInfo = armnnUtils::Permuted(input.GetTensorInfo(), mapping); + layer->GetOutputSlot(0).SetTensorInfo(outInfo); + + return layer; +} + +IConnectableLayer* SwizzleInDeswizzleOut(INetwork& network, IOutputSlot& input, IConnectableLayer& layer, + const std::string& name) +{ + // Add swizzle layer + IConnectableLayer* const swizzleLayer = AddSwizzleLayer(network, input, NHWCToArmNN, "swizzle_for-" + name); + + // Connect swizzledInput to layer + swizzleLayer->GetOutputSlot(0).Connect(layer.GetInputSlot(0)); + + // Add deswizzle layer + IConnectableLayer* const deswizzleLayer = AddSwizzleLayer(network, layer.GetOutputSlot(0), ArmNNToNHWC, + "deswizzle_for-" + name); + + return deswizzleLayer; +} + +template +void ReadMandatoryNodeAttributeImpl(const tensorflow::NodeDef& nodeDef, + const std::string& attribName, + tensorflow::AttrValue::ValueCase expectedValueCase, + Callable callable) +{ + auto iter = nodeDef.attr().find(attribName); + if (iter != nodeDef.attr().end()) + { + const auto& attrValue = iter->second; + if (attrValue.value_case() == expectedValueCase) + { + callable(attrValue); + } + else + { + throw ParseException(boost::str(boost::format( + "Attribute %1% of node %2% expected to have %3% as tensorflow::AttrValue::ValueCase, " + "but found %4% instead") + % attribName + % nodeDef.name() + % static_cast(expectedValueCase) + % static_cast(attrValue.value_case()))); + } + } + else + { + throw ParseException(boost::str(boost::format("Could not find required attribute %1% in node %2%") + % attribName % nodeDef.name())); + } +} + +template +void ReadOptionalNodeAttributeImpl(const tensorflow::NodeDef& nodeDef, + const std::string& attribName, + tensorflow::AttrValue::ValueCase expectedValueCase, + Callable callable) +{ + auto iter = nodeDef.attr().find(attribName); + if (iter != nodeDef.attr().end()) + { + const auto& attrValue = iter->second; + if (attrValue.value_case() == expectedValueCase) + { + callable(attrValue); + } + else + { + throw ParseException(boost::str(boost::format( + "Attribute %1% of node %2% expected to have %3% as tensorflow::AttrValue::ValueCase, " + "but found %4% instead") + % attribName + % nodeDef.name() + % static_cast(expectedValueCase) + % static_cast(attrValue.value_case()))); + } + } +} + +float ReadMandatoryNodeFloatAttribute(const tensorflow::NodeDef& nodeDef, const std::string& name) +{ + float attribValue = 0.0f; + ReadMandatoryNodeAttributeImpl(nodeDef, name, tensorflow::AttrValue::kF, + [&attribValue](const tensorflow::AttrValue& attrValue) + { + attribValue = attrValue.f(); + }); + return attribValue; +} + +uint32_t ReadMandatoryNodeUint32Attribute(const tensorflow::NodeDef& nodeDef, const std::string& name) +{ + uint32_t attribValue = 0u; + ReadMandatoryNodeAttributeImpl(nodeDef, name, tensorflow::AttrValue::kI, + [&attribValue](const tensorflow::AttrValue& attrValue) + { + attribValue = static_cast(attrValue.i()); + }); + return attribValue; +} + +std::string ReadMandatoryNodeStringAttribute(const tensorflow::NodeDef& nodeDef, const std::string& name) +{ + std::string attribValue = ""; + ReadMandatoryNodeAttributeImpl(nodeDef, name, tensorflow::AttrValue::kS, + [&attribValue](const tensorflow::AttrValue& attrValue) + { + attribValue = attrValue.s(); + }); + return attribValue; +} + +std::vector ReadMandatoryNodeUint32ListAttribute(const tensorflow::NodeDef& nodeDef, + const std::string& name) +{ + std::vector attriList; + ReadMandatoryNodeAttributeImpl(nodeDef, name, tensorflow::AttrValue::kList, + [&attriList](const tensorflow::AttrValue& attrValue) + { + for (int attriNum = 0; attriNum < attrValue.list().i_size(); ++attriNum) + { + attriList.push_back(static_cast(attrValue.list().i(attriNum))); + } + }); + + return attriList; +} + +std::vector ReadOptionalNodeUint32ListAttribute(const tensorflow::NodeDef& nodeDef, + const std::string& name) +{ + std::vector attriList; + ReadOptionalNodeAttributeImpl(nodeDef, name, tensorflow::AttrValue::kList, + [&attriList](const tensorflow::AttrValue& attrValue) + { + for (int attriNum = 0; attriNum < attrValue.list().i_size(); ++attriNum) + { + attriList.push_back(static_cast(attrValue.list().i(attriNum))); + } + }); + + return attriList; +} + +bool ReadOptionalNodeBoolAttribute(const tensorflow::NodeDef& nodeDef, + const std::string& name, + bool defaultValue = false) +{ + bool attribValue = defaultValue; + ReadOptionalNodeAttributeImpl(nodeDef, name, tensorflow::AttrValue::kB, + [&attribValue](const tensorflow::AttrValue& attrValue) + { + attribValue = attrValue.b(); + }); + return attribValue; +} + +tensorflow::DataType ReadMandatoryNodeTypeAttribute(const tensorflow::NodeDef& nodeDef, const std::string& name) +{ + tensorflow::DataType attribValue = tensorflow::DT_INVALID; + ReadMandatoryNodeAttributeImpl(nodeDef, name, tensorflow::AttrValue::kType, + [&attribValue](const tensorflow::AttrValue& attrValue) + { + attribValue = attrValue.type(); + }); + return attribValue; +} + +TensorInfo PrepareReshape(const TensorInfo& input, const std::vector& targetDims) +{ + std::vector outDims(targetDims.begin(), targetDims.end()); + const auto stretchDim = std::find(targetDims.begin(), targetDims.end(), -1); + + if (stretchDim != targetDims.end()) + { + if (std::find(std::next(stretchDim), targetDims.end(), -1) != targetDims.end()) + { + throw ParseException("At most one component of shape can be -1"); + } + + auto targetNumElements = boost::numeric_cast(std::accumulate(targetDims.begin(), targetDims.end(), + -1, std::multiplies())); + auto stretchIndex = static_cast(std::distance(targetDims.begin(), stretchDim)); + outDims[stretchIndex] = input.GetNumElements() / targetNumElements; + } + + TensorInfo reshapeInfo = input; + reshapeInfo.SetShape(TensorShape{ static_cast(outDims.size()), outDims.data() }); + + return reshapeInfo; +} + +// We need the input0Slot to guide the reshape for input1Slot +IOutputSlot* BroadcastForAddandMul(IOutputSlot* input0Slot, IOutputSlot* input1Slot, bool isNHWC, INetwork& m_Network, + const tensorflow::NodeDef& nodeDef) +{ + const TensorInfo& input1Info = input1Slot->GetTensorInfo(); + const TensorInfo inputTensorInfo = input0Slot->GetTensorInfo(); + const unsigned int matchDim = inputTensorInfo.GetNumDimensions() - (isNHWC ? 1 : 3); + std::array reshapedDimensions; + std::fill_n(reshapedDimensions.begin(), inputTensorInfo.GetNumDimensions(), 1); + reshapedDimensions[matchDim] = input1Info.GetShape()[0]; + + armnn::TensorInfo reshapedInfo = input1Info; + reshapedInfo.SetShape(TensorShape{ inputTensorInfo.GetNumDimensions(), reshapedDimensions.data() }); + + const std::string reshapeLayerName = "reshape_for-" + nodeDef.name(); + ReshapeDescriptor reshapeDesc; + reshapeDesc.m_TargetShape = reshapedInfo.GetShape(); + IConnectableLayer* const reshapeLayer = m_Network.AddReshapeLayer(reshapeDesc, reshapeLayerName.c_str()); + + input1Slot->Connect(reshapeLayer->GetInputSlot(0)); + reshapeLayer->GetOutputSlot(0).SetTensorInfo(reshapedInfo); + + input1Slot = &reshapeLayer->GetOutputSlot(0); + + return input1Slot; +} + +OutputId ParseOutputId(const std::string & name) +{ + unsigned int outputNum = 0; + size_t colonPos = name.find_last_of(":"); + if (colonPos != std::string::npos) + { + int n = std::stoi(name.substr(colonPos+1)); + if (n<0 || n>100) + { + throw ParseException("Output tensor id is out of range for "+name); + } + outputNum = static_cast(n); + } + return OutputId(name.substr(0,colonPos),outputNum); +} + +} // namespace + +const std::map TfParser::ms_OperationNameToParsingFunctions = { + { "Const", &TfParser::ParseConst }, + { "Add", &TfParser::ParseAdd }, + { "BiasAdd", &TfParser::ParseBiasAdd }, + { "Identity", &TfParser::ParseIdentity }, + { "Conv2D", &TfParser::ParseConv2D }, + { "DepthwiseConv2dNative", &TfParser::ParseDepthwiseConv2D }, + { "FusedBatchNorm", &TfParser::ParseFusedBatchNorm }, + { "ConcatV2", &TfParser::ParseConcat }, + { "LRN", &TfParser::ParseLrn }, + { "MatMul", &TfParser::ParseMatMul }, + { "Mul", &TfParser::ParseMul }, + { "Placeholder", &TfParser::ParsePlaceholder }, + { "Relu", &TfParser::ParseRelu }, + { "Relu6", &TfParser::ParseRelu6 }, + { "Reshape", &TfParser::ParseReshape }, + { "ResizeBilinear", &TfParser::ParseResizeBilinear }, + { "Shape", &TfParser::ParseShape }, + { "Squeeze", &TfParser::ParseSqueeze }, + { "Sigmoid", &TfParser::ParseSigmoid }, + { "Softmax", &TfParser::ParseSoftmax }, + { "Softplus", &TfParser::ParseSoftplus }, + { "Tanh", &TfParser::ParseTanh }, + { "MaxPool", &TfParser::ParseMaxPool }, + { "AvgPool", &TfParser::ParseAvgPool }, +}; + +ITfParser* ITfParser::CreateRaw() +{ + return new TfParser(); +} + +ITfParserPtr ITfParser::Create() +{ + return ITfParserPtr(CreateRaw(), &ITfParser::Destroy); +} + +void ITfParser::Destroy(ITfParser* parser) +{ + delete parser; +} + +inline void CalculateSamePadding(uint32_t inputSize, uint32_t stride, + uint32_t filterSize, bool samePadding, + uint32_t* paddingFront, uint32_t* paddingBack) { + *paddingFront = 0; + *paddingBack = 0; + + if (samePadding) { + uint32_t outputSize = (inputSize + stride - 1) / stride; + uint32_t temp = (outputSize - 1) * stride + filterSize; + if (temp > inputSize) { + *paddingFront = (temp - inputSize) / 2; + *paddingBack = (temp - inputSize) - *paddingFront; + } + } +} + +void CalcPadding(uint32_t input, uint32_t kernel, uint32_t stride, uint32_t& outPadHead, uint32_t& outPadTail, + bool samePadding) +{ + CalculateSamePadding(input, stride, kernel, samePadding, &outPadHead, &outPadTail); +} + +/// An Abstract base class which represents a single tensorflow operation (node) +/// that has been (potentially partially) converted to Armnn. +/// It may not yet have been fully converted into actual Armnn layers. +class ParsedTfOperation +{ +public: + ParsedTfOperation(TfParser* parser, const tensorflow::NodeDef& node) + : m_Parser(parser) + , m_Node(node) + { + } + + virtual ~ParsedTfOperation() {}; + + const tensorflow::NodeDef& GetNode() const { return m_Node; } + + /// Gets the ArmNN IOutputSlot corresponding to the given output index of the Tensorflow operation. + /// This may result in the creation of Armnn layers if this was deferred (e.g. see ParsedConstTfOperation). + virtual IOutputSlot& ResolveArmnnOutputSlot(unsigned int tfOutputIndex) = 0; + + /// If this operation is an Identity then this will follow return the 'parent' operation (recursively). + virtual ParsedTfOperation* ResolveIdentityOperations() + { + return this; + } + +protected: + TfParser* m_Parser; + const tensorflow::NodeDef& m_Node; +}; + +/// An ParsedTfOperation where the Armnn equivalent is a single layer, +/// with output slots that correspond directly to the Tf node outputs. +class SingleLayerParsedTfOperation : public ParsedTfOperation +{ +public: + SingleLayerParsedTfOperation(TfParser* parser, const tensorflow::NodeDef& node, IConnectableLayer* layer) + : ParsedTfOperation(parser, node) + , m_Layer(layer) + { + } + + IOutputSlot& ResolveArmnnOutputSlot(unsigned int tfOutputIndex) override + { + BOOST_ASSERT(m_Layer); + // Assume one-to-one mapping between Tf and armnn output slots. + unsigned int armnnOutputSlotIdx = tfOutputIndex; + if (armnnOutputSlotIdx >= m_Layer->GetNumOutputSlots()) + { + throw ParseException( + boost::str(boost::format("The requested output slot #%1% " + "for %2% does not exist") % armnnOutputSlotIdx % m_Layer->GetName())); + } + return m_Layer->GetOutputSlot(armnnOutputSlotIdx); + } + +protected: + IConnectableLayer* m_Layer; +}; + +/// A SingleLayerParsedTfOperation for deferred layer creation +class DeferredSingleLayerParsedTfOperation : public SingleLayerParsedTfOperation +{ +public: + DeferredSingleLayerParsedTfOperation(TfParser* parser, const tensorflow::NodeDef& node) + : SingleLayerParsedTfOperation(parser, node, nullptr) + { + } + + IOutputSlot& ResolveArmnnOutputSlot(unsigned int tfOutputIndex) override + { + if (!m_Layer) + { + CreateLayerDeferred(); + } + return SingleLayerParsedTfOperation::ResolveArmnnOutputSlot(tfOutputIndex); + } + +private: + virtual void CreateLayerDeferred() = 0; +}; + + +TfParser::TfParser() + : m_Network(nullptr, nullptr) +{ +} + + +const tensorflow::NodeDef* TfParser::ResolveIdentityNode(const tensorflow::NodeDef* nodeDef) +{ + if (nodeDef->op() != "Identity") + { + return nodeDef; + } + + if (nodeDef->input_size() != 1) + { + throw ParseException("Identity node does not have correct amount of inputs!"); + } + + auto it = m_NodesByName.find(nodeDef->input(0)); + if (it != m_NodesByName.end()) + { + const tensorflow::NodeDef* inputNode = it->second; + return ResolveIdentityNode(inputNode); + } + else + { + throw ParseException("Cannot find what the Identity node is linked to!"); + } +} + +std::vector +TfParser::GetTfInputNodes(const tensorflow::NodeDef& nodeDef) const +{ + std::vector ret; + + ret.reserve(boost::numeric_cast(nodeDef.input_size())); + for (int j = 0; j < nodeDef.input_size(); ++j) + { + OutputId outputId = ParseOutputId(nodeDef.input(j)); + auto inputIt = m_NodesByName.find(outputId.m_IndexedValue); + if (inputIt == m_NodesByName.end()) + { + throw ParseException( + "Can't find node '" + nodeDef.input(j) + + "', which is listed as an input of '" + nodeDef.name() + "'"); + } + ret.push_back(OutputOfConstNodeDef(inputIt->second,outputId.m_Index)); + } + + return ret; +} + +std::vector +TfParser::GetInputParsedTfOperationsChecked(const tensorflow::NodeDef& nodeDef, + std::size_t expectedNumInputs) +{ + // Fetch the tensorflow nodes connected as inputs and validate the size. + std::vector nodes = GetTfInputNodes(nodeDef); + const std::size_t numInputs = nodes.size(); + if (numInputs != expectedNumInputs) + { + throw ParseException(boost::str(boost::format("Unexpected number of inputs for node %1%. " + "Expected %2%, found %3%") % nodeDef.name() % expectedNumInputs % numInputs)); + } + // Fetch the corresponding ParsedTfOperation operations + std::vector result; + for (auto&& node : nodes) + { + auto it = m_ParsedTfOperations.find(node.m_IndexedValue->name()); + if (it == m_ParsedTfOperations.end()) + { + throw ParseException("Node with name '" + node.m_IndexedValue->name() + "' has not been parsed"); + } + ParsedTfOperation* parsedOp = it->second.get(); + // Transparently 'skip' any Identity operations. This simplifies the logic inside the ParseXXX() functions. + parsedOp = parsedOp->ResolveIdentityOperations(); + result.push_back(OutputOfParsedTfOperation(parsedOp,node.m_Index)); + } + return result; +} + +ParsedTfOperationPtr TfParser::ParseAdd(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef) +{ + std::vector inputs = GetInputParsedTfOperationsChecked(nodeDef, 2); + + // If one of the inputs is a MatMul and the other is a const, then we handle both nodes together as FullyConnected + if (inputs[0].m_IndexedValue->GetNode().op() == "MatMul" && + HasParsedConstTensor(inputs[1].m_IndexedValue->GetNode().name())) + { + IConnectableLayer* layer = + AddFullyConnectedLayer(inputs[0].m_IndexedValue->GetNode(), + &nodeDef,nodeDef.name().c_str()); + return std::make_unique(this, nodeDef, layer); + } + else if (HasParsedConstTensor(inputs[0].m_IndexedValue->GetNode().name()) && + inputs[1].m_IndexedValue->GetNode().op() == "MatMul") + { + IConnectableLayer* layer = + AddFullyConnectedLayer(inputs[1].m_IndexedValue->GetNode(), + &nodeDef,nodeDef.name().c_str()); + return std::make_unique(this, nodeDef, layer); + } + else + { + // Otherwise it's just a regular addition + return AddAdditionLayer(nodeDef); + } +} + +ParsedTfOperationPtr TfParser::ParseBiasAdd(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef) +{ + return AddAdditionLayer(nodeDef, true); +} + +/// An ParsedTfOperation which forwards to another (used for Identity nodes). +class ParsedIdentityTfOperation : public ParsedTfOperation +{ +public: + ParsedIdentityTfOperation(TfParser* parser, const tensorflow::NodeDef& node, ParsedTfOperation* representative) + : ParsedTfOperation(parser, node) + , m_Representative(representative) + { + } + + virtual IOutputSlot& ResolveArmnnOutputSlot(unsigned int tfOutputIndex) override + { + BOOST_ASSERT(m_Representative); + return m_Representative->ResolveArmnnOutputSlot(tfOutputIndex); + } + + virtual ParsedTfOperation* ResolveIdentityOperations() override + { + return m_Representative->ResolveIdentityOperations(); + } + +private: + ParsedTfOperation* m_Representative; +}; + +ParsedTfOperationPtr TfParser::ParseIdentity(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef) +{ + std::vector inputs = GetInputParsedTfOperationsChecked(nodeDef, 1); + // Any requests for the output slots of this node should be forwarded to the node connected as input. + return std::make_unique(this, nodeDef, inputs[0].m_IndexedValue); +} + +/// An ParsedTfOperation for a Const node. +/// Creation of the armnn ConstLayer is deferred until it is actually needed, because Const nodes are mostly used +/// for weight inputs to MatMul/Conv2D nodes and in these cases armnn doesn't need a ConstLayer. +template +class ParsedConstTfOperation : public DeferredSingleLayerParsedTfOperation +{ +public: + ParsedConstTfOperation(TfParser* parser, const tensorflow::NodeDef& node, + const T* tensorData, const TensorInfo& tensorInfo) + : DeferredSingleLayerParsedTfOperation(parser, node), + m_Storage(tensorData, tensorData + tensorInfo.GetNumElements()), + m_TensorInfo(tensorInfo) + { + BOOST_ASSERT(tensorInfo.GetDataType() == GetDataType()); + } + + void CreateLayerDeferred() override + { + BOOST_ASSERT(m_Layer == nullptr); + m_Layer = m_Parser->m_Network->AddConstantLayer(ConstTensor(m_TensorInfo, m_Storage), m_Node.name().c_str()); + m_Layer->GetOutputSlot(0).SetTensorInfo(m_TensorInfo); + } + + ConstTensor GetConstTensor(bool swizzleForConvolutionWeights, std::vector& outputTensorData) const + { + // Mappings from TensorFlow filter tensors to the ArmNN filter tensors. + // Tensorflow weights are [H, W, In, Out] + // ArmNN weights are [Out, In, H, W] + static const PermutationVector HWIOToOIHW = {2, 3, 1, 0}; + + const TensorInfo outInfo = swizzleForConvolutionWeights + ? armnnUtils::Permuted(m_TensorInfo, HWIOToOIHW) + : m_TensorInfo; + + outputTensorData.resize(m_TensorInfo.GetNumElements()); + + // Copy or swizzle from the permanent storage into the storage the caller provided. + if (swizzleForConvolutionWeights) + { + armnnUtils::Permute(outInfo.GetShape(), HWIOToOIHW, m_Storage.data(), outputTensorData.data()); + } + else + { + memcpy(outputTensorData.data(), m_Storage.data(), m_TensorInfo.GetNumBytes()); + } + // Update the result to point to the user provided storage + ConstTensor constTensor(outInfo, outputTensorData); + return constTensor; + } + +private: + ///< Manages the lifetime of the tensor data. + std::vector m_Storage; + ///< Describes the layout of the tensor and points to the data in m_Storage. + TensorInfo m_TensorInfo; +}; + +DataType ConvertTfTensorDataType(const tensorflow::DataType tfDataType) +{ + switch (tfDataType) + { + case tensorflow::DT_FLOAT: + return DataType::Float32; + break; + case tensorflow::DT_INT32: + return DataType::Signed32; + break; + default: + throw ParseException(boost::str( + boost::format("Unknown DataType %1% for node") + % tensorflow::DataType_Name(tfDataType))); + } +} + +struct ParseTfTensorValueList +{ + template + static void Parse( + const tensorflow::TensorProto& tfTensor, + unsigned int dstElements, + std::vector& outputData); + + template + static void ReadData(const void* srcData, unsigned int numSrcElements, + std::vector& dstData, unsigned int numDstElements) + { + // If there are no entries in the list, perform no action + if (numSrcElements == 0) + { + return; + } + + // If no size was provided, use the length of the value list + if (numDstElements == 0) + { + numDstElements = numSrcElements; + } + + // Allocate memory + dstData.resize(std::max(numSrcElements, numDstElements) * sizeof(DataType)); + + const DataType* srcTensor = reinterpret_cast(srcData); + DataType* dstTensor = reinterpret_cast(dstData.data()); + + // Copy the value list entries into the destination + std::copy(srcTensor, srcTensor + numSrcElements, dstTensor); + + if (numDstElements > numSrcElements) + { + // Use the last element in the list to fill the remaining entries + std::fill(dstTensor + numSrcElements, dstTensor + numDstElements, srcTensor[numSrcElements - 1]); + } + } + +}; + +template <> +void ParseTfTensorValueList::Parse(const tensorflow::TensorProto& tfTensor, + unsigned int dstElements, std::vector& outputData) +{ + ReadData(tfTensor.float_val().data(), static_cast(tfTensor.float_val_size()), + outputData, dstElements); +} + +template <> +void ParseTfTensorValueList::Parse(const tensorflow::TensorProto& tfTensor, + unsigned int dstElements, std::vector& outputData) +{ + ReadData(tfTensor.int_val().data(), static_cast(tfTensor.int_val_size()), + outputData, dstElements); +} + +template class OperatorType, typename T = int8_t> +struct MakeTfOperation +{ + template + inline static std::unique_ptr> Parse(TfParser* parser, const tensorflow::NodeDef& node, + Args&&... args) + { + return std::make_unique>(parser, node, std::forward(args)...); + } +}; + +template <> +struct MakeTfOperation +{ + template + inline static std::unique_ptr> Parse(TfParser* parser, + const tensorflow::NodeDef& node, const std::vector& tensorData, const TensorInfo& tensorInfo) + { + return std::make_unique>(parser, node, + reinterpret_cast(tensorData.data()), tensorInfo); + } +}; + +template +struct InvokeParseFunction +{ + template + inline static ResType Result(DataType dataType, Args&&... args) + { + if (dataType == DataType::Float32) + { + return FuncType::template Parse(std::forward(args)...); + } + else if (dataType == DataType::Signed32) + { + return FuncType::template Parse(std::forward(args)...); + } + + return ResType(); + } + + template + inline static void Result(DataType dataType, Args&&... args) + { + if (dataType == DataType::Float32) + { + FuncType::template Parse(std::forward(args)...); + } + else if (dataType == DataType::Signed32) + { + FuncType::template Parse(std::forward(args)...); + } + } +}; + +ParsedTfOperationPtr TfParser::ParseConst(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef) +{ + BOOST_ASSERT(nodeDef.op() == "Const"); + + if (nodeDef.attr().count("value") == 0) + { + throw ParseException(boost::str( + boost::format("Value not found for Const node - %1%") + % nodeDef.name())); + } + + const tensorflow::TensorProto& tfTensor = nodeDef.attr().at("value").tensor(); + const tensorflow::TensorShapeProto& tfTensorShape = tfTensor.tensor_shape(); + const tensorflow::DataType tfDataType = ReadMandatoryNodeTypeAttribute(nodeDef, "dtype"); + + const auto GetDimensionSize = [](auto& d) { return d.size(); }; + + std::vector dimensionSizes; + std::transform(tfTensorShape.dim().begin(), tfTensorShape.dim().end(), + std::back_inserter(dimensionSizes), GetDimensionSize); + + // Calculate number of elements + const DataType dataType = ConvertTfTensorDataType(tfDataType); + unsigned int numElements = 0U; + + if (!dimensionSizes.empty()) + { + numElements = std::accumulate(dimensionSizes.begin(), dimensionSizes.end(), + 1U, std::multiplies()); + } + + std::vector tensorData; + + // Get tensor data from the list of values attribute + if (tfTensor.tensor_content().empty()) + { + InvokeParseFunction::Result(dataType, tfTensor, numElements, tensorData); + + // If the tensor shape is not defined, but there is a value list, then interpret the data as a 1D + // tensor of the provided number of elements + if (numElements == 0) + { + const unsigned int tfNumElements = static_cast(tensorData.size()) / GetDataTypeSize(dataType); + dimensionSizes.push_back(tfNumElements); + } + } + // Get tensor data from tensor content attribute + else + { + tensorData.assign(tfTensor.tensor_content().begin(), tfTensor.tensor_content().end()); + + // Check if a tensor shape is defined for the tensor content + if (numElements == 0) + { + throw ParseException(boost::str( + boost::format("No tensor shape found for Const node - %1%") + % nodeDef.name())); + } + } + + // Const node requires at least a list of values or a content attribute + if (tensorData.empty()) + { + throw ParseException(boost::str( + boost::format("No tensor data found for Const node - %1%") + % nodeDef.name())); + } + + const TensorInfo tensorInfo(static_cast(dimensionSizes.size()), dimensionSizes.data(), dataType); + + // If we have a list of values, then the length of the list must be + // less than or equal to the number of elements implied by the shape argument + if (tensorData.size() > tensorInfo.GetNumBytes()) + { + throw ParseException(boost::str( + boost::format("Number of elements (%1%) should be less than or equal \ + to the number of elements implied by the shape argument (%2%) for Const node - %3%") + % (tensorData.size() / GetDataTypeSize(dataType)) + % tensorInfo.GetNumElements() + % nodeDef.name())); + } + + return InvokeParseFunction>::Result( + dataType, this, nodeDef, tensorData, tensorInfo); +} + +template +bool TfParser::HasParsedConstTensor(const std::string & nodeName) const +{ + auto it = m_ParsedTfOperations.find(nodeName); + if (it == m_ParsedTfOperations.end() || + dynamic_cast*>(it->second.get()) == nullptr) + { + return false; + } + else + { + return true; + } +} + +ParsedTfOperationPtr TfParser::ParseConv2D(const tensorflow::NodeDef& nodeDef, + const tensorflow::GraphDef& graphDef) +{ + std::vector inputs = GetInputParsedTfOperationsChecked(nodeDef, 2); + IOutputSlot& inputSlot = inputs[0].m_IndexedValue->ResolveArmnnOutputSlot(inputs[0].m_Index); + TensorInfo inputTensorInfo = inputSlot.GetTensorInfo(); + + if (!HasParsedConstTensor(inputs[1].m_IndexedValue->GetNode().name())) + { + throw ParseException("ArmNN only supports Convolution layers with constant weights"); + } + ParsedConstTfOperation* weightNode = + boost::polymorphic_downcast *>(inputs[1].m_IndexedValue); + + std::string paddingString = ReadMandatoryNodeStringAttribute(nodeDef, "padding"); + std::string dataFormat = ReadMandatoryNodeStringAttribute(nodeDef, "data_format"); + std::vector strides = ReadMandatoryNodeUint32ListAttribute(nodeDef, "strides"); + + // read the dilations, if present - only [1,1,1,1] (the default) is supported + std::vector dilations = ReadOptionalNodeUint32ListAttribute(nodeDef, "dilations"); + if (!dilations.empty()) + { + for (auto dilation : dilations) + { + if (dilation != 1u) + { + throw ParseException("ArmNN only supports Convolution layers with dilations [1,1,1,1]"); + } + } + } + + Convolution2dDescriptor desc; + desc.m_BiasEnabled = false; + + if (dataFormat == "NHWC") + { + desc.m_StrideX = strides[2]; + desc.m_StrideY = strides[1]; + // Swizzle input to supported memory layout + inputTensorInfo = armnnUtils::Permuted(inputSlot.GetTensorInfo(), NHWCToArmNN); + } + else if (dataFormat == "NCHW") + { + desc.m_StrideX = strides[3]; + desc.m_StrideY = strides[2]; + } + else + { + throw ParseException("Unsupported data format passed for Conv2D. Only NHWC and NCHW supported"); + } + + uint32_t inputHeight = inputTensorInfo.GetShape()[2]; + uint32_t inputWidth = inputTensorInfo.GetShape()[3]; + + std::vector outputTensorData; + + ConstTensor weightTensor = weightNode->GetConstTensor(true, outputTensorData); + + uint32_t weightHeight = weightTensor.GetShape()[2]; + uint32_t weightWidth = weightTensor.GetShape()[3]; + + bool padding = false; + TensorInfo outputInfo; + if (paddingString == "SAME") + { + padding = true; + outputInfo = TensorInfo({ inputTensorInfo.GetShape()[0], + weightTensor.GetShape()[0], + static_cast(ceil( + static_cast(inputHeight) / + static_cast(desc.m_StrideY))), + static_cast(ceil( + static_cast(inputWidth) / + static_cast(desc.m_StrideX))) + }, DataType::Float32); + } + else if (paddingString == "VALID") + { + padding = false; + outputInfo = TensorInfo({ inputTensorInfo.GetShape()[0], + weightTensor.GetShape()[0], + static_cast(ceil( + static_cast(inputHeight - weightHeight + 1) / + static_cast(desc.m_StrideY))), + static_cast(ceil( + static_cast(inputWidth - weightWidth + 1) / + static_cast(desc.m_StrideX))) + }, DataType::Float32); + } + else + { + throw ParseException("Only 'SAME' and 'VALID' padding supported"); + } + + CalcPadding(inputHeight, weightHeight, desc.m_StrideY, desc.m_PadTop, desc.m_PadBottom, padding); + CalcPadding(inputWidth, weightWidth, desc.m_StrideX, desc.m_PadLeft, desc.m_PadRight, padding); + + IConnectableLayer* layer = m_Network->AddConvolution2dLayer(desc, weightTensor, nodeDef.name().c_str()); + layer->GetOutputSlot(0).SetTensorInfo(outputInfo); + + if (dataFormat == "NHWC") + { + layer = SwizzleInDeswizzleOut(*m_Network, inputSlot, *layer, nodeDef.name()); + } + else + { + inputSlot.Connect(layer->GetInputSlot(0)); + } + + return std::make_unique(this, nodeDef, layer); +} + +ParsedTfOperationPtr TfParser::ParseDepthwiseConv2D(const tensorflow::NodeDef& nodeDef, + const tensorflow::GraphDef& graphDef) +{ + std::vector inputs = GetInputParsedTfOperationsChecked(nodeDef, 2); + IOutputSlot& inputSlot = inputs[0].m_IndexedValue->ResolveArmnnOutputSlot(inputs[0].m_Index); + TensorInfo inputTensorInfo = inputSlot.GetTensorInfo(); + + if (!HasParsedConstTensor(inputs[1].m_IndexedValue->GetNode().name())) + { + throw ParseException("ArmNN only supports Depthwise Convolution layers with constant weights"); + } + ParsedConstTfOperation* weightNode = + boost::polymorphic_downcast *>(inputs[1].m_IndexedValue); + + + std::string paddingString = ReadMandatoryNodeStringAttribute(nodeDef, "padding"); + std::string dataFormat = ReadMandatoryNodeStringAttribute(nodeDef, "data_format"); + std::vector strides = ReadMandatoryNodeUint32ListAttribute(nodeDef, "strides"); + + DepthwiseConvolution2dDescriptor desc; + desc.m_BiasEnabled = false; + + if (dataFormat == "NHWC") + { + desc.m_StrideX = strides[2]; + desc.m_StrideY = strides[1]; + // Swizzle input to supported memory layout + inputTensorInfo = armnnUtils::Permuted(inputSlot.GetTensorInfo(), NHWCToArmNN); + } + else if (dataFormat == "NCHW") + { + desc.m_StrideX = strides[3]; + desc.m_StrideY = strides[2]; + } + else + { + throw ParseException("Unsupported data format passed for DepthwiseConv2dNative. Only NHWC and NCHW supported"); + } + + uint32_t inputHeight = inputTensorInfo.GetShape()[2]; + uint32_t inputWidth = inputTensorInfo.GetShape()[3]; + + std::vector outputTensorData; + + ConstTensor weightTensor = weightNode->GetConstTensor(true, outputTensorData); + + uint32_t weightHeight = weightTensor.GetShape()[2]; + uint32_t weightWidth = weightTensor.GetShape()[3]; + + bool padding = false; + TensorInfo outputInfo; + if (paddingString == "SAME") + { + padding = true; + outputInfo = TensorInfo({ inputTensorInfo.GetShape()[0], + weightTensor.GetShape()[0] * weightTensor.GetShape()[1], + static_cast(ceil( + static_cast(inputHeight) / + static_cast(desc.m_StrideY))), + static_cast(ceil( + static_cast(inputWidth) / + static_cast(desc.m_StrideX))) + }, DataType::Float32); + } + else if (paddingString == "VALID") + { + padding = false; + outputInfo = TensorInfo({ inputTensorInfo.GetShape()[0], + weightTensor.GetShape()[0] * weightTensor.GetShape()[1], + static_cast(ceil( + static_cast(inputHeight - weightHeight + 1) / + static_cast(desc.m_StrideY))), + static_cast(ceil( + static_cast(inputWidth - weightWidth + 1) / + static_cast(desc.m_StrideX))) + }, DataType::Float32); + } + else + { + throw ParseException("Only 'SAME' and 'VALID' padding supported"); + } + + CalcPadding(inputHeight, weightHeight, desc.m_StrideY, desc.m_PadTop, desc.m_PadBottom, padding); + CalcPadding(inputWidth, weightWidth, desc.m_StrideX, desc.m_PadLeft, desc.m_PadRight, padding); + + IConnectableLayer* layer = m_Network->AddDepthwiseConvolution2dLayer(desc, weightTensor, nodeDef.name().c_str()); + layer->GetOutputSlot(0).SetTensorInfo(outputInfo); + + if (dataFormat == "NHWC") + { + layer = SwizzleInDeswizzleOut(*m_Network, inputSlot, *layer, nodeDef.name()); + } + else + { + inputSlot.Connect(layer->GetInputSlot(0)); + } + + return std::make_unique(this, nodeDef, layer); +} + +ParsedTfOperationPtr TfParser::ParseFusedBatchNorm(const tensorflow::NodeDef& nodeDef, + const tensorflow::GraphDef& graphDef) +{ + std::vector inputs = GetInputParsedTfOperationsChecked(nodeDef, 5); + + if (!HasParsedConstTensor(inputs[1].m_IndexedValue->GetNode().name())) + { + throw ParseException("ArmNN only supports FusedBatchNormalization layers with constant scale"); + } + ParsedConstTfOperation* scaleNode = + boost::polymorphic_downcast *>(inputs[1].m_IndexedValue); + + if (!HasParsedConstTensor(inputs[2].m_IndexedValue->GetNode().name())) + { + throw ParseException("ArmNN only supports FusedBatchNormalization layers with constant offset"); + } + ParsedConstTfOperation* offsetNode = + boost::polymorphic_downcast *>(inputs[2].m_IndexedValue); + + if (!HasParsedConstTensor(inputs[3].m_IndexedValue->GetNode().name())) + { + throw ParseException("ArmNN only supports FusedBatchNormalization layers with constant mean"); + } + ParsedConstTfOperation* meanNode = + boost::polymorphic_downcast *>(inputs[3].m_IndexedValue); + + if (!HasParsedConstTensor(inputs[4].m_IndexedValue->GetNode().name())) + { + throw ParseException("ArmNN only supports FusedBatchNormalization layers with constant variance"); + } + ParsedConstTfOperation* varianceNode = + boost::polymorphic_downcast *>(inputs[4].m_IndexedValue); + + // The descriptor only has the epsilon attribute + BatchNormalizationDescriptor desc; + desc.m_Eps = ReadMandatoryNodeFloatAttribute(nodeDef, "epsilon"); + + // data for the parsed tensor args (scale, offset, mean, variance) must be stored locally until the layer is added + std::vector scaleTensorData; + ConstTensor scaleTensor = scaleNode->GetConstTensor(false, scaleTensorData); + + std::vector offsetTensorData; + ConstTensor offsetTensor = offsetNode->GetConstTensor(false, offsetTensorData); + + std::vector meanTensorData; + ConstTensor meanTensor = meanNode->GetConstTensor(false, meanTensorData); + + std::vector varianceTensorData; + ConstTensor varianceTensor = varianceNode->GetConstTensor(false, varianceTensorData); + + IConnectableLayer* layer = m_Network->AddBatchNormalizationLayer(desc, + meanTensor, + varianceTensor, + offsetTensor, + scaleTensor, + nodeDef.name().c_str()); + + IOutputSlot& inputSlot = inputs[0].m_IndexedValue->ResolveArmnnOutputSlot(inputs[0].m_Index); + + const std::string dataFormat = ReadMandatoryNodeStringAttribute(nodeDef, "data_format"); + + if (dataFormat == "NHWC") + { + const TensorInfo outputTensorInfo = armnnUtils::Permuted(inputSlot.GetTensorInfo(), NHWCToArmNN); + layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); + layer = SwizzleInDeswizzleOut(*m_Network, inputSlot, *layer, nodeDef.name()); + } + else + { + layer->GetOutputSlot(0).SetTensorInfo(inputSlot.GetTensorInfo()); + inputSlot.Connect(layer->GetInputSlot(0)); + } + + return std::make_unique(this, nodeDef, layer); +} + +ParsedTfOperationPtr TfParser::ParseConcat(const tensorflow::NodeDef& nodeDef, + const tensorflow::GraphDef& graphDef) +{ + std::vector nodes = GetTfInputNodes(nodeDef); + // In tensorflow, we have the last input of the Concat layer as the axis for concatenation + unsigned int numInputs = static_cast(nodes.size()); + unsigned int numConcatView = numInputs - 1; + + OriginsDescriptor concatDescriptor(static_cast(numConcatView), MaxNumOfTensorDimensions); + std::vectormergeDimSizes(MaxNumOfTensorDimensions, 0u); + + unsigned int mergeDim = 0; + std::vector inputs = GetInputParsedTfOperationsChecked(nodeDef, numInputs); + + // The last input is the axis for concatenation + if (!HasParsedConstTensor(inputs[numInputs - 1].m_IndexedValue->GetNode().name())) + { + throw ParseException("ArmNN only supports Concat with constant axis"); + } + ParsedConstTfOperation* shapeNode = + boost::polymorphic_downcast*>(inputs[numInputs - 1].m_IndexedValue); + + std::vector axisTensorData; + ConstTensor axisTensor = shapeNode->GetConstTensor(false, axisTensorData); + + // This concatDim indicates the data format: 3 is the NHWC, 1 is the NCHW + const unsigned int concatDimInput = static_cast(axisTensorData[0]); + + // Armnn supports concatenation along the channel dimension for data format NHWC and NCHW + if (concatDimInput == 0 || concatDimInput == 2) + { + throw ParseException("The dimension for concatenation is not supported by Armnn"); + } + + // This is the only concatDim we support in Armnn + const unsigned int concatDim = 1; + for (unsigned int viewIndex = 0; viewIndex < numConcatView; ++viewIndex) + { + // need to double check whether it should be + IOutputSlot& inputSlot = + inputs[viewIndex].m_IndexedValue->ResolveArmnnOutputSlot(inputs[viewIndex].m_Index); + TensorInfo inputTensorInfo = inputSlot.GetTensorInfo(); + + if (inputTensorInfo.GetNumDimensions() != MaxNumOfTensorDimensions) + { + throw ParseException("The number of dimensions for input tensors of the concatenation op should be 4"); + } + + if (concatDimInput == 3) + { + inputTensorInfo = armnnUtils::Permuted(inputTensorInfo, NHWCToArmNN); + } + + for (unsigned int dim = 0; dim < MaxNumOfTensorDimensions; ++dim) + { + mergeDimSizes[dim] = inputTensorInfo.GetShape()[dim]; + } + + for (unsigned int j = 0; j < concatDim; ++j) + { + concatDescriptor.SetViewOriginCoord(viewIndex, j, 0); + } + + concatDescriptor.SetViewOriginCoord(viewIndex, concatDim, mergeDim); + mergeDim += mergeDimSizes[concatDim]; + + for (unsigned int j = concatDim+1; j < MaxNumOfTensorDimensions; ++j) + { + concatDescriptor.SetViewOriginCoord(viewIndex, j, 0); + } + } + + mergeDimSizes[concatDim] = mergeDim; + armnn::IConnectableLayer *layer = m_Network->AddMergerLayer(concatDescriptor, nodeDef.name().c_str()); + + layer->GetOutputSlot(0).SetTensorInfo(armnn::TensorInfo(MaxNumOfTensorDimensions, mergeDimSizes.data(), + DataType::Float32)); + + for (unsigned int v = 0; v < numConcatView; ++v) + { + IOutputSlot& inputSlot = inputs[v].m_IndexedValue->ResolveArmnnOutputSlot(inputs[v].m_Index); + if (concatDimInput == 3) + { + IConnectableLayer* const swizzleLayer = AddSwizzleLayer(*m_Network, inputSlot, NHWCToArmNN, + "swizzle_for-" + nodeDef.name()); + swizzleLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(v)); + } + else + { + inputSlot.Connect(layer->GetInputSlot(v)); + } + } + + if (concatDimInput == 3) + { + IConnectableLayer* const deswizzleLayer = AddSwizzleLayer(*m_Network, layer->GetOutputSlot(0), ArmNNToNHWC, + "deswizzle_for-" + nodeDef.name()); + layer = deswizzleLayer; + } + + return std::make_unique(this, nodeDef, layer); +} + +ParsedTfOperationPtr TfParser::ParseShape(const tensorflow::NodeDef& nodeDef, + const tensorflow::GraphDef& graphDef) +{ + // Note: The Shape layer is handled in a special way, because: + // 1. ARMNN doesn't support int32 tensors which it outputs + // 2. ARMNN works with statically shaped tensors which are known at parse time + // 3. because of 1. and 2. we treat the output of Shape as a temporary const int32 + // tensor which may be used as an input to other ops, most likely a Reshape + + const tensorflow::DataType tfDataType = ReadMandatoryNodeTypeAttribute(nodeDef, "out_type"); + if (tfDataType != tensorflow::DT_INT32) + { + throw ParseException("Armnn only supports DT_INT32 as out_type"); + } + + const std::vector inputs = GetInputParsedTfOperationsChecked(nodeDef, 1); + IOutputSlot& prevLayerOutputSlot = inputs[0].m_IndexedValue->ResolveArmnnOutputSlot(inputs[0].m_Index); + const TensorInfo& prevLayerTensorInfo = prevLayerOutputSlot.GetTensorInfo(); + unsigned int prevLayerDimensions = prevLayerTensorInfo.GetNumDimensions(); + + std::vector shapeTensorData; + shapeTensorData.reserve(prevLayerDimensions); + + for (unsigned int i=0; i(prevLayerTensorInfo.GetShape()[i])); + } + + TensorInfo shapeTensorInfo(1, &prevLayerDimensions, DataType::Signed32); + + return std::make_unique>(this, + nodeDef, + &shapeTensorData[0], + shapeTensorInfo); +} + +ParsedTfOperationPtr TfParser::ParseReshape(const tensorflow::NodeDef& nodeDef, + const tensorflow::GraphDef& graphDef) +{ + std::vector inputs = GetInputParsedTfOperationsChecked(nodeDef, 2); + ParsedTfOperation* inputNode = inputs[0].m_IndexedValue; + + if (!HasParsedConstTensor(inputs[1].m_IndexedValue->GetNode().name())) + { + throw ParseException("ArmNN only supports Reshape layers with constant shapes"); + } + ParsedConstTfOperation* shapeNode = + boost::polymorphic_downcast*>(inputs[1].m_IndexedValue); + + armnn::IOutputSlot& prevLayerOutputSlot = inputNode->ResolveArmnnOutputSlot(inputs[0].m_Index); + TensorInfo inputTensorInfo = prevLayerOutputSlot.GetTensorInfo(); + + std::vector shapeTensorData; + ConstTensor shapeTensor = shapeNode->GetConstTensor(false, shapeTensorData); + const TensorInfo outputTensorInfo = PrepareReshape(inputTensorInfo, shapeTensorData); + + TensorShape targetShape = outputTensorInfo.GetShape(); + ReshapeDescriptor reshapeDesc; + reshapeDesc.m_TargetShape = targetShape; + + IConnectableLayer* layer = m_Network->AddReshapeLayer(reshapeDesc, nodeDef.name().c_str()); + prevLayerOutputSlot.Connect(layer->GetInputSlot(0)); + layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); + + return std::make_unique(this, nodeDef, layer); +} + +ParsedTfOperationPtr TfParser::ParseResizeBilinear(const tensorflow::NodeDef& nodeDef, + const tensorflow::GraphDef& graphDef) +{ + std::vector inputs = GetInputParsedTfOperationsChecked(nodeDef, 2); + + if (!HasParsedConstTensor(inputs[1].m_IndexedValue->GetNode().name())) + { + throw ParseException("ArmNN only supports ResizeBilinear layers with constant sizes"); + } + ParsedConstTfOperation* sizeNode = + boost::polymorphic_downcast*>(inputs[1].m_IndexedValue); + + // Check the align_corners attribute is not set + if (ReadOptionalNodeBoolAttribute(nodeDef, "align_corners", false)) + { + throw ParseException("ArmNN only supports ResizeBilinear layers with align_corners set to false"); + } + + // data for the parsed tensor args (size) must be stored locally + std::vector sizeTensorData; + ConstTensor sizeTensor = sizeNode->GetConstTensor(false, sizeTensorData); + + // The descriptor only has target height and width attributes, which we get from the size tensor + ResizeBilinearDescriptor desc; + desc.m_TargetHeight = static_cast (sizeTensorData[0]); + desc.m_TargetWidth = static_cast (sizeTensorData[1]); + + IConnectableLayer* layer = m_Network->AddResizeBilinearLayer(desc, nodeDef.name().c_str()); + + IOutputSlot& inputSlot = inputs[0].m_IndexedValue->ResolveArmnnOutputSlot(inputs[0].m_Index); + TensorInfo inputTensorInfo = inputSlot.GetTensorInfo(); + // the input shape is always in BHWC format, this will be swizzled below; for now, + // get the batch and channels to make up the ArmNN output shape with the target size + unsigned int outBatch = inputTensorInfo.GetShape()[0]; + unsigned int outChannels = inputTensorInfo.GetShape()[3]; + unsigned int outHeight = desc.m_TargetHeight; + unsigned int outWidth = desc.m_TargetWidth; + TensorShape outShape({outBatch, outChannels, outHeight, outWidth}); + // The output DataType is always Float32, regardless of the input DataType + const TensorInfo outputTensorInfo(outShape, armnn::DataType::Float32); + layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); + + // TensorFlow ResizeBilinear input is always in BHWC format, so add swizzle and deswizzle layers + layer = SwizzleInDeswizzleOut(*m_Network, inputSlot, *layer, nodeDef.name()); + + return std::make_unique(this, nodeDef, layer); +} + +TensorInfo OutputShapeOfSqueeze(const tensorflow::NodeDef& nodeDef, TensorInfo inputTensorInfo) +{ + BOOST_ASSERT(nodeDef.op() == "Squeeze"); + tensorflow::DataType tfDataType = ReadMandatoryNodeTypeAttribute(nodeDef, "T"); + + DataType type; + if (tfDataType == tensorflow::DT_FLOAT) + { + type = DataType::Float32; + } + else if (tfDataType == tensorflow::DT_INT32) + { + type = DataType::Signed32; + } + else + { + throw ParseException(boost::str( + boost::format("Unsupported DataType %1% for Squeeze operation") + % tensorflow::DataType_Name(tfDataType))); + } + + std::vector squeezeDims = ReadOptionalNodeUint32ListAttribute(nodeDef, "squeeze_dims"); + if (squeezeDims.empty()) + { + for(unsigned int i = 0; i < inputTensorInfo.GetNumDimensions(); i++) + { + if (inputTensorInfo.GetShape()[i] == 1) + { + squeezeDims.push_back(i); + } + } + } + + std::vector outputDims; + for(unsigned int i = 0; i < inputTensorInfo.GetNumDimensions(); i++) + { + bool includeDimension = (std::find(squeezeDims.begin(), squeezeDims.end(), i) == squeezeDims.end()); + if (includeDimension) + { + outputDims.push_back(inputTensorInfo.GetShape()[i]); + } + } + + if (outputDims.size() > 4) + { + throw ParseException("Unsupported shape for Squeeze"); + } + + TensorInfo outTensorInfo = TensorInfo(boost::numeric_cast(outputDims.size()), + outputDims.data(), + type); + + return outTensorInfo; +} + +ParsedTfOperationPtr TfParser::ParseSqueeze(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef) +{ + std::vector inputs = GetInputParsedTfOperationsChecked(nodeDef, 1); + + IOutputSlot& prevLayerOutputSlot = inputs[0].m_IndexedValue->ResolveArmnnOutputSlot(inputs[0].m_Index); + TensorInfo inputTensorInfo = prevLayerOutputSlot.GetTensorInfo(); + + TensorInfo outputInfo; + outputInfo = OutputShapeOfSqueeze(nodeDef, inputTensorInfo); + + ReshapeDescriptor reshapeDesc; + reshapeDesc.m_TargetShape = outputInfo.GetShape(); + IConnectableLayer* layer = m_Network->AddReshapeLayer(reshapeDesc, nodeDef.name().c_str()); + prevLayerOutputSlot.Connect(layer->GetInputSlot(0)); + layer->GetOutputSlot(0).SetTensorInfo(outputInfo); + + return std::make_unique(this, nodeDef, layer); +} + +ParsedTfOperationPtr TfParser::ParseLrn(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef) +{ + std::vector inputs = GetInputParsedTfOperationsChecked(nodeDef, 1); + + NormalizationDescriptor normalizationDescriptor; + normalizationDescriptor.m_NormMethodType = NormalizationAlgorithmMethod::LocalBrightness; + normalizationDescriptor.m_NormChannelType = NormalizationAlgorithmChannel::Across; + normalizationDescriptor.m_Alpha = ReadMandatoryNodeFloatAttribute(nodeDef, "alpha"); + normalizationDescriptor.m_Beta = ReadMandatoryNodeFloatAttribute(nodeDef, "beta"); + normalizationDescriptor.m_K = ReadMandatoryNodeFloatAttribute(nodeDef, "bias"); + normalizationDescriptor.m_NormSize = ReadMandatoryNodeUint32Attribute(nodeDef, "depth_radius"); + + // The window size must be an odd value. For a window size of (2 * n + 1), TensorFlow defines depth_radius = n. + normalizationDescriptor.m_NormSize = normalizationDescriptor.m_NormSize * 2 + 1; + + IOutputSlot& prevLayerOutputSlot = inputs[0].m_IndexedValue->ResolveArmnnOutputSlot(inputs[0].m_Index); + + IConnectableLayer* layer = m_Network->AddNormalizationLayer(normalizationDescriptor, + nodeDef.name().c_str()); + + const TensorInfo permutedInfo = armnnUtils::Permuted(prevLayerOutputSlot.GetTensorInfo(), NHWCToArmNN); + layer->GetOutputSlot(0).SetTensorInfo(permutedInfo); + + layer = SwizzleInDeswizzleOut(*m_Network, prevLayerOutputSlot, *layer, nodeDef.name()); + + return std::make_unique(this, nodeDef, layer); +} + +/// An ParsedTfOperation for a MatMul node. +/// Creation of the armnn FullyConnected layer is deferred until it is actually needed, because MatMul nodes are +/// often used for the first part of a biased FullyConnected (MatMul followed by Add) and in these cases armnn doesn't +/// need a separate layer for the MatMul. +class ParsedMatMulTfOperation : public DeferredSingleLayerParsedTfOperation +{ +public: + ParsedMatMulTfOperation(TfParser* parser, const tensorflow::NodeDef& node) + : DeferredSingleLayerParsedTfOperation(parser, node) + { + } + + void CreateLayerDeferred() override + { + BOOST_ASSERT(m_Layer == nullptr); + m_Layer = m_Parser->AddFullyConnectedLayer(m_Node, nullptr, m_Node.name().c_str()); + } +}; + +ParsedTfOperationPtr TfParser::ParseMatMul(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef) +{ + // Defer the creation of the layer (see ParsedMatMulTfOperation). + return std::make_unique(this, nodeDef); +} + +ParsedTfOperationPtr TfParser::ParseMul(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef) +{ + boost::ignore_unused(graphDef); + + std::vector inputs = GetInputParsedTfOperationsChecked(nodeDef, 2); + + IConnectableLayer* const layer = m_Network->AddMultiplicationLayer(nodeDef.name().c_str()); + IOutputSlot* input0Slot = &inputs[0].m_IndexedValue->ResolveArmnnOutputSlot(inputs[0].m_Index); + IOutputSlot* input1Slot = &inputs[1].m_IndexedValue->ResolveArmnnOutputSlot(inputs[1].m_Index); + + auto const input0NumDims = input0Slot->GetTensorInfo().GetNumDimensions(); + auto const input1NumDims = input1Slot->GetTensorInfo().GetNumDimensions(); + + if (input0NumDims < input1NumDims) + { + const bool isNHWC = true; + input0Slot = BroadcastForAddandMul(input1Slot, input0Slot, isNHWC, *m_Network, nodeDef); + } + if (input1NumDims < input0NumDims) + { + const bool isNHWC = true; + input1Slot = BroadcastForAddandMul(input0Slot, input1Slot, isNHWC, *m_Network, nodeDef); + } + + input0Slot->Connect(layer->GetInputSlot(0)); + input1Slot->Connect(layer->GetInputSlot(1)); + + if (input0NumDims < input1NumDims) + { + layer->GetOutputSlot(0).SetTensorInfo(input1Slot->GetTensorInfo()); + } + else + { + layer->GetOutputSlot(0).SetTensorInfo(input0Slot->GetTensorInfo()); + } + return std::make_unique(this, nodeDef, layer); +} + +ParsedTfOperationPtr TfParser::ParsePlaceholder(const tensorflow::NodeDef& nodeDef, + const tensorflow::GraphDef& graphDef) +{ + boost::ignore_unused(graphDef); + + std::vector inputs = GetInputParsedTfOperationsChecked(nodeDef, 0); + + const LayerBindingId layerId = boost::numeric_cast(m_NetworkInputsBindingInfo.size()); + + auto it = m_InputShapes.find(nodeDef.name()); + if (it == m_InputShapes.end()) + { + throw ParseException("Missing input shape for Placeholder '" + nodeDef.name() + "'"); + } + TensorInfo tensorInfo(it->second, DataType::Float32); + + IConnectableLayer* const layer = m_Network->AddInputLayer(layerId, nodeDef.name().c_str()); + + layer->GetOutputSlot(0).SetTensorInfo(tensorInfo); + + TrackInputBinding(layer, layerId, tensorInfo); + + return std::make_unique(this, nodeDef, layer); +} + +ParsedTfOperationPtr TfParser::ParseRelu(const tensorflow::NodeDef& nodeDef, + const tensorflow::GraphDef& graphDef) +{ + boost::ignore_unused(graphDef); + + ActivationDescriptor activationDesc; + activationDesc.m_Function = ActivationFunction::ReLu; + return AddActivationLayer(nodeDef, activationDesc); +} + +ParsedTfOperationPtr TfParser::ParseRelu6(const tensorflow::NodeDef& nodeDef, + const tensorflow::GraphDef& graphDef) +{ + boost::ignore_unused(graphDef); + + ActivationDescriptor activationDesc; + activationDesc.m_Function = ActivationFunction::BoundedReLu; + activationDesc.m_A = 6.0f; + activationDesc.m_B = 0.0f; + + return AddActivationLayer(nodeDef, activationDesc); +} + +ParsedTfOperationPtr TfParser::ParseSigmoid(const tensorflow::NodeDef& nodeDef, + const tensorflow::GraphDef& graphDef) +{ + boost::ignore_unused(graphDef); + + ActivationDescriptor activationDesc; + activationDesc.m_Function = ActivationFunction::Sigmoid; + + return AddActivationLayer(nodeDef, activationDesc); +} + +ParsedTfOperationPtr TfParser::ParseSoftmax(const tensorflow::NodeDef& nodeDef, + const tensorflow::GraphDef& graphDef) +{ + boost::ignore_unused(graphDef); + + std::vector inputs = GetInputParsedTfOperationsChecked(nodeDef, 1); + + SoftmaxDescriptor softmaxDescriptor; + IConnectableLayer* const layer = m_Network->AddSoftmaxLayer(softmaxDescriptor, nodeDef.name().c_str()); + + IOutputSlot& prevLayerSlot = inputs[0].m_IndexedValue->ResolveArmnnOutputSlot(inputs[0].m_Index); + prevLayerSlot.Connect(layer->GetInputSlot(0)); + layer->GetOutputSlot(0).SetTensorInfo(prevLayerSlot.GetTensorInfo()); + + return std::make_unique(this, nodeDef, layer); +} + +ParsedTfOperationPtr TfParser::ParseSoftplus(const tensorflow::NodeDef& nodeDef, + const tensorflow::GraphDef& graphDef) +{ + boost::ignore_unused(graphDef); + + ActivationDescriptor activationDesc; + activationDesc.m_Function = ActivationFunction::SoftReLu; + + return AddActivationLayer(nodeDef, activationDesc); +} + +ParsedTfOperationPtr TfParser::ParseTanh(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef) +{ + boost::ignore_unused(graphDef); + + ActivationDescriptor activationDesc; + activationDesc.m_Function = ActivationFunction::TanH; + activationDesc.m_A = 1.0f; + activationDesc.m_B = 1.0f; + + return AddActivationLayer(nodeDef, activationDesc); +} + +ParsedTfOperationPtr TfParser::AddActivationLayer(const tensorflow::NodeDef& nodeDef, + ActivationDescriptor& activationDesc) +{ + std::vector inputs = GetInputParsedTfOperationsChecked(nodeDef, 1); + + IConnectableLayer* const layer = m_Network->AddActivationLayer(activationDesc, nodeDef.name().c_str()); + + IOutputSlot& prevLayerOutputSlot = inputs[0].m_IndexedValue->ResolveArmnnOutputSlot(inputs[0].m_Index); + prevLayerOutputSlot.Connect(layer->GetInputSlot(0)); + layer->GetOutputSlot(0).SetTensorInfo(prevLayerOutputSlot.GetTensorInfo()); + return std::make_unique(this, nodeDef, layer); +} + +ParsedTfOperationPtr TfParser::ParseMaxPool(const tensorflow::NodeDef& nodeDef, + const tensorflow::GraphDef& graphDef) +{ + return ParsePooling2d(nodeDef, graphDef, PoolingAlgorithm::Max); +} + +ParsedTfOperationPtr TfParser::ParseAvgPool(const tensorflow::NodeDef& nodeDef, + const tensorflow::GraphDef& graphDef) +{ + return ParsePooling2d(nodeDef, graphDef, PoolingAlgorithm::Average); +} + +ParsedTfOperationPtr TfParser::ParsePooling2d(const tensorflow::NodeDef& nodeDef, + const tensorflow::GraphDef& graphDef, PoolingAlgorithm pooltype) +{ + std::vector inputs = GetInputParsedTfOperationsChecked(nodeDef, 1); + IOutputSlot& inputSlot = inputs[0].m_IndexedValue->ResolveArmnnOutputSlot(inputs[0].m_Index); + TensorInfo inputTensorInfo = inputSlot.GetTensorInfo(); + + if (inputs.size() != 1) + { + throw ParseException("2D Pooling expects one input!"); + } + + std::string paddingString = ReadMandatoryNodeStringAttribute(nodeDef, "padding"); + std::string dataFormat = ReadMandatoryNodeStringAttribute(nodeDef, "data_format"); + std::vector strides = ReadMandatoryNodeUint32ListAttribute(nodeDef, "strides"); + std::vector ksize = ReadMandatoryNodeUint32ListAttribute(nodeDef, "ksize"); // size of pool windows + + Pooling2dDescriptor pooling2dDescriptor; + pooling2dDescriptor.m_PoolType = pooltype; + pooling2dDescriptor.m_PaddingMethod = PaddingMethod::Exclude; + pooling2dDescriptor.m_OutputShapeRounding = OutputShapeRounding::Floor; + + if (dataFormat == "NHWC") + { + pooling2dDescriptor.m_StrideX = strides[2]; + pooling2dDescriptor.m_StrideY = strides[1]; + pooling2dDescriptor.m_PoolWidth = ksize[2]; + pooling2dDescriptor.m_PoolHeight = ksize[1]; + // Swizzle input to supported memory layout + inputTensorInfo = armnnUtils::Permuted(inputSlot.GetTensorInfo(), NHWCToArmNN); + } + else if (dataFormat == "NCHW") + { + pooling2dDescriptor.m_StrideX = strides[3]; + pooling2dDescriptor.m_StrideY = strides[2]; + pooling2dDescriptor.m_PoolWidth = ksize[3]; + pooling2dDescriptor.m_PoolHeight = ksize[2]; + } + else + { + throw ParseException("Only NHWC or NCHW supported for Pooling2d"); + } + + uint32_t inputHeight = inputTensorInfo.GetShape()[2]; + uint32_t inputWidth = inputTensorInfo.GetShape()[3]; + + bool padding = false; + TensorInfo outputInfo; + if (paddingString == "SAME") + { + padding = true; + outputInfo = TensorInfo({ inputTensorInfo.GetShape()[0], + inputTensorInfo.GetShape()[1], + static_cast(ceil( + static_cast(inputHeight) / + static_cast(pooling2dDescriptor.m_StrideY))), + static_cast(ceil( + static_cast(inputWidth) / + static_cast(pooling2dDescriptor.m_StrideX))) + }, DataType::Float32); + } + else if (paddingString == "VALID") + { + padding = false; + outputInfo = TensorInfo({ inputTensorInfo.GetShape()[0], + inputTensorInfo.GetShape()[1], + static_cast(ceil( + static_cast(inputHeight - pooling2dDescriptor.m_PoolHeight + 1) / + static_cast(pooling2dDescriptor.m_StrideY))), + static_cast(ceil( + static_cast(inputWidth - pooling2dDescriptor.m_PoolWidth + 1) / + static_cast(pooling2dDescriptor.m_StrideX))) + }, DataType::Float32); + } + else + { + throw ParseException("Only 'SAME' and 'VALID' padding supported"); + } + + CalcPadding(inputWidth, pooling2dDescriptor.m_PoolWidth, pooling2dDescriptor.m_StrideX, + pooling2dDescriptor.m_PadLeft, pooling2dDescriptor.m_PadRight, padding); + CalcPadding(inputHeight, pooling2dDescriptor.m_PoolHeight, pooling2dDescriptor.m_StrideY, + pooling2dDescriptor.m_PadTop, pooling2dDescriptor.m_PadBottom, padding); + + + IConnectableLayer* layer = m_Network->AddPooling2dLayer(pooling2dDescriptor, nodeDef.name().c_str()); + if (layer == nullptr) + { + throw ParseException("Failed to add pooling2d layer"); + } + + layer->GetOutputSlot(0).SetTensorInfo(outputInfo); + + if (dataFormat == "NHWC") + { + layer = SwizzleInDeswizzleOut(*m_Network, inputSlot, *layer, nodeDef.name()); + } + else + { + inputSlot.Connect(layer->GetInputSlot(0)); + } + + return std::make_unique(this, nodeDef, layer); +} + +ParsedTfOperationPtr TfParser::AddAdditionLayer(const tensorflow::NodeDef& nodeDef, bool isBiasAdd) +{ + std::vector inputs = GetInputParsedTfOperationsChecked(nodeDef, 2); + + IOutputSlot* input0Slot = &inputs[0].m_IndexedValue->ResolveArmnnOutputSlot(inputs[0].m_Index); + IOutputSlot* input1Slot = &inputs[1].m_IndexedValue->ResolveArmnnOutputSlot(inputs[1].m_Index); + + const TensorInfo& input0Info = input0Slot->GetTensorInfo(); + const TensorInfo& input1Info = input1Slot->GetTensorInfo(); + + if (isBiasAdd) + { + // BiasAdd takes bias as a 1D tensor. We need to add a reshape layer to create a 4D tensor + // with the same data in the correct dimension for broadcast in addition. + if(input1Info.GetNumDimensions() != 1) + { + throw ParseException("Unsupported bias for BiasAdd. It should be a 1D vector."); + } + + const std::string dataFormat = ReadMandatoryNodeStringAttribute(nodeDef, "data_format"); + const bool isNHWC = (dataFormat == "NHWC"); + const bool isNCHW = (dataFormat == "NCHW"); + + if (!isNHWC && ! isNCHW) + { + throw ParseException("Only NHWC or NCHW supported for BiasAdd"); + } + + input1Slot = BroadcastForAddandMul(input0Slot, input1Slot, isNHWC, *m_Network, nodeDef); + } + else + { + if (input0Info.GetNumDimensions() == 1) + { + const bool isNHWC = true; + input0Slot = BroadcastForAddandMul(input1Slot, input0Slot, isNHWC, *m_Network, nodeDef); + } + + if (input1Info.GetNumDimensions() == 1) + { + const bool isNHWC = true; + input1Slot = BroadcastForAddandMul(input0Slot, input1Slot, isNHWC, *m_Network, nodeDef); + } + } + + IConnectableLayer* const layer = m_Network->AddAdditionLayer(nodeDef.name().c_str()); + + input0Slot->Connect(layer->GetInputSlot(0)); + input1Slot->Connect(layer->GetInputSlot(1)); + + if (input0Info.GetNumDimensions() == 1 && isBiasAdd == false) + { + layer->GetOutputSlot(0).SetTensorInfo(input1Slot->GetTensorInfo()); + } + else + { + layer->GetOutputSlot(0).SetTensorInfo(input0Slot->GetTensorInfo()); + } + + return std::make_unique(this, nodeDef, layer); +} + +IConnectableLayer* TfParser::AddFullyConnectedLayer(const tensorflow::NodeDef& matMulNodeDef, + const tensorflow::NodeDef* addNodeDef, const char* armnnLayerName) +{ + // find bias const (if applicable) + ParsedConstTfOperation* biasNode = nullptr; + if (addNodeDef != nullptr) + { + std::vector addInputs = GetInputParsedTfOperationsChecked(*addNodeDef, 2); + // find our inputs + if (HasParsedConstTensor(addInputs[0].m_IndexedValue->GetNode().name())) + { + biasNode = boost::polymorphic_downcast*>(addInputs[0].m_IndexedValue); + } + else if (HasParsedConstTensor(addInputs[1].m_IndexedValue->GetNode().name())) + { + biasNode = boost::polymorphic_downcast*>(addInputs[1].m_IndexedValue); + } + else + { + throw ParseException("ArmNN only supports fully connected layers with constant bias"); + } + } + + // find matmul inputs + ParsedConstTfOperation* weightNode = nullptr; + ParsedTfOperation* inputNode = nullptr; + unsigned int inputIdx = 0; + std::vector mulInputs = GetInputParsedTfOperationsChecked(matMulNodeDef, 2); + if (HasParsedConstTensor(mulInputs[0].m_IndexedValue->GetNode().name())) + { + weightNode = boost::polymorphic_downcast*>(mulInputs[0].m_IndexedValue); + inputNode = mulInputs[1].m_IndexedValue; + inputIdx = mulInputs[1].m_Index; + } + else if (HasParsedConstTensor(mulInputs[1].m_IndexedValue->GetNode().name())) + { + weightNode = boost::polymorphic_downcast*>(mulInputs[1].m_IndexedValue); + inputNode = mulInputs[0].m_IndexedValue; + inputIdx = mulInputs[0].m_Index; + } + else + { + throw ParseException("ArmNN only supports fully connected layers with constant weights"); + } + + std::vector weightTensorData; + // handle weight + ConstTensor weights = weightNode->GetConstTensor(false, weightTensorData); + + FullyConnectedDescriptor desc; + desc.m_BiasEnabled = addNodeDef != nullptr; + + IConnectableLayer* layer = nullptr; + // make the layer + if (addNodeDef != nullptr) + { + std::vector biasTensorData; + ConstTensor biases = biasNode->GetConstTensor(false, biasTensorData); + + if (weights.GetShape()[1] != biases.GetShape()[0]) + { + throw ParseException("shape of matmul and bias do not match"); + } + + layer = m_Network->AddFullyConnectedLayer(desc, weights, biases, armnnLayerName); + } + else + { + layer = m_Network->AddFullyConnectedLayer(desc, weights, armnnLayerName); + } + + BOOST_ASSERT(layer != nullptr); + + inputNode->ResolveArmnnOutputSlot(inputIdx).Connect(layer->GetInputSlot(0)); + unsigned int batches = inputNode->ResolveArmnnOutputSlot(inputIdx).GetTensorInfo().GetShape()[0]; + + // handle output + TensorInfo outputInfo({ batches, weights.GetShape()[1] }, DataType::Float32); + layer->GetOutputSlot(0).SetTensorInfo(outputInfo); + return layer; +} + +void TfParser::LoadNodeDef(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef) +{ + // get the type of the node (assume float) + tensorflow::DataType type = tensorflow::DT_FLOAT; + if (nodeDef.attr().count("T") != 0) + { + auto attr = nodeDef.attr().at("T"); + type = attr.type(); + } + else if (nodeDef.attr().count("dtype") != 0) + { + auto attr = nodeDef.attr().at("dtype"); + type = attr.type(); + } + + if (type != tensorflow::DT_FLOAT && nodeDef.op() != "Const") + { + throw ParseException("Currently only FLOAT is supported for tensorflow nodes (apart from Const)"); + } + + const std::string& operation = nodeDef.op(); + auto it = ms_OperationNameToParsingFunctions.find(operation); + if (it != ms_OperationNameToParsingFunctions.end()) + { + auto func = it->second; + ParsedTfOperationPtr parsedTfOperation = (this->*func)(nodeDef, graphDef); + ParsedTfOperation* parsedTfOperationRaw = parsedTfOperation.get(); + + // Store the parsed operation so that dependent layers can connect to it + auto it = m_ParsedTfOperations.find(nodeDef.name()); + if (it != m_ParsedTfOperations.end()) + { + throw ParseException(boost::str(boost::format("Name %1% used by more than one node") % nodeDef.name())); + } + m_ParsedTfOperations[nodeDef.name()] = std::move(parsedTfOperation); + + // If this node was requested as an output from the network then add an ArmNN output layer + if (std::find(m_RequestedOutputs.begin(), m_RequestedOutputs.end(), nodeDef.name()) != + m_RequestedOutputs.end()) + { + auto outId = ParseOutputId(nodeDef.name()); + const LayerBindingId layerId = boost::numeric_cast(m_NetworkOutputsBindingInfo.size()); + IOutputSlot& prevSlot = parsedTfOperationRaw->ResolveArmnnOutputSlot(outId.m_Index); + + TensorInfo tensorInfo = prevSlot.GetTensorInfo(); + + IConnectableLayer* outputLayer = m_Network->AddOutputLayer(layerId, nodeDef.name().c_str()); + + prevSlot.Connect(outputLayer->GetInputSlot(0)); + + TrackOutputBinding(outputLayer, layerId, tensorInfo); + } + } + else + { + throw ParseException(boost::str( + boost::format("Unsupported operation %1% in tensorflow::GraphDef") % operation)); + } +} + +void TfParser::LoadGraphDef(const tensorflow::GraphDef& graphDef) +{ + // add all nodes to our map + m_NodesByName.clear(); + m_NetworkInputsBindingInfo.clear(); + m_NetworkOutputsBindingInfo.clear(); + + for (int i = 0; i < graphDef.node_size(); ++i) + { + const tensorflow::NodeDef& node = graphDef.node(i); + m_NodesByName[node.name()] = &node; + } + + // Find the output nodes the user requested + std::vector targetNodes; + for (const std::string& requestedOutputName : m_RequestedOutputs) + { + auto nodeIt = m_NodesByName.find(requestedOutputName); + if (nodeIt == m_NodesByName.end()) + { + throw ParseException("Couldn't find requested output node '" + requestedOutputName + "' in graph"); + } + targetNodes.push_back(nodeIt->second); + } + + // Sort them into a linear ordering such that all inputs of a node are before the node itself + std::vector sortedNodes; + if (!armnnUtils::GraphTopologicalSort( + targetNodes, + [this](const tensorflow::NodeDef* node) + { + auto outputs = GetTfInputNodes(*node); + std::vector nodesOnly; + for (const auto & o : outputs) { + nodesOnly.push_back(o.m_IndexedValue); + } + return nodesOnly; + }, + sortedNodes)) + { + throw ParseException("Cycle detected in graph"); + } + + // Parse each node in order, knowing that all inputs of a node will be processed before the node itself + for (const auto& it : sortedNodes) + { + const tensorflow::NodeDef& currentNode = *it; + LoadNodeDef(currentNode, graphDef); + } +} + +INetworkPtr TfParser::CreateNetworkFromTextFile(const char* graphFile, + const std::map& inputShapes, + const std::vector& requestedOutputs) +{ + FILE* fd = fopen(graphFile, "r"); + + if (fd == nullptr) + { + std::stringstream error; + error << "Graph file " << graphFile << " failed to open"; + throw FileNotFoundException(error.str()); + } + + // Parse the file into a message + tensorflow::GraphDef graphDef; + auto input = new google::protobuf::io::FileInputStream(fileno(fd)); + bool success = google::protobuf::TextFormat::Parse(input, &graphDef); + delete input; + fclose(fd); + + if (!success) + { + std::stringstream error; + error << "Failed to parse graph file"; + throw ParseException(error.str()); + } + + return CreateNetworkFromGraphDef(graphDef, inputShapes, requestedOutputs); +} + +INetworkPtr TfParser::CreateNetworkFromString(const char* protoText, + const std::map& inputShapes, + const std::vector& requestedOutputs) +{ + // Parse the string into a message + tensorflow::GraphDef graphDef; + bool success = google::protobuf::TextFormat::ParseFromString(protoText, &graphDef); + + if (!success) + { + std::stringstream error; + error << "Failed to parse graph file"; + throw ParseException(error.str()); + } + + return CreateNetworkFromGraphDef(graphDef, inputShapes, requestedOutputs); +} + +INetworkPtr TfParser::CreateNetworkFromBinaryFile(const char* graphFile, + const std::map& inputShapes, + const std::vector& requestedOutputs) +{ + FILE* fd = fopen(graphFile, "rb"); + + if (fd == nullptr) + { + std::stringstream error; + error << "Graph file " << graphFile << " failed to open"; + throw FileNotFoundException(error.str()); + } + + // Parse the file into a message + tensorflow::GraphDef graphDef; + + google::protobuf::io::FileInputStream inStream(fileno(fd)); + google::protobuf::io::CodedInputStream codedStream(&inStream); + codedStream.SetTotalBytesLimit(INT_MAX, INT_MAX); + bool success = graphDef.ParseFromCodedStream(&codedStream); + fclose(fd); + + if (!success) + { + std::stringstream error; + error << "Failed to parse protobuf file" << graphFile; + throw ParseException(error.str()); + } + + return CreateNetworkFromGraphDef(graphDef, inputShapes, requestedOutputs); +} + +INetworkPtr TfParser::CreateNetworkFromGraphDef(const tensorflow::GraphDef& graphDef, + const std::map& inputShapes, + const std::vector& requestedOutputs) +{ + m_Network = INetwork::Create(); + + m_InputShapes = inputShapes; + if (requestedOutputs.size() == 0) + { + throw ParseException("requestedOutputs must have at least one entry"); + } + m_RequestedOutputs = requestedOutputs; + + try + { + LoadGraphDef(graphDef); + } + catch (const ParseException& e) + { + Cleanup(); + throw e; + } + + Cleanup(); + + return std::move(m_Network); +} + +void TfParser::Cleanup() +{ + // cleanup, in case we reuse this parser + m_InputShapes.clear(); + m_RequestedOutputs.clear(); + m_NodesByName.clear(); + m_ParsedTfOperations.clear(); +} + +BindingPointInfo TfParser::GetNetworkInputBindingInfo(const std::string& name) const +{ + return GetBindingInfo(name, "input", m_NetworkInputsBindingInfo); +} + +BindingPointInfo TfParser::GetNetworkOutputBindingInfo(const std::string& name) const +{ + return GetBindingInfo(name, "output", m_NetworkOutputsBindingInfo); +} + +std::pair TfParser::GetBindingInfo(const std::string& layerName, + const char* bindingPointDesc, + const std::unordered_map& nameToBindingInfo) +{ + auto it = nameToBindingInfo.find(layerName); + if (it == nameToBindingInfo.end()) + { + throw InvalidArgumentException(boost::str(boost::format("Unknown %1% '%2%'") % bindingPointDesc % layerName)); + } + return it->second; +} + +void TfParser::TrackInputBinding(IConnectableLayer* layer, LayerBindingId id, const TensorInfo& tensorInfo) +{ + return TrackBindingPoint(layer, id, tensorInfo, "input", m_NetworkInputsBindingInfo); +} + +void TfParser::TrackOutputBinding(IConnectableLayer* layer, LayerBindingId id, const TensorInfo& tensorInfo) +{ + return TrackBindingPoint(layer, id, tensorInfo, "output", m_NetworkOutputsBindingInfo); +} + +void TfParser::TrackBindingPoint(IConnectableLayer* layer, + LayerBindingId id, + const TensorInfo& tensorInfo, + const char* bindingPointDesc, + std::unordered_map& nameToBindingInfo) +{ + const std::string layerName = layer->GetName(); + auto it = nameToBindingInfo.find(layerName); + if (it == nameToBindingInfo.end()) + { + nameToBindingInfo[layerName] = std::make_pair(id, tensorInfo); + } + else + { + throw ParseException(boost::str( + boost::format("Id %1% used by more than one %2% layer") % id % bindingPointDesc)); + } +} + +} // namespace armnnTfParser diff --git a/src/armnnTfParser/TfParser.hpp b/src/armnnTfParser/TfParser.hpp new file mode 100644 index 0000000000..c5b4bce8ac --- /dev/null +++ b/src/armnnTfParser/TfParser.hpp @@ -0,0 +1,199 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// See LICENSE file in the project root for full license information. +// +#pragma once + +#include "armnnTfParser/ITfParser.hpp" + +#include "armnn/Types.hpp" +#include "armnn/Tensor.hpp" +#include "armnn/INetwork.hpp" + +#include +#include +#include +#include + +namespace armnn +{ +class TensorInfo; +} + +namespace tensorflow +{ +class GraphDef; +class NodeDef; +} + +namespace armnnTfParser +{ + +using BindingPointInfo = std::pair; + +class ParsedTfOperation; +using ParsedTfOperationPtr = std::unique_ptr; + +/// +/// WithOutputTensorIndex wraps a value and an index. The purpose of +/// this template is to signify that in Tensorflow the input name of +/// a layer has the convention of 'inputTensorName:#index' where the +/// #index can be omitted and it implicitly means the 0. output of +/// the referenced layer. By supporting this notation we can handle +/// layers with multiple outputs, such as Split. +/// +template +struct WithOutputTensorIndex +{ + T m_IndexedValue; + unsigned int m_Index; + + WithOutputTensorIndex(const T & value, unsigned int index) + : m_IndexedValue{value} + , m_Index{index} {} + + WithOutputTensorIndex(T && value, unsigned int index) + : m_IndexedValue{value} + , m_Index{index} {} +}; + +using OutputOfParsedTfOperation = WithOutputTensorIndex; +using OutputOfConstNodeDef = WithOutputTensorIndex; +using OutputId = WithOutputTensorIndex; + +class TfParser : public ITfParser +{ +public: + /// Create the network from a protobuf text file on disk + virtual armnn::INetworkPtr CreateNetworkFromTextFile( + const char* graphFile, + const std::map& inputShapes, + const std::vector& requestedOutputs) override; + + /// Create the network from a protobuf binary file on disk + virtual armnn::INetworkPtr CreateNetworkFromBinaryFile( + const char* graphFile, + const std::map& inputShapes, + const std::vector& requestedOutputs) override; + + /// Create the network directly from protobuf text in a string. Useful for debugging/testing + virtual armnn::INetworkPtr CreateNetworkFromString( + const char* protoText, + const std::map& inputShapes, + const std::vector& requestedOutputs) override; + + /// Retrieve binding info (layer id and tensor info) for the network input identified by the given layer name + virtual BindingPointInfo GetNetworkInputBindingInfo(const std::string& name) const override; + + /// Retrieve binding info (layer id and tensor info) for the network output identified by the given layer name + virtual BindingPointInfo GetNetworkOutputBindingInfo(const std::string& name) const override; + +public: + TfParser(); + +private: + template + friend class ParsedConstTfOperation; + friend class ParsedMatMulTfOperation; + + /// Parses a GraphDef loaded into memory from one of the other CreateNetwork* + armnn::INetworkPtr CreateNetworkFromGraphDef(const tensorflow::GraphDef& graphDef, + const std::map& inputShapes, + const std::vector& requestedOutputs); + + /// sets up variables and then performs BFS to parse all nodes + void LoadGraphDef(const tensorflow::GraphDef& graphDef); + + /// parses a given node, assuming nodes before it in graph have been done + void LoadNodeDef(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); + + /// Handling identity layers as the input for Conv2D layer + const tensorflow::NodeDef* ResolveIdentityNode(const tensorflow::NodeDef* nodeDef); + /// Finds the nodes connected as inputs of the given node in the graph. + std::vector GetTfInputNodes(const tensorflow::NodeDef& nodeDef) const; + /// Finds the IParsedTfOperations for the nodes connected as inputs of the given node in the graph, + /// and throws an exception if the number of inputs does not match the expected one. + /// This will automatically resolve any identity nodes. The result vector contains the parsed operation + /// together with the output tensor index to make the connection unambiguous. + std::vector GetInputParsedTfOperationsChecked(const tensorflow::NodeDef& nodeDef, + std::size_t expectedNumInputs); + + ParsedTfOperationPtr ParseConst(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); + + /// Checks if there is a pre-parsed const tensor is available with the given name and Type + template + bool HasParsedConstTensor(const std::string & nodeName) const; + + ParsedTfOperationPtr ParseAdd(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); + ParsedTfOperationPtr ParseBiasAdd(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); + ParsedTfOperationPtr ParseConv2D(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); + ParsedTfOperationPtr ParseDepthwiseConv2D(const tensorflow::NodeDef& nodeDef,const tensorflow::GraphDef& graphDef); + ParsedTfOperationPtr ParseFusedBatchNorm(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); + ParsedTfOperationPtr ParseConcat(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); + ParsedTfOperationPtr ParseIdentity(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); + ParsedTfOperationPtr ParseLrn(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); + ParsedTfOperationPtr ParseMatMul(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); + ParsedTfOperationPtr ParseMul(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); + ParsedTfOperationPtr ParsePlaceholder(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); + ParsedTfOperationPtr ParseRelu(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); + ParsedTfOperationPtr ParseRelu6(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); + ParsedTfOperationPtr ParseReshape(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); + ParsedTfOperationPtr ParseResizeBilinear(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); + ParsedTfOperationPtr ParseShape(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); + ParsedTfOperationPtr ParseSqueeze(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); + ParsedTfOperationPtr ParseSigmoid(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); + ParsedTfOperationPtr ParseSoftmax(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); + ParsedTfOperationPtr ParseSoftplus(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); + ParsedTfOperationPtr ParseTanh(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); + ParsedTfOperationPtr ParseMaxPool(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); + ParsedTfOperationPtr ParseAvgPool(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); + ParsedTfOperationPtr ParsePooling2d(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef, + armnn::PoolingAlgorithm pooltype); + ParsedTfOperationPtr AddActivationLayer(const tensorflow::NodeDef& nodeDef, armnn::ActivationDescriptor& desc); + ParsedTfOperationPtr AddAdditionLayer(const tensorflow::NodeDef& nodeDef, bool isBiasAdd = false); + armnn::IConnectableLayer* AddFullyConnectedLayer(const tensorflow::NodeDef& matMulNodeDef, + const tensorflow::NodeDef* addNodeDef, const char* armnnLayerName); + + static std::pair GetBindingInfo(const std::string& layerName, + const char* bindingPointDesc, + const std::unordered_map& nameToBindingInfo); + + void TrackInputBinding(armnn::IConnectableLayer* layer, + armnn::LayerBindingId id, + const armnn::TensorInfo& tensorInfo); + + void TrackOutputBinding(armnn::IConnectableLayer* layer, + armnn::LayerBindingId id, + const armnn::TensorInfo& tensorInfo); + + static void TrackBindingPoint(armnn::IConnectableLayer* layer, armnn::LayerBindingId id, + const armnn::TensorInfo& tensorInfo, + const char* bindingPointDesc, + std::unordered_map& nameToBindingInfo); + + void Cleanup(); + + /// The network we're building. Gets cleared after it is passed to the user + armnn::INetworkPtr m_Network; + + using OperationParsingFunction = ParsedTfOperationPtr(TfParser::*)(const tensorflow::NodeDef& nodeDef, + const tensorflow::GraphDef& graphDef); + + /// map of TensorFlow operation names to parsing member functions + static const std::map ms_OperationNameToParsingFunctions; + + std::map m_InputShapes; + std::vector m_RequestedOutputs; + + /// map of nodes extracted from the GraphDef to speed up parsing + std::unordered_map m_NodesByName; + + std::unordered_map m_ParsedTfOperations; + + /// maps input layer names to their corresponding ids and tensor infos + std::unordered_map m_NetworkInputsBindingInfo; + + /// maps output layer names to their corresponding ids and tensor infos + std::unordered_map m_NetworkOutputsBindingInfo; +}; +} diff --git a/src/armnnTfParser/test/Activations.cpp b/src/armnnTfParser/test/Activations.cpp new file mode 100644 index 0000000000..72ed64d653 --- /dev/null +++ b/src/armnnTfParser/test/Activations.cpp @@ -0,0 +1,113 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// See LICENSE file in the project root for full license information. +// + +#include +#include "armnnTfParser/ITfParser.hpp" +#include "ParserPrototxtFixture.hpp" + +BOOST_AUTO_TEST_SUITE(TensorflowParser) + + +struct ActivationFixture : public ParserPrototxtFixture +{ + explicit ActivationFixture(const char* activationFunction) + { + m_Prototext = "node {\n" + " name: \"Placeholder\"\n" + " op: \"Placeholder\"\n" + " attr {\n" + " key: \"dtype\"\n" + " value {\n" + " type: DT_FLOAT\n" + " }\n" + " }\n" + " attr {\n" + " key: \"shape\"\n" + " value {\n" + " shape {\n" + " unknown_rank: true\n" + " }\n" + " }\n" + " }\n" + "}\n" + "node {\n" + " name: \""; + m_Prototext.append(activationFunction); + m_Prototext.append("\"\n" + " op: \""); + m_Prototext.append(activationFunction); + m_Prototext.append("\"\n" + " input: \"Placeholder\"\n" + " attr {\n" + " key: \"T\"\n" + " value {\n" + " type: DT_FLOAT\n" + " }\n" + " }\n" + "}\n"); + + SetupSingleInputSingleOutput({ 1, 7 }, "Placeholder", activationFunction); + } +}; + + +struct ReLuFixture : ActivationFixture +{ + ReLuFixture() : ActivationFixture("Relu") {} +}; +BOOST_FIXTURE_TEST_CASE(ParseReLu, ReLuFixture) +{ + RunTest<2>({ -1.0f, -0.5f, 1.25f, -3.0f, 0.0f, 0.5f, -0.75f }, + { 0.0f, 0.0f, 1.25f, 0.0f, 0.0f, 0.5f, 0.0f }); +} + + +struct ReLu6Fixture : ActivationFixture +{ + ReLu6Fixture() : ActivationFixture("Relu6") {} +}; +BOOST_FIXTURE_TEST_CASE(ParseReLu6, ReLu6Fixture) +{ + RunTest<2>({ -1.0f, -0.5f, 7.25f, -3.0f, 0.0f, 0.5f, -0.75f }, + { 0.0f, 0.0f, 6.0f, 0.0f, 0.0f, 0.5f, 0.0f }); +} + + +struct SigmoidFixture : ActivationFixture +{ + SigmoidFixture() : ActivationFixture("Sigmoid") {} +}; +BOOST_FIXTURE_TEST_CASE(ParseSigmoid, SigmoidFixture) +{ + RunTest<2>({ -0.1f, -0.2f, -0.3f, -0.4f, 0.1f, 0.2f, 0.3f }, + { 0.4750208f, 0.45016602f, 0.42555749f, 0.40131235f, 0.52497917f, 0.54983395f, 0.57444251f }); +} + + +struct SoftplusFixture : ActivationFixture +{ + SoftplusFixture() : ActivationFixture("Softplus") {} +}; +BOOST_FIXTURE_TEST_CASE(ParseSoftplus, SoftplusFixture) +{ + RunTest<2>({ -0.1f, -0.2f, -0.3f, -0.4f, 0.1f, 0.2f, 0.3f }, + { 0.64439666f, 0.59813893f, 0.55435526f, 0.51301527f, 0.74439669f, 0.7981388f, 0.85435522f }); +} + + +struct TanhFixture : ActivationFixture +{ + TanhFixture() : ActivationFixture("Tanh") {} +}; +BOOST_FIXTURE_TEST_CASE(ParseTanh, TanhFixture) +{ + RunTest<2>({ -0.1f, -0.2f, -0.3f, -0.4f, 0.1f, 0.2f, 0.3f }, + { -0.09966799f, -0.19737528f, -0.29131261f, -0.379949f, 0.09966799f, 0.19737528f, 0.29131261f }); +} + + + + +BOOST_AUTO_TEST_SUITE_END() diff --git a/src/armnnTfParser/test/Addition.cpp b/src/armnnTfParser/test/Addition.cpp new file mode 100644 index 0000000000..c9e69268c6 --- /dev/null +++ b/src/armnnTfParser/test/Addition.cpp @@ -0,0 +1,78 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// See LICENSE file in the project root for full license information. +// + +#include +#include "armnnTfParser/ITfParser.hpp" +#include "ParserPrototxtFixture.hpp" + +BOOST_AUTO_TEST_SUITE(TensorflowParser) + +struct AdditionFixture : public ParserPrototxtFixture +{ + AdditionFixture() + { + m_Prototext = "node { \n" + " name: \"graphInput\" \n" + " op: \"Placeholder\" \n" + " attr { \n" + " key: \"dtype\" \n" + " value { \n" + " type: DT_FLOAT \n" + " } \n" + " } \n" + " attr { \n" + " key: \"shape\" \n" + " value { \n" + " shape { \n" + " } \n" + " } \n" + " } \n" + " } \n" + " node { \n" + " name: \"softmax1\" \n" + " op: \"Softmax\" \n" + " input: \"graphInput\" \n" + " attr { \n" + " key: \"T\" \n" + " value { \n" + " type: DT_FLOAT \n" + " } \n" + " } \n" + " }\n" + " node {\n" + " name: \"softmax2\"\n" + " op : \"Softmax\"\n" + " input: \"graphInput\"\n" + " attr { \n" + " key: \"T\" \n" + " value { \n" + " type: DT_FLOAT \n" + " } \n" + " } \n" + " }\n" + " node {\n" + " name: \"addition\"\n" + " op : \"Add\"\n" + " input: \"softmax1\"\n" + " input: \"softmax2\"\n" + " attr { \n" + " key: \"T\" \n" + " value { \n" + " type: DT_FLOAT \n" + " } \n" + " } \n" + " }\n"; + + SetupSingleInputSingleOutput({ 1, 7 }, "graphInput", "addition"); + } +}; + +BOOST_FIXTURE_TEST_CASE(ParseAddition, AdditionFixture) +{ + RunTest<2>({ 0, 0, 10000, 0, 0, 0, 0 }, { 0, 0, 2, 0, 0, 0, 0 }); +} + + +BOOST_AUTO_TEST_SUITE_END() diff --git a/src/armnnTfParser/test/BiasAdd.cpp b/src/armnnTfParser/test/BiasAdd.cpp new file mode 100644 index 0000000000..e29aeb1057 --- /dev/null +++ b/src/armnnTfParser/test/BiasAdd.cpp @@ -0,0 +1,104 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// See LICENSE file in the project root for full license information. +// + +#include +#include "armnnTfParser/ITfParser.hpp" +#include "ParserPrototxtFixture.hpp" + +BOOST_AUTO_TEST_SUITE(TensorflowParser) + +struct BiasAddFixture : public ParserPrototxtFixture +{ + explicit BiasAddFixture(const std::string& dataFormat) + { + m_Prototext = R"( +node { + name: "graphInput" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } +} +node { + name: "bias" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 3 + } + } + float_val: 1 + float_val: 2 + float_val: 3 + } + } + } +} +node { + name: "biasAdd" + op : "BiasAdd" + input: "graphInput" + input: "bias" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "data_format" + value { + s: ")" + dataFormat + R"(" + } + } +} +)"; + + SetupSingleInputSingleOutput({ 1, 3, 1, 3 }, "graphInput", "biasAdd"); + } +}; + +struct BiasAddFixtureNCHW : BiasAddFixture +{ + BiasAddFixtureNCHW() : BiasAddFixture("NCHW") {} +}; + +struct BiasAddFixtureNHWC : BiasAddFixture +{ + BiasAddFixtureNHWC() : BiasAddFixture("NHWC") {} +}; + +BOOST_FIXTURE_TEST_CASE(ParseBiasAddNCHW, BiasAddFixtureNCHW) +{ + RunTest<4>(std::vector(9), { 1, 1, 1, 2, 2, 2, 3, 3, 3 }); +} + +BOOST_FIXTURE_TEST_CASE(ParseBiasAddNHWC, BiasAddFixtureNHWC) +{ + RunTest<4>(std::vector(9), { 1, 2, 3, 1, 2, 3, 1, 2, 3 }); +} + +BOOST_AUTO_TEST_SUITE_END() diff --git a/src/armnnTfParser/test/BroadcastForAdd.cpp b/src/armnnTfParser/test/BroadcastForAdd.cpp new file mode 100644 index 0000000000..4c9731d7fc --- /dev/null +++ b/src/armnnTfParser/test/BroadcastForAdd.cpp @@ -0,0 +1,149 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// See LICENSE file in the project root for full license information. +// + +#include +#include "armnnTfParser/ITfParser.hpp" +#include "ParserPrototxtFixture.hpp" +// This is a special case for add, which supports broadcasting +BOOST_AUTO_TEST_SUITE(TensorflowParser) + +struct BroadcastForAddFixtureSlot1 : public ParserPrototxtFixture +{ + BroadcastForAddFixtureSlot1() + { + m_Prototext = R"( + node { + name: "graphInput" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + } + node { + name: "Const_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 4.0 + float_val: 5.0 + } + } + } + } + node { + name: "Add" + op: "Add" + input: "graphInput" + input: "Const_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + )"; + + SetupSingleInputSingleOutput({ 1, 2, 2, 2 }, "graphInput", "Add"); + } +}; + +struct BroadcastForAddFixtureSlot0 : public ParserPrototxtFixture +{ + BroadcastForAddFixtureSlot0() + { + m_Prototext = R"( + node { + name: "graphInput" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + } + node { + name: "Const_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 4.0 + float_val: 5.0 + } + } + } + } + node { + name: "Add" + op: "Add" + input: "Const_1" + input: "graphInput" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + )"; + + SetupSingleInputSingleOutput({ 1, 2, 2, 2 }, "graphInput", "Add"); + } +}; + + +BOOST_FIXTURE_TEST_CASE(ParseBroadcastForAddition1, BroadcastForAddFixtureSlot1) +{ + RunTest<4>({ 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0 }, { 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0 }); +} + +BOOST_FIXTURE_TEST_CASE(ParseBroadcastForAddition0, BroadcastForAddFixtureSlot0) +{ + RunTest<4>({ 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0 }, { 5.0, 6.0, 6.0, 7.0, 7.0, 8.0, 8.0, 9.0 }); +} + + + +BOOST_AUTO_TEST_SUITE_END() diff --git a/src/armnnTfParser/test/Concat.cpp b/src/armnnTfParser/test/Concat.cpp new file mode 100644 index 0000000000..a7d5ea03af --- /dev/null +++ b/src/armnnTfParser/test/Concat.cpp @@ -0,0 +1,183 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// See LICENSE file in the project root for full license information. +// + +#include +#include "armnnTfParser/ITfParser.hpp" +#include "ParserPrototxtFixture.hpp" + +BOOST_AUTO_TEST_SUITE(TensorflowParser) + +struct ConcatFixture : public ParserPrototxtFixture +{ + explicit ConcatFixture(const armnn::TensorShape& inputShape0, const armnn::TensorShape& inputShape1, + unsigned int concatDim) + { + m_Prototext = R"( + node { + name: "graphInput0" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + } + node { + name: "graphInput1" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + } + node { + name: "concat/axis" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: )"; + + m_Prototext += std::to_string(concatDim); + + m_Prototext += R"( + } + } + } + } + node { + name: "concat" + op: "ConcatV2" + input: "graphInput0" + input: "graphInput1" + input: "concat/axis" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_FLOAT + } + } + } + )"; + + Setup({{"graphInput0", inputShape0 }, + {"graphInput1", inputShape1 }}, {"concat"}); + } +}; + +struct ConcatFixtureNCHW : ConcatFixture +{ + ConcatFixtureNCHW() : ConcatFixture({ 1, 1, 2, 2 }, { 1, 1, 2, 2 }, 1 ) {} +}; + +struct ConcatFixtureNHWC : ConcatFixture +{ + ConcatFixtureNHWC() : ConcatFixture({ 1, 1, 2, 2 }, { 1, 1, 2, 2 }, 3 ) {} +}; + +BOOST_FIXTURE_TEST_CASE(ParseConcatNCHW, ConcatFixtureNCHW) +{ + RunTest<4>({{"graphInput0", {0.0, 1.0, 2.0, 3.0}}, + {"graphInput1", {4.0, 5.0, 6.0, 7.0}}}, + {{"concat", { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0 }}}); +} + +BOOST_FIXTURE_TEST_CASE(ParseConcatNHWC, ConcatFixtureNHWC) +{ + RunTest<4>({{"graphInput0", {0.0, 1.0, 2.0, 3.0}}, + {"graphInput1", {4.0, 5.0, 6.0, 7.0}}}, + {{"concat", { 0.0, 1.0, 4.0, 5.0, 2.0, 3.0, 6.0, 7.0 }}}); +} + +struct ConcatFixtureDim1 : ConcatFixture +{ + ConcatFixtureDim1() : ConcatFixture({ 1, 2, 3, 4 }, { 1, 2, 3, 4 }, 1) {} +}; + +struct ConcatFixtureDim3 : ConcatFixture +{ + ConcatFixtureDim3() : ConcatFixture({ 1, 2, 3, 4 }, { 1, 2, 3, 4 }, 3) {} +}; + +BOOST_FIXTURE_TEST_CASE(ParseConcatDim1, ConcatFixtureDim1) +{ + RunTest<4>({ { "graphInput0", { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, + 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0 } }, + { "graphInput1", { 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0, 61.0, + 62.0, 63.0, 64.0, 65.0, 66.0, 67.0, 68.0, 69.0, 70.0, 71.0, 72.0, 73.0 } } }, + { { "concat", { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, + 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, + 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0, 61.0, + 62.0, 63.0, 64.0, 65.0, 66.0, 67.0, 68.0, 69.0, 70.0, 71.0, 72.0, 73.0 } } }); +} + +BOOST_FIXTURE_TEST_CASE(ParseConcatDim3, ConcatFixtureDim3) +{ + RunTest<4>({ { "graphInput0", { 0.0, 1.0, 2.0, 3.0, + 4.0, 5.0, 6.0, 7.0, + 8.0, 9.0, 10.0, 11.0, + 12.0, 13.0, 14.0, 15.0, + 16.0, 17.0, 18.0, 19.0, + 20.0, 21.0, 22.0, 23.0 } }, + { "graphInput1", { 50.0, 51.0, 52.0, 53.0, + 54.0, 55.0, 56.0, 57.0, + 58.0, 59.0, 60.0, 61.0, + 62.0, 63.0, 64.0, 65.0, + 66.0, 67.0, 68.0, 69.0, + 70.0, 71.0, 72.0, 73.0 } } }, + { { "concat", { 0.0, 1.0, 2.0, 3.0, + 50.0, 51.0, 52.0, 53.0, + 4.0, 5.0, 6.0, 7.0, + 54.0, 55.0, 56.0, 57.0, + 8.0, 9.0, 10.0, 11.0, + 58.0, 59.0, 60.0, 61.0, + 12.0, 13.0, 14.0, 15.0, + 62.0, 63.0, 64.0, 65.0, + 16.0, 17.0, 18.0, 19.0, + 66.0, 67.0, 68.0, 69.0, + 20.0, 21.0, 22.0, 23.0, + 70.0, 71.0, 72.0, 73.0 } } }); +} + +BOOST_AUTO_TEST_SUITE_END() \ No newline at end of file diff --git a/src/armnnTfParser/test/ConcatOfConcats.cpp b/src/armnnTfParser/test/ConcatOfConcats.cpp new file mode 100644 index 0000000000..7316b9f1ac --- /dev/null +++ b/src/armnnTfParser/test/ConcatOfConcats.cpp @@ -0,0 +1,316 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// See LICENSE file in the project root for full license information. +// + +#include +#include "armnnTfParser/ITfParser.hpp" +#include "ParserPrototxtFixture.hpp" + +BOOST_AUTO_TEST_SUITE(TensorflowParser) + +struct ConcatOfConcatsFixture : public ParserPrototxtFixture +{ + explicit ConcatOfConcatsFixture(const armnn::TensorShape& inputShape0, const armnn::TensorShape& inputShape1, + const armnn::TensorShape& inputShape2, const armnn::TensorShape& inputShape3, + unsigned int concatDim) + { + m_Prototext = R"( + node { + name: "graphInput0" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + } + node { + name: "graphInput1" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + } + node { + name: "graphInput2" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + } + node { + name: "graphInput3" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + } + node { + name: "Relu" + op: "Relu" + input: "graphInput0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node { + name: "Relu_1" + op: "Relu" + input: "graphInput1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node { + name: "Relu_2" + op: "Relu" + input: "graphInput2" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node { + name: "Relu_3" + op: "Relu" + input: "graphInput3" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node { + name: "concat/axis" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: )"; + m_Prototext += std::to_string(concatDim); + m_Prototext += R"( + } + } + } + } + node { + name: "concat" + op: "ConcatV2" + input: "Relu" + input: "Relu_1" + input: "concat/axis" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + } + node { + name: "concat_1/axis" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: )"; + m_Prototext += std::to_string(concatDim); + m_Prototext += R"( + } + } + } + } + node { + name: "concat_1" + op: "ConcatV2" + input: "Relu_2" + input: "Relu_3" + input: "concat_1/axis" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + } + node { + name: "concat_2/axis" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: )"; + m_Prototext += std::to_string(concatDim); + m_Prototext += R"( + } + } + } + } + node { + name: "concat_2" + op: "ConcatV2" + input: "concat" + input: "concat_1" + input: "concat_2/axis" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + } + )"; + + Setup({{ "graphInput0", inputShape0 }, + { "graphInput1", inputShape1 }, + { "graphInput2", inputShape2 }, + { "graphInput3", inputShape3}}, {"concat_2"}); + } +}; + +struct ConcatOfConcatsFixtureNCHW : ConcatOfConcatsFixture +{ + ConcatOfConcatsFixtureNCHW() : ConcatOfConcatsFixture({ 1, 1, 2, 2 }, { 1, 1, 2, 2 }, { 1, 1, 2, 2 }, + { 1, 1, 2, 2 }, 1 ) {} +}; + +struct ConcatOfConcatsFixtureNHWC : ConcatOfConcatsFixture +{ + ConcatOfConcatsFixtureNHWC() : ConcatOfConcatsFixture({ 1, 1, 2, 2 }, { 1, 1, 2, 2 }, { 1, 1, 2, 2 }, + { 1, 1, 2, 2 }, 3 ) {} +}; + +BOOST_FIXTURE_TEST_CASE(ParseConcatOfConcatsNCHW, ConcatOfConcatsFixtureNCHW) +{ + RunTest<4>({{"graphInput0", {0.0, 1.0, 2.0, 3.0}}, + {"graphInput1", {4.0, 5.0, 6.0, 7.0}}, + {"graphInput2", {8.0, 9.0, 10.0, 11.0}}, + {"graphInput3", {12.0, 13.0, 14.0, 15.0}}}, + {{"concat_2", { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, + 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0 }}}); +} + +BOOST_FIXTURE_TEST_CASE(ParseConcatOfConcatsNHWC, ConcatOfConcatsFixtureNHWC) +{ + RunTest<4>({{"graphInput0", {0.0, 1.0, 2.0, 3.0}}, + {"graphInput1", {4.0, 5.0, 6.0, 7.0}}, + {"graphInput2", {8.0, 9.0, 10.0, 11.0}}, + {"graphInput3", {12.0, 13.0, 14.0, 15.0}}}, + {{"concat_2", { 0.0, 1.0, 4.0, 5.0, 8.0, 9.0, 12.0, 13.0, + 2.0, 3.0, 6.0, 7.0, 10.0, 11.0, 14.0, 15.0 }}}); +} + +BOOST_AUTO_TEST_SUITE_END() diff --git a/src/armnnTfParser/test/Constant.cpp b/src/armnnTfParser/test/Constant.cpp new file mode 100644 index 0000000000..09587fc3d5 --- /dev/null +++ b/src/armnnTfParser/test/Constant.cpp @@ -0,0 +1,321 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// See LICENSE file in the project root for full license information. +// + +#include + +#include "armnnTfParser/ITfParser.hpp" + +#include "ParserPrototxtFixture.hpp" + +BOOST_AUTO_TEST_SUITE(TensorflowParser) + +// Tests that a Const node in Tensorflow can be converted to a ConstLayer in armnn (as opposed to most +// Const nodes which are used as weight inputs for convolutions etc. and are therefore not converted to +// armnn ConstLayers). +struct ConstantFixture : public ParserPrototxtFixture +{ + ConstantFixture() + { + // input = tf.placeholder(tf.float32, name = "input") + // const = tf.constant([17], tf.float32, [1]) + // output = tf.add(input, const, name = "output") + m_Prototext = + R"( +node { + name: "input" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + unknown_rank: true + } + } + } +} +node { + name: "Const" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 1 + } + } + float_val: 17.0 + } + } + } +} +node { + name: "output" + op: "Add" + input: "input" + input: "Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} + )"; + SetupSingleInputSingleOutput({ 1 }, "input", "output"); + } +}; + +BOOST_FIXTURE_TEST_CASE(Constant, ConstantFixture) +{ + RunTest<1>({1}, {18}); +} + + +// Tests that a single Const node in Tensorflow can be used twice by a dependant node. This should result in only +// a single armnn ConstLayer being created. +struct ConstantReusedFixture : public ParserPrototxtFixture +{ + ConstantReusedFixture() + { + // const = tf.constant([17], tf.float32, [1]) + // output = tf.add(const, const, name = "output") + m_Prototext = + R"( +node { + name: "Const" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 1 + } + } + float_val: 17.0 + } + } + } +} +node { + name: "output" + op: "Add" + input: "Const" + input: "Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} + )"; + Setup({}, { "output" }); + } +}; + +BOOST_FIXTURE_TEST_CASE(ConstantReused, ConstantReusedFixture) +{ + RunTest<1>({}, { { "output", { 34 } } }); +} + +template +struct ConstantValueListFixture : public ParserPrototxtFixture +{ + ConstantValueListFixture() + { + m_Prototext = + R"( +node { + name: "output" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 2 + } + dim { + size: 3 + } + })"; + + double value = 0.75; + for (int i = 0; i < ListSize; i++, value += 0.25) + { + m_Prototext += std::string("float_val : ") + std::to_string(value) + "\n"; + } + + m_Prototext += + R"( + } + } + } +} + )"; + Setup({}, { "output" }); + } +}; + +using ConstantSingleValueListFixture = ConstantValueListFixture<1>; +using ConstantMultipleValueListFixture = ConstantValueListFixture<4>; +using ConstantMaxValueListFixture = ConstantValueListFixture<6>; + +BOOST_FIXTURE_TEST_CASE(ConstantSingleValueList, ConstantSingleValueListFixture) +{ + RunTest<2>({}, { { "output", { 0.75f, 0.75f, 0.75f, 0.75f, 0.75f, 0.75f } } }); +} +BOOST_FIXTURE_TEST_CASE(ConstantMultipleValueList, ConstantMultipleValueListFixture) +{ + RunTest<2>({}, { { "output", { 0.75f, 1.f, 1.25f, 1.5f, 1.5f, 1.5f } } }); +} +BOOST_FIXTURE_TEST_CASE(ConstantMaxValueList, ConstantMaxValueListFixture) +{ + RunTest<2>({}, { { "output", { 0.75f, 1.f, 1.25f, 1.50f, 1.75f, 2.f } } }); +} + +template +struct ConstantCreateFixture : public ParserPrototxtFixture +{ + ConstantCreateFixture() + { + m_Prototext = + R"( +node { + name: "output" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + )"; + + if (WithShape) + { + m_Prototext += + R"( +tensor_shape { + dim { + size: 2 + } + dim { + size: 2 + } +} + )"; + } + else + { + m_Prototext += + R"( +tensor_shape { +} + )"; + } + + if (WithContent) + { + m_Prototext += + R"( +tensor_content: "\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?" + )"; + } + + if (WithValueList) + { + m_Prototext += + R"( +float_val: 1.0 +float_val: 1.0 +float_val: 1.0 +float_val: 1.0 +float_val: 1.0 + )"; + } + + m_Prototext += + R"( + } + } + } +} + )"; + } +}; + +using ConstantCreateNoValueListFixture = ConstantCreateFixture; +using ConstantCreateNoValueList2Fixture = ConstantCreateFixture; +using ConstantCreateNoContentFixture = ConstantCreateFixture; +using ConstantCreateNoContent2Fixture = ConstantCreateFixture; +using ConstantCreateNoShapeFixture = ConstantCreateFixture; +using ConstantCreateNoShape2Fixture = ConstantCreateFixture; +using ConstantCreateNoShape3Fixture = ConstantCreateFixture; + +BOOST_FIXTURE_TEST_CASE(ConstantCreateInvalidValueList, ConstantCreateNoValueListFixture) +{ + BOOST_REQUIRE_THROW(Setup({}, { "output" }), armnn::ParseException); +} +BOOST_FIXTURE_TEST_CASE(ConstantCreateInvalidValueList2, ConstantCreateNoValueList2Fixture) +{ + BOOST_REQUIRE_THROW(Setup({}, { "output" }), armnn::ParseException); +} +BOOST_FIXTURE_TEST_CASE(ConstantCreateInvalidContent, ConstantCreateNoContentFixture) +{ + BOOST_REQUIRE_THROW(Setup({}, { "output" }), armnn::ParseException); +} +BOOST_FIXTURE_TEST_CASE(ConstantCreateInvalidShape, ConstantCreateNoShapeFixture) +{ + BOOST_REQUIRE_THROW(Setup({}, { "output" }), armnn::ParseException); +} +BOOST_FIXTURE_TEST_CASE(ConstantCreateNoShape2, ConstantCreateNoShape2Fixture) +{ + BOOST_REQUIRE_THROW(Setup({}, { "output" }), armnn::ParseException); +} +BOOST_FIXTURE_TEST_CASE(ConstantCreateNoShape3, ConstantCreateNoShape3Fixture) +{ + Setup({}, { "output" }); + RunTest<1>({}, { { "output", { 1.f, 1.f, 1.f, 1.f, 1.f } } }); +} + +BOOST_AUTO_TEST_SUITE_END() diff --git a/src/armnnTfParser/test/Convolution2d.cpp b/src/armnnTfParser/test/Convolution2d.cpp new file mode 100644 index 0000000000..a7c7648b81 --- /dev/null +++ b/src/armnnTfParser/test/Convolution2d.cpp @@ -0,0 +1,322 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// See LICENSE file in the project root for full license information. +// + +#include +#include "armnnTfParser/ITfParser.hpp" +#include "ParserPrototxtFixture.hpp" +#include +#include + +BOOST_AUTO_TEST_SUITE(TensorflowParser) + +struct Convolution2dFixture : public ParserPrototxtFixture +{ + explicit Convolution2dFixture(const char* paddingType) + : Convolution2dFixture(paddingType, 1) + {} + + // dilation: 0 - dilations attribute is not included; + // dilation: >0 - dilations attribute set to [1,v,v,1], where v is the value of the dilation arg + explicit Convolution2dFixture(const char* paddingType, int stride, int dilation = 0) + { + std::string strideString = std::to_string(stride); + std::string dilationString = std::to_string(dilation); + m_Prototext = "node { \n" + " name: \"graphInput\" \n" + " op: \"Placeholder\" \n" + " attr { \n" + " key: \"dtype\" \n" + " value { \n" + " type: DT_FLOAT \n" + " } \n" + " } \n" + " attr { \n" + " key: \"shape\" \n" + " value { \n" + " shape { \n" + " } \n" + " } \n" + " } \n" + " } \n" + " node { \n" + " name: \"Const_1\" \n" + " op: \"Const\" \n" + " attr { \n" + " key: \"dtype\" \n" + " value { \n" + " type: DT_FLOAT \n" + " } \n" + " } \n" + " attr { \n" + " key: \"value\" \n" + " value { \n" + " tensor { \n" + " dtype: DT_FLOAT \n" + " tensor_shape { \n" + " dim { \n" + " size: 1 \n" + " } \n" + " dim { \n" + " size: 3 \n" + " } \n" + " dim { \n" + " size: 1 \n" + " } \n" + " dim { \n" + " size: 1 \n" + " } \n" + " } \n" + " tensor_content: \"\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?\" \n" + " } \n" + " } \n" + " } \n" + "} \n" + "node { \n" + " name: \"potato\" \n" + " op: \"Conv2D\" \n" + " input: \"graphInput\" \n" + " input: \"Const_1\" \n" + " attr { \n" + " key: \"T\" \n" + " value { \n" + " type: DT_FLOAT \n" + " } \n" + " } \n" + " attr { \n" + " key: \"data_format\" \n" + " value { \n" + " s: \"NHWC\" \n" + " } \n" + " } \n" + " attr { \n" + " key: \"padding\" \n" + " value { \n" + " s: \""; + m_Prototext.append(paddingType); + m_Prototext.append("\"\n" + " } \n" + " } \n" + " attr { \n" + " key: \"strides\" \n" + " value { \n" + " list { \n" + " i: 1 \n" + " i: 1 \n" + " i: "); + m_Prototext.append(strideString); + m_Prototext.append(" \n" + " i: 1 \n" + " } \n" + " } \n" + " } \n"); + + if (dilation > 0) + { + m_Prototext.append(" attr { \n" + " key: \"dilations\" \n" + " value { \n" + " list { \n" + " i: 1 \n" + " i: "); + m_Prototext.append(dilationString); + m_Prototext.append(" \n" + " i: "); + m_Prototext.append(dilationString); + m_Prototext.append(" \n" + " i: 1 \n" + " } \n" + " } \n" + " } \n"); + } + m_Prototext.append(" attr { \n" + " key: \"use_cudnn_on_gpu\" \n" + " value { \n" + " b: false \n" + " } \n" + " } \n" + "} \n"); + + // Manual height computation based on stride parameter. + BOOST_ASSERT_MSG(stride == 1 || stride==2, "Add support for strides other than 1 or 2."); + unsigned int dims[] = {1,2,3,1}; + if (stride == 2) + { + dims[1]=3; + } + + SetupSingleInputSingleOutput(armnn::TensorShape(4, dims), "graphInput", "potato"); + } +}; + + +struct Convolution2dSameFixture : Convolution2dFixture +{ + Convolution2dSameFixture() : Convolution2dFixture("SAME", 1){} +}; +BOOST_FIXTURE_TEST_CASE(ParseConv2DSame, Convolution2dSameFixture) +{ + RunTest<4>({1, 2, 3, 4, 5, 6}, {2, 4, 4, 6.5f, 10 , 8.5f}); +} + +struct Convolution2dValidFixture : Convolution2dFixture +{ + Convolution2dValidFixture() : Convolution2dFixture("VALID", 1){} +}; +BOOST_FIXTURE_TEST_CASE(ParseConv2DValid, Convolution2dValidFixture) +{ + RunTest<4>({1, 2, 3, 4, 5, 6}, {4, 10}); +} + + +struct Convolution2dStride2SameFixture : Convolution2dFixture +{ + Convolution2dStride2SameFixture() : Convolution2dFixture("SAME", 2){} +}; +BOOST_FIXTURE_TEST_CASE(ParseConv2DStride2Same, Convolution2dStride2SameFixture) +{ + RunTest<4>({1, 2, 3, 4, 5, 6, 7, 8, 9}, {2, 4, 6.5, 8.5, 11, 13}); +} + + +struct Convolution2dStride2ValidFixture : Convolution2dFixture +{ + Convolution2dStride2ValidFixture() : Convolution2dFixture("VALID", 2){} +}; +BOOST_FIXTURE_TEST_CASE(ParseConv2DStride2Valid, Convolution2dStride2ValidFixture) +{ + RunTest<4>({1, 2, 3, 4, 5, 6, 7, 8, 9}, {4, 10, 16}); +} + + +struct Convolution2dDilation1Fixture : Convolution2dFixture +{ + Convolution2dDilation1Fixture() : Convolution2dFixture("SAME", 1, 1){} +}; +BOOST_FIXTURE_TEST_CASE(ParseConv2DDilation1, Convolution2dDilation1Fixture) +{ + RunTest<4>({1, 2, 3, 4, 5, 6}, {2, 4, 4, 6.5f, 10 , 8.5f}); +} + +BOOST_AUTO_TEST_CASE(ParseConv2DDilation2) +{ + const char* prototext = "" + "node {\n" + " name: \"graphInput\"\n" + " op: \"Placeholder\"\n" + " attr {\n" + " key: \"dtype\"\n" + " value {\n" + " type: DT_FLOAT\n" + " }\n" + " }\n" + " attr {\n" + " key: \"shape\"\n" + " value {\n" + " shape {\n" + " }\n" + " }\n" + " }\n" + "}\n" + "node {\n" + " name: \"Const_1\"\n" + " op: \"Const\"\n" + " attr {\n" + " key: \"dtype\"\n" + " value {\n" + " type: DT_FLOAT\n" + " }\n" + " }\n" + " attr {\n" + " key: \"value\"\n" + " value {\n" + " tensor {\n" + " dtype: DT_FLOAT\n" + " tensor_shape {\n" + " dim {\n" + " size: 1\n" + " }\n" + " dim {\n" + " size: 3\n" + " }\n" + " dim {\n" + " size: 1\n" + " }\n" + " dim {\n" + " size: 1\n" + " }\n" + " }\n" + " tensor_content: \"\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?\"\n" + " }\n" + " }\n" + " }\n" + "}\n" + "node {\n" + " name: \"potato\"\n" + " op: \"Conv2D\"\n" + " input: \"graphInput\"\n" + " input: \"Const_1\"\n" + " attr {\n" + " key: \"T\"\n" + " value {\n" + " type: DT_FLOAT\n" + " }\n" + " }\n" + " attr {\n" + " key: \"data_format\"\n" + " value {\n" + " s: \"NHWC\"\n" + " }\n" + " }\n" + " attr {\n" + " key: \"padding\"\n" + " value {\n" + " s: \"SAME\"\n" + " }\n" + " }\n" + " attr {\n" + " key: \"strides\"\n" + " value {\n" + " list {\n" + " i: 1\n" + " i: 1\n" + " i: 1\n" + " i: 1\n" + " }\n" + " }\n" + " }\n" + " attr {\n" + " key: \"dilations\"\n" + " value {\n" + " list {\n" + " i: 1\n" + " i: 2\n" + " i: 2\n" + " i: 1\n" + " }\n" + " }\n" + " }\n" + " attr {\n" + " key: \"use_cudnn_on_gpu\"\n" + " value {\n" + " b: false\n" + " }\n" + " }\n" + "}\n"; + + std::map inputShapes; + armnn::TensorShape tensorShape = { 1, 3, 3, 1 }; + inputShapes["graphInput"] = tensorShape; + armnnTfParser::ITfParserPtr parser = armnnTfParser::ITfParser::Create(); + BOOST_CHECK_EXCEPTION(parser->CreateNetworkFromString(prototext, inputShapes, { "potato" }), + armnn::ParseException, + [] (armnn::ParseException const& ex)->bool + { + return strcmp(ex.what(), + "ArmNN only supports Convolution layers with dilations [1,1,1,1]") == 0; + }); +} + + +BOOST_AUTO_TEST_SUITE_END() diff --git a/src/armnnTfParser/test/DepthwiseConvolution2d.cpp b/src/armnnTfParser/test/DepthwiseConvolution2d.cpp new file mode 100644 index 0000000000..84e7a7e7a9 --- /dev/null +++ b/src/armnnTfParser/test/DepthwiseConvolution2d.cpp @@ -0,0 +1,166 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// See LICENSE file in the project root for full license information. +// + +#include +#include "armnnTfParser/ITfParser.hpp" +#include "ParserPrototxtFixture.hpp" +#include +#include + +BOOST_AUTO_TEST_SUITE(TensorflowParser) + +struct DepthwiseConvolution2dFixture : public ParserPrototxtFixture +{ + explicit DepthwiseConvolution2dFixture(const char* paddingType) + { + m_Prototext = "node { \n" + " name: \"graphInput\" \n" + " op: \"Placeholder\" \n" + " attr { \n" + " key: \"dtype\" \n" + " value { \n" + " type: DT_FLOAT \n" + " } \n" + " } \n" + " attr { \n" + " key: \"value\" \n" + " value { \n" + " tensor { \n" + " dtype: DT_FLOAT \n" + " tensor_shape { \n" + " dim { \n" + " size: 1 \n" + " } \n" + " dim { \n" + " size: 1 \n" + " } \n" + " dim { \n" + " size: 3 \n" + " } \n" + " dim { \n" + " size: 3 \n" + " } \n" + " } \n" + " tensor_content: \"\\000\\000\\200?\\000\\000\\000@\\000\\000@@\\000\\000\\200@" + "\\000\\000\\240@\\000\\000\\300@\\000\\000\\340@\\000\\000\\000A\\000\\000\\020A\" \n" + " } \n" + " } \n" + " } \n" + " } \n" + " node { \n" + " name: \"Const_1\" \n" + " op: \"Const\" \n" + " attr { \n" + " key: \"dtype\" \n" + " value { \n" + " type: DT_FLOAT \n" + " } \n" + " } \n" + " attr { \n" + " key: \"value\" \n" + " value { \n" + " tensor { \n" + " dtype: DT_FLOAT \n" + " tensor_shape { \n" + " dim { \n" + " size: 1 \n" + " } \n" + " dim { \n" + " size: 3 \n" + " } \n" + " dim { \n" + " size: 3 \n" + " } \n" + " dim { \n" + " size: 3 \n" + " } \n" + " } \n" + " tensor_content: \"\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?" + "\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?" + "\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?" + "\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?" + "\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?" + "\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?" + "\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?" + "\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?" + "\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?\" \n" + " } \n" + " } \n" + " } \n" + "} \n" + "node { \n" + " name: \"potato\" \n" + " op: \"DepthwiseConv2dNative\" \n" + " input: \"graphInput\" \n" + " input: \"Const_1\" \n" + " attr { \n" + " key: \"T\" \n" + " value { \n" + " type: DT_FLOAT \n" + " } \n" + " } \n" + " attr { \n" + " key: \"data_format\" \n" + " value { \n" + " s: \"NHWC\" \n" + " } \n" + " } \n" + " attr { \n" + " key: \"padding\" \n" + " value { \n" + " s: \""; + m_Prototext.append(paddingType); + m_Prototext.append("\"\n" + " } \n" + " } \n" + " attr { \n" + " key: \"strides\" \n" + " value { \n" + " list { \n" + " i: 1 \n" + " i: 1 \n" + " i: 1 \n" + " i: 1 \n" + " } \n" + " } \n" + " } \n" + " attr { \n" + " key: \"use_cudnn_on_gpu\" \n" + " value { \n" + " b: false \n" + " } \n" + " } \n" + "} \n"); + + SetupSingleInputSingleOutput({ 1, 1, 3, 3 }, "graphInput", "potato"); + } +}; + +struct DepthwiseConvolution2dSameFixture : DepthwiseConvolution2dFixture +{ + DepthwiseConvolution2dSameFixture() : DepthwiseConvolution2dFixture("SAME") { } +}; + +BOOST_FIXTURE_TEST_CASE(ParseDepthwiseConv2DSame, DepthwiseConvolution2dSameFixture) +{ + RunTest<4>({ 1, 2, 3, 4, 5, 6, 7, 8, 9 }, + { 2.5f, 5.f, 2.5f, 3.5f, 7.f, 3.5f, 4.5f, 9.f, 4.5f, + 6.f, 12.f, 6.f, 7.5f, 15.f, 7.5f, 9.f, 18.f, 9.f, + 5.5f, 11.f, 5.5f, 6.5f, 13.f, 6.5f, 7.5f, 15.f, 7.5f}); +} + +struct DepthwiseConvolution2dValidFixture : DepthwiseConvolution2dFixture +{ + DepthwiseConvolution2dValidFixture() : DepthwiseConvolution2dFixture("VALID") { } +}; + +BOOST_FIXTURE_TEST_CASE(ParseDepthwiseConv2DValid, DepthwiseConvolution2dValidFixture) +{ + RunTest<4>({ 1, 2, 3, 4, 5, 6, 7, 8, 9 }, // input data + { 6.f, 12.f, 6.f, 7.5f, 15.f, 7.5f, 9.f, 18.f, 9.f }); // output expected data +} + + +BOOST_AUTO_TEST_SUITE_END() diff --git a/src/armnnTfParser/test/FullyConnected.cpp b/src/armnnTfParser/test/FullyConnected.cpp new file mode 100644 index 0000000000..2a7b4951b7 --- /dev/null +++ b/src/armnnTfParser/test/FullyConnected.cpp @@ -0,0 +1,579 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// See LICENSE file in the project root for full license information. +// + +#include +#include "armnnTfParser/ITfParser.hpp" +#include "ParserPrototxtFixture.hpp" +#include "Runtime.hpp" +#include "Network.hpp" +#include "Graph.hpp" + +BOOST_AUTO_TEST_SUITE(TensorflowParser) + +// In Tensorflow fully connected layers are expressed as a MatMul followed by an Add. +// The TfParser must detect this case and convert them to a FullyConnected layer. +struct FullyConnectedFixture : public ParserPrototxtFixture +{ + FullyConnectedFixture() + { + // input = tf.placeholder(tf.float32, [1, 1], "input") + // weights = tf.constant([2], tf.float32, [1, 1]) + // matmul = tf.matmul(input, weights) + // bias = tf.constant([1], tf.float32) + // output = tf.add(matmul, bias, name="output") + m_Prototext = R"( +node { + name: "input" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 1 + } + dim { + size: 1 + } + } + } + } +} +node { + name: "Const" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 1 + } + dim { + size: 1 + } + } + float_val: 2.0 + } + } + } +} +node { + name: "MatMul" + op: "MatMul" + input: "input" + input: "Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "transpose_a" + value { + b: false + } + } + attr { + key: "transpose_b" + value { + b: false + } + } +} +node { + name: "Const_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 1 + } + } + float_val: 1.0 + } + } + } +} +node { + name: "output" + op: "Add" + input: "MatMul" + input: "Const_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} + )"; + SetupSingleInputSingleOutput({ 1, 1 }, "input", "output"); + } +}; + +BOOST_FIXTURE_TEST_CASE(FullyConnected, FullyConnectedFixture) +{ + RunTest<1>({ 3 }, { 7 }); +} + +// Similar to FullyConnectedFixture, but this time the MatMul's output is used by two Adds. This should result +// in two FullyConnected layers being created. +// I +// | +// M -- C +// / \' +// C-- A A -- C +// \ / +// A +struct MatMulUsedInTwoFcFixture : public ParserPrototxtFixture +{ + MatMulUsedInTwoFcFixture() + { + m_Prototext = R"( +node { + name: "input" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 1 + } + dim { + size: 1 + } + } + } + } +} +node { + name: "Const" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 1 + } + dim { + size: 1 + } + } + float_val: 2.0 + } + } + } +} +node { + name: "MatMul" + op: "MatMul" + input: "input" + input: "Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "transpose_a" + value { + b: false + } + } + attr { + key: "transpose_b" + value { + b: false + } + } +} +node { + name: "Const_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 1 + } + } + float_val: 5.0 + } + } + } +} +node { + name: "Const_2" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 1 + } + } + float_val: 15.0 + } + } + } +} +node { + name: "Add" + op: "Add" + input: "MatMul" + input: "Const_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "Add_1" + op: "Add" + input: "MatMul" + input: "Const_2" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "output" + op: "Add" + input: "Add" + input: "Add_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} + )"; + SetupSingleInputSingleOutput({ 1, 1 }, "input", "output"); + } +}; + +BOOST_FIXTURE_TEST_CASE(MatMulUsedInTwoFc, MatMulUsedInTwoFcFixture) +{ + RunTest<1>({ 3 }, { 32 }); + // Ideally we would check here that the armnn network has 5 layers: + // Input, 2 x FullyConnected (biased), Add and Output. + // This would make sure the parser hasn't incorrectly added some unconnected layers corresponding to the MatMul +} + +// Similar to MatMulUsedInTwoFc, but this time the Adds are 'staggered' (see diagram), which means that only one +// FullyConnected layer can be created (the other should just be an Add). +// I +// | +// M -- C1 +// / \' +// C2 -- A | +// \ / +// A +struct MatMulUsedInTwoFcStaggeredFixture : public ParserPrototxtFixture +{ + MatMulUsedInTwoFcStaggeredFixture() + { + // input = tf.placeholder(tf.float32, shape=[1,1], name = "input") + // const1 = tf.constant([17], tf.float32, [1,1]) + // mul = tf.matmul(input, const1) + // const2 = tf.constant([7], tf.float32, [1]) + // fc = tf.add(mul, const2) + // output = tf.add(mul, fc, name="output") + m_Prototext = R"( +node { + name: "input" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 1 + } + dim { + size: 1 + } + } + } + } +} +node { + name: "Const" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 1 + } + dim { + size: 1 + } + } + float_val: 17.0 + } + } + } +} +node { + name: "MatMul" + op: "MatMul" + input: "input" + input: "Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "transpose_a" + value { + b: false + } + } + attr { + key: "transpose_b" + value { + b: false + } + } +} +node { + name: "Const_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 1 + } + } + float_val: 7.0 + } + } + } +} +node { + name: "Add" + op: "Add" + input: "MatMul" + input: "Const_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "output" + op: "Add" + input: "MatMul" + input: "Add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} + )"; + SetupSingleInputSingleOutput({ 1, 1 }, "input", "output"); + } +}; + +BOOST_FIXTURE_TEST_CASE(MatMulUsedInTwoFcStaggered, MatMulUsedInTwoFcStaggeredFixture) +{ + RunTest<1>({ 2 }, { 75 }); + // Ideally we would check here that the armnn network has 5 layers: + // Input, FullyConnected (biased), FullyConnected (non biased), Add and Output. +} + +// A MatMul in isolation, not connected to an add. Should result in a non-biased FullyConnected layer. +struct MatMulFixture : public ParserPrototxtFixture +{ + MatMulFixture() + { + // input = tf.placeholder(tf.float32, shape = [1, 1], name = "input") + // const = tf.constant([17], tf.float32, [1, 1]) + // output = tf.matmul(input, const, name = "output") + m_Prototext = R"( +node { + name: "input" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 1 + } + dim { + size: 1 + } + } + } + } +} +node { + name: "Const" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 1 + } + dim { + size: 1 + } + } + float_val: 17.0 + } + } + } +} +node { + name: "output" + op: "MatMul" + input: "input" + input: "Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "transpose_a" + value { + b: false + } + } + attr { + key: "transpose_b" + value { + b: false + } + } +} + )"; + SetupSingleInputSingleOutput({ 1, 1 }, "input", "output"); + } +}; + +BOOST_FIXTURE_TEST_CASE(MatMul, MatMulFixture) +{ + RunTest<1>({ 2 }, { 34 }); +} + +BOOST_AUTO_TEST_SUITE_END() diff --git a/src/armnnTfParser/test/FusedBatchNorm.cpp b/src/armnnTfParser/test/FusedBatchNorm.cpp new file mode 100644 index 0000000000..632d5f01f9 --- /dev/null +++ b/src/armnnTfParser/test/FusedBatchNorm.cpp @@ -0,0 +1,175 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// See LICENSE file in the project root for full license information. +// + +#include +#include "armnnTfParser/ITfParser.hpp" +#include "ParserPrototxtFixture.hpp" + +BOOST_AUTO_TEST_SUITE(TensorflowParser) + +struct FusedBatchNormFixture : public ParserPrototxtFixture +{ + FusedBatchNormFixture() + { + m_Prototext = "node { \n" + " name: \"graphInput\" \n" + " op: \"Placeholder\" \n" + " attr { \n" + " key: \"dtype\" \n" + " value { \n" + " type: DT_FLOAT \n" + " } \n" + " } \n" + " attr { \n" + " key: \"shape\" \n" + " value { \n" + " shape { \n" + " } \n" + " } \n" + " } \n" + "} \n" + "node { \n" + " name: \"Const_1\" \n" + " op: \"Const\" \n" + " attr { \n" + " key: \"dtype\" \n" + " value { \n" + " type: DT_FLOAT \n" + " } \n" + " } \n" + " attr { \n" + " key: \"value\" \n" + " value { \n" + " tensor { \n" + " dtype: DT_FLOAT \n" + " tensor_shape { \n" + " dim { \n" + " size: 1 \n" + " } \n" + " } \n" + " float_val: 1.0 \n" + " } \n" + " } \n" + " } \n" + "} \n" + "node { \n" + " name: \"Const_2\" \n" + " op: \"Const\" \n" + " attr { \n" + " key: \"dtype\" \n" + " value { \n" + " type: DT_FLOAT \n" + " } \n" + " } \n" + " attr { \n" + " key: \"value\" \n" + " value { \n" + " tensor { \n" + " dtype: DT_FLOAT \n" + " tensor_shape { \n" + " dim { \n" + " size: 1 \n" + " } \n" + " } \n" + " float_val: 0.0 \n" + " } \n" + " } \n" + " } \n" + "} \n" + "node { \n" + " name: \"FusedBatchNormLayer/mean\" \n" + " op: \"Const\" \n" + " attr { \n" + " key: \"dtype\" \n" + " value { \n" + " type: DT_FLOAT \n" + " } \n" + " } \n" + " attr { \n" + " key: \"value\" \n" + " value { \n" + " tensor { \n" + " dtype: DT_FLOAT \n" + " tensor_shape { \n" + " dim { \n" + " size: 1 \n" + " } \n" + " } \n" + " float_val: 5.0 \n" + " } \n" + " } \n" + " } \n" + "} \n" + "node { \n" + " name: \"FusedBatchNormLayer/variance\" \n" + " op: \"Const\" \n" + " attr { \n" + " key: \"dtype\" \n" + " value { \n" + " type: DT_FLOAT \n" + " } \n" + " } \n" + " attr { \n" + " key: \"value\" \n" + " value { \n" + " tensor { \n" + " dtype: DT_FLOAT \n" + " tensor_shape { \n" + " dim { \n" + " size: 1 \n" + " } \n" + " } \n" + " float_val: 2.0 \n" + " } \n" + " } \n" + " } \n" + "} \n" + "node { \n" + " name: \"output\" \n" + " op: \"FusedBatchNorm\" \n" + " input: \"graphInput\" \n" + " input: \"Const_1\" \n" + " input: \"Const_2\" \n" + " input: \"FusedBatchNormLayer/mean\" \n" + " input: \"FusedBatchNormLayer/variance\" \n" + " attr { \n" + " key: \"T\" \n" + " value { \n" + " type: DT_FLOAT \n" + " } \n" + " } \n" + " attr { \n" + " key: \"data_format\" \n" + " value { \n" + " s: \"NHWC\" \n" + " } \n" + " } \n" + " attr { \n" + " key: \"epsilon\" \n" + " value { \n" + " f: 0.0010000000475 \n" + " } \n" + " } \n" + " attr { \n" + " key: \"is_training\" \n" + " value { \n" + " b: false \n" + " } \n" + " } \n" + "} \n"; + + SetupSingleInputSingleOutput({1, 3, 3, 1}, "graphInput", "output"); + } +}; + +BOOST_FIXTURE_TEST_CASE(ParseFusedBatchNorm, FusedBatchNormFixture) +{ + RunTest<4>({1, 2, 3, 4, 5, 6, 7, 8, 9}, // input data + {-2.8277204f, -2.12079024f, -1.4138602f, + -0.7069301f, 0.0f, 0.7069301f, + 1.4138602f, 2.12079024f, 2.8277204f}); // expected output data +} + +BOOST_AUTO_TEST_SUITE_END() diff --git a/src/armnnTfParser/test/Identity.cpp b/src/armnnTfParser/test/Identity.cpp new file mode 100644 index 0000000000..ca20de5760 --- /dev/null +++ b/src/armnnTfParser/test/Identity.cpp @@ -0,0 +1,161 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// See LICENSE file in the project root for full license information. +// + +#include +#include "armnnTfParser/ITfParser.hpp" +#include "ParserPrototxtFixture.hpp" + +BOOST_AUTO_TEST_SUITE(TensorflowParser) + +struct IdentitySimpleFixture : public ParserPrototxtFixture +{ + IdentitySimpleFixture() + { + m_Prototext = "node{ " + " name: \"Placeholder\"" + " op: \"Placeholder\"" + " attr {" + " key: \"dtype\"" + " value {" + " type: DT_FLOAT" + " }" + " }" + " attr {" + " key: \"shape\"" + " value {" + " shape {" + " unknown_rank: true" + " }" + " }" + " }" + "}" + "node {" + " name: \"Identity\"" + " op: \"Identity\"" + " input: \"Placeholder\"" + " attr {" + " key: \"T\"" + " value {" + " type: DT_FLOAT" + " }" + " }" + "}"; + SetupSingleInputSingleOutput({ 4 }, "Placeholder", "Identity"); + } +}; + +BOOST_FIXTURE_TEST_CASE(IdentitySimple, IdentitySimpleFixture) +{ + RunTest<1>({ 1.0f, 2.0f, 3.0f, 4.0f }, { 1.0f, 2.0f, 3.0f, 4.0f }); +} + +struct IdentityFixture : public ParserPrototxtFixture +{ + IdentityFixture() + { + m_Prototext = "node{ " + " name: \"Placeholder\"" + " op: \"Placeholder\"" + " attr {" + " key: \"dtype\"" + " value {" + " type: DT_FLOAT" + " }" + " }" + " attr {" + " key: \"shape\"" + " value {" + " shape {" + " unknown_rank: true" + " }" + " }" + " }" + "}" + "node {" + " name: \"Identity\"" + " op: \"Identity\"" + " input: \"Placeholder\"" + " attr {" + " key: \"T\"" + " value {" + " type: DT_FLOAT" + " }" + " }" + "}" + "node {" + " name: \"Add\"" + " op: \"Add\"" + " input: \"Identity\"" + " input: \"Identity\"" + " attr {" + " key: \"T\"" + " value {" + " type: DT_FLOAT" + " }" + " }" + "}"; + SetupSingleInputSingleOutput({ 4 }, "Placeholder", "Add"); + } +}; + +BOOST_FIXTURE_TEST_CASE(ParseIdentity, IdentityFixture) +{ + RunTest<1>({ 1.0f, 2.0f, 3.0f, 4.0f }, { 2.0f, 4.0f, 6.0f, 8.0f }); +} + +struct IdentityChainFixture : public ParserPrototxtFixture +{ + IdentityChainFixture() + { + m_Prototext = "node{ " + " name: \"Placeholder\"" + " op: \"Placeholder\"" + " attr {" + " key: \"dtype\"" + " value {" + " type: DT_FLOAT" + " }" + " }" + " attr {" + " key: \"shape\"" + " value {" + " shape {" + " unknown_rank: true" + " }" + " }" + " }" + "}" + "node {" + " name: \"Identity\"" + " op: \"Identity\"" + " input: \"Placeholder\"" + " attr {" + " key: \"T\"" + " value {" + " type: DT_FLOAT" + " }" + " }" + "}" + "node {" + " name: \"Identity2\"" + " op: \"Identity\"" + " input: \"Identity\"" + " attr {" + " key: \"T\"" + " value {" + " type: DT_FLOAT" + " }" + " }" + "}"; + SetupSingleInputSingleOutput({ 4 }, "Placeholder", "Identity2"); + } +}; + +BOOST_FIXTURE_TEST_CASE(IdentityChain, IdentityChainFixture) +{ + RunTest<1>({ 1.0f, 2.0f, 3.0f, 4.0f }, { 1.0f, 2.0f, 3.0f, 4.0f }); +} + +BOOST_AUTO_TEST_SUITE_END() diff --git a/src/armnnTfParser/test/LocalResponseNormalization.cpp b/src/armnnTfParser/test/LocalResponseNormalization.cpp new file mode 100644 index 0000000000..a7c2bfe3e1 --- /dev/null +++ b/src/armnnTfParser/test/LocalResponseNormalization.cpp @@ -0,0 +1,121 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// See LICENSE file in the project root for full license information. +// + +#include +#include "armnnTfParser/ITfParser.hpp" +#include "ParserPrototxtFixture.hpp" + +BOOST_AUTO_TEST_SUITE(TensorflowParser) + + +struct LocalResponseNormalizationBaseFixture : public ParserPrototxtFixture +{ + explicit LocalResponseNormalizationBaseFixture(float alpha, float beta, float bias) + { + std::string alphaString = std::to_string(alpha); + std::string betaString = std::to_string(beta); + std::string biasString = std::to_string(bias); + + m_Prototext = "node {" + " name: \"Placeholder\"" + " op: \"Placeholder\"" + " attr {" + " key: \"dtype\"" + " value {" + " type: DT_FLOAT" + " }" + " }" + " attr {" + " key: \"shape\"" + " value {" + " shape {" + " unknown_rank: true" + " }" + " }" + " }" + "}" + "node {" + " name: \"LRN\"" + " op: \"LRN\"" + " input: \"Placeholder\"" + " attr {" + " key: \"T\"" + " value {" + " type: DT_FLOAT" + " }" + " }" + " attr {" + " key: \"alpha\"" + " value {" + " f: "; + m_Prototext.append(alphaString); + m_Prototext.append("\n" + " }" + " }" + " attr {" + " key: \"beta\"" + " value {" + " f: "); + m_Prototext.append(betaString); + m_Prototext.append("\n" + " }" + " }" + " attr {" + " key: \"bias\"" + " value {" + " f: "); + m_Prototext.append(biasString); + m_Prototext.append("\n" + " }" + " }" + " attr {" + " key: \"depth_radius\"" + " value {" + " i: 1" + " }" + " }" + "}"); + } +}; + + +struct LocalResponseNormalizationFixtureSimple : public LocalResponseNormalizationBaseFixture +{ + explicit LocalResponseNormalizationFixtureSimple() + : LocalResponseNormalizationBaseFixture(1.0f, 1.0f, 1.0f) + { + SetupSingleInputSingleOutput({ 2, 2, 2, 1 }, "Placeholder", "LRN"); + } +}; +BOOST_FIXTURE_TEST_CASE(ParseSimpleLocalResponseNormalization, LocalResponseNormalizationFixtureSimple) +{ + RunTest<4>({ 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f }, + { 0.5f, 0.4f, 0.3f, 0.23529412f, 0.1923077f, 0.16216217f, 0.14f, 0.12307692f }); +} + + +struct LocalResponseNormalizationFixture : public LocalResponseNormalizationBaseFixture +{ + explicit LocalResponseNormalizationFixture() + : LocalResponseNormalizationBaseFixture(0.5f, 1.0f, 0.5f) + { + SetupSingleInputSingleOutput({1, 3, 3, 2}, "Placeholder", "LRN"); + } +}; +BOOST_FIXTURE_TEST_CASE(ParseLocalResponseNormalization, LocalResponseNormalizationFixture) +{ + RunTest<4>({ 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, + 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, + 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f}, + + {0.333333340f, 0.66666670f, 0.230769250f, 0.307692320f, 0.161290320f, 0.19354838f, + 0.122807020f, 0.14035088f, 0.098901100f, 0.109890110f, 0.082706770f, 0.09022556f, + 0.071038246f, 0.07650273f, 0.062240668f, 0.066390045f, 0.055374593f, 0.05863192f}); +} + + + + +BOOST_AUTO_TEST_SUITE_END() diff --git a/src/armnnTfParser/test/MultiOutput.cpp b/src/armnnTfParser/test/MultiOutput.cpp new file mode 100644 index 0000000000..56be33dab7 --- /dev/null +++ b/src/armnnTfParser/test/MultiOutput.cpp @@ -0,0 +1,144 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// See LICENSE file in the project root for full license information. +// + +#include +#include "armnnTfParser/ITfParser.hpp" +#include "ParserPrototxtFixture.hpp" + +BOOST_AUTO_TEST_SUITE(TensorflowParser) + +struct MultiOutMatchFixture : public ParserPrototxtFixture +{ + MultiOutMatchFixture() + { + m_Prototext = R"( +node { + name: "input" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } +} +node { + name: "softmax1" + op: "Softmax" + input: "input:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} + )"; + SetupSingleInputSingleOutput({ 1, 7 }, "input", "softmax1"); + } +}; + +BOOST_FIXTURE_TEST_CASE(MultiOutMatch, MultiOutMatchFixture) +{ + // Note that the point of this test is to verify the parsing went well. + // Here we make sure the softmax has really connected to the input layer. + RunTest<2>({ 0, 0, 10000, 0, 0, 0, 0 }, { 0, 0, 1, 0, 0, 0, 0 }); +} + +struct MultiOutFailFixture : public ParserPrototxtFixture +{ + MultiOutFailFixture() + { + m_Prototext = R"( +node { + name: "input" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } +} +node { + name: "softmax1" + op: "Softmax" + input: "input:1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} + )"; + BOOST_CHECK_THROW(SetupSingleInputSingleOutput({ 1, 7 }, "input", "softmax1"), armnn::ParseException); + } +}; + +BOOST_FIXTURE_TEST_CASE(MultiOutFail, MultiOutFailFixture) +{ + // Not running the graph because this is expected to throw an exception during parsing. +} + +struct MultiOutInvalidFixture : public ParserPrototxtFixture +{ + MultiOutInvalidFixture() + { + m_Prototext = R"( +node { + name: "input" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } +} +node { + name: "softmax1" + op: "Softmax" + input: "input:-1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} + )"; + BOOST_CHECK_THROW(SetupSingleInputSingleOutput({ 1, 7 }, "input", "softmax1"), armnn::ParseException); + } +}; + +BOOST_FIXTURE_TEST_CASE(MultiOutInvalid, MultiOutInvalidFixture) +{ + // Not running the graph because this is expected to throw an exception during parsing. +} + + +BOOST_AUTO_TEST_SUITE_END() \ No newline at end of file diff --git a/src/armnnTfParser/test/Multiplication.cpp b/src/armnnTfParser/test/Multiplication.cpp new file mode 100644 index 0000000000..3a20fd1141 --- /dev/null +++ b/src/armnnTfParser/test/Multiplication.cpp @@ -0,0 +1,172 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// See LICENSE file in the project root for full license information. +// + +#include +#include "armnnTfParser/ITfParser.hpp" +#include "ParserPrototxtFixture.hpp" + +BOOST_AUTO_TEST_SUITE(TensorflowParser) + +struct MultiplicationFixture : public ParserPrototxtFixture +{ + MultiplicationFixture() + { + m_Prototext = "node { \n" + " name: \"graphInput\" \n" + " op: \"Placeholder\" \n" + " attr { \n" + " key: \"dtype\" \n" + " value { \n" + " type: DT_FLOAT \n" + " } \n" + " } \n" + " attr { \n" + " key: \"shape\" \n" + " value { \n" + " shape { \n" + " } \n" + " } \n" + " } \n" + " } \n" + " node { \n" + " name: \"softmax1\" \n" + " op: \"Softmax\" \n" + " input: \"graphInput\" \n" + " attr { \n" + " key: \"T\" \n" + " value { \n" + " type: DT_FLOAT \n" + " } \n" + " } \n" + " }\n" + " node {\n" + " name: \"softmax2\"\n" + " op : \"Softmax\"\n" + " input: \"graphInput\"\n" + " attr { \n" + " key: \"T\" \n" + " value { \n" + " type: DT_FLOAT \n" + " } \n" + " } \n" + " }\n" + " node {\n" + " name: \"multiplication\"\n" + " op : \"Mul\"\n" + " input: \"softmax1\"\n" + " input: \"softmax2\"\n" + " attr { \n" + " key: \"T\" \n" + " value { \n" + " type: DT_FLOAT \n" + " } \n" + " } \n" + " }\n"; + + SetupSingleInputSingleOutput({ 1, 7 }, "graphInput", "multiplication"); + } +}; + +BOOST_FIXTURE_TEST_CASE(ParseMultiplication, MultiplicationFixture) +{ + RunTest<2>({ 0, 0, 10000, 0, 0, 0, 0 }, { 0, 0, 1, 0, 0, 0, 0 }); +} + +struct MultiplicationBroadcastFixture : public ParserPrototxtFixture +{ + MultiplicationBroadcastFixture(const armnn::TensorShape& inputShape0, const armnn::TensorShape& inputShape1) + { + m_Prototext = R"( +node { + name: "input0" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } +} +node { + name: "input1" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } +} +node { + name: "output" + op: "Mul" + input: "input0" + input: "input1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} + )"; + + Setup({ { "input0", inputShape0 }, + { "input1", inputShape1 } }, + { "output" }); + } +}; + +struct MultiplicationBroadcastFixture4D1D : public MultiplicationBroadcastFixture +{ + MultiplicationBroadcastFixture4D1D() : MultiplicationBroadcastFixture({ 1, 2, 2, 3 }, { 1 }) {} +}; + +BOOST_FIXTURE_TEST_CASE(ParseMultiplicationBroadcast4D1D, MultiplicationBroadcastFixture4D1D) +{ + RunTest<4>({ { "input0", { 0.0f, 1.0f, 2.0f, + 3.0f, 4.0f, 5.0f, + 6.0f, 7.0f, 8.0f, + 9.0f, 10.0f, 11.0f } }, + { "input1", { 5.0f } } }, + { { "output", { 0.0f, 5.0f, 10.0f, + 15.0f, 20.0f, 25.0f, + 30.0f, 35.0f, 40.0f, + 45.0f, 50.0f, 55.0f } } }); +} + +struct MultiplicationBroadcastFixture1D4D : public MultiplicationBroadcastFixture +{ + MultiplicationBroadcastFixture1D4D() : MultiplicationBroadcastFixture({ 1 }, { 1, 2, 2, 3 }) {} +}; + +BOOST_FIXTURE_TEST_CASE(ParseMultiplicationBroadcast1D4D, MultiplicationBroadcastFixture1D4D) +{ + RunTest<4>({ { "input0", { 3.0f } }, + { "input1", { 0.0f, 1.0f, 2.0f, + 3.0f, 4.0f, 5.0f, + 6.0f, 7.0f, 8.0f, + 9.0f, 10.0f, 11.0f } } }, + { { "output", { 0.0f, 3.0f, 6.0f, + 9.0f, 12.0f, 15.0f, + 18.0f, 21.0f, 24.0f, + 27.0f, 30.0f, 33.0f } } }); +} + +BOOST_AUTO_TEST_SUITE_END() diff --git a/src/armnnTfParser/test/PassThru.cpp b/src/armnnTfParser/test/PassThru.cpp new file mode 100644 index 0000000000..8462ec27cc --- /dev/null +++ b/src/armnnTfParser/test/PassThru.cpp @@ -0,0 +1,52 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// See LICENSE file in the project root for full license information. +// +#include +#include "armnnTfParser/ITfParser.hpp" +#include "ParserPrototxtFixture.hpp" + +BOOST_AUTO_TEST_SUITE(TensorflowParser) + +struct PassThruFixture : public ParserPrototxtFixture +{ + PassThruFixture() + { + m_Prototext = "node {\n" + " name: \"Placeholder\"\n" + " op: \"Placeholder\"\n" + " attr {\n" + " key: \"dtype\"\n" + " value {\n" + " type: DT_FLOAT\n" + " }\n" + " }\n" + " attr {\n" + " key: \"shape\"\n" + " value {\n" + " shape {\n" + " }\n" + " }\n" + " }\n" + "}\n"; + SetupSingleInputSingleOutput({ 1, 7 }, "Placeholder", "Placeholder"); + } +}; + +BOOST_FIXTURE_TEST_CASE(ValidateOutput, PassThruFixture) +{ + BOOST_TEST(m_Parser->GetNetworkOutputBindingInfo("Placeholder").second.GetNumDimensions() == 2); + BOOST_TEST(m_Parser->GetNetworkOutputBindingInfo("Placeholder").second.GetShape()[0] == 1); + BOOST_TEST(m_Parser->GetNetworkOutputBindingInfo("Placeholder").second.GetShape()[1] == 7); +} + +BOOST_FIXTURE_TEST_CASE(RunGraph, PassThruFixture) +{ + armnn::TensorInfo inputTensorInfo = m_Parser->GetNetworkInputBindingInfo("Placeholder").second; + auto input = MakeRandomTensor(inputTensorInfo, 378346); + std::vector inputVec; + inputVec.assign(input.data(), input.data() + input.num_elements()); + RunTest<2>(inputVec, inputVec); // The passthru network should output the same as the input +} + +BOOST_AUTO_TEST_SUITE_END() diff --git a/src/armnnTfParser/test/Pooling.cpp b/src/armnnTfParser/test/Pooling.cpp new file mode 100644 index 0000000000..36ffa47def --- /dev/null +++ b/src/armnnTfParser/test/Pooling.cpp @@ -0,0 +1,112 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// See LICENSE file in the project root for full license information. +// + +#include +#include "armnnTfParser/ITfParser.hpp" +#include "ParserPrototxtFixture.hpp" + +BOOST_AUTO_TEST_SUITE(TensorflowParser) + + +struct Pooling2dFixture : public ParserPrototxtFixture +{ + explicit Pooling2dFixture(const char* poolingtype) + { + m_Prototext = "node {\n" + " name: \"Placeholder\"\n" + " op: \"Placeholder\"\n" + " attr {\n" + " key: \"dtype\"\n" + " value {\n" + " type: DT_FLOAT\n" + " }\n" + " }\n" + " attr {\n" + " key: \"value\"\n" + " value {\n" + " tensor {\n" + " dtype: DT_FLOAT\n" + " tensor_shape {\n" + " }\n" + " }\n" + " }\n" + " }\n" + " }\n" + "node {\n" + " name: \""; + m_Prototext.append(poolingtype); + m_Prototext.append("\"\n" + " op: \""); + m_Prototext.append(poolingtype); + m_Prototext.append("\"\n" + " input: \"Placeholder\"\n" + " attr {\n" + " key: \"T\"\n" + " value {\n" + " type: DT_FLOAT\n" + " }\n" + " }\n" + " attr {\n" + " key: \"data_format\"\n" + " value {\n" + " s: \"NHWC\"\n" + " }\n" + " }\n" + " attr {\n" + " key: \"ksize\"\n" + " value {\n" + " list {\n" + " i: 1\n" + " i: 2\n" + " i: 2\n" + " i: 1\n" + " }\n" + " }\n" + " }\n" + " attr {\n" + " key: \"padding\"\n" + " value {\n" + " s: \"VALID\"\n" + " }\n" + " }\n" + " attr {\n" + " key: \"strides\"\n" + " value {\n" + " list {\n" + " i: 1\n" + " i: 1\n" + " i: 1\n" + " i: 1\n" + " }\n" + " }\n" + " }\n" + "}\n"); + + SetupSingleInputSingleOutput({ 1, 2, 2, 1 }, "Placeholder", poolingtype); + } +}; + + +struct MaxPoolFixture : Pooling2dFixture +{ + MaxPoolFixture() : Pooling2dFixture("MaxPool") {} +}; +BOOST_FIXTURE_TEST_CASE(ParseMaxPool, MaxPoolFixture) +{ + RunTest<4>({1.0f, 2.0f, 3.0f, -4.0f}, {3.0f}); +} + + +struct AvgPoolFixture : Pooling2dFixture +{ + AvgPoolFixture() : Pooling2dFixture("AvgPool") {} +}; +BOOST_FIXTURE_TEST_CASE(ParseAvgPool, AvgPoolFixture) +{ + RunTest<4>({1.0f, 2.0f, 3.0f, 4.0f}, {2.5f}); +} + + +BOOST_AUTO_TEST_SUITE_END() diff --git a/src/armnnTfParser/test/Reshape.cpp b/src/armnnTfParser/test/Reshape.cpp new file mode 100644 index 0000000000..4eb6b12467 --- /dev/null +++ b/src/armnnTfParser/test/Reshape.cpp @@ -0,0 +1,86 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// See LICENSE file in the project root for full license information. +// + +#include +#include "armnnTfParser/ITfParser.hpp" +#include "ParserPrototxtFixture.hpp" + +BOOST_AUTO_TEST_SUITE(TensorflowParser) + + +struct ReshapeFixture : public ParserPrototxtFixture +{ + ReshapeFixture() + { + m_Prototext = "node { \n" + " name: \"graphInput\" \n" + " op: \"Placeholder\" \n" + " attr { \n" + " key: \"dtype\" \n" + " value { \n" + " type: DT_FLOAT \n" + " } \n" + " } \n" + " attr { \n" + " key: \"shape\" \n" + " value { \n" + " shape { \n" + " } \n" + " } \n" + " } \n" + " } \n" + "node { \n" + " name: \"Reshape/shape\" \n" + " op: \"Const\" \n" + " attr { \n" + " key: \"dtype\" \n" + " value { \n" + " type: DT_INT32 \n" + " } \n" + " } \n" + " attr { \n" + " key: \"value\" \n" + " value { \n" + " tensor { \n" + " dtype: DT_INT32 \n" + " tensor_shape { \n" + " dim { \n" + " size: 2 \n" + " } \n" + " } \n" + " tensor_content: \"\\002\\000\\000\\000\\002\\000\\000\\000\" \n" + " } \n" + " } \n" + " } \n" + "} \n" + "node { \n" + " name: \"Reshape\" \n" + " op: \"Reshape\" \n" + " input: \"graphInput\" \n" + " input: \"Reshape/shape\" \n" + " attr { \n" + " key: \"T\" \n" + " value { \n" + " type: DT_FLOAT \n" + " } \n" + " } \n" + " attr { \n" + " key: \"Tshape\" \n" + " value { \n" + " type: DT_INT32 \n" + " } \n" + " } \n" + "} \n"; + + SetupSingleInputSingleOutput({1, 4}, "graphInput", "Reshape"); + } +}; + +BOOST_FIXTURE_TEST_CASE(ParseReshape, ReshapeFixture) +{ + RunTest<2>({ 0.0f, 1.0f, 2.0f, 3.0f }, { 0.0f, 1.0f, 2.0f, 3.0f }); +} + +BOOST_AUTO_TEST_SUITE_END() diff --git a/src/armnnTfParser/test/ResizeBilinear.cpp b/src/armnnTfParser/test/ResizeBilinear.cpp new file mode 100644 index 0000000000..30d898f5bb --- /dev/null +++ b/src/armnnTfParser/test/ResizeBilinear.cpp @@ -0,0 +1,114 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// See LICENSE file in the project root for full license information. +// + +#include +#include "armnnTfParser/ITfParser.hpp" +#include "ParserPrototxtFixture.hpp" + +BOOST_AUTO_TEST_SUITE(TensorflowParser) + +struct ResizeBilinearFixture : public ParserPrototxtFixture +{ + ResizeBilinearFixture() + { + m_Prototext = R"( +node { + name: "graphInput" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 1 + } + dim { + size: 3 + } + dim { + size: 3 + } + dim { + size: 1 + } + } + tensor_content: +"\000\000\000\000\000\000\200?\000\000\000@\000\000@@\000\000\200@\000\000\240@\000\000\300@\000\000\340@\000\000\000A" + } + } + } +} +node { + name: "resizeBilinearLayer/size" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\005\000\000\000\005\000\000\000" + } + } + } +} +node { + name: "resizeBilinearLayer" + op: "ResizeBilinear" + input: "graphInput" + input: "resizeBilinearLayer/size" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "align_corners" + value { + b: false + } + } +} + )"; + + SetupSingleInputSingleOutput({ 1, 3, 3, 1 }, "graphInput", "resizeBilinearLayer"); + } +}; + +BOOST_FIXTURE_TEST_CASE(ParseResizeBilinear, ResizeBilinearFixture) +{ + RunTest<4>(// input data + { 0.0f, 1.0f, 2.0f, + 3.0f, 4.0f, 5.0f, + 6.0f, 7.0f, 8.0f }, + // expected output data + { 0.0f, 0.6f, 1.2f, 1.8f, 2.0f, + 1.8f, 2.4f, 3.0f, 3.6f, 3.8f, + 3.6f, 4.2f, 4.8f, 5.4f, 5.6f, + 5.4f, 6.0f, 6.6f, 7.2f, 7.4f, + 6.0f, 6.6f, 7.2f, 7.8f, 8.0f }); + +} + +BOOST_AUTO_TEST_SUITE_END() diff --git a/src/armnnTfParser/test/Shape.cpp b/src/armnnTfParser/test/Shape.cpp new file mode 100644 index 0000000000..7b414ecfac --- /dev/null +++ b/src/armnnTfParser/test/Shape.cpp @@ -0,0 +1,94 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// See LICENSE file in the project root for full license information. +// + +#include +#include "armnnTfParser/ITfParser.hpp" +#include "ParserPrototxtFixture.hpp" + +BOOST_AUTO_TEST_SUITE(TensorflowParser) + +struct ShapeFixture : public ParserPrototxtFixture +{ + ShapeFixture() + { + m_Prototext = + "node { \n" + " name: \"Placeholder\" \n" + " op: \"Placeholder\" \n" + " attr { \n" + " key: \"dtype\" \n" + " value { \n" + " type: DT_FLOAT \n" + " } \n" + " } \n" + " attr { \n" + " key: \"shape\" \n" + " value { \n" + " shape { \n" + " dim { \n" + " size: 1 \n" + " } \n" + " dim { \n" + " size: 1 \n" + " } \n" + " dim { \n" + " size: 1 \n" + " } \n" + " dim { \n" + " size: 4 \n" + " } \n" + " } \n" + " } \n" + " } \n" + "} \n" + "node { \n" + " name: \"shapeTest\" \n" + " op: \"Shape\" \n" + " input: \"Placeholder\" \n" + " attr { \n" + " key: \"T\" \n" + " value { \n" + " type: DT_FLOAT \n" + " } \n" + " } \n" + " attr { \n" + " key: \"out_type\" \n" + " value { \n" + " type: DT_INT32 \n" + " } \n" + " } \n" + "} \n" + "node { \n" + " name: \"Reshape\" \n" + " op: \"Reshape\" \n" + " input: \"Placeholder\" \n" + " input: \"shapeTest\" \n" + " attr { \n" + " key: \"T\" \n" + " value { \n" + " type: DT_FLOAT \n" + " } \n" + " } \n" + " attr { \n" + " key: \"Tshape\" \n" + " value { \n" + " type: DT_INT32 \n" + " } \n" + " } \n" + "} \n"; + + SetupSingleInputSingleOutput({1, 4}, "Placeholder", "Reshape"); + } +}; + +BOOST_FIXTURE_TEST_CASE(ParseShape, ShapeFixture) +{ + // Note: the test's output cannot be an int32 const layer, because that cannot exist in the + // as ARMNN only supports u8 and float layers. For that reason I added a reshape layer + // which reshapes the input to its original dimensions, which is not changing it. + RunTest<2>({ 0.0f, 1.0f, 2.0f, 3.0f }, { 0.0f, 1.0f, 2.0f, 3.0f }); +} + +BOOST_AUTO_TEST_SUITE_END() diff --git a/src/armnnTfParser/test/Softmax.cpp b/src/armnnTfParser/test/Softmax.cpp new file mode 100644 index 0000000000..1ab28ea3aa --- /dev/null +++ b/src/armnnTfParser/test/Softmax.cpp @@ -0,0 +1,55 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// See LICENSE file in the project root for full license information. +// + +#include +#include "armnnTfParser/ITfParser.hpp" +#include "ParserPrototxtFixture.hpp" + +BOOST_AUTO_TEST_SUITE(TensorflowParser) + +struct SoftmaxFixture : public ParserPrototxtFixture +{ + SoftmaxFixture() + { + m_Prototext = "node {\n" + " name: \"blah\"\n" + " op: \"Placeholder\"\n" + " attr {\n" + " key: \"dtype\"\n" + " value {\n" + " type: DT_FLOAT\n" + " }\n" + " }\n" + " attr {\n" + " key: \"shape\"\n" + " value {\n" + " shape {\n" + " }\n" + " }\n" + " }\n" + "}\n" + "node {\n" + " name: \"blah2\"\n" + " op: \"Softmax\"\n" + " input: \"blah\"\n" + " attr {\n" + " key: \"T\"\n" + " value {\n" + " type: DT_FLOAT\n" + " }\n" + " }\n" + "}\n"; + + SetupSingleInputSingleOutput({ 1, 7 }, "blah", "blah2"); + } +}; + +BOOST_FIXTURE_TEST_CASE(ParseSoftmax, SoftmaxFixture) +{ + RunTest<2>({ 0, 0, 10000, 0, 0, 0, 0 }, { 0, 0, 1, 0, 0, 0, 0 }); +} + + +BOOST_AUTO_TEST_SUITE_END() diff --git a/src/armnnTfParser/test/Squeeze.cpp b/src/armnnTfParser/test/Squeeze.cpp new file mode 100644 index 0000000000..d2d7d49494 --- /dev/null +++ b/src/armnnTfParser/test/Squeeze.cpp @@ -0,0 +1,108 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// See LICENSE file in the project root for full license information. +// + +#include +#include "armnnTfParser/ITfParser.hpp" +#include "ParserPrototxtFixture.hpp" + +BOOST_AUTO_TEST_SUITE(TensorflowParser) + + +template +struct SqueezeFixture : public ParserPrototxtFixture +{ + SqueezeFixture() + { + m_Prototext = + "node { \n" + " name: \"graphInput\" \n" + " op: \"Placeholder\" \n" + " attr { \n" + " key: \"dtype\" \n" + " value { \n" + " type: DT_FLOAT \n" + " } \n" + " } \n" + " attr { \n" + " key: \"shape\" \n" + " value { \n" + " shape { \n" + " } \n" + " } \n" + " } \n" + " } \n" + "node { \n" + " name: \"Squeeze\" \n" + " op: \"Squeeze\" \n" + " input: \"graphInput\" \n" + " attr { \n" + " key: \"T\" \n" + " value { \n" + " type: DT_FLOAT \n" + " } \n" + " } \n" + " attr { \n" + " key: \"squeeze_dims\" \n" + " value { \n" + " list {\n"; + + if (withDimZero) + { + m_Prototext += "i:0\n"; + } + + if (withDimOne) + { + m_Prototext += "i:1\n"; + } + + m_Prototext += + " } \n" + " } \n" + " } \n" + "} \n"; + + SetupSingleInputSingleOutput({ 1, 1, 2, 2 }, "graphInput", "Squeeze"); + } +}; + +typedef SqueezeFixture ImpliedDimensionsSqueezeFixture; +typedef SqueezeFixture ExplicitDimensionZeroSqueezeFixture; +typedef SqueezeFixture ExplicitDimensionOneSqueezeFixture; +typedef SqueezeFixture ExplicitDimensionsSqueezeFixture; + +BOOST_FIXTURE_TEST_CASE(ParseImplicitSqueeze, ImpliedDimensionsSqueezeFixture) +{ + BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("Squeeze").second.GetShape() == + armnn::TensorShape({2,2}))); + RunTest<2>({ 1.0f, 2.0f, 3.0f, 4.0f }, + { 1.0f, 2.0f, 3.0f, 4.0f }); +} + +BOOST_FIXTURE_TEST_CASE(ParseDimensionZeroSqueeze, ExplicitDimensionZeroSqueezeFixture) +{ + BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("Squeeze").second.GetShape() == + armnn::TensorShape({1,2,2}))); + RunTest<3>({ 1.0f, 2.0f, 3.0f, 4.0f }, + { 1.0f, 2.0f, 3.0f, 4.0f }); +} + +BOOST_FIXTURE_TEST_CASE(ParseDimensionOneSqueeze, ExplicitDimensionOneSqueezeFixture) +{ + BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("Squeeze").second.GetShape() == + armnn::TensorShape({1,2,2}))); + RunTest<3>({ 1.0f, 2.0f, 3.0f, 4.0f }, + { 1.0f, 2.0f, 3.0f, 4.0f }); +} + +BOOST_FIXTURE_TEST_CASE(ParseExplicitDimensionsSqueeze, ExplicitDimensionsSqueezeFixture) +{ + BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("Squeeze").second.GetShape() == + armnn::TensorShape({2,2}))); + RunTest<2>({ 1.0f, 2.0f, 3.0f, 4.0f }, + { 1.0f, 2.0f, 3.0f, 4.0f }); +} + +BOOST_AUTO_TEST_SUITE_END() diff --git a/src/armnnTfParser/test/TestDependencies.cpp b/src/armnnTfParser/test/TestDependencies.cpp new file mode 100644 index 0000000000..13ab17c5b6 --- /dev/null +++ b/src/armnnTfParser/test/TestDependencies.cpp @@ -0,0 +1,296 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// See LICENSE file in the project root for full license information. +// + +#include +#include "armnnTfParser/ITfParser.hpp" +#include "ParserPrototxtFixture.hpp" + +BOOST_AUTO_TEST_SUITE(TensorflowParser) + +// Graph which tests that nodes are re-ordered in the queue when they are encountered a second time. +// In this case R0 will be encountered first via R1 and then via R2. At that time +// we need to make sure that R0 (and the I on which it is dependent) is moved to the front again +// so that it is before both R1 and R2. +// I +// | +// R0 +// / \' +// R1 R2 +// \ | +// \ R3 +// \| +// O +struct RediscoveredDependenciesFixture : public ParserPrototxtFixture +{ + RediscoveredDependenciesFixture() + { + // input = tf.placeholder(tf.float32, 1, "input") + // relu0 = tf.nn.relu(input, "relu0") + // relu1 = tf.nn.relu(relu0, "relu1") + // relu2 = tf.nn.relu(relu0, "relu2") + // relu3 = tf.nn.relu(relu2, "relu3") + // output = tf.add(relu1, relu3, "output") + m_Prototext = R"( + node { + name: "input" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 1 + } + } + } + } + } + node { + name: "relu0" + op: "Relu" + input: "input" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node { + name: "relu1" + op: "Relu" + input: "relu0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node { + name: "relu2" + op: "Relu" + input: "relu0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node { + name: "relu3" + op: "Relu" + input: "relu2" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node { + name: "output" + op: "Add" + input: "relu1" + input: "relu3" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + )"; + SetupSingleInputSingleOutput({ 1 }, "input", "output"); + } +}; + +BOOST_FIXTURE_TEST_CASE(RediscoveredDependencies, RediscoveredDependenciesFixture) +{ + RunTest<1>({1}, {2}); +} + +// Tests that a simple cycle in the tensorflow graph will be detected and an exception thrown, rather than the TfParser +// getting stuck in an infinite loop. +BOOST_AUTO_TEST_CASE(SimpleCycle) +{ + const char* prototext = R"( +node { + name: "r1" + op: "Relu" + input: "r2" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "r2" + op: "Relu" + input: "r1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} + )"; + armnnTfParser::ITfParserPtr parser = armnnTfParser::ITfParser::Create(); + BOOST_CHECK_THROW(parser->CreateNetworkFromString(prototext, {}, { "r2" }), armnn::ParseException); +} + +// Similar to the above SimpleCycle test, but has a single node which connects to itself. +BOOST_AUTO_TEST_CASE(SingleNodeCycle) +{ + const char* prototext = R"( +node { + name: "r1" + op: "Relu" + input: "r1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} + )"; + armnnTfParser::ITfParserPtr parser = armnnTfParser::ITfParser::Create(); + BOOST_CHECK_THROW(parser->CreateNetworkFromString(prototext, {}, { "r1" }), armnn::ParseException); +} + +// Similar to the above SimpleCycle test, but with a more complicated graph. +// I +// | +// A2---<---<- +// / \' | +// R1 R2 | +// \ | | +// \ R3 | +// \| | +// A1-->--->| +// +BOOST_AUTO_TEST_CASE(ComplexCycle) +{ + // input = tf.placeholder(tf.float32, 1, "input") + // add2 = tf.nn.relu(input, add1, "add2") // This line won't actually run in TF, because add1 is not yet defined + // relu1 = tf.nn.relu(relu0, "relu1") + // relu2 = tf.nn.relu(relu0, "relu2") + // relu3 = tf.nn.relu(relu2, "relu3") + // add1 = tf.add(relu1, relu3, "add1") + const char* prototext = R"( + node { + name: "input" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 1 + } + } + } + } + } + node { + name: "add2" + op: "Add" + input: "input" + input: "add1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node { + name: "relu1" + op: "Relu" + input: "add2" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node { + name: "relu2" + op: "Relu" + input: "add2" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node { + name: "relu3" + op: "Relu" + input: "relu2" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node { + name: "add1" + op: "Add" + input: "relu1" + input: "relu3" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + )"; + armnnTfParser::ITfParserPtr parser = armnnTfParser::ITfParser::Create(); + BOOST_CHECK_THROW(parser->CreateNetworkFromString(prototext, {}, { "add1" }), armnn::ParseException); +} + +// Tests that a graph with an input that is not present throws a ParseException. +BOOST_AUTO_TEST_CASE(InvalidInput) +{ + const char* prototext = R"( +node { + name: "r1" + op: "Relu" + input: "a-node-that-does-not-exist" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} + )"; + armnnTfParser::ITfParserPtr parser = armnnTfParser::ITfParser::Create(); + BOOST_CHECK_THROW(parser->CreateNetworkFromString(prototext, {}, { "r1" }), armnn::ParseException); +} + +BOOST_AUTO_TEST_SUITE_END() diff --git a/src/armnnTfParser/test/TestMultiInputsOutputs.cpp b/src/armnnTfParser/test/TestMultiInputsOutputs.cpp new file mode 100644 index 0000000000..5eea616ec8 --- /dev/null +++ b/src/armnnTfParser/test/TestMultiInputsOutputs.cpp @@ -0,0 +1,92 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// See LICENSE file in the project root for full license information. +// + +#include +#include "armnnTfParser/ITfParser.hpp" +#include "ParserPrototxtFixture.hpp" + +BOOST_AUTO_TEST_SUITE(TensorflowParser) + +struct MultiInputsOutputsFixture : public ParserPrototxtFixture +{ + MultiInputsOutputsFixture() + { + // input1 = tf.placeholder(tf.float32, shape=[], name = "input1") + // input2 = tf.placeholder(tf.float32, shape = [], name = "input2") + // add1 = tf.add(input1, input2, name = "add1") + // add2 = tf.add(input1, input2, name = "add2") + m_Prototext = R"( +node { + name: "input1" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } +} +node { + name: "input2" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } +} +node { + name: "add1" + op: "Add" + input: "input1" + input: "input2" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "add2" + op: "Add" + input: "input1" + input: "input2" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} + )"; + Setup({ { "input1", { 1 } }, + { "input2", { 1 } } }, + { "add1", "add2" }); + } +}; + +BOOST_FIXTURE_TEST_CASE(MultiInputsOutputs, MultiInputsOutputsFixture) +{ + RunTest<1>({ { "input1", {12.0f} }, { "input2", { 13.0f } } }, + { { "add1", { 25.0f } }, { "add2", { 25.0f } } }); +} + +BOOST_AUTO_TEST_SUITE_END() -- cgit v1.2.1