aboutsummaryrefslogtreecommitdiff
path: root/delegate/src/test/ActivationTest.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'delegate/src/test/ActivationTest.cpp')
-rw-r--r--delegate/src/test/ActivationTest.cpp98
1 files changed, 94 insertions, 4 deletions
diff --git a/delegate/src/test/ActivationTest.cpp b/delegate/src/test/ActivationTest.cpp
index f894d67372..69041d77a2 100644
--- a/delegate/src/test/ActivationTest.cpp
+++ b/delegate/src/test/ActivationTest.cpp
@@ -22,7 +22,6 @@ namespace armnnDelegate
void ActivationReLuTest(std::vector<armnn::BackendId>& backends)
{
-
std::vector<float> inputData = {
-0.1f, -0.2f, -0.3f, -0.4f,
0.1f, 0.2f, 0.3f, 0.4f,
@@ -116,6 +115,64 @@ void ActivationTanHTest(std::vector<armnn::BackendId>& backends)
outputExpectedData);
}
+void ActivationEluTest(std::vector<armnn::BackendId>& backends)
+{
+ std::vector<float> inputData = {
+ -0.1f, -0.2f, -0.3f, -0.4f,
+ 0.1f, 0.2f, 0.3f, 0.4f,
+ -1.0f, -2.0f, -3.0f, -4.0f,
+ 1.0f, 2.0f, 3.0f, 4.0f
+ };
+
+ // Calculate output values for input.
+ auto f = [](float value)
+ {
+ if (value < 0)
+ {
+ // alpha * (exp(x) - 1)
+ return 1 * (std::exp(value) - 1);
+ }
+ return value;
+ };
+ std::vector<float> outputExpectedData(inputData.size());
+ std::transform(inputData.begin(), inputData.end(), outputExpectedData.begin(), f);
+
+ ActivationTest(tflite::BuiltinOperator_ELU,
+ backends,
+ inputData,
+ outputExpectedData);
+}
+
+void ActivationHardSwishTest(std::vector<armnn::BackendId>& backends)
+{
+ std::vector<float> inputData = {
+ -0.1f, -0.2f, -0.3f, -0.4f,
+ 0.1f, 0.2f, 0.3f, 0.4f,
+ -1.0f, -2.0f, -3.0f, -4.0f,
+ 1.0f, 2.0f, 3.0f, 4.0f
+ };
+
+ // Calculate output values for input.
+ auto f = [](float x)
+ {
+ // Break down the calculation to help with verification.
+ // hard_swish(x) = x * relu6(x+3) / 6
+ // relu6(x) = min(max(x,0),6)
+ float reLu6_step1 = std::max((x + 3),0.0f);
+ float reLu6Complete = std::min(reLu6_step1, 6.0f);
+ float hardSwish_step1 = x * reLu6Complete;
+ float result = hardSwish_step1 / 6;
+ return result;
+ };
+ std::vector<float> outputExpectedData(inputData.size());
+ std::transform(inputData.begin(), inputData.end(), outputExpectedData.begin(), f);
+
+ ActivationTest(tflite::BuiltinOperator_HARD_SWISH,
+ backends,
+ inputData,
+ outputExpectedData);
+}
+
TEST_SUITE("Activation_CpuRefTests")
{
@@ -137,13 +194,24 @@ TEST_CASE ("Activation_Sigmoid_CpuRef_Test")
ActivationSigmoidTest(backends);
}
-
TEST_CASE ("Activation_TanH_CpuRef_Test")
{
std::vector<armnn::BackendId> backends = { armnn::Compute::CpuRef };
ActivationTanHTest(backends);
}
+TEST_CASE ("Activation_Elu_CpuRef_Test")
+{
+ std::vector<armnn::BackendId> backends = { armnn::Compute::CpuRef };
+ ActivationEluTest(backends);
+}
+
+TEST_CASE ("Activation_HardSwish_CpuRef_Test")
+{
+ std::vector<armnn::BackendId> backends = { armnn::Compute::CpuRef };
+ ActivationHardSwishTest(backends);
+}
+
}
TEST_SUITE("Activation_CpuAccTests")
@@ -167,13 +235,24 @@ TEST_CASE ("Activation_Sigmoid_CpuAcc_Test")
ActivationSigmoidTest(backends);
}
-
TEST_CASE ("Activation_TanH_CpuAcc_Test")
{
std::vector<armnn::BackendId> backends = { armnn::Compute::CpuAcc };
ActivationTanHTest(backends);
}
+TEST_CASE ("Activation_Elu_CpuAcc_Test")
+{
+ std::vector<armnn::BackendId> backends = { armnn::Compute::CpuAcc };
+ ActivationEluTest(backends);
+}
+
+TEST_CASE ("Activation_HardSwish_CpuAcc_Test")
+{
+ std::vector<armnn::BackendId> backends = { armnn::Compute::CpuAcc };
+ ActivationHardSwishTest(backends);
+}
+
}
TEST_SUITE("Activation_GpuAccTests")
@@ -197,13 +276,24 @@ TEST_CASE ("Activation_Sigmoid_GpuAcc_Test")
ActivationSigmoidTest(backends);
}
-
TEST_CASE ("Activation_TanH_GpuAcc_Test")
{
std::vector<armnn::BackendId> backends = { armnn::Compute::GpuAcc };
ActivationTanHTest(backends);
}
+TEST_CASE ("Activation_Elu_GpuAcc_Test")
+{
+ std::vector<armnn::BackendId> backends = { armnn::Compute::GpuAcc };
+ ActivationEluTest(backends);
+}
+
+TEST_CASE ("Activation_HardSwish_GpuAcc_Test")
+{
+ std::vector<armnn::BackendId> backends = { armnn::Compute::GpuAcc };
+ ActivationHardSwishTest(backends);
+}
+
}
} // namespace armnnDelegate \ No newline at end of file