From ad0c7388f6261989a268ffb2d042f2bd80736e3f Mon Sep 17 00:00:00 2001 From: Giorgio Arena Date: Mon, 23 Apr 2018 16:16:21 +0100 Subject: COMPMID-1068 Create validate method to CLDepthWiseConvolution Change-Id: I3301b66a8a072c6ecd0d7f2dabef350017b55ac4 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/128677 Tested-by: Jenkins Reviewed-by: Anthony Barbier --- .../CL/kernels/CLDepthwiseVectorToTensorKernel.cpp | 47 +++++++++++++++++----- 1 file changed, 37 insertions(+), 10 deletions(-) (limited to 'src/core/CL/kernels/CLDepthwiseVectorToTensorKernel.cpp') diff --git a/src/core/CL/kernels/CLDepthwiseVectorToTensorKernel.cpp b/src/core/CL/kernels/CLDepthwiseVectorToTensorKernel.cpp index 83fc168f45..26336ebf79 100644 --- a/src/core/CL/kernels/CLDepthwiseVectorToTensorKernel.cpp +++ b/src/core/CL/kernels/CLDepthwiseVectorToTensorKernel.cpp @@ -34,6 +34,34 @@ using namespace arm_compute; +namespace +{ +TensorShape compute_output_shape(const TensorShape &input, size_t conv_w, size_t conv_h) +{ + TensorShape output_shape(input); + output_shape.set(0, conv_w); + output_shape.set(1, conv_h); + output_shape.set(2, input.x() / (conv_w * conv_h)); + + return output_shape; +} + +Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, size_t conv_w, size_t conv_h) +{ + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::S32, DataType::F16, DataType::F32); + + if(output->total_size() != 0) + { + TensorShape output_shape = compute_output_shape(input->tensor_shape(), conv_w, conv_h); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), output_shape); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, output); + } + + return Status{}; +} +} // namespace + CLDepthwiseVectorToTensorKernel::CLDepthwiseVectorToTensorKernel() : _input(nullptr), _output(nullptr) { @@ -41,20 +69,13 @@ CLDepthwiseVectorToTensorKernel::CLDepthwiseVectorToTensorKernel() void CLDepthwiseVectorToTensorKernel::configure(const ICLTensor *input, ICLTensor *output, size_t conv_w, size_t conv_h) { - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::S32, DataType::F16, DataType::F32); - ARM_COMPUTE_ERROR_ON_NULLPTR(output); - - TensorShape output_shape = input->info()->tensor_shape(); - output_shape.set(0, conv_w); - output_shape.set(1, conv_h); - output_shape.set(2, input->info()->tensor_shape()[0] / (conv_w * conv_h)); + ARM_COMPUTE_ERROR_ON_NULLPTR(input, output); // Output auto inizialitation if not yet initialized + TensorShape output_shape = compute_output_shape(input->info()->tensor_shape(), conv_w, conv_h); auto_init_if_empty(*output->info(), input->info()->clone()->set_tensor_shape(output_shape)); - ARM_COMPUTE_ERROR_ON_MISMATCHING_DIMENSIONS(output->info()->tensor_shape(), output_shape); - ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output); - ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input, output); + ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), conv_w, conv_h)); _input = input; _output = output; @@ -75,6 +96,12 @@ void CLDepthwiseVectorToTensorKernel::configure(const ICLTensor *input, ICLTenso ICLKernel::configure(win); } +Status CLDepthwiseVectorToTensorKernel::validate(const ITensorInfo *input, const ITensorInfo *output, size_t conv_w, size_t conv_h) +{ + ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, conv_w, conv_h)); + return Status{}; +} + void CLDepthwiseVectorToTensorKernel::run(const Window &window, cl::CommandQueue &queue) { ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); -- cgit v1.2.1