diff options
Diffstat (limited to 'src/armnnTfLiteParser/test/Conv2D.cpp')
-rw-r--r-- | src/armnnTfLiteParser/test/Conv2D.cpp | 72 |
1 files changed, 56 insertions, 16 deletions
diff --git a/src/armnnTfLiteParser/test/Conv2D.cpp b/src/armnnTfLiteParser/test/Conv2D.cpp index c25e62bb00..45c4a43519 100644 --- a/src/armnnTfLiteParser/test/Conv2D.cpp +++ b/src/armnnTfLiteParser/test/Conv2D.cpp @@ -104,18 +104,21 @@ TEST_CASE_FIXTURE(SimpleConv2DFixture, "ParseSimpleConv2D") struct Conv2DWithBiasesFixture : public ParserFlatbuffersFixture { - explicit Conv2DWithBiasesFixture(const std::string & inputShape, - const std::string & outputShape, - const std::string & filterShape, - const std::string & filterData, - const std::string & biasShape, - const std::string & biasData, - const std::string & strides, - const std::string & activation="NONE", - const std::string & filterScale="1.0", - const std::string & filterZeroPoint="0", - const std::string & outputScale="2.0", - const std::string & outputZeroPoint="0") + explicit Conv2DWithBiasesFixture(const std::string& inputShape, + const std::string& outputShape, + const std::string& filterShape, + const std::string& filterData, + const std::string& biasShape, + const std::string& biasData, + const std::string& strides, + const std::string& activation="NONE", + const std::string& filterScale="1.0", + const std::string& filterZeroPoint="0", + const std::string& outputScale="2.0", + const std::string& outputZeroPoint="0", + const std::string& dataType = "UINT8", + const std::string& filterDataType = "UINT8", + const std::string& biasDataType = "INT32") { m_JsonString = R"( { @@ -125,7 +128,7 @@ struct Conv2DWithBiasesFixture : public ParserFlatbuffersFixture "tensors": [ { "shape": )" + inputShape + R"(, - "type": "UINT8", + "type": )" + dataType + R"(, "buffer": 0, "name": "inputTensor", "quantization": { @@ -137,7 +140,7 @@ struct Conv2DWithBiasesFixture : public ParserFlatbuffersFixture }, { "shape": )" + outputShape + R"(, - "type": "UINT8", + "type": )" + dataType + R"(, "buffer": 1, "name": "outputTensor", "quantization": { @@ -149,7 +152,7 @@ struct Conv2DWithBiasesFixture : public ParserFlatbuffersFixture }, { "shape": )" + filterShape + R"( , - "type": "UINT8", + "type": )" + filterDataType + R"(, "buffer": 2, "name": "filterTensor", "quantization": { @@ -161,7 +164,7 @@ struct Conv2DWithBiasesFixture : public ParserFlatbuffersFixture }, { "shape": )" + biasShape + R"( , - "type": "INT32", + "type": )" + biasDataType + R"(, "buffer": 3, "name": "biasTensor", "quantization": { @@ -662,4 +665,41 @@ TEST_CASE_FIXTURE(PerChannelConv2DFixture, "ParsePerChannelConv2D") }); } +struct Conv2FloatWithInt8WeightsAndBiasesFixture : Conv2DWithBiasesFixture +{ + Conv2FloatWithInt8WeightsAndBiasesFixture() + : Conv2DWithBiasesFixture("[ 1, 2, 2, 1 ]", // inputShape + "[ 1, 2, 2, 1 ]", // outputShape + "[ 1, 2, 2, 1 ]", // filterShape + "[ 2,1, 0,6 ]", // filterData + "[ 1 ]", // biasShape + "[ 10, 0, 0, 0 ]", // biasData + "1", // stride w and h + "NONE", // activation + "1.0", // filterScale + "0", // filterZeroPoint + "2.0", // outputScale + "0", // outputZeroPoint + "FLOAT32", // dataType + "INT8", // filterDataType + "INT8") // biasDataType + {} +}; + +TEST_CASE_FIXTURE(Conv2FloatWithInt8WeightsAndBiasesFixture, "ParseConv2FloatWithInt8WeightsAndBiasesFixture") +{ + RunTest<4, armnn::DataType::Float32>( + 0, + { + 1, 2, + 3, 4, + }, + { + (1*2 + 2*1 + 3*0 + 4*6 + 10), + (2*2 + 0*1 + 4*0 + 0*6 + 10), + (3*2 + 4*1 + 0*0 + 0*6 + 10), + (4*2 + 0*1 + 0*0 + 0*6 + 10) + }); +} + } |