From 655e8c6334580a570008243af1896d269fdd60ad Mon Sep 17 00:00:00 2001 From: Michele Di Giorgio Date: Thu, 28 Jan 2021 12:51:02 +0000 Subject: Make data_layout an attribute of the Scale function Resolves COMPMID-4208 Change-Id: I61ca670134a005462ad0528a5aff9507a90860e7 Signed-off-by: Michele Di Giorgio Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/4942 Tested-by: Arm Jenkins Reviewed-by: Georgios Pinitas Comments-Addressed: Arm Jenkins --- src/core/NEON/kernels/NEScaleKernel.cpp | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) (limited to 'src/core/NEON/kernels/NEScaleKernel.cpp') diff --git a/src/core/NEON/kernels/NEScaleKernel.cpp b/src/core/NEON/kernels/NEScaleKernel.cpp index 39ed6317a1..f2c11b203c 100644 --- a/src/core/NEON/kernels/NEScaleKernel.cpp +++ b/src/core/NEON/kernels/NEScaleKernel.cpp @@ -158,7 +158,7 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *dx, const ARM_COMPUTE_UNUSED(info.constant_border_value); ARM_COMPUTE_RETURN_ERROR_ON_MSG(info.use_padding, "Padding is not supported"); - const DataLayout data_layout = input->data_layout(); + const DataLayout data_layout = info.data_layout == DataLayout::UNKNOWN ? input->data_layout() : info.data_layout; const auto width_index = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH); const auto height_index = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT); const auto output_width = output->dimension(width_index); @@ -192,7 +192,7 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *dx, const NEScaleKernel::NEScaleKernel() : _func(nullptr), _offsets(nullptr), _dx(nullptr), _dy(nullptr), _input(nullptr), _output(nullptr), _policy(), _border_mode(), _constant_border_value(PixelValue()), _sampling_offset(0), - _align_corners(false) + _align_corners(false), _data_layout(DataLayout::UNKNOWN) { } @@ -209,9 +209,9 @@ void NEScaleKernel::configure(const ITensor *input, const ITensor *dx, const ITe info)); // Get data layout and width/height indices - const DataLayout data_layout = input->info()->data_layout(); - const int idx_width = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH); - const int idx_height = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT); + _data_layout = info.data_layout == DataLayout::UNKNOWN ? input->info()->data_layout() : info.data_layout; + const int idx_width = get_data_layout_dimension_index(_data_layout, DataLayoutDimension::WIDTH); + const int idx_height = get_data_layout_dimension_index(_data_layout, DataLayoutDimension::HEIGHT); _input = input; _output = output; @@ -242,11 +242,11 @@ void NEScaleKernel::configure(const ITensor *input, const ITensor *dx, const ITe } // Configure scale function to run - if(_input->info()->data_layout() == DataLayout::NCHW) + if(_data_layout == DataLayout::NCHW) { std::string function_to_call("scale_"); function_to_call += string_from_data_type(_input->info()->data_type()) + "_"; - function_to_call += string_from_data_layout(_input->info()->data_layout()) + "_"; + function_to_call += string_from_data_layout(_data_layout) + "_"; function_to_call += string_from_interpolation_policy(_policy); static std::map map_function = @@ -471,9 +471,8 @@ template void NEScaleKernel::scale_bilinear_qasymm(const Window &window) { // Get data layout and width/height indices - const DataLayout data_layout = _input->info()->data_layout(); - const int idx_width = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH); - const int idx_height = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT); + const int idx_width = get_data_layout_dimension_index(_data_layout, DataLayoutDimension::WIDTH); + const int idx_height = get_data_layout_dimension_index(_data_layout, DataLayoutDimension::HEIGHT); // Compute the ratio between source height and destination height const auto hr = scale_utils::calculate_resize_ratio(_input->info()->dimension(idx_height), _output->info()->dimension(idx_height), _align_corners); @@ -586,9 +585,9 @@ void NEScaleKernel::run(const Window &window, const ThreadInfo &info) ARM_COMPUTE_UNUSED(info); ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window); - ARM_COMPUTE_ERROR_ON(_func == nullptr && _input->info()->data_layout() == DataLayout::NCHW); + ARM_COMPUTE_ERROR_ON(_func == nullptr && _data_layout == DataLayout::NCHW); - if(_input->info()->data_layout() == DataLayout::NCHW) + if(_data_layout == DataLayout::NCHW) { (this->*_func)(window); } -- cgit v1.2.1