diff options
-rw-r--r-- | src/armnn/QuantizerVisitor.cpp | 13 | ||||
-rw-r--r-- | src/armnn/QuantizerVisitor.hpp | 4 | ||||
-rw-r--r-- | src/armnn/test/QuantizerTest.cpp | 56 |
3 files changed, 71 insertions, 2 deletions
diff --git a/src/armnn/QuantizerVisitor.cpp b/src/armnn/QuantizerVisitor.cpp index 9819d71ea9..7158c99995 100644 --- a/src/armnn/QuantizerVisitor.cpp +++ b/src/armnn/QuantizerVisitor.cpp @@ -319,9 +319,18 @@ void QuantizerVisitor::VisitInstanceNormalizationLayer(const IConnectableLayer* SetQuantizedInputConnections(layer, newLayer); } +void QuantizerVisitor::VisitLogSoftmaxLayer(const IConnectableLayer* layer, + const LogSoftmaxDescriptor& logSoftmaxDescriptor, + const char* name) +{ + IConnectableLayer* newLayer = m_QuantizedNetwork->AddLogSoftmaxLayer(logSoftmaxDescriptor, name); + RecordLayer(layer, newLayer); + SetQuantizedInputConnections(layer, newLayer); +} + void QuantizerVisitor::VisitMeanLayer(const IConnectableLayer* layer, - const MeanDescriptor& meanDescriptor, - const char* name) + const MeanDescriptor& meanDescriptor, + const char* name) { IConnectableLayer* newLayer = m_QuantizedNetwork->AddMeanLayer(meanDescriptor, name); RecordLayer(layer, newLayer); diff --git a/src/armnn/QuantizerVisitor.hpp b/src/armnn/QuantizerVisitor.hpp index d1c4375b59..89d1932a08 100644 --- a/src/armnn/QuantizerVisitor.hpp +++ b/src/armnn/QuantizerVisitor.hpp @@ -93,6 +93,10 @@ public: const InstanceNormalizationDescriptor& instanceNormalizationDescriptor, const char* name = nullptr) override; + void VisitLogSoftmaxLayer(const IConnectableLayer* layer, + const LogSoftmaxDescriptor& logSoftmaxDescriptor, + const char* name = nullptr) override; + void VisitMeanLayer(const IConnectableLayer* layer, const MeanDescriptor& meanDescriptor, const char* name = nullptr) override; diff --git a/src/armnn/test/QuantizerTest.cpp b/src/armnn/test/QuantizerTest.cpp index 6f7c115164..101be1fb57 100644 --- a/src/armnn/test/QuantizerTest.cpp +++ b/src/armnn/test/QuantizerTest.cpp @@ -1062,6 +1062,62 @@ BOOST_AUTO_TEST_CASE(QuantizeInstanceNormalization) VisitLayersTopologically(quantizedNetworkQSymm16.get(), validatorQSymm16); } +BOOST_AUTO_TEST_CASE(QuantizeLogSoftmax) +{ + class TestLogSoftmaxQuantization : public TestQuantization + { + public: + TestLogSoftmaxQuantization(const TensorShape& inputShape, const TensorShape& outputShape) + : TestQuantization(inputShape, outputShape) {} + + TestLogSoftmaxQuantization(const QuantizerOptions& options, + const TensorShape& inputShape, + const TensorShape& outputShape) + : TestQuantization(options, inputShape, outputShape) {} + + void VisitLogSoftmaxLayer(const IConnectableLayer* layer, + const SoftmaxDescriptor& descriptor, + const char* name = nullptr) override + { + TensorInfo info = layer->GetOutputSlot(0).GetTensorInfo(); + + const OffsetScalePair qAsymm8Params{ 30.0f / g_Asymm8QuantizationBase, 128 }; + const OffsetScalePair qSymm16Params{ 15.0f / g_Symm16QuantizationBase, 0 }; + + TestQuantizationParams(info, qAsymm8Params, qSymm16Params); + } + }; + + const TensorShape tensorShape{ 1U }; + const TensorInfo tensorInfo(tensorShape, DataType::Float32); + + INetworkPtr network = INetwork::Create(); + + LogSoftmaxDescriptor descriptor; + descriptor.m_Beta = 1.0f; + + IConnectableLayer* inputLayer = network->AddInputLayer(0); + IConnectableLayer* logSoftmaxLayer = network->AddLogSoftmaxLayer(descriptor); + IConnectableLayer* outputLayer = network->AddOutputLayer(0); + + inputLayer->GetOutputSlot(0).Connect(logSoftmaxLayer->GetInputSlot(0)); + logSoftmaxLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0)); + + inputLayer->GetOutputSlot(0).SetTensorInfo(tensorInfo); + logSoftmaxLayer->GetOutputSlot(0).SetTensorInfo(tensorInfo); + + // test QAsymm8 quantization + INetworkPtr quantizedNetworkQAsymm8 = INetworkQuantizer::Create(network.get())->ExportNetwork(); + TestLogSoftmaxQuantization validatorQAsymm8(tensorShape, tensorShape); + VisitLayersTopologically(quantizedNetworkQAsymm8.get(), validatorQAsymm8); + + // test QuantisedSymm16 quantization + const QuantizerOptions options(DataType::QuantisedSymm16); + INetworkPtr quantizedNetworkQSymm16 = INetworkQuantizer::Create(network.get(), options)->ExportNetwork(); + TestLogSoftmaxQuantization validatorQSymm16(options, tensorShape, tensorShape); + VisitLayersTopologically(quantizedNetworkQSymm16.get(), validatorQSymm16); +} + INetworkPtr CreateNetworkWithSoftmaxLayer(const SoftmaxDescriptor& descriptor, const TensorShape& shape) { INetworkPtr network = INetwork::Create(); |