diff options
author | Matteo Martincigh <matteo.martincigh@arm.com> | 2019-06-19 12:53:27 +0100 |
---|---|---|
committer | Matteo Martincigh <matteo.martincigh@arm.com> | 2019-06-19 16:02:00 +0000 |
commit | a4812b6cd54e4dc4903f457066281d8bf0ccf448 (patch) | |
tree | 8096dc837be68887f449f3f825e6b2bb7486d8a6 /src/armnn/test | |
parent | 20b1f88309903b576ae030888022f38cce2bbc82 (diff) | |
download | armnn-a4812b6cd54e4dc4903f457066281d8bf0ccf448.tar.gz |
IVGCVSW-3270 Add Quantizer support for the new Prelu Activation layer
* Implemented VisitPreluLayer
* Added unit test for Prelu layer quantization
Change-Id: I0442053f69608a400d295654b103cfd2429a0341
Signed-off-by: Matteo Martincigh <matteo.martincigh@arm.com>
Diffstat (limited to 'src/armnn/test')
-rw-r--r-- | src/armnn/test/QuantizerTest.cpp | 104 |
1 files changed, 104 insertions, 0 deletions
diff --git a/src/armnn/test/QuantizerTest.cpp b/src/armnn/test/QuantizerTest.cpp index 581991b57c..2792d5c483 100644 --- a/src/armnn/test/QuantizerTest.cpp +++ b/src/armnn/test/QuantizerTest.cpp @@ -1611,6 +1611,110 @@ BOOST_AUTO_TEST_CASE(QuantizeBatchToSpace) VisitLayersTopologically(quantizedNetworkQSymm16.get(), validatorQSymm16); } +BOOST_AUTO_TEST_CASE(QuantizePrelu) +{ + class TestPreluQuantization : public TestQuantization + { + public: + TestPreluQuantization(const TensorShape& inputShape, + const TensorShape& alphaShape, + const TensorShape& outputShape) + : TestQuantization(inputShape, outputShape) + , m_AlphaShape(alphaShape) + {} + + TestPreluQuantization(const QuantizerOptions& options, + const TensorShape& inputShape, + const TensorShape& alphaShape, + const TensorShape& outputShape) + : TestQuantization(options, inputShape, outputShape) + , m_AlphaShape(alphaShape) + {} + + void VisitInputLayer(const IConnectableLayer* layer, + LayerBindingId id, + const char* name = nullptr) override + { + const TensorInfo& info = layer->GetOutputSlot(0).GetTensorInfo(); + + switch (id) + { + case 0: // Input + BOOST_TEST(m_InputShape == info.GetShape()); + break; + case 1: // Alpha + BOOST_TEST(m_AlphaShape == info.GetShape()); + break; + default: + throw InvalidArgumentException("Invalid layer binding id for PReLU layer"); + } + + // Based off current default [-15.0f, 15.0f] + TestQuantizationParams(info, + { 30.0f / g_Asymm8QuantizationBase, 128 }, // QASymm8 + { 15.0f / g_Symm16QuantizationBase, 0 }); // QSymm16 + } + + void VisitOutputLayer(const IConnectableLayer* layer, + LayerBindingId id, + const char* name = nullptr) override + { + const TensorInfo& info = layer->GetInputSlot(0).GetConnection()->GetTensorInfo(); + BOOST_TEST(m_OutputShape == info.GetShape()); + } + + void VisitPreluLayer(const IConnectableLayer* layer, + const char* name = nullptr) override + { + const TensorInfo& info = layer->GetOutputSlot(0).GetTensorInfo(); + TestQuantizationParams(info, + { 30.0f / g_Asymm8QuantizationBase, 128 }, // QASymm8 + { 15.0f / g_Symm16QuantizationBase, 0 }); // QSymm16 + } + + private: + TensorShape m_AlphaShape; + }; + + INetworkPtr network = INetwork::Create(); + + const TensorShape inputShape{ 4, 1, 2 }; + const TensorShape alphaShape{ 5, 4, 3, 1 }; + const TensorShape outputShape{ 5, 4, 3, 2 }; + TensorInfo inputInfo(inputShape, DataType::Float32); + TensorInfo alphaInfo(alphaShape, DataType::Float32); + TensorInfo outputInfo(outputShape, DataType::Float32); + + // Add the input layers + IConnectableLayer* input = network->AddInputLayer(0); + IConnectableLayer* alpha = network->AddInputLayer(1); + + // Add the layer under test + IConnectableLayer* prelu = network->AddPreluLayer("prelu"); + + // Add the output layers + IConnectableLayer* output = network->AddOutputLayer(0); + + // Establish connections + input->GetOutputSlot(0).Connect(prelu->GetInputSlot(0)); + alpha->GetOutputSlot(0).Connect(prelu->GetInputSlot(1)); + prelu->GetOutputSlot(0).Connect(output->GetInputSlot(0)); + + // Set tensor info + input->GetOutputSlot(0).SetTensorInfo(inputInfo); + alpha->GetOutputSlot(0).SetTensorInfo(alphaInfo); + prelu->GetOutputSlot(0).SetTensorInfo(outputInfo); + + INetworkPtr quantizedNetworkQAsymm8 = INetworkQuantizer::Create(network.get())->ExportNetwork(); + TestPreluQuantization validatorQAsymm8(inputShape, alphaShape, outputShape); + VisitLayersTopologically(quantizedNetworkQAsymm8.get(), validatorQAsymm8); + + const QuantizerOptions options(DataType::QuantisedSymm16); + INetworkPtr quantizedNetworkQSymm16 = INetworkQuantizer::Create(network.get(), options)->ExportNetwork(); + TestPreluQuantization validatorQSymm16(options, inputShape, alphaShape, outputShape); + VisitLayersTopologically(quantizedNetworkQSymm16.get(), validatorQSymm16); +} + std::vector<uint8_t> SetupQuantize(float value) { armnn::TensorInfo inputInfo({ 1, 2, 2 }, armnn::DataType::Float32); |