aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/NEDequantizationLayerKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/NEDequantizationLayerKernel.cpp')
-rw-r--r--src/core/NEON/kernels/NEDequantizationLayerKernel.cpp41
1 files changed, 15 insertions, 26 deletions
diff --git a/src/core/NEON/kernels/NEDequantizationLayerKernel.cpp b/src/core/NEON/kernels/NEDequantizationLayerKernel.cpp
index 5abd6a122d..f555df3828 100644
--- a/src/core/NEON/kernels/NEDequantizationLayerKernel.cpp
+++ b/src/core/NEON/kernels/NEDequantizationLayerKernel.cpp
@@ -43,7 +43,7 @@ namespace
Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output)
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::QASYMM8_PER_CHANNEL, DataType::QSYMM8, DataType::QSYMM16);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::QSYMM8_PER_CHANNEL, DataType::QSYMM8, DataType::QSYMM16);
if(output->tensor_shape().total_size() > 0)
{
@@ -160,10 +160,9 @@ void run_dequantization_qasymm8(const ITensor *input, ITensor *output, const Win
}
template <typename T>
-void run_dequantization_qasymm8_per_channel_nchw(const ITensor *input, ITensor *output, const Window &window)
+void run_dequantization_qsymm8_per_channel_nchw(const ITensor *input, ITensor *output, const Window &window)
{
- const std::vector<float> scale = input->info()->quantization_info().scale();
- const std::vector<int32_t> offset = input->info()->quantization_info().offset();
+ const auto scale = input->info()->quantization_info().scale();
const int window_step_x = 16;
const auto window_start_x = static_cast<int>(window.x().start());
@@ -179,14 +178,14 @@ void run_dequantization_qasymm8_per_channel_nchw(const ITensor *input, ITensor *
execute_window_loop(win, [&](const Coordinates & id)
{
- const auto in_ptr = reinterpret_cast<const uint8_t *>(in.ptr());
+ const auto in_ptr = reinterpret_cast<const int8_t *>(in.ptr());
const auto out_ptr = reinterpret_cast<T *>(out.ptr());
int x = window_start_x;
for(; x <= (window_end_x - window_step_x); x += window_step_x)
{
const auto vin = wrapper::vloadq(in_ptr + x);
- const auto vdeq = vdequantize(vin, scale[id.z()], offset[id.z()]);
+ const auto vdeq = vdequantize(vin, scale[id.z()]);
store_result<T>(reinterpret_cast<T *>(out_ptr + x), vdeq);
}
@@ -194,18 +193,17 @@ void run_dequantization_qasymm8_per_channel_nchw(const ITensor *input, ITensor *
// Compute left-over elements
for(; x < window_end_x; ++x)
{
- uint8_t val = *(in_ptr + x);
- *(out_ptr + x) = static_cast<T>(dequantize(val, scale[id.z()], offset[id.z()]));
+ int8_t val = *(in_ptr + x);
+ *(out_ptr + x) = static_cast<T>(dequantize(val, scale[id.z()]));
}
},
in, out);
}
template <typename T>
-void run_dequantization_qasymm8_per_channel_nhwc(const ITensor *input, ITensor *output, const Window &window)
+void run_dequantization_qsymm8_per_channel_nhwc(const ITensor *input, ITensor *output, const Window &window)
{
- const std::vector<float> scale = input->info()->quantization_info().scale();
- const std::vector<int32_t> offset = input->info()->quantization_info().offset();
+ const auto scale = input->info()->quantization_info().scale();
const int window_step_x = 16;
const auto window_start_x = static_cast<int>(window.x().start());
@@ -221,7 +219,7 @@ void run_dequantization_qasymm8_per_channel_nhwc(const ITensor *input, ITensor *
execute_window_loop(win, [&](const Coordinates &)
{
- const auto in_ptr = reinterpret_cast<const uint8_t *>(in.ptr());
+ const auto in_ptr = reinterpret_cast<const int8_t *>(in.ptr());
const auto out_ptr = reinterpret_cast<T *>(out.ptr());
int x = window_start_x;
@@ -236,17 +234,8 @@ void run_dequantization_qasymm8_per_channel_nhwc(const ITensor *input, ITensor *
scale[x + 12], scale[x + 13], scale[x + 14], scale[x + 15]
}
};
- const int32x4x4_t voffset =
- {
- {
- offset[x + 0], offset[x + 1], offset[x + 2], offset[x + 3],
- offset[x + 4], offset[x + 5], offset[x + 6], offset[x + 7],
- offset[x + 8], offset[x + 9], offset[x + 10], offset[x + 11],
- offset[x + 12], offset[x + 13], offset[x + 14], offset[x + 15]
- }
- };
const auto vin = wrapper::vloadq(in_ptr + x);
- const auto vdeq = vdequantize(vin, vscale, voffset);
+ const auto vdeq = vdequantize(vin, vscale);
store_result<T>(reinterpret_cast<T *>(out_ptr + x), vdeq);
}
@@ -254,8 +243,8 @@ void run_dequantization_qasymm8_per_channel_nhwc(const ITensor *input, ITensor *
// Compute left-over elements
for(; x < window_end_x; ++x)
{
- uint8_t val = *(in_ptr + x);
- *(out_ptr + x) = static_cast<T>(dequantize(val, scale[x], offset[x]));
+ int8_t val = *(in_ptr + x);
+ *(out_ptr + x) = static_cast<T>(dequantize(val, scale[x]));
}
},
in, out);
@@ -353,8 +342,8 @@ void run_dequantization_core(const ITensor *input, ITensor *output, const Window
case DataType::QASYMM8:
run_dequantization_qasymm8<T>(input, output, window);
break;
- case DataType::QASYMM8_PER_CHANNEL:
- input->info()->data_layout() == DataLayout::NHWC ? run_dequantization_qasymm8_per_channel_nhwc<T>(input, output, window) : run_dequantization_qasymm8_per_channel_nchw<T>(input, output, window);
+ case DataType::QSYMM8_PER_CHANNEL:
+ input->info()->data_layout() == DataLayout::NHWC ? run_dequantization_qsymm8_per_channel_nhwc<T>(input, output, window) : run_dequantization_qsymm8_per_channel_nchw<T>(input, output, window);
break;
case DataType::QSYMM8:
run_dequantization_qsymm8<T>(input, output, window);