aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatteo Martincigh <matteo.martincigh@arm.com>2019-06-19 12:53:27 +0100
committerMatteo Martincigh <matteo.martincigh@arm.com>2019-06-19 16:02:00 +0000
commita4812b6cd54e4dc4903f457066281d8bf0ccf448 (patch)
tree8096dc837be68887f449f3f825e6b2bb7486d8a6
parent20b1f88309903b576ae030888022f38cce2bbc82 (diff)
downloadarmnn-a4812b6cd54e4dc4903f457066281d8bf0ccf448.tar.gz
IVGCVSW-3270 Add Quantizer support for the new Prelu Activation layer
* Implemented VisitPreluLayer * Added unit test for Prelu layer quantization Change-Id: I0442053f69608a400d295654b103cfd2429a0341 Signed-off-by: Matteo Martincigh <matteo.martincigh@arm.com>
-rw-r--r--src/armnn/QuantizerVisitor.cpp8
-rw-r--r--src/armnn/QuantizerVisitor.hpp3
-rw-r--r--src/armnn/test/QuantizerTest.cpp104
3 files changed, 115 insertions, 0 deletions
diff --git a/src/armnn/QuantizerVisitor.cpp b/src/armnn/QuantizerVisitor.cpp
index c6a55f404c..ef6b068145 100644
--- a/src/armnn/QuantizerVisitor.cpp
+++ b/src/armnn/QuantizerVisitor.cpp
@@ -359,6 +359,14 @@ void QuantizerVisitor::VisitPooling2dLayer(const IConnectableLayer* layer,
SetQuantizedInputConnections(layer, newLayer);
}
+void QuantizerVisitor::VisitPreluLayer(const IConnectableLayer* layer,
+ const char* name)
+{
+ IConnectableLayer* newLayer = m_QuantizedNetwork->AddPreluLayer(name);
+ RecordLayer(layer, newLayer);
+ SetQuantizedInputConnections(layer, newLayer);
+}
+
void QuantizerVisitor::VisitReshapeLayer(const IConnectableLayer* layer,
const ReshapeDescriptor& reshapeDescriptor,
const char* name)
diff --git a/src/armnn/QuantizerVisitor.hpp b/src/armnn/QuantizerVisitor.hpp
index 1f1f651c1a..ca9f6940a6 100644
--- a/src/armnn/QuantizerVisitor.hpp
+++ b/src/armnn/QuantizerVisitor.hpp
@@ -103,6 +103,9 @@ public:
const Pooling2dDescriptor& pooling2dDescriptor,
const char* name = nullptr) override;
+ void VisitPreluLayer(const IConnectableLayer* layer,
+ const char* name = nullptr) override;
+
void VisitReshapeLayer(const IConnectableLayer* layer,
const ReshapeDescriptor& reshapeDescriptor,
const char* name = nullptr) override;
diff --git a/src/armnn/test/QuantizerTest.cpp b/src/armnn/test/QuantizerTest.cpp
index 581991b57c..2792d5c483 100644
--- a/src/armnn/test/QuantizerTest.cpp
+++ b/src/armnn/test/QuantizerTest.cpp
@@ -1611,6 +1611,110 @@ BOOST_AUTO_TEST_CASE(QuantizeBatchToSpace)
VisitLayersTopologically(quantizedNetworkQSymm16.get(), validatorQSymm16);
}
+BOOST_AUTO_TEST_CASE(QuantizePrelu)
+{
+ class TestPreluQuantization : public TestQuantization
+ {
+ public:
+ TestPreluQuantization(const TensorShape& inputShape,
+ const TensorShape& alphaShape,
+ const TensorShape& outputShape)
+ : TestQuantization(inputShape, outputShape)
+ , m_AlphaShape(alphaShape)
+ {}
+
+ TestPreluQuantization(const QuantizerOptions& options,
+ const TensorShape& inputShape,
+ const TensorShape& alphaShape,
+ const TensorShape& outputShape)
+ : TestQuantization(options, inputShape, outputShape)
+ , m_AlphaShape(alphaShape)
+ {}
+
+ void VisitInputLayer(const IConnectableLayer* layer,
+ LayerBindingId id,
+ const char* name = nullptr) override
+ {
+ const TensorInfo& info = layer->GetOutputSlot(0).GetTensorInfo();
+
+ switch (id)
+ {
+ case 0: // Input
+ BOOST_TEST(m_InputShape == info.GetShape());
+ break;
+ case 1: // Alpha
+ BOOST_TEST(m_AlphaShape == info.GetShape());
+ break;
+ default:
+ throw InvalidArgumentException("Invalid layer binding id for PReLU layer");
+ }
+
+ // Based off current default [-15.0f, 15.0f]
+ TestQuantizationParams(info,
+ { 30.0f / g_Asymm8QuantizationBase, 128 }, // QASymm8
+ { 15.0f / g_Symm16QuantizationBase, 0 }); // QSymm16
+ }
+
+ void VisitOutputLayer(const IConnectableLayer* layer,
+ LayerBindingId id,
+ const char* name = nullptr) override
+ {
+ const TensorInfo& info = layer->GetInputSlot(0).GetConnection()->GetTensorInfo();
+ BOOST_TEST(m_OutputShape == info.GetShape());
+ }
+
+ void VisitPreluLayer(const IConnectableLayer* layer,
+ const char* name = nullptr) override
+ {
+ const TensorInfo& info = layer->GetOutputSlot(0).GetTensorInfo();
+ TestQuantizationParams(info,
+ { 30.0f / g_Asymm8QuantizationBase, 128 }, // QASymm8
+ { 15.0f / g_Symm16QuantizationBase, 0 }); // QSymm16
+ }
+
+ private:
+ TensorShape m_AlphaShape;
+ };
+
+ INetworkPtr network = INetwork::Create();
+
+ const TensorShape inputShape{ 4, 1, 2 };
+ const TensorShape alphaShape{ 5, 4, 3, 1 };
+ const TensorShape outputShape{ 5, 4, 3, 2 };
+ TensorInfo inputInfo(inputShape, DataType::Float32);
+ TensorInfo alphaInfo(alphaShape, DataType::Float32);
+ TensorInfo outputInfo(outputShape, DataType::Float32);
+
+ // Add the input layers
+ IConnectableLayer* input = network->AddInputLayer(0);
+ IConnectableLayer* alpha = network->AddInputLayer(1);
+
+ // Add the layer under test
+ IConnectableLayer* prelu = network->AddPreluLayer("prelu");
+
+ // Add the output layers
+ IConnectableLayer* output = network->AddOutputLayer(0);
+
+ // Establish connections
+ input->GetOutputSlot(0).Connect(prelu->GetInputSlot(0));
+ alpha->GetOutputSlot(0).Connect(prelu->GetInputSlot(1));
+ prelu->GetOutputSlot(0).Connect(output->GetInputSlot(0));
+
+ // Set tensor info
+ input->GetOutputSlot(0).SetTensorInfo(inputInfo);
+ alpha->GetOutputSlot(0).SetTensorInfo(alphaInfo);
+ prelu->GetOutputSlot(0).SetTensorInfo(outputInfo);
+
+ INetworkPtr quantizedNetworkQAsymm8 = INetworkQuantizer::Create(network.get())->ExportNetwork();
+ TestPreluQuantization validatorQAsymm8(inputShape, alphaShape, outputShape);
+ VisitLayersTopologically(quantizedNetworkQAsymm8.get(), validatorQAsymm8);
+
+ const QuantizerOptions options(DataType::QuantisedSymm16);
+ INetworkPtr quantizedNetworkQSymm16 = INetworkQuantizer::Create(network.get(), options)->ExportNetwork();
+ TestPreluQuantization validatorQSymm16(options, inputShape, alphaShape, outputShape);
+ VisitLayersTopologically(quantizedNetworkQSymm16.get(), validatorQSymm16);
+}
+
std::vector<uint8_t> SetupQuantize(float value)
{
armnn::TensorInfo inputInfo({ 1, 2, 2 }, armnn::DataType::Float32);