aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorKeith Davis <keith.davis@arm.com>2020-06-04 16:34:23 +0100
committerKeithARM <keith.davis@arm.com>2020-06-17 15:37:13 +0000
commit2b6a6a4fe1b97e7321a8d19f95bd515207d49302 (patch)
tree61afb65f5f5df6fa2484c03ca04d0538542dd8dd /src
parent300ad5695e2a577d2a9292b3cd6d182aae3298a3 (diff)
downloadarmnn-2b6a6a4fe1b97e7321a8d19f95bd515207d49302.tar.gz
IVGCVSW-4909 Add Quantizer Support for FILL operator
Signed-off-by: Keith Davis <keith.davis@arm.com> Change-Id: I7ac9600b8956b4fb875f2f7efa061f8dac73d53c
Diffstat (limited to 'src')
-rw-r--r--src/armnn/QuantizerVisitor.cpp9
-rw-r--r--src/armnn/QuantizerVisitor.hpp4
-rw-r--r--src/armnn/test/QuantizerTest.cpp71
3 files changed, 84 insertions, 0 deletions
diff --git a/src/armnn/QuantizerVisitor.cpp b/src/armnn/QuantizerVisitor.cpp
index 9c1ac17d70..a30a29db75 100644
--- a/src/armnn/QuantizerVisitor.cpp
+++ b/src/armnn/QuantizerVisitor.cpp
@@ -284,6 +284,15 @@ void QuantizerVisitor::VisitElementwiseUnaryLayer(const IConnectableLayer* layer
SetQuantizedInputConnections(layer, newLayer);
}
+void QuantizerVisitor::VisitFillLayer(const IConnectableLayer* layer,
+ const FillDescriptor& desc,
+ const char* name)
+{
+ IConnectableLayer* newLayer = m_QuantizedNetwork->AddFillLayer(desc, name);
+ RecordLayer(layer, newLayer);
+ SetQuantizedInputConnections(layer, newLayer);
+}
+
void QuantizerVisitor::VisitFullyConnectedLayer(const IConnectableLayer *layer,
const FullyConnectedDescriptor& desc,
const ConstTensor& weights,
diff --git a/src/armnn/QuantizerVisitor.hpp b/src/armnn/QuantizerVisitor.hpp
index 29500ab0c8..65bd67101e 100644
--- a/src/armnn/QuantizerVisitor.hpp
+++ b/src/armnn/QuantizerVisitor.hpp
@@ -89,6 +89,10 @@ public:
const ElementwiseUnaryDescriptor& elementwiseUnaryDescriptor,
const char* name = nullptr) override;
+ void VisitFillLayer(const IConnectableLayer* layer,
+ const FillDescriptor& desc,
+ const char* name) override;
+
void VisitFullyConnectedLayer(const IConnectableLayer *layer,
const FullyConnectedDescriptor& desc,
const ConstTensor& weights,
diff --git a/src/armnn/test/QuantizerTest.cpp b/src/armnn/test/QuantizerTest.cpp
index 669703ca54..a3c458112d 100644
--- a/src/armnn/test/QuantizerTest.cpp
+++ b/src/armnn/test/QuantizerTest.cpp
@@ -1143,6 +1143,77 @@ void ValidateFullyConnectedLayer(const bool biasEnabled)
VisitLayersTopologically(quantizedNetworkQSymmS16.get(), validatorQSymmS16);
}
+BOOST_AUTO_TEST_CASE(QuantizeFill)
+{
+ class TestFillQuantization : public TestQuantization
+ {
+ public:
+ TestFillQuantization(const TensorShape& inputShape, const TensorShape& outputShape)
+ : TestQuantization(inputShape, outputShape) {}
+
+ TestFillQuantization(const QuantizerOptions& options,
+ const TensorShape& inputShape,
+ const TensorShape& outputShape)
+ : TestQuantization(options, inputShape, outputShape) {}
+
+ virtual void VisitFillLayer(const IConnectableLayer* layer,
+ const FillDescriptor& desc,
+ const char* name = nullptr)
+ {
+ IgnoreUnused(desc, name);
+ TensorInfo info = layer->GetOutputSlot(0).GetTensorInfo();
+
+ const OffsetScalePair qAsymmU8Params{ 30.0f / g_AsymmU8QuantizationBase, 128 };
+ const OffsetScalePair qAsymmS8Params { 30.0f / g_AsymmS8QuantizationBase, 0};
+ const OffsetScalePair qSymmS8Params { 15.0f / g_SymmS8QuantizationBase, 0};
+ const OffsetScalePair qSymmS16Params{ 15.0f / g_SymmS16QuantizationBase, 0 };
+
+ TestQuantizationParams(info, qAsymmU8Params, qAsymmS8Params, qSymmS8Params, qSymmS16Params);
+ }
+ };
+
+ const TensorShape tensorShape{ 1U };
+ const TensorInfo tensorInfo(tensorShape, DataType::Float32);
+
+ INetworkPtr network = INetwork::Create();
+
+ FillDescriptor descriptor;
+ descriptor.m_Value = 1;
+
+ IConnectableLayer* inputLayer = network->AddInputLayer(0);
+ IConnectableLayer* fillLayer = network->AddFillLayer(descriptor);
+ IConnectableLayer* outputLayer = network->AddOutputLayer(0);
+
+ inputLayer->GetOutputSlot(0).Connect(fillLayer->GetInputSlot(0));
+ fillLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
+
+ inputLayer->GetOutputSlot(0).SetTensorInfo(tensorInfo);
+ fillLayer->GetOutputSlot(0).SetTensorInfo(tensorInfo);
+
+ // test QAsymmU8 quantization
+ INetworkPtr quantizedNetworkQAsymmU8 = INetworkQuantizer::Create(network.get())->ExportNetwork();
+ TestFillQuantization validatorQAsymmU8(tensorShape, tensorShape);
+ VisitLayersTopologically(quantizedNetworkQAsymmU8.get(), validatorQAsymmU8);
+
+ // test QAsymmS8 quantization
+ const QuantizerOptions qAsymmS8Options(DataType::QAsymmS8);
+ INetworkPtr quantizedNetworkQAsymmS8 = INetworkQuantizer::Create(network.get(), qAsymmS8Options)->ExportNetwork();
+ TestFillQuantization validatorQAsymmS8(qAsymmS8Options, tensorShape, tensorShape);
+ VisitLayersTopologically(quantizedNetworkQAsymmS8.get(), validatorQAsymmS8);
+
+ // test QSymmS8 quantization
+ const QuantizerOptions qSymmS8Options(DataType::QSymmS8);
+ INetworkPtr quantizedNetworkQSymmS8 = INetworkQuantizer::Create(network.get(), qSymmS8Options)->ExportNetwork();
+ TestFillQuantization validatorQSymmS8(qSymmS8Options, tensorShape, tensorShape);
+ VisitLayersTopologically(quantizedNetworkQSymmS8.get(), validatorQSymmS8);
+
+ // test QuantisedSymmS16 quantization
+ const QuantizerOptions qSymmS16options(DataType::QSymmS16);
+ INetworkPtr quantizedNetworkQSymmS16 = INetworkQuantizer::Create(network.get(), qSymmS16options)->ExportNetwork();
+ TestFillQuantization validatorQSymmS16(qSymmS16options, tensorShape, tensorShape);
+ VisitLayersTopologically(quantizedNetworkQSymmS16.get(), validatorQSymmS16);
+}
+
BOOST_AUTO_TEST_CASE(QuantizeFullyConnected)
{
ValidateFullyConnectedLayer(false);