From 5880b911bf4b7fd8308c93e299d77ac78f282c19 Mon Sep 17 00:00:00 2001 From: Mike Kelly Date: Fri, 28 Jan 2022 16:18:54 +0000 Subject: MLCE-604 Add Unidirectional Sequence Lstm support to TFLite * Added Unidirectional Sequence Lstm support to TFLite Parser * Added support for float operations with int8 weights to TFLite Parser * Added to Conv2d, Conv3D, DepthwiseConv2D, FullyConnected, TransposeConv and UnidirectionalSequenceLstm * Renamed subgraphIndex to subgraph to fix name-shadowing warning. Signed-off-by: Mike Kelly Change-Id: I818976ab88abc05dcb4bad246fb4108e6e879283 --- src/armnnTfLiteParser/test/FullyConnected.cpp | 36 +++++++++++++++++++++++---- 1 file changed, 31 insertions(+), 5 deletions(-) (limited to 'src/armnnTfLiteParser/test/FullyConnected.cpp') diff --git a/src/armnnTfLiteParser/test/FullyConnected.cpp b/src/armnnTfLiteParser/test/FullyConnected.cpp index fc000bf95b..108b878e20 100644 --- a/src/armnnTfLiteParser/test/FullyConnected.cpp +++ b/src/armnnTfLiteParser/test/FullyConnected.cpp @@ -15,7 +15,10 @@ struct FullyConnectedFixture : public ParserFlatbuffersFixture const std::string& filterShape, const std::string& filterData, const std::string biasShape = "", - const std::string biasData = "") + const std::string biasData = "", + const std::string dataType = "UINT8", + const std::string weightsDataType = "UINT8", + const std::string biasDataType = "INT32") { std::string inputTensors = "[ 0, 2 ]"; std::string biasTensor = ""; @@ -26,7 +29,7 @@ struct FullyConnectedFixture : public ParserFlatbuffersFixture biasTensor = R"( { "shape": )" + biasShape + R"( , - "type": "INT32", + "type": )" + biasDataType + R"(, "buffer": 3, "name": "biasTensor", "quantization": { @@ -47,7 +50,7 @@ struct FullyConnectedFixture : public ParserFlatbuffersFixture "tensors": [ { "shape": )" + inputShape + R"(, - "type": "UINT8", + "type": )" + dataType + R"(, "buffer": 0, "name": "inputTensor", "quantization": { @@ -59,7 +62,7 @@ struct FullyConnectedFixture : public ParserFlatbuffersFixture }, { "shape": )" + outputShape + R"(, - "type": "UINT8", + "type": )" + dataType + R"(, "buffer": 1, "name": "outputTensor", "quantization": { @@ -71,7 +74,7 @@ struct FullyConnectedFixture : public ParserFlatbuffersFixture }, { "shape": )" + filterShape + R"(, - "type": "UINT8", + "type": )" + weightsDataType + R"(, "buffer": 2, "name": "filterTensor", "quantization": { @@ -353,4 +356,27 @@ TEST_CASE_FIXTURE(FullyConnectedNonConstWeightsNoBias, "ParseFullyConnectedNonCo {{"output", { 20 }}}); } +struct FullyConnectedWeightsBiasFloat : FullyConnectedFixture +{ + FullyConnectedWeightsBiasFloat() + : FullyConnectedFixture("[ 1, 4, 1, 1 ]", // inputShape + "[ 1, 1 ]", // outputShape + "[ 1, 4 ]", // filterShape + "[ 2, 3, 4, 5 ]", // filterData + "[ 1 ]", // biasShape + "[ 10, 0, 0, 0 ]", // filterShape + "FLOAT32", // input and output dataType + "INT8", // weights dataType + "FLOAT32") // bias dataType + {} +}; + +TEST_CASE_FIXTURE(FullyConnectedWeightsBiasFloat, "FullyConnectedWeightsBiasFloat") +{ + RunTest<2, armnn::DataType::Float32>( + 0, + { 10, 20, 30, 40 }, + { 400 }); +} + } -- cgit v1.2.1