diff options
author | Mike Kelly <mike.kelly@arm.com> | 2022-01-28 16:18:54 +0000 |
---|---|---|
committer | mike.kelly <mike.kelly@arm.com> | 2022-04-22 15:46:50 +0000 |
commit | 5880b911bf4b7fd8308c93e299d77ac78f282c19 (patch) | |
tree | b256346d6fc78e78735cc50ec822286f809dd37f /src/armnnTfLiteParser/test/Conv2D.cpp | |
parent | 4dae5794644b44be8c93bc6db553a205551bc077 (diff) | |
download | armnn-5880b911bf4b7fd8308c93e299d77ac78f282c19.tar.gz |
MLCE-604 Add Unidirectional Sequence Lstm support to TFLite
* Added Unidirectional Sequence Lstm support to TFLite Parser
* Added support for float operations with int8 weights to TFLite Parser
* Added to Conv2d, Conv3D, DepthwiseConv2D, FullyConnected,
TransposeConv and UnidirectionalSequenceLstm
* Renamed subgraphIndex to subgraph to fix name-shadowing warning.
Signed-off-by: Mike Kelly <mike.kelly@arm.com>
Change-Id: I818976ab88abc05dcb4bad246fb4108e6e879283
Diffstat (limited to 'src/armnnTfLiteParser/test/Conv2D.cpp')
-rw-r--r-- | src/armnnTfLiteParser/test/Conv2D.cpp | 72 |
1 files changed, 56 insertions, 16 deletions
diff --git a/src/armnnTfLiteParser/test/Conv2D.cpp b/src/armnnTfLiteParser/test/Conv2D.cpp index c25e62bb00..45c4a43519 100644 --- a/src/armnnTfLiteParser/test/Conv2D.cpp +++ b/src/armnnTfLiteParser/test/Conv2D.cpp @@ -104,18 +104,21 @@ TEST_CASE_FIXTURE(SimpleConv2DFixture, "ParseSimpleConv2D") struct Conv2DWithBiasesFixture : public ParserFlatbuffersFixture { - explicit Conv2DWithBiasesFixture(const std::string & inputShape, - const std::string & outputShape, - const std::string & filterShape, - const std::string & filterData, - const std::string & biasShape, - const std::string & biasData, - const std::string & strides, - const std::string & activation="NONE", - const std::string & filterScale="1.0", - const std::string & filterZeroPoint="0", - const std::string & outputScale="2.0", - const std::string & outputZeroPoint="0") + explicit Conv2DWithBiasesFixture(const std::string& inputShape, + const std::string& outputShape, + const std::string& filterShape, + const std::string& filterData, + const std::string& biasShape, + const std::string& biasData, + const std::string& strides, + const std::string& activation="NONE", + const std::string& filterScale="1.0", + const std::string& filterZeroPoint="0", + const std::string& outputScale="2.0", + const std::string& outputZeroPoint="0", + const std::string& dataType = "UINT8", + const std::string& filterDataType = "UINT8", + const std::string& biasDataType = "INT32") { m_JsonString = R"( { @@ -125,7 +128,7 @@ struct Conv2DWithBiasesFixture : public ParserFlatbuffersFixture "tensors": [ { "shape": )" + inputShape + R"(, - "type": "UINT8", + "type": )" + dataType + R"(, "buffer": 0, "name": "inputTensor", "quantization": { @@ -137,7 +140,7 @@ struct Conv2DWithBiasesFixture : public ParserFlatbuffersFixture }, { "shape": )" + outputShape + R"(, - "type": "UINT8", + "type": )" + dataType + R"(, "buffer": 1, "name": "outputTensor", "quantization": { @@ -149,7 +152,7 @@ struct Conv2DWithBiasesFixture : public ParserFlatbuffersFixture }, { "shape": )" + filterShape + R"( , - "type": "UINT8", + "type": )" + filterDataType + R"(, "buffer": 2, "name": "filterTensor", "quantization": { @@ -161,7 +164,7 @@ struct Conv2DWithBiasesFixture : public ParserFlatbuffersFixture }, { "shape": )" + biasShape + R"( , - "type": "INT32", + "type": )" + biasDataType + R"(, "buffer": 3, "name": "biasTensor", "quantization": { @@ -662,4 +665,41 @@ TEST_CASE_FIXTURE(PerChannelConv2DFixture, "ParsePerChannelConv2D") }); } +struct Conv2FloatWithInt8WeightsAndBiasesFixture : Conv2DWithBiasesFixture +{ + Conv2FloatWithInt8WeightsAndBiasesFixture() + : Conv2DWithBiasesFixture("[ 1, 2, 2, 1 ]", // inputShape + "[ 1, 2, 2, 1 ]", // outputShape + "[ 1, 2, 2, 1 ]", // filterShape + "[ 2,1, 0,6 ]", // filterData + "[ 1 ]", // biasShape + "[ 10, 0, 0, 0 ]", // biasData + "1", // stride w and h + "NONE", // activation + "1.0", // filterScale + "0", // filterZeroPoint + "2.0", // outputScale + "0", // outputZeroPoint + "FLOAT32", // dataType + "INT8", // filterDataType + "INT8") // biasDataType + {} +}; + +TEST_CASE_FIXTURE(Conv2FloatWithInt8WeightsAndBiasesFixture, "ParseConv2FloatWithInt8WeightsAndBiasesFixture") +{ + RunTest<4, armnn::DataType::Float32>( + 0, + { + 1, 2, + 3, 4, + }, + { + (1*2 + 2*1 + 3*0 + 4*6 + 10), + (2*2 + 0*1 + 4*0 + 0*6 + 10), + (3*2 + 4*1 + 0*0 + 0*6 + 10), + (4*2 + 0*1 + 0*0 + 0*6 + 10) + }); +} + } |