diff options
Diffstat (limited to 'src/core/NEON/kernels/arm_conv/depthwise/depthwise_depthfirst_generic.hpp')
-rw-r--r-- | src/core/NEON/kernels/arm_conv/depthwise/depthwise_depthfirst_generic.hpp | 45 |
1 files changed, 28 insertions, 17 deletions
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/depthwise_depthfirst_generic.hpp b/src/core/NEON/kernels/arm_conv/depthwise/depthwise_depthfirst_generic.hpp index b058ce26f2..ca5026b6e0 100644 --- a/src/core/NEON/kernels/arm_conv/depthwise/depthwise_depthfirst_generic.hpp +++ b/src/core/NEON/kernels/arm_conv/depthwise/depthwise_depthfirst_generic.hpp @@ -99,7 +99,7 @@ class GenericDepthfirstStrategy : public DepthwiseDepthfirstStrategyCommon<TInpu { interleaves::PackingArguments packing_args( this->get_kernel_rows(), this->get_kernel_cols(), sizeof(TWeight), - false, sizeof(TAccum), // Don't pack the bias + false, sizeof(TAccum), this->uses_premultiply(), // Don't pack the bias this->get_vl_type(), sizeof(TAccum), this->get_accumulator_depth_vl(), [this] (unsigned int idx, unsigned int &x, unsigned int &y) -> bool { return this->get_kernel_packing_point(idx, x, y); } @@ -115,7 +115,7 @@ class GenericDepthfirstStrategy : public DepthwiseDepthfirstStrategyCommon<TInpu { interleaves::PackingArguments packing_args( this->get_kernel_rows(), this->get_kernel_cols(), sizeof(TWeight), - false, sizeof(TAccum), // Don't pack the bias + false, sizeof(TAccum), this->uses_premultiply(), // Don't pack the bias this->get_vl_type(), sizeof(TAccum), this->get_accumulator_depth_vl(), [this] (unsigned int idx, unsigned int &x, unsigned int &y) -> bool { return this->get_kernel_packing_point(idx, x, y); } @@ -208,6 +208,7 @@ class DepthwiseDepthfirstGeneric : public DepthwiseDepthfirstCommon<TInput, TWei OutputArrayElement<TOutput>, GenericInputArrayElement<TInput>, InputBufferElement<TInput>, + IntermediateBufferElement<TInput>, ActivationsElement<TAccum, OutputStage> >; using WorkingSpace = typename WorkspaceManager::WorkspaceType; @@ -232,21 +233,38 @@ class DepthwiseDepthfirstGeneric : public DepthwiseDepthfirstCommon<TInput, TWei depthwise_depthfirst::stash_bias(this->get_output_stage(), m_bias); } - size_t get_working_size_per_thread(const unsigned int n_input_channels) const override + size_t get_working_size_per_thread() const override { DepthwiseArgs args(this->m_args); - args.input_channels = n_input_channels; return WorkspaceManager::get_sizeof_workspace(WorkspaceArgs<IDepthfirstStrategy, OutputStage>(this->m_strat.get(), args, this->get_output_stage())); } - void initialise_working_space(void *buffer, unsigned int n_input_channels) const override + void initialise_working_space(void *buffer) const override { DepthwiseArgs args(this->m_args); - args.input_channels = n_input_channels; return WorkspaceManager::initialise(buffer, WorkspaceArgs<IDepthfirstStrategy, OutputStage>(this->m_strat.get(), args, this->get_output_stage())); } protected: + void fill_inptr_array(const DepthwiseArgs &args, + const TensorSpec<const TInput *> &input, + const TInput **inptr_array, TInput *input_buffer, + const unsigned int input_i, const unsigned int input_j, + const unsigned int input_pad_top, const unsigned int input_pad_left) const override + { + fill_pointer_array_generic_kernel<const TInput>( + inptr_array, + this->m_strat->get_output_rows(), this->m_strat->get_output_cols(), + args.kernel_rows, args.kernel_cols, + args.stride_rows, args.stride_cols, + input.base, + input.ld_row, input.ld_col, + input_buffer, + input_pad_top, args.input_rows - input_i, + input_pad_left, args.input_cols - input_j + ); + } + void compute_tile_padded( const DepthwiseArgs &args, unsigned int output_i, unsigned int output_j, @@ -268,17 +286,10 @@ class DepthwiseDepthfirstGeneric : public DepthwiseDepthfirstCommon<TInput, TWei const auto input_pad_left = static_cast<unsigned int>(ij < 0 ? -ij : 0); const auto input_j = static_cast<unsigned int>(ij < 0 ? 0 : ij); - fill_pointer_array_generic_kernel<const TInput>( - ws->inptr_array, - this->m_strat->get_output_rows(), this->m_strat->get_output_cols(), - args.kernel_rows, args.kernel_cols, - args.stride_rows, args.stride_cols, - input.base + input_i*input.ld_row + input_j*input.ld_col + channel_start, - input.ld_row, input.ld_col, - ws->input_buffer, - input_pad_top, args.input_rows - input_i, - input_pad_left, args.input_cols - input_j - ); + Tile<TInput> multiplied_input; + this->initialise_inptr_array(args, channel_start, channel_end, input, + ws->inptr_array, ws->input_buffer, ws->intermediate_buffer, + input_i, input_j, input_pad_top, input_pad_left, multiplied_input); // Compute the output pointer array fill_pointer_array<TOutput>( |