From d2fab7315bac3a586f2f1b1c8d64f2441f89ca64 Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Fri, 2 Mar 2018 11:18:12 +0000 Subject: COMPMID-935 - Implementing Convolution with Winograd on OpenCL (part 4) Implemented Winograd Output Transform (2x2,3x3) on OpenCL Implemented CLWinogradConvolutionLayer on OpenCL Change-Id: I6a113fc5f052ca07f878d2b800d2ab003f84af65 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/125148 Reviewed-by: Georgios Pinitas Tested-by: Jenkins --- src/core/CL/kernels/CLWinogradInputTransformKernel.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) (limited to 'src/core/CL/kernels/CLWinogradInputTransformKernel.cpp') diff --git a/src/core/CL/kernels/CLWinogradInputTransformKernel.cpp b/src/core/CL/kernels/CLWinogradInputTransformKernel.cpp index 72adb5f358..3b9350f9ba 100644 --- a/src/core/CL/kernels/CLWinogradInputTransformKernel.cpp +++ b/src/core/CL/kernels/CLWinogradInputTransformKernel.cpp @@ -44,11 +44,11 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, c ARM_COMPUTE_RETURN_ERROR_ON_MSG(kernel_dims.width != 3 || kernel_dims.height != 3, "Winograd input transform only supports 3x3 kernels"); ARM_COMPUTE_UNUSED(kernel_dims); - const TensorShape output_shape = misc::shape_calculator::compute_winograd_input_transform_shape(*input, conv_info, Size2D(3U, 3U)); - // Validate configured output if(output->total_size() != 0) { + const TensorShape output_shape = misc::shape_calculator::compute_winograd_input_transform_shape(*input, conv_info, kernel_dims); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), output_shape); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output); } @@ -151,7 +151,8 @@ void CLWinogradInputTransformKernel::configure(const ICLTensor *input, ICLTensor Status CLWinogradInputTransformKernel::validate(const ITensorInfo *input, const ITensorInfo *output, const PadStrideInfo &conv_info, const Size2D &kernel_dims) { ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output); - ARM_COMPUTE_RETURN_ERROR_ON(validate_arguments(input, output, conv_info, kernel_dims)); + ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, conv_info, kernel_dims)); + ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(), output->clone().get(), conv_info, kernel_dims).first); return Status{}; } -- cgit v1.2.1