aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp')
-rw-r--r--src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp15
1 files changed, 6 insertions, 9 deletions
diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
index 674937eff0..cc9ae27bed 100644
--- a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
@@ -107,7 +107,7 @@ inline Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *i
}
inline std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input0, ITensorInfo *input1, ITensorInfo *output,
- bool is_interleaved_transposed, GPUTarget gpu_target,
+ bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info, GPUTarget gpu_target,
ElementsProcessed &num_elements_processed)
{
bool window_changed = false;
@@ -117,6 +117,9 @@ inline std::pair<Status, Window> validate_and_configure_window(ITensorInfo *inpu
unsigned int &num_elems_processed_per_iteration_x = num_elements_processed[0];
unsigned int &num_elems_processed_per_iteration_y = num_elements_processed[1];
+ // Output tensor auto inizialitation if not yet initialized
+ auto_init_if_empty(*output, input0->clone()->set_tensor_shape(compute_mm_shape(*input0, *input1, is_interleaved_transposed, reshape_info)));
+
if(is_interleaved_transposed)
{
// Configure kernel window
@@ -182,13 +185,6 @@ void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTen
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input0, input1, output);
- // Output tensor auto inizialitation if not yet initialized
- TensorShape tensor_shape{ input0->info()->tensor_shape() };
- tensor_shape.set(0, is_interleaved_transposed ? reshape_info.n() : input1->info()->dimension(0));
- tensor_shape.set(1, is_interleaved_transposed ? reshape_info.m() : input0->info()->dimension(1));
-
- auto_init_if_empty(*output->info(), input0->info()->clone()->set_tensor_shape(tensor_shape));
-
// Perform validate step
ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input0->info(), input1->info(), output->info(), is_interleaved_transposed, reshape_info));
@@ -246,7 +242,7 @@ void CLGEMMMatrixMultiplyKernel::configure(const ICLTensor *input0, const ICLTen
ElementsProcessed num_elements_processed{};
// Configure kernel window
- auto win_config = validate_and_configure_window(input0->info(), input1->info(), output->info(), is_interleaved_transposed, gpu_target, num_elements_processed);
+ auto win_config = validate_and_configure_window(input0->info(), input1->info(), output->info(), is_interleaved_transposed, reshape_info, gpu_target, num_elements_processed);
ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
ICLKernel::configure(win_config.second);
@@ -354,6 +350,7 @@ Status CLGEMMMatrixMultiplyKernel::validate(const ITensorInfo *input0, const ITe
input1->clone().get(),
output->clone().get(),
is_interleaved_transposed,
+ reshape_info,
gpu_target,
num_elements_processed)
.first);