aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNattapat Chaimanowong <nattapat.chaimanowong@arm.com>2019-05-09 10:13:20 +0100
committerMatteo Martincigh <matteo.martincigh@arm.com>2019-05-10 08:30:55 +0000
commit724e48013142562b7f09c9c819f57c314c4ee3d4 (patch)
tree935b425961ccb13a488708312bcc70b3b32fc87e
parent5fa83938592db420914903235daf3f1d5c97d6bc (diff)
downloadarmnn-724e48013142562b7f09c9c819f57c314c4ee3d4.tar.gz
IVGCVSW-3061 Modify NetworkQuantizer to support option to preserve input/output types
* Also add unit tests for new preserve type option Signed-off-by: Nattapat Chaimanowong <nattapat.chaimanowong@arm.com> Change-Id: I860759072f2e3546698118d1bcd5e79eb4e805ec
-rw-r--r--include/armnnQuantizer/INetworkQuantizer.hpp10
-rw-r--r--src/armnn/NetworkQuantizer.cpp2
-rw-r--r--src/armnn/QuantizerVisitor.cpp41
-rw-r--r--src/armnn/QuantizerVisitor.hpp7
-rw-r--r--src/armnn/test/QuantizerTest.cpp112
5 files changed, 156 insertions, 16 deletions
diff --git a/include/armnnQuantizer/INetworkQuantizer.hpp b/include/armnnQuantizer/INetworkQuantizer.hpp
index 89548d1057..826b077f6e 100644
--- a/include/armnnQuantizer/INetworkQuantizer.hpp
+++ b/include/armnnQuantizer/INetworkQuantizer.hpp
@@ -14,10 +14,16 @@ namespace armnn
struct QuantizerOptions
{
- QuantizerOptions() : m_ActivationFormat(DataType::QuantisedAsymm8) {}
- QuantizerOptions(DataType activationFormat) : m_ActivationFormat(activationFormat) {}
+ QuantizerOptions() : QuantizerOptions(DataType::QuantisedAsymm8, false) {}
+
+ QuantizerOptions(DataType activationFormat) : QuantizerOptions(activationFormat, false) {}
+
+ QuantizerOptions(DataType activationFormat, bool preserveType)
+ : m_ActivationFormat(activationFormat)
+ , m_PreserveType(preserveType) {}
DataType m_ActivationFormat;
+ bool m_PreserveType;
};
using INetworkQuantizerPtr = std::unique_ptr<class INetworkQuantizer, void(*)(INetworkQuantizer* quantizer)>;
diff --git a/src/armnn/NetworkQuantizer.cpp b/src/armnn/NetworkQuantizer.cpp
index 12e459d276..f308d54d49 100644
--- a/src/armnn/NetworkQuantizer.cpp
+++ b/src/armnn/NetworkQuantizer.cpp
@@ -171,7 +171,7 @@ INetworkPtr NetworkQuantizer::ExportNetwork()
throw InvalidArgumentException("Unsupported quantization target");
}
- QuantizerVisitor quantizerVisitor(m_Ranges, quantizationScheme.get());
+ QuantizerVisitor quantizerVisitor(m_Ranges, quantizationScheme.get(), m_Options.m_PreserveType);
VisitLayers(graph, quantizerVisitor);
// clear the ranges
diff --git a/src/armnn/QuantizerVisitor.cpp b/src/armnn/QuantizerVisitor.cpp
index 919eda1c7e..61e0e60779 100644
--- a/src/armnn/QuantizerVisitor.cpp
+++ b/src/armnn/QuantizerVisitor.cpp
@@ -11,10 +11,13 @@
namespace armnn
{
-QuantizerVisitor::QuantizerVisitor(const RangeTracker& rangeTracker, const IQuantizationScheme* quantizationScheme)
+QuantizerVisitor::QuantizerVisitor(const RangeTracker& rangeTracker,
+ const IQuantizationScheme* quantizationScheme,
+ bool preserveType)
: m_Ranges(rangeTracker)
, m_QuantizedNetwork(INetwork::Create())
, m_QuantizationScheme(quantizationScheme)
+ , m_PreserveType(preserveType)
{
}
@@ -106,15 +109,41 @@ void QuantizerVisitor::VisitFullyConnectedLayer(const IConnectableLayer *layer,
void QuantizerVisitor::VisitInputLayer(const IConnectableLayer *layer, LayerBindingId id, const char *name)
{
- IConnectableLayer* newLayer = m_QuantizedNetwork->AddInputLayer(id, name);
- RecordLayer(layer, newLayer);
+ const DataType dataType = layer->GetOutputSlot(0).GetTensorInfo().GetDataType();
+ IConnectableLayer* inputLayer = m_QuantizedNetwork->AddInputLayer(id, name);
+
+ if (m_PreserveType && (dataType == DataType::Float32 || dataType == DataType::Float16))
+ {
+ IConnectableLayer* quantizeLayer = m_QuantizedNetwork->AddQuantizeLayer();
+ inputLayer->GetOutputSlot(0).Connect(quantizeLayer->GetInputSlot(0));
+ inputLayer->GetOutputSlot(0).SetTensorInfo(layer->GetOutputSlot(0).GetTensorInfo());
+ RecordLayer(layer, quantizeLayer);
+ }
+ else
+ {
+ RecordLayer(layer, inputLayer);
+ }
}
void QuantizerVisitor::VisitOutputLayer(const IConnectableLayer* layer, LayerBindingId id, const char* name)
{
- IConnectableLayer* newLayer = m_QuantizedNetwork->AddOutputLayer(id, name);
- RecordLayer(layer, newLayer);
- SetQuantizedInputConnections(layer, newLayer);
+ const TensorInfo& info = layer->GetInputSlot(0).GetConnection()->GetTensorInfo();
+ const DataType& dataType = info.GetDataType();
+ IConnectableLayer* outputLayer = m_QuantizedNetwork->AddOutputLayer(id, name);
+
+ if (m_PreserveType && (dataType == DataType::Float32 || dataType == DataType::Float16))
+ {
+ IConnectableLayer* dequantizeLayer = m_QuantizedNetwork->AddDequantizeLayer();
+ RecordLayer(layer, dequantizeLayer);
+ SetQuantizedInputConnections(layer, dequantizeLayer);
+ dequantizeLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
+ dequantizeLayer->GetOutputSlot(0).SetTensorInfo(info);
+ }
+ else
+ {
+ RecordLayer(layer, outputLayer);
+ SetQuantizedInputConnections(layer, outputLayer);
+ }
}
void QuantizerVisitor::VisitBatchNormalizationLayer(const IConnectableLayer* layer,
diff --git a/src/armnn/QuantizerVisitor.hpp b/src/armnn/QuantizerVisitor.hpp
index eb9ebac3d9..300ac164de 100644
--- a/src/armnn/QuantizerVisitor.hpp
+++ b/src/armnn/QuantizerVisitor.hpp
@@ -25,7 +25,10 @@ class StaticRangeVisitor;
class QuantizerVisitor : public LayerVisitorBase<VisitorNoThrowPolicy>
{
public:
- QuantizerVisitor(const RangeTracker& rangeTracker, const IQuantizationScheme* quantizationScheme);
+ QuantizerVisitor(const RangeTracker& rangeTracker,
+ const IQuantizationScheme* quantizationScheme,
+ bool preserveType = false);
+
~QuantizerVisitor() = default;
/// Functions to quantize the individual layers, overridden from ILayerVisitor
@@ -132,6 +135,8 @@ private:
std::unordered_map<LayerGuid, IConnectableLayer*> m_QuantizedGuidToLayerMap;
const IQuantizationScheme* m_QuantizationScheme;
+
+ const bool m_PreserveType;
};
} //namespace armnn
diff --git a/src/armnn/test/QuantizerTest.cpp b/src/armnn/test/QuantizerTest.cpp
index 259e90fcca..2103de062c 100644
--- a/src/armnn/test/QuantizerTest.cpp
+++ b/src/armnn/test/QuantizerTest.cpp
@@ -38,15 +38,15 @@ class TestQuantization : public LayerVisitorBase<VisitorThrowingPolicy>
public:
TestQuantization(const TensorShape& inputShape, const TensorShape& outputShape)
: LayerVisitorBase<VisitorThrowingPolicy>()
- , m_QuantizerOptions(QuantizerOptions())
, m_InputShape(inputShape)
- , m_OutputShape(outputShape) {}
+ , m_OutputShape(outputShape)
+ , m_QuantizerOptions(QuantizerOptions()) {}
TestQuantization(const QuantizerOptions& options, const TensorShape& inputShape, const TensorShape& outputShape)
: LayerVisitorBase<VisitorThrowingPolicy>()
- , m_QuantizerOptions(options)
, m_InputShape(inputShape)
- , m_OutputShape(outputShape) {}
+ , m_OutputShape(outputShape)
+ , m_QuantizerOptions(options) {}
void VisitInputLayer(const IConnectableLayer* layer,
LayerBindingId id,
@@ -91,6 +91,9 @@ protected:
TestQuantizationParamsImpl(info, DataType::QuantisedAsymm8, params.first, params.second);
}
+ TensorShape m_InputShape;
+ TensorShape m_OutputShape;
+
private:
void TestQuantizationParamsImpl(const TensorInfo& info, DataType dataType, float scale, int32_t offset)
{
@@ -100,8 +103,6 @@ private:
}
QuantizerOptions m_QuantizerOptions;
- TensorShape m_InputShape;
- TensorShape m_OutputShape;
};
void VisitLayersTopologically(const INetwork* inputNetwork, ILayerVisitor& visitor)
@@ -1574,5 +1575,104 @@ BOOST_AUTO_TEST_CASE(QuantizeNegativeInf)
BOOST_CHECK_EQUAL(SetupQuantize(-1 * std::numeric_limits<float>::infinity())[0], 0);
}
+class TestPreserveType : public TestAdditionQuantization
+{
+public:
+ TestPreserveType(const QuantizerOptions& options,
+ const DataType& dataType,
+ const TensorShape& inputShape,
+ const TensorShape& outputShape)
+ : TestAdditionQuantization(options, inputShape, outputShape)
+ , m_DataType(dataType)
+ , m_VisitedQuantizeLayer(false)
+ , m_VisitedDequantizeLayer(false) {}
+
+ void VisitInputLayer(const IConnectableLayer* layer,
+ LayerBindingId id,
+ const char* name = nullptr) override
+ {
+ const TensorInfo& info = layer->GetOutputSlot(0).GetTensorInfo();
+ BOOST_TEST(GetDataTypeName(info.GetDataType()) == GetDataTypeName(m_DataType));
+ BOOST_TEST(m_InputShape == info.GetShape());
+ }
+
+ void VisitOutputLayer(const IConnectableLayer* layer,
+ LayerBindingId id,
+ const char* name = nullptr) override
+ {
+ const TensorInfo& info = layer->GetInputSlot(0).GetConnection()->GetTensorInfo();
+ BOOST_TEST(GetDataTypeName(info.GetDataType()) == GetDataTypeName(m_DataType));
+ BOOST_TEST(m_OutputShape == info.GetShape());
+ }
+
+ void VisitQuantizeLayer(const IConnectableLayer* layer,
+ const char* name = nullptr) override
+ {
+ m_VisitedQuantizeLayer = true;
+ }
+
+ void VisitDequantizeLayer(const IConnectableLayer* layer,
+ const char* name = nullptr) override
+ {
+ m_VisitedDequantizeLayer = true;
+ }
+
+ void CheckQuantizeDequantizeLayerVisited(bool expected)
+ {
+ if (expected)
+ {
+ BOOST_CHECK(m_VisitedQuantizeLayer);
+ BOOST_CHECK(m_VisitedDequantizeLayer);
+ }
+ else
+ {
+ BOOST_CHECK(!m_VisitedQuantizeLayer);
+ BOOST_CHECK(!m_VisitedDequantizeLayer);
+ }
+ }
+private:
+ const DataType m_DataType;
+ bool m_VisitedQuantizeLayer;
+ bool m_VisitedDequantizeLayer;
+};
+
+void PreserveTypeTestImpl(const DataType& dataType)
+{
+ INetworkPtr network = INetwork::Create();
+
+ // Add the layers
+ IConnectableLayer* input0 = network->AddInputLayer(0);
+ IConnectableLayer* input1 = network->AddInputLayer(1);
+ IConnectableLayer* addition = network->AddAdditionLayer();
+ IConnectableLayer* output = network->AddOutputLayer(2);
+
+ input0->GetOutputSlot(0).Connect(addition->GetInputSlot(0));
+ input1->GetOutputSlot(0).Connect(addition->GetInputSlot(1));
+ addition->GetOutputSlot(0).Connect(output->GetInputSlot(0));
+
+ const TensorShape shape{1U, 2U, 3U};
+ const TensorInfo info(shape, dataType);
+ input0->GetOutputSlot(0).SetTensorInfo(info);
+ input1->GetOutputSlot(0).SetTensorInfo(info);
+ addition->GetOutputSlot(0).SetTensorInfo(info);
+
+ const QuantizerOptions options(DataType::QuantisedAsymm8, true);
+ INetworkPtr quantizedNetworkQAsymm8 = INetworkQuantizer::Create(network.get(), options)->ExportNetwork();
+ TestPreserveType validatorQAsymm8(options, dataType, shape, shape);
+ VisitLayersTopologically(quantizedNetworkQAsymm8.get(), validatorQAsymm8);
+ validatorQAsymm8.CheckQuantizeDequantizeLayerVisited(
+ dataType == DataType::Float32 || dataType == DataType::Float16);
+}
+
+BOOST_AUTO_TEST_CASE(PreserveTypeFloat32)
+{
+ PreserveTypeTestImpl(DataType::Float32);
+}
+
+BOOST_AUTO_TEST_CASE(PreserveTypeQAsymm8)
+{
+ PreserveTypeTestImpl(DataType::QuantisedAsymm8);
+}
+
BOOST_AUTO_TEST_SUITE_END()
} // namespace armnn