aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_conv/depthwise/depthwise_depthfirst_generic.hpp
diff options
context:
space:
mode:
authorramelg01 <ramy.elgammal@arm.com>2022-04-07 02:42:52 +0100
committerRamy Elgammal <ramy.elgammal@arm.com>2022-04-26 15:51:22 +0000
commit8a164884dddf769643cf3b9f7f94e43cb4f3c20b (patch)
tree35958dd48b6df1a851c880dad2b2ce285671b611 /src/core/NEON/kernels/arm_conv/depthwise/depthwise_depthfirst_generic.hpp
parentc827e99fc46521f43719b0c2d1b6f05d66abf68c (diff)
downloadComputeLibrary-8a164884dddf769643cf3b9f7f94e43cb4f3c20b.tar.gz
Update Neon™ depthwise kernel
- Reduce duplication and simplify overall structure. - Improve multi-threaded performance by sharing more data in lower-level caches. Partially Resolves: COMPMID-5054 Signed-off-by: Ramy Elgammal <ramy.elgammal@arm.com> Change-Id: Iac747f39b21c540122fa75218762631c4d787911 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7449 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Andrew Mundy Reviewed-by: Sheri Zhang <sheri.zhang@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/arm_conv/depthwise/depthwise_depthfirst_generic.hpp')
-rw-r--r--src/core/NEON/kernels/arm_conv/depthwise/depthwise_depthfirst_generic.hpp520
1 files changed, 221 insertions, 299 deletions
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/depthwise_depthfirst_generic.hpp b/src/core/NEON/kernels/arm_conv/depthwise/depthwise_depthfirst_generic.hpp
index f04f7751db..9f53f7cc6f 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/depthwise_depthfirst_generic.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/depthwise_depthfirst_generic.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -24,355 +24,277 @@
#pragma once
-#include "src/core/NEON/kernels/arm_gemm/utils.hpp"
-
-#ifdef CYCLE_PROFILING
-#include "profiler.hpp"
-#endif
-
-#include <limits>
+#include "depthwise_depthfirst.hpp"
namespace arm_conv {
namespace depthwise {
-template <class Strategy, unsigned OutputRows, unsigned int OutputCols>
-class DepthwiseDepthfirstGenericBase :
- public DepthwiseCommon<typename Strategy::input_type,
- typename Strategy::weight_type,
- typename Strategy::return_type>
+template <typename TInput, typename TOutput, typename TAccum>
+struct GenericDepthfirstKernelStrategyFunctionType
{
- protected:
+ using KernelType = std::function<void(const TInput *const *const, TOutput *const *const, const void *, const void *, const unsigned int, const unsigned int, const TAccum, const TAccum)>;
+};
- using TInput = typename Strategy::input_type;
- using TWeight = typename Strategy::weight_type;
- using TOutput = typename Strategy::return_type;
- using TAccum = typename Strategy::bias_type;
+template <typename TInput, typename TOutput>
+struct GenericDepthfirstKernelStrategyFunctionType<TInput, TOutput, int32_t>
+{
+ using KernelType = std::function<void(const TInput *const *const, TOutput *const *const, const void *, const arm_gemm::Requantize32 &, unsigned int, unsigned int)>;
+};
- size_t sizeof_input_ptr_array(void) const
- {
- return sizeof(TInput *) * this->m_args.kernel_rows * this->m_args.kernel_cols * Strategy::n_output_points;
- }
+template <typename TInput, typename TWeight, typename TOutput, typename TAccum>
+class GenericDepthfirstKernelStrategy
+{
+ unsigned int m_n_output_points;
+ arm_gemm::VLType m_vl_type;
+ unsigned int m_accumulator_depth_vl;
- size_t sizeof_input_buffer(unsigned int n_channels) const
+ public:
+ GenericDepthfirstKernelStrategy(unsigned int n_output_points, arm_gemm::VLType vl_type, unsigned int accumulator_depth_vl=1)
+ : m_n_output_points(n_output_points), m_vl_type(vl_type), m_accumulator_depth_vl(accumulator_depth_vl)
{
- const unsigned int vl = arm_gemm::utils::get_vector_length<TInput>(Strategy::vl_type);
- const auto rounded_channels = arm_gemm::roundup(n_channels, vl);
- return sizeof(TInput) * rounded_channels;
}
- size_t sizeof_output_buffer(unsigned int n_channels) const
+ virtual ~GenericDepthfirstKernelStrategy() = default;
+
+ virtual arm_gemm::VLType get_vl_type() const { return m_vl_type; }
+ virtual unsigned int get_accumulator_depth_vl() const { return m_accumulator_depth_vl; }
+ virtual unsigned int get_n_output_points() const { return m_n_output_points; }
+
+ using KernelType = typename GenericDepthfirstKernelStrategyFunctionType<TInput, TOutput, TAccum>::KernelType;
+ virtual KernelType get_kernel(void) const = 0;
+};
+
+template <typename TInput,
+ typename TWeight=TInput,
+ typename TOutput=TInput,
+ typename TAccum=typename DefaultTAccum<TInput>::Type,
+ typename OutputStage=typename DefaultOutputStage<TOutput>::Type>
+class GenericDepthfirstStrategy : public DepthwiseDepthfirstStrategyCommon<TInput, TWeight, TOutput, TAccum, OutputStage>
+{
+ protected:
+ using KernelStrategyType = GenericDepthfirstKernelStrategy<TInput, TWeight, TOutput, TAccum>;
+ std::unique_ptr<KernelStrategyType> m_strategy;
+
+ public:
+ GenericDepthfirstStrategy(
+ KernelStrategyType *strat, unsigned int n_output_rows, unsigned int n_output_cols,
+ const DepthwiseArgs &args
+ )
+ : DepthwiseDepthfirstStrategyCommon<TInput, TWeight, TOutput, TAccum, OutputStage>(
+ n_output_rows, n_output_cols,
+ args.kernel_rows, args.kernel_cols,
+ args.stride_rows, args.stride_cols
+ ),
+ m_strategy(strat)
{
- const unsigned int vl = arm_gemm::utils::get_vector_length<TOutput>(Strategy::vl_type);
- const auto rounded_channels = arm_gemm::roundup(n_channels, vl);
- return sizeof(TOutput) * rounded_channels;
}
- unsigned int input_rows(void) const
+ GenericDepthfirstStrategy(GenericDepthfirstStrategy &) = delete;
+ GenericDepthfirstStrategy operator=(GenericDepthfirstStrategy &) = delete;
+
+ arm_gemm::VLType get_vl_type(void) const override { return m_strategy->get_vl_type(); }
+ unsigned int get_accumulator_depth_vl(void) const override { return m_strategy->get_accumulator_depth_vl(); }
+
+ size_t get_storage_size(const DepthwiseArgs &args) const override
{
- return this->m_args.kernel_rows + (OutputRows - 1)*this->m_args.stride_rows;
+ interleaves::PackingArguments packing_args(
+ this->get_kernel_rows(), this->get_kernel_cols(), sizeof(TWeight),
+ false, sizeof(TAccum), // Don't pack the bias
+ this->get_vl_type(), sizeof(TAccum), this->get_accumulator_depth_vl(),
+ [this] (unsigned int idx, unsigned int &x, unsigned int &y) -> bool
+ { return this->get_kernel_packing_point(idx, x, y); }
+ );
+ return interleaves::get_storage_size_generic(packing_args, args);
}
- unsigned int input_cols(void) const
+ void pack_parameters(
+ const DepthwiseArgs &args, void *buffer,
+ const void *biases, const OutputStage &,
+ const void *weights, size_t ld_weight_col, size_t ld_weight_row
+ ) const override
{
- return this->m_args.kernel_cols + (OutputCols - 1)*this->m_args.stride_cols;
+ interleaves::PackingArguments packing_args(
+ this->get_kernel_rows(), this->get_kernel_cols(), sizeof(TWeight),
+ false, sizeof(TAccum), // Don't pack the bias
+ this->get_vl_type(), sizeof(TAccum), this->get_accumulator_depth_vl(),
+ [this] (unsigned int idx, unsigned int &x, unsigned int &y) -> bool
+ { return this->get_kernel_packing_point(idx, x, y); }
+ );
+ interleaves::pack_parameters_generic(
+ packing_args, args, buffer, biases, weights, ld_weight_col, ld_weight_row);
}
- void execute_tiles(
- std::function<void(const TInput *const *, TOutput *const *)> tile_fn,
- std::function<void(TInput *, unsigned int)> initialise_input_buffer,
- const unsigned int batches,
- const unsigned int input_height,
- const unsigned int input_width,
- const unsigned int input_channels,
- const PaddingValues &padding,
- const void *const _input,
- const size_t ld_input_col,
- const size_t ld_input_row,
- const size_t ld_input_batch,
- const unsigned int output_height,
- const unsigned int output_width,
- void *const _output,
- const size_t ld_output_col,
- const size_t ld_output_row,
- const size_t ld_output_batch,
- void *const _working_space,
- const unsigned int thread_id,
- const unsigned int n_threads
- ) const
+ const typename KernelStrategyType::KernelType get_kernel() const { return m_strategy->get_kernel(); }
+};
+
+// Use a templated function to marshal arguments when executing the kernel.
+template <typename OutputStage> struct DepthwiseDepthfirstGenericKernelCall;
+
+template <>
+struct DepthwiseDepthfirstGenericKernelCall<Nothing>
+{
+ template <typename StratType, typename WorkspaceType, typename TAccum>
+ static void execute(
+ const StratType *strat, const WorkspaceType *ws, const Nothing &,
+ const TAccum *bias, const void *params,
+ const unsigned int n_kernel_points, const unsigned int n_output_channels
+ )
{
- static_assert(OutputRows * OutputCols <= Strategy::n_output_points,
- "Too many output points for kernel.");
-
- // Determine what portion of the work to do.
- const unsigned int n_rows_per_thread = arm_gemm::iceildiv(output_height, n_threads);
- const int start_out_height = std::min(thread_id * n_rows_per_thread, output_height);
- const int end_out_height = std::min(start_out_height + n_rows_per_thread, output_height);
-
- // Cast input and output pointers into the right types
- const TInput *const inptr = static_cast<const TInput *>(_input);
- TOutput *const outptr = static_cast<TOutput *>(_output);
-
- // Allocate portions of the working space
- uint8_t *const working_space = static_cast<uint8_t *>(_working_space) + this->get_working_size(thread_id, input_channels);
- const TInput **const inptr_array = reinterpret_cast<const TInput **>(working_space);
- TOutput *const output_buffer = reinterpret_cast<TOutput *>(working_space + this->sizeof_input_ptr_array());
- TInput *const input_buffer = reinterpret_cast<TInput *>(working_space + this->sizeof_input_ptr_array() + this->sizeof_output_buffer(input_channels * this->m_args.channel_multiplier));
-
- // Create an array for the output pointers
- TOutput * _outptr_array[Strategy::n_output_points];
- TOutput **const outptr_array = _outptr_array;
-
- // Initialise the input buffer
- initialise_input_buffer(input_buffer, input_channels);
-
- // For each output tile, construct the requisite set of pointers and call
- // into the kernel.
- for (unsigned int batch = 0; batch < batches; batch++)
- {
- // Get batch pointers
- const auto inptr_batch = inptr + batch * ld_input_batch;
- const auto outptr_batch = outptr + batch * ld_output_batch;
-
- for (int start_out_i = start_out_height;
- start_out_i < end_out_height;
- start_out_i += static_cast<int>(OutputRows))
- {
- const int end_out_i = std::min(start_out_i + OutputRows,
- output_height);
-
- for (int start_out_j = 0;
- start_out_j < static_cast<int>(output_width);
- start_out_j += static_cast<int>(OutputCols))
- {
- const int end_out_j = std::min(start_out_j + OutputCols,
- output_width);
-
- // Fill the pointer arrays with pointers to the input/output buffers.
- for (auto index = 0u;
- index < (Strategy::n_output_points * this->m_args.kernel_rows * this->m_args.kernel_cols);
- index++)
- {
- inptr_array[index] = input_buffer;
- }
- for (auto index = 0u; index < Strategy::n_output_points; index++)
- {
- outptr_array[index] = output_buffer;
- }
-
- // Construct the pointer arrays together. Note that the input pointer
- // array is striped. Since the array has already been filled with
- // pointers to the padding array we merely fill in the valid points
- // as we get to them.
- unsigned int output_index = 0;
- auto outptr_row = outptr_batch + start_out_i * ld_output_row + start_out_j * ld_output_col;
- for (auto out_i = start_out_i; out_i < end_out_i; out_i++)
- {
- auto outptr_col = outptr_row;
-
- // Compute the padding for this row of tiles.
- const int start_in_i = out_i * this->m_args.stride_rows - padding.top;
- const int end_in_i = start_in_i + this->m_args.kernel_rows;
- const auto pad_top = static_cast<unsigned int>(std::max<int>(0, 0 - start_in_i));
- const auto pad_bottom = static_cast<unsigned int>(std::max<int>(0, end_in_i - input_height));
- const unsigned int valid_rows = this->m_args.kernel_rows - pad_top - pad_bottom;
-
- for (auto out_j = start_out_j; out_j < end_out_j; out_j++, output_index++)
- {
- // Compute the output pointer.
- outptr_array[output_index] = outptr_col;
- outptr_col += ld_output_col;
-
- // Compute the padding for this tile.
- const int start_in_j = out_j * this->m_args.stride_cols - padding.left;
- const int end_in_j = start_in_j + this->m_args.kernel_cols;
- const auto pad_left = static_cast<unsigned int>(std::max<int>(0, 0 - start_in_j));
- const auto pad_right = static_cast<unsigned int>(std::max<int>(0, end_in_j - input_width));
- const unsigned int valid_cols = this->m_args.kernel_cols - pad_left - pad_right;
-
- // Hence compute the input pointers.
- auto input_index = output_index + Strategy::n_output_points * (pad_top * this->m_args.kernel_cols + pad_left);
- auto inptr_row = inptr_batch + (start_in_i + pad_top) * ld_input_row + (start_in_j + pad_left) * ld_input_col;
- for (auto in_i = 0u; in_i < valid_rows; in_i++)
- {
- auto inptr_col = inptr_row;
- auto input_index_col = input_index;
-
- for (auto in_j = 0u; in_j < valid_cols; in_j++)
- {
- inptr_array[input_index_col] = inptr_col;
- inptr_col += ld_input_col;
- input_index_col += Strategy::n_output_points;
- }
-
- inptr_row += ld_input_row;
- input_index += Strategy::n_output_points * this->m_args.kernel_cols;
- }
- }
-
- outptr_row += ld_output_row;
- }
-
- tile_fn(inptr_array, outptr_array);
- }
- }
- }
+ strat->get_kernel()(
+ ws->inptr_array,
+ ws->outptr_array,
+ params, bias,
+ n_kernel_points, n_output_channels,
+ ws->activation_min, ws->activation_max
+ );
}
+};
- public:
- DepthwiseDepthfirstGenericBase(const DepthwiseArgs &args) : DepthwiseCommon<TInput, TWeight, TOutput>(args)
+template <>
+struct DepthwiseDepthfirstGenericKernelCall<arm_gemm::Requantize32>
+{
+ template <typename StratType, typename WorkspaceType>
+ static void execute(
+ const StratType *strat, const WorkspaceType *ws, const arm_gemm::Requantize32 &qp,
+ const int32_t *, const void *params,
+ const unsigned int n_kernel_points, const unsigned int n_output_channels
+ )
{
+ strat->get_kernel()(
+ ws->inptr_array,
+ ws->outptr_array,
+ params, qp,
+ n_kernel_points, n_output_channels
+ );
}
+};
- DepthwiseDepthfirstGenericBase(DepthwiseDepthfirstGenericBase &) = delete;
- DepthwiseDepthfirstGenericBase &operator=(DepthwiseDepthfirstGenericBase &) = delete;
- size_t get_storage_size(void) const override
+/* Workspace Element for an array of input pointers as consumed by the
+ * "Generic" depthwise kernels.
+ */
+template <typename T>
+class GenericInputArrayElement
+{
+ public:
+ struct Workspace
{
- const unsigned int vl = arm_gemm::utils::get_vector_length<TAccum>(Strategy::vl_type);
- const auto rounded_channels = arm_gemm::roundup(this->m_args.input_channels, vl);
- return (this->m_args.kernel_rows * this->m_args.kernel_cols) * rounded_channels * sizeof(TWeight);
- }
+ const T **inptr_array;
+ };
- void pack_parameters(void *_buffer, const void *, const void *_weights, size_t ld_weight_col, size_t ld_weight_row) override
+ template <class OutputStage>
+ static size_t get_element_size(const WorkspaceArgs<IDepthfirstStrategy, OutputStage> &args)
{
- // Cast the pointers
- TWeight *buffer = static_cast<TWeight *>(_buffer);
- const TWeight *const weights = static_cast<const TWeight *>(_weights);
-
- const unsigned int vl = arm_gemm::utils::get_vector_length<TAccum>(Strategy::vl_type);
- ld_weight_col = (ld_weight_col == 0) ? this->m_args.input_channels : ld_weight_col;
- ld_weight_row = (ld_weight_row == 0) ? this->m_args.kernel_cols * ld_weight_col : ld_weight_row;
-
- for (unsigned int n = 0; n < this->m_args.input_channels; n += vl)
- {
- const unsigned int todo = std::min(vl, this->m_args.input_channels - n);
-
- // Copy each of the weights in turn
- auto weights_row = weights + n;
- for (unsigned int i = 0; i < this->m_args.kernel_rows; i++)
- {
- auto weights_col = weights_row;
-
- for (unsigned int j = 0; j < this->m_args.kernel_cols; j++)
- {
- for (unsigned int m = 0; m < todo; m++)
- {
- buffer[m] = weights_col[m];
- }
- buffer += vl;
-
- weights_col += ld_weight_col;
- }
-
- weights_row += ld_weight_row;
- }
- }
+ const auto kernel_points = args.depthwise_args.kernel_rows * args.depthwise_args.kernel_cols;
+ return sizeof(T **) * args.strategy->get_input_rows() * args.strategy->get_input_cols() * kernel_points;
}
- size_t get_working_size(const unsigned int n_threads, const unsigned int n_channels) const override
+ template <class WorkspaceType, class OutputStage>
+ static void *initialise(WorkspaceType *ws, void *buffer, const WorkspaceArgs<IDepthfirstStrategy, OutputStage> &args)
{
- const unsigned int n_output_channels = n_channels * this->m_args.channel_multiplier;
- return n_threads * (sizeof_input_ptr_array() +
- sizeof_output_buffer(n_output_channels) +
- sizeof_input_buffer(n_channels));
+ ws->inptr_array = reinterpret_cast<const T**>(buffer);
+ return reinterpret_cast<char *>(buffer) + get_element_size(args);
}
};
-template <class Strategy, unsigned OutputRows, unsigned int OutputCols>
-class DepthwiseDepthfirstGeneric : public DepthwiseDepthfirstGenericBase<Strategy, OutputRows, OutputCols>
+template <typename TInput, typename TWeight=TInput, typename TOutput=TInput,
+ typename TAccum=typename DefaultTAccum<TInput>::Type,
+ typename OutputStage=typename DefaultOutputStage<TOutput>::Type>
+class DepthwiseDepthfirstGeneric : public DepthwiseDepthfirstCommon<TInput, TWeight, TOutput, TAccum, OutputStage>
{
- using Parent = DepthwiseDepthfirstGenericBase<Strategy, OutputRows, OutputCols>;
- using TInput = typename Parent::TInput;
- using TWeight = typename Parent::TWeight;
- using TAccum = typename Parent::TAccum;
- using TOutput = typename Parent::TOutput;
-
+ using StratType = GenericDepthfirstStrategy<TInput, TWeight, TOutput, TAccum, OutputStage>;
+ using Parent = DepthwiseDepthfirstCommon<TInput, TWeight, TOutput, TAccum, OutputStage>;
+ using WorkspaceManager = Workspace<
+ OutputArrayElement<TOutput>,
+ GenericInputArrayElement<TInput>,
+ InputBufferElement<TInput>,
+ ActivationsElement<TAccum, OutputStage>
+ >;
+ using WorkingSpace = typename WorkspaceManager::WorkspaceType;
const TAccum *m_bias = nullptr;
public:
- DepthwiseDepthfirstGeneric(const DepthwiseArgs &args) : Parent(args)
+ DepthwiseDepthfirstGeneric(StratType *const strat, const DepthwiseArgs &args, const OutputStage &os={})
+ : Parent(strat, args, os)
{
}
DepthwiseDepthfirstGeneric(DepthwiseDepthfirstGeneric &) = delete;
DepthwiseDepthfirstGeneric &operator=(DepthwiseDepthfirstGeneric &) = delete;
- void pack_parameters(void *buffer, const void *bias, const void *weights, size_t ld_weight_col, size_t ld_weight_row) override
+ void pack_parameters(
+ void *buffer, const void *biases,
+ const void *weights, size_t ld_weight_col, size_t ld_weight_row
+ ) override
+ {
+ Parent::pack_parameters(buffer, biases, weights, ld_weight_col, ld_weight_row);
+ m_bias = reinterpret_cast<const TAccum *>(biases); // Get a copy of the biases
+ depthwise_depthfirst::stash_bias(this->get_output_stage(), m_bias);
+ }
+
+ size_t get_working_size_per_thread(const unsigned int n_input_channels) const override
+ {
+ DepthwiseArgs args(this->m_args);
+ args.input_channels = n_input_channels;
+ return WorkspaceManager::get_sizeof_workspace(WorkspaceArgs<IDepthfirstStrategy, OutputStage>(this->m_strat.get(), args, this->get_output_stage()));
+ }
+
+ void initialise_working_space(void *buffer, unsigned int n_input_channels) const override
{
- m_bias = static_cast<const TAccum *>(bias);
- Parent::pack_parameters(buffer, bias, weights, ld_weight_col, ld_weight_row);
+ DepthwiseArgs args(this->m_args);
+ args.input_channels = n_input_channels;
+ return WorkspaceManager::initialise(buffer, WorkspaceArgs<IDepthfirstStrategy, OutputStage>(this->m_strat.get(), args, this->get_output_stage()));
}
- using DepthwiseDepthfirstGenericBase<Strategy, OutputRows, OutputCols>::execute;
- void execute(
- const unsigned int batches,
- const unsigned int input_height,
- const unsigned int input_width,
- const unsigned int input_channels,
- const PaddingValues &padding,
- const void *const _input,
- const size_t ld_input_col,
- const size_t ld_input_row,
- const size_t ld_input_batch,
- const void *const parameters,
- const unsigned int output_height,
- const unsigned int output_width,
- void *const _output,
- const size_t ld_output_col,
- const size_t ld_output_row,
- const size_t ld_output_batch,
- void *const _working_space,
- const unsigned int thread_id,
- const unsigned int n_threads
+ protected:
+ void compute_tile_padded(
+ unsigned int output_i, unsigned int output_j,
+ unsigned int channel_start, unsigned int channel_end,
+ const TensorSpec<const TInput *> &input,
+ const TensorSpec<TOutput *> &output,
+ const void *parameters,
+ void *working_space_raw
) const override
{
- Strategy strat(this->m_args.cpu_info);
-#ifdef CYCLE_PROFILING
- arm_gemm::profiler prof;
-#endif
-
- // Compute activation values
- TAccum activation_min, activation_max;
- std::tie(activation_min, activation_max) = get_default_activation_values<TAccum>();
-
- switch (this->m_args.activation.type)
- {
- case arm_gemm::Activation::Type::BoundedReLU:
- activation_max = static_cast<TAccum>(this->m_args.activation.param1);
- // Fall through
- case arm_gemm::Activation::Type::ReLU:
- activation_min = static_cast<TAccum>(0);
- break;
- default:
- break;
- }
-
- // Create a function to initialise the input buffer
- const auto initialise_input_buffer = [] (TInput *const buffer, const unsigned int n) {
- std::memset(buffer, 0, n * sizeof(TInput));
- };
-
- // Create a function to execute a tile of work
- const auto tile_fn = [&] (const TInput *const *const inptrs, TOutput *const * const outptrs) {
-#ifdef CYCLE_PROFILING
- auto p = prof.ScopedProfiler(
- PROFILE_KERNEL,
- (unsigned long) (OutputRows * OutputCols * this->m_args.kernel_rows* this->m_args.kernel_cols)
- );
-#endif
- strat.kernel(inptrs, outptrs, parameters, m_bias,
- this->m_args.kernel_rows * this->m_args.kernel_cols,
- this->m_args.input_channels, activation_min, activation_max);
- };
-
- // Call into a parent utility function to do the actual work.
- Parent::execute_tiles(
- tile_fn, initialise_input_buffer,
- batches, input_height, input_width, input_channels, padding,
- _input, ld_input_col, ld_input_row, ld_input_batch,
- output_height, output_width,
- _output, ld_output_col, ld_output_row, ld_output_batch,
- _working_space, thread_id, n_threads
+ // Get the working space
+ WorkingSpace *ws = reinterpret_cast<WorkingSpace *>(working_space_raw);
+
+ const int ii = static_cast<int>(output_i * this->m_args.stride_rows) - this->m_args.padding.top;
+ const auto input_pad_top = static_cast<unsigned int>(ii < 0 ? -ii : 0);
+ const auto input_i = static_cast<unsigned int>(ii < 0 ? 0 : ii);
+
+ const int ij = static_cast<int>(output_j * this->m_args.stride_cols) - this->m_args.padding.left;
+ const auto input_pad_left = static_cast<unsigned int>(ij < 0 ? -ij : 0);
+ const auto input_j = static_cast<unsigned int>(ij < 0 ? 0 : ij);
+
+ fill_pointer_array_generic_kernel<const TInput>(
+ ws->inptr_array,
+ this->m_strat->get_output_rows(), this->m_strat->get_output_cols(),
+ this->m_args.kernel_rows, this->m_args.kernel_cols,
+ this->m_args.stride_rows, this->m_args.stride_cols,
+ input.base + input_i*input.ld_row + input_j*input.ld_col + channel_start,
+ input.ld_row, input.ld_col,
+ ws->input_buffer,
+ input_pad_top, this->m_args.input_rows - input_i,
+ input_pad_left, this->m_args.input_cols - input_j
+ );
+
+ // Compute the output pointer array
+ fill_pointer_array<TOutput>(
+ ws->outptr_array, this->m_strat->get_output_rows(), this->m_strat->get_output_cols(),
+ output.base + output_i*output.ld_row + output_j*output.ld_col + channel_start,
+ output.ld_row, output.ld_col,
+ ws->output_buffer,
+ 0, this->m_args.output_rows - output_i, // Top padding, # valid rows
+ 0, this->m_args.output_cols - output_j // Left padding, # valid columns
+ );
+
+ // Execute the kernel
+ DepthwiseDepthfirstGenericKernelCall<OutputStage>::execute(
+ reinterpret_cast<const StratType *>(this->m_strat.get()), ws,
+ this->get_output_stage(), m_bias, parameters,
+ this->m_args.kernel_rows * this->m_args.kernel_cols,
+ channel_end - channel_start
);
}
};