aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/NECol2ImKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/NECol2ImKernel.cpp')
-rw-r--r--src/core/NEON/kernels/NECol2ImKernel.cpp48
1 files changed, 23 insertions, 25 deletions
diff --git a/src/core/NEON/kernels/NECol2ImKernel.cpp b/src/core/NEON/kernels/NECol2ImKernel.cpp
index bb8e758b78..d6517ac072 100644
--- a/src/core/NEON/kernels/NECol2ImKernel.cpp
+++ b/src/core/NEON/kernels/NECol2ImKernel.cpp
@@ -29,26 +29,17 @@
#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/core/Types.h"
#include "arm_compute/core/Validate.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include <arm_neon.h>
#include <cstddef>
#include <cstdint>
using namespace arm_compute;
+using namespace misc::shape_calculator;
namespace
{
-TensorShape get_output_shape(const ITensorInfo *input, const Size2D &convolved_dims)
-{
- TensorShape output_shape = input->tensor_shape();
- output_shape.set(0, convolved_dims.width);
- output_shape.set(1, convolved_dims.height);
- output_shape.set(2, input->tensor_shape()[0]);
- output_shape.set(3, input->tensor_shape()[3]); // For NEON the batch size is on the fourth dimension of the input tensor
-
- return output_shape;
-}
-
Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const Size2D &convolved_dims)
{
//Note: ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input) is not needed here as this kernel doesn't use NEON FP16 instructions.
@@ -60,12 +51,28 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, c
// Validate configured output
if(output->total_size() != 0)
{
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), get_output_shape(input, convolved_dims));
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), compute_col2im_shape(*input, convolved_dims, false));
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
}
return Status{};
}
+
+std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output, const Size2D &convolved_dims)
+{
+ // Output auto inizialitation if not yet initialized
+ auto_init_if_empty(*output, input->clone()->set_tensor_shape(compute_col2im_shape(*input, convolved_dims, false)));
+
+ // Configure kernel window
+ Window win = calculate_max_window(*input, Steps());
+
+ // The NECol2ImKernel doesn't need padding so update_window_and_padding() can be skipped
+ Coordinates coord;
+ coord.set_num_dimensions(output->num_dimensions());
+ output->set_valid_region(ValidRegion(coord, output->tensor_shape()));
+
+ return std::make_pair(Status{}, win);
+}
} // namespace
template <typename T>
@@ -102,11 +109,6 @@ NECol2ImKernel::NECol2ImKernel()
void NECol2ImKernel::configure(const ITensor *input, ITensor *output, const Size2D &convolved_dims)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
-
- // Output auto inizialitation if not yet initialized
- auto_init_if_empty(*output->info(), input->info()->clone()->set_tensor_shape(get_output_shape(input->info(), convolved_dims)));
-
- // Perform validation step
ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), convolved_dims));
_input = input;
@@ -130,19 +132,15 @@ void NECol2ImKernel::configure(const ITensor *input, ITensor *output, const Size
}
// Configure kernel window
- Window win = calculate_max_window(*input->info(), Steps());
-
- // The NECol2ImKernel doesn't need padding so update_window_and_padding() can be skipped
- Coordinates coord;
- coord.set_num_dimensions(output->info()->num_dimensions());
- output->info()->set_valid_region(ValidRegion(coord, output->info()->tensor_shape()));
-
- INEKernel::configure(win);
+ auto win_config = validate_and_configure_window(input->info(), output->info(), convolved_dims);
+ ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
+ INEKernel::configure(win_config.second);
}
Status NECol2ImKernel::validate(const ITensorInfo *input, const ITensorInfo *output, const Size2D &convolved_dims)
{
ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, convolved_dims));
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(), output->clone().get(), convolved_dims).first);
return Status{};
}