aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp')
-rw-r--r--src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp10
1 files changed, 6 insertions, 4 deletions
diff --git a/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp b/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp
index 52b461e255..b76cf6aa10 100644
--- a/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp
+++ b/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp
@@ -167,10 +167,10 @@ Status NEGEMMConvolutionLayer::validate_gemm3d(DataType data_type, int gemm_3d_d
}
void NEGEMMConvolutionLayer::configure(const ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output, const PadStrideInfo &conv_info, const WeightsInfo &weights_info,
- const Size2D &dilation, const ActivationLayerInfo &act_info)
+ const Size2D &dilation, const ActivationLayerInfo &act_info, unsigned int num_groups)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output);
-
+ ARM_COMPUTE_UNUSED(num_groups);
ARM_COMPUTE_ERROR_THROW_ON(NEGEMMConvolutionLayer::validate(input->info(),
weights->info(),
biases != nullptr ? biases->info() : nullptr,
@@ -178,7 +178,8 @@ void NEGEMMConvolutionLayer::configure(const ITensor *input, const ITensor *weig
conv_info,
weights_info,
dilation,
- act_info));
+ act_info,
+ num_groups));
const DataType data_type = input->info()->data_type();
const DataLayout data_layout = input->info()->data_layout();
@@ -346,13 +347,14 @@ void NEGEMMConvolutionLayer::configure(const ITensor *input, const ITensor *weig
}
Status NEGEMMConvolutionLayer::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const PadStrideInfo &conv_info,
- const WeightsInfo &weights_info, const Size2D &dilation, const ActivationLayerInfo &act_info)
+ const WeightsInfo &weights_info, const Size2D &dilation, const ActivationLayerInfo &act_info, unsigned int num_groups)
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, weights, output);
ARM_COMPUTE_RETURN_ERROR_ON_MSG(weights_info.are_reshaped(), "Weights already reshaped are not supported!");
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(input, weights);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(num_groups > 1, "Grouping (num_groups != 1) is not supported on NEON");
const DataLayout data_layout = input->data_layout();
const DataType data_type = input->data_type();