From 7784c837afd5844fb6dc4d166ff253d983abfd2d Mon Sep 17 00:00:00 2001 From: Abe Mbise Date: Thu, 31 May 2018 16:48:41 +0100 Subject: COMPMID-1167: Validation for NEDepthwiseConvolutionLayer Change-Id: I9689e1a0627dc015dd2ce98417e4c97bb55581bb Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/131327 Reviewed-by: Anthony Barbier Tested-by: Jenkins --- .../kernels/NEDepthwiseVectorToTensorKernel.cpp | 38 +++++++++++++++++----- 1 file changed, 29 insertions(+), 9 deletions(-) (limited to 'src/core/NEON/kernels/NEDepthwiseVectorToTensorKernel.cpp') diff --git a/src/core/NEON/kernels/NEDepthwiseVectorToTensorKernel.cpp b/src/core/NEON/kernels/NEDepthwiseVectorToTensorKernel.cpp index 86a6d1c1a8..fe141bef56 100644 --- a/src/core/NEON/kernels/NEDepthwiseVectorToTensorKernel.cpp +++ b/src/core/NEON/kernels/NEDepthwiseVectorToTensorKernel.cpp @@ -34,8 +34,27 @@ #include "arm_compute/core/Types.h" #include "arm_compute/core/Validate.h" #include "arm_compute/core/Window.h" +#include "arm_compute/core/utils/misc/ShapeCalculator.h" using namespace arm_compute; +using namespace arm_compute::misc::shape_calculator; + +namespace +{ +Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, size_t conv_w, size_t conv_h) +{ + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::S32, DataType::F16, DataType::F32); + + if(output->total_size() != 0) + { + TensorShape output_shape = compute_vector_to_tensor_output_shape(input->tensor_shape(), conv_w, conv_h, output->data_layout()); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), output_shape); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output); + } + + return Status{}; +} +} // namespace template void NEDepthwiseVectorToTensorKernel::vector_to_tensor(const Window &window) @@ -76,19 +95,13 @@ NEDepthwiseVectorToTensorKernel::NEDepthwiseVectorToTensorKernel() void NEDepthwiseVectorToTensorKernel::configure(const ITensor *input, ITensor *output, size_t conv_w, size_t conv_h) { - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::S32, DataType::F16, DataType::F32); - ARM_COMPUTE_ERROR_ON_NULLPTR(output); - - TensorShape output_shape = input->info()->tensor_shape(); - output_shape.set(0, conv_w); - output_shape.set(1, conv_h); - output_shape.set(2, input->info()->tensor_shape()[0] / (conv_w * conv_h)); + ARM_COMPUTE_ERROR_ON_NULLPTR(input, output); // Output auto inizialitation if not yet initialized + TensorShape output_shape = compute_vector_to_tensor_output_shape(input->info()->tensor_shape(), conv_w, conv_h, output->info()->data_layout()); auto_init_if_empty(*output->info(), input->info()->clone()->set_tensor_shape(output_shape)); - ARM_COMPUTE_ERROR_ON_MISMATCHING_DIMENSIONS(output->info()->tensor_shape(), output_shape); - ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output); + ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), conv_w, conv_h)); _input = input; _output = output; @@ -121,6 +134,13 @@ void NEDepthwiseVectorToTensorKernel::configure(const ITensor *input, ITensor *o INEKernel::configure(win); } +Status NEDepthwiseVectorToTensorKernel::validate(const ITensorInfo *input, const ITensorInfo *output, size_t conv_w, size_t conv_h) +{ + ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output); + ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, conv_w, conv_h)); + return Status{}; +} + void NEDepthwiseVectorToTensorKernel::run(const Window &window, const ThreadInfo &info) { ARM_COMPUTE_UNUSED(info); -- cgit v1.2.1