From bab9dc64cb4fad9cc0c4d48678f3e7f841b6504d Mon Sep 17 00:00:00 2001 From: Aron Virginas-Tar Date: Wed, 18 Sep 2019 14:49:29 +0100 Subject: IVGCVSW-3881 Add Quantizer support for SLICE Signed-off-by: Aron Virginas-Tar Change-Id: I72bc00888d416fee177ea2e6e5006f8ff04f612e --- src/armnn/QuantizerVisitor.cpp | 9 +++++++ src/armnn/QuantizerVisitor.hpp | 4 +++ src/armnn/test/QuantizerTest.cpp | 55 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 68 insertions(+) diff --git a/src/armnn/QuantizerVisitor.cpp b/src/armnn/QuantizerVisitor.cpp index 9b13056da8..5a86264efd 100644 --- a/src/armnn/QuantizerVisitor.cpp +++ b/src/armnn/QuantizerVisitor.cpp @@ -413,6 +413,15 @@ void QuantizerVisitor::VisitRsqrtLayer(const IConnectableLayer* layer, SetQuantizedInputConnections(layer, newLayer); } +void QuantizerVisitor::VisitSliceLayer(const IConnectableLayer* layer, + const SliceDescriptor& sliceDescriptor, + const char* name) +{ + IConnectableLayer* newLayer = m_QuantizedNetwork->AddSliceLayer(sliceDescriptor, name); + RecordLayer(layer, newLayer); + SetQuantizedInputConnections(layer, newLayer); +} + void QuantizerVisitor::VisitSoftmaxLayer(const IConnectableLayer* layer, const SoftmaxDescriptor& softmaxDescriptor, const char* name) diff --git a/src/armnn/QuantizerVisitor.hpp b/src/armnn/QuantizerVisitor.hpp index e087fe6a61..3a1e300b7f 100644 --- a/src/armnn/QuantizerVisitor.hpp +++ b/src/armnn/QuantizerVisitor.hpp @@ -124,6 +124,10 @@ public: void VisitRsqrtLayer(const IConnectableLayer*, const char* name = nullptr) override; + void VisitSliceLayer(const IConnectableLayer* layer, + const SliceDescriptor& sliceDescriptor, + const char* name = nullptr) override; + void VisitSoftmaxLayer(const IConnectableLayer* layer, const SoftmaxDescriptor& softmaxDescriptor, const char* name = nullptr) override; diff --git a/src/armnn/test/QuantizerTest.cpp b/src/armnn/test/QuantizerTest.cpp index 7a5d27bd52..d902b8df40 100644 --- a/src/armnn/test/QuantizerTest.cpp +++ b/src/armnn/test/QuantizerTest.cpp @@ -1905,6 +1905,61 @@ BOOST_AUTO_TEST_CASE(QuantizeStack) VisitLayersTopologically(quantizedNetworkQSymm16.get(), validatorQSymm16); } +BOOST_AUTO_TEST_CASE(QuantizeSlice) +{ + class TestSliceQuantization : public TestQuantization + { + public: + TestSliceQuantization(const TensorShape& inputShape, const TensorShape& outputShape) + : TestQuantization(inputShape, outputShape) + {} + + TestSliceQuantization(const QuantizerOptions& options, + const TensorShape& inputShape, + const TensorShape& outputShape) + : TestQuantization(options, inputShape, outputShape) + {} + + virtual void VisitSliceLayer(const IConnectableLayer* layer, + const SliceDescriptor& desc, + const char* name = nullptr) + { + const TensorInfo& info = layer->GetOutputSlot(0).GetTensorInfo(); + + const OffsetScalePair qAsymm8Params{ 30.0f / g_Asymm8QuantizationBase, 128 }; + const OffsetScalePair qSymm16Params{ 15.0f / g_Symm16QuantizationBase, 0 }; + + TestQuantizationParams(info, qAsymm8Params, qSymm16Params); + } + }; + + TensorShape shape{ 3 }; + TensorInfo info(shape, DataType::Float32); + + INetworkPtr network = INetwork::Create(); + + IConnectableLayer* inputLayer = network->AddInputLayer(0); + IConnectableLayer* sliceLayer = network->AddSliceLayer(SliceDescriptor()); + IConnectableLayer* outputLayer = network->AddOutputLayer(0); + + inputLayer->GetOutputSlot(0).Connect(sliceLayer->GetInputSlot(0)); + sliceLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0)); + + inputLayer->GetOutputSlot(0).SetTensorInfo(info); + sliceLayer->GetOutputSlot(0).SetTensorInfo(info); + + // test QAsymm8 quantization + INetworkPtr quantizedNetworkQAsymm8 = INetworkQuantizer::Create(network.get())->ExportNetwork(); + TestSliceQuantization validatorQAsymm8(shape, shape); + VisitLayersTopologically(quantizedNetworkQAsymm8.get(), validatorQAsymm8); + + // test QSymm16 quantization + const QuantizerOptions options(DataType::QuantisedSymm16); + INetworkPtr quantizedNetworkQSymm16 = INetworkQuantizer::Create(network.get(), options)->ExportNetwork(); + TestSliceQuantization validatorQSymm16(options, shape, shape); + VisitLayersTopologically(quantizedNetworkQSymm16.get(), validatorQSymm16); +} + std::vector SetupQuantize(float value) { armnn::TensorInfo inputInfo({ 1, 2, 2 }, armnn::DataType::Float32); -- cgit v1.2.1