From 20b4313365ea2ed31f59fd757f68f791f076e6bc Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Mon, 14 May 2018 16:05:23 +0100 Subject: COMPMID-814: Add validate method to scale. Change-Id: I5004c79ac7b10f988f25e14847f1ea2be01629da Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/131143 Tested-by: Jenkins Reviewed-by: Anthony Barbier --- src/runtime/NEON/functions/NEScale.cpp | 47 +++++++++++++++++++++++++++++++--- 1 file changed, 44 insertions(+), 3 deletions(-) (limited to 'src/runtime/NEON/functions/NEScale.cpp') diff --git a/src/runtime/NEON/functions/NEScale.cpp b/src/runtime/NEON/functions/NEScale.cpp index 43ef6199ba..9407273c1f 100644 --- a/src/runtime/NEON/functions/NEScale.cpp +++ b/src/runtime/NEON/functions/NEScale.cpp @@ -45,7 +45,6 @@ namespace void precompute_dx_dy_offsets(ITensor *dx, ITensor *dy, ITensor *offsets, float wr, float hr, size_t input_element_size, SamplingPolicy sampling_policy) { ARM_COMPUTE_ERROR_ON(nullptr == offsets); - ARM_COMPUTE_ERROR_ON(sampling_policy != SamplingPolicy::CENTER); ARM_COMPUTE_UNUSED(sampling_policy); Window win; @@ -99,8 +98,8 @@ NEScale::NEScale() // NOLINT void NEScale::configure(ITensor *input, ITensor *output, InterpolationPolicy policy, BorderMode border_mode, PixelValue constant_border_value, SamplingPolicy sampling_policy) { - ARM_COMPUTE_ERROR_ON(nullptr == input); - ARM_COMPUTE_ERROR_ON(nullptr == output); + ARM_COMPUTE_ERROR_ON_NULLPTR(input, output); + ARM_COMPUTE_ERROR_THROW_ON(NEScale::validate(input->info(), output->info(), policy, border_mode, constant_border_value, sampling_policy)); // Get data layout and width/height indices const DataLayout data_layout = input->info()->data_layout(); @@ -171,6 +170,48 @@ void NEScale::configure(ITensor *input, ITensor *output, InterpolationPolicy pol _border_handler.configure(input, _scale_kernel.border_size(), border_mode, PixelValue(constant_border_value)); } +Status NEScale::validate(const ITensorInfo *input, const ITensorInfo *output, InterpolationPolicy policy, + BorderMode border_mode, PixelValue constant_border_value, SamplingPolicy sampling_policy) +{ + ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output); + ARM_COMPUTE_RETURN_ERROR_ON(sampling_policy != SamplingPolicy::CENTER); + ARM_COMPUTE_UNUSED(border_mode, constant_border_value); + + ITensorInfo *offsets = nullptr; + ITensorInfo *dx = nullptr; + ITensorInfo *dy = nullptr; + + // Get data layout and width/height indices + const DataLayout data_layout = input->data_layout(); + const int idx_width = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH); + const int idx_height = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT); + + // Get the tensor shape of auxilary buffers + const TensorShape shape(output->dimension(idx_width), output->dimension(idx_height)); + + TensorInfo tensor_info_offsets(shape, Format::S32); + TensorInfo tensor_info_dx(shape, Format::F32); + TensorInfo tensor_info_dy(shape, Format::F32); + + switch(policy) + { + case InterpolationPolicy::NEAREST_NEIGHBOR: + offsets = &tensor_info_offsets; + break; + case InterpolationPolicy::BILINEAR: + offsets = &tensor_info_offsets; + dx = &tensor_info_dx; + dy = &tensor_info_dy; + break; + default: + break; + } + + ARM_COMPUTE_RETURN_ON_ERROR(NEScaleKernel::validate(input->clone().get(), dx, dy, offsets, output->clone().get(), + policy, border_mode, sampling_policy)); + return Status{}; +} + void NEScale::run() { NEScheduler::get().schedule(&_border_handler, Window::DimZ); -- cgit v1.2.1