aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNarumol Prangnawarat <narumol.prangnawarat@arm.com>2021-05-25 14:26:24 +0100
committerNarumol Prangnawarat <narumol.prangnawarat@arm.com>2021-05-25 15:05:06 +0100
commit4a4af11f2e02465df8f5aa2b2de19d00c2a8ea3d (patch)
treed9205277fff0528310f36859dbd76aea7e7f17f3
parentbfaee6b574301a54eab07a6021c39ae710977f7f (diff)
downloadarmnn-4a4af11f2e02465df8f5aa2b2de19d00c2a8ea3d.tar.gz
IVGCVSW-3649 Add Prelu with different alpha dimension test to TfLiteParser
Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com> Change-Id: I982ecd66ea3ed4d88934cd8254832eecb4a7adb4
-rw-r--r--docs/01_01_parsers.dox1
-rw-r--r--src/armnn/optimizations/AddBroadcastReshapeLayer.hpp3
-rw-r--r--src/armnnTfLiteParser/test/Prelu.cpp28
3 files changed, 26 insertions, 6 deletions
diff --git a/docs/01_01_parsers.dox b/docs/01_01_parsers.dox
index af87eba7af..761380c939 100644
--- a/docs/01_01_parsers.dox
+++ b/docs/01_01_parsers.dox
@@ -133,6 +133,7 @@ The Arm NN SDK TensorFlow Lite parser currently supports the following operators
- NEG
- PACK
- PAD
+- PRELU
- QUANTIZE
- RELU
- RELU6
diff --git a/src/armnn/optimizations/AddBroadcastReshapeLayer.hpp b/src/armnn/optimizations/AddBroadcastReshapeLayer.hpp
index 4cfe2e4898..d243a807fd 100644
--- a/src/armnn/optimizations/AddBroadcastReshapeLayer.hpp
+++ b/src/armnn/optimizations/AddBroadcastReshapeLayer.hpp
@@ -17,7 +17,8 @@ namespace optimizations
static const std::set<armnn::LayerType> broadcastOps{ LayerType::Addition, LayerType::Division,
LayerType::Maximum, LayerType::Minimum,
- LayerType::Multiplication, LayerType::Subtraction };
+ LayerType::Multiplication, LayerType::Prelu,
+ LayerType::Subtraction };
class AddBroadcastReshapeLayerImpl
{
diff --git a/src/armnnTfLiteParser/test/Prelu.cpp b/src/armnnTfLiteParser/test/Prelu.cpp
index b4aa8d7f4d..48a86dcefc 100644
--- a/src/armnnTfLiteParser/test/Prelu.cpp
+++ b/src/armnnTfLiteParser/test/Prelu.cpp
@@ -106,7 +106,7 @@ struct PreluFixture : public ParserFlatbuffersFixture
struct SimplePreluFixture : PreluFixture
{
SimplePreluFixture() : PreluFixture("[ 2, 3 ]",
- "[ 1, 1 ]",
+ "[ 1 ]",
"[ 2, 3 ]",
"[ 0, 1 ]",
"") {}
@@ -115,13 +115,23 @@ struct SimplePreluFixture : PreluFixture
struct PreluConstAlphaFixture : PreluFixture
{
PreluConstAlphaFixture() : PreluFixture(
- "[ 2, 3 ]",
- "[ 2, 3 ]",
- "[ 2, 3 ]",
+ "[ 1, 2, 3 ]",
+ "[ 1, 2, 3 ]",
+ "[ 1, 2, 3 ]",
"[ 0 ]",
"\"data\": [ 0, 0, 128, 62, 0, 0, 128, 62, 0, 0, 128, 62, 0, 0, 128, 62, 0, 0, 128, 62, 0, 0, 128, 62 ]"){}
};
+struct PreluBroadcastAlphaFixture : PreluFixture
+{
+ PreluBroadcastAlphaFixture() : PreluFixture(
+ "[ 1, 1, 2, 3 ]",
+ "[ 1, 3 ]",
+ "[ 1, 1, 2, 3 ]",
+ "[ 0 ]",
+ "\"data\": [ 0, 0, 128, 62, 0, 0, 128, 62, 0, 0, 128, 62 ]"){}
+};
+
struct PreluDynamicTensorFixture : PreluFixture
{
PreluDynamicTensorFixture() : PreluFixture("[ 2, 3 ]",
@@ -141,7 +151,15 @@ BOOST_FIXTURE_TEST_CASE(SimplePrelu, SimplePreluFixture)
BOOST_FIXTURE_TEST_CASE(PreluConstAlpha, PreluConstAlphaFixture)
{
- RunTest<2, armnn::DataType::Float32>(
+ RunTest<3, armnn::DataType::Float32>(
+ 0,
+ {{"input0", { -14.f, 2.f, 0.f, 1.f, -5.f, 14.f }}},
+ {{"output", { -3.5f, 2.f, 0.f, 1.f, -1.25f, 14.f }}});
+}
+
+BOOST_FIXTURE_TEST_CASE(PreluBroadcastAlpha, PreluBroadcastAlphaFixture)
+{
+ RunTest<4, armnn::DataType::Float32>(
0,
{{"input0", { -14.f, 2.f, 0.f, 1.f, -5.f, 14.f }}},
{{"output", { -3.5f, 2.f, 0.f, 1.f, -1.25f, 14.f }}});