aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/assembly/winograd.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/assembly/winograd.hpp')
-rw-r--r--src/core/NEON/kernels/assembly/winograd.hpp181
1 files changed, 106 insertions, 75 deletions
diff --git a/src/core/NEON/kernels/assembly/winograd.hpp b/src/core/NEON/kernels/assembly/winograd.hpp
index 50290757ec..dbf95d23cd 100644
--- a/src/core/NEON/kernels/assembly/winograd.hpp
+++ b/src/core/NEON/kernels/assembly/winograd.hpp
@@ -45,17 +45,24 @@ struct ConvolutionArgs
Shape2D kernel_shape;
arm_gemm::Activation activation;
- ConvolutionArgs(
- unsigned int n_batches,
- const Shape2D &input_shape,
- unsigned int n_input_channels,
- unsigned int pad_top, unsigned int pad_left,
- const Shape2D &output_shape,
- unsigned int n_output_channels,
- const Shape2D kernel_shape,
- const arm_gemm::Activation &activation = {})
- : n_batches(n_batches), input_shape(input_shape), n_input_channels(n_input_channels), pad_top(pad_top), pad_left(pad_left), output_shape(output_shape), n_output_channels(n_output_channels),
- kernel_shape(kernel_shape), activation(activation)
+ ConvolutionArgs(unsigned int n_batches,
+ const Shape2D &input_shape,
+ unsigned int n_input_channels,
+ unsigned int pad_top,
+ unsigned int pad_left,
+ const Shape2D &output_shape,
+ unsigned int n_output_channels,
+ const Shape2D kernel_shape,
+ const arm_gemm::Activation &activation = {})
+ : n_batches(n_batches),
+ input_shape(input_shape),
+ n_input_channels(n_input_channels),
+ pad_top(pad_top),
+ pad_left(pad_left),
+ output_shape(output_shape),
+ n_output_channels(n_output_channels),
+ kernel_shape(kernel_shape),
+ activation(activation)
{
}
};
@@ -105,23 +112,30 @@ public:
virtual unsigned int get_transformed_tile_rows(void) const = 0;
virtual unsigned int get_transformed_tile_cols(void) const = 0;
- void execute(
- const ConvolutionArgs &args,
- const void *inptr, size_t ld_in_row, size_t ld_in_col, size_t ld_input_channel,
- void *outptr, const WinogradDomainSpec &wds,
- unsigned int thread_id, unsigned int n_threads) const
+ void execute(const ConvolutionArgs &args,
+ const void *inptr,
+ size_t ld_in_row,
+ size_t ld_in_col,
+ size_t ld_input_channel,
+ void *outptr,
+ const WinogradDomainSpec &wds,
+ unsigned int thread_id,
+ unsigned int n_threads) const
{
- this->execute(
- args, inptr, ld_in_row, ld_in_col, ld_input_channel,
- outptr, wds.weight_ld_matrix, wds.weight_ld_row,
- thread_id, n_threads);
+ this->execute(args, inptr, ld_in_row, ld_in_col, ld_input_channel, outptr, wds.weight_ld_matrix,
+ wds.weight_ld_row, thread_id, n_threads);
}
- virtual void execute(
- const ConvolutionArgs &args,
- const void *inptr, size_t ld_in_row, size_t ld_in_col, size_t ld_input_channel,
- void *outptr, size_t ld_out_matrix, size_t ld_out_row,
- unsigned int thread_id, unsigned int n_threads) const = 0;
+ virtual void execute(const ConvolutionArgs &args,
+ const void *inptr,
+ size_t ld_in_row,
+ size_t ld_in_col,
+ size_t ld_input_channel,
+ void *outptr,
+ size_t ld_out_matrix,
+ size_t ld_out_row,
+ unsigned int thread_id,
+ unsigned int n_threads) const = 0;
};
} // namespace weight_transform
@@ -136,27 +150,35 @@ public:
virtual unsigned int get_input_rows(void) const = 0;
virtual unsigned int get_input_cols(void) const = 0;
- virtual size_t get_working_space_size(
- const ConvolutionArgs &args,
- unsigned int n_threads) const = 0;
-
- void execute(
- const ConvolutionArgs &args,
- const void *inptr, size_t ld_in_batch, size_t ld_in_row, size_t ld_in_col,
- void *outptr, const WinogradDomainSpec &wds,
- void *working_space, unsigned int thread_id, unsigned int n_threads) const
+ virtual size_t get_working_space_size(const ConvolutionArgs &args, unsigned int n_threads) const = 0;
+
+ void execute(const ConvolutionArgs &args,
+ const void *inptr,
+ size_t ld_in_batch,
+ size_t ld_in_row,
+ size_t ld_in_col,
+ void *outptr,
+ const WinogradDomainSpec &wds,
+ void *working_space,
+ unsigned int thread_id,
+ unsigned int n_threads) const
{
- this->execute(
- args, inptr, ld_in_batch, ld_in_row, ld_in_col,
- outptr, wds.input_ld_batch, wds.input_ld_matrix, wds.input_ld_row,
- working_space, thread_id, n_threads);
+ this->execute(args, inptr, ld_in_batch, ld_in_row, ld_in_col, outptr, wds.input_ld_batch, wds.input_ld_matrix,
+ wds.input_ld_row, working_space, thread_id, n_threads);
}
- virtual void execute(
- const ConvolutionArgs &args,
- const void *inptr, size_t ld_in_batch, size_t ld_in_row, size_t ld_in_col,
- void *outptr, size_t ld_out_batch, size_t ld_out_matrix, size_t ld_out_row,
- void *working_space, unsigned int thread_id, unsigned int n_threads) const = 0;
+ virtual void execute(const ConvolutionArgs &args,
+ const void *inptr,
+ size_t ld_in_batch,
+ size_t ld_in_row,
+ size_t ld_in_col,
+ void *outptr,
+ size_t ld_out_batch,
+ size_t ld_out_matrix,
+ size_t ld_out_row,
+ void *working_space,
+ unsigned int thread_id,
+ unsigned int n_threads) const = 0;
};
} // namespace input_transform
@@ -177,31 +199,37 @@ public:
virtual unsigned int get_kernel_rows(void) const = 0;
virtual unsigned int get_kernel_cols(void) const = 0;
- virtual size_t get_working_space_size(
- const ConvolutionArgs &args,
- unsigned int n_threads) const = 0;
-
- void execute(
- const ConvolutionArgs &args,
- const void *inptr, const WinogradDomainSpec &wds,
- const void *bias,
- void *outptr, size_t ld_out_batch, size_t ld_out_row, size_t ld_out_col,
- void *working_space, unsigned int thread_id, unsigned int n_threads) const
+ virtual size_t get_working_space_size(const ConvolutionArgs &args, unsigned int n_threads) const = 0;
+
+ void execute(const ConvolutionArgs &args,
+ const void *inptr,
+ const WinogradDomainSpec &wds,
+ const void *bias,
+ void *outptr,
+ size_t ld_out_batch,
+ size_t ld_out_row,
+ size_t ld_out_col,
+ void *working_space,
+ unsigned int thread_id,
+ unsigned int n_threads) const
{
- this->execute(
- args,
- inptr, wds.output_ld_batch, wds.output_ld_matrix, wds.output_ld_row,
- bias,
- outptr, ld_out_batch, ld_out_row, ld_out_col,
- working_space, thread_id, n_threads);
+ this->execute(args, inptr, wds.output_ld_batch, wds.output_ld_matrix, wds.output_ld_row, bias, outptr,
+ ld_out_batch, ld_out_row, ld_out_col, working_space, thread_id, n_threads);
}
- virtual void execute(
- const ConvolutionArgs &args,
- const void *inptr, size_t ld_in_batch, size_t ld_in_matrix, size_t ld_in_row,
- const void *bias,
- void *outptr, size_t ld_out_batch, size_t ld_out_row, size_t ld_out_col,
- void *working_space, unsigned int thread_id, unsigned int n_threads) const = 0;
+ virtual void execute(const ConvolutionArgs &args,
+ const void *inptr,
+ size_t ld_in_batch,
+ size_t ld_in_matrix,
+ size_t ld_in_row,
+ const void *bias,
+ void *outptr,
+ size_t ld_out_batch,
+ size_t ld_out_row,
+ size_t ld_out_col,
+ void *working_space,
+ unsigned int thread_id,
+ unsigned int n_threads) const = 0;
};
} // namespace output_transform
@@ -210,7 +238,7 @@ struct WinogradImpl
{
const output_transform::ITransform *output_transform = nullptr;
const weight_transform::ITransform *weight_transform = nullptr;
- const input_transform::ITransform *input_transform = nullptr;
+ const input_transform::ITransform *input_transform = nullptr;
std::unique_ptr<arm_gemm::GemmArgs> gemm_args;
WinogradDomainSpec winograd_spec;
};
@@ -220,15 +248,18 @@ struct WinogradImpl
* Assigns to the pointers in the `dest` struct and returns true or false to
* indicate whether the given problem can be executed or not.
*/
-template <typename TIn, typename TWeight = TIn, typename TOut = TIn, typename TWinogradIn = TIn, typename TWinogradOut = TOut>
-bool get_implementation(
- WinogradImpl &dest, // Destination for the selected implementation
- const CPUInfo *,
- const ConvolutionArgs &,
- int max_threads,
- bool fast_mode,
- const WinogradConfig *,
- const arm_gemm::GemmConfig *);
+template <typename TIn,
+ typename TWeight = TIn,
+ typename TOut = TIn,
+ typename TWinogradIn = TIn,
+ typename TWinogradOut = TOut>
+bool get_implementation(WinogradImpl &dest, // Destination for the selected implementation
+ const CPUInfo *,
+ const ConvolutionArgs &,
+ int max_threads,
+ bool fast_mode,
+ const WinogradConfig *,
+ const arm_gemm::GemmConfig *);
} // namespace winograd
} // namespace arm_conv