diff options
-rw-r--r-- | docs/05_01_parsers.dox | 1 | ||||
-rw-r--r-- | src/armnnTfLiteParser/TfLiteParser.cpp | 6 | ||||
-rw-r--r-- | src/armnnTfLiteParser/TfLiteParser.hpp | 1 | ||||
-rw-r--r-- | src/armnnTfLiteParser/test/ElementWiseUnary.cpp | 13 |
4 files changed, 21 insertions, 0 deletions
diff --git a/docs/05_01_parsers.dox b/docs/05_01_parsers.dox index 03f2cba8d4..8e406e85cf 100644 --- a/docs/05_01_parsers.dox +++ b/docs/05_01_parsers.dox @@ -176,6 +176,7 @@ The Arm NN SDK TensorFlow Lite parser currently supports the following operators - SPLIT - SPLIT_V - SQUEEZE +- SQRT - STRIDED_SLICE - SUB - SUM diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index 32b7c63fbb..7fe954d901 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -730,6 +730,7 @@ TfLiteParserImpl::TfLiteParserImpl(const Optional<ITfLiteParser::TfLiteParserOpt m_ParserFunctions[tflite::BuiltinOperator_RESIZE_BILINEAR] = &TfLiteParserImpl::ParseResizeBilinear; m_ParserFunctions[tflite::BuiltinOperator_RESIZE_NEAREST_NEIGHBOR] = &TfLiteParserImpl::ParseResizeNearestNeighbor; m_ParserFunctions[tflite::BuiltinOperator_RSQRT] = &TfLiteParserImpl::ParseRsqrt; + m_ParserFunctions[tflite::BuiltinOperator_SQRT] = &TfLiteParserImpl::ParseSqrt; m_ParserFunctions[tflite::BuiltinOperator_SHAPE] = &TfLiteParserImpl::ParseShape; m_ParserFunctions[tflite::BuiltinOperator_SLICE] = &TfLiteParserImpl::ParseSlice; m_ParserFunctions[tflite::BuiltinOperator_SOFTMAX] = &TfLiteParserImpl::ParseSoftmax; @@ -4180,6 +4181,11 @@ void TfLiteParserImpl::ParseRsqrt(size_t subgraphIndex, size_t operatorIndex) ParseElementwiseUnary(subgraphIndex, operatorIndex, armnn::UnaryOperation::Rsqrt); } +void TfLiteParserImpl::ParseSqrt(size_t subgraphIndex, size_t operatorIndex) +{ + ParseElementwiseUnary(subgraphIndex, operatorIndex, armnn::UnaryOperation::Sqrt); +} + void TfLiteParserImpl::ParseElementwiseUnary(size_t subgraphIndex, size_t operatorIndex, UnaryOperation unaryOperation) { CHECK_MODEL(m_Model, subgraphIndex, operatorIndex); diff --git a/src/armnnTfLiteParser/TfLiteParser.hpp b/src/armnnTfLiteParser/TfLiteParser.hpp index 6a8992fc0f..3fc72057c7 100644 --- a/src/armnnTfLiteParser/TfLiteParser.hpp +++ b/src/armnnTfLiteParser/TfLiteParser.hpp @@ -174,6 +174,7 @@ private: void ParseShape(size_t subgraphIndex, size_t operatorIndex); void ParseSlice(size_t subgraphIndex, size_t operatorIndex); void ParseSoftmax(size_t subgraphIndex, size_t operatorIndex); + void ParseSqrt(size_t subgraphIndex, size_t operatorIndex); void ParseSpaceToBatchND(size_t subgraphIndex, size_t operatorIndex); void ParseSplit(size_t subgraphIndex, size_t operatorIndex); void ParseSplitV(size_t subgraphIndex, size_t operatorIndex); diff --git a/src/armnnTfLiteParser/test/ElementWiseUnary.cpp b/src/armnnTfLiteParser/test/ElementWiseUnary.cpp index 2d784fb022..bab9a05bd2 100644 --- a/src/armnnTfLiteParser/test/ElementWiseUnary.cpp +++ b/src/armnnTfLiteParser/test/ElementWiseUnary.cpp @@ -141,4 +141,17 @@ TEST_CASE_FIXTURE(SimpleRsqrtFixture, "ParseRsqrt") 0.2f, 0.125f, 0.1f} }}); } +struct SimpleSqrtFixture : public ElementWiseUnaryFixture +{ + SimpleSqrtFixture() : ElementWiseUnaryFixture("SQRT", "FLOAT32", "[ 1, 2, 3, 1 ]", "[ 1, 2, 3, 1 ]") {} +}; + +TEST_CASE_FIXTURE(SimpleSqrtFixture, "ParseSqrt") +{ + RunTest<4, armnn::DataType::Float32>(0, {{ "inputTensor", { 9.0f, 4.0f, 16.0f, + 25.0f, 36.0f, 49.0f } }}, + {{ "outputTensor",{ 3.0f, 2.0f, 4.0f, + 5.0f, 6.0f, 7.0f} }}); +} + } |