aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohn Kesapides <john.kesapides@arm.com>2019-03-04 16:29:22 +0000
committerPablo Marquez <pablo.tello@arm.com>2019-03-13 13:34:48 +0000
commitadfb2737046028c042f0aecaff87733a442da29f (patch)
tree23b08fb9529075277e51dc1ae7e6489f690c9698
parent381fcf20c3ee028e14c154ff4b75bc7410f91168 (diff)
downloadComputeLibrary-adfb2737046028c042f0aecaff87733a442da29f.tar.gz
COMPMID-1935 Add support for QASYMM8 in NEQuantizeLayer
Change-Id: I2b63a644d8e34f91c830d9ac398debcbdca3e497 Signed-off-by: John Kesapides <john.kesapides@arm.com> Reviewed-on: https://review.mlplatform.org/c/829 Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--arm_compute/core/NEON/kernels/NEQuantizationLayerKernel.h24
-rw-r--r--arm_compute/runtime/NEON/functions/NEQuantizationLayer.h27
-rw-r--r--src/core/NEON/kernels/NEQuantizationLayerKernel.cpp194
-rw-r--r--src/core/Rounding.cpp8
-rw-r--r--src/runtime/NEON/functions/NEQuantizationLayer.cpp35
-rw-r--r--tests/validation/NEON/QuantizationLayer.cpp53
-rw-r--r--tests/validation/fixtures/QuantizationLayerFixture.h65
-rw-r--r--tests/validation/reference/QuantizationLayer.cpp21
-rw-r--r--tests/validation/reference/QuantizationLayer.h5
9 files changed, 245 insertions, 187 deletions
diff --git a/arm_compute/core/NEON/kernels/NEQuantizationLayerKernel.h b/arm_compute/core/NEON/kernels/NEQuantizationLayerKernel.h
index ca7658bb7e..391a72c6db 100644
--- a/arm_compute/core/NEON/kernels/NEQuantizationLayerKernel.h
+++ b/arm_compute/core/NEON/kernels/NEQuantizationLayerKernel.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -54,32 +54,30 @@ public:
NEQuantizationLayerKernel &operator=(NEQuantizationLayerKernel &&) = default;
/** Default destructor */
~NEQuantizationLayerKernel() = default;
- /** Set the input, output, min and max.
+ /** Set the input, output.
*
- * @param[in] input Source tensor with at least 3 dimensions. The dimensions over the third will be interpreted as batches. Data types supported: F32.
- * @param[out] output Destination tensor with the same dimensions of input. Data types supported: U8.
- * @param[in] min_max Pointer to the tensor with shape [2, batches] which stores the minimum and maximum value for each 3D input tensor.
- * The dimensions over the second must match the batched dimensions of the input tensor. Data type supported: F32
+ * @param[in] input Source tensor. The dimensions over the third will be interpreted as batches. Data types supported: F32/F16.
+ * @param[out] output Destination tensor with the same dimensions of input. Data types supported: QASYMM8.
*/
- void configure(const ITensor *input, ITensor *output, const ITensor *min_max);
+ void configure(const ITensor *input, ITensor *output);
/** Static function to check if given info will lead to a valid configuration of @ref NEQuantizationLayerKernel
*
- * @param[in] input Input tensor info. Data types supported: F32.
- * @param[in] output Output tensor info. Data types supported: U8.
- * @param[in] min_max Info for the tensor with shape [2, batches] which stores the minimum and maximum value for each 3D input tensor.
- * The dimensions over the second must match the batched dimensions of the input tensor. Data type supported: F32.
+ * @param[in] input Input tensor info. Data types supported: F32/F16.
+ * @param[in] output Output tensor info. Data types supported: QASYMM8.
*
* @return a status
*/
- static Status validate(const ITensorInfo *input, const ITensorInfo *output, const ITensorInfo *min_max);
+ static Status validate(const ITensorInfo *input, const ITensorInfo *output);
// Inherited methods overridden:
void run(const Window &window, const ThreadInfo &info) override;
private:
+ template <typename T>
+ void quantize(const Window &window, const QuantizationInfo &qinfo);
+
const ITensor *_input;
ITensor *_output;
- const ITensor *_min_max;
};
} // namespace arm_compute
#endif /*__ARM_COMPUTE_NEQUANTIZATIONLAYERKERNEL_H__ */
diff --git a/arm_compute/runtime/NEON/functions/NEQuantizationLayer.h b/arm_compute/runtime/NEON/functions/NEQuantizationLayer.h
index 9cc1666b4c..9ca199d1ee 100644
--- a/arm_compute/runtime/NEON/functions/NEQuantizationLayer.h
+++ b/arm_compute/runtime/NEON/functions/NEQuantizationLayer.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -26,9 +26,8 @@
#include "arm_compute/runtime/IFunction.h"
-#include "arm_compute/core/NEON/kernels/NEMinMaxLayerKernel.h"
#include "arm_compute/core/NEON/kernels/NEQuantizationLayerKernel.h"
-#include "arm_compute/runtime/Tensor.h"
+#include "arm_compute/runtime/NEON/INESimpleFunctionNoBorder.h"
#include "arm_compute/core/Types.h"
@@ -38,39 +37,29 @@ class ITensor;
/** Basic function to simulate a quantization layer. This function calls the following NEON kernels:
*
- * @note The implementation supports only 3D input tensors
*
- * -# @ref NEMinMaxLayerKernel
* -# @ref NEQuantizationLayerKernel
*
*/
-class NEQuantizationLayer : public IFunction
+class NEQuantizationLayer : public INESimpleFunctionNoBorder
{
public:
/** Default constructor */
- NEQuantizationLayer();
+ NEQuantizationLayer() = default;
/** Set the input and output tensors.
*
- * @param[in] input Source tensor with at least 3 dimensions. The dimensions over the third will be interpreted as batches. Data types supported: F32
- * @param[out] output Destination tensor with the same dimensions of input. Data types supported: U8
+ * @param[in] input Source tensor. The dimensions over the third will be interpreted as batches. Data types supported: F32
+ * @param[out] output Destination tensor with the same dimensions of input. Data types supported: QASYMM8
*/
void configure(const ITensor *input, ITensor *output);
/** Static function to check if given info will lead to a valid configuration of @ref NEQuantizationLayer
*
* @param[in] input Input tensor info. The dimensions over the third will be interpreted as batches. Data types supported: F32.
- * @param[in] output Output tensor info. Data types supported: U8
+ * @param[in] output Output tensor info. Data types supported: QASYMM8
*
* @return a status
*/
static Status validate(const ITensorInfo *input, const ITensorInfo *output);
-
- // Inherited methods overridden:
- void run() override;
-
-private:
- NEQuantizationLayerKernel _quantize_kernel;
- NEMinMaxLayerKernel _min_max_kernel;
- Tensor _min_max;
};
-}
+} // namespace arm_compute
#endif /* __ARM_COMPUTE_NEQUANTIZATIONLAYER_H__ */
diff --git a/src/core/NEON/kernels/NEQuantizationLayerKernel.cpp b/src/core/NEON/kernels/NEQuantizationLayerKernel.cpp
index b49400ab7d..136457c34e 100644
--- a/src/core/NEON/kernels/NEQuantizationLayerKernel.cpp
+++ b/src/core/NEON/kernels/NEQuantizationLayerKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -23,152 +23,140 @@
*/
#include "arm_compute/core/NEON/kernels/NEQuantizationLayerKernel.h"
-#include "arm_compute/core/AccessWindowStatic.h"
#include "arm_compute/core/Error.h"
#include "arm_compute/core/Helpers.h"
+#include "arm_compute/core/NEON/NEAsymm.h"
+#include "arm_compute/core/NEON/wrapper/wrapper.h"
#include "arm_compute/core/Utils.h"
#include "arm_compute/core/Validate.h"
#include "arm_compute/core/Window.h"
+#include "arm_compute/core/CPP/Validate.h"
+
#include <arm_neon.h>
using namespace arm_compute;
namespace
{
-Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const ITensorInfo *min_max)
+Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output)
{
- ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output, min_max);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON(input->num_dimensions() < 3);
-
- if(output->tensor_shape().total_size() > 0)
- {
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
- }
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
+ ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON(output->tensor_shape().total_size() == 0);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QASYMM8);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
return Status{};
}
-std::tuple<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output, ITensorInfo *min_max)
+inline const float32x4x4_t load_value(const float *input_ptr)
{
- // Output tensor auto initialization if not yet initialized
- auto_init_if_empty(*output, input->tensor_shape(), 1, DataType::U8);
-
- constexpr unsigned int num_elems_processed_per_iteration = 8;
-
- // Configure window
- Window win = calculate_max_window(*input, Steps(num_elems_processed_per_iteration));
- AccessWindowHorizontal input_access(input, 0, num_elems_processed_per_iteration);
- AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration);
- AccessWindowStatic min_max_access(min_max, 0, 0, 2, min_max->dimension(1));
-
- // Update window and padding
- bool window_changed = update_window_and_padding(win, input_access, output_access, min_max_access);
-
- output_access.set_valid_region(win, input->valid_region());
-
- Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
- return std::make_tuple(err, win);
+ return { wrapper::vloadq(input_ptr),
+ wrapper::vloadq(input_ptr + 4),
+ wrapper::vloadq(input_ptr + 8),
+ wrapper::vloadq(input_ptr + 12) };
+}
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+inline const float32x4x4_t load_value(const float16_t *input_ptr)
+{
+ return { vcvt_f32_f16(wrapper::vload(input_ptr)),
+ vcvt_f32_f16(wrapper::vload(input_ptr + 4)),
+ vcvt_f32_f16(wrapper::vload(input_ptr + 8)),
+ vcvt_f32_f16(wrapper::vload(input_ptr + 12)) };
}
+
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
} // namespace
NEQuantizationLayerKernel::NEQuantizationLayerKernel()
- : _input(nullptr), _output(nullptr), _min_max(nullptr)
+ : _input(nullptr), _output(nullptr)
{
}
-void NEQuantizationLayerKernel::configure(const ITensor *input, ITensor *output, const ITensor *min_max)
+void NEQuantizationLayerKernel::configure(const ITensor *input, ITensor *output)
{
- ARM_COMPUTE_ERROR_ON_NULLPTR(input, output, min_max);
- ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), min_max->info()));
+ ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info()));
- _input = input;
- _output = output;
- _min_max = min_max;
+ _input = input;
+ _output = output;
// Configure kernel window
- auto win_config = validate_and_configure_window(input->info(), output->info(), min_max->info());
+ Window win_config = calculate_max_window(*input->info(), Steps());
- ARM_COMPUTE_ERROR_THROW_ON(std::get<0>(win_config));
+ Coordinates coord;
+ coord.set_num_dimensions(output->info()->num_dimensions());
+ output->info()->set_valid_region(ValidRegion(coord, output->info()->tensor_shape()));
- INEKernel::configure(std::get<1>(win_config));
+ INEKernel::configure(win_config);
}
-Status NEQuantizationLayerKernel::validate(const ITensorInfo *input, const ITensorInfo *output, const ITensorInfo *min_max)
+Status NEQuantizationLayerKernel::validate(const ITensorInfo *input, const ITensorInfo *output)
{
- ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, min_max));
- ARM_COMPUTE_RETURN_ON_ERROR(std::get<0>(validate_and_configure_window(input->clone().get(), output->clone().get(), min_max->clone().get())));
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output));
return Status{};
}
+template <typename T>
+void NEQuantizationLayerKernel::quantize(const Window &window, const QuantizationInfo &qinfo)
+{
+ constexpr auto window_step = 16;
+ const auto window_start_x = static_cast<int>(window.x().start());
+ const auto window_end_x = static_cast<int>(window.x().end());
+
+#ifdef __aarch64__
+ constexpr RoundingPolicy rounding_policy = RoundingPolicy::TO_NEAREST_EVEN;
+#else //__aarch64__
+ constexpr RoundingPolicy rounding_policy = RoundingPolicy::TO_ZERO;
+#endif //__aarch64__
+
+ // Collapse window and reset first dimension to handle tail calculations manually
+ Window win_collapsed = window.collapse_if_possible(window, Window::DimZ);
+ win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1));
+
+ Iterator input(_input, win_collapsed);
+ Iterator output(_output, win_collapsed);
+ execute_window_loop(win_collapsed, [&](const Coordinates & id)
+ {
+ auto input_ptr = reinterpret_cast<const T *>(input.ptr());
+ auto output_ptr = reinterpret_cast<uint8_t *>(output.ptr());
+
+ int x = window_start_x;
+ for(; x <= (window_end_x - window_step); x += window_step)
+ {
+ wrapper::vstore(&output_ptr[x], vquantize(load_value(&input_ptr[x]), qinfo));
+ }
+ // Compute left-over elements
+ for(; x < window_end_x; ++x)
+ {
+ output_ptr[x] = qinfo.quantize(input_ptr[x], rounding_policy);
+ }
+ },
+ input, output);
+}
+
void NEQuantizationLayerKernel::run(const Window &window, const ThreadInfo &info)
{
ARM_COMPUTE_UNUSED(info);
ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
- Window window_input_output(window);
- window_input_output.set(3, Window::Dimension(0, 1, 1));
-
- Window window_min_max;
- window_min_max.use_tensor_dimensions(_min_max->info()->tensor_shape());
- window_min_max.set(Window::DimX, Window::Dimension(0, 1, 1));
-
- Iterator input(_input, window_input_output);
- Iterator output(_output, window_input_output);
- Iterator min_max(_min_max, window_min_max);
+ const QuantizationInfo &qinfo = _output->info()->quantization_info();
- execute_window_loop(window_min_max, [&](const Coordinates & id_batch)
+ switch(_input->info()->data_type())
{
- // Get the min and max
- float min = *(reinterpret_cast<const float *>(min_max.ptr()) + 0);
- float max = *(reinterpret_cast<const float *>(min_max.ptr()) + 1);
-
- // Saturate the result if min = max
- if(min == max)
- {
- min = 0.0f;
- max = 1.0f;
- }
-
- const float32x4_t vmin = vdupq_n_f32(min);
- const float32x4_t inv_range = vdupq_n_f32(1.0f / (max - min));
- const float32x4_t quantization_max = vdupq_n_f32(255.0f);
- const float32x4_t quantization_mul = vdupq_n_f32(256.0f);
-
- // Uniformly map values to range 8bit integers, i.e. [min, max] -> [0, 255]
- execute_window_loop(window_input_output, [&](const Coordinates & id)
- {
- // Get the input values
- const auto input_ptr = reinterpret_cast<const float *>(input.ptr() + id_batch[1] * _input->info()->strides_in_bytes()[3]);
- float32x4x2_t val = vld2q_f32(input_ptr);
-
- // Map float values to range [0.0, 1.0]
- val.val[0] = vsubq_f32(val.val[0], vmin);
- val.val[1] = vsubq_f32(val.val[1], vmin);
- val.val[0] = vmulq_f32(val.val[0], inv_range);
- val.val[1] = vmulq_f32(val.val[1], inv_range);
-
- // Quantize
- val.val[0] = vmulq_f32(val.val[0], quantization_mul);
- val.val[1] = vmulq_f32(val.val[1], quantization_mul);
- val.val[0] = vminq_f32(val.val[0], quantization_max);
- val.val[1] = vminq_f32(val.val[1], quantization_max);
-
- const uint32x4_t val_u32_low = vcvtq_u32_f32(val.val[0]);
- const uint32x4_t val_u32_high = vcvtq_u32_f32(val.val[1]);
- const uint16x4x2_t val_u16 = vzip_u16(vmovn_u32(val_u32_low), vmovn_u32(val_u32_high));
-
- const uint8x8_t quantized = vmovn_u16(vcombine_u16(val_u16.val[0], val_u16.val[1]));
-
- // Store the quantized values
- auto output_ptr = reinterpret_cast<uint8_t *>(output.ptr() + id_batch[1] * _output->info()->strides_in_bytes()[3]);
- vst1_u8(output_ptr, quantized);
- },
- input, output);
- },
- min_max);
+ case DataType::F32:
+ NEQuantizationLayerKernel::quantize<float>(window, qinfo);
+ break;
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ case DataType::F16:
+ NEQuantizationLayerKernel::quantize<float16_t>(window, qinfo);
+ break;
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ default:
+ ARM_COMPUTE_ERROR("Unsupported data type.");
+ }
}
diff --git a/src/core/Rounding.cpp b/src/core/Rounding.cpp
index fea635be97..da6e5f6099 100644
--- a/src/core/Rounding.cpp
+++ b/src/core/Rounding.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -50,7 +50,13 @@ int arm_compute::round(float x, RoundingPolicy rounding_policy)
}
case RoundingPolicy::TO_NEAREST_EVEN:
{
+#ifdef __aarch64__
+ asm("fcvtns %x[res], %s[value]"
+ : [res] "=r"(rounded)
+ : [value] "w"(x));
+#else // __aarch64__
ARM_COMPUTE_ERROR("TO_NEAREST_EVEN rounding policy is not supported.");
+#endif // __aarch64__
break;
}
default:
diff --git a/src/runtime/NEON/functions/NEQuantizationLayer.cpp b/src/runtime/NEON/functions/NEQuantizationLayer.cpp
index 8f7db96de8..65873b1b14 100644
--- a/src/runtime/NEON/functions/NEQuantizationLayer.cpp
+++ b/src/runtime/NEON/functions/NEQuantizationLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -26,22 +26,13 @@
#include "arm_compute/core/Types.h"
#include "arm_compute/core/Validate.h"
-#include "arm_compute/runtime/NEON/NEScheduler.h"
using namespace arm_compute;
-NEQuantizationLayer::NEQuantizationLayer()
- : _quantize_kernel(), _min_max_kernel(), _min_max()
-{
-}
-
Status NEQuantizationLayer::validate(const ITensorInfo *input, const ITensorInfo *output)
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
-
- TensorInfo min_max{ input->num_channels(), input->data_type() };
- ARM_COMPUTE_RETURN_ON_ERROR(NEMinMaxLayerKernel::validate(input, &min_max));
- ARM_COMPUTE_RETURN_ON_ERROR(NEQuantizationLayerKernel::validate(input, output, &min_max));
+ ARM_COMPUTE_RETURN_ON_ERROR(NEQuantizationLayerKernel::validate(input, output));
return Status{};
}
@@ -50,24 +41,8 @@ void NEQuantizationLayer::configure(const ITensor *input, ITensor *output)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
- // Configure min-max kernel. _min_max tensor will be auto-configured within the kernel
- _min_max_kernel.configure(input, &_min_max);
-
// Configure quantize kernel
- _quantize_kernel.configure(input, output, &_min_max);
-
- // Allocate min_max tensor
- _min_max.allocator()->allocate();
-}
-
-void NEQuantizationLayer::run()
-{
- // Reset min and max
- _min_max_kernel.reset();
-
- // Run min and max kernel
- NEScheduler::get().schedule(&_min_max_kernel, Window::DimY);
-
- // Run quantize kernel
- NEScheduler::get().schedule(&_quantize_kernel, Window::DimY);
+ auto k = arm_compute::support::cpp14::make_unique<NEQuantizationLayerKernel>();
+ k->configure(input, output);
+ _kernel = std::move(k);
}
diff --git a/tests/validation/NEON/QuantizationLayer.cpp b/tests/validation/NEON/QuantizationLayer.cpp
index 6526539d1a..487eb70120 100644
--- a/tests/validation/NEON/QuantizationLayer.cpp
+++ b/tests/validation/NEON/QuantizationLayer.cpp
@@ -55,21 +55,17 @@ TEST_SUITE(QuantizationLayer)
// *INDENT-OFF*
// clang-format off
DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(
- framework::dataset::make("InputInfo", { TensorInfo(TensorShape(16U, 16U, 16U, 5U), 1, DataType::U8), // Wrong input data type
- TensorInfo(TensorShape(16U, 5U, 16U), 1, DataType::U8), // Invalid shape
+ framework::dataset::make("InputInfo", { TensorInfo(TensorShape(16U, 16U, 16U, 5U), 1, DataType::QASYMM8), // Wrong input data type
TensorInfo(TensorShape(16U, 16U, 16U, 5U), 1, DataType::F32), // Wrong output data type
- TensorInfo(TensorShape(16U, 16U, 2U, 5U), 1, DataType::U8), // Missmatching shapes
- TensorInfo(TensorShape(17U, 16U, 16U, 5U), 1, DataType::U8), // Shrink window
+ TensorInfo(TensorShape(16U, 16U, 2U, 5U), 1, DataType::F32), // Missmatching shapes
TensorInfo(TensorShape(16U, 16U, 16U, 5U), 1, DataType::F32), // Valid
}),
framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(16U, 16U, 16U, 5U), 1, DataType::F32),
- TensorInfo(TensorShape(16U, 5U, 16U), 1, DataType::U8),
TensorInfo(TensorShape(16U, 16U, 16U, 5U), 1, DataType::U16),
- TensorInfo(TensorShape(16U, 16U, 16U, 5U), 1, DataType::F32),
- TensorInfo(TensorShape(17U, 16U, 16U, 5U), 1, DataType::F32),
- TensorInfo(TensorShape(16U, 16U, 16U, 5U), 1, DataType::U8),
+ TensorInfo(TensorShape(16U, 16U, 16U, 5U), 1, DataType::QASYMM8),
+ TensorInfo(TensorShape(16U, 16U, 16U, 5U), 1, DataType::QASYMM8),
})),
- framework::dataset::make("Expected", { false, false, false, false, false, true})),
+ framework::dataset::make("Expected", { false, false, false, true})),
input_info, output_info, expected)
{
ARM_COMPUTE_EXPECT(bool(NEQuantizationLayer::validate(&input_info.clone()->set_is_resizable(false), &output_info.clone()->set_is_resizable(false))) == expected, framework::LogLevel::ERRORS);
@@ -81,7 +77,7 @@ DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(QuantizationS
{
// Create tensors
Tensor src = create_tensor<Tensor>(shape, data_type);
- Tensor dst = create_tensor<Tensor>(shape, DataType::U8);
+ Tensor dst = create_tensor<Tensor>(shape, DataType::QASYMM8);
ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS);
@@ -96,30 +92,51 @@ DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(QuantizationS
validate(dst.info()->valid_region(), valid_region);
// Validate padding
- const PaddingSize padding = PaddingCalculator(shape.x(), 8).required_padding();
- validate(src.info()->padding(), padding);
- validate(dst.info()->padding(), padding);
+ validate(src.info()->padding(), PaddingSize());
+ validate(dst.info()->padding(), PaddingSize());
}
template <typename T>
-using NEQuantizationLayerFixture = QuantizationValidationFixture<Tensor, Accessor, NEQuantizationLayer, T>;
+using NEQuantizationLayerFixture = QAsymm8QuantizationValidationFixture<Tensor, Accessor, NEQuantizationLayer, T>;
TEST_SUITE(Float)
TEST_SUITE(FP32)
-FIXTURE_DATA_TEST_CASE(RunSmall, NEQuantizationLayerFixture<float>, framework::DatasetMode::PRECOMMIT, combine(concat(datasets::Small3DShapes(), datasets::Small4DShapes()),
- framework::dataset::make("DataType", DataType::F32)))
+FIXTURE_DATA_TEST_CASE(RunSmall, NEQuantizationLayerFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(concat(datasets::Small3DShapes(), datasets::Small4DShapes()),
+ framework::dataset::make("DataType", DataType::F32)),
+ framework::dataset::make("QuantizationInfo", { QuantizationInfo(0.5f, 10) })))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_u8);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, NEQuantizationLayerFixture<float>, framework::DatasetMode::NIGHTLY, combine(concat(datasets::Large3DShapes(), datasets::Large4DShapes()),
- framework::dataset::make("DataType", DataType::F32)))
+FIXTURE_DATA_TEST_CASE(RunLarge, NEQuantizationLayerFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(concat(datasets::Large3DShapes(), datasets::Large4DShapes()),
+ framework::dataset::make("DataType", DataType::F32)),
+ framework::dataset::make("QuantizationInfo", { QuantizationInfo(0.5f, 10) })))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_u8);
}
TEST_SUITE_END() // FP32
TEST_SUITE_END() // Float
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+TEST_SUITE(Half)
+TEST_SUITE(FP16)
+FIXTURE_DATA_TEST_CASE(RunSmall, NEQuantizationLayerFixture<half>, framework::DatasetMode::PRECOMMIT, combine(combine(concat(datasets::Small3DShapes(), datasets::Small4DShapes()),
+ framework::dataset::make("DataType", DataType::F16)),
+ framework::dataset::make("QuantizationInfo", { QuantizationInfo(0.5f, 10) })))
+{
+ // Validate output
+ validate(Accessor(_target), _reference, tolerance_u8);
+}
+FIXTURE_DATA_TEST_CASE(RunLarge, NEQuantizationLayerFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(concat(datasets::Large3DShapes(), datasets::Large4DShapes()),
+ framework::dataset::make("DataType", DataType::F16)),
+ framework::dataset::make("QuantizationInfo", { QuantizationInfo(0.5f, 10) })))
+{
+ // Validate output
+ validate(Accessor(_target), _reference, tolerance_u8);
+}
+TEST_SUITE_END() // FP16
+TEST_SUITE_END() // Half
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
TEST_SUITE_END() // QuantizationLayer
TEST_SUITE_END() // NEON
diff --git a/tests/validation/fixtures/QuantizationLayerFixture.h b/tests/validation/fixtures/QuantizationLayerFixture.h
index 8590b7193b..65de405788 100644
--- a/tests/validation/fixtures/QuantizationLayerFixture.h
+++ b/tests/validation/fixtures/QuantizationLayerFixture.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -103,6 +103,69 @@ protected:
TensorType _target{};
SimpleTensor<uint8_t> _reference{};
};
+
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
+class QAsymm8QuantizationValidationFixture : public framework::Fixture
+{
+public:
+ template <typename...>
+ void setup(TensorShape shape, DataType data_type, QuantizationInfo quant_info)
+ {
+ _target = compute_target(shape, data_type, quant_info);
+ _reference = compute_reference(shape, data_type, quant_info);
+ }
+
+protected:
+ template <typename U>
+ void fill(U &&tensor)
+ {
+ library->fill_tensor_uniform(tensor, 0);
+ }
+
+ TensorType compute_target(const TensorShape &shape, DataType data_type, QuantizationInfo quant_info)
+ {
+ // Create tensors
+ TensorType src = create_tensor<TensorType>(shape, data_type);
+ TensorType dst = create_tensor<TensorType>(shape, DataType::QASYMM8, 1, quant_info);
+
+ // Create and configure function
+ FunctionType quantization_layer;
+ quantization_layer.configure(&src, &dst);
+
+ ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS);
+
+ // Allocate tensors
+ src.allocator()->allocate();
+ dst.allocator()->allocate();
+
+ ARM_COMPUTE_EXPECT(!src.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(!dst.info()->is_resizable(), framework::LogLevel::ERRORS);
+
+ // Fill tensors
+ fill(AccessorType(src));
+
+ // Compute function
+ quantization_layer.run();
+
+ return dst;
+ }
+
+ SimpleTensor<uint8_t> compute_reference(const TensorShape &shape, DataType data_type, QuantizationInfo quant_info)
+ {
+ // Create reference
+ SimpleTensor<T> src{ shape, data_type };
+
+ // Fill reference
+ fill(src);
+
+ return reference::quantization_layer<T>(src, quant_info);
+ }
+
+ TensorType _target{};
+ SimpleTensor<uint8_t> _reference{};
+};
+
} // namespace validation
} // namespace test
} // namespace arm_compute
diff --git a/tests/validation/reference/QuantizationLayer.cpp b/tests/validation/reference/QuantizationLayer.cpp
index d7ce490209..3d6c5bc13d 100644
--- a/tests/validation/reference/QuantizationLayer.cpp
+++ b/tests/validation/reference/QuantizationLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -81,6 +81,25 @@ SimpleTensor<uint8_t> quantization_layer(const SimpleTensor<T> &src)
}
template SimpleTensor<uint8_t> quantization_layer(const SimpleTensor<float> &src);
+
+template <typename T>
+SimpleTensor<uint8_t> quantization_layer(const SimpleTensor<T> &src, const QuantizationInfo quantization_info)
+{
+ // Create reference
+ SimpleTensor<uint8_t> dst{ src.shape(), DataType::QASYMM8, 1, quantization_info };
+
+ for(int i = 0; i < src.num_elements(); ++i)
+ {
+#ifdef __aarch64__
+ dst[i] = quantization_info.quantize((src[i]), RoundingPolicy::TO_NEAREST_EVEN);
+#else // __aarch64__
+ dst[i] = quantization_info.quantize((src[i]), RoundingPolicy::TO_ZERO);
+#endif // __aarch64__
+ }
+ return dst;
+}
+template SimpleTensor<uint8_t> quantization_layer(const SimpleTensor<half> &src, const QuantizationInfo quantization_info);
+template SimpleTensor<uint8_t> quantization_layer(const SimpleTensor<float> &src, const QuantizationInfo quantization_info);
} // namespace reference
} // namespace validation
} // namespace test
diff --git a/tests/validation/reference/QuantizationLayer.h b/tests/validation/reference/QuantizationLayer.h
index 7c5572ccf8..60d8ea4023 100644
--- a/tests/validation/reference/QuantizationLayer.h
+++ b/tests/validation/reference/QuantizationLayer.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -37,6 +37,9 @@ namespace reference
{
template <typename T, typename std::enable_if<is_floating_point<T>::value, int>::type = 0>
SimpleTensor<uint8_t> quantization_layer(const SimpleTensor<T> &src);
+
+template <typename T>
+SimpleTensor<uint8_t> quantization_layer(const SimpleTensor<T> &src, const QuantizationInfo quantization_info);
} // namespace reference
} // namespace validation
} // namespace test