aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/NEScaleKernel.cpp
diff options
context:
space:
mode:
authorMichele Di Giorgio <michele.digiorgio@arm.com>2021-01-28 12:51:02 +0000
committerMichele Di Giorgio <michele.digiorgio@arm.com>2021-02-01 08:46:01 +0000
commit655e8c6334580a570008243af1896d269fdd60ad (patch)
tree18f5e6e5cc9148a1afcc54c8c6ac54620242bc6d /src/core/NEON/kernels/NEScaleKernel.cpp
parentcc438f23e206d9b5bb55e491b97fbc9b0962dabc (diff)
downloadComputeLibrary-655e8c6334580a570008243af1896d269fdd60ad.tar.gz
Make data_layout an attribute of the Scale function
Resolves COMPMID-4208 Change-Id: I61ca670134a005462ad0528a5aff9507a90860e7 Signed-off-by: Michele Di Giorgio <michele.digiorgio@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/4942 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/NEScaleKernel.cpp')
-rw-r--r--src/core/NEON/kernels/NEScaleKernel.cpp23
1 files changed, 11 insertions, 12 deletions
diff --git a/src/core/NEON/kernels/NEScaleKernel.cpp b/src/core/NEON/kernels/NEScaleKernel.cpp
index 39ed6317a1..f2c11b203c 100644
--- a/src/core/NEON/kernels/NEScaleKernel.cpp
+++ b/src/core/NEON/kernels/NEScaleKernel.cpp
@@ -158,7 +158,7 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *dx, const
ARM_COMPUTE_UNUSED(info.constant_border_value);
ARM_COMPUTE_RETURN_ERROR_ON_MSG(info.use_padding, "Padding is not supported");
- const DataLayout data_layout = input->data_layout();
+ const DataLayout data_layout = info.data_layout == DataLayout::UNKNOWN ? input->data_layout() : info.data_layout;
const auto width_index = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
const auto height_index = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
const auto output_width = output->dimension(width_index);
@@ -192,7 +192,7 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *dx, const
NEScaleKernel::NEScaleKernel()
: _func(nullptr), _offsets(nullptr), _dx(nullptr), _dy(nullptr), _input(nullptr), _output(nullptr), _policy(), _border_mode(), _constant_border_value(PixelValue()), _sampling_offset(0),
- _align_corners(false)
+ _align_corners(false), _data_layout(DataLayout::UNKNOWN)
{
}
@@ -209,9 +209,9 @@ void NEScaleKernel::configure(const ITensor *input, const ITensor *dx, const ITe
info));
// Get data layout and width/height indices
- const DataLayout data_layout = input->info()->data_layout();
- const int idx_width = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
- const int idx_height = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
+ _data_layout = info.data_layout == DataLayout::UNKNOWN ? input->info()->data_layout() : info.data_layout;
+ const int idx_width = get_data_layout_dimension_index(_data_layout, DataLayoutDimension::WIDTH);
+ const int idx_height = get_data_layout_dimension_index(_data_layout, DataLayoutDimension::HEIGHT);
_input = input;
_output = output;
@@ -242,11 +242,11 @@ void NEScaleKernel::configure(const ITensor *input, const ITensor *dx, const ITe
}
// Configure scale function to run
- if(_input->info()->data_layout() == DataLayout::NCHW)
+ if(_data_layout == DataLayout::NCHW)
{
std::string function_to_call("scale_");
function_to_call += string_from_data_type(_input->info()->data_type()) + "_";
- function_to_call += string_from_data_layout(_input->info()->data_layout()) + "_";
+ function_to_call += string_from_data_layout(_data_layout) + "_";
function_to_call += string_from_interpolation_policy(_policy);
static std::map<std::string, ScaleFunctionPtr> map_function =
@@ -471,9 +471,8 @@ template <typename T>
void NEScaleKernel::scale_bilinear_qasymm(const Window &window)
{
// Get data layout and width/height indices
- const DataLayout data_layout = _input->info()->data_layout();
- const int idx_width = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
- const int idx_height = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
+ const int idx_width = get_data_layout_dimension_index(_data_layout, DataLayoutDimension::WIDTH);
+ const int idx_height = get_data_layout_dimension_index(_data_layout, DataLayoutDimension::HEIGHT);
// Compute the ratio between source height and destination height
const auto hr = scale_utils::calculate_resize_ratio(_input->info()->dimension(idx_height), _output->info()->dimension(idx_height), _align_corners);
@@ -586,9 +585,9 @@ void NEScaleKernel::run(const Window &window, const ThreadInfo &info)
ARM_COMPUTE_UNUSED(info);
ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
- ARM_COMPUTE_ERROR_ON(_func == nullptr && _input->info()->data_layout() == DataLayout::NCHW);
+ ARM_COMPUTE_ERROR_ON(_func == nullptr && _data_layout == DataLayout::NCHW);
- if(_input->info()->data_layout() == DataLayout::NCHW)
+ if(_data_layout == DataLayout::NCHW)
{
(this->*_func)(window);
}