From 5324782821dce525cf7c1636c659f998fae2fb85 Mon Sep 17 00:00:00 2001 From: Ferran Balaguer Date: Mon, 18 Feb 2019 12:47:35 +0000 Subject: IVGCVSW-2613 Support static quantization of BatchToSpace Change-Id: I44b12c5c246b7aacc789420dbe55a16efaab6f98 --- src/armnn/test/QuantizerTest.cpp | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) (limited to 'src/armnn/test') diff --git a/src/armnn/test/QuantizerTest.cpp b/src/armnn/test/QuantizerTest.cpp index 3edea9cc6e..548203a6a9 100644 --- a/src/armnn/test/QuantizerTest.cpp +++ b/src/armnn/test/QuantizerTest.cpp @@ -1176,6 +1176,7 @@ BOOST_AUTO_TEST_CASE(QuantizeStridedSlice) // Add the layer under test StridedSliceDescriptor stridedSliceDesc; IConnectableLayer* stridedSlice = network->AddStridedSliceLayer(stridedSliceDesc); + CompleteLeakyReluNetwork(network.get(), activation, stridedSlice, info); auto quantizedNetwork = INetworkQuantizer::Create(network.get())->ExportNetwork(); @@ -1183,5 +1184,36 @@ BOOST_AUTO_TEST_CASE(QuantizeStridedSlice) VisitLayersTopologically(quantizedNetwork.get(), validator); } +BOOST_AUTO_TEST_CASE(QuantizeBatchToSpace) +{ + class TestBatchToSpaceQuantization : public TestLeakyReLuActivationQuantization + { + public: + void VisitBatchToSpaceNdLayer(const IConnectableLayer* layer, + const BatchToSpaceNdDescriptor& batchToSpaceNdDescriptor, + const char* name = nullptr) override + { + CheckForwardedQuantizationSettings(layer); + } + }; + + INetworkPtr network = INetwork::Create(); + + TensorShape shape{1U}; + TensorInfo info(shape, DataType::Float32); + + IConnectableLayer* activation = CreateStartOfLeakyReluNetwork(network.get(), info); + + // Add the layer under test + BatchToSpaceNdDescriptor descriptor; + IConnectableLayer* batchToSpace = network->AddBatchToSpaceNdLayer(descriptor); + + CompleteLeakyReluNetwork(network.get(), activation, batchToSpace, info); + + auto quantizedNetwork = INetworkQuantizer::Create(network.get())->ExportNetwork(); + TestBatchToSpaceQuantization validator; + VisitLayersTopologically(quantizedNetwork.get(), validator); +} + BOOST_AUTO_TEST_SUITE_END() } // namespace armnn -- cgit v1.2.1