aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/test/QuantizerTest.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/test/QuantizerTest.cpp')
-rw-r--r--src/armnn/test/QuantizerTest.cpp112
1 files changed, 106 insertions, 6 deletions
diff --git a/src/armnn/test/QuantizerTest.cpp b/src/armnn/test/QuantizerTest.cpp
index 259e90fcca..2103de062c 100644
--- a/src/armnn/test/QuantizerTest.cpp
+++ b/src/armnn/test/QuantizerTest.cpp
@@ -38,15 +38,15 @@ class TestQuantization : public LayerVisitorBase<VisitorThrowingPolicy>
public:
TestQuantization(const TensorShape& inputShape, const TensorShape& outputShape)
: LayerVisitorBase<VisitorThrowingPolicy>()
- , m_QuantizerOptions(QuantizerOptions())
, m_InputShape(inputShape)
- , m_OutputShape(outputShape) {}
+ , m_OutputShape(outputShape)
+ , m_QuantizerOptions(QuantizerOptions()) {}
TestQuantization(const QuantizerOptions& options, const TensorShape& inputShape, const TensorShape& outputShape)
: LayerVisitorBase<VisitorThrowingPolicy>()
- , m_QuantizerOptions(options)
, m_InputShape(inputShape)
- , m_OutputShape(outputShape) {}
+ , m_OutputShape(outputShape)
+ , m_QuantizerOptions(options) {}
void VisitInputLayer(const IConnectableLayer* layer,
LayerBindingId id,
@@ -91,6 +91,9 @@ protected:
TestQuantizationParamsImpl(info, DataType::QuantisedAsymm8, params.first, params.second);
}
+ TensorShape m_InputShape;
+ TensorShape m_OutputShape;
+
private:
void TestQuantizationParamsImpl(const TensorInfo& info, DataType dataType, float scale, int32_t offset)
{
@@ -100,8 +103,6 @@ private:
}
QuantizerOptions m_QuantizerOptions;
- TensorShape m_InputShape;
- TensorShape m_OutputShape;
};
void VisitLayersTopologically(const INetwork* inputNetwork, ILayerVisitor& visitor)
@@ -1574,5 +1575,104 @@ BOOST_AUTO_TEST_CASE(QuantizeNegativeInf)
BOOST_CHECK_EQUAL(SetupQuantize(-1 * std::numeric_limits<float>::infinity())[0], 0);
}
+class TestPreserveType : public TestAdditionQuantization
+{
+public:
+ TestPreserveType(const QuantizerOptions& options,
+ const DataType& dataType,
+ const TensorShape& inputShape,
+ const TensorShape& outputShape)
+ : TestAdditionQuantization(options, inputShape, outputShape)
+ , m_DataType(dataType)
+ , m_VisitedQuantizeLayer(false)
+ , m_VisitedDequantizeLayer(false) {}
+
+ void VisitInputLayer(const IConnectableLayer* layer,
+ LayerBindingId id,
+ const char* name = nullptr) override
+ {
+ const TensorInfo& info = layer->GetOutputSlot(0).GetTensorInfo();
+ BOOST_TEST(GetDataTypeName(info.GetDataType()) == GetDataTypeName(m_DataType));
+ BOOST_TEST(m_InputShape == info.GetShape());
+ }
+
+ void VisitOutputLayer(const IConnectableLayer* layer,
+ LayerBindingId id,
+ const char* name = nullptr) override
+ {
+ const TensorInfo& info = layer->GetInputSlot(0).GetConnection()->GetTensorInfo();
+ BOOST_TEST(GetDataTypeName(info.GetDataType()) == GetDataTypeName(m_DataType));
+ BOOST_TEST(m_OutputShape == info.GetShape());
+ }
+
+ void VisitQuantizeLayer(const IConnectableLayer* layer,
+ const char* name = nullptr) override
+ {
+ m_VisitedQuantizeLayer = true;
+ }
+
+ void VisitDequantizeLayer(const IConnectableLayer* layer,
+ const char* name = nullptr) override
+ {
+ m_VisitedDequantizeLayer = true;
+ }
+
+ void CheckQuantizeDequantizeLayerVisited(bool expected)
+ {
+ if (expected)
+ {
+ BOOST_CHECK(m_VisitedQuantizeLayer);
+ BOOST_CHECK(m_VisitedDequantizeLayer);
+ }
+ else
+ {
+ BOOST_CHECK(!m_VisitedQuantizeLayer);
+ BOOST_CHECK(!m_VisitedDequantizeLayer);
+ }
+ }
+private:
+ const DataType m_DataType;
+ bool m_VisitedQuantizeLayer;
+ bool m_VisitedDequantizeLayer;
+};
+
+void PreserveTypeTestImpl(const DataType& dataType)
+{
+ INetworkPtr network = INetwork::Create();
+
+ // Add the layers
+ IConnectableLayer* input0 = network->AddInputLayer(0);
+ IConnectableLayer* input1 = network->AddInputLayer(1);
+ IConnectableLayer* addition = network->AddAdditionLayer();
+ IConnectableLayer* output = network->AddOutputLayer(2);
+
+ input0->GetOutputSlot(0).Connect(addition->GetInputSlot(0));
+ input1->GetOutputSlot(0).Connect(addition->GetInputSlot(1));
+ addition->GetOutputSlot(0).Connect(output->GetInputSlot(0));
+
+ const TensorShape shape{1U, 2U, 3U};
+ const TensorInfo info(shape, dataType);
+ input0->GetOutputSlot(0).SetTensorInfo(info);
+ input1->GetOutputSlot(0).SetTensorInfo(info);
+ addition->GetOutputSlot(0).SetTensorInfo(info);
+
+ const QuantizerOptions options(DataType::QuantisedAsymm8, true);
+ INetworkPtr quantizedNetworkQAsymm8 = INetworkQuantizer::Create(network.get(), options)->ExportNetwork();
+ TestPreserveType validatorQAsymm8(options, dataType, shape, shape);
+ VisitLayersTopologically(quantizedNetworkQAsymm8.get(), validatorQAsymm8);
+ validatorQAsymm8.CheckQuantizeDequantizeLayerVisited(
+ dataType == DataType::Float32 || dataType == DataType::Float16);
+}
+
+BOOST_AUTO_TEST_CASE(PreserveTypeFloat32)
+{
+ PreserveTypeTestImpl(DataType::Float32);
+}
+
+BOOST_AUTO_TEST_CASE(PreserveTypeQAsymm8)
+{
+ PreserveTypeTestImpl(DataType::QuantisedAsymm8);
+}
+
BOOST_AUTO_TEST_SUITE_END()
} // namespace armnn