aboutsummaryrefslogtreecommitdiff
path: root/src/core/GLES_COMPUTE/kernels/GCTensorShiftKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/GLES_COMPUTE/kernels/GCTensorShiftKernel.cpp')
-rw-r--r--src/core/GLES_COMPUTE/kernels/GCTensorShiftKernel.cpp26
1 files changed, 12 insertions, 14 deletions
diff --git a/src/core/GLES_COMPUTE/kernels/GCTensorShiftKernel.cpp b/src/core/GLES_COMPUTE/kernels/GCTensorShiftKernel.cpp
index c2182171a6..21946b7f8d 100644
--- a/src/core/GLES_COMPUTE/kernels/GCTensorShiftKernel.cpp
+++ b/src/core/GLES_COMPUTE/kernels/GCTensorShiftKernel.cpp
@@ -39,7 +39,7 @@ using namespace arm_compute;
using namespace arm_compute::gles_compute;
GCTensorShiftKernel::GCTensorShiftKernel()
- : _input(nullptr), _lws(gles::NDRange(1U, 1U, 1U))
+ : _input(nullptr), _lws(gles::NDRange(1U, 1U, 1U)), _left_padding(0)
{
}
@@ -59,18 +59,18 @@ void GCTensorShiftKernel::configure(IGCTensor *input)
options.emplace(("#define " + dt_name));
unsigned int num_elems_written_per_iteration_x = input->info()->dimension(0) + input->info()->padding().left + input->info()->padding().right;
- unsigned int num_elems_written_per_iteration_y = 1;
- unsigned int num_elems_written_per_iteration_z = 1;
std::stringstream kernel_name;
kernel_name << "tensorshift";
_kernel = static_cast<GCKernel>(GCKernelLibrary::get().create_kernel(kernel_name.str(), options));
- Window win = calculate_max_enlarged_window(*input->info(), Steps(num_elems_written_per_iteration_x, num_elems_written_per_iteration_y, num_elems_written_per_iteration_z));
- AccessWindowHorizontal input_access(input->info(), 0, num_elems_written_per_iteration_x);
+ Window win;
+ win.set(Window::DimX, Window::Dimension(0, num_elems_written_per_iteration_x, num_elems_written_per_iteration_x));
+ win.use_tensor_dimensions(input->info()->tensor_shape(), Window::DimY);
+ win.use_tensor_dimensions(input->info()->tensor_shape(), Window::DimZ);
- update_window_and_padding(win, input_access);
+ _left_padding = _input->info()->padding().left;
IGCKernel::configure(win);
}
@@ -80,6 +80,11 @@ void GCTensorShiftKernel::run(const Window &window)
ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(IKernel::window(), window);
+ if(int(_left_padding) == 0 || !_input->needs_shifting())
+ {
+ return;
+ }
+
_kernel.use();
// Get initial windows
@@ -92,14 +97,7 @@ void GCTensorShiftKernel::run(const Window &window)
add_3D_tensor_argument(idx, _input, 1, slice);
- const PaddingSize &padding1 = _input->info()->padding();
-
- if(int(padding1.left) == 0)
- {
- break;
- }
-
- _kernel.set_argument(idx++, static_cast<unsigned int>(padding1.left));
+ _kernel.set_argument(idx++, static_cast<unsigned int>(_left_padding));
_kernel.update_shader_params();
enqueue(*this, slice, _lws);