diff options
Diffstat (limited to 'src/core/CL/kernels/CLLocallyConnectedMatrixMultiplyKernel.cpp')
-rw-r--r-- | src/core/CL/kernels/CLLocallyConnectedMatrixMultiplyKernel.cpp | 9 |
1 files changed, 5 insertions, 4 deletions
diff --git a/src/core/CL/kernels/CLLocallyConnectedMatrixMultiplyKernel.cpp b/src/core/CL/kernels/CLLocallyConnectedMatrixMultiplyKernel.cpp index 1a7d95cc2c..ad2f3a4892 100644 --- a/src/core/CL/kernels/CLLocallyConnectedMatrixMultiplyKernel.cpp +++ b/src/core/CL/kernels/CLLocallyConnectedMatrixMultiplyKernel.cpp @@ -90,13 +90,14 @@ void CLLocallyConnectedMatrixMultiplyKernel::configure(const ICLTensor *input0, _input1 = input1; _output = output; + cl::NDRange lws_hint; if(output->info()->dimension(1) == 196) { - _lws_hint = cl::NDRange(1, 7); + lws_hint = cl::NDRange(1, 7); } else { - _lws_hint = cl::NDRange(8, 8); + lws_hint = cl::NDRange(8, 8); } std::ostringstream mm_arguments; @@ -114,7 +115,7 @@ void CLLocallyConnectedMatrixMultiplyKernel::configure(const ICLTensor *input0, ARM_COMPUTE_ERROR_THROW_ON(std::get<0>(win_config)); - ICLKernel::configure(std::get<1>(win_config)); + ICLKernel::configure_internal(std::get<1>(win_config), lws_hint); } Status CLLocallyConnectedMatrixMultiplyKernel::validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output) @@ -142,7 +143,7 @@ void CLLocallyConnectedMatrixMultiplyKernel::run(const Window &window, cl::Comma add_2D_tensor_argument(idx, _input0, slice); add_3D_tensor_argument(idx, _input1, slice_matrix_b); add_2D_tensor_argument(idx, _output, slice); - enqueue(queue, *this, slice, _lws_hint); + enqueue(queue, *this, slice, lws_hint()); } while(window.slide_window_slice_2D(slice)); } |