aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/kernels/CLWinogradInputTransformKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/kernels/CLWinogradInputTransformKernel.cpp')
-rw-r--r--src/core/CL/kernels/CLWinogradInputTransformKernel.cpp7
1 files changed, 4 insertions, 3 deletions
diff --git a/src/core/CL/kernels/CLWinogradInputTransformKernel.cpp b/src/core/CL/kernels/CLWinogradInputTransformKernel.cpp
index 72adb5f358..3b9350f9ba 100644
--- a/src/core/CL/kernels/CLWinogradInputTransformKernel.cpp
+++ b/src/core/CL/kernels/CLWinogradInputTransformKernel.cpp
@@ -44,11 +44,11 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, c
ARM_COMPUTE_RETURN_ERROR_ON_MSG(kernel_dims.width != 3 || kernel_dims.height != 3, "Winograd input transform only supports 3x3 kernels");
ARM_COMPUTE_UNUSED(kernel_dims);
- const TensorShape output_shape = misc::shape_calculator::compute_winograd_input_transform_shape(*input, conv_info, Size2D(3U, 3U));
-
// Validate configured output
if(output->total_size() != 0)
{
+ const TensorShape output_shape = misc::shape_calculator::compute_winograd_input_transform_shape(*input, conv_info, kernel_dims);
+
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), output_shape);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
}
@@ -151,7 +151,8 @@ void CLWinogradInputTransformKernel::configure(const ICLTensor *input, ICLTensor
Status CLWinogradInputTransformKernel::validate(const ITensorInfo *input, const ITensorInfo *output, const PadStrideInfo &conv_info, const Size2D &kernel_dims)
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
- ARM_COMPUTE_RETURN_ERROR_ON(validate_arguments(input, output, conv_info, kernel_dims));
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, conv_info, kernel_dims));
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(), output->clone().get(), conv_info, kernel_dims).first);
return Status{};
}