diff options
author | Teresa Charlin <teresa.charlinreyes@arm.com> | 2023-09-15 15:19:21 +0100 |
---|---|---|
committer | TeresaARM <teresa.charlinreyes@arm.com> | 2023-09-29 11:05:29 +0000 |
commit | 077cddbe9e956c6740557a9add499385f235c384 (patch) | |
tree | ae1816443bf4f85c7968aa3e542ef2b5e5400e7e | |
parent | 9a45e8fab86f7078d22360794058f5550413df78 (diff) | |
download | armnn-077cddbe9e956c6740557a9add499385f235c384.tar.gz |
IVGCVSW-8055 Add support for GELU activation function.
* Add support to CpuRef, CpuAcc and GpuAcc
* Add support to tflite parser, classic and opaque tflite delegates
* Add support to serializer and deserializer
* Add Unit tests
Signed-off-by: Teresa Charlin <teresa.charlinreyes@arm.com>
Change-Id: Ibc60ef2ef2a051e6d9af6e15d24c46316ec19de4
24 files changed, 217 insertions, 4 deletions
diff --git a/delegate/classic/src/Activation.hpp b/delegate/classic/src/Activation.hpp index a93cee43a3..1c55c2e9b5 100644 --- a/delegate/classic/src/Activation.hpp +++ b/delegate/classic/src/Activation.hpp @@ -109,6 +109,11 @@ TfLiteStatus VisitActivationOperator(DelegateData& delegateData, activationDesc.m_A = leakyReluParameters->alpha; break; } + case kTfLiteBuiltinGelu: + { + activationDesc.m_Function = armnn::ActivationFunction::Gelu; + break; + } default: { return kTfLiteError; diff --git a/delegate/classic/src/armnn_delegate.cpp b/delegate/classic/src/armnn_delegate.cpp index c8f57d6cc3..6054de5c5e 100644 --- a/delegate/classic/src/armnn_delegate.cpp +++ b/delegate/classic/src/armnn_delegate.cpp @@ -729,6 +729,12 @@ TfLiteStatus ArmnnSubgraph::VisitNode(DelegateData& delegateData, tfLiteNode, nodeIndex, kTfLiteBuiltinGatherNd); + case kTfLiteBuiltinGelu: + return VisitActivationOperator(delegateData, + tfLiteContext, + tfLiteNode, + nodeIndex, + kTfLiteBuiltinGelu); case kTfLiteBuiltinGreater: return VisitComparisonOperator(delegateData, tfLiteContext, diff --git a/delegate/opaque/src/Activation.hpp b/delegate/opaque/src/Activation.hpp index dd9c2f68bc..ad242e5799 100644 --- a/delegate/opaque/src/Activation.hpp +++ b/delegate/opaque/src/Activation.hpp @@ -24,6 +24,9 @@ std::string GetLayerName(armnn::ActivationFunction activationFunction) case armnn::ActivationFunction::Elu: layerName += " ELU"; break; + case armnn::ActivationFunction::Gelu: + layerName += " GELU"; + break; case armnn::ActivationFunction::HardSwish: layerName += " HARD_SWISH"; break; @@ -175,6 +178,11 @@ TfLiteStatus VisitActivationOperator(DelegateData& delegateData, activationDesc.m_A = leakyReluParameters->alpha; break; } + case kTfLiteBuiltinGelu: + { + activationDesc.m_Function = armnn::ActivationFunction::Gelu; + break; + } default: { return kTfLiteError; diff --git a/delegate/opaque/src/armnn_delegate.cpp b/delegate/opaque/src/armnn_delegate.cpp index 08b1504efb..6abf7398cc 100644 --- a/delegate/opaque/src/armnn_delegate.cpp +++ b/delegate/opaque/src/armnn_delegate.cpp @@ -808,6 +808,12 @@ TfLiteStatus ArmnnSubgraph::VisitNode(DelegateData& delegateData, tfLiteNode, nodeIndex, kTfLiteBuiltinGatherNd); + case kTfLiteBuiltinGelu: + return VisitActivationOperator(delegateData, + tfLiteContext, + tfLiteNode, + nodeIndex, + kTfLiteBuiltinGelu); case kTfLiteBuiltinGreater: return VisitComparisonOperator(delegateData, tfLiteContext, diff --git a/delegate/test/ActivationTest.cpp b/delegate/test/ActivationTest.cpp index 620c299803..70321cd7e5 100644 --- a/delegate/test/ActivationTest.cpp +++ b/delegate/test/ActivationTest.cpp @@ -196,6 +196,33 @@ void ActivationLeakyReLuTest(std::vector<armnn::BackendId>& backends) alpha); } +void ActivationGeluTest(std::vector<armnn::BackendId>& backends) +{ + std::vector<float> inputData = + { + -0.1f, -0.2f, -0.3f, -0.4f, + 0.1f, 0.2f, 0.3f, 0.4f, + -1.0f, -2.0f, -3.0f, -4.0f, + 1.0f, 2.0f, 3.0f, 4.0f + }; + + // Calculate output values for input. + auto f = [](float x) + { + // gelu(x) = x * 1/2 * (1 + erf(x / sqrt(2))), + // where erf is Gaussian error function + auto result = x * (0.5f * (1.0f + erff(static_cast<float>(x / std::sqrt(2))))); + return result; + }; + std::vector<float> outputExpectedData(inputData.size()); + std::transform(inputData.begin(), inputData.end(), outputExpectedData.begin(), f); + + ActivationTest(tflite::BuiltinOperator_GELU, + backends, + inputData, + outputExpectedData); +} + TEST_SUITE("Activation_CpuRefTests") { @@ -241,6 +268,12 @@ TEST_CASE ("Activation_LeakyRelu_CpuRef_Test") ActivationLeakyReLuTest(backends); } +TEST_CASE ("Activation_Gelu_CpuRef_Test") +{ + std::vector<armnn::BackendId> backends = { armnn::Compute::CpuRef }; + ActivationGeluTest(backends); +} + } TEST_SUITE("Activation_CpuAccTests") @@ -288,6 +321,12 @@ TEST_CASE ("Activation_LeakyRelu_CpuAcc_Test") ActivationLeakyReLuTest(backends); } +TEST_CASE ("Activation_Gelu_CpuAcc_Test") +{ + std::vector<armnn::BackendId> backends = { armnn::Compute::CpuAcc }; + ActivationGeluTest(backends); +} + } TEST_SUITE("Activation_GpuAccTests") @@ -335,6 +374,12 @@ TEST_CASE ("Activation_LeakyRelu_GpuAcc_Test") ActivationLeakyReLuTest(backends); } +TEST_CASE ("Activation_Gelu_GpuAcc_Test") +{ + std::vector<armnn::BackendId> backends = { armnn::Compute::GpuAcc }; + ActivationGeluTest(backends); +} + } } // namespace armnnDelegate
\ No newline at end of file diff --git a/delegate/test/ActivationTestHelper.hpp b/delegate/test/ActivationTestHelper.hpp index b0a4d6785d..7beb53ba3d 100644 --- a/delegate/test/ActivationTestHelper.hpp +++ b/delegate/test/ActivationTestHelper.hpp @@ -73,7 +73,11 @@ std::vector<char> CreateActivationTfLiteModel(tflite::BuiltinOperator activation flatbuffers::Offset <flatbuffers::String> modelDescription = flatBufferBuilder.CreateString("ArmnnDelegate: Activation Operator Model"); - flatbuffers::Offset <OperatorCode> operatorCode = CreateOperatorCode(flatBufferBuilder, activationOperatorCode); + flatbuffers::Offset <OperatorCode> operatorCode = CreateOperatorCode(flatBufferBuilder, + activationOperatorCode, + 0, + 1, + activationOperatorCode); flatbuffers::Offset <Model> flatbufferModel = CreateModel(flatBufferBuilder, diff --git a/docs/05_01_parsers.dox b/docs/05_01_parsers.dox index 4454f44218..7dcf8e2553 100644 --- a/docs/05_01_parsers.dox +++ b/docs/05_01_parsers.dox @@ -141,6 +141,7 @@ The Arm NN SDK TensorFlow Lite parser currently supports the following operators - FULLY_CONNECTED, Supported Fused Activation: RELU , RELU6 , TANH, NONE - GATHER - GATHER_ND +- GELU - GREATER - GREATER_EQUAL - HARD_SWISH diff --git a/docs/05_03_delegate.dox b/docs/05_03_delegate.dox index 9a40a8ac9e..dde50e13f7 100644 --- a/docs/05_03_delegate.dox +++ b/docs/05_03_delegate.dox @@ -86,6 +86,8 @@ The Arm NN SDK TensorFlow Lite delegate currently supports the following operato - GATHER_ND +- GELU + - GREATER - GREATER_EQUAL diff --git a/include/armnn/Types.hpp b/include/armnn/Types.hpp index 933e7da412..d87e7f7147 100644 --- a/include/armnn/Types.hpp +++ b/include/armnn/Types.hpp @@ -96,7 +96,8 @@ enum class ActivationFunction Sqrt = 8, Square = 9, Elu = 10, - HardSwish = 11 + HardSwish = 11, + Gelu = 12 }; enum class ArgMinMaxFunction diff --git a/include/armnn/TypesUtils.hpp b/include/armnn/TypesUtils.hpp index ca098f60fb..b5f224ed75 100644 --- a/include/armnn/TypesUtils.hpp +++ b/include/armnn/TypesUtils.hpp @@ -44,6 +44,7 @@ constexpr char const* GetActivationFunctionAsCString(ActivationFunction activati case ActivationFunction::Square: return "Square"; case ActivationFunction::Elu: return "Elu"; case ActivationFunction::HardSwish: return "HardSwish"; + case ActivationFunction::Gelu: return "Gelu"; default: return "Unknown"; } } diff --git a/src/armnnDeserializer/Deserializer.cpp b/src/armnnDeserializer/Deserializer.cpp index 505f4d88a4..6e2c07bf37 100644 --- a/src/armnnDeserializer/Deserializer.cpp +++ b/src/armnnDeserializer/Deserializer.cpp @@ -502,6 +502,8 @@ armnn::ActivationFunction ToActivationFunction(armnnSerializer::ActivationFuncti return armnn::ActivationFunction::Elu; case armnnSerializer::ActivationFunction_HardSwish: return armnn::ActivationFunction::HardSwish; + case armnnSerializer::ActivationFunction_Gelu: + return armnn::ActivationFunction::Gelu; default: return armnn::ActivationFunction::Sigmoid; } diff --git a/src/armnnSerializer/ArmnnSchema.fbs b/src/armnnSerializer/ArmnnSchema.fbs index ec4b48639d..131970e449 100644 --- a/src/armnnSerializer/ArmnnSchema.fbs +++ b/src/armnnSerializer/ArmnnSchema.fbs @@ -21,7 +21,8 @@ enum ActivationFunction : byte { Sqrt = 8, Square = 9, Elu = 10, - HardSwish = 11 + HardSwish = 11, + Gelu = 12 } enum ArgMinMaxFunction : byte { diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp index 6cadb598a2..0df675d1db 100644 --- a/src/armnnSerializer/Serializer.cpp +++ b/src/armnnSerializer/Serializer.cpp @@ -78,6 +78,8 @@ serializer::ActivationFunction GetFlatBufferActivationFunction(armnn::Activation return serializer::ActivationFunction::ActivationFunction_Elu; case armnn::ActivationFunction::HardSwish: return serializer::ActivationFunction::ActivationFunction_HardSwish; + case armnn::ActivationFunction::Gelu: + return serializer::ActivationFunction::ActivationFunction_Gelu; default: return serializer::ActivationFunction::ActivationFunction_Sigmoid; } diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index 3f4f0d811f..c2be54f5f5 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -771,6 +771,7 @@ TfLiteParserImpl::TfLiteParserImpl(const Optional<ITfLiteParser::TfLiteParserOpt m_ParserFunctions[tflite::BuiltinOperator_FLOOR_DIV] = &TfLiteParserImpl::ParseFloorDiv; m_ParserFunctions[tflite::BuiltinOperator_FULLY_CONNECTED] = &TfLiteParserImpl::ParseFullyConnected; m_ParserFunctions[tflite::BuiltinOperator_GATHER] = &TfLiteParserImpl::ParseGather; + m_ParserFunctions[tflite::BuiltinOperator_GELU] = &TfLiteParserImpl::ParseGelu; m_ParserFunctions[tflite::BuiltinOperator_GATHER_ND] = &TfLiteParserImpl::ParseGatherNd; m_ParserFunctions[tflite::BuiltinOperator_GREATER] = &TfLiteParserImpl::ParseGreater; m_ParserFunctions[tflite::BuiltinOperator_GREATER_EQUAL] = &TfLiteParserImpl::ParseGreaterOrEqual; @@ -3159,6 +3160,11 @@ void TfLiteParserImpl::ParseHardSwish(size_t subgraphIndex, size_t operatorIndex ParseActivation(subgraphIndex, operatorIndex, ActivationFunction::HardSwish); } +void TfLiteParserImpl::ParseGelu(size_t subgraphIndex, size_t operatorIndex) +{ + ParseActivation(subgraphIndex,operatorIndex,ActivationFunction::Gelu); +} + void TfLiteParserImpl::ParseActivation(size_t subgraphIndex, size_t operatorIndex, ActivationFunction activationType) { CHECK_MODEL(m_Model, subgraphIndex, operatorIndex); @@ -3219,6 +3225,11 @@ void TfLiteParserImpl::ParseActivation(size_t subgraphIndex, size_t operatorInde layerName += fmt::format("HARDSWISH:{}:{}", subgraphIndex, operatorIndex); break; } + case ActivationFunction::Gelu: + { + layerName += fmt::format("GELU:{}:{}", subgraphIndex, operatorIndex); + break; + } default: { throw ParseException( diff --git a/src/armnnTfLiteParser/TfLiteParser.hpp b/src/armnnTfLiteParser/TfLiteParser.hpp index f0c7ddefb9..4ee0650626 100644 --- a/src/armnnTfLiteParser/TfLiteParser.hpp +++ b/src/armnnTfLiteParser/TfLiteParser.hpp @@ -141,6 +141,7 @@ private: void ParseFullyConnected(size_t subgraphIndex, size_t operatorIndex); void ParseGather(size_t subgraphIndex, size_t operatorIndex); void ParseGatherNd(size_t subgraphIndex, size_t operatorIndex); + void ParseGelu(size_t subgraphIndex, size_t operatorIndex); void ParseGreater(size_t subgraphIndex, size_t operatorIndex); void ParseGreaterOrEqual(size_t subgraphIndex, size_t operatorIndex); void ParseHardSwish(size_t subgraphIndex, size_t operatorIndex); diff --git a/src/armnnTfLiteParser/test/Activations.cpp b/src/armnnTfLiteParser/test/Activations.cpp index a3d75edab5..9cd03a3c9d 100644 --- a/src/armnnTfLiteParser/test/Activations.cpp +++ b/src/armnnTfLiteParser/test/Activations.cpp @@ -130,4 +130,16 @@ TEST_CASE_FIXTURE(HardSwishFixture, "ParseHardSwish") { -0.0f, -0.0f, -0.04833334f, 0.84f, 1.90666667f, 3.0f, 4.0f }); } +struct GeluFixture : ActivationFixture +{ + GeluFixture() : ActivationFixture("GELU", "FLOAT32") {} +}; + +TEST_CASE_FIXTURE(GeluFixture, "ParseGelu") +{ + RunTest<2, armnn::DataType::Float32>(0, + { -4.0f, -3.0f, -2.9f, 1.2f, 2.2f, 3.0f, 4.0f }, + {-0.000126361847f, -0.00404950976f, -0.00541083235f, 1.06191647f, 2.16941237f, 2.9959507f, 3.99987364f }); +} + } diff --git a/src/backends/aclCommon/ArmComputeUtils.hpp b/src/backends/aclCommon/ArmComputeUtils.hpp index fc59b281b5..f466ab1777 100644 --- a/src/backends/aclCommon/ArmComputeUtils.hpp +++ b/src/backends/aclCommon/ArmComputeUtils.hpp @@ -77,6 +77,7 @@ ConvertActivationFunctionToAclActivationFunction(ActivationFunction armnnFunctio case ActivationFunction::TanH: return AclActivationFunction::TANH; case ActivationFunction::Elu: return AclActivationFunction::ELU; case ActivationFunction::HardSwish: return AclActivationFunction::HARD_SWISH; + case ActivationFunction::Gelu: return AclActivationFunction::GELU; default: throw InvalidArgumentException("Unsupported activation function"); } } diff --git a/src/backends/backendsCommon/test/layerTests/ActivationTestImpl.cpp b/src/backends/backendsCommon/test/layerTests/ActivationTestImpl.cpp index 1dcbdfac9e..b562a8af32 100644 --- a/src/backends/backendsCommon/test/layerTests/ActivationTestImpl.cpp +++ b/src/backends/backendsCommon/test/layerTests/ActivationTestImpl.cpp @@ -1218,6 +1218,71 @@ LayerTestResult<int16_t, 4> HardSwishInt16Test( template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>> +LayerTestResult<T, 4> GeluTestCommon( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory, + float qScale, + int32_t qOffset) +{ + std::vector<float> inputData = + { + -0.1f, -0.2f, -0.3f, -0.4f, + 0.1f, 0.2f, 0.3f, 0.4f, + -1.0f, -2.0f, -3.0f, -4.0f, + 1.0f, 2.0f, 3.0f, 4.0f + }; + // Calculate output values for input. + auto f = [](float x) + { + // gelu(x) = x * 1/2 * (1 + erf(x / sqrt(2))), + // where erf is Gaussian error function + auto result = x * (0.5f * (1.0f + erff(static_cast<float>(x / std::sqrt(2))))); + return result; + }; + std::vector<float> expectedOutput(inputData.size()); + std::transform(inputData.begin(), inputData.end(), expectedOutput.begin(), f); + + return SimpleActivationTest<ArmnnType>(workloadFactory, + memoryManager, + tensorHandleFactory, + armnn::ActivationFunction::Gelu, + 0.f, + 0.f, + qScale, + qOffset, + inputData, + qScale, + qOffset, + expectedOutput); +} + +LayerTestResult<float, 4> GeluTest( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory) +{ + return GeluTestCommon<armnn::DataType::Float32>(workloadFactory, memoryManager, tensorHandleFactory, 0.1f, 0); +} + +LayerTestResult<uint8_t, 4> GeluUint8Test( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory) +{ + return GeluTestCommon<armnn::DataType::QAsymmU8>(workloadFactory, memoryManager, tensorHandleFactory, 0.1f, 64); +} + +LayerTestResult<int16_t, 4> GeluInt16Test( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory) +{ + return GeluTestCommon<armnn::DataType::QSymmS16>(workloadFactory, memoryManager, tensorHandleFactory, 0.1f, 0); +} + + +template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>> LayerTestResult<T, 4> CompareActivationTestImpl( armnn::IWorkloadFactory& workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, diff --git a/src/backends/backendsCommon/test/layerTests/ActivationTestImpl.hpp b/src/backends/backendsCommon/test/layerTests/ActivationTestImpl.hpp index e23cd32583..5df6813466 100644 --- a/src/backends/backendsCommon/test/layerTests/ActivationTestImpl.hpp +++ b/src/backends/backendsCommon/test/layerTests/ActivationTestImpl.hpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017-2021, 2023 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -274,6 +274,25 @@ LayerTestResult<int16_t, 4> HardSwishInt16Test( const armnn::ITensorHandleFactory& tensorHandleFactory); // +// Gelu +// + +LayerTestResult<float, 4> GeluTest( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory); + +LayerTestResult<uint8_t, 4> GeluUint8Test( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory); + +LayerTestResult<int16_t, 4> GeluInt16Test( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory); + +// // Other // diff --git a/src/backends/cl/test/ClLayerTests.cpp b/src/backends/cl/test/ClLayerTests.cpp index 33e1b69ade..a596a01be8 100644 --- a/src/backends/cl/test/ClLayerTests.cpp +++ b/src/backends/cl/test/ClLayerTests.cpp @@ -73,6 +73,9 @@ ARMNN_AUTO_TEST_FIXTURE_WITH_THF(Tanh, ClContextControlFixture, TanhTest) // Elu Activation ARMNN_AUTO_TEST_FIXTURE_WITH_THF(Elu, ClContextControlFixture, EluTest) +// Gelu Activation +ARMNN_AUTO_TEST_FIXTURE_WITH_THF(Gelu, ClContextControlFixture, GeluTest) + // Batch Mat Mul ARMNN_AUTO_TEST_FIXTURE_WITH_THF(BatchMatMul2DSimpleFloat32, ClContextControlFixture, diff --git a/src/backends/neon/test/NeonLayerTests.cpp b/src/backends/neon/test/NeonLayerTests.cpp index 658d718b19..a938ceb9c3 100644 --- a/src/backends/neon/test/NeonLayerTests.cpp +++ b/src/backends/neon/test/NeonLayerTests.cpp @@ -722,6 +722,9 @@ ARMNN_AUTO_TEST_CASE_WITH_THF(Tanh, TanhTest) // Elu Activation ARMNN_AUTO_TEST_CASE_WITH_THF(Elu, EluTest) +// Gelu Activation +ARMNN_AUTO_TEST_CASE_WITH_THF(Gelu, GeluTest) + // Softmax // Moved to NeonLayerTests_NDK_Bug.cpp //ARMNN_AUTO_TEST_CASE_WITH_THF(SimpleSoftmaxBeta1, SimpleSoftmaxTest, 1.0f) diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index defdf0d807..167639a733 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -588,6 +588,7 @@ bool RefLayerSupport::IsActivationSupported(const TensorInfo& input, case ActivationFunction::Abs: case ActivationFunction::BoundedReLu: case ActivationFunction::Elu: + case ActivationFunction::Gelu: case ActivationFunction::HardSwish: case ActivationFunction::LeakyReLu: case ActivationFunction::Linear: diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp index af4ed966b2..cfe85594b3 100644 --- a/src/backends/reference/test/RefLayerTests.cpp +++ b/src/backends/reference/test/RefLayerTests.cpp @@ -770,11 +770,17 @@ ARMNN_AUTO_TEST_CASE_WITH_THF(TanhInt16, TanhInt16Test) ARMNN_AUTO_TEST_CASE_WITH_THF(Elu, EluTest) ARMNN_AUTO_TEST_CASE_WITH_THF(EluUint8, EluUint8Test) ARMNN_AUTO_TEST_CASE_WITH_THF(EluInt16, EluInt16Test) + // HardSwish Activation ARMNN_AUTO_TEST_CASE_WITH_THF(HardSwish, HardSwishTest) ARMNN_AUTO_TEST_CASE_WITH_THF(HardSwishUint8, HardSwishUint8Test) ARMNN_AUTO_TEST_CASE_WITH_THF(HardSwishInt16, HardSwishInt16Test) +// Gelu Activation +ARMNN_AUTO_TEST_CASE_WITH_THF(Gelu, GeluTest) +ARMNN_AUTO_TEST_CASE_WITH_THF(GeluUint8, GeluUint8Test) +ARMNN_AUTO_TEST_CASE_WITH_THF(GeluInt16, GeluInt16Test) + // Fully Connected ARMNN_AUTO_TEST_CASE_WITH_THF(SimpleFullyConnected, FullyConnectedFloat32Test, false, false) ARMNN_AUTO_TEST_CASE_WITH_THF(FullyConnectedUint8, FullyConnectedTest<DataType::QAsymmU8>, false, true) diff --git a/src/backends/reference/workloads/Activation.cpp b/src/backends/reference/workloads/Activation.cpp index 8de0e8b3b2..1577543fe4 100644 --- a/src/backends/reference/workloads/Activation.cpp +++ b/src/backends/reference/workloads/Activation.cpp @@ -82,6 +82,13 @@ float Activation(float in, output = in * (std::min(std::max((in + 3),0.0f),6.0f)) / 6; break; } + case ActivationFunction::Gelu: + { + // gelu(x) = x * 1/2 * (1 + erf(x / sqrt(2))), + // where erf is Gaussian error function + output = in * (0.5f * (1.0f + erff(static_cast<float>(in / std::sqrt(2))))); + break; + } default: { throw InvalidArgumentException("Unsupported activation function"); |