aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/NEDepthwiseIm2ColKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/NEDepthwiseIm2ColKernel.cpp')
-rw-r--r--src/core/NEON/kernels/NEDepthwiseIm2ColKernel.cpp19
1 files changed, 10 insertions, 9 deletions
diff --git a/src/core/NEON/kernels/NEDepthwiseIm2ColKernel.cpp b/src/core/NEON/kernels/NEDepthwiseIm2ColKernel.cpp
index b924d9f8bd..cfd8eacfdd 100644
--- a/src/core/NEON/kernels/NEDepthwiseIm2ColKernel.cpp
+++ b/src/core/NEON/kernels/NEDepthwiseIm2ColKernel.cpp
@@ -85,7 +85,7 @@ void NEDepthwiseIm2ColKernel::run_generic(const Window &window)
const int src_y = -pad_top + src_pixel_linear / max_initial_x * stride_y;
// Get pointers
- const uint8_t *const input_ptr = in.ptr() + id.z() * input_stride_z;
+ const uint8_t *const input_ptr = in.ptr() + id.z() / _depth_multiplier * input_stride_z;
auto output_ptr = reinterpret_cast<T *>(out.ptr());
const int height = src_y + _kernel_dims.height;
const int width = src_x + _kernel_dims.width;
@@ -114,24 +114,25 @@ void NEDepthwiseIm2ColKernel::run_generic(const Window &window)
}
NEDepthwiseIm2ColKernel::NEDepthwiseIm2ColKernel()
- : _func(nullptr), _input(nullptr), _output(nullptr), _kernel_dims(), _conv_info(), _has_bias()
+ : _func(nullptr), _input(nullptr), _output(nullptr), _kernel_dims(), _conv_info(), _has_bias(), _depth_multiplier(1)
{
}
-void NEDepthwiseIm2ColKernel::configure(const ITensor *input, ITensor *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias)
+void NEDepthwiseIm2ColKernel::configure(const ITensor *input, ITensor *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, unsigned int depth_multiplier)
{
ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input, output);
ARM_COMPUTE_ERROR_ON(is_data_type_quantized_asymmetric(input->info()->data_type()) && has_bias);
- ARM_COMPUTE_ERROR_ON(input->info()->dimension(2) != output->info()->dimension(2));
+ ARM_COMPUTE_ERROR_ON((input->info()->dimension(2) * depth_multiplier) != output->info()->dimension(2));
ARM_COMPUTE_ERROR_ON(output->info()->dimension(0) != (kernel_dims.width * kernel_dims.height + ((has_bias) ? 1 : 0)));
- _input = input;
- _output = output;
- _kernel_dims = kernel_dims;
- _conv_info = conv_info;
- _has_bias = has_bias;
+ _input = input;
+ _output = output;
+ _kernel_dims = kernel_dims;
+ _conv_info = conv_info;
+ _has_bias = has_bias;
+ _depth_multiplier = depth_multiplier;
// Configure kernel window
Window win = calculate_max_window(*input->info(), Steps());