diff options
Diffstat (limited to 'src/core/NEON/kernels/NEIm2ColKernel.cpp')
-rw-r--r-- | src/core/NEON/kernels/NEIm2ColKernel.cpp | 12 |
1 files changed, 7 insertions, 5 deletions
diff --git a/src/core/NEON/kernels/NEIm2ColKernel.cpp b/src/core/NEON/kernels/NEIm2ColKernel.cpp index 8cb4f4b889..98b1488a9d 100644 --- a/src/core/NEON/kernels/NEIm2ColKernel.cpp +++ b/src/core/NEON/kernels/NEIm2ColKernel.cpp @@ -45,12 +45,13 @@ using namespace arm_compute; namespace { Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info, - bool has_bias, bool is_fully_connected, bool is_flatten, const Size2D &dilation) + bool has_bias, const Size2D &dilation, unsigned int num_groups, bool is_fully_connected, bool is_flatten) { ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32); ARM_COMPUTE_RETURN_ERROR_ON(input->data_type() == DataType::QASYMM8 && has_bias); ARM_COMPUTE_RETURN_ERROR_ON((dilation.x() < 1) || (dilation.y() < 1)); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(num_groups > 1, "Number of groups greater than one are not supported on NEON"); if(output->total_size() > 0) { @@ -290,13 +291,14 @@ NEIm2ColKernel::NEIm2ColKernel() } void NEIm2ColKernel::configure(const ITensor *input, ITensor *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info, - bool has_bias, bool is_fully_connected, bool is_flatten, const Size2D &dilation) + bool has_bias, const Size2D &dilation, unsigned int num_groups, bool is_fully_connected, bool is_flatten) { ARM_COMPUTE_ERROR_ON_NULLPTR(input, output); // Perform validation step ARM_COMPUTE_UNUSED(is_fully_connected, is_flatten); - ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), kernel_dims, conv_info, has_bias, is_fully_connected, is_flatten, dilation)); + ARM_COMPUTE_UNUSED(num_groups); + ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), kernel_dims, conv_info, has_bias, dilation, num_groups, is_fully_connected, is_flatten)); const DataLayout data_layout = input->info()->data_layout(); const unsigned int width_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH); @@ -378,9 +380,9 @@ void NEIm2ColKernel::configure(const ITensor *input, ITensor *output, const Size } Status NEIm2ColKernel::validate(const ITensorInfo *input, const ITensorInfo *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info, - bool has_bias, bool is_fully_connected, bool is_flatten, const Size2D &dilation) + bool has_bias, const Size2D &dilation, unsigned int num_groups, bool is_fully_connected, bool is_flatten) { - ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, kernel_dims, conv_info, has_bias, is_fully_connected, is_flatten, dilation)); + ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, kernel_dims, conv_info, has_bias, dilation, num_groups, is_fully_connected, is_flatten)); return Status{}; } |