From 616838074de39b46c3e91ac07a7d0956c7cd40e8 Mon Sep 17 00:00:00 2001 From: David Monahan Date: Tue, 12 Jan 2021 09:11:07 +0000 Subject: IVGCVSW-5424 TFLite parser not parsing new TransposeConv * Added support for bias vector in TransposeConv to TfLite parser * Added UnitTest Signed-off-by: David Monahan Change-Id: I6483986a9da9216084b4b885b7fd980fc3580fa9 --- src/armnnTfLiteParser/TfLiteParser.cpp | 31 +++++- src/armnnTfLiteParser/test/TransposeConv.cpp | 139 +++++++++++++++++++++++++++ 2 files changed, 165 insertions(+), 5 deletions(-) (limited to 'src') diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index 9d9f4fa14b..8e0fae68d1 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -1083,7 +1083,14 @@ void TfLiteParser::ParseTransposeConv(size_t subgraphIndex, size_t operatorIndex desc.m_DataLayout = armnn::DataLayout::NHWC; auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex); - CHECK_VALID_SIZE(inputs.size(), 3); + if (inputs.size() == 4) + { + desc.m_BiasEnabled = true; + } + else + { + CHECK_VALID_SIZE(inputs.size(), 3); + } auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex); CHECK_VALID_SIZE(outputs.size(), 1); @@ -1143,10 +1150,24 @@ void TfLiteParser::ParseTransposeConv(size_t subgraphIndex, size_t operatorIndex armnn::IConnectableLayer* layer = nullptr; auto layerName = fmt::format("TransposeConv:{}:{}", subgraphIndex, operatorIndex); - layer = m_Network->AddTransposeConvolution2dLayer(desc, - filterTensorAndData.first, - EmptyOptional(), - layerName.c_str()); + if (desc.m_BiasEnabled) + { + auto biasTensorInfo = ToTensorInfo(inputs[3]); + auto biasConstTensor = CreateConstTensor(inputs[3], + biasTensorInfo, + armnn::Optional()); + layer = m_Network->AddTransposeConvolution2dLayer(desc, + filterTensorAndData.first, + biasConstTensor.first, + layerName.c_str()); + } + else + { + layer = m_Network->AddTransposeConvolution2dLayer(desc, + filterTensorAndData.first, + EmptyOptional(), + layerName.c_str()); + } ARMNN_ASSERT(layer != nullptr); diff --git a/src/armnnTfLiteParser/test/TransposeConv.cpp b/src/armnnTfLiteParser/test/TransposeConv.cpp index 084a286dbd..94e42438e1 100644 --- a/src/armnnTfLiteParser/test/TransposeConv.cpp +++ b/src/armnnTfLiteParser/test/TransposeConv.cpp @@ -131,4 +131,143 @@ BOOST_FIXTURE_TEST_CASE( ParseSimpleTransposeConv, SimpleTransposeConvFixture ) }); } +struct TransposeConvFixtureWithBias : public ParserFlatbuffersFixture +{ + explicit TransposeConvFixtureWithBias(const std::string& inputShape, + const std::string& outputShape, + const std::string& filterShape, + const std::string& filterData, + const std::string& strideX, + const std::string& strideY, + const std::string& dataType, + const std::string& biasShape, + const std::string& biasData) + { + m_JsonString = R"( + { + "version": 3, + "operator_codes": [ { "builtin_code": "TRANSPOSE_CONV" } ], + "subgraphs": [ { + "tensors": [ + { + "shape": [ 4 ], + "type": "UINT8", + "buffer": 0, + "name": "outputShapeTensor", + "quantization": { + "min": [ 0.0 ], + "max": [ 255.0 ], + "scale": [ 1.0 ], + "zero_point": [ 0 ], + } + }, + { + "shape": )" + filterShape + R"(, + "type": ")" + dataType + R"(", + "buffer": 1, + "name": "filterTensor", + "quantization": { + "min": [ 0.0 ], + "max": [ 255.0 ], + "scale": [ 1.0 ], + "zero_point": [ 0 ], + } + }, + { + "shape": )" + inputShape + R"(, + "type": ")" + dataType + R"(", + "buffer": 2, + "name": "inputTensor", + "quantization": { + "min": [ 0.0 ], + "max": [ 255.0 ], + "scale": [ 1.0 ], + "zero_point": [ 0 ], + } + }, + { + "shape": )" + biasShape + R"( , + "type": "INT32", + "buffer": 3, + "name": "biasTensor", + "quantization": { + "min": [ 0.0 ], + "max": [ 255.0 ], + "scale": [ 1.0 ], + "zero_point": [ 0 ], + } + }, + { + "shape": )" + outputShape + R"(, + "type": ")" + dataType + R"(", + "buffer": 4, + "name": "outputTensor", + "quantization": { + "min": [ 0.0 ], + "max": [ 255.0 ], + "scale": [ 1.0 ], + "zero_point": [ 0 ], + } + } + ], + "inputs": [ 2 ], + "outputs": [ 4 ], + "operators": [ + { + "opcode_index": 0, + "inputs": [ 0, 1, 2, 3], + "outputs": [ 4 ], + "builtin_options_type": "TransposeConvOptions", + "builtin_options": { + "padding": "VALID", + "stride_w": )" + strideX + R"(, + "stride_h": )" + strideY + R"( + }, + "custom_options_format": "FLEXBUFFERS" + } + ], + } ], + "buffers" : [ + { "data": )" + outputShape + R"( }, + { "data": )" + filterData + R"( }, + { }, + { "data": )" + biasData + R"( }, + { } + ] + } + )"; + SetupSingleInputSingleOutput("inputTensor", "outputTensor"); + } +}; + +struct SimpleTransposeConvFixtureWithBias : TransposeConvFixtureWithBias +{ + SimpleTransposeConvFixtureWithBias() + : TransposeConvFixtureWithBias("[ 1, 2, 2, 1 ]", // inputShape + "[ 1, 3, 3, 1 ]", // outputShape + "[ 1, 2, 2, 1 ]", // filterShape + "[ 0, 1, 2, 4 ]", // filterData + "1", // strideX + "1", // strideY + "UINT8", // dataType + "[ 1 ]", // bias shape + "[ 10, 0, 0, 0 ]") // bias data + {} +}; + +BOOST_FIXTURE_TEST_CASE( ParseSimpleTransposeConvWithBias, SimpleTransposeConvFixtureWithBias ) +{ + RunTest<4, armnn::DataType::QAsymmU8>( + 0, + { + 1, 2, + 3, 4 + }, + { + 10, 11, 12, + 12, 21, 22, + 16, 30, 26 + }); +} + BOOST_AUTO_TEST_SUITE_END() -- cgit v1.2.1