aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.h
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.h')
-rw-r--r--arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.h55
1 files changed, 32 insertions, 23 deletions
diff --git a/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.h b/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.h
index e2b849aa3d..40b6f5da39 100644
--- a/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.h
+++ b/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.h
@@ -95,31 +95,32 @@ class NEGEMMInterleavedMatrixMultiplyWrapperTemplate : public NEGEMMInterleavedM
public:
/** Configure the matrix multiplication: C = alpha * A * B + beta * C
*
- * @param[in] prepared_a Already reshaped matrix A.
- * @param[in] transformed_b Already reshaped matrix B.
- * @param[out] tmp_c Temporary buffer to be used to store intermediate results.
- * @param[in,out] c Result matrix C.
- * @param[in] block_walker Window containing iteration information for the M and batch dimensions.
- * @param[in] block_sizes Block sizes to use for the matrix multiplication (A & B must have been reshaped using these same block sizes).
- * @param[in] params M, N, K sizes.
- * @param[in] is_pretransposed Is B also pretransposed ?
- * @param[in] alpha Alpha value
- * @param[in] beta Beta value
- * @param[in] max_num_threads Maximum number of threads that might be used for the calculations.
+ * @param[in] prepared_a Already reshaped matrix A.
+ * @param[in] transformed_b Already reshaped matrix B.
+ * @param[out] tmp_c Temporary buffer to be used to store intermediate results.
+ * @param[in,out] c Result matrix C.
+ * @param[in] block_walker Window containing iteration information for the M and batch dimensions.
+ * @param[in] block_sizes Block sizes to use for the matrix multiplication (A & B must have been reshaped using these same block sizes).
+ * @param[in] params M, N, K sizes.
+ * @param[in] gemm_info GEMM meta-data
+ * @param[in] alpha Alpha value
+ * @param[in] beta Beta value
+ * @param[in] max_num_threads Maximum number of threads that might be used for the calculations.
*/
void configure(const ITensor *prepared_a, const ITensor *transformed_b, ITensor *tmp_c, ITensor *c, const Window &block_walker, const BlockSizes &block_sizes,
- const INEGEMMWrapperKernel::Params &params, bool b_is_pretransposed, float alpha, float beta, unsigned int max_num_threads)
+ const INEGEMMWrapperKernel::Params &params, const GEMMInfo &gemm_info, float alpha, float beta, unsigned int max_num_threads)
{
- _prepared_a = prepared_a;
- _transformed_b = transformed_b;
- _tmp_c = tmp_c;
- _c = c;
- _block_walker = block_walker;
- _block_sizes = block_sizes;
- _params = params;
- _b_is_pretransposed = b_is_pretransposed;
- _alpha = alpha;
- _beta = beta;
+ _prepared_a = prepared_a;
+ _transformed_b = transformed_b;
+ _tmp_c = tmp_c;
+ _c = c;
+ _block_walker = block_walker;
+ _block_sizes = block_sizes;
+ _params = params;
+ _b_is_pretransposed = gemm_info.pretranpose_B();
+ _reinterpret_c_as_3d = gemm_info.depth_output_gemm3d() != 0;
+ _alpha = alpha;
+ _beta = beta;
auto_init_if_empty(*_tmp_c->info(), c->info()->clone()->set_tensor_shape(TensorShape{ _block_sizes.x_block * strategy::out_height(), max_num_threads }));
}
@@ -133,6 +134,14 @@ public:
TensorAccessor<typename strategy::result_type> c(*_c);
TensorAccessor<typename strategy::result_type> tmp_c(*_tmp_c);
+ // Handle 3d output re-interpretation
+ if(_reinterpret_c_as_3d)
+ {
+ Strides c_strides_as_3d = _c->info()->strides_in_bytes();
+ c_strides_as_3d.remove(Window::DimZ);
+ c.set_strides(c_strides_as_3d);
+ }
+
int prev_batch = -1;
typename strategy::operand_type *a_ptr = nullptr;
auto window_iterator = arm_compute::create_window_iterator(batch_window, start_offset, end_offset, [&](const Coordinates & id)
@@ -216,9 +225,9 @@ private:
INEGEMMWrapperKernel::Params _params{};
Window _block_walker{};
bool _b_is_pretransposed{ false };
+ bool _reinterpret_c_as_3d{ false };
typename strategy::result_type _alpha{};
typename strategy::result_type _beta{};
};
-
} // namespace arm_compute
#endif /* __ARM_COMPUTE_NEGEMMINTERLEAVEDMATRIXMULTIPLYWRAPPER_H__ */