diff options
Diffstat (limited to 'src/cpu/kernels/internal/CpuDepthwiseConv2dAssemblyWrapperKernel.cpp')
-rw-r--r-- | src/cpu/kernels/internal/CpuDepthwiseConv2dAssemblyWrapperKernel.cpp | 13 |
1 files changed, 9 insertions, 4 deletions
diff --git a/src/cpu/kernels/internal/CpuDepthwiseConv2dAssemblyWrapperKernel.cpp b/src/cpu/kernels/internal/CpuDepthwiseConv2dAssemblyWrapperKernel.cpp index 73bf7dcb8a..5360abf5ac 100644 --- a/src/cpu/kernels/internal/CpuDepthwiseConv2dAssemblyWrapperKernel.cpp +++ b/src/cpu/kernels/internal/CpuDepthwiseConv2dAssemblyWrapperKernel.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022 Arm Limited. + * Copyright (c) 2021-2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -62,6 +62,9 @@ void create_arm_dwc(const ITensorInfo *src, const ITensorInfo *weights, ITensorI unsigned int stride_rows{}; std::tie(stride_cols, stride_rows) = info.pad_stride_info.stride(); + unsigned int dilation_cols = info.dilation.x(); + unsigned int dilation_rows = info.dilation.y(); + const arm_conv::PaddingValues padding = assembly_utils::map_to_arm_conv_padding(info.pad_stride_info); const unsigned int n_batches = src->dimension(idx_batches); @@ -76,7 +79,7 @@ void create_arm_dwc(const ITensorInfo *src, const ITensorInfo *weights, ITensorI const arm_gemm::Activation activation = assembly_utils::map_to_arm_gemm_activation(info.act_info); - arm_conv::depthwise::DepthwiseArgs args(&cpu_info, kernel_rows, kernel_cols, stride_rows, stride_cols, + arm_conv::depthwise::DepthwiseArgs args(&cpu_info, kernel_rows, kernel_cols, stride_rows, stride_cols, dilation_rows, dilation_cols, n_batches, src_rows, src_cols, n_channels, dst_rows, dst_cols, info.depth_multiplier, padding, activation, nullptr); @@ -103,6 +106,9 @@ void create_arm_dwc_quant(const ITensorInfo *src, const ITensorInfo *weights, IT unsigned int stride_rows{}; std::tie(stride_cols, stride_rows) = info.pad_stride_info.stride(); + unsigned int dilation_cols = info.dilation.x(); + unsigned int dilation_rows = info.dilation.y(); + const arm_conv::PaddingValues padding = assembly_utils::map_to_arm_conv_padding(info.pad_stride_info); const unsigned int n_batches = src->dimension(idx_batches); @@ -117,7 +123,7 @@ void create_arm_dwc_quant(const ITensorInfo *src, const ITensorInfo *weights, IT const arm_gemm::Activation activation = assembly_utils::map_to_arm_gemm_activation(info.act_info); - arm_conv::depthwise::DepthwiseArgs args(&cpu_info, kernel_rows, kernel_cols, stride_rows, stride_cols, + arm_conv::depthwise::DepthwiseArgs args(&cpu_info, kernel_rows, kernel_cols, stride_rows, stride_cols, dilation_rows, dilation_cols, n_batches, src_rows, src_cols, n_channels, dst_rows, dst_cols, info.depth_multiplier, padding, activation, nullptr); @@ -265,7 +271,6 @@ Status CpuDepthwiseConv2dAssemblyWrapperKernel::validate(const ITensorInfo *src, ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(src); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::F16, DataType::F32); ARM_COMPUTE_RETURN_ERROR_ON_MSG(src->data_layout() != DataLayout::NHWC, "Only NHWC is supported by assembly kernels"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(info.dilation != Size2D(1, 1), "Assembly kernels do not support dilation != (1, 1)"); if(is_data_type_quantized_per_channel(weights->data_type())) { |