From 2f746b3f346d3efa9071fc53592652425869d6b3 Mon Sep 17 00:00:00 2001 From: Jan Eilers Date: Tue, 28 Jul 2020 14:00:06 +0100 Subject: Github#433 Add HardSwish support to TfLiteParser Signed-off-by: Jan Eilers Change-Id: Ic476f8d80bba080ab459db9e6a59cbafd307d129 --- src/armnnTfLiteParser/TfLiteParser.cpp | 10 +++++++++- src/armnnTfLiteParser/TfLiteParser.hpp | 1 + src/armnnTfLiteParser/test/Activations.cpp | 12 ++++++++++++ 3 files changed, 22 insertions(+), 1 deletion(-) diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index 69430134df..1a4449302e 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -538,6 +538,7 @@ TfLiteParser::TfLiteParser(const Optional& o m_ParserFunctions[tflite::BuiltinOperator_DEQUANTIZE] = &TfLiteParser::ParseDequantize; m_ParserFunctions[tflite::BuiltinOperator_EXP] = &TfLiteParser::ParseExp; m_ParserFunctions[tflite::BuiltinOperator_FULLY_CONNECTED] = &TfLiteParser::ParseFullyConnected; + m_ParserFunctions[tflite::BuiltinOperator_HARD_SWISH] = &TfLiteParser::ParseHardSwish; m_ParserFunctions[tflite::BuiltinOperator_LEAKY_RELU] = &TfLiteParser::ParseLeakyRelu; m_ParserFunctions[tflite::BuiltinOperator_LOGISTIC] = &TfLiteParser::ParseLogistic; m_ParserFunctions[tflite::BuiltinOperator_L2_NORMALIZATION] = &TfLiteParser::ParseL2Normalization; @@ -1992,7 +1993,7 @@ void TfLiteParser::ParseRelu6(size_t subgraphIndex, size_t operatorIndex) void TfLiteParser::ParseLeakyRelu(size_t subgraphIndex, size_t operatorIndex) { - ParseActivation(subgraphIndex,operatorIndex, ActivationFunction::LeakyReLu); + ParseActivation(subgraphIndex, operatorIndex, ActivationFunction::LeakyReLu); } void TfLiteParser::ParseLogistic(size_t subgraphIndex, size_t operatorIndex) @@ -2005,6 +2006,10 @@ void TfLiteParser::ParseTanH(size_t subgraphIndex, size_t operatorIndex) ParseActivation(subgraphIndex,operatorIndex,ActivationFunction::TanH); } +void TfLiteParser::ParseHardSwish(size_t subgraphIndex, size_t operatorIndex) +{ + ParseActivation(subgraphIndex, operatorIndex, ActivationFunction::HardSwish); +} void TfLiteParser::ParseActivation(size_t subgraphIndex, size_t operatorIndex, ActivationFunction activationType) { @@ -2055,6 +2060,9 @@ void TfLiteParser::ParseActivation(size_t subgraphIndex, size_t operatorIndex, A activationDesc.m_A = options->alpha; break; } + case ActivationFunction::HardSwish: + layerName += str(boost::format("HARDSWISH:%1%:%2%") % subgraphIndex % operatorIndex); + break; default: { throw ParseException( diff --git a/src/armnnTfLiteParser/TfLiteParser.hpp b/src/armnnTfLiteParser/TfLiteParser.hpp index c252b0f73f..6a611509f4 100644 --- a/src/armnnTfLiteParser/TfLiteParser.hpp +++ b/src/armnnTfLiteParser/TfLiteParser.hpp @@ -104,6 +104,7 @@ private: void ParseDetectionPostProcess(size_t subgraphIndex, size_t operatorIndex); void ParseExp(size_t subgraphIndex, size_t operatorIndex); void ParseFullyConnected(size_t subgraphIndex, size_t operatorIndex); + void ParseHardSwish(size_t subgraphIndex, size_t operatorIndex); void ParseLeakyRelu(size_t subgraphIndex, size_t operatorIndex); void ParseLogistic(size_t subgraphIndex, size_t operatorIndex); void ParseL2Normalization(size_t subgraphIndex, size_t operatorIndex); diff --git a/src/armnnTfLiteParser/test/Activations.cpp b/src/armnnTfLiteParser/test/Activations.cpp index e8153a28cb..e57477e620 100644 --- a/src/armnnTfLiteParser/test/Activations.cpp +++ b/src/armnnTfLiteParser/test/Activations.cpp @@ -105,4 +105,16 @@ BOOST_FIXTURE_TEST_CASE(ParseTanH, TanHFixture) { -0.1f, -0.2f, -0.3f, -0.4f, 0.1f, 0.2f, 0.3f }, { -0.09966799f, -0.19737528f, -0.29131261f, -0.379949f, 0.09966799f, 0.19737528f, 0.29131261f }); } + +struct HardSwishFixture : ActivationFixture +{ + HardSwishFixture() : ActivationFixture("HARD_SWISH", "FLOAT32") {} +}; + +BOOST_FIXTURE_TEST_CASE(ParseHardSwish, HardSwishFixture) +{ + RunTest<2, armnn::DataType::Float32>(0, + { -4.0f, -3.0f, -2.9f, 1.2f, 2.2f, 3.0f, 4.0f }, + { -0.0f, -0.0f, -0.04833334f, 0.84f, 1.90666667f, 3.0f, 4.0f }); +} BOOST_AUTO_TEST_SUITE_END() -- cgit v1.2.1