From 724e48013142562b7f09c9c819f57c314c4ee3d4 Mon Sep 17 00:00:00 2001 From: Nattapat Chaimanowong Date: Thu, 9 May 2019 10:13:20 +0100 Subject: IVGCVSW-3061 Modify NetworkQuantizer to support option to preserve input/output types * Also add unit tests for new preserve type option Signed-off-by: Nattapat Chaimanowong Change-Id: I860759072f2e3546698118d1bcd5e79eb4e805ec --- include/armnnQuantizer/INetworkQuantizer.hpp | 10 ++- src/armnn/NetworkQuantizer.cpp | 2 +- src/armnn/QuantizerVisitor.cpp | 41 ++++++++-- src/armnn/QuantizerVisitor.hpp | 7 +- src/armnn/test/QuantizerTest.cpp | 112 +++++++++++++++++++++++++-- 5 files changed, 156 insertions(+), 16 deletions(-) diff --git a/include/armnnQuantizer/INetworkQuantizer.hpp b/include/armnnQuantizer/INetworkQuantizer.hpp index 89548d1057..826b077f6e 100644 --- a/include/armnnQuantizer/INetworkQuantizer.hpp +++ b/include/armnnQuantizer/INetworkQuantizer.hpp @@ -14,10 +14,16 @@ namespace armnn struct QuantizerOptions { - QuantizerOptions() : m_ActivationFormat(DataType::QuantisedAsymm8) {} - QuantizerOptions(DataType activationFormat) : m_ActivationFormat(activationFormat) {} + QuantizerOptions() : QuantizerOptions(DataType::QuantisedAsymm8, false) {} + + QuantizerOptions(DataType activationFormat) : QuantizerOptions(activationFormat, false) {} + + QuantizerOptions(DataType activationFormat, bool preserveType) + : m_ActivationFormat(activationFormat) + , m_PreserveType(preserveType) {} DataType m_ActivationFormat; + bool m_PreserveType; }; using INetworkQuantizerPtr = std::unique_ptr; diff --git a/src/armnn/NetworkQuantizer.cpp b/src/armnn/NetworkQuantizer.cpp index 12e459d276..f308d54d49 100644 --- a/src/armnn/NetworkQuantizer.cpp +++ b/src/armnn/NetworkQuantizer.cpp @@ -171,7 +171,7 @@ INetworkPtr NetworkQuantizer::ExportNetwork() throw InvalidArgumentException("Unsupported quantization target"); } - QuantizerVisitor quantizerVisitor(m_Ranges, quantizationScheme.get()); + QuantizerVisitor quantizerVisitor(m_Ranges, quantizationScheme.get(), m_Options.m_PreserveType); VisitLayers(graph, quantizerVisitor); // clear the ranges diff --git a/src/armnn/QuantizerVisitor.cpp b/src/armnn/QuantizerVisitor.cpp index 919eda1c7e..61e0e60779 100644 --- a/src/armnn/QuantizerVisitor.cpp +++ b/src/armnn/QuantizerVisitor.cpp @@ -11,10 +11,13 @@ namespace armnn { -QuantizerVisitor::QuantizerVisitor(const RangeTracker& rangeTracker, const IQuantizationScheme* quantizationScheme) +QuantizerVisitor::QuantizerVisitor(const RangeTracker& rangeTracker, + const IQuantizationScheme* quantizationScheme, + bool preserveType) : m_Ranges(rangeTracker) , m_QuantizedNetwork(INetwork::Create()) , m_QuantizationScheme(quantizationScheme) + , m_PreserveType(preserveType) { } @@ -106,15 +109,41 @@ void QuantizerVisitor::VisitFullyConnectedLayer(const IConnectableLayer *layer, void QuantizerVisitor::VisitInputLayer(const IConnectableLayer *layer, LayerBindingId id, const char *name) { - IConnectableLayer* newLayer = m_QuantizedNetwork->AddInputLayer(id, name); - RecordLayer(layer, newLayer); + const DataType dataType = layer->GetOutputSlot(0).GetTensorInfo().GetDataType(); + IConnectableLayer* inputLayer = m_QuantizedNetwork->AddInputLayer(id, name); + + if (m_PreserveType && (dataType == DataType::Float32 || dataType == DataType::Float16)) + { + IConnectableLayer* quantizeLayer = m_QuantizedNetwork->AddQuantizeLayer(); + inputLayer->GetOutputSlot(0).Connect(quantizeLayer->GetInputSlot(0)); + inputLayer->GetOutputSlot(0).SetTensorInfo(layer->GetOutputSlot(0).GetTensorInfo()); + RecordLayer(layer, quantizeLayer); + } + else + { + RecordLayer(layer, inputLayer); + } } void QuantizerVisitor::VisitOutputLayer(const IConnectableLayer* layer, LayerBindingId id, const char* name) { - IConnectableLayer* newLayer = m_QuantizedNetwork->AddOutputLayer(id, name); - RecordLayer(layer, newLayer); - SetQuantizedInputConnections(layer, newLayer); + const TensorInfo& info = layer->GetInputSlot(0).GetConnection()->GetTensorInfo(); + const DataType& dataType = info.GetDataType(); + IConnectableLayer* outputLayer = m_QuantizedNetwork->AddOutputLayer(id, name); + + if (m_PreserveType && (dataType == DataType::Float32 || dataType == DataType::Float16)) + { + IConnectableLayer* dequantizeLayer = m_QuantizedNetwork->AddDequantizeLayer(); + RecordLayer(layer, dequantizeLayer); + SetQuantizedInputConnections(layer, dequantizeLayer); + dequantizeLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0)); + dequantizeLayer->GetOutputSlot(0).SetTensorInfo(info); + } + else + { + RecordLayer(layer, outputLayer); + SetQuantizedInputConnections(layer, outputLayer); + } } void QuantizerVisitor::VisitBatchNormalizationLayer(const IConnectableLayer* layer, diff --git a/src/armnn/QuantizerVisitor.hpp b/src/armnn/QuantizerVisitor.hpp index eb9ebac3d9..300ac164de 100644 --- a/src/armnn/QuantizerVisitor.hpp +++ b/src/armnn/QuantizerVisitor.hpp @@ -25,7 +25,10 @@ class StaticRangeVisitor; class QuantizerVisitor : public LayerVisitorBase { public: - QuantizerVisitor(const RangeTracker& rangeTracker, const IQuantizationScheme* quantizationScheme); + QuantizerVisitor(const RangeTracker& rangeTracker, + const IQuantizationScheme* quantizationScheme, + bool preserveType = false); + ~QuantizerVisitor() = default; /// Functions to quantize the individual layers, overridden from ILayerVisitor @@ -132,6 +135,8 @@ private: std::unordered_map m_QuantizedGuidToLayerMap; const IQuantizationScheme* m_QuantizationScheme; + + const bool m_PreserveType; }; } //namespace armnn diff --git a/src/armnn/test/QuantizerTest.cpp b/src/armnn/test/QuantizerTest.cpp index 259e90fcca..2103de062c 100644 --- a/src/armnn/test/QuantizerTest.cpp +++ b/src/armnn/test/QuantizerTest.cpp @@ -38,15 +38,15 @@ class TestQuantization : public LayerVisitorBase public: TestQuantization(const TensorShape& inputShape, const TensorShape& outputShape) : LayerVisitorBase() - , m_QuantizerOptions(QuantizerOptions()) , m_InputShape(inputShape) - , m_OutputShape(outputShape) {} + , m_OutputShape(outputShape) + , m_QuantizerOptions(QuantizerOptions()) {} TestQuantization(const QuantizerOptions& options, const TensorShape& inputShape, const TensorShape& outputShape) : LayerVisitorBase() - , m_QuantizerOptions(options) , m_InputShape(inputShape) - , m_OutputShape(outputShape) {} + , m_OutputShape(outputShape) + , m_QuantizerOptions(options) {} void VisitInputLayer(const IConnectableLayer* layer, LayerBindingId id, @@ -91,6 +91,9 @@ protected: TestQuantizationParamsImpl(info, DataType::QuantisedAsymm8, params.first, params.second); } + TensorShape m_InputShape; + TensorShape m_OutputShape; + private: void TestQuantizationParamsImpl(const TensorInfo& info, DataType dataType, float scale, int32_t offset) { @@ -100,8 +103,6 @@ private: } QuantizerOptions m_QuantizerOptions; - TensorShape m_InputShape; - TensorShape m_OutputShape; }; void VisitLayersTopologically(const INetwork* inputNetwork, ILayerVisitor& visitor) @@ -1574,5 +1575,104 @@ BOOST_AUTO_TEST_CASE(QuantizeNegativeInf) BOOST_CHECK_EQUAL(SetupQuantize(-1 * std::numeric_limits::infinity())[0], 0); } +class TestPreserveType : public TestAdditionQuantization +{ +public: + TestPreserveType(const QuantizerOptions& options, + const DataType& dataType, + const TensorShape& inputShape, + const TensorShape& outputShape) + : TestAdditionQuantization(options, inputShape, outputShape) + , m_DataType(dataType) + , m_VisitedQuantizeLayer(false) + , m_VisitedDequantizeLayer(false) {} + + void VisitInputLayer(const IConnectableLayer* layer, + LayerBindingId id, + const char* name = nullptr) override + { + const TensorInfo& info = layer->GetOutputSlot(0).GetTensorInfo(); + BOOST_TEST(GetDataTypeName(info.GetDataType()) == GetDataTypeName(m_DataType)); + BOOST_TEST(m_InputShape == info.GetShape()); + } + + void VisitOutputLayer(const IConnectableLayer* layer, + LayerBindingId id, + const char* name = nullptr) override + { + const TensorInfo& info = layer->GetInputSlot(0).GetConnection()->GetTensorInfo(); + BOOST_TEST(GetDataTypeName(info.GetDataType()) == GetDataTypeName(m_DataType)); + BOOST_TEST(m_OutputShape == info.GetShape()); + } + + void VisitQuantizeLayer(const IConnectableLayer* layer, + const char* name = nullptr) override + { + m_VisitedQuantizeLayer = true; + } + + void VisitDequantizeLayer(const IConnectableLayer* layer, + const char* name = nullptr) override + { + m_VisitedDequantizeLayer = true; + } + + void CheckQuantizeDequantizeLayerVisited(bool expected) + { + if (expected) + { + BOOST_CHECK(m_VisitedQuantizeLayer); + BOOST_CHECK(m_VisitedDequantizeLayer); + } + else + { + BOOST_CHECK(!m_VisitedQuantizeLayer); + BOOST_CHECK(!m_VisitedDequantizeLayer); + } + } +private: + const DataType m_DataType; + bool m_VisitedQuantizeLayer; + bool m_VisitedDequantizeLayer; +}; + +void PreserveTypeTestImpl(const DataType& dataType) +{ + INetworkPtr network = INetwork::Create(); + + // Add the layers + IConnectableLayer* input0 = network->AddInputLayer(0); + IConnectableLayer* input1 = network->AddInputLayer(1); + IConnectableLayer* addition = network->AddAdditionLayer(); + IConnectableLayer* output = network->AddOutputLayer(2); + + input0->GetOutputSlot(0).Connect(addition->GetInputSlot(0)); + input1->GetOutputSlot(0).Connect(addition->GetInputSlot(1)); + addition->GetOutputSlot(0).Connect(output->GetInputSlot(0)); + + const TensorShape shape{1U, 2U, 3U}; + const TensorInfo info(shape, dataType); + input0->GetOutputSlot(0).SetTensorInfo(info); + input1->GetOutputSlot(0).SetTensorInfo(info); + addition->GetOutputSlot(0).SetTensorInfo(info); + + const QuantizerOptions options(DataType::QuantisedAsymm8, true); + INetworkPtr quantizedNetworkQAsymm8 = INetworkQuantizer::Create(network.get(), options)->ExportNetwork(); + TestPreserveType validatorQAsymm8(options, dataType, shape, shape); + VisitLayersTopologically(quantizedNetworkQAsymm8.get(), validatorQAsymm8); + validatorQAsymm8.CheckQuantizeDequantizeLayerVisited( + dataType == DataType::Float32 || dataType == DataType::Float16); +} + +BOOST_AUTO_TEST_CASE(PreserveTypeFloat32) +{ + PreserveTypeTestImpl(DataType::Float32); +} + +BOOST_AUTO_TEST_CASE(PreserveTypeQAsymm8) +{ + PreserveTypeTestImpl(DataType::QuantisedAsymm8); +} + BOOST_AUTO_TEST_SUITE_END() } // namespace armnn -- cgit v1.2.1