diff options
author | Mike Kelly <mike.kelly@arm.com> | 2022-01-28 16:18:54 +0000 |
---|---|---|
committer | mike.kelly <mike.kelly@arm.com> | 2022-04-22 15:46:50 +0000 |
commit | 5880b911bf4b7fd8308c93e299d77ac78f282c19 (patch) | |
tree | b256346d6fc78e78735cc50ec822286f809dd37f /src/armnnTfLiteParser/test/FullyConnected.cpp | |
parent | 4dae5794644b44be8c93bc6db553a205551bc077 (diff) | |
download | armnn-5880b911bf4b7fd8308c93e299d77ac78f282c19.tar.gz |
MLCE-604 Add Unidirectional Sequence Lstm support to TFLite
* Added Unidirectional Sequence Lstm support to TFLite Parser
* Added support for float operations with int8 weights to TFLite Parser
* Added to Conv2d, Conv3D, DepthwiseConv2D, FullyConnected,
TransposeConv and UnidirectionalSequenceLstm
* Renamed subgraphIndex to subgraph to fix name-shadowing warning.
Signed-off-by: Mike Kelly <mike.kelly@arm.com>
Change-Id: I818976ab88abc05dcb4bad246fb4108e6e879283
Diffstat (limited to 'src/armnnTfLiteParser/test/FullyConnected.cpp')
-rw-r--r-- | src/armnnTfLiteParser/test/FullyConnected.cpp | 36 |
1 files changed, 31 insertions, 5 deletions
diff --git a/src/armnnTfLiteParser/test/FullyConnected.cpp b/src/armnnTfLiteParser/test/FullyConnected.cpp index fc000bf95b..108b878e20 100644 --- a/src/armnnTfLiteParser/test/FullyConnected.cpp +++ b/src/armnnTfLiteParser/test/FullyConnected.cpp @@ -15,7 +15,10 @@ struct FullyConnectedFixture : public ParserFlatbuffersFixture const std::string& filterShape, const std::string& filterData, const std::string biasShape = "", - const std::string biasData = "") + const std::string biasData = "", + const std::string dataType = "UINT8", + const std::string weightsDataType = "UINT8", + const std::string biasDataType = "INT32") { std::string inputTensors = "[ 0, 2 ]"; std::string biasTensor = ""; @@ -26,7 +29,7 @@ struct FullyConnectedFixture : public ParserFlatbuffersFixture biasTensor = R"( { "shape": )" + biasShape + R"( , - "type": "INT32", + "type": )" + biasDataType + R"(, "buffer": 3, "name": "biasTensor", "quantization": { @@ -47,7 +50,7 @@ struct FullyConnectedFixture : public ParserFlatbuffersFixture "tensors": [ { "shape": )" + inputShape + R"(, - "type": "UINT8", + "type": )" + dataType + R"(, "buffer": 0, "name": "inputTensor", "quantization": { @@ -59,7 +62,7 @@ struct FullyConnectedFixture : public ParserFlatbuffersFixture }, { "shape": )" + outputShape + R"(, - "type": "UINT8", + "type": )" + dataType + R"(, "buffer": 1, "name": "outputTensor", "quantization": { @@ -71,7 +74,7 @@ struct FullyConnectedFixture : public ParserFlatbuffersFixture }, { "shape": )" + filterShape + R"(, - "type": "UINT8", + "type": )" + weightsDataType + R"(, "buffer": 2, "name": "filterTensor", "quantization": { @@ -353,4 +356,27 @@ TEST_CASE_FIXTURE(FullyConnectedNonConstWeightsNoBias, "ParseFullyConnectedNonCo {{"output", { 20 }}}); } +struct FullyConnectedWeightsBiasFloat : FullyConnectedFixture +{ + FullyConnectedWeightsBiasFloat() + : FullyConnectedFixture("[ 1, 4, 1, 1 ]", // inputShape + "[ 1, 1 ]", // outputShape + "[ 1, 4 ]", // filterShape + "[ 2, 3, 4, 5 ]", // filterData + "[ 1 ]", // biasShape + "[ 10, 0, 0, 0 ]", // filterShape + "FLOAT32", // input and output dataType + "INT8", // weights dataType + "FLOAT32") // bias dataType + {} +}; + +TEST_CASE_FIXTURE(FullyConnectedWeightsBiasFloat, "FullyConnectedWeightsBiasFloat") +{ + RunTest<2, armnn::DataType::Float32>( + 0, + { 10, 20, 30, 40 }, + { 400 }); +} + } |