aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_conv/depthwise/depthwise_depthfirst_multiplier.hpp
diff options
context:
space:
mode:
authorramelg01 <ramy.elgammal@arm.com>2022-04-07 02:42:52 +0100
committerRamy Elgammal <ramy.elgammal@arm.com>2022-04-26 15:51:22 +0000
commit8a164884dddf769643cf3b9f7f94e43cb4f3c20b (patch)
tree35958dd48b6df1a851c880dad2b2ce285671b611 /src/core/NEON/kernels/arm_conv/depthwise/depthwise_depthfirst_multiplier.hpp
parentc827e99fc46521f43719b0c2d1b6f05d66abf68c (diff)
downloadComputeLibrary-8a164884dddf769643cf3b9f7f94e43cb4f3c20b.tar.gz
Update Neon™ depthwise kernel
- Reduce duplication and simplify overall structure. - Improve multi-threaded performance by sharing more data in lower-level caches. Partially Resolves: COMPMID-5054 Signed-off-by: Ramy Elgammal <ramy.elgammal@arm.com> Change-Id: Iac747f39b21c540122fa75218762631c4d787911 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7449 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Andrew Mundy Reviewed-by: Sheri Zhang <sheri.zhang@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/arm_conv/depthwise/depthwise_depthfirst_multiplier.hpp')
-rw-r--r--src/core/NEON/kernels/arm_conv/depthwise/depthwise_depthfirst_multiplier.hpp954
1 files changed, 511 insertions, 443 deletions
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/depthwise_depthfirst_multiplier.hpp b/src/core/NEON/kernels/arm_conv/depthwise/depthwise_depthfirst_multiplier.hpp
index 2862361b82..e58467b0f4 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/depthwise_depthfirst_multiplier.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/depthwise_depthfirst_multiplier.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -24,7 +24,8 @@
#pragma once
-#include "src/core/NEON/kernels/arm_gemm/utils.hpp"
+#include "depthwise_depthfirst.hpp"
+#include "interleaves/generic_quantized_dot_product.hpp"
#ifdef CYCLE_PROFILING
#include "profiler.hpp"
@@ -35,492 +36,559 @@
namespace arm_conv {
namespace depthwise {
-namespace common
+template <typename TInput, typename TWeight, typename TOutput, typename TAccum>
+class DepthfirstMultiplierStrategy : public DepthwiseDepthfirstStrategyCommon<TInput, TWeight, TOutput, TAccum, Nothing>
{
- template <typename strategy, typename F>
- void depthwise_multiplier_execute(
- const F execute_tile,
- typename strategy::input_type pad_value,
- const DepthwiseArgs &args,
- const unsigned int batches,
- const unsigned int input_height,
- const unsigned int input_width,
- const unsigned int input_channels,
- const PaddingValues &padding,
- const void *const _input,
- const size_t ld_input_col,
- const size_t ld_input_row,
- const size_t ld_input_batch,
- const void *const parameters,
- const size_t param_stride,
- const unsigned int output_height,
- const unsigned int output_width,
- void *const _output,
- const size_t ld_output_col,
- const size_t ld_output_row,
- const size_t ld_output_batch,
- void *const _working_space,
- const unsigned int thread_id,
- const unsigned int n_threads
- )
- {
- using TInput = typename strategy::input_type;
- using TOutput = typename strategy::return_type;
-
- // Determine what portion of the work to do.
- const unsigned int n_rows_per_thread = arm_gemm::iceildiv(output_height, n_threads);
- const int start_out_height = std::min(thread_id * n_rows_per_thread, output_height);
- const int end_out_height = std::min(start_out_height + n_rows_per_thread, output_height);
-
- // Cast input and output pointers into the right types
- const TInput *const inptr = static_cast<const TInput *>(_input);
- TOutput *const outptr = static_cast<TOutput *>(_output);
-
- // To simplify the kernel, we process padded or non-NCHW-ordered input into
- // a form which can be consumed by the kernel. This data is stored here and
- // passed into the kernel as an array of N pointers (one per row of the
- // input).
- TInput rearranged_input[strategy::input_rows][strategy::input_col_quads*(16 / sizeof(TInput))];
- const TInput *inptrs[strategy::input_rows];
-
- // Create an array for the output pointers
- TOutput * _outptr_array[strategy::output_rows * strategy::output_cols];
- TOutput **const outptr_array = _outptr_array;
-
- // Allocate portions of the working space
- uint8_t *const working_space = static_cast<uint8_t *>(_working_space);
- TOutput *const output_buffer = reinterpret_cast<TOutput *>(working_space);
-
- // For each output tile, construct the requisite set of pointers and call
- // into the kernel.
- for (unsigned int batch = 0; batch < batches; batch++)
- {
- // Get batch pointers
- const auto inptr_batch = inptr + batch * ld_input_batch;
- const auto outptr_batch = outptr + batch * ld_output_batch;
+ using Parent = DepthwiseDepthfirstStrategyCommon<TInput, TWeight, TOutput, TAccum, Nothing>;
- for (int start_out_i = start_out_height;
- start_out_i < end_out_height;
- start_out_i += static_cast<int>(strategy::output_rows))
+ protected:
+ virtual interleaves::PackingArguments get_packing_args(const DepthwiseArgs &args) const
+ {
+ return interleaves::PackingArguments(
+ args.kernel_rows, args.kernel_cols, sizeof(TWeight),
+ true, sizeof(TAccum),
+ this->get_vl_type(),
+ sizeof(TAccum), 1,
+ [args] (unsigned int pos, unsigned int &x, unsigned int &y) -> bool
{
- const int end_out_i = start_out_i + strategy::output_rows;
- const int start_in_i = start_out_i * strategy::stride_rows - padding.top;
- const int end_in_i = start_in_i + strategy::input_rows;
-
- // Compute top/bottom padding
- const auto pad_top = static_cast<unsigned int>(-std::min(start_in_i, 0));
- const auto pad_bottom = static_cast<unsigned int>(-std::min(static_cast<int>(input_height) - end_in_i, 0));
- const unsigned int valid_output_rows = std::min(
- end_out_i - start_out_i,
- static_cast<int>(output_height) - start_out_i
- );
-
- for (int start_out_j = 0; start_out_j < static_cast<int>(output_width);)
+ if (pos < args.kernel_rows * args.kernel_cols)
{
- const int start_in_j = start_out_j * strategy::stride_cols - args.padding.left;
- const int pad_left = -std::min(0, start_in_j);
-
- const int end_out_j = start_out_j + strategy::output_cols;
- const int end_in_j = start_in_j + strategy::input_cols;
-
- const auto pad_right = static_cast<unsigned int>(-std::min(static_cast<int>(input_width) - end_in_j, 0));
- const unsigned int valid_output_cols = std::min(
- end_out_j - start_out_j,
- static_cast<int>(output_width) - start_out_j
- );
-
- // Construct the output pointer array.
- TOutput **outptr_pos = outptr_array;
- for (auto i = 0u; i < valid_output_rows; i++)
- {
- unsigned int j = 0u;
- TOutput *colptr = outptr_batch + (start_out_i + i) * ld_output_row + start_out_j * ld_output_col;
- for (; j < valid_output_cols; j++)
- {
- *(outptr_pos++) = colptr;
- colptr += ld_output_col;
- }
- for (; j < strategy::output_cols; j++)
- {
- *(outptr_pos++) = output_buffer;
- }
- }
- for (auto i = valid_output_rows; i < strategy::output_rows; i++)
- {
- for (auto j = 0u; j < strategy::output_cols; j++)
- {
- *(outptr_pos++) = output_buffer;
- }
- }
-
- start_out_j += strategy::output_cols;
-
- const uint8_t *params = static_cast<const uint8_t *>(parameters);
-
- // Loop over the input channels
- for (unsigned int in_c = 0; in_c < input_channels; in_c++)
- {
- // Construct the input array - first fill with padding values and
- // then fill in correct values.
- for (unsigned int i = 0; i < strategy::input_rows; i++)
- {
- for (unsigned int j = 0;
- j < (16 / sizeof(TInput)) * strategy::input_col_quads; j++)
- {
- rearranged_input[i][j] = pad_value;
- }
- inptrs[i] = rearranged_input[i];
- }
-
- auto inptr_row = inptr_batch + in_c +
- (start_in_i + pad_top) * ld_input_row +
- (start_in_j + pad_left) * ld_input_col;
- if (ld_input_col == 1 && !pad_left &&
- start_in_j + (16 / sizeof(TInput)) * strategy::input_col_quads < input_width)
- {
- // The input tensor is already in NCHW format, and we're reading
- // an unpadded section of it - allow the kernel to read it
- // directly.
- for (unsigned int i = pad_top; i < strategy::input_rows - pad_bottom; i++)
- {
- inptrs[i] = inptr_row;
- inptr_row += ld_input_row;
- }
- }
- else
- {
- // Either the input tensor isn't in NCHW format, or we're reading
- // a padded section. Copy the relevant portion of the input here
- // and allow the kernel to read this.
- for (unsigned int i = pad_top; i < strategy::input_rows - pad_bottom; i++)
- {
- auto inptr_col = inptr_row;
- for (unsigned int j = pad_left; j < strategy::input_cols - pad_right; j++)
- {
- rearranged_input[i][j] = *inptr_col;
- inptr_col += ld_input_col;
- }
- inptr_row += ld_input_row;
- }
- }
-
- execute_tile(inptrs, outptr_array, params);
-
- // Progress the output pointers
- TOutput **outptr_pos = outptr_array;
- for (auto i = 0u; i < strategy::output_rows * strategy::output_cols; i++)
- {
- outptr_pos[i] += args.channel_multiplier;
- }
-
- // Progress the pointer into the parameters
- params += param_stride;
- }
+ y = pos % args.kernel_cols;
+ x = pos / args.kernel_cols;
+ return true;
}
+ return false;
}
- }
+ );
}
-}
-template <class strategy>
-class DepthwiseDepthfirstWithMultiplier :
- public DepthwiseCommon<typename strategy::input_type,
- typename strategy::weight_type,
- typename strategy::return_type>
-{
- using TInput = typename strategy::input_type;
- using TWeight = typename strategy::weight_type;
- using TOutput = typename strategy::return_type;
- using TAccum = typename strategy::bias_type;
+ public:
+ using Parent::Parent;
+
+ size_t get_storage_size(const DepthwiseArgs &args) const override
+ {
+ return interleaves::get_storage_size_generic(this->get_packing_args(args), args);
+ }
- size_t sizeof_output_buffer(unsigned int n_channels) const
+ void pack_parameters(const DepthwiseArgs &args, void *buffer, const void *biases, const Nothing &, const void *weights, size_t ld_weight_col, size_t ld_weight_row) const override
{
- const unsigned int vl = arm_gemm::utils::get_vector_length<TOutput>(strategy::vl_type);
- const auto rounded_channels = arm_gemm::roundup(n_channels, vl);
- return sizeof(TOutput) * rounded_channels;
+ interleaves::pack_parameters_generic(
+ this->get_packing_args(args), args,
+ buffer, biases, weights, ld_weight_col, ld_weight_row
+ );
}
+ using KernelType = std::function<void(
+ const TInput *const *, // Input pointers
+ TOutput *const *, // Output pointers
+ const void *, // Ravelled bias, weights, and quantization parameters
+ unsigned int, // # output channels
+ TAccum, TAccum // Min and max activation clamps
+ )>;
+ virtual KernelType get_kernel(void) const = 0;
+};
+
+
+template <typename TInput, typename TWeight, typename TOutput>
+class DepthfirstMultiplierStrategy<TInput, TWeight, TOutput, int32_t> : public DepthwiseDepthfirstStrategyCommon<TInput, TWeight, TOutput, int32_t, arm_gemm::Requantize32>
+{
+ using Parent = DepthwiseDepthfirstStrategyCommon<TInput, TWeight, TOutput, int32_t, arm_gemm::Requantize32>;
+
public:
- DepthwiseDepthfirstWithMultiplier(const DepthwiseArgs &args) : DepthwiseCommon<TInput, TWeight, TOutput>(args)
+ using Parent::Parent;
+
+ size_t get_storage_size(const DepthwiseArgs &args) const override
{
+ return interleaves::quantized::get_storage_size(args, this->get_vl_type(), this->get_accumulator_depth_vl());
}
- DepthwiseDepthfirstWithMultiplier(DepthwiseDepthfirstWithMultiplier &) = delete;
- DepthwiseDepthfirstWithMultiplier &operator=(DepthwiseDepthfirstWithMultiplier &) = delete;
+ void pack_parameters(const DepthwiseArgs &args, void *buffer, const void *biases, const arm_gemm::Requantize32 &qp, const void *weights, size_t ld_weight_col, size_t ld_weight_row) const override
+ {
+ interleaves::quantized::pack_parameters<TWeight>(
+ buffer, reinterpret_cast<const int32_t *>(biases),
+ reinterpret_cast<const TWeight *>(weights), ld_weight_col, ld_weight_row,
+ args, qp, this->get_vl_type(), this->get_accumulator_depth_vl()
+ );
+ }
- size_t get_storage_size(void) const override
+ using KernelType = std::function<void(
+ const TInput *const *, // Input pointers
+ TOutput *const *, // Output pointers
+ const void *, // Ravelled bias, weights, and quantization parameters
+ unsigned int, // # output channels
+ const arm_gemm::Requantize32 &
+ )>;
+ virtual KernelType get_kernel(void) const = 0;
+};
+
+
+template <typename TInput, typename TWeight, typename TOutput, typename TAccum>
+class GenericDepthfirstMultiplierKernelStrategy
+{
+ const arm_gemm::VLType m_vl_type;
+ const unsigned int m_output_rows, m_output_cols;
+
+ public:
+ GenericDepthfirstMultiplierKernelStrategy(unsigned int output_rows, unsigned int output_cols, arm_gemm::VLType vl_type)
+ : m_vl_type(vl_type), m_output_rows(output_rows), m_output_cols(output_cols)
{
- // TODO What if we insert extra padding? Biases are a different size to the inputs, ...
- const unsigned int vl = arm_gemm::utils::get_vector_length<TInput>(strategy::vl_type);
- const auto rounded_channels = this->m_args.input_channels * arm_gemm::roundup(this->m_args.channel_multiplier, vl);
- return (1 + this->m_args.kernel_rows * this->m_args.kernel_cols) * rounded_channels * sizeof(TWeight);
}
- void pack_parameters(void *_buffer, const void *_biases, const void *_weights, size_t ld_weight_col, size_t ld_weight_row) override
+ virtual ~GenericDepthfirstMultiplierKernelStrategy() = default;
+
+ arm_gemm::VLType get_vl_type(void) const { return m_vl_type; }
+ unsigned int get_output_rows(void) const { return m_output_rows; }
+ unsigned int get_output_cols(void) const { return m_output_cols; }
+
+ using KernelType = std::function<void(
+ const TInput *const *, // Input pointers
+ TOutput *const *, // Output pointers
+ const TWeight *, // Ravelled weight parameters
+ const TAccum *, // Bias,
+ unsigned int, unsigned int, // Number of kernel points, number of output channels
+ TAccum, TAccum // Activation minimum and maximum
+ )>;
+ virtual KernelType get_kernel(void) const = 0;
+};
+
+template <typename TInput, typename TWeight, typename TOutput>
+class GenericDepthfirstMultiplierKernelStrategy<TInput, TWeight, TOutput, int32_t>
+{
+ const arm_gemm::VLType m_vl_type;
+ const unsigned int m_output_rows, m_output_cols;
+
+ public:
+ GenericDepthfirstMultiplierKernelStrategy(unsigned int output_rows, unsigned int output_cols, arm_gemm::VLType vl_type)
+ : m_vl_type(vl_type), m_output_rows(output_rows), m_output_cols(output_cols)
{
- // TODO What if the kernel needs a different packing function?
+ }
- // Cast the pointers
- float *buffer = static_cast<float *>(_buffer);
- const float *biases = static_cast<const float *>(_biases);
- const float *const weights = static_cast<const float *>(_weights);
+ virtual ~GenericDepthfirstMultiplierKernelStrategy() = default;
+
+ arm_gemm::VLType get_vl_type(void) const { return m_vl_type; }
+ unsigned int get_output_rows(void) const { return m_output_rows; }
+ unsigned int get_output_cols(void) const { return m_output_cols; }
+
+ using KernelType = std::function<void(
+ const TInput *const *, // Input pointers
+ TOutput *const *, // Output pointers
+ const TWeight *, // Ravelled weight parameters
+ const int32_t *, // Bias,
+ unsigned int, unsigned int, // Number of kernel points, number of output channels
+ const int32_t *, const int32_t *, const int32_t *, // Per-channel left-shifts, multipliers, right-shifts (need to account for start channel)
+ const arm_gemm::Requantize32 &
+ )>;
+ virtual KernelType get_kernel(void) const = 0;
+};
- const unsigned int vl = arm_gemm::utils::get_vector_length<TInput>(strategy::vl_type);
- ld_weight_col = (ld_weight_col == 0) ? this->m_args.channel_multiplier * this->m_args.input_channels : ld_weight_col;
- ld_weight_row = (ld_weight_row == 0) ? this->m_args.kernel_cols * ld_weight_col : ld_weight_row;
+template <typename TInput,
+ typename TWeight=TInput,
+ typename TOutput=TInput,
+ typename TAccum=typename DefaultTAccum<TInput>::Type,
+ typename OutputStage=typename DefaultOutputStage<TOutput>::Type>
+class GenericDepthfirstMultiplierStrategy : public DepthwiseDepthfirstStrategyCommon<TInput, TWeight, TOutput, TAccum, OutputStage>
+{
+ using KernelStrategyType = GenericDepthfirstMultiplierKernelStrategy<TInput, TWeight, TOutput, TAccum>;
+ std::unique_ptr<KernelStrategyType> m_kern;
- for (unsigned int in_c = 0; in_c < this->m_args.input_channels; in_c++)
- {
- for (unsigned int n = 0; n < this->m_args.channel_multiplier; n += vl)
+ protected:
+ virtual interleaves::PackingArguments get_packing_args(const DepthwiseArgs &args) const
+ {
+ return interleaves::PackingArguments(
+ args.kernel_rows, args.kernel_cols, sizeof(TWeight),
+ false, sizeof(TAccum),
+ this->get_vl_type(),
+ sizeof(TAccum), 1,
+ [args] (unsigned int pos, unsigned int &x, unsigned int &y) -> bool
{
- const unsigned int out_c = in_c * this->m_args.channel_multiplier + n;
- const unsigned int todo = std::min(vl, this->m_args.channel_multiplier - n);
-
- // Copy across the correct amount of bias (or 0)
- for (unsigned int i = 0; i < todo; i++)
+ if (pos < args.kernel_rows * args.kernel_cols)
{
- buffer[i] = (biases == nullptr) ? 0 : biases[out_c + i];
+ y = pos % args.kernel_cols;
+ x = pos / args.kernel_cols;
+ return true;
}
- buffer += vl;
+ return false;
+ }
+ );
+ }
- // Copy each of the weights in turn
- auto weights_row = weights + out_c;
- for (unsigned int i = 0; i < this->m_args.kernel_rows; i++)
- {
- auto weights_col = weights_row;
+ public:
+ GenericDepthfirstMultiplierStrategy(KernelStrategyType *kern, const DepthwiseArgs &args)
+ : DepthwiseDepthfirstStrategyCommon<TInput, TWeight, TOutput, TAccum, OutputStage>(
+ kern->get_output_rows(), kern->get_output_cols(),
+ args.kernel_rows, args.kernel_cols,
+ args.stride_rows, args.stride_cols
+ ),
+ m_kern(kern)
+ {
+ };
- for (unsigned int j = 0; j < this->m_args.kernel_cols; j++)
- {
- for (unsigned int m = 0; m < todo; m++)
- {
- buffer[m] = weights_col[m];
- }
- buffer += vl;
+ arm_gemm::VLType get_vl_type(void) const override { return m_kern->get_vl_type(); }
+ const typename KernelStrategyType::KernelType get_kernel(void) const { return m_kern->get_kernel(); }
- weights_col += ld_weight_col;
- }
+ size_t get_storage_size(const DepthwiseArgs &args) const override
+ {
+ return interleaves::get_storage_size_generic(this->get_packing_args(args), args);
+ }
- weights_row += ld_weight_row;
- }
- }
- }
+ void pack_parameters(const DepthwiseArgs &args, void *buffer, const void *biases, const OutputStage &, const void *weights, size_t ld_weight_col, size_t ld_weight_row) const override
+ {
+ interleaves::pack_parameters_generic(
+ this->get_packing_args(args), args,
+ buffer, biases, weights, ld_weight_col, ld_weight_row
+ );
}
+};
+
+// Specialise elements of the wrapper based on the type of kernel.
+namespace depthfirst_multiplier {
- size_t get_working_size(const unsigned int n_threads, const unsigned int n_channels) const override
+/* Working space element which contains a pointer for each row of input, a row
+ * of padding, and a space which can be used to construct an NCHW-ordered patch
+ * of input.
+ */
+template <typename T, bool IsGeneric=false, typename OutputStage=Nothing>
+class InputPatchElement
+{
+ public:
+ struct Workspace
{
- const unsigned int n_output_channels = n_channels * this->m_args.channel_multiplier;
- return n_threads * sizeof_output_buffer(n_output_channels);
+ constexpr static bool InputPatchIsGeneric = IsGeneric;
+ const T **input_rows;
+ T *input_padding;
+ T *input_patch;
+ };
+
+ static size_t get_element_size(const WorkspaceArgs<IDepthfirstStrategy, OutputStage> &args)
+ {
+ return sizeof_input_rows(args) + sizeof_input_padding(args) + sizeof_input_patch(args);
}
-
- using DepthwiseCommon<typename strategy::input_type, typename strategy::weight_type, typename strategy::return_type>::execute;
- void execute(
- const unsigned int batches,
- const unsigned int input_height,
- const unsigned int input_width,
- const unsigned int input_channels,
- const PaddingValues &padding,
- const void *const _input,
- const size_t ld_input_col,
- const size_t ld_input_row,
- const size_t ld_input_batch,
- const void *const parameters,
- const unsigned int output_height,
- const unsigned int output_width,
- void *const _output,
- const size_t ld_output_col,
- const size_t ld_output_row,
- const size_t ld_output_batch,
- void *const _working_space,
- const unsigned int thread_id,
- const unsigned int n_threads
- ) const override
+
+ template <class WorkspaceType>
+ static void *initialise(WorkspaceType *ws, void *buffer, const WorkspaceArgs<IDepthfirstStrategy, OutputStage> &args)
{
- strategy strat(this->m_args.cpu_info);
-#ifdef CYCLE_PROFILING
- arm_gemm::profiler prof;
-#endif
+ auto buffer_bytes = reinterpret_cast<char *>(buffer);
+
+ ws->input_rows = reinterpret_cast<const T **>(buffer_bytes);
+ buffer_bytes += sizeof_input_rows(args);
+
+ ws->input_padding = reinterpret_cast<T*>(buffer_bytes);
+ buffer_bytes += sizeof_input_padding(args);
+
+ ws->input_patch = reinterpret_cast<T*>(buffer_bytes);
+ buffer_bytes += sizeof_input_patch(args);
- // Compute activation values
- TAccum activation_min = std::numeric_limits<TAccum>::has_infinity ? -std::numeric_limits<TAccum>::infinity() : std::numeric_limits<TAccum>::min();
- TAccum activation_max = std::numeric_limits<TAccum>::has_infinity ? std::numeric_limits<TAccum>::infinity() : std::numeric_limits<TAccum>::max();
+ // Initialise the padding
+ memset(ws->input_padding,
+ get_input_buffer_fill_value(args.output_stage),
+ sizeof_input_padding(args));
- switch (this->m_args.activation.type)
+ return buffer_bytes;
+ }
+
+ protected:
+ static size_t sizeof_input_rows(const WorkspaceArgs<IDepthfirstStrategy, OutputStage> &args)
+ {
+ if (IsGeneric)
+ {
+ return sizeof(T *) * args.strategy->get_output_rows() * args.depthwise_args.kernel_rows * args.depthwise_args.kernel_cols;
+ }
+ else
{
- case arm_gemm::Activation::Type::BoundedReLU:
- activation_max = static_cast<TAccum>(this->m_args.activation.param1);
- // Fall through
- case arm_gemm::Activation::Type::ReLU:
- activation_min = static_cast<TAccum>(0);
- break;
- default:
- break;
+ return sizeof(T *) * args.strategy->get_input_rows();
}
+ }
+
+ static size_t sizeof_input_padding(const WorkspaceArgs<IDepthfirstStrategy, OutputStage> &args)
+ {
+ // Round-up the number of columns to be a whole number of QUADS
+ auto input_cols = arm_gemm::roundup<size_t>(args.strategy->get_input_cols(), 16 / sizeof(T));
+ return sizeof(T) * input_cols;
+ }
- // Determine what portion of the work to do.
- const unsigned int n_rows_per_thread = arm_gemm::iceildiv(output_height, n_threads);
- const int start_out_height = std::min(thread_id * n_rows_per_thread, output_height);
- const int end_out_height = std::min(start_out_height + n_rows_per_thread, output_height);
-
- // Need a stride over blocks of parameters
- const unsigned int vl = arm_gemm::utils::get_vector_length<TOutput>(strategy::vl_type);
- const unsigned int param_stride =
- arm_gemm::roundup(this->m_args.channel_multiplier, vl) *
- (sizeof(TAccum) + sizeof(TWeight) * strategy::kernel_rows * strategy::kernel_cols);
-
- // Cast input and output pointers into the right types
- const TInput *const inptr = static_cast<const TInput *>(_input);
- TOutput *const outptr = static_cast<TOutput *>(_output);
-
- // To simplify the kernel, we process padded or non-NCHW-ordered input into
- // a form which can be consumed by the kernel. This data is stored here and
- // passed into the kernel as an array of N pointers (one per row of the
- // input).
- TInput rearranged_input[strategy::input_rows][strategy::input_col_quads*4];
- const TInput *inptrs[strategy::input_rows];
-
- // Create an array for the output pointers
- TOutput * _outptr_array[strategy::output_rows * strategy::output_cols];
- TOutput **const outptr_array = _outptr_array;
-
- // Allocate portions of the working space
- uint8_t *const working_space = static_cast<uint8_t *>(_working_space) + get_working_size(thread_id, input_channels);
- TOutput *const output_buffer = reinterpret_cast<TOutput *>(working_space);
-
- // For each output tile, construct the requisite set of pointers and call
- // into the kernel.
- for (unsigned int batch = 0; batch < batches; batch++)
+ static size_t sizeof_input_patch(const WorkspaceArgs<IDepthfirstStrategy, OutputStage> &args)
+ {
+ if (IsGeneric)
+ {
+ // Round-up the number of columns to be a whole number of QUADS
+ auto output_cols = arm_gemm::roundup<size_t>(args.strategy->get_output_cols(), 16 / sizeof(T));
+ const auto kernel_points = args.depthwise_args.kernel_rows * args.depthwise_args.kernel_cols;
+ return sizeof(T) * kernel_points * args.strategy->get_output_rows() * output_cols;
+ }
+ else
{
- // Get batch pointers
- const auto inptr_batch = inptr + batch * ld_input_batch;
- const auto outptr_batch = outptr + batch * ld_output_batch;
+ // Round-up the number of columns to be a whole number of QUADS
+ auto input_cols = arm_gemm::roundup<size_t>(args.strategy->get_input_cols(), 16 / sizeof(T));
+ return sizeof(T) * args.strategy->get_input_rows() * input_cols;
+ }
+ }
+};
+
+template <bool IsGeneric, typename TInput, typename TWeight, typename TOutput, typename TAccum, typename OutputStage>
+struct StrategyType
+{
+ using Type = DepthfirstMultiplierStrategy<TInput, TWeight, TOutput, TAccum>;
+
+ template <typename WorkspaceType>
+ static void execute(
+ const DepthwiseArgs &args, const WorkspaceType *ws, const Type *strat,
+ const OutputStage &, const unsigned int,
+ const void *parameters, const void *
+ )
+ {
+ strat->get_kernel()(
+ ws->input_rows,
+ ws->outptr_array,
+ parameters, args.channel_multiplier,
+ ws->activation_min, ws->activation_max
+ );
+ }
+};
- for (int start_out_i = start_out_height;
- start_out_i < end_out_height;
- start_out_i += static_cast<int>(strategy::output_rows))
+template <typename TInput, typename TWeight, typename TOutput, typename TAccum, typename OutputStage>
+struct StrategyType<true, TInput, TWeight, TOutput, TAccum, OutputStage>
+{
+ using Type = GenericDepthfirstMultiplierStrategy<TInput, TWeight, TOutput, TAccum, OutputStage>;
+
+ template <typename WorkspaceType>
+ static void execute(
+ const DepthwiseArgs &args, const WorkspaceType *ws, const Type *strat,
+ const OutputStage &, const unsigned int start_output_channel,
+ const void *parameters, const void *bias
+ )
+ {
+ strat->get_kernel()(
+ ws->input_rows, ws->outptr_array,
+ reinterpret_cast<const TWeight *>(parameters),
+ bias == nullptr ? nullptr : reinterpret_cast<const TAccum *>(bias) + start_output_channel,
+ strat->get_kernel_rows() * strat->get_kernel_cols(),
+ args.channel_multiplier,
+ ws->activation_min, ws->activation_max
+ );
+ }
+};
+
+template <typename TInput, typename TWeight, typename TOutput>
+struct StrategyType<false, TInput, TWeight, TOutput, int32_t, arm_gemm::Requantize32>
+{
+ using Type = DepthfirstMultiplierStrategy<TInput, TWeight, TOutput, int32_t>;
+
+ template <typename WorkspaceType>
+ static void execute(
+ const DepthwiseArgs &args, const WorkspaceType *ws, const Type *strat,
+ const arm_gemm::Requantize32 &qp, const unsigned int,
+ const void *parameters, const void *
+ )
+ {
+ strat->get_kernel()(
+ ws->input_rows,
+ ws->outptr_array,
+ parameters, args.channel_multiplier,
+ qp
+ );
+ }
+};
+
+template <typename TInput, typename TWeight, typename TOutput>
+struct StrategyType<true, TInput, TWeight, TOutput, int32_t, arm_gemm::Requantize32>
+{
+ using Type = GenericDepthfirstMultiplierStrategy<TInput, TWeight, TOutput, int32_t, arm_gemm::Requantize32>;
+
+ template <typename WorkspaceType>
+ static void execute(
+ const DepthwiseArgs &args, const WorkspaceType *ws, const Type *strat,
+ const arm_gemm::Requantize32 &qp, const unsigned int start_output_channel,
+ const void *parameters, const void *
+ )
+ {
+ auto get_ptr = [start_output_channel] (const int32_t *ptr) -> const int32_t *
+ {
+ return ptr == nullptr ? nullptr : ptr + start_output_channel;
+ };
+
+ strat->get_kernel()(
+ ws->input_rows, ws->outptr_array,
+ reinterpret_cast<const TWeight *>(parameters),
+ get_ptr(qp.bias),
+ strat->get_kernel_rows() * strat->get_kernel_cols(),
+ args.channel_multiplier,
+ get_ptr(qp.per_channel_left_shifts),
+ get_ptr(qp.per_channel_muls),
+ get_ptr(qp.per_channel_right_shifts),
+ qp
+ );
+ }
+};
+
+template <bool IsGeneric> struct PrepareInputSample;
+
+template <> struct PrepareInputSample<false>
+{
+ template <typename WorkspaceType, typename StrategyType, typename T>
+ static void execute(
+ const DepthwiseArgs &, WorkspaceType *ws, const StrategyType *strat,
+ T *base_ptr, size_t ld_row, size_t ld_col,
+ const unsigned int input_pad_top, const unsigned int valid_rows,
+ const unsigned int input_pad_left, const unsigned int valid_cols
+ )
+ {
+ fill_nchw_patch_array(
+ ws->input_rows, ws->input_patch, strat->get_input_rows(), strat->get_input_cols(),
+ base_ptr, ld_row, ld_col,
+ ws->input_padding,
+ input_pad_top, valid_rows,
+ input_pad_left, valid_cols
+ );
+ }
+};
+
+template <> struct PrepareInputSample<true>
+{
+ template <typename WorkspaceType, typename StrategyType, typename T>
+ static void execute(
+ const DepthwiseArgs &args, WorkspaceType *ws, const StrategyType *strat,
+ T *base_ptr, size_t ld_row, size_t ld_col,
+ const unsigned int input_pad_top, const unsigned int valid_rows,
+ const unsigned int input_pad_left, const unsigned int valid_cols
+ )
+ {
+ fill_patch_array_generic_kernel(
+ ws->input_rows, ws->input_patch,
+ strat->get_output_rows(), strat->get_output_cols(),
+ args.kernel_rows, args.kernel_cols,
+ args.stride_rows, args.stride_cols,
+ base_ptr, ld_row, ld_col,
+ ws->input_padding,
+ input_pad_top, valid_rows,
+ input_pad_left, valid_cols
+ );
+ }
+};
+
+} // namespace depthfirst_multiplier
+
+template <typename TInput,
+ typename TWeight=TInput,
+ typename TOutput=TInput,
+ typename TAccum=typename DefaultTAccum<TInput>::Type,
+ bool is_generic=false,
+ typename OutputStage=typename DefaultOutputStage<TOutput>::Type>
+class DepthwiseDepthfirstMultiplier : public DepthfirstDriver<TInput, TWeight, TOutput>
+{
+ protected:
+ using StratType = typename depthfirst_multiplier::StrategyType<is_generic, TInput, TWeight, TOutput, TAccum, OutputStage>::Type;
+ using WorkspaceManager = Workspace<
+ OutputArrayElement<TOutput>,
+ depthfirst_multiplier::InputPatchElement<TInput, is_generic, OutputStage>,
+ ActivationsElement<TOutput, OutputStage>
+ >;
+ using WorkingSpace = typename WorkspaceManager::WorkspaceType;
+
+ OutputStage m_os; // Copy of the output parameters
+ const void *m_bias = nullptr; // Copy of the bias (should we need it)
+
+ public:
+ DepthwiseDepthfirstMultiplier(StratType *const strat, const DepthwiseArgs &args, const OutputStage &os = {})
+ : DepthfirstDriver<TInput, TWeight, TOutput>(strat, args), m_os(os)
+ {
+ }
+
+ DepthwiseDepthfirstMultiplier(DepthwiseDepthfirstMultiplier &) = delete;
+ DepthwiseDepthfirstMultiplier &operator=(DepthwiseDepthfirstMultiplier &) = delete;
+
+ size_t get_storage_size(void) const override
+ {
+ return reinterpret_cast<const StratType *>(this->m_strat.get())
+ ->get_storage_size(this->m_args);
+ }
+
+ void pack_parameters(void *buffer, const void *biases, const void *weights, size_t ld_weight_col, size_t ld_weight_row) override
+ {
+ reinterpret_cast<const StratType *>(this->m_strat.get())
+ ->pack_parameters(this->m_args, buffer, biases, m_os, weights, ld_weight_col, ld_weight_row);
+ m_bias = biases;
+ depthwise_depthfirst::stash_bias(m_os, biases);
+ }
+
+ size_t get_working_size_per_thread(const unsigned int n_input_channels) const override
+ {
+ DepthwiseArgs args(this->m_args);
+ args.input_channels = n_input_channels;
+ return WorkspaceManager::get_sizeof_workspace(WorkspaceArgs<IDepthfirstStrategy, OutputStage>(this->m_strat.get(), args, m_os));
+ }
+
+ void initialise_working_space(void *buffer, unsigned int n_input_channels) const override
+ {
+ DepthwiseArgs args(this->m_args);
+ args.input_channels = n_input_channels;
+ return WorkspaceManager::initialise(buffer, WorkspaceArgs<IDepthfirstStrategy, OutputStage>(this->m_strat.get(), args, m_os));
+ }
+
+ void compute_tile_padded(
+ unsigned int output_i, unsigned int output_j,
+ unsigned int output_channel_start, unsigned int output_channel_end,
+ const TensorSpec<const TInput *> &input,
+ const TensorSpec<TOutput *> &output,
+ const void *parameters,
+ void *working_space_raw
+ ) const override
+ {
+ // Get the working space
+ auto ws = reinterpret_cast<WorkingSpace *>(working_space_raw);
+
+ const int ii = static_cast<int>(output_i * this->m_args.stride_rows) - this->m_args.padding.top;
+ const auto input_pad_top = static_cast<unsigned int>(ii < 0 ? -ii : 0);
+ const auto input_i = static_cast<unsigned int>(ii < 0 ? 0 : ii);
+
+ const int ij = static_cast<int>(output_j * this->m_args.stride_cols) - this->m_args.padding.left;
+ const auto input_pad_left = static_cast<unsigned int>(ij < 0 ? -ij : 0);
+ const auto input_j = static_cast<unsigned int>(ij < 0 ? 0 : ij);
+
+ // Compute the output pointer array. We'll update this array after every
+ // invocation of the kernel.
+ fill_pointer_array(
+ ws->outptr_array, this->m_strat->get_output_rows(), this->m_strat->get_output_cols(),
+ output.base + output_i*output.ld_row + output_j*output.ld_col + output_channel_start,
+ output.ld_row, output.ld_col,
+ ws->output_buffer,
+ 0, this->m_args.output_rows - output_i, // Top padding, # valid rows
+ 0, this->m_args.output_cols - output_j // Left padding, # valid columns
+ );
+
+ // Compute the parameter stride
+ DepthwiseArgs single_iter(this->m_args);
+ single_iter.input_channels = 1;
+ const size_t parameter_stride = reinterpret_cast<const StratType *>(this->m_strat.get())
+ ->get_storage_size(single_iter);
+
+ for (; output_channel_start < output_channel_end;
+ output_channel_start += this->m_args.channel_multiplier)
+ {
+ // Compute the input pointer array
+ const auto input_channel = output_channel_start / this->m_args.channel_multiplier;
+
+ // Construct the input patch
+ depthfirst_multiplier::PrepareInputSample<is_generic>::execute(
+ this->m_args, ws, this->m_strat.get(),
+ input.base + input_channel + input_i*input.ld_row + input_j*input.ld_col, input.ld_row, input.ld_col,
+ input_pad_top, this->m_args.input_rows - input_i,
+ input_pad_left, this->m_args.input_cols - input_j
+ );
+
+ // Execute the kernel
+ depthfirst_multiplier::StrategyType<is_generic, TInput, TWeight, TOutput, TAccum, OutputStage>::execute(
+ this->m_args, ws, reinterpret_cast<const StratType *>(this->m_strat.get()), m_os, output_channel_start,
+ parameters, m_bias
+ );
+
+ // Update the output pointers
+ for (unsigned int n = 0; n < this->m_strat->get_output_rows() * this->m_strat->get_output_cols(); n++)
{
- const int end_out_i = start_out_i + strategy::output_rows;
- const int start_in_i = start_out_i * strategy::stride_rows - padding.top;
- const int end_in_i = start_in_i + strategy::input_rows;
-
- // Compute top/bottom padding
- const auto pad_top = static_cast<unsigned int>(-std::min(start_in_i, 0));
- const auto pad_bottom = static_cast<unsigned int>(-std::min(static_cast<int>(input_height) - end_in_i, 0));
- const unsigned int valid_output_rows = std::min(
- end_out_i - start_out_i,
- static_cast<int>(output_height) - start_out_i
- );
-
- for (int start_out_j = 0; start_out_j < static_cast<int>(output_width);)
- {
- const int start_in_j = start_out_j * strategy::stride_cols - this->m_args.padding.left;
- const int pad_left = -std::min(0, start_in_j);
-
- const int end_out_j = start_out_j + strategy::output_cols;
- const int end_in_j = start_in_j + strategy::input_cols;
-
- const auto pad_right = static_cast<unsigned int>(-std::min(static_cast<int>(input_width) - end_in_j, 0));
- const unsigned int valid_output_cols = std::min(
- end_out_j - start_out_j,
- static_cast<int>(output_width) - start_out_j
- );
-
- // Construct the output pointer array.
- TOutput **outptr_pos = outptr_array;
- for (auto i = 0u; i < valid_output_rows; i++)
- {
- unsigned int j = 0u;
- TOutput *colptr = outptr_batch + (start_out_i + i) * ld_output_row + start_out_j * ld_output_col;
- for (; j < valid_output_cols; j++)
- {
- *(outptr_pos++) = colptr;
- colptr += ld_output_col;
- }
- for (; j < strategy::output_cols; j++)
- {
- *(outptr_pos++) = output_buffer;
- }
- }
- for (auto i = valid_output_rows; i < strategy::output_rows; i++)
- {
- for (auto j = 0u; j < strategy::output_cols; j++)
- {
- *(outptr_pos++) = output_buffer;
- }
- }
-
- start_out_j += strategy::output_cols;
-
- const uint8_t *params = static_cast<const uint8_t *>(parameters);
-
- // Loop over the input channels
- for (unsigned int in_c = 0; in_c < input_channels; in_c++)
- {
- // Construct the input array - first fill with padding values and
- // then fill in correct values.
- for (unsigned int i = 0; i < strategy::input_rows; i++)
- {
- for (unsigned int j = 0; j < 4 * strategy::input_col_quads; j++)
- {
- rearranged_input[i][j] = static_cast<TInput>(0);
- }
- inptrs[i] = rearranged_input[i];
- }
-
- auto inptr_row = inptr_batch + in_c +
- (start_in_i + pad_top) * ld_input_row +
- (start_in_j + pad_left) * ld_input_col;
- if (ld_input_col == 1 && !pad_left &&
- start_in_j + 4 * strategy::input_col_quads < input_width)
- {
- // The input tensor is already in NCHW format, and we're reading
- // an unpadded section of it - allow the kernel to read it
- // directly.
- for (unsigned int i = pad_top; i < strategy::input_rows - pad_bottom; i++)
- {
- inptrs[i] = inptr_row;
- inptr_row += ld_input_row;
- }
- }
- else
- {
- // Either the input tensor isn't in NCHW format, or we're reading
- // a padded section. Copy the relevant portion of the input here
- // and allow the kernel to read this.
- for (unsigned int i = pad_top; i < strategy::input_rows - pad_bottom; i++)
- {
- auto inptr_col = inptr_row;
- for (unsigned int j = pad_left; j < strategy::input_cols - pad_right; j++)
- {
- rearranged_input[i][j] = *inptr_col;
- inptr_col += ld_input_col;
- }
- inptr_row += ld_input_row;
- }
- }
-
- {
-#ifdef CYCLE_PROFILING
- auto p = prof.ScopedProfiler(PROFILE_KERNEL, (unsigned long)(strategy::output_rows * strategy::output_cols * this->m_args.channel_multiplier * strategy::kernel_rows * strategy::kernel_cols));
-#endif
- strat.kernel(
- inptrs, outptr_array, params,
- this->m_args.channel_multiplier,
- activation_min, activation_max
- );
- }
-
- // Progress the output pointers
- TOutput **outptr_pos = outptr_array;
- for (auto i = 0u; i < strategy::output_rows * strategy::output_cols; i++)
- {
- outptr_pos[i] += this->m_args.channel_multiplier;
- }
-
- // Progress the pointer into the parameters
- params += param_stride;
- }
- }
+ ws->outptr_array[n] += this->m_args.channel_multiplier;
}
+
+ // Progress the parameters
+ parameters = reinterpret_cast<const char *>(parameters) + parameter_stride;
}
}
};