aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/assembly/NEGEMMInterleavedStrategies.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/assembly/NEGEMMInterleavedStrategies.h')
-rw-r--r--src/core/NEON/kernels/assembly/NEGEMMInterleavedStrategies.h37
1 files changed, 20 insertions, 17 deletions
diff --git a/src/core/NEON/kernels/assembly/NEGEMMInterleavedStrategies.h b/src/core/NEON/kernels/assembly/NEGEMMInterleavedStrategies.h
index 26d9e9999d..6e30148b5d 100644
--- a/src/core/NEON/kernels/assembly/NEGEMMInterleavedStrategies.h
+++ b/src/core/NEON/kernels/assembly/NEGEMMInterleavedStrategies.h
@@ -76,32 +76,34 @@ public:
* @param[in] transformed_a Reshaped tensor A.
* @param[in] block_walker Window representing the layout of the matrix's blocks.
* @param[in] params M, N, K sizes.
+ * @param[in] gemm_info GEMM meta-data
*
* @return A wrapped specialized transformA kernel
*/
virtual std::unique_ptr<NEGEMMInterleavedTransformAWrapper> instantiate_transformA(const ITensor *a,
ITensor *transformed_a,
const Window &block_walker,
- const INEGEMMWrapperKernel::Params &params) = 0;
+ const INEGEMMWrapperKernel::Params &params,
+ const GEMMInfo &gemm_info) = 0;
/** Instantiate and configure a prepareB Kernel
*
- * @param transformed_a Already reshaped tensor A.
- * @param transformed_b Already reshaped tensor B.
- * @param tmp_c Temporary buffer to be used to store intermediate results.
- * @param c Result tensor C.
- * @param block_walker Window containing iteration information for the M and batch dimensions.
- * @param block_sizes Block sizes to use for the matrix multiplication (A & B must have been reshaped using these same block sizes).
- * @param params M, N, K sizes.
- * @param alpha Alpha value
- * @param beta Beta value
- * @param pretranspose_b Is B also pretransposed ?
- * @param num_threads Maximum number of threads that might be used for the calculations.
+ * @param[in] transformed_a Already reshaped tensor A.
+ * @param[in] transformed_b Already reshaped tensor B.
+ * @param[in] tmp_c Temporary buffer to be used to store intermediate results.
+ * @param[in] c Result tensor C.
+ * @param[in] block_walker Window containing iteration information for the M and batch dimensions.
+ * @param[in] block_sizes Block sizes to use for the matrix multiplication (A & B must have been reshaped using these same block sizes).
+ * @param[in] params M, N, K sizes.
+ * @param[in] alpha Alpha value
+ * @param[in] beta Beta value
+ * @param[in] gemm_info GEMM meta-data
+ * @param[in] num_threads Maximum number of threads that might be used for the calculations.
*
* @return A wrapped specialized MatrixMultiply kernel
*/
virtual std::unique_ptr<NEGEMMInterleavedMatrixMultiplyWrapper> instantiate_matrix_multiply(const ITensor *transformed_a, const ITensor *transformed_b, ITensor *tmp_c, ITensor *c,
const Window &block_walker, const BlockSizes &block_sizes,
- const INEGEMMWrapperKernel::Params &params, float alpha, float beta, bool pretranspose_b,
+ const INEGEMMWrapperKernel::Params &params, float alpha, float beta, const GEMMInfo &gemm_info,
unsigned int num_threads) = 0;
/** Calculates the block sizes of a given strategy
*
@@ -138,19 +140,20 @@ public:
std::unique_ptr<NEGEMMInterleavedTransformAWrapper> instantiate_transformA(const ITensor *a,
ITensor *transformed_a,
const Window &block_walker,
- const INEGEMMWrapperKernel::Params &params) override
+ const INEGEMMWrapperKernel::Params &params,
+ const GEMMInfo &gemm_info) override
{
auto transform_a = support::cpp14::make_unique<NEGEMMInterleavedTransformAWrapperTemplate<strategy>>();
- transform_a->configure(a, transformed_a, false, block_walker, params);
+ transform_a->configure(a, transformed_a, false, gemm_info.reinterpret_input_as_3d(), block_walker, params);
return std::move(transform_a);
}
std::unique_ptr<NEGEMMInterleavedMatrixMultiplyWrapper> instantiate_matrix_multiply(const ITensor *transformed_a, const ITensor *transformed_b, ITensor *tmp_c, ITensor *c,
const Window &block_walker, const BlockSizes &block_sizes,
- const INEGEMMWrapperKernel::Params &params, float alpha, float beta, bool pretranspose_b,
+ const INEGEMMWrapperKernel::Params &params, float alpha, float beta, const GEMMInfo &gemm_info,
unsigned int num_threads) override
{
auto matrix_multiply = support::cpp14::make_unique<NEGEMMInterleavedMatrixMultiplyWrapperTemplate<strategy>>();
- matrix_multiply->configure(transformed_a, transformed_b, tmp_c, c, block_walker, block_sizes, params, pretranspose_b, alpha, beta, num_threads);
+ matrix_multiply->configure(transformed_a, transformed_b, tmp_c, c, block_walker, block_sizes, params, gemm_info, alpha, beta, num_threads);
return std::move(matrix_multiply);
}