aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/NEROIAlignLayerKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/NEROIAlignLayerKernel.cpp')
-rw-r--r--src/core/NEON/kernels/NEROIAlignLayerKernel.cpp122
1 files changed, 68 insertions, 54 deletions
diff --git a/src/core/NEON/kernels/NEROIAlignLayerKernel.cpp b/src/core/NEON/kernels/NEROIAlignLayerKernel.cpp
index c48cda8b8e..e937dadba7 100644
--- a/src/core/NEON/kernels/NEROIAlignLayerKernel.cpp
+++ b/src/core/NEON/kernels/NEROIAlignLayerKernel.cpp
@@ -47,7 +47,7 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *rois, ITe
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, rois, output);
ARM_COMPUTE_RETURN_ERROR_ON(rois->dimension(0) != 5);
ARM_COMPUTE_RETURN_ERROR_ON(rois->num_dimensions() > 2);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F32, DataType::F16);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::F32, DataType::F16);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_LAYOUT_NOT_IN(input, DataLayout::NHWC, DataLayout::NCHW);
ARM_COMPUTE_RETURN_ERROR_ON((pool_info.pooled_width() == 0) || (pool_info.pooled_height() == 0));
ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input);
@@ -59,7 +59,7 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *rois, ITe
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(compute_roi_align_shape(*input, *rois, pool_info), output->tensor_shape());
}
- if(input->data_type() == DataType::QASYMM8)
+ if(input->data_type() == DataType::QASYMM8 || input->data_type() == DataType::QASYMM8_SIGNED)
{
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(rois, 1, DataType::QASYMM16);
@@ -116,7 +116,7 @@ Status NEROIAlignLayerKernel::validate(const ITensorInfo *input, const ITensorIn
}
/** Average pooling over an aligned window */
-template <typename input_data_type, DataLayout data_layout>
+template <typename input_data_type>
inline input_data_type roi_align_1x1(const ITensor *input,
unsigned int roi_batch,
float region_start_x,
@@ -135,7 +135,8 @@ inline input_data_type roi_align_1x1(const ITensor *input,
}
else
{
- float avg = 0;
+ const DataLayout data_layout = input->info()->data_layout();
+ float avg = 0;
// Iterate through the aligned pooling region
for(int iy = 0; iy < grid_size_y; ++iy)
{
@@ -185,7 +186,7 @@ inline input_data_type roi_align_1x1(const ITensor *input,
}
/** Average pooling over an aligned window */
-template <typename input_data_type, DataLayout data_layout>
+template <typename input_data_type>
inline input_data_type roi_align_1x1_qasymm8(const ITensor *input,
unsigned int roi_batch,
float region_start_x,
@@ -205,8 +206,11 @@ inline input_data_type roi_align_1x1_qasymm8(const ITensor *input,
}
else
{
- float avg = 0;
- const UniformQuantizationInfo input_qinfo = input->info()->quantization_info().uniform();
+ float avg = 0;
+ const UniformQuantizationInfo input_qinfo = input->info()->quantization_info().uniform();
+ const bool is_qasymm_signed = is_data_type_quantized_asymmetric_signed(input->info()->data_type());
+ const DataLayout data_layout = input->info()->data_layout();
+
// Iterate through the aligned pooling region
for(int iy = 0; iy < grid_size_y; ++iy)
{
@@ -234,26 +238,57 @@ inline input_data_type roi_align_1x1_qasymm8(const ITensor *input,
if(data_layout == DataLayout::NCHW)
{
- float data1 = dequantize_qasymm8(*reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(x_low, y_low, pz, roi_batch))), input_qinfo);
- float data2 = dequantize_qasymm8(*reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(x_high, y_low, pz, roi_batch))), input_qinfo);
- float data3 = dequantize_qasymm8(*reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(x_low, y_high, pz, roi_batch))), input_qinfo);
- float data4 = dequantize_qasymm8(*reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(x_high, y_high, pz, roi_batch))), input_qinfo);
- avg += w1 * data1 + w2 * data2 + w3 * data3 + w4 * data4;
+ if(is_qasymm_signed)
+ {
+ float data1 = dequantize_qasymm8_signed(*reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(x_low, y_low, pz, roi_batch))), input_qinfo);
+ float data2 = dequantize_qasymm8_signed(*reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(x_high, y_low, pz, roi_batch))), input_qinfo);
+ float data3 = dequantize_qasymm8_signed(*reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(x_low, y_high, pz, roi_batch))), input_qinfo);
+ float data4 = dequantize_qasymm8_signed(*reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(x_high, y_high, pz, roi_batch))), input_qinfo);
+ avg += w1 * data1 + w2 * data2 + w3 * data3 + w4 * data4;
+ }
+ else
+ {
+ float data1 = dequantize_qasymm8(*reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(x_low, y_low, pz, roi_batch))), input_qinfo);
+ float data2 = dequantize_qasymm8(*reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(x_high, y_low, pz, roi_batch))), input_qinfo);
+ float data3 = dequantize_qasymm8(*reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(x_low, y_high, pz, roi_batch))), input_qinfo);
+ float data4 = dequantize_qasymm8(*reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(x_high, y_high, pz, roi_batch))), input_qinfo);
+ avg += w1 * data1 + w2 * data2 + w3 * data3 + w4 * data4;
+ }
}
else
{
- const auto data1 = dequantize_qasymm8(*reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(pz, x_low, y_low, roi_batch))), input_qinfo);
- const auto data2 = dequantize_qasymm8(*reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(pz, x_high, y_low, roi_batch))), input_qinfo);
- const auto data3 = dequantize_qasymm8(*reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(pz, x_low, y_high, roi_batch))), input_qinfo);
- const auto data4 = dequantize_qasymm8(*reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(pz, x_high, y_high, roi_batch))), input_qinfo);
- avg += w1 * data1 + w2 * data2 + w3 * data3 + w4 * data4;
+ if(is_qasymm_signed)
+ {
+ const auto data1 = dequantize_qasymm8_signed(*reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(pz, x_low, y_low, roi_batch))), input_qinfo);
+ const auto data2 = dequantize_qasymm8_signed(*reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(pz, x_high, y_low, roi_batch))), input_qinfo);
+ const auto data3 = dequantize_qasymm8_signed(*reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(pz, x_low, y_high, roi_batch))), input_qinfo);
+ const auto data4 = dequantize_qasymm8_signed(*reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(pz, x_high, y_high, roi_batch))), input_qinfo);
+ avg += w1 * data1 + w2 * data2 + w3 * data3 + w4 * data4;
+ }
+ else
+ {
+ const auto data1 = dequantize_qasymm8(*reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(pz, x_low, y_low, roi_batch))), input_qinfo);
+ const auto data2 = dequantize_qasymm8(*reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(pz, x_high, y_low, roi_batch))), input_qinfo);
+ const auto data3 = dequantize_qasymm8(*reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(pz, x_low, y_high, roi_batch))), input_qinfo);
+ const auto data4 = dequantize_qasymm8(*reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(pz, x_high, y_high, roi_batch))), input_qinfo);
+ avg += w1 * data1 + w2 * data2 + w3 * data3 + w4 * data4;
+ }
}
}
}
avg /= grid_size_x * grid_size_y;
- return quantize_qasymm8(avg, out_qinfo);
+ input_data_type res = 0;
+ if(is_qasymm_signed)
+ {
+ res = quantize_qasymm8_signed(avg, out_qinfo);
+ }
+ else
+ {
+ res = quantize_qasymm8(avg, out_qinfo);
+ }
+ return res;
}
}
@@ -265,52 +300,30 @@ inline float compute_region_coordinate(int p, float bin_size, float roi_anchor,
void NEROIAlignLayerKernel::run(const Window &window, const ThreadInfo &info)
{
- if(_input->info()->data_layout() == DataLayout::NCHW)
+ const DataLayout data_layout = _input->info()->data_layout();
+ if(data_layout == DataLayout::NCHW || data_layout == DataLayout::NHWC)
{
switch(_input->info()->data_type())
{
case DataType::QASYMM8:
{
- NEROIAlignLayerKernel::internal_run<DataLayout::NCHW, uint8_t, uint16_t>(window, info);
- break;
- }
- case DataType::F32:
- {
- NEROIAlignLayerKernel::internal_run<DataLayout::NCHW, float>(window, info);
- break;
- }
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
- case DataType::F16:
- {
- NEROIAlignLayerKernel::internal_run<DataLayout::NCHW, float16_t>(window, info);
+ NEROIAlignLayerKernel::internal_run<uint8_t, uint16_t>(window, info);
break;
}
-#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
- default:
- {
- ARM_COMPUTE_ERROR("DataType not supported");
- break;
- }
- }
- }
- else if(_input->info()->data_layout() == DataLayout::NHWC)
- {
- switch(_input->info()->data_type())
- {
- case DataType::QASYMM8:
+ case DataType::QASYMM8_SIGNED:
{
- NEROIAlignLayerKernel::internal_run<DataLayout::NHWC, uint8_t, uint16_t>(window, info);
+ NEROIAlignLayerKernel::internal_run<int8_t, uint16_t>(window, info);
break;
}
case DataType::F32:
{
- NEROIAlignLayerKernel::internal_run<DataLayout::NHWC, float>(window, info);
+ NEROIAlignLayerKernel::internal_run<float>(window, info);
break;
}
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case DataType::F16:
{
- NEROIAlignLayerKernel::internal_run<DataLayout::NHWC, float16_t>(window, info);
+ NEROIAlignLayerKernel::internal_run<float16_t>(window, info);
break;
}
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
@@ -327,21 +340,22 @@ void NEROIAlignLayerKernel::run(const Window &window, const ThreadInfo &info)
}
}
-template <DataLayout data_layout, typename input_data_type, typename roi_data_type>
+template <typename input_data_type, typename roi_data_type>
void NEROIAlignLayerKernel::internal_run(const Window &window, const ThreadInfo &info)
{
ARM_COMPUTE_UNUSED(info);
ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
- const size_t values_per_roi = _rois->info()->dimension(0);
+ const DataLayout data_layout = _input->info()->data_layout();
+ const size_t values_per_roi = _rois->info()->dimension(0);
const int roi_list_start = window.x().start();
const int roi_list_end = window.x().end();
- const unsigned int idx_width = get_data_layout_dimension_index(_input->info()->data_layout(), DataLayoutDimension::WIDTH);
- const unsigned int idx_height = get_data_layout_dimension_index(_input->info()->data_layout(), DataLayoutDimension::HEIGHT);
- const unsigned int idx_depth = get_data_layout_dimension_index(_input->info()->data_layout(), DataLayoutDimension::CHANNEL);
+ const unsigned int idx_width = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
+ const unsigned int idx_height = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
+ const unsigned int idx_depth = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL);
const int input_width = _input->info()->dimension(idx_width);
const int input_height = _input->info()->dimension(idx_height);
@@ -397,14 +411,14 @@ void NEROIAlignLayerKernel::internal_run(const Window &window, const ThreadInfo
input_data_type out_val(0);
if(is_qasymm)
{
- out_val = roi_align_1x1_qasymm8<input_data_type, data_layout>(
+ out_val = roi_align_1x1_qasymm8<input_data_type>(
_input, roi_batch, region_start_x, bin_size_x,
roi_bin_grid_x, region_end_x, region_start_y, bin_size_y,
roi_bin_grid_y, region_end_y, ch, _output->info()->quantization_info());
}
else
{
- out_val = roi_align_1x1<input_data_type, data_layout>(
+ out_val = roi_align_1x1<input_data_type>(
_input, roi_batch, region_start_x, bin_size_x,
roi_bin_grid_x, region_end_x, region_start_y, bin_size_y,
roi_bin_grid_y, region_end_y, ch);