aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_conv/depthwise/interleaves/generic.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm_conv/depthwise/interleaves/generic.cpp')
-rw-r--r--src/core/NEON/kernels/arm_conv/depthwise/interleaves/generic.cpp20
1 files changed, 11 insertions, 9 deletions
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/interleaves/generic.cpp b/src/core/NEON/kernels/arm_conv/depthwise/interleaves/generic.cpp
index 056f08d037..dc505a013d 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/interleaves/generic.cpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/interleaves/generic.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2022 Arm Limited.
+ * Copyright (c) 2022-2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -32,11 +32,11 @@ namespace interleaves {
PackingArguments::PackingArguments(
unsigned int kernel_rows, unsigned int kernel_cols, size_t weight_element_size,
- bool include_bias, size_t bias_element_size,
+ bool include_bias, size_t bias_element_size, bool premultiply,
arm_gemm::VLType vl_type, size_t accumulator_element_size, unsigned int accumulator_depth_vl,
std::function<bool(unsigned int, unsigned int &, unsigned int &)> get_weight_pos
) : kernel_rows(kernel_rows), kernel_cols(kernel_cols), weight_element_size(weight_element_size),
- include_bias(include_bias), bias_element_size(bias_element_size),
+ include_bias(include_bias), bias_element_size(bias_element_size), premultiply(premultiply),
vl_type(vl_type), accumulator_element_size(accumulator_element_size), accumulator_depth_vl(accumulator_depth_vl),
get_weight_pos(get_weight_pos)
{
@@ -46,7 +46,7 @@ size_t get_storage_size_generic(const PackingArguments &packing_args, const Dept
{
// If the channel multiplier is greater than one, then we treat this as a
// repeated packing of `channel_multiplier`-sized problems.
- if (args.channel_multiplier > 1)
+ if (args.channel_multiplier > 1 && !packing_args.premultiply)
{
DepthwiseArgs args_per_input_channel(args);
args_per_input_channel.input_channels = args.channel_multiplier;
@@ -58,7 +58,7 @@ size_t get_storage_size_generic(const PackingArguments &packing_args, const Dept
const unsigned int vl =
packing_args.accumulator_depth_vl *
arm_gemm::utils::get_vector_length<uint8_t>(packing_args.vl_type) / packing_args.accumulator_element_size;
- const unsigned int n_packs = arm_gemm::iceildiv(args.input_channels, vl);
+ const unsigned int n_packs = arm_gemm::iceildiv(args.input_channels * args.channel_multiplier, vl);
const auto pack_size = (packing_args.include_bias ? packing_args.bias_element_size : 0) +
packing_args.kernel_points() * packing_args.weight_element_size;
return n_packs * pack_size * vl;
@@ -81,7 +81,7 @@ void pack_parameters_generic(
// If the channel multiplier is greater than one, then we treat this as a
// repeated packing of `channel_multiplier`-sized problems.
- if (args.channel_multiplier > 1)
+ if (args.channel_multiplier > 1 && !packing_args.premultiply)
{
// Get a modified copy of the depthwise arguments
DepthwiseArgs args_per_input_channel(args);
@@ -107,17 +107,19 @@ void pack_parameters_generic(
return;
}
+ auto input_channels = args.input_channels * args.channel_multiplier;
+
// Finalise the weight strides
- ld_weight_col = (ld_weight_col == 0) ? args.input_channels : ld_weight_col;
+ ld_weight_col = (ld_weight_col == 0) ? input_channels : ld_weight_col;
ld_weight_row = (ld_weight_row == 0) ? packing_args.kernel_cols * ld_weight_col : ld_weight_row;
const unsigned int vl =
packing_args.accumulator_depth_vl *
arm_gemm::utils::get_vector_length<uint8_t>(packing_args.vl_type) / packing_args.accumulator_element_size;
- for (unsigned int n = 0; n < args.input_channels; n += vl)
+ for (unsigned int n = 0; n < input_channels; n += vl)
{
- const unsigned int todo = std::min(vl, args.input_channels - n);
+ const unsigned int todo = std::min(vl, input_channels - n);
if (packing_args.include_bias)
{