aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorFinnWilliamsArm <Finn.Williams@arm.com>2019-09-06 10:04:08 +0100
committerfinn.williams <finn.williams@arm.com>2019-09-06 13:57:36 +0000
commit9e0deb76fc25be6a0e898f53a115af3aeed9e5b8 (patch)
tree9d7c0157c0c1e6d8c17336c6c8f8c7c64c93aa4d /src
parent366023fcb63d0ade5bd742ede2fc27899296877c (diff)
downloadarmnn-9e0deb76fc25be6a0e898f53a115af3aeed9e5b8.tar.gz
IVGCVSW-3742 Add Quantizer support for ABS
Signed-off-by: FinnWilliamsArm <Finn.Williams@arm.com> Change-Id: I3f8e716ae432737bb314e26792d18aa518aa1952
Diffstat (limited to 'src')
-rw-r--r--src/armnn/QuantizerVisitor.cpp7
-rw-r--r--src/armnn/QuantizerVisitor.hpp2
-rw-r--r--src/armnn/test/QuantizerTest.cpp49
3 files changed, 58 insertions, 0 deletions
diff --git a/src/armnn/QuantizerVisitor.cpp b/src/armnn/QuantizerVisitor.cpp
index cf4164ee7d..9b13056da8 100644
--- a/src/armnn/QuantizerVisitor.cpp
+++ b/src/armnn/QuantizerVisitor.cpp
@@ -113,6 +113,13 @@ void QuantizerVisitor::RecordLayer(const IConnectableLayer* srcLayer, IConnectab
m_QuantizedGuidToLayerMap[quantizedLayer->GetGuid()] = quantizedLayer;
}
+void QuantizerVisitor::VisitAbsLayer(const IConnectableLayer* layer, const char* name)
+{
+ IConnectableLayer* newLayer = m_QuantizedNetwork->AddAbsLayer(name);
+ RecordLayer(layer, newLayer);
+ SetQuantizedInputConnections(layer, newLayer);
+}
+
void QuantizerVisitor::VisitActivationLayer(const IConnectableLayer* layer,
const ActivationDescriptor& activationDescriptor,
const char* name)
diff --git a/src/armnn/QuantizerVisitor.hpp b/src/armnn/QuantizerVisitor.hpp
index 7480a0c0c6..e087fe6a61 100644
--- a/src/armnn/QuantizerVisitor.hpp
+++ b/src/armnn/QuantizerVisitor.hpp
@@ -32,6 +32,8 @@ public:
~QuantizerVisitor() = default;
/// Functions to quantize the individual layers, overridden from ILayerVisitor
+ void VisitAbsLayer(const IConnectableLayer* layer, const char* name = nullptr) override;
+
void VisitActivationLayer(const IConnectableLayer* layer,
const ActivationDescriptor& activationDescriptor,
const char* name = nullptr) override;
diff --git a/src/armnn/test/QuantizerTest.cpp b/src/armnn/test/QuantizerTest.cpp
index ff1cfc4020..7a5d27bd52 100644
--- a/src/armnn/test/QuantizerTest.cpp
+++ b/src/armnn/test/QuantizerTest.cpp
@@ -182,6 +182,7 @@ public:
}
};
+
BOOST_AUTO_TEST_CASE(QuantizeAddition)
{
INetworkPtr network = INetwork::Create();
@@ -1299,6 +1300,54 @@ BOOST_AUTO_TEST_CASE(QuantizeConstant)
VisitLayersTopologically(quantizedNetworkQSymm16.get(), validatorQSymm16);
}
+BOOST_AUTO_TEST_CASE(QuantizeAbs)
+{
+ class TestAbsQuantization : public TestLeakyReLuActivationQuantization
+ {
+ public:
+ TestAbsQuantization(const TensorShape& inputShape, const TensorShape& outputShape) :
+ TestLeakyReLuActivationQuantization(inputShape, outputShape)
+ {}
+
+ TestAbsQuantization(const QuantizerOptions& options,
+ const TensorShape& inputShape,
+ const TensorShape& outputShape) :
+ TestLeakyReLuActivationQuantization(options, inputShape, outputShape)
+ {}
+
+ void VisitAbsLayer(const IConnectableLayer *layer,
+ const char *name = nullptr) override
+ {
+ TensorInfo outputInfo = layer->GetOutputSlot(0).GetTensorInfo();
+
+ TestQuantizationParams(outputInfo,
+ { 30.0f / g_Asymm8QuantizationBase, 128 },
+ { 15.0f / g_Symm16QuantizationBase, 0 });
+ }
+ };
+
+ INetworkPtr network = INetwork::Create();
+
+ //Add the layer being tested
+ IConnectableLayer* absLayer = network->AddAbsLayer();
+
+ const TensorShape shape{1U};
+ TensorInfo info(shape, DataType::Float32);
+
+ IConnectableLayer* activation = CreateStartOfLeakyReluNetwork(network.get(), info);
+
+ CompleteLeakyReluNetwork(network.get(), activation, absLayer, info);
+
+ INetworkPtr quantizedNetworkQAsymm8 = INetworkQuantizer::Create(network.get())->ExportNetwork();
+ TestAbsQuantization validatorQAsymm8(shape, shape);
+ VisitLayersTopologically(quantizedNetworkQAsymm8.get(), validatorQAsymm8);
+
+ const QuantizerOptions options(DataType::QuantisedSymm16);
+ INetworkPtr quantizedNetworkQSymm16 = INetworkQuantizer::Create(network.get(), options)->ExportNetwork();
+ TestAbsQuantization validatorQSymm16(options, shape, shape);
+ VisitLayersTopologically(quantizedNetworkQSymm16.get(), validatorQSymm16);
+}
+
BOOST_AUTO_TEST_CASE(QuantizeConcat)
{
class TestConcatQuantization : public TestQuantization