diff options
Diffstat (limited to 'arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.h')
-rw-r--r-- | arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.h | 55 |
1 files changed, 32 insertions, 23 deletions
diff --git a/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.h b/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.h index e2b849aa3d..40b6f5da39 100644 --- a/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.h +++ b/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.h @@ -95,31 +95,32 @@ class NEGEMMInterleavedMatrixMultiplyWrapperTemplate : public NEGEMMInterleavedM public: /** Configure the matrix multiplication: C = alpha * A * B + beta * C * - * @param[in] prepared_a Already reshaped matrix A. - * @param[in] transformed_b Already reshaped matrix B. - * @param[out] tmp_c Temporary buffer to be used to store intermediate results. - * @param[in,out] c Result matrix C. - * @param[in] block_walker Window containing iteration information for the M and batch dimensions. - * @param[in] block_sizes Block sizes to use for the matrix multiplication (A & B must have been reshaped using these same block sizes). - * @param[in] params M, N, K sizes. - * @param[in] is_pretransposed Is B also pretransposed ? - * @param[in] alpha Alpha value - * @param[in] beta Beta value - * @param[in] max_num_threads Maximum number of threads that might be used for the calculations. + * @param[in] prepared_a Already reshaped matrix A. + * @param[in] transformed_b Already reshaped matrix B. + * @param[out] tmp_c Temporary buffer to be used to store intermediate results. + * @param[in,out] c Result matrix C. + * @param[in] block_walker Window containing iteration information for the M and batch dimensions. + * @param[in] block_sizes Block sizes to use for the matrix multiplication (A & B must have been reshaped using these same block sizes). + * @param[in] params M, N, K sizes. + * @param[in] gemm_info GEMM meta-data + * @param[in] alpha Alpha value + * @param[in] beta Beta value + * @param[in] max_num_threads Maximum number of threads that might be used for the calculations. */ void configure(const ITensor *prepared_a, const ITensor *transformed_b, ITensor *tmp_c, ITensor *c, const Window &block_walker, const BlockSizes &block_sizes, - const INEGEMMWrapperKernel::Params ¶ms, bool b_is_pretransposed, float alpha, float beta, unsigned int max_num_threads) + const INEGEMMWrapperKernel::Params ¶ms, const GEMMInfo &gemm_info, float alpha, float beta, unsigned int max_num_threads) { - _prepared_a = prepared_a; - _transformed_b = transformed_b; - _tmp_c = tmp_c; - _c = c; - _block_walker = block_walker; - _block_sizes = block_sizes; - _params = params; - _b_is_pretransposed = b_is_pretransposed; - _alpha = alpha; - _beta = beta; + _prepared_a = prepared_a; + _transformed_b = transformed_b; + _tmp_c = tmp_c; + _c = c; + _block_walker = block_walker; + _block_sizes = block_sizes; + _params = params; + _b_is_pretransposed = gemm_info.pretranpose_B(); + _reinterpret_c_as_3d = gemm_info.depth_output_gemm3d() != 0; + _alpha = alpha; + _beta = beta; auto_init_if_empty(*_tmp_c->info(), c->info()->clone()->set_tensor_shape(TensorShape{ _block_sizes.x_block * strategy::out_height(), max_num_threads })); } @@ -133,6 +134,14 @@ public: TensorAccessor<typename strategy::result_type> c(*_c); TensorAccessor<typename strategy::result_type> tmp_c(*_tmp_c); + // Handle 3d output re-interpretation + if(_reinterpret_c_as_3d) + { + Strides c_strides_as_3d = _c->info()->strides_in_bytes(); + c_strides_as_3d.remove(Window::DimZ); + c.set_strides(c_strides_as_3d); + } + int prev_batch = -1; typename strategy::operand_type *a_ptr = nullptr; auto window_iterator = arm_compute::create_window_iterator(batch_window, start_offset, end_offset, [&](const Coordinates & id) @@ -216,9 +225,9 @@ private: INEGEMMWrapperKernel::Params _params{}; Window _block_walker{}; bool _b_is_pretransposed{ false }; + bool _reinterpret_c_as_3d{ false }; typename strategy::result_type _alpha{}; typename strategy::result_type _beta{}; }; - } // namespace arm_compute #endif /* __ARM_COMPUTE_NEGEMMINTERLEAVEDMATRIXMULTIPLYWRAPPER_H__ */ |