diff options
Diffstat (limited to 'src/armnnTfLiteParser/test/TransposeConv.cpp')
-rw-r--r-- | src/armnnTfLiteParser/test/TransposeConv.cpp | 139 |
1 files changed, 139 insertions, 0 deletions
diff --git a/src/armnnTfLiteParser/test/TransposeConv.cpp b/src/armnnTfLiteParser/test/TransposeConv.cpp index 084a286dbd..94e42438e1 100644 --- a/src/armnnTfLiteParser/test/TransposeConv.cpp +++ b/src/armnnTfLiteParser/test/TransposeConv.cpp @@ -131,4 +131,143 @@ BOOST_FIXTURE_TEST_CASE( ParseSimpleTransposeConv, SimpleTransposeConvFixture ) }); } +struct TransposeConvFixtureWithBias : public ParserFlatbuffersFixture +{ + explicit TransposeConvFixtureWithBias(const std::string& inputShape, + const std::string& outputShape, + const std::string& filterShape, + const std::string& filterData, + const std::string& strideX, + const std::string& strideY, + const std::string& dataType, + const std::string& biasShape, + const std::string& biasData) + { + m_JsonString = R"( + { + "version": 3, + "operator_codes": [ { "builtin_code": "TRANSPOSE_CONV" } ], + "subgraphs": [ { + "tensors": [ + { + "shape": [ 4 ], + "type": "UINT8", + "buffer": 0, + "name": "outputShapeTensor", + "quantization": { + "min": [ 0.0 ], + "max": [ 255.0 ], + "scale": [ 1.0 ], + "zero_point": [ 0 ], + } + }, + { + "shape": )" + filterShape + R"(, + "type": ")" + dataType + R"(", + "buffer": 1, + "name": "filterTensor", + "quantization": { + "min": [ 0.0 ], + "max": [ 255.0 ], + "scale": [ 1.0 ], + "zero_point": [ 0 ], + } + }, + { + "shape": )" + inputShape + R"(, + "type": ")" + dataType + R"(", + "buffer": 2, + "name": "inputTensor", + "quantization": { + "min": [ 0.0 ], + "max": [ 255.0 ], + "scale": [ 1.0 ], + "zero_point": [ 0 ], + } + }, + { + "shape": )" + biasShape + R"( , + "type": "INT32", + "buffer": 3, + "name": "biasTensor", + "quantization": { + "min": [ 0.0 ], + "max": [ 255.0 ], + "scale": [ 1.0 ], + "zero_point": [ 0 ], + } + }, + { + "shape": )" + outputShape + R"(, + "type": ")" + dataType + R"(", + "buffer": 4, + "name": "outputTensor", + "quantization": { + "min": [ 0.0 ], + "max": [ 255.0 ], + "scale": [ 1.0 ], + "zero_point": [ 0 ], + } + } + ], + "inputs": [ 2 ], + "outputs": [ 4 ], + "operators": [ + { + "opcode_index": 0, + "inputs": [ 0, 1, 2, 3], + "outputs": [ 4 ], + "builtin_options_type": "TransposeConvOptions", + "builtin_options": { + "padding": "VALID", + "stride_w": )" + strideX + R"(, + "stride_h": )" + strideY + R"( + }, + "custom_options_format": "FLEXBUFFERS" + } + ], + } ], + "buffers" : [ + { "data": )" + outputShape + R"( }, + { "data": )" + filterData + R"( }, + { }, + { "data": )" + biasData + R"( }, + { } + ] + } + )"; + SetupSingleInputSingleOutput("inputTensor", "outputTensor"); + } +}; + +struct SimpleTransposeConvFixtureWithBias : TransposeConvFixtureWithBias +{ + SimpleTransposeConvFixtureWithBias() + : TransposeConvFixtureWithBias("[ 1, 2, 2, 1 ]", // inputShape + "[ 1, 3, 3, 1 ]", // outputShape + "[ 1, 2, 2, 1 ]", // filterShape + "[ 0, 1, 2, 4 ]", // filterData + "1", // strideX + "1", // strideY + "UINT8", // dataType + "[ 1 ]", // bias shape + "[ 10, 0, 0, 0 ]") // bias data + {} +}; + +BOOST_FIXTURE_TEST_CASE( ParseSimpleTransposeConvWithBias, SimpleTransposeConvFixtureWithBias ) +{ + RunTest<4, armnn::DataType::QAsymmU8>( + 0, + { + 1, 2, + 3, 4 + }, + { + 10, 11, 12, + 12, 21, 22, + 16, 30, 26 + }); +} + BOOST_AUTO_TEST_SUITE_END() |