aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/kernels/CLGEMMInterleave4x4Kernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/kernels/CLGEMMInterleave4x4Kernel.cpp')
-rw-r--r--src/core/CL/kernels/CLGEMMInterleave4x4Kernel.cpp21
1 files changed, 14 insertions, 7 deletions
diff --git a/src/core/CL/kernels/CLGEMMInterleave4x4Kernel.cpp b/src/core/CL/kernels/CLGEMMInterleave4x4Kernel.cpp
index 241dd8549d..d12255ff24 100644
--- a/src/core/CL/kernels/CLGEMMInterleave4x4Kernel.cpp
+++ b/src/core/CL/kernels/CLGEMMInterleave4x4Kernel.cpp
@@ -80,8 +80,12 @@ std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITen
output_access.set_valid_region(win, input->valid_region());
}
+ // Collapse along the Z direction
+ // This collapse needs to be here in order to tune the Z dimension of LWS
+ Window collapsed = win.collapse(win, Window::DimZ);
+
Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
- return std::make_pair(err, win);
+ return std::make_pair(err, collapsed);
}
} // namespace
@@ -136,6 +140,10 @@ void CLGEMMInterleave4x4Kernel::configure(const ICLTensor *input, ICLTensor *out
_config_id += support::cpp11::to_string(output->info()->dimension(0));
_config_id += "_";
_config_id += support::cpp11::to_string(output->info()->dimension(1));
+ _config_id += "_";
+ _config_id += support::cpp11::to_string(output->info()->dimension(2));
+ _config_id += "_";
+ _config_id += support::cpp11::to_string(output->info()->dimension(3));
}
Status CLGEMMInterleave4x4Kernel::validate(const ITensorInfo *input, const ITensorInfo *output, int mult_interleave4x4_height)
@@ -160,15 +168,14 @@ void CLGEMMInterleave4x4Kernel::run(const Window &window, cl::CommandQueue &queu
*
* After this operation, the output matrix will have the following shape: [ height * 4, width / 4 ]
*/
- Window in_slice = window.first_slice_window_2D();
- Window out_slice = window.first_slice_window_2D();
+ Window slice = window.first_slice_window_3D();
do
{
unsigned int idx = 0;
- add_2D_tensor_argument(idx, _input, in_slice);
- add_2D_tensor_argument(idx, _output, out_slice);
- enqueue(queue, *this, in_slice, _lws_hint);
+ add_3D_tensor_argument(idx, _input, slice);
+ add_3D_tensor_argument(idx, _output, slice);
+ enqueue(queue, *this, slice, _lws_hint);
}
- while(window.slide_window_slice_2D(in_slice) && window.slide_window_slice_2D(out_slice));
+ while(window.slide_window_slice_3D(slice));
}