From 8a164884dddf769643cf3b9f7f94e43cb4f3c20b Mon Sep 17 00:00:00 2001 From: ramelg01 Date: Thu, 7 Apr 2022 02:42:52 +0100 Subject: =?UTF-8?q?Update=20Neon=E2=84=A2=20depthwise=20kernel?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Reduce duplication and simplify overall structure. - Improve multi-threaded performance by sharing more data in lower-level caches. Partially Resolves: COMPMID-5054 Signed-off-by: Ramy Elgammal Change-Id: Iac747f39b21c540122fa75218762631c4d787911 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7449 Tested-by: Arm Jenkins Reviewed-by: Andrew Mundy Reviewed-by: Sheri Zhang Comments-Addressed: Arm Jenkins --- .../arm_conv/depthwise/depthwise_depthfirst.hpp | 736 +++++++++++++-------- 1 file changed, 468 insertions(+), 268 deletions(-) (limited to 'src/core/NEON/kernels/arm_conv/depthwise/depthwise_depthfirst.hpp') diff --git a/src/core/NEON/kernels/arm_conv/depthwise/depthwise_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/depthwise/depthwise_depthfirst.hpp index 57fa11151b..6905076357 100644 --- a/src/core/NEON/kernels/arm_conv/depthwise/depthwise_depthfirst.hpp +++ b/src/core/NEON/kernels/arm_conv/depthwise/depthwise_depthfirst.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. + * Copyright (c) 2021-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -24,7 +24,9 @@ #pragma once -#include "src/core/NEON/kernels/arm_gemm/utils.hpp" +#include "src/core/NEON/kernels/arm_conv/addressing.hpp" +#include "depthwise_strategies_common.hpp" +#include "working_space.hpp" #ifdef CYCLE_PROFILING #include "profiler.hpp" @@ -35,349 +37,547 @@ namespace arm_conv { namespace depthwise { -struct IDepthwiseDepthfirstStrategy +template +class DepthwiseDepthfirstStrategyCommon + : public DepthfirstStrategy { - virtual arm_gemm::VLType get_vl_type() const = 0; + protected: + unsigned int m_output_rows, m_output_cols; + unsigned int m_kernel_rows, m_kernel_cols; + unsigned int m_stride_rows, m_stride_cols; - virtual unsigned int get_input_rows() const = 0; - virtual unsigned int get_input_cols() const = 0; + public: + DepthwiseDepthfirstStrategyCommon( + unsigned int output_rows, unsigned int output_cols, + unsigned int kernel_rows, unsigned int kernel_cols, + unsigned int stride_rows=1, unsigned int stride_cols=1 + ) : m_output_rows(output_rows), m_output_cols(output_cols), + m_kernel_rows(kernel_rows), m_kernel_cols(kernel_cols), + m_stride_rows(stride_rows), m_stride_cols(stride_cols) + { + } + + DepthwiseDepthfirstStrategyCommon(unsigned int output_size, unsigned int kernel_size, unsigned int stride=1) + : DepthwiseDepthfirstStrategyCommon(output_size, output_size, kernel_size, kernel_size, stride, stride) + { + } + + virtual ~DepthwiseDepthfirstStrategyCommon() {} + + unsigned int get_output_rows() const override { return m_output_rows; } + unsigned int get_output_cols() const override { return m_output_cols; } - virtual unsigned int get_output_rows() const = 0; - virtual unsigned int get_output_cols() const = 0; + unsigned int get_kernel_rows() const override { return m_kernel_rows; } + unsigned int get_kernel_cols() const override { return m_kernel_cols; } - virtual unsigned int get_kernel_rows() const = 0; - virtual unsigned int get_kernel_cols() const = 0; + unsigned int get_stride_rows() const override { return m_stride_rows; } + unsigned int get_stride_cols() const override { return m_stride_cols; } +}; + +template ::Type> +class DepthwiseDepthfirstStrategy : public DepthwiseDepthfirstStrategyCommon +{ + using Parent = DepthwiseDepthfirstStrategyCommon; - virtual unsigned int get_stride_rows() const = 0; - virtual unsigned int get_stride_cols() const = 0; + public: + using Parent::Parent; - virtual void indirect_kernel( - const void *const *const input_ptrs, - void *const *const output_ptrs, + typedef void (*IndirectKernelType)( + const TInput *const *input_ptrs, + TOutput *const *output_ptrs, const void *params, unsigned int n_channels, - const void *activation_min, - const void *activation_max - ) const = 0; + const TAccum activation_min, + const TAccum activation_max + ); + virtual IndirectKernelType get_indirect_kernel(void) const = 0; - virtual void direct_kernel( + typedef void (*DirectKernelType)( const unsigned int n_tile_rows, const unsigned int n_tile_cols, - const void *inptr, int64_t ld_input_row, int64_t ld_input_col, - void *outptr, int64_t ld_output_row, int64_t ld_output_col, + const TInput *inptr_base, int64_t ld_input_row, int64_t ld_input_col, + TOutput *outptr_base, int64_t ld_output_row, int64_t ld_output_col, const void *params, unsigned int n_channels, - const void *activation_min, - const void *activation_max - ) const = 0; - - virtual ~IDepthwiseDepthfirstStrategy() {} + const TAccum activation_min, + const TAccum activation_max + ); + virtual DirectKernelType get_direct_kernel(void) const = 0; }; -template -class DepthwiseDepthfirst : public DepthwiseCommon +template +class DepthwiseDepthfirstStrategy +: public DepthwiseDepthfirstStrategyCommon { - const std::unique_ptr m_strat; + using Parent = DepthwiseDepthfirstStrategyCommon; - size_t sizeof_inptr_array(void) const + protected: + interleaves::PackingArguments get_packing_args(void) const { - return sizeof(TInput *) * m_strat->get_input_rows() * m_strat->get_input_cols(); + return interleaves::PackingArguments( + this->get_kernel_rows(), this->get_kernel_cols(), sizeof(TWeight), + false, sizeof(int32_t), // Don't pack the bias + this->get_vl_type(), sizeof(int32_t), this->get_accumulator_depth_vl(), + [this] (unsigned int idx, unsigned int &x, unsigned int &y) -> bool + { return this->get_kernel_packing_point(idx, x, y); } + ); } - size_t sizeof_input_buffer(unsigned int n_input_channels) const + public: + using Parent::Parent; + + typedef void (*KernelType)( + unsigned int, // n_channels, + const TInput *const *, // inptrs + const TWeight *, // weights + const int32_t *, // bias, + const arm_gemm::Requantize32 &, + const int32_t *, const int32_t *, // requant_muls and requant_shifts + TOutput *const * // outptrs + ); + virtual KernelType get_kernel() const = 0; + + size_t get_storage_size(const DepthwiseArgs &args) const override { - return sizeof(TInput) * n_input_channels; + return interleaves::get_storage_size_generic(get_packing_args(), args); } - size_t sizeof_outptr_array(void) const + void pack_parameters( + const DepthwiseArgs &args, void *buffer, + const void *biases, const arm_gemm::Requantize32 &, + const void *weights, size_t ld_weight_col, size_t ld_weight_row + ) const override { - return sizeof(TInput *) * m_strat->get_output_rows() * m_strat->get_output_cols(); + interleaves::pack_parameters_generic( + get_packing_args(), args, buffer, biases, weights, ld_weight_col, ld_weight_row); } +}; - size_t sizeof_output_buffer(unsigned int n_output_channels) const - { - return sizeof(TOutput) * n_output_channels; - } +template +class DepthwiseDepthfirstCommon : public DepthfirstDriver +{ + using StratType = DepthwiseDepthfirstStrategyCommon; + OutputStage m_os; + + protected: + inline OutputStage &get_output_stage(void) { return m_os; } + inline const OutputStage &get_output_stage(void) const { return m_os; } public: - DepthwiseDepthfirst( - IDepthwiseDepthfirstStrategy *const strat, - const DepthwiseArgs &args - ) : DepthwiseCommon(args), m_strat(strat) + DepthwiseDepthfirstCommon(StratType *const strat, const DepthwiseArgs &args, const OutputStage &os) + : DepthfirstDriver(strat, args), m_os(os) { } - DepthwiseDepthfirst(DepthwiseDepthfirst &) = delete; - DepthwiseDepthfirst &operator=(DepthwiseDepthfirst &) = delete; + DepthwiseDepthfirstCommon(DepthwiseDepthfirstCommon &) = delete; + DepthwiseDepthfirstCommon &operator=(DepthwiseDepthfirstCommon &) = delete; size_t get_storage_size(void) const override { - // TODO What if we insert extra padding? Biases are a different size to the inputs, ... - const unsigned int vl = arm_gemm::utils::get_vector_length(m_strat->get_vl_type()); - const auto rounded_channels = arm_gemm::roundup(this->m_args.input_channels, vl); - return (1 + this->m_args.kernel_rows * this->m_args.kernel_cols) * rounded_channels * sizeof(TWeight); + return reinterpret_cast(this->m_strat.get())-> + get_storage_size(this->m_args); } - void pack_parameters(void *_buffer, const void *_biases, const void *_weights, size_t ld_weight_col, size_t ld_weight_row) override + void pack_parameters(void *buffer, const void *biases, const void *weights, size_t ld_weight_col, size_t ld_weight_row) override { - // TODO What if the kernel needs a different packing function? + reinterpret_cast(this->m_strat.get())-> + pack_parameters(this->m_args, buffer, biases, m_os, weights, ld_weight_col, ld_weight_row); + } +}; - // Cast the pointers - uint8_t *buffer = static_cast(_buffer); - const TAccum *biases = static_cast(_biases); - const TWeight *const weights = static_cast(_weights); +namespace depthwise_depthfirst { - const unsigned int vl = arm_gemm::utils::get_vector_length(m_strat->get_vl_type()); - ld_weight_col = (ld_weight_col == 0) ? this->m_args.input_channels : ld_weight_col; - ld_weight_row = (ld_weight_row == 0) ? this->m_args.kernel_cols * ld_weight_col : ld_weight_row; +/* Workspace Element for an array of input pointers as consumed by the + * specialised depthwise kernels. + */ +template +class InputArrayElement +{ + public: + struct Workspace + { + const T **inptr_array; + }; - for (unsigned int n = 0; n < this->m_args.input_channels; n += vl) - { - const unsigned int todo = std::min(vl, this->m_args.input_channels - n); + template + static size_t get_element_size(const WorkspaceArgs &args) + { + return sizeof(T **) * args.strategy->get_input_rows() * args.strategy->get_input_cols(); + } - // Copy across the correct amount of bias (or 0) - for (unsigned int i = 0; i < todo; i++) - { - reinterpret_cast(buffer)[i] = (biases == nullptr) ? 0 : biases[n + i]; - } - buffer += vl * sizeof(TAccum); + template + static void *initialise(WorkspaceType *ws, void *buffer, const WorkspaceArgs &args) + { + ws->inptr_array = reinterpret_cast(buffer); + return reinterpret_cast(buffer) + get_element_size(args); + } +}; - // Copy each of the weights in turn - auto weights_row = weights + n; - for (unsigned int i = 0; i < this->m_args.kernel_rows; i++) - { - auto weights_col = weights_row; +template +struct WorkspaceFinalElement +{ + using Element = ActivationsElement; +}; - for (unsigned int j = 0; j < this->m_args.kernel_cols; j++) - { - for (unsigned int m = 0; m < todo; m++) - { - reinterpret_cast(buffer)[m] = weights_col[m]; - } - buffer += vl * sizeof(TWeight); +template <> +struct WorkspaceFinalElement +{ + using Element = RequantizationParametersElement; +}; - weights_col += ld_weight_col; - } +template +struct Invoke +{ + constexpr static bool supports_direct_kernel = true; - weights_row += ld_weight_row; - } - } + template + static inline void indirect(const Strat *strat, const Workspace *ws, const OutputStage &, const void *params, const TAccum *, unsigned int n_channels) + { + strat->get_indirect_kernel()( + ws->inptr_array, + ws->outptr_array, + params, n_channels, + ws->activation_min, ws->activation_max + ); } - size_t get_working_size(const unsigned int n_threads, const unsigned int n_channels) const override + template + static void direct( + const Strat *strat, const Workspace *ws, const OutputStage &, + unsigned int n_tile_rows, unsigned int n_tile_cols, + const TInput *inptr, size_t ld_in_row, size_t ld_in_col, + TOutput *outptr, size_t ld_out_row, size_t ld_out_col, + const void *params, unsigned int n_channels + ) { - const unsigned int n_output_channels = n_channels * this->m_args.channel_multiplier; - return n_threads * (sizeof_inptr_array() + sizeof_outptr_array() + - sizeof_output_buffer(n_output_channels) + - sizeof_input_buffer(n_channels)); + strat->get_direct_kernel()( + n_tile_rows, n_tile_cols, + inptr, ld_in_row, ld_in_col, + outptr, ld_out_row, ld_out_col, + params, n_channels, ws->activation_min, ws->activation_max + ); } +}; - using DepthwiseCommon::execute; - void execute( - const unsigned int batches, - const unsigned int input_height, - const unsigned int input_width, - const unsigned int input_channels, - const PaddingValues &padding, - const void *const _input, - const size_t ld_input_col, - const size_t ld_input_row, - const size_t ld_input_batch, - const void *const parameters, - const unsigned int output_height, - const unsigned int output_width, - void *const _output, - const size_t ld_output_col, - const size_t ld_output_row, - const size_t ld_output_batch, - void *const _working_space, - const unsigned int thread_id, - const unsigned int n_threads - ) const override +template +struct Invoke +{ + constexpr static bool supports_direct_kernel = false; + + template + static inline void indirect(const Strat *strat, const Workspace *ws, const arm_gemm::Requantize32 &qp, const void *params, const TAccum *, unsigned int n_channels) { -#ifdef CYCLE_PROFILING - arm_gemm::profiler prof; -#endif + strat->get_kernel()( + n_channels, ws->inptr_array, + reinterpret_cast(params), ws->bias, + qp, ws->requant_muls, ws->requant_shifts, + ws->outptr_array + ); + } + + template + static inline void direct( + const Strat *, const Workspace *, const arm_gemm::Requantize32 &, + unsigned int, unsigned int, // n_tile_rows, n_tile_cols + const TInput *, size_t, size_t, // Input pointer, row stride, column stride + TOutput *, size_t, size_t, // Output pointer, row stride, column stride + const void *, unsigned int // Parameters, number of channels + ) + { + // Do nothing - this should never be reached because entry to it is guarded + // by an `if` on a `constexpr static bool`. + } +}; - // Compute activation values - TAccum activation_min, activation_max; - std::tie(activation_min, activation_max) = get_default_activation_values(); +namespace +{ - switch (this->m_args.activation.type) - { - case arm_gemm::Activation::Type::BoundedReLU: - activation_max = static_cast(this->m_args.activation.param1); - // Fall through - case arm_gemm::Activation::Type::ReLU: - activation_min = static_cast(0); - break; - default: - break; - } +template +inline void stash_bias(OutputStage &, const void *) {} - // Determine what portion of the work to do. - const unsigned int n_rows_per_thread = arm_gemm::iceildiv(output_height, n_threads); - const int start_out_height = std::min(thread_id * n_rows_per_thread, output_height); - const int end_out_height = std::min(start_out_height + n_rows_per_thread, output_height); +template <> +inline void stash_bias(arm_gemm::Requantize32 &qp, const void *bias) __attribute__ ((unused)); - // Cast input and output pointers into the right types - const TInput *const inptr = static_cast(_input); - TOutput *const outptr = static_cast(_output); +template <> +inline void stash_bias(arm_gemm::Requantize32 &qp, const void *bias) +{ + qp.bias = reinterpret_cast(bias); +} - // Allocate portions of the working space - uint8_t *working_space = static_cast(_working_space) + get_working_size(thread_id, input_channels); +} - const void **const inptr_array = reinterpret_cast(working_space); - working_space += sizeof_inptr_array(); +} // namespace depthwise_depthfirst - void **const outptr_array = reinterpret_cast(working_space); - working_space += sizeof_outptr_array(); +template ::Type, + typename OutputStage=typename DefaultOutputStage::Type> +class DepthwiseDepthfirst +: public DepthwiseDepthfirstCommon +{ + using StratType = DepthwiseDepthfirstStrategy; + using Parent = DepthwiseDepthfirstCommon; + using WorkspaceManager = Workspace< + OutputArrayElement, + depthwise_depthfirst::InputArrayElement, + InputBufferElement, + typename depthwise_depthfirst::WorkspaceFinalElement::Element + >; + using WorkingSpace = typename WorkspaceManager::WorkspaceType; + + // We keep a copy of the bias and output stage + const TAccum *m_bias; - TOutput *const output_buffer = reinterpret_cast(working_space); - working_space += sizeof_output_buffer(input_channels * this->m_args.channel_multiplier); + public: + DepthwiseDepthfirst(StratType *const strat, const DepthwiseArgs &args, const OutputStage &os = {}) + : Parent(strat, args, os), m_bias(nullptr) + { + } - TInput *const input_buffer = reinterpret_cast(working_space); + DepthwiseDepthfirst(DepthwiseDepthfirst &) = delete; + DepthwiseDepthfirst &operator=(DepthwiseDepthfirst &) = delete; - // Initialise the input buffer - for (unsigned int c = 0; c < input_channels; c++) - { - input_buffer[c] = static_cast(0); - } + void pack_parameters(void *buffer, const void *biases, const void *weights, size_t ld_weight_col, size_t ld_weight_row) override + { + reinterpret_cast(this->m_strat.get())->pack_parameters( + this->m_args, buffer, biases, this->get_output_stage(), + weights, ld_weight_col, ld_weight_row + ); + m_bias = reinterpret_cast(biases); + depthwise_depthfirst::stash_bias(this->get_output_stage(), biases); + } - // For each output tile, construct the requisite set of pointers and call - // into the kernel. - for (unsigned int batch = 0; batch < batches; batch++) - { - // Get batch pointers - const auto inptr_batch = inptr + batch * ld_input_batch; - const auto outptr_batch = outptr + batch * ld_output_batch; + size_t get_working_size_per_thread(const unsigned int n_input_channels) const override + { + DepthwiseArgs args(this->m_args); + args.input_channels = n_input_channels; + return WorkspaceManager::get_sizeof_workspace( + WorkspaceArgs(this->m_strat.get(), args, this->get_output_stage()) + ); + } - for (int start_out_i = start_out_height; - start_out_i < end_out_height; - start_out_i += static_cast(m_strat->get_output_rows())) - { - const int end_out_i = start_out_i + m_strat->get_output_rows(); - const int start_in_i = start_out_i * m_strat->get_stride_rows() - padding.top; - const int end_in_i = start_in_i + m_strat->get_input_rows(); - - // Compute top/bottom padding - const auto pad_top = static_cast(-std::min(start_in_i, 0)); - const auto pad_bottom = static_cast(-std::min(static_cast(input_height) - end_in_i, 0)); - const unsigned int valid_output_rows = std::min( - end_out_i - start_out_i, - static_cast(output_height) - start_out_i - ); + void initialise_working_space(void *buffer, unsigned int n_input_channels) const override + { + DepthwiseArgs args(this->m_args); + args.input_channels = n_input_channels; + WorkspaceManager::initialise( + buffer, WorkspaceArgs(this->m_strat.get(), args, this->get_output_stage()) + ); + } - // Fill the input pointer array with padding values - for (auto index = 0u; index < m_strat->get_input_rows() * m_strat->get_input_cols(); index++) - { - inptr_array[index] = input_buffer; - } + protected: + void compute_tile_padded( + unsigned int output_i, unsigned int output_j, + unsigned int output_channel_start, unsigned int output_channel_end, + const TensorSpec &input, + const TensorSpec &output, + const void *parameters, + void *working_space_raw + ) const override + { + // Get the working space + auto ws = reinterpret_cast(working_space_raw); + + // Compute the input pointer array + const auto input_channel_start = output_channel_start / this->m_args.channel_multiplier; + + const int ii = static_cast(output_i * this->m_args.stride_rows) - this->m_args.padding.top; + const auto input_pad_top = static_cast(ii < 0 ? -ii : 0); + const auto input_i = static_cast(ii < 0 ? 0 : ii); + + const int ij = static_cast(output_j * this->m_args.stride_cols) - this->m_args.padding.left; + const auto input_pad_left = static_cast(ij < 0 ? -ij : 0); + const auto input_j = static_cast(ij < 0 ? 0 : ij); + + fill_pointer_array( + ws->inptr_array, this->m_strat->get_input_rows(), this->m_strat->get_input_cols(), + input.base + input_i*input.ld_row + input_j*input.ld_col + input_channel_start, + input.ld_row, input.ld_col, + ws->input_buffer, + input_pad_top, this->m_args.input_rows - input_i, + input_pad_left, this->m_args.input_cols - input_j + ); + + // Compute the output pointer array + fill_pointer_array( + ws->outptr_array, this->m_strat->get_output_rows(), this->m_strat->get_output_cols(), + output.base + output_i*output.ld_row + output_j*output.ld_col + output_channel_start, + output.ld_row, output.ld_col, + ws->output_buffer, + 0, this->m_args.output_rows - output_i, // Top padding, # valid rows + 0, this->m_args.output_cols - output_j // Left padding, # valid columns + ); + + // Execute the kernel + depthwise_depthfirst::Invoke::indirect( + reinterpret_cast(this->m_strat.get()), + ws, this->get_output_stage(), parameters, m_bias, output_channel_end - output_channel_start + ); + } - for (int start_out_j = 0; start_out_j < static_cast(output_width);) - { - const int start_in_j = start_out_j * m_strat->get_stride_cols() - this->m_args.padding.left; - int pad_left = std::min(0, start_in_j); + void compute_row_padded_tile_row( + const unsigned int output_i, unsigned int output_j, unsigned int n_tile_cols, + const unsigned int output_channel_start, const unsigned int output_channel_end, + const TensorSpec &input, + const TensorSpec &output, + const void *parameters, + void *working_space + ) const override + { + using Invoker = depthwise_depthfirst::Invoke; + auto ws = reinterpret_cast(working_space); + const auto strat = reinterpret_cast(this->m_strat.get()); + const auto os = this->get_output_stage(); + + // Compute top and bottom padding; hence fill in the initial pointer arrays. + const auto input_channel_start = output_channel_start / this->m_args.channel_multiplier; + const int ii = static_cast(output_i * this->m_args.stride_rows) - this->m_args.padding.top; + const auto input_pad_top = static_cast(ii < 0 ? -ii : 0); + + const auto input_i = static_cast(ii < 0 ? 0 : ii); + const auto input_j = output_j * this->m_args.stride_cols - this->m_args.padding.left; + + const auto valid_input_rows = std::min(strat->get_input_rows(), this->m_args.input_rows - input_i); + const auto valid_output_rows = std::min(strat->get_output_rows(), this->m_args.output_rows - output_i); + + const auto input_point_stride = input.ld_col * this->m_strat->get_output_cols() * this->m_args.stride_cols; + const auto output_point_stride = output.ld_col * this->m_strat->get_output_cols(); + + fill_pointer_array( + ws->inptr_array, this->m_strat->get_input_rows(), this->m_strat->get_input_cols(), + input.base + input_i*input.ld_row + input_j*input.ld_col + input_channel_start, + input.ld_row, input.ld_col, + ws->input_buffer, + input_pad_top, this->m_args.input_rows - input_i, + 0, this->m_args.input_cols - input_j // No left padding + ); + + fill_pointer_array( + ws->outptr_array, this->m_strat->get_output_rows(), this->m_strat->get_output_cols(), + output.base + output_i*output.ld_row + output_j*output.ld_col + output_channel_start, + output.ld_row, output.ld_col, + ws->output_buffer, + 0, this->m_args.output_rows - output_i, // Top padding, # valid rows + 0, this->m_args.output_cols - output_j // Left padding, # valid columns + ); + + for (; n_tile_cols; n_tile_cols--) + { + // Execute the kernel + Invoker::indirect( + strat, ws, os, parameters, m_bias, output_channel_end - output_channel_start + ); - // Compute how many output tiles we can compute with the direct kernel. - int n_direct_tiles = 0; - if (!pad_top && !pad_bottom && !pad_left) + // Update all unpadded pointers + { + auto ptr = ws->inptr_array + strat->get_input_cols() * input_pad_top; + for (auto n = input_pad_top; n < valid_input_rows; n++) + { + for (auto m = 0u; m < strat->get_input_cols(); m++) { - // Determine the maximum number of tiles we could handle. - n_direct_tiles = (output_width - start_out_j) / m_strat->get_output_cols(); - - // Continue to reduce this number as required to avoid reading - // padding on the right edge. - int end_in_j = start_in_j + n_direct_tiles * m_strat->get_input_cols(); - int pad_right = std::max(0, end_in_j - static_cast(input_width)); - - while (pad_right && n_direct_tiles) - { - n_direct_tiles--; - end_in_j -= m_strat->get_input_cols(); - pad_right = std::max(0, end_in_j - static_cast(input_width)); - } + *(ptr++) += input_point_stride; } + } + } + { + auto ptr = ws->outptr_array; + for (auto n = 0u; n < valid_output_rows * strat->get_output_cols(); n++) + { + *(ptr++) += output_point_stride; + } + } + } + } - // Use the unpadded kernel if we can, otherwise use the padded one. - if (n_direct_tiles) - { - auto inptr = inptr_batch + start_in_i*ld_input_row + start_in_j*ld_input_col; - auto outptr = outptr_batch + start_out_i*ld_output_row + start_out_j*ld_output_col; - start_out_j += n_direct_tiles*m_strat->get_output_cols(); + void compute_tiles_unpadded( + unsigned int output_i, const unsigned int output_j, + unsigned int n_tile_rows, unsigned int n_tile_cols, + unsigned int output_channel_start, unsigned int output_channel_end, + const TensorSpec &input, + const TensorSpec &output, + const void *parameters, + void *working_space_raw + ) const override + { + using Invoker = depthwise_depthfirst::Invoke; + auto ws = reinterpret_cast(working_space_raw); + const auto strat = reinterpret_cast(this->m_strat.get()); + const auto os = this->get_output_stage(); -#ifdef CYCLE_PROFILING - auto p = prof.ScopedProfiler(PROFILE_KERNEL, 0); -#endif - m_strat->direct_kernel(1, n_direct_tiles, - inptr, ld_input_row, ld_input_col, - outptr, ld_output_row, ld_output_col, - parameters, this->m_args.input_channels, - &activation_min, &activation_max); - continue; - } + if (Invoker::supports_direct_kernel) + { + // If the direct kernel is supported, then use it. + // Compute the base pointers we'll use in the tile. + auto outptr = output.base + output_channel_start + output_i * output.ld_row + output_j * output.ld_col; + const int start_input_i = output_i * this->m_args.stride_rows - this->m_args.padding.top; + const int start_input_j = output_j * this->m_args.stride_cols - this->m_args.padding.left; + auto inptr = input.base + output_channel_start + start_input_i * input.ld_row + start_input_j * input.ld_col; + + // Execute the kernel + Invoker::direct( + strat, ws, os, + n_tile_rows, n_tile_cols, + inptr, input.ld_row, input.ld_col, + outptr, output.ld_row, output.ld_col, + parameters, output_channel_end - output_channel_start + ); + } + else + { + // Otherwise, we repeatedly call the padded kernel but use our knowledge + // of the tensor structure to avoid recomputing the pointer array. + const auto input_channel_start = output_channel_start / this->m_args.channel_multiplier; + + const auto n_input_pointers = this->m_strat->get_input_rows() * this->m_strat->get_input_cols(); + const auto input_point_stride = input.ld_col * this->m_strat->get_output_cols() * this->m_args.stride_cols; + const auto n_output_pointers = this->m_strat->get_output_rows() * this->m_strat->get_output_cols(); + const auto output_point_stride = output.ld_col * this->m_strat->get_output_cols(); + + // For each tile row, initialise the input and output pointer arrays. For + // each subsequent tile we simply update the pointers. + for (unsigned int tile_i = 0; tile_i < n_tile_rows; tile_i++) + { + const int input_i = static_cast(output_i * this->m_args.stride_rows) - this->m_args.padding.top; + const int input_j = static_cast(output_j * this->m_args.stride_cols) - this->m_args.padding.left; + + fill_pointer_array( + ws->inptr_array, this->m_strat->get_input_rows(), this->m_strat->get_input_cols(), + input.base + input_i*input.ld_row + input_j*input.ld_col + input_channel_start, + input.ld_row, input.ld_col, + ws->input_buffer, + 0, this->m_args.input_rows, + 0, this->m_args.input_cols + ); - const int end_out_j = start_out_j + m_strat->get_output_cols(); - const int end_in_j = start_in_j + m_strat->get_input_cols(); + // Compute the output pointer array + fill_pointer_array( + ws->outptr_array, this->m_strat->get_output_rows(), this->m_strat->get_output_cols(), + output.base + output_i*output.ld_row + output_j*output.ld_col + output_channel_start, + output.ld_row, output.ld_col, + ws->output_buffer, + 0, this->m_args.output_rows, + 0, this->m_args.output_cols + ); - const auto pad_right = static_cast(-std::min(static_cast(input_width) - end_in_j, 0)); - const unsigned int valid_output_cols = std::min( - end_out_j - start_out_j, - static_cast(output_width) - start_out_j + for (unsigned int tile_j = 0; tile_j < n_tile_cols; tile_j++) + { + // Invoke the indirect kernel for this tile + depthwise_depthfirst::Invoke::indirect( + strat, ws, os, parameters, m_bias, output_channel_end - output_channel_start ); - pad_left *= -1; - // Construct the input pointer array - fill the array with pointers to - // the input buffer and then fill in the required values. - for (auto i = pad_top; i < m_strat->get_input_rows() - pad_bottom; i++) - { - // Can skip over the left padding because we will have either the - // same or less than the previous tile. - unsigned int j = pad_left; - const TInput *colptr = inptr_batch + (start_in_i + i) * ld_input_row + (start_in_j + j) * ld_input_col; - const void **ptrs = inptr_array + i * m_strat->get_input_cols() + j; - for (; j < m_strat->get_input_cols() - pad_right; j++) - { - *(ptrs++) = colptr; - colptr += ld_input_col; - } - for (; j < m_strat->get_input_cols(); j++) - { - *(ptrs++) = input_buffer; - } - } - // Construct the output pointer array. - void **outptr_pos = outptr_array; - for (auto i = 0u; i < valid_output_rows; i++) + // Progress the pointers + for (auto i = 0u; i < n_input_pointers; i++) { - unsigned int j = 0u; - TOutput *colptr = outptr_batch + (start_out_i + i) * ld_output_row + start_out_j * ld_output_col; - for (; j < valid_output_cols; j++) - { - *(outptr_pos++) = colptr; - colptr += ld_output_col; - } - for (; j < m_strat->get_output_cols(); j++) - { - *(outptr_pos++) = output_buffer; - } + ws->inptr_array[i] += input_point_stride; } - for (auto i = valid_output_rows; i < m_strat->get_output_rows(); i++) + for (auto i = 0u; i < n_output_pointers; i++) { - for (auto j = 0u; j < m_strat->get_output_cols(); j++) - { - *(outptr_pos++) = output_buffer; - } + ws->outptr_array[i] += output_point_stride; } - - start_out_j += m_strat->get_output_cols(); - -#ifdef CYCLE_PROFILING - // TODO Work number - auto p = prof.ScopedProfiler(PROFILE_KERNEL, (unsigned long)(0)); -#endif - m_strat->indirect_kernel(inptr_array, outptr_array, parameters, - this->m_args.input_channels, - &activation_min, &activation_max); } + + output_i += this->m_strat->get_output_rows(); } } } -- cgit v1.2.1