diff options
author | Michele Di Giorgio <michele.digiorgio@arm.com> | 2021-01-28 12:51:02 +0000 |
---|---|---|
committer | Michele Di Giorgio <michele.digiorgio@arm.com> | 2021-02-01 08:46:01 +0000 |
commit | 655e8c6334580a570008243af1896d269fdd60ad (patch) | |
tree | 18f5e6e5cc9148a1afcc54c8c6ac54620242bc6d /src/core/NEON | |
parent | cc438f23e206d9b5bb55e491b97fbc9b0962dabc (diff) | |
download | ComputeLibrary-655e8c6334580a570008243af1896d269fdd60ad.tar.gz |
Make data_layout an attribute of the Scale function
Resolves COMPMID-4208
Change-Id: I61ca670134a005462ad0528a5aff9507a90860e7
Signed-off-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/4942
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/core/NEON')
-rw-r--r-- | src/core/NEON/kernels/NEScaleKernel.cpp | 23 | ||||
-rw-r--r-- | src/core/NEON/kernels/NEScaleKernel.h | 1 |
2 files changed, 12 insertions, 12 deletions
diff --git a/src/core/NEON/kernels/NEScaleKernel.cpp b/src/core/NEON/kernels/NEScaleKernel.cpp index 39ed6317a1..f2c11b203c 100644 --- a/src/core/NEON/kernels/NEScaleKernel.cpp +++ b/src/core/NEON/kernels/NEScaleKernel.cpp @@ -158,7 +158,7 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *dx, const ARM_COMPUTE_UNUSED(info.constant_border_value); ARM_COMPUTE_RETURN_ERROR_ON_MSG(info.use_padding, "Padding is not supported"); - const DataLayout data_layout = input->data_layout(); + const DataLayout data_layout = info.data_layout == DataLayout::UNKNOWN ? input->data_layout() : info.data_layout; const auto width_index = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH); const auto height_index = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT); const auto output_width = output->dimension(width_index); @@ -192,7 +192,7 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *dx, const NEScaleKernel::NEScaleKernel() : _func(nullptr), _offsets(nullptr), _dx(nullptr), _dy(nullptr), _input(nullptr), _output(nullptr), _policy(), _border_mode(), _constant_border_value(PixelValue()), _sampling_offset(0), - _align_corners(false) + _align_corners(false), _data_layout(DataLayout::UNKNOWN) { } @@ -209,9 +209,9 @@ void NEScaleKernel::configure(const ITensor *input, const ITensor *dx, const ITe info)); // Get data layout and width/height indices - const DataLayout data_layout = input->info()->data_layout(); - const int idx_width = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH); - const int idx_height = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT); + _data_layout = info.data_layout == DataLayout::UNKNOWN ? input->info()->data_layout() : info.data_layout; + const int idx_width = get_data_layout_dimension_index(_data_layout, DataLayoutDimension::WIDTH); + const int idx_height = get_data_layout_dimension_index(_data_layout, DataLayoutDimension::HEIGHT); _input = input; _output = output; @@ -242,11 +242,11 @@ void NEScaleKernel::configure(const ITensor *input, const ITensor *dx, const ITe } // Configure scale function to run - if(_input->info()->data_layout() == DataLayout::NCHW) + if(_data_layout == DataLayout::NCHW) { std::string function_to_call("scale_"); function_to_call += string_from_data_type(_input->info()->data_type()) + "_"; - function_to_call += string_from_data_layout(_input->info()->data_layout()) + "_"; + function_to_call += string_from_data_layout(_data_layout) + "_"; function_to_call += string_from_interpolation_policy(_policy); static std::map<std::string, ScaleFunctionPtr> map_function = @@ -471,9 +471,8 @@ template <typename T> void NEScaleKernel::scale_bilinear_qasymm(const Window &window) { // Get data layout and width/height indices - const DataLayout data_layout = _input->info()->data_layout(); - const int idx_width = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH); - const int idx_height = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT); + const int idx_width = get_data_layout_dimension_index(_data_layout, DataLayoutDimension::WIDTH); + const int idx_height = get_data_layout_dimension_index(_data_layout, DataLayoutDimension::HEIGHT); // Compute the ratio between source height and destination height const auto hr = scale_utils::calculate_resize_ratio(_input->info()->dimension(idx_height), _output->info()->dimension(idx_height), _align_corners); @@ -586,9 +585,9 @@ void NEScaleKernel::run(const Window &window, const ThreadInfo &info) ARM_COMPUTE_UNUSED(info); ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window); - ARM_COMPUTE_ERROR_ON(_func == nullptr && _input->info()->data_layout() == DataLayout::NCHW); + ARM_COMPUTE_ERROR_ON(_func == nullptr && _data_layout == DataLayout::NCHW); - if(_input->info()->data_layout() == DataLayout::NCHW) + if(_data_layout == DataLayout::NCHW) { (this->*_func)(window); } diff --git a/src/core/NEON/kernels/NEScaleKernel.h b/src/core/NEON/kernels/NEScaleKernel.h index b93a213e99..f6ee3fa4c5 100644 --- a/src/core/NEON/kernels/NEScaleKernel.h +++ b/src/core/NEON/kernels/NEScaleKernel.h @@ -116,6 +116,7 @@ private: PixelValue _constant_border_value; float _sampling_offset; bool _align_corners; + DataLayout _data_layout; }; } // namespace arm_compute #endif /*ARM_COMPUTE_NESCALEKERNEL_H */ |