aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_conv/depthwise/depthwise_depthfirst.hpp
diff options
context:
space:
mode:
authorramelg01 <ramy.elgammal@arm.com>2022-04-07 02:42:52 +0100
committerRamy Elgammal <ramy.elgammal@arm.com>2022-04-26 15:51:22 +0000
commit8a164884dddf769643cf3b9f7f94e43cb4f3c20b (patch)
tree35958dd48b6df1a851c880dad2b2ce285671b611 /src/core/NEON/kernels/arm_conv/depthwise/depthwise_depthfirst.hpp
parentc827e99fc46521f43719b0c2d1b6f05d66abf68c (diff)
downloadComputeLibrary-8a164884dddf769643cf3b9f7f94e43cb4f3c20b.tar.gz
Update Neon™ depthwise kernel
- Reduce duplication and simplify overall structure. - Improve multi-threaded performance by sharing more data in lower-level caches. Partially Resolves: COMPMID-5054 Signed-off-by: Ramy Elgammal <ramy.elgammal@arm.com> Change-Id: Iac747f39b21c540122fa75218762631c4d787911 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7449 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Andrew Mundy Reviewed-by: Sheri Zhang <sheri.zhang@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/arm_conv/depthwise/depthwise_depthfirst.hpp')
-rw-r--r--src/core/NEON/kernels/arm_conv/depthwise/depthwise_depthfirst.hpp736
1 files changed, 468 insertions, 268 deletions
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/depthwise_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/depthwise_depthfirst.hpp
index 57fa11151b..6905076357 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/depthwise_depthfirst.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/depthwise_depthfirst.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -24,7 +24,9 @@
#pragma once
-#include "src/core/NEON/kernels/arm_gemm/utils.hpp"
+#include "src/core/NEON/kernels/arm_conv/addressing.hpp"
+#include "depthwise_strategies_common.hpp"
+#include "working_space.hpp"
#ifdef CYCLE_PROFILING
#include "profiler.hpp"
@@ -35,349 +37,547 @@
namespace arm_conv {
namespace depthwise {
-struct IDepthwiseDepthfirstStrategy
+template <typename TInput, typename TWeight, typename TOutput, typename TAccum,
+ typename OutputStage>
+class DepthwiseDepthfirstStrategyCommon
+ : public DepthfirstStrategy<TInput, TWeight, TOutput, TAccum, OutputStage>
{
- virtual arm_gemm::VLType get_vl_type() const = 0;
+ protected:
+ unsigned int m_output_rows, m_output_cols;
+ unsigned int m_kernel_rows, m_kernel_cols;
+ unsigned int m_stride_rows, m_stride_cols;
- virtual unsigned int get_input_rows() const = 0;
- virtual unsigned int get_input_cols() const = 0;
+ public:
+ DepthwiseDepthfirstStrategyCommon(
+ unsigned int output_rows, unsigned int output_cols,
+ unsigned int kernel_rows, unsigned int kernel_cols,
+ unsigned int stride_rows=1, unsigned int stride_cols=1
+ ) : m_output_rows(output_rows), m_output_cols(output_cols),
+ m_kernel_rows(kernel_rows), m_kernel_cols(kernel_cols),
+ m_stride_rows(stride_rows), m_stride_cols(stride_cols)
+ {
+ }
+
+ DepthwiseDepthfirstStrategyCommon(unsigned int output_size, unsigned int kernel_size, unsigned int stride=1)
+ : DepthwiseDepthfirstStrategyCommon(output_size, output_size, kernel_size, kernel_size, stride, stride)
+ {
+ }
+
+ virtual ~DepthwiseDepthfirstStrategyCommon() {}
+
+ unsigned int get_output_rows() const override { return m_output_rows; }
+ unsigned int get_output_cols() const override { return m_output_cols; }
- virtual unsigned int get_output_rows() const = 0;
- virtual unsigned int get_output_cols() const = 0;
+ unsigned int get_kernel_rows() const override { return m_kernel_rows; }
+ unsigned int get_kernel_cols() const override { return m_kernel_cols; }
- virtual unsigned int get_kernel_rows() const = 0;
- virtual unsigned int get_kernel_cols() const = 0;
+ unsigned int get_stride_rows() const override { return m_stride_rows; }
+ unsigned int get_stride_cols() const override { return m_stride_cols; }
+};
+
+template <typename TInput, typename TWeight, typename TOutput, typename TAccum, typename OutputStage=typename DefaultOutputStage<TOutput>::Type>
+class DepthwiseDepthfirstStrategy : public DepthwiseDepthfirstStrategyCommon<TInput, TWeight, TOutput, TAccum, OutputStage>
+{
+ using Parent = DepthwiseDepthfirstStrategyCommon<TInput, TWeight, TOutput, TAccum, OutputStage>;
- virtual unsigned int get_stride_rows() const = 0;
- virtual unsigned int get_stride_cols() const = 0;
+ public:
+ using Parent::Parent;
- virtual void indirect_kernel(
- const void *const *const input_ptrs,
- void *const *const output_ptrs,
+ typedef void (*IndirectKernelType)(
+ const TInput *const *input_ptrs,
+ TOutput *const *output_ptrs,
const void *params,
unsigned int n_channels,
- const void *activation_min,
- const void *activation_max
- ) const = 0;
+ const TAccum activation_min,
+ const TAccum activation_max
+ );
+ virtual IndirectKernelType get_indirect_kernel(void) const = 0;
- virtual void direct_kernel(
+ typedef void (*DirectKernelType)(
const unsigned int n_tile_rows, const unsigned int n_tile_cols,
- const void *inptr, int64_t ld_input_row, int64_t ld_input_col,
- void *outptr, int64_t ld_output_row, int64_t ld_output_col,
+ const TInput *inptr_base, int64_t ld_input_row, int64_t ld_input_col,
+ TOutput *outptr_base, int64_t ld_output_row, int64_t ld_output_col,
const void *params, unsigned int n_channels,
- const void *activation_min,
- const void *activation_max
- ) const = 0;
-
- virtual ~IDepthwiseDepthfirstStrategy() {}
+ const TAccum activation_min,
+ const TAccum activation_max
+ );
+ virtual DirectKernelType get_direct_kernel(void) const = 0;
};
-template <typename TInput, typename TWeight, typename TOutput, typename TAccum>
-class DepthwiseDepthfirst : public DepthwiseCommon<TInput, TWeight, TOutput>
+template <typename TInput, typename TWeight, typename TOutput>
+class DepthwiseDepthfirstStrategy<TInput, TWeight, TOutput, int32_t>
+: public DepthwiseDepthfirstStrategyCommon<TInput, TWeight, TOutput, int32_t, arm_gemm::Requantize32>
{
- const std::unique_ptr<IDepthwiseDepthfirstStrategy> m_strat;
+ using Parent = DepthwiseDepthfirstStrategyCommon<TInput, TWeight, TOutput, int32_t, arm_gemm::Requantize32>;
- size_t sizeof_inptr_array(void) const
+ protected:
+ interleaves::PackingArguments get_packing_args(void) const
{
- return sizeof(TInput *) * m_strat->get_input_rows() * m_strat->get_input_cols();
+ return interleaves::PackingArguments(
+ this->get_kernel_rows(), this->get_kernel_cols(), sizeof(TWeight),
+ false, sizeof(int32_t), // Don't pack the bias
+ this->get_vl_type(), sizeof(int32_t), this->get_accumulator_depth_vl(),
+ [this] (unsigned int idx, unsigned int &x, unsigned int &y) -> bool
+ { return this->get_kernel_packing_point(idx, x, y); }
+ );
}
- size_t sizeof_input_buffer(unsigned int n_input_channels) const
+ public:
+ using Parent::Parent;
+
+ typedef void (*KernelType)(
+ unsigned int, // n_channels,
+ const TInput *const *, // inptrs
+ const TWeight *, // weights
+ const int32_t *, // bias,
+ const arm_gemm::Requantize32 &,
+ const int32_t *, const int32_t *, // requant_muls and requant_shifts
+ TOutput *const * // outptrs
+ );
+ virtual KernelType get_kernel() const = 0;
+
+ size_t get_storage_size(const DepthwiseArgs &args) const override
{
- return sizeof(TInput) * n_input_channels;
+ return interleaves::get_storage_size_generic(get_packing_args(), args);
}
- size_t sizeof_outptr_array(void) const
+ void pack_parameters(
+ const DepthwiseArgs &args, void *buffer,
+ const void *biases, const arm_gemm::Requantize32 &,
+ const void *weights, size_t ld_weight_col, size_t ld_weight_row
+ ) const override
{
- return sizeof(TInput *) * m_strat->get_output_rows() * m_strat->get_output_cols();
+ interleaves::pack_parameters_generic(
+ get_packing_args(), args, buffer, biases, weights, ld_weight_col, ld_weight_row);
}
+};
- size_t sizeof_output_buffer(unsigned int n_output_channels) const
- {
- return sizeof(TOutput) * n_output_channels;
- }
+template <typename TInput, typename TWeight, typename TOutput, typename TAccum, typename OutputStage>
+class DepthwiseDepthfirstCommon : public DepthfirstDriver<TInput, TWeight, TOutput>
+{
+ using StratType = DepthwiseDepthfirstStrategyCommon<TInput, TWeight, TOutput, TAccum, OutputStage>;
+ OutputStage m_os;
+
+ protected:
+ inline OutputStage &get_output_stage(void) { return m_os; }
+ inline const OutputStage &get_output_stage(void) const { return m_os; }
public:
- DepthwiseDepthfirst(
- IDepthwiseDepthfirstStrategy *const strat,
- const DepthwiseArgs &args
- ) : DepthwiseCommon<TInput, TWeight, TOutput>(args), m_strat(strat)
+ DepthwiseDepthfirstCommon(StratType *const strat, const DepthwiseArgs &args, const OutputStage &os)
+ : DepthfirstDriver<TInput, TWeight, TOutput>(strat, args), m_os(os)
{
}
- DepthwiseDepthfirst(DepthwiseDepthfirst &) = delete;
- DepthwiseDepthfirst &operator=(DepthwiseDepthfirst &) = delete;
+ DepthwiseDepthfirstCommon(DepthwiseDepthfirstCommon &) = delete;
+ DepthwiseDepthfirstCommon &operator=(DepthwiseDepthfirstCommon &) = delete;
size_t get_storage_size(void) const override
{
- // TODO What if we insert extra padding? Biases are a different size to the inputs, ...
- const unsigned int vl = arm_gemm::utils::get_vector_length<TInput>(m_strat->get_vl_type());
- const auto rounded_channels = arm_gemm::roundup(this->m_args.input_channels, vl);
- return (1 + this->m_args.kernel_rows * this->m_args.kernel_cols) * rounded_channels * sizeof(TWeight);
+ return reinterpret_cast<const StratType *>(this->m_strat.get())->
+ get_storage_size(this->m_args);
}
- void pack_parameters(void *_buffer, const void *_biases, const void *_weights, size_t ld_weight_col, size_t ld_weight_row) override
+ void pack_parameters(void *buffer, const void *biases, const void *weights, size_t ld_weight_col, size_t ld_weight_row) override
{
- // TODO What if the kernel needs a different packing function?
+ reinterpret_cast<const StratType *>(this->m_strat.get())->
+ pack_parameters(this->m_args, buffer, biases, m_os, weights, ld_weight_col, ld_weight_row);
+ }
+};
- // Cast the pointers
- uint8_t *buffer = static_cast<uint8_t *>(_buffer);
- const TAccum *biases = static_cast<const TAccum *>(_biases);
- const TWeight *const weights = static_cast<const TWeight *>(_weights);
+namespace depthwise_depthfirst {
- const unsigned int vl = arm_gemm::utils::get_vector_length<TAccum>(m_strat->get_vl_type());
- ld_weight_col = (ld_weight_col == 0) ? this->m_args.input_channels : ld_weight_col;
- ld_weight_row = (ld_weight_row == 0) ? this->m_args.kernel_cols * ld_weight_col : ld_weight_row;
+/* Workspace Element for an array of input pointers as consumed by the
+ * specialised depthwise kernels.
+ */
+template <typename T>
+class InputArrayElement
+{
+ public:
+ struct Workspace
+ {
+ const T **inptr_array;
+ };
- for (unsigned int n = 0; n < this->m_args.input_channels; n += vl)
- {
- const unsigned int todo = std::min(vl, this->m_args.input_channels - n);
+ template <class OutputStage>
+ static size_t get_element_size(const WorkspaceArgs<IDepthfirstStrategy, OutputStage> &args)
+ {
+ return sizeof(T **) * args.strategy->get_input_rows() * args.strategy->get_input_cols();
+ }
- // Copy across the correct amount of bias (or 0)
- for (unsigned int i = 0; i < todo; i++)
- {
- reinterpret_cast<TAccum *>(buffer)[i] = (biases == nullptr) ? 0 : biases[n + i];
- }
- buffer += vl * sizeof(TAccum);
+ template <class WorkspaceType, class OutputStage>
+ static void *initialise(WorkspaceType *ws, void *buffer, const WorkspaceArgs<IDepthfirstStrategy, OutputStage> &args)
+ {
+ ws->inptr_array = reinterpret_cast<const T**>(buffer);
+ return reinterpret_cast<char *>(buffer) + get_element_size(args);
+ }
+};
- // Copy each of the weights in turn
- auto weights_row = weights + n;
- for (unsigned int i = 0; i < this->m_args.kernel_rows; i++)
- {
- auto weights_col = weights_row;
+template <typename TAccum, typename OutputStage, bool IsDot=false>
+struct WorkspaceFinalElement
+{
+ using Element = ActivationsElement<TAccum, OutputStage>;
+};
- for (unsigned int j = 0; j < this->m_args.kernel_cols; j++)
- {
- for (unsigned int m = 0; m < todo; m++)
- {
- reinterpret_cast<TWeight *>(buffer)[m] = weights_col[m];
- }
- buffer += vl * sizeof(TWeight);
+template <>
+struct WorkspaceFinalElement<int32_t, arm_gemm::Requantize32, false>
+{
+ using Element = RequantizationParametersElement;
+};
- weights_col += ld_weight_col;
- }
+template <typename TInput, typename TWeight, typename TOutput, typename TAccum, typename OutputStage>
+struct Invoke
+{
+ constexpr static bool supports_direct_kernel = true;
- weights_row += ld_weight_row;
- }
- }
+ template <typename Strat, typename Workspace>
+ static inline void indirect(const Strat *strat, const Workspace *ws, const OutputStage &, const void *params, const TAccum *, unsigned int n_channels)
+ {
+ strat->get_indirect_kernel()(
+ ws->inptr_array,
+ ws->outptr_array,
+ params, n_channels,
+ ws->activation_min, ws->activation_max
+ );
}
- size_t get_working_size(const unsigned int n_threads, const unsigned int n_channels) const override
+ template <typename Strat, typename Workspace>
+ static void direct(
+ const Strat *strat, const Workspace *ws, const OutputStage &,
+ unsigned int n_tile_rows, unsigned int n_tile_cols,
+ const TInput *inptr, size_t ld_in_row, size_t ld_in_col,
+ TOutput *outptr, size_t ld_out_row, size_t ld_out_col,
+ const void *params, unsigned int n_channels
+ )
{
- const unsigned int n_output_channels = n_channels * this->m_args.channel_multiplier;
- return n_threads * (sizeof_inptr_array() + sizeof_outptr_array() +
- sizeof_output_buffer(n_output_channels) +
- sizeof_input_buffer(n_channels));
+ strat->get_direct_kernel()(
+ n_tile_rows, n_tile_cols,
+ inptr, ld_in_row, ld_in_col,
+ outptr, ld_out_row, ld_out_col,
+ params, n_channels, ws->activation_min, ws->activation_max
+ );
}
+};
- using DepthwiseCommon<TInput, TWeight, TOutput>::execute;
- void execute(
- const unsigned int batches,
- const unsigned int input_height,
- const unsigned int input_width,
- const unsigned int input_channels,
- const PaddingValues &padding,
- const void *const _input,
- const size_t ld_input_col,
- const size_t ld_input_row,
- const size_t ld_input_batch,
- const void *const parameters,
- const unsigned int output_height,
- const unsigned int output_width,
- void *const _output,
- const size_t ld_output_col,
- const size_t ld_output_row,
- const size_t ld_output_batch,
- void *const _working_space,
- const unsigned int thread_id,
- const unsigned int n_threads
- ) const override
+template <typename TInput, typename TWeight, typename TOutput, typename TAccum>
+struct Invoke<TInput, TWeight, TOutput, TAccum, arm_gemm::Requantize32>
+{
+ constexpr static bool supports_direct_kernel = false;
+
+ template <typename Strat, typename Workspace>
+ static inline void indirect(const Strat *strat, const Workspace *ws, const arm_gemm::Requantize32 &qp, const void *params, const TAccum *, unsigned int n_channels)
{
-#ifdef CYCLE_PROFILING
- arm_gemm::profiler prof;
-#endif
+ strat->get_kernel()(
+ n_channels, ws->inptr_array,
+ reinterpret_cast<const TWeight *>(params), ws->bias,
+ qp, ws->requant_muls, ws->requant_shifts,
+ ws->outptr_array
+ );
+ }
+
+ template <typename Strat, typename Workspace>
+ static inline void direct(
+ const Strat *, const Workspace *, const arm_gemm::Requantize32 &,
+ unsigned int, unsigned int, // n_tile_rows, n_tile_cols
+ const TInput *, size_t, size_t, // Input pointer, row stride, column stride
+ TOutput *, size_t, size_t, // Output pointer, row stride, column stride
+ const void *, unsigned int // Parameters, number of channels
+ )
+ {
+ // Do nothing - this should never be reached because entry to it is guarded
+ // by an `if` on a `constexpr static bool`.
+ }
+};
- // Compute activation values
- TAccum activation_min, activation_max;
- std::tie(activation_min, activation_max) = get_default_activation_values<TAccum>();
+namespace
+{
- switch (this->m_args.activation.type)
- {
- case arm_gemm::Activation::Type::BoundedReLU:
- activation_max = static_cast<TAccum>(this->m_args.activation.param1);
- // Fall through
- case arm_gemm::Activation::Type::ReLU:
- activation_min = static_cast<TAccum>(0);
- break;
- default:
- break;
- }
+template <typename OutputStage>
+inline void stash_bias(OutputStage &, const void *) {}
- // Determine what portion of the work to do.
- const unsigned int n_rows_per_thread = arm_gemm::iceildiv(output_height, n_threads);
- const int start_out_height = std::min(thread_id * n_rows_per_thread, output_height);
- const int end_out_height = std::min(start_out_height + n_rows_per_thread, output_height);
+template <>
+inline void stash_bias(arm_gemm::Requantize32 &qp, const void *bias) __attribute__ ((unused));
- // Cast input and output pointers into the right types
- const TInput *const inptr = static_cast<const TInput *>(_input);
- TOutput *const outptr = static_cast<TOutput *>(_output);
+template <>
+inline void stash_bias(arm_gemm::Requantize32 &qp, const void *bias)
+{
+ qp.bias = reinterpret_cast<const int32_t *>(bias);
+}
- // Allocate portions of the working space
- uint8_t *working_space = static_cast<uint8_t *>(_working_space) + get_working_size(thread_id, input_channels);
+}
- const void **const inptr_array = reinterpret_cast<const void **>(working_space);
- working_space += sizeof_inptr_array();
+} // namespace depthwise_depthfirst
- void **const outptr_array = reinterpret_cast<void **>(working_space);
- working_space += sizeof_outptr_array();
+template <typename TInput,
+ typename TWeight=TInput,
+ typename TOutput=TInput,
+ typename TAccum=typename DefaultTAccum<TInput>::Type,
+ typename OutputStage=typename DefaultOutputStage<TOutput>::Type>
+class DepthwiseDepthfirst
+: public DepthwiseDepthfirstCommon<TInput, TWeight, TOutput, TAccum, OutputStage>
+{
+ using StratType = DepthwiseDepthfirstStrategy<TInput, TWeight, TOutput, TAccum>;
+ using Parent = DepthwiseDepthfirstCommon<TInput, TWeight, TOutput, TAccum, OutputStage>;
+ using WorkspaceManager = Workspace<
+ OutputArrayElement<TOutput>,
+ depthwise_depthfirst::InputArrayElement<TInput>,
+ InputBufferElement<TInput>,
+ typename depthwise_depthfirst::WorkspaceFinalElement<TAccum, OutputStage>::Element
+ >;
+ using WorkingSpace = typename WorkspaceManager::WorkspaceType;
+
+ // We keep a copy of the bias and output stage
+ const TAccum *m_bias;
- TOutput *const output_buffer = reinterpret_cast<TOutput *>(working_space);
- working_space += sizeof_output_buffer(input_channels * this->m_args.channel_multiplier);
+ public:
+ DepthwiseDepthfirst(StratType *const strat, const DepthwiseArgs &args, const OutputStage &os = {})
+ : Parent(strat, args, os), m_bias(nullptr)
+ {
+ }
- TInput *const input_buffer = reinterpret_cast<TInput *>(working_space);
+ DepthwiseDepthfirst(DepthwiseDepthfirst &) = delete;
+ DepthwiseDepthfirst &operator=(DepthwiseDepthfirst &) = delete;
- // Initialise the input buffer
- for (unsigned int c = 0; c < input_channels; c++)
- {
- input_buffer[c] = static_cast<TInput>(0);
- }
+ void pack_parameters(void *buffer, const void *biases, const void *weights, size_t ld_weight_col, size_t ld_weight_row) override
+ {
+ reinterpret_cast<const StratType *>(this->m_strat.get())->pack_parameters(
+ this->m_args, buffer, biases, this->get_output_stage(),
+ weights, ld_weight_col, ld_weight_row
+ );
+ m_bias = reinterpret_cast<const TAccum *>(biases);
+ depthwise_depthfirst::stash_bias(this->get_output_stage(), biases);
+ }
- // For each output tile, construct the requisite set of pointers and call
- // into the kernel.
- for (unsigned int batch = 0; batch < batches; batch++)
- {
- // Get batch pointers
- const auto inptr_batch = inptr + batch * ld_input_batch;
- const auto outptr_batch = outptr + batch * ld_output_batch;
+ size_t get_working_size_per_thread(const unsigned int n_input_channels) const override
+ {
+ DepthwiseArgs args(this->m_args);
+ args.input_channels = n_input_channels;
+ return WorkspaceManager::get_sizeof_workspace(
+ WorkspaceArgs<IDepthfirstStrategy, OutputStage>(this->m_strat.get(), args, this->get_output_stage())
+ );
+ }
- for (int start_out_i = start_out_height;
- start_out_i < end_out_height;
- start_out_i += static_cast<int>(m_strat->get_output_rows()))
- {
- const int end_out_i = start_out_i + m_strat->get_output_rows();
- const int start_in_i = start_out_i * m_strat->get_stride_rows() - padding.top;
- const int end_in_i = start_in_i + m_strat->get_input_rows();
-
- // Compute top/bottom padding
- const auto pad_top = static_cast<unsigned int>(-std::min(start_in_i, 0));
- const auto pad_bottom = static_cast<unsigned int>(-std::min(static_cast<int>(input_height) - end_in_i, 0));
- const unsigned int valid_output_rows = std::min(
- end_out_i - start_out_i,
- static_cast<int>(output_height) - start_out_i
- );
+ void initialise_working_space(void *buffer, unsigned int n_input_channels) const override
+ {
+ DepthwiseArgs args(this->m_args);
+ args.input_channels = n_input_channels;
+ WorkspaceManager::initialise(
+ buffer, WorkspaceArgs<IDepthfirstStrategy, OutputStage>(this->m_strat.get(), args, this->get_output_stage())
+ );
+ }
- // Fill the input pointer array with padding values
- for (auto index = 0u; index < m_strat->get_input_rows() * m_strat->get_input_cols(); index++)
- {
- inptr_array[index] = input_buffer;
- }
+ protected:
+ void compute_tile_padded(
+ unsigned int output_i, unsigned int output_j,
+ unsigned int output_channel_start, unsigned int output_channel_end,
+ const TensorSpec<const TInput *> &input,
+ const TensorSpec<TOutput *> &output,
+ const void *parameters,
+ void *working_space_raw
+ ) const override
+ {
+ // Get the working space
+ auto ws = reinterpret_cast<WorkingSpace *>(working_space_raw);
+
+ // Compute the input pointer array
+ const auto input_channel_start = output_channel_start / this->m_args.channel_multiplier;
+
+ const int ii = static_cast<int>(output_i * this->m_args.stride_rows) - this->m_args.padding.top;
+ const auto input_pad_top = static_cast<unsigned int>(ii < 0 ? -ii : 0);
+ const auto input_i = static_cast<unsigned int>(ii < 0 ? 0 : ii);
+
+ const int ij = static_cast<int>(output_j * this->m_args.stride_cols) - this->m_args.padding.left;
+ const auto input_pad_left = static_cast<unsigned int>(ij < 0 ? -ij : 0);
+ const auto input_j = static_cast<unsigned int>(ij < 0 ? 0 : ij);
+
+ fill_pointer_array<const TInput>(
+ ws->inptr_array, this->m_strat->get_input_rows(), this->m_strat->get_input_cols(),
+ input.base + input_i*input.ld_row + input_j*input.ld_col + input_channel_start,
+ input.ld_row, input.ld_col,
+ ws->input_buffer,
+ input_pad_top, this->m_args.input_rows - input_i,
+ input_pad_left, this->m_args.input_cols - input_j
+ );
+
+ // Compute the output pointer array
+ fill_pointer_array(
+ ws->outptr_array, this->m_strat->get_output_rows(), this->m_strat->get_output_cols(),
+ output.base + output_i*output.ld_row + output_j*output.ld_col + output_channel_start,
+ output.ld_row, output.ld_col,
+ ws->output_buffer,
+ 0, this->m_args.output_rows - output_i, // Top padding, # valid rows
+ 0, this->m_args.output_cols - output_j // Left padding, # valid columns
+ );
+
+ // Execute the kernel
+ depthwise_depthfirst::Invoke<TInput, TWeight, TOutput, TAccum, OutputStage>::indirect(
+ reinterpret_cast<const StratType *>(this->m_strat.get()),
+ ws, this->get_output_stage(), parameters, m_bias, output_channel_end - output_channel_start
+ );
+ }
- for (int start_out_j = 0; start_out_j < static_cast<int>(output_width);)
- {
- const int start_in_j = start_out_j * m_strat->get_stride_cols() - this->m_args.padding.left;
- int pad_left = std::min(0, start_in_j);
+ void compute_row_padded_tile_row(
+ const unsigned int output_i, unsigned int output_j, unsigned int n_tile_cols,
+ const unsigned int output_channel_start, const unsigned int output_channel_end,
+ const TensorSpec<const TInput *> &input,
+ const TensorSpec<TOutput *> &output,
+ const void *parameters,
+ void *working_space
+ ) const override
+ {
+ using Invoker = depthwise_depthfirst::Invoke<TInput, TWeight, TOutput, TAccum, OutputStage>;
+ auto ws = reinterpret_cast<WorkingSpace *>(working_space);
+ const auto strat = reinterpret_cast<const StratType *>(this->m_strat.get());
+ const auto os = this->get_output_stage();
+
+ // Compute top and bottom padding; hence fill in the initial pointer arrays.
+ const auto input_channel_start = output_channel_start / this->m_args.channel_multiplier;
+ const int ii = static_cast<int>(output_i * this->m_args.stride_rows) - this->m_args.padding.top;
+ const auto input_pad_top = static_cast<unsigned int>(ii < 0 ? -ii : 0);
+
+ const auto input_i = static_cast<unsigned int>(ii < 0 ? 0 : ii);
+ const auto input_j = output_j * this->m_args.stride_cols - this->m_args.padding.left;
+
+ const auto valid_input_rows = std::min(strat->get_input_rows(), this->m_args.input_rows - input_i);
+ const auto valid_output_rows = std::min(strat->get_output_rows(), this->m_args.output_rows - output_i);
+
+ const auto input_point_stride = input.ld_col * this->m_strat->get_output_cols() * this->m_args.stride_cols;
+ const auto output_point_stride = output.ld_col * this->m_strat->get_output_cols();
+
+ fill_pointer_array<const TInput>(
+ ws->inptr_array, this->m_strat->get_input_rows(), this->m_strat->get_input_cols(),
+ input.base + input_i*input.ld_row + input_j*input.ld_col + input_channel_start,
+ input.ld_row, input.ld_col,
+ ws->input_buffer,
+ input_pad_top, this->m_args.input_rows - input_i,
+ 0, this->m_args.input_cols - input_j // No left padding
+ );
+
+ fill_pointer_array(
+ ws->outptr_array, this->m_strat->get_output_rows(), this->m_strat->get_output_cols(),
+ output.base + output_i*output.ld_row + output_j*output.ld_col + output_channel_start,
+ output.ld_row, output.ld_col,
+ ws->output_buffer,
+ 0, this->m_args.output_rows - output_i, // Top padding, # valid rows
+ 0, this->m_args.output_cols - output_j // Left padding, # valid columns
+ );
+
+ for (; n_tile_cols; n_tile_cols--)
+ {
+ // Execute the kernel
+ Invoker::indirect(
+ strat, ws, os, parameters, m_bias, output_channel_end - output_channel_start
+ );
- // Compute how many output tiles we can compute with the direct kernel.
- int n_direct_tiles = 0;
- if (!pad_top && !pad_bottom && !pad_left)
+ // Update all unpadded pointers
+ {
+ auto ptr = ws->inptr_array + strat->get_input_cols() * input_pad_top;
+ for (auto n = input_pad_top; n < valid_input_rows; n++)
+ {
+ for (auto m = 0u; m < strat->get_input_cols(); m++)
{
- // Determine the maximum number of tiles we could handle.
- n_direct_tiles = (output_width - start_out_j) / m_strat->get_output_cols();
-
- // Continue to reduce this number as required to avoid reading
- // padding on the right edge.
- int end_in_j = start_in_j + n_direct_tiles * m_strat->get_input_cols();
- int pad_right = std::max(0, end_in_j - static_cast<int>(input_width));
-
- while (pad_right && n_direct_tiles)
- {
- n_direct_tiles--;
- end_in_j -= m_strat->get_input_cols();
- pad_right = std::max(0, end_in_j - static_cast<int>(input_width));
- }
+ *(ptr++) += input_point_stride;
}
+ }
+ }
+ {
+ auto ptr = ws->outptr_array;
+ for (auto n = 0u; n < valid_output_rows * strat->get_output_cols(); n++)
+ {
+ *(ptr++) += output_point_stride;
+ }
+ }
+ }
+ }
- // Use the unpadded kernel if we can, otherwise use the padded one.
- if (n_direct_tiles)
- {
- auto inptr = inptr_batch + start_in_i*ld_input_row + start_in_j*ld_input_col;
- auto outptr = outptr_batch + start_out_i*ld_output_row + start_out_j*ld_output_col;
- start_out_j += n_direct_tiles*m_strat->get_output_cols();
+ void compute_tiles_unpadded(
+ unsigned int output_i, const unsigned int output_j,
+ unsigned int n_tile_rows, unsigned int n_tile_cols,
+ unsigned int output_channel_start, unsigned int output_channel_end,
+ const TensorSpec<const TInput *> &input,
+ const TensorSpec<TOutput *> &output,
+ const void *parameters,
+ void *working_space_raw
+ ) const override
+ {
+ using Invoker = depthwise_depthfirst::Invoke<TInput, TWeight, TOutput, TAccum, OutputStage>;
+ auto ws = reinterpret_cast<WorkingSpace *>(working_space_raw);
+ const auto strat = reinterpret_cast<const StratType *>(this->m_strat.get());
+ const auto os = this->get_output_stage();
-#ifdef CYCLE_PROFILING
- auto p = prof.ScopedProfiler(PROFILE_KERNEL, 0);
-#endif
- m_strat->direct_kernel(1, n_direct_tiles,
- inptr, ld_input_row, ld_input_col,
- outptr, ld_output_row, ld_output_col,
- parameters, this->m_args.input_channels,
- &activation_min, &activation_max);
- continue;
- }
+ if (Invoker::supports_direct_kernel)
+ {
+ // If the direct kernel is supported, then use it.
+ // Compute the base pointers we'll use in the tile.
+ auto outptr = output.base + output_channel_start + output_i * output.ld_row + output_j * output.ld_col;
+ const int start_input_i = output_i * this->m_args.stride_rows - this->m_args.padding.top;
+ const int start_input_j = output_j * this->m_args.stride_cols - this->m_args.padding.left;
+ auto inptr = input.base + output_channel_start + start_input_i * input.ld_row + start_input_j * input.ld_col;
+
+ // Execute the kernel
+ Invoker::direct(
+ strat, ws, os,
+ n_tile_rows, n_tile_cols,
+ inptr, input.ld_row, input.ld_col,
+ outptr, output.ld_row, output.ld_col,
+ parameters, output_channel_end - output_channel_start
+ );
+ }
+ else
+ {
+ // Otherwise, we repeatedly call the padded kernel but use our knowledge
+ // of the tensor structure to avoid recomputing the pointer array.
+ const auto input_channel_start = output_channel_start / this->m_args.channel_multiplier;
+
+ const auto n_input_pointers = this->m_strat->get_input_rows() * this->m_strat->get_input_cols();
+ const auto input_point_stride = input.ld_col * this->m_strat->get_output_cols() * this->m_args.stride_cols;
+ const auto n_output_pointers = this->m_strat->get_output_rows() * this->m_strat->get_output_cols();
+ const auto output_point_stride = output.ld_col * this->m_strat->get_output_cols();
+
+ // For each tile row, initialise the input and output pointer arrays. For
+ // each subsequent tile we simply update the pointers.
+ for (unsigned int tile_i = 0; tile_i < n_tile_rows; tile_i++)
+ {
+ const int input_i = static_cast<int>(output_i * this->m_args.stride_rows) - this->m_args.padding.top;
+ const int input_j = static_cast<int>(output_j * this->m_args.stride_cols) - this->m_args.padding.left;
+
+ fill_pointer_array<const TInput>(
+ ws->inptr_array, this->m_strat->get_input_rows(), this->m_strat->get_input_cols(),
+ input.base + input_i*input.ld_row + input_j*input.ld_col + input_channel_start,
+ input.ld_row, input.ld_col,
+ ws->input_buffer,
+ 0, this->m_args.input_rows,
+ 0, this->m_args.input_cols
+ );
- const int end_out_j = start_out_j + m_strat->get_output_cols();
- const int end_in_j = start_in_j + m_strat->get_input_cols();
+ // Compute the output pointer array
+ fill_pointer_array(
+ ws->outptr_array, this->m_strat->get_output_rows(), this->m_strat->get_output_cols(),
+ output.base + output_i*output.ld_row + output_j*output.ld_col + output_channel_start,
+ output.ld_row, output.ld_col,
+ ws->output_buffer,
+ 0, this->m_args.output_rows,
+ 0, this->m_args.output_cols
+ );
- const auto pad_right = static_cast<unsigned int>(-std::min(static_cast<int>(input_width) - end_in_j, 0));
- const unsigned int valid_output_cols = std::min(
- end_out_j - start_out_j,
- static_cast<int>(output_width) - start_out_j
+ for (unsigned int tile_j = 0; tile_j < n_tile_cols; tile_j++)
+ {
+ // Invoke the indirect kernel for this tile
+ depthwise_depthfirst::Invoke<TInput, TWeight, TOutput, TAccum, OutputStage>::indirect(
+ strat, ws, os, parameters, m_bias, output_channel_end - output_channel_start
);
- pad_left *= -1;
- // Construct the input pointer array - fill the array with pointers to
- // the input buffer and then fill in the required values.
- for (auto i = pad_top; i < m_strat->get_input_rows() - pad_bottom; i++)
- {
- // Can skip over the left padding because we will have either the
- // same or less than the previous tile.
- unsigned int j = pad_left;
- const TInput *colptr = inptr_batch + (start_in_i + i) * ld_input_row + (start_in_j + j) * ld_input_col;
- const void **ptrs = inptr_array + i * m_strat->get_input_cols() + j;
- for (; j < m_strat->get_input_cols() - pad_right; j++)
- {
- *(ptrs++) = colptr;
- colptr += ld_input_col;
- }
- for (; j < m_strat->get_input_cols(); j++)
- {
- *(ptrs++) = input_buffer;
- }
- }
- // Construct the output pointer array.
- void **outptr_pos = outptr_array;
- for (auto i = 0u; i < valid_output_rows; i++)
+ // Progress the pointers
+ for (auto i = 0u; i < n_input_pointers; i++)
{
- unsigned int j = 0u;
- TOutput *colptr = outptr_batch + (start_out_i + i) * ld_output_row + start_out_j * ld_output_col;
- for (; j < valid_output_cols; j++)
- {
- *(outptr_pos++) = colptr;
- colptr += ld_output_col;
- }
- for (; j < m_strat->get_output_cols(); j++)
- {
- *(outptr_pos++) = output_buffer;
- }
+ ws->inptr_array[i] += input_point_stride;
}
- for (auto i = valid_output_rows; i < m_strat->get_output_rows(); i++)
+ for (auto i = 0u; i < n_output_pointers; i++)
{
- for (auto j = 0u; j < m_strat->get_output_cols(); j++)
- {
- *(outptr_pos++) = output_buffer;
- }
+ ws->outptr_array[i] += output_point_stride;
}
-
- start_out_j += m_strat->get_output_cols();
-
-#ifdef CYCLE_PROFILING
- // TODO Work number
- auto p = prof.ScopedProfiler(PROFILE_KERNEL, (unsigned long)(0));
-#endif
- m_strat->indirect_kernel(inptr_array, outptr_array, parameters,
- this->m_args.input_channels,
- &activation_min, &activation_max);
}
+
+ output_i += this->m_strat->get_output_rows();
}
}
}