aboutsummaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2019-06-04 13:04:16 +0100
committerGeorgios Pinitas <georgios.pinitas@arm.com>2019-06-24 14:56:23 +0000
commit3d13af8a39f408318328a95d5329bc17fd923438 (patch)
treeb0d9c82062e229f8938d2c9f762ee67758196bf3 /tests
parentdb09b3783ff9af67c6d373b12aa9a6aff3c5d0f1 (diff)
downloadComputeLibrary-3d13af8a39f408318328a95d5329bc17fd923438.tar.gz
COMPMID-2235: Extend type support for CL/NEON DequantizationLayer.
Adds support for: - QSYMM8 Change-Id: Ia0b839fc844ce0f968dad1b69a001f9a660dbcd5 Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com> Reviewed-on: https://review.mlplatform.org/c/1378 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Manuel Bottini <manuel.bottini@arm.com> Reviewed-by: Michalis Spyrou <michalis.spyrou@arm.com>
Diffstat (limited to 'tests')
-rw-r--r--tests/AssetsLibrary.h3
-rw-r--r--tests/datasets/DatatypeDataset.h53
-rw-r--r--tests/validation/CL/DequantizationLayer.cpp21
-rw-r--r--tests/validation/CL/UNIT/TensorAllocator.cpp6
-rw-r--r--tests/validation/NEON/DequantizationLayer.cpp21
-rw-r--r--tests/validation/UNIT/TensorInfo.cpp20
-rw-r--r--tests/validation/fixtures/DequantizationLayerFixture.h55
-rw-r--r--tests/validation/reference/DequantizationLayer.cpp74
-rw-r--r--tests/validation/reference/DequantizationLayer.h4
9 files changed, 196 insertions, 61 deletions
diff --git a/tests/AssetsLibrary.h b/tests/AssetsLibrary.h
index 5c8019bdff..2f2665f381 100644
--- a/tests/AssetsLibrary.h
+++ b/tests/AssetsLibrary.h
@@ -634,6 +634,7 @@ void AssetsLibrary::fill_tensor_uniform(T &&tensor, std::random_device::result_t
break;
}
case DataType::S8:
+ case DataType::QSYMM8:
{
std::uniform_int_distribution<int8_t> distribution_s8(std::numeric_limits<int8_t>::lowest(), std::numeric_limits<int8_t>::max());
fill(tensor, distribution_s8, seed_offset);
@@ -728,6 +729,7 @@ void AssetsLibrary::fill_tensor_uniform_ranged(T
break;
}
case DataType::S8:
+ case DataType::QSYMM8:
{
const auto converted_pairs = detail::convert_range_pair<int8_t>(excluded_range_pairs);
RangedUniformDistribution<int8_t> distribution_s8(std::numeric_limits<int8_t>::lowest(),
@@ -808,6 +810,7 @@ void AssetsLibrary::fill_tensor_uniform(T &&tensor, std::random_device::result_t
break;
}
case DataType::S8:
+ case DataType::QSYMM8:
{
ARM_COMPUTE_ERROR_ON(!(std::is_same<int8_t, D>::value));
std::uniform_int_distribution<int8_t> distribution_s8(low, high);
diff --git a/tests/datasets/DatatypeDataset.h b/tests/datasets/DatatypeDataset.h
new file mode 100644
index 0000000000..bb2774b4b3
--- /dev/null
+++ b/tests/datasets/DatatypeDataset.h
@@ -0,0 +1,53 @@
+/*
+ * Copyright (c) 2019 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef __ARM_COMPUTE_TEST_DATATYPE_DATASET_H__
+#define __ARM_COMPUTE_TEST_DATATYPE_DATASET_H__
+
+#include "arm_compute/core/Types.h"
+#include "tests/framework/datasets/ContainerDataset.h"
+
+#include <vector>
+
+namespace arm_compute
+{
+namespace test
+{
+namespace datasets
+{
+class QuantizedTypes final : public framework::dataset::ContainerDataset<std::vector<DataType>>
+{
+public:
+ QuantizedTypes()
+ : ContainerDataset("QuantizedTypes",
+ {
+ DataType::QSYMM8,
+ DataType::QASYMM8,
+ })
+ {
+ }
+};
+} // namespace datasets
+} // namespace test
+} // namespace arm_compute
+#endif /* __ARM_COMPUTE_TEST_DATATYPE_DATASET_H__ */
diff --git a/tests/validation/CL/DequantizationLayer.cpp b/tests/validation/CL/DequantizationLayer.cpp
index b1b0d81c6d..2ef8c60998 100644
--- a/tests/validation/CL/DequantizationLayer.cpp
+++ b/tests/validation/CL/DequantizationLayer.cpp
@@ -27,6 +27,7 @@
#include "arm_compute/runtime/CL/functions/CLDequantizationLayer.h"
#include "tests/CL/CLAccessor.h"
#include "tests/PaddingCalculator.h"
+#include "tests/datasets/DatatypeDataset.h"
#include "tests/datasets/ShapeDatasets.h"
#include "tests/framework/Asserts.h"
#include "tests/framework/Macros.h"
@@ -96,16 +97,14 @@ template <typename T>
using CLDequantizationLayerFixture = DequantizationValidationFixture<CLTensor, CLAccessor, CLDequantizationLayer, T>;
TEST_SUITE(FP16)
-FIXTURE_DATA_TEST_CASE(RunSmall, CLDequantizationLayerFixture<half>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(),
- framework::dataset::make("DataType", DataType::F16)),
- framework::dataset::make("QuantizationInfo", { QuantizationInfo(0.1f, 128.0f) })))
+FIXTURE_DATA_TEST_CASE(RunSmall, CLDequantizationLayerFixture<half>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), datasets::QuantizedTypes()),
+ framework::dataset::make("DataType", DataType::F16)))
{
// Validate output
validate(CLAccessor(_target), _reference);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, CLDequantizationLayerFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(),
- framework::dataset::make("DataType", DataType::F16)),
- framework::dataset::make("QuantizationInfo", { QuantizationInfo(0.1f, 128.0f) })))
+FIXTURE_DATA_TEST_CASE(RunLarge, CLDequantizationLayerFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), datasets::QuantizedTypes()),
+ framework::dataset::make("DataType", DataType::F16)))
{
// Validate output
validate(CLAccessor(_target), _reference);
@@ -113,16 +112,14 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLDequantizationLayerFixture<half>, framework::
TEST_SUITE_END() // FP16
TEST_SUITE(FP32)
-FIXTURE_DATA_TEST_CASE(RunSmall, CLDequantizationLayerFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(),
- framework::dataset::make("DataType", DataType::F32)),
- framework::dataset::make("QuantizationInfo", { QuantizationInfo(0.1f, 128.0f) })))
+FIXTURE_DATA_TEST_CASE(RunSmall, CLDequantizationLayerFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), datasets::QuantizedTypes()),
+ framework::dataset::make("DataType", DataType::F32)))
{
// Validate output
validate(CLAccessor(_target), _reference);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, CLDequantizationLayerFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(),
- framework::dataset::make("DataType", DataType::F32)),
- framework::dataset::make("QuantizationInfo", { QuantizationInfo(0.1f, 128.0f) })))
+FIXTURE_DATA_TEST_CASE(RunLarge, CLDequantizationLayerFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), datasets::QuantizedTypes()),
+ framework::dataset::make("DataType", DataType::F32)))
{
// Validate output
validate(CLAccessor(_target), _reference);
diff --git a/tests/validation/CL/UNIT/TensorAllocator.cpp b/tests/validation/CL/UNIT/TensorAllocator.cpp
index 4b8e105240..d91f4dd022 100644
--- a/tests/validation/CL/UNIT/TensorAllocator.cpp
+++ b/tests/validation/CL/UNIT/TensorAllocator.cpp
@@ -249,9 +249,9 @@ TEST_CASE(Symm8PerChannelQuantizationInfo, framework::DatasetMode::ALL)
// Check quantization information
ARM_COMPUTE_EXPECT(!tensor.info()->quantization_info().empty(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(!tensor.info()->quantization_info().scale.empty(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(tensor.info()->quantization_info().scale.size() == scale.size(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(tensor.info()->quantization_info().offset.empty(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(!tensor.info()->quantization_info().scale().empty(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(tensor.info()->quantization_info().scale().size() == scale.size(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(tensor.info()->quantization_info().offset().empty(), framework::LogLevel::ERRORS);
CLQuantization quantization = tensor.quantization();
ARM_COMPUTE_ASSERT(quantization.scale != nullptr);
diff --git a/tests/validation/NEON/DequantizationLayer.cpp b/tests/validation/NEON/DequantizationLayer.cpp
index 0ae20b7b5d..a4606fe8a0 100644
--- a/tests/validation/NEON/DequantizationLayer.cpp
+++ b/tests/validation/NEON/DequantizationLayer.cpp
@@ -27,6 +27,7 @@
#include "arm_compute/runtime/TensorAllocator.h"
#include "tests/NEON/Accessor.h"
#include "tests/PaddingCalculator.h"
+#include "tests/datasets/DatatypeDataset.h"
#include "tests/datasets/ShapeDatasets.h"
#include "tests/framework/Asserts.h"
#include "tests/framework/Macros.h"
@@ -106,16 +107,14 @@ using NEDequantizationLayerFixture = DequantizationValidationFixture<Tensor, Acc
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
TEST_SUITE(FP16)
-FIXTURE_DATA_TEST_CASE(RunSmall, NEDequantizationLayerFixture<half>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(),
- framework::dataset::make("DataType", DataType::F16)),
- framework::dataset::make("QuantizationInfo", { QuantizationInfo(0.1f, 128.0f) })))
+FIXTURE_DATA_TEST_CASE(RunSmall, NEDequantizationLayerFixture<half>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), datasets::QuantizedTypes()),
+ framework::dataset::make("DataType", DataType::F16)))
{
// Validate output
validate(Accessor(_target), _reference);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, NEDequantizationLayerFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(),
- framework::dataset::make("DataType", DataType::F16)),
- framework::dataset::make("QuantizationInfo", { QuantizationInfo(0.1f, 128.0f) })))
+FIXTURE_DATA_TEST_CASE(RunLarge, NEDequantizationLayerFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), datasets::QuantizedTypes()),
+ framework::dataset::make("DataType", DataType::F16)))
{
// Validate output
validate(Accessor(_target), _reference);
@@ -124,16 +123,14 @@ TEST_SUITE_END() // FP16
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
TEST_SUITE(FP32)
-FIXTURE_DATA_TEST_CASE(RunSmall, NEDequantizationLayerFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(),
- framework::dataset::make("DataType", DataType::F32)),
- framework::dataset::make("QuantizationInfo", { QuantizationInfo(0.1f, 128.0f) })))
+FIXTURE_DATA_TEST_CASE(RunSmall, NEDequantizationLayerFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), datasets::QuantizedTypes()),
+ framework::dataset::make("DataType", DataType::F32)))
{
// Validate output
validate(Accessor(_target), _reference);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, NEDequantizationLayerFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(),
- framework::dataset::make("DataType", DataType::F32)),
- framework::dataset::make("QuantizationInfo", { QuantizationInfo(0.1f, 128.0f) })))
+FIXTURE_DATA_TEST_CASE(RunLarge, NEDequantizationLayerFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), datasets::QuantizedTypes()),
+ framework::dataset::make("DataType", DataType::F32)))
{
// Validate output
validate(Accessor(_target), _reference);
diff --git a/tests/validation/UNIT/TensorInfo.cpp b/tests/validation/UNIT/TensorInfo.cpp
index 96d07da2b4..009c757925 100644
--- a/tests/validation/UNIT/TensorInfo.cpp
+++ b/tests/validation/UNIT/TensorInfo.cpp
@@ -141,9 +141,9 @@ TEST_CASE(SymmQuantizationInfo, framework::DatasetMode::ALL)
// Check quantization information
ARM_COMPUTE_EXPECT(!info.quantization_info().empty(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(!info.quantization_info().scale.empty(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(info.quantization_info().scale.size() == 1, framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(info.quantization_info().offset.empty(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(!info.quantization_info().scale().empty(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(info.quantization_info().scale().size() == 1, framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(info.quantization_info().offset().empty(), framework::LogLevel::ERRORS);
UniformQuantizationInfo qinfo = info.quantization_info().uniform();
ARM_COMPUTE_EXPECT(qinfo.scale == scale, framework::LogLevel::ERRORS);
@@ -160,10 +160,10 @@ TEST_CASE(AsymmQuantizationInfo, framework::DatasetMode::ALL)
// Check quantization information
ARM_COMPUTE_EXPECT(!info.quantization_info().empty(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(!info.quantization_info().scale.empty(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(info.quantization_info().scale.size() == 1, framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(!info.quantization_info().offset.empty(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(info.quantization_info().offset.size() == 1, framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(!info.quantization_info().scale().empty(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(info.quantization_info().scale().size() == 1, framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(!info.quantization_info().offset().empty(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(info.quantization_info().offset().size() == 1, framework::LogLevel::ERRORS);
UniformQuantizationInfo qinfo = info.quantization_info().uniform();
ARM_COMPUTE_EXPECT(qinfo.scale == scale, framework::LogLevel::ERRORS);
@@ -179,9 +179,9 @@ TEST_CASE(SymmPerChannelQuantizationInfo, framework::DatasetMode::ALL)
// Check quantization information
ARM_COMPUTE_EXPECT(!info.quantization_info().empty(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(!info.quantization_info().scale.empty(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(info.quantization_info().scale.size() == scale.size(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(info.quantization_info().offset.empty(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(!info.quantization_info().scale().empty(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(info.quantization_info().scale().size() == scale.size(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(info.quantization_info().offset().empty(), framework::LogLevel::ERRORS);
}
TEST_SUITE_END() // TensorInfoValidation
diff --git a/tests/validation/fixtures/DequantizationLayerFixture.h b/tests/validation/fixtures/DequantizationLayerFixture.h
index 2e3712dff2..15f3711189 100644
--- a/tests/validation/fixtures/DequantizationLayerFixture.h
+++ b/tests/validation/fixtures/DequantizationLayerFixture.h
@@ -47,10 +47,11 @@ class DequantizationValidationFixture : public framework::Fixture
{
public:
template <typename...>
- void setup(TensorShape shape, DataType data_type, QuantizationInfo qinfo)
+ void setup(TensorShape shape, DataType src_data_type, DataType dst_datatype)
{
- _target = compute_target(shape, data_type, qinfo);
- _reference = compute_reference(shape, data_type, qinfo);
+ _quantization_info = generate_quantization_info(src_data_type);
+ _target = compute_target(shape, src_data_type, dst_datatype);
+ _reference = compute_reference(shape, src_data_type);
}
protected:
@@ -60,11 +61,11 @@ protected:
library->fill_tensor_uniform(tensor, 0);
}
- TensorType compute_target(const TensorShape &shape, DataType data_type, QuantizationInfo qinfo)
+ TensorType compute_target(const TensorShape &shape, DataType src_data_type, DataType dst_datatype)
{
// Create tensors
- TensorType src = create_tensor<TensorType>(shape, DataType::QASYMM8, 1, qinfo);
- TensorType dst = create_tensor<TensorType>(shape, data_type);
+ TensorType src = create_tensor<TensorType>(shape, src_data_type, 1, _quantization_info);
+ TensorType dst = create_tensor<TensorType>(shape, dst_datatype);
// Create and configure function
FunctionType dequantization_layer;
@@ -89,19 +90,43 @@ protected:
return dst;
}
- SimpleTensor<T> compute_reference(const TensorShape &shape, DataType data_type, QuantizationInfo qinfo)
+ SimpleTensor<T> compute_reference(const TensorShape &shape, DataType src_data_type)
{
- // Create reference
- SimpleTensor<uint8_t> src{ shape, DataType::QASYMM8, 1, qinfo };
-
- // Fill reference
- fill(src);
+ if(is_data_type_quantized_asymmetric(src_data_type))
+ {
+ SimpleTensor<uint8_t> src{ shape, src_data_type, 1, _quantization_info };
+ fill(src);
+ return reference::dequantization_layer<T>(src);
+ }
+ else
+ {
+ SimpleTensor<int8_t> src{ shape, src_data_type, 1, _quantization_info };
+ fill(src);
+ return reference::dequantization_layer<T>(src);
+ }
+ }
- return reference::dequantization_layer<T>(src);
+protected:
+ QuantizationInfo generate_quantization_info(DataType data_type)
+ {
+ std::uniform_int_distribution<> distribution(1, 127);
+ std::mt19937 gen(library.get()->seed());
+
+ switch(data_type)
+ {
+ case DataType::QSYMM8:
+ return QuantizationInfo(1.f / distribution(gen));
+ case DataType::QASYMM8:
+ return QuantizationInfo(1.f / distribution(gen), distribution(gen));
+ default:
+ ARM_COMPUTE_ERROR("Unsupported data type");
+ }
}
- TensorType _target{};
- SimpleTensor<T> _reference{};
+protected:
+ TensorType _target{};
+ SimpleTensor<T> _reference{};
+ QuantizationInfo _quantization_info{};
};
} // namespace validation
} // namespace test
diff --git a/tests/validation/reference/DequantizationLayer.cpp b/tests/validation/reference/DequantizationLayer.cpp
index 286a609d79..d07371c883 100644
--- a/tests/validation/reference/DequantizationLayer.cpp
+++ b/tests/validation/reference/DequantizationLayer.cpp
@@ -23,6 +23,8 @@
*/
#include "DequantizationLayer.h"
+#include "Permute.h"
+
namespace arm_compute
{
namespace test
@@ -31,24 +33,82 @@ namespace validation
{
namespace reference
{
-template <typename T>
-SimpleTensor<T> dequantization_layer(const SimpleTensor<uint8_t> &src)
+namespace
+{
+template <typename TOut>
+TOut dequantize(int8_t val, const UniformQuantizationInfo qinfo)
+{
+ return static_cast<TOut>(dequantize_qsymm8(val, qinfo));
+}
+template <typename TOut>
+TOut dequantize(uint8_t val, const UniformQuantizationInfo qinfo)
+{
+ return static_cast<TOut>(dequantize_qasymm8(val, qinfo));
+}
+
+template <typename TOut, typename TIn>
+SimpleTensor<TOut> dequantization_layer_nchw(const SimpleTensor<TIn> &src)
{
- const DataType dst_data_type = std::is_same<T, float>::value ? DataType::F32 : DataType::F16;
- const UniformQuantizationInfo &quantization_info = src.quantization_info().uniform();
+ const DataType src_data_type = src.data_type();
+ const DataType dst_data_type = std::is_same<TOut, float>::value ? DataType::F32 : DataType::F16;
- SimpleTensor<T> dst{ src.shape(), dst_data_type };
+ SimpleTensor<TOut> dst{ src.shape(), dst_data_type };
- for(int i = 0; i < src.num_elements(); ++i)
+ if(src_data_type == DataType::QSYMM8_PER_CHANNEL)
{
- dst[i] = static_cast<T>(dequantize_qasymm8(src[i], quantization_info));
+ const int WH = src.shape().x() * src.shape().y();
+ const int C = src.shape().z();
+ const int N = src.shape().total_size() / (WH * C);
+
+ const std::vector<float> qscales = src.quantization_info().scale();
+
+ for(int n = 0; n < N; ++n)
+ {
+ for(int c = 0; c < C; ++c)
+ {
+ const size_t idx = n * C * WH + c * WH;
+ const UniformQuantizationInfo channel_qinfo = { qscales[c], 0 };
+
+ // Dequantize slice
+ for(int s = 0; s < WH; ++s)
+ {
+ dst[idx + s] = dequantize<TOut>(src[idx + s], channel_qinfo);
+ }
+ }
+ }
+ }
+ else
+ {
+ const UniformQuantizationInfo &quantization_info = src.quantization_info().uniform();
+ ARM_COMPUTE_ERROR_ON(quantization_info.offset != 0 && src_data_type == DataType::QSYMM8);
+
+ for(int i = 0; i < src.num_elements(); ++i)
+ {
+ dst[i] = static_cast<TOut>(dequantize<TOut>(src[i], quantization_info));
+ }
}
return dst;
}
+} // namespace
+template <typename TOut, typename TIn>
+SimpleTensor<TOut> dequantization_layer(const SimpleTensor<TIn> &src)
+{
+ if(src.data_layout() == DataLayout::NHWC && src.data_type() == DataType::QSYMM8_PER_CHANNEL)
+ {
+ SimpleTensor<TIn> src_nchw = reference::permute<TIn>(src, PermutationVector(1U, 2U, 0U));
+ return reference::permute<TOut>(dequantization_layer_nchw<TOut>(src_nchw), PermutationVector(2U, 0U, 1U));
+ }
+ else
+ {
+ return dequantization_layer_nchw<TOut>(src);
+ }
+}
template SimpleTensor<half> dequantization_layer(const SimpleTensor<uint8_t> &src);
template SimpleTensor<float> dequantization_layer(const SimpleTensor<uint8_t> &src);
+template SimpleTensor<half> dequantization_layer(const SimpleTensor<int8_t> &src);
+template SimpleTensor<float> dequantization_layer(const SimpleTensor<int8_t> &src);
} // namespace reference
} // namespace validation
} // namespace test
diff --git a/tests/validation/reference/DequantizationLayer.h b/tests/validation/reference/DequantizationLayer.h
index 1d0e54b442..8c780849fd 100644
--- a/tests/validation/reference/DequantizationLayer.h
+++ b/tests/validation/reference/DequantizationLayer.h
@@ -35,8 +35,8 @@ namespace validation
{
namespace reference
{
-template <typename T>
-SimpleTensor<T> dequantization_layer(const SimpleTensor<uint8_t> &src);
+template <typename TOut, typename TIn>
+SimpleTensor<TOut> dequantization_layer(const SimpleTensor<TIn> &src);
} // namespace reference
} // namespace validation
} // namespace test