From 5880b911bf4b7fd8308c93e299d77ac78f282c19 Mon Sep 17 00:00:00 2001 From: Mike Kelly Date: Fri, 28 Jan 2022 16:18:54 +0000 Subject: MLCE-604 Add Unidirectional Sequence Lstm support to TFLite * Added Unidirectional Sequence Lstm support to TFLite Parser * Added support for float operations with int8 weights to TFLite Parser * Added to Conv2d, Conv3D, DepthwiseConv2D, FullyConnected, TransposeConv and UnidirectionalSequenceLstm * Renamed subgraphIndex to subgraph to fix name-shadowing warning. Signed-off-by: Mike Kelly Change-Id: I818976ab88abc05dcb4bad246fb4108e6e879283 --- src/armnnTfLiteParser/test/Conv2D.cpp | 72 +++++++++++++++++++++------ src/armnnTfLiteParser/test/FullyConnected.cpp | 36 ++++++++++++-- 2 files changed, 87 insertions(+), 21 deletions(-) (limited to 'src/armnnTfLiteParser/test') diff --git a/src/armnnTfLiteParser/test/Conv2D.cpp b/src/armnnTfLiteParser/test/Conv2D.cpp index c25e62bb00..45c4a43519 100644 --- a/src/armnnTfLiteParser/test/Conv2D.cpp +++ b/src/armnnTfLiteParser/test/Conv2D.cpp @@ -104,18 +104,21 @@ TEST_CASE_FIXTURE(SimpleConv2DFixture, "ParseSimpleConv2D") struct Conv2DWithBiasesFixture : public ParserFlatbuffersFixture { - explicit Conv2DWithBiasesFixture(const std::string & inputShape, - const std::string & outputShape, - const std::string & filterShape, - const std::string & filterData, - const std::string & biasShape, - const std::string & biasData, - const std::string & strides, - const std::string & activation="NONE", - const std::string & filterScale="1.0", - const std::string & filterZeroPoint="0", - const std::string & outputScale="2.0", - const std::string & outputZeroPoint="0") + explicit Conv2DWithBiasesFixture(const std::string& inputShape, + const std::string& outputShape, + const std::string& filterShape, + const std::string& filterData, + const std::string& biasShape, + const std::string& biasData, + const std::string& strides, + const std::string& activation="NONE", + const std::string& filterScale="1.0", + const std::string& filterZeroPoint="0", + const std::string& outputScale="2.0", + const std::string& outputZeroPoint="0", + const std::string& dataType = "UINT8", + const std::string& filterDataType = "UINT8", + const std::string& biasDataType = "INT32") { m_JsonString = R"( { @@ -125,7 +128,7 @@ struct Conv2DWithBiasesFixture : public ParserFlatbuffersFixture "tensors": [ { "shape": )" + inputShape + R"(, - "type": "UINT8", + "type": )" + dataType + R"(, "buffer": 0, "name": "inputTensor", "quantization": { @@ -137,7 +140,7 @@ struct Conv2DWithBiasesFixture : public ParserFlatbuffersFixture }, { "shape": )" + outputShape + R"(, - "type": "UINT8", + "type": )" + dataType + R"(, "buffer": 1, "name": "outputTensor", "quantization": { @@ -149,7 +152,7 @@ struct Conv2DWithBiasesFixture : public ParserFlatbuffersFixture }, { "shape": )" + filterShape + R"( , - "type": "UINT8", + "type": )" + filterDataType + R"(, "buffer": 2, "name": "filterTensor", "quantization": { @@ -161,7 +164,7 @@ struct Conv2DWithBiasesFixture : public ParserFlatbuffersFixture }, { "shape": )" + biasShape + R"( , - "type": "INT32", + "type": )" + biasDataType + R"(, "buffer": 3, "name": "biasTensor", "quantization": { @@ -662,4 +665,41 @@ TEST_CASE_FIXTURE(PerChannelConv2DFixture, "ParsePerChannelConv2D") }); } +struct Conv2FloatWithInt8WeightsAndBiasesFixture : Conv2DWithBiasesFixture +{ + Conv2FloatWithInt8WeightsAndBiasesFixture() + : Conv2DWithBiasesFixture("[ 1, 2, 2, 1 ]", // inputShape + "[ 1, 2, 2, 1 ]", // outputShape + "[ 1, 2, 2, 1 ]", // filterShape + "[ 2,1, 0,6 ]", // filterData + "[ 1 ]", // biasShape + "[ 10, 0, 0, 0 ]", // biasData + "1", // stride w and h + "NONE", // activation + "1.0", // filterScale + "0", // filterZeroPoint + "2.0", // outputScale + "0", // outputZeroPoint + "FLOAT32", // dataType + "INT8", // filterDataType + "INT8") // biasDataType + {} +}; + +TEST_CASE_FIXTURE(Conv2FloatWithInt8WeightsAndBiasesFixture, "ParseConv2FloatWithInt8WeightsAndBiasesFixture") +{ + RunTest<4, armnn::DataType::Float32>( + 0, + { + 1, 2, + 3, 4, + }, + { + (1*2 + 2*1 + 3*0 + 4*6 + 10), + (2*2 + 0*1 + 4*0 + 0*6 + 10), + (3*2 + 4*1 + 0*0 + 0*6 + 10), + (4*2 + 0*1 + 0*0 + 0*6 + 10) + }); +} + } diff --git a/src/armnnTfLiteParser/test/FullyConnected.cpp b/src/armnnTfLiteParser/test/FullyConnected.cpp index fc000bf95b..108b878e20 100644 --- a/src/armnnTfLiteParser/test/FullyConnected.cpp +++ b/src/armnnTfLiteParser/test/FullyConnected.cpp @@ -15,7 +15,10 @@ struct FullyConnectedFixture : public ParserFlatbuffersFixture const std::string& filterShape, const std::string& filterData, const std::string biasShape = "", - const std::string biasData = "") + const std::string biasData = "", + const std::string dataType = "UINT8", + const std::string weightsDataType = "UINT8", + const std::string biasDataType = "INT32") { std::string inputTensors = "[ 0, 2 ]"; std::string biasTensor = ""; @@ -26,7 +29,7 @@ struct FullyConnectedFixture : public ParserFlatbuffersFixture biasTensor = R"( { "shape": )" + biasShape + R"( , - "type": "INT32", + "type": )" + biasDataType + R"(, "buffer": 3, "name": "biasTensor", "quantization": { @@ -47,7 +50,7 @@ struct FullyConnectedFixture : public ParserFlatbuffersFixture "tensors": [ { "shape": )" + inputShape + R"(, - "type": "UINT8", + "type": )" + dataType + R"(, "buffer": 0, "name": "inputTensor", "quantization": { @@ -59,7 +62,7 @@ struct FullyConnectedFixture : public ParserFlatbuffersFixture }, { "shape": )" + outputShape + R"(, - "type": "UINT8", + "type": )" + dataType + R"(, "buffer": 1, "name": "outputTensor", "quantization": { @@ -71,7 +74,7 @@ struct FullyConnectedFixture : public ParserFlatbuffersFixture }, { "shape": )" + filterShape + R"(, - "type": "UINT8", + "type": )" + weightsDataType + R"(, "buffer": 2, "name": "filterTensor", "quantization": { @@ -353,4 +356,27 @@ TEST_CASE_FIXTURE(FullyConnectedNonConstWeightsNoBias, "ParseFullyConnectedNonCo {{"output", { 20 }}}); } +struct FullyConnectedWeightsBiasFloat : FullyConnectedFixture +{ + FullyConnectedWeightsBiasFloat() + : FullyConnectedFixture("[ 1, 4, 1, 1 ]", // inputShape + "[ 1, 1 ]", // outputShape + "[ 1, 4 ]", // filterShape + "[ 2, 3, 4, 5 ]", // filterData + "[ 1 ]", // biasShape + "[ 10, 0, 0, 0 ]", // filterShape + "FLOAT32", // input and output dataType + "INT8", // weights dataType + "FLOAT32") // bias dataType + {} +}; + +TEST_CASE_FIXTURE(FullyConnectedWeightsBiasFloat, "FullyConnectedWeightsBiasFloat") +{ + RunTest<2, armnn::DataType::Float32>( + 0, + { 10, 20, 30, 40 }, + { 400 }); +} + } -- cgit v1.2.1