diff options
Diffstat (limited to 'src/armnnTfLiteParser')
-rw-r--r-- | src/armnnTfLiteParser/TfLiteParser.cpp | 6 | ||||
-rw-r--r-- | src/armnnTfLiteParser/test/Pad.cpp | 59 |
2 files changed, 55 insertions, 10 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index d1d45f5583..c3d56b13d3 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -1845,6 +1845,8 @@ void TfLiteParser::ParsePad(size_t subgraphIndex, size_t operatorIndex) TfLiteParser::TensorRawPtrVector outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex); CHECK_VALID_SIZE(outputs.size(), 1); + armnn::TensorInfo inputTensorInfo = ToTensorInfo(inputs[0]); + armnn::TensorInfo padTensorInfo = ToTensorInfo(inputs[1]); BufferRawPtr bufferPtr = GetBuffer(m_Model, inputs[1]->buffer); @@ -1853,6 +1855,10 @@ void TfLiteParser::ParsePad(size_t subgraphIndex, size_t operatorIndex) size_t step = 2; armnn::PadDescriptor desc; + if (inputTensorInfo.IsQuantized()) + { + desc.m_PadValue = static_cast<float>(inputTensorInfo.GetQuantizationOffset()); + } for (unsigned int i = 0; i < padTensorInfo.GetNumElements() / step; ++i) { desc.m_PadList.emplace_back(padBuffer[i * step], padBuffer[i * step + 1]); diff --git a/src/armnnTfLiteParser/test/Pad.cpp b/src/armnnTfLiteParser/test/Pad.cpp index bdc8478ca2..aab1536628 100644 --- a/src/armnnTfLiteParser/test/Pad.cpp +++ b/src/armnnTfLiteParser/test/Pad.cpp @@ -14,10 +14,13 @@ BOOST_AUTO_TEST_SUITE(TensorflowLiteParser) struct PadFixture : public ParserFlatbuffersFixture { - explicit PadFixture(const std::string & inputShape, - const std::string & outputShape, - const std::string & padListShape, - const std::string & padListData) + explicit PadFixture(const std::string& inputShape, + const std::string& outputShape, + const std::string& padListShape, + const std::string& padListData, + const std::string& dataType = "FLOAT32", + const std::string& scale = "1.0", + const std::string& offset = "0") { m_JsonString = R"( { @@ -27,26 +30,26 @@ struct PadFixture : public ParserFlatbuffersFixture "tensors": [ { "shape": )" + inputShape + R"(, - "type": "FLOAT32", + "type": )" + dataType + R"(, "buffer": 0, "name": "inputTensor", "quantization": { "min": [ 0.0 ], "max": [ 255.0 ], - "scale": [ 1.0 ], - "zero_point": [ 0 ], + "scale": [ )" + scale + R"( ], + "zero_point": [ )" + offset + R"( ], } }, { "shape": )" + outputShape + R"(, - "type": "FLOAT32", + "type": )" + dataType + R"(, "buffer": 1, "name": "outputTensor", "quantization": { "min": [ 0.0 ], "max": [ 255.0 ], - "scale": [ 1.0 ], - "zero_point": [ 0 ], + "scale": [ )" + scale + R"( ], + "zero_point": [ )" + offset + R"( ], } }, { @@ -101,4 +104,40 @@ BOOST_FIXTURE_TEST_CASE(ParsePad, SimplePadFixture) 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f }}}); } +struct Uint8PadFixture : public PadFixture +{ + Uint8PadFixture() : PadFixture("[ 2, 3 ]", "[ 4, 7 ]", "[ 2, 2 ]", + "[ 1,0,0,0, 1,0,0,0, 2,0,0,0, 2,0,0,0 ]", + "UINT8", "-2.0", "3") {} +}; + +BOOST_FIXTURE_TEST_CASE(ParsePadUint8, Uint8PadFixture) +{ + RunTest<2, armnn::DataType::QAsymmU8> + (0, + {{ "inputTensor", { 1, 2, 3, 4, 5, 6 }}}, + {{ "outputTensor", { 3, 3, 3, 3, 3, 3, 3, + 3, 3, 1, 2, 3, 3, 3, + 3, 3, 4, 5, 6, 3, 3, + 3, 3, 3, 3, 3, 3, 3 }}}); +} + +struct Int8PadFixture : public PadFixture +{ + Int8PadFixture() : PadFixture("[ 2, 3 ]", "[ 4, 7 ]", "[ 2, 2 ]", + "[ 1,0,0,0, 1,0,0,0, 2,0,0,0, 2,0,0,0 ]", + "INT8", "-2.0", "3") {} +}; + +BOOST_FIXTURE_TEST_CASE(ParsePadInt8, Int8PadFixture) +{ + RunTest<2, armnn::DataType::QAsymmS8> + (0, + {{ "inputTensor", { 1, -2, 3, 4, 5, -6 }}}, + {{ "outputTensor", { 3, 3, 3, 3, 3, 3, 3, + 3, 3, 1, -2, 3, 3, 3, + 3, 3, 4, 5, -6, 3, 3, + 3, 3, 3, 3, 3, 3, 3 }}}); +} + BOOST_AUTO_TEST_SUITE_END() |