aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/kernels/CLSoftmaxLayerKernel.cpp
diff options
context:
space:
mode:
authorsteniu01 <steven.niu@arm.com>2017-07-13 14:24:23 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-09-17 14:16:42 +0100
commit0d523cccb549e4ff9dd231d033d612391ca31c85 (patch)
tree9361140af9bb4596db94f62fb07f7151a3c20ccf /src/core/CL/kernels/CLSoftmaxLayerKernel.cpp
parent00394ae1015c1eaa73f4d98fad31b7771063cd3a (diff)
downloadComputeLibrary-0d523cccb549e4ff9dd231d033d612391ca31c85.tar.gz
COMPMID-443 Change CLSoftMaxLayerKernel to use 3D tensor and collapse the higer dimension
Change-Id: I730ef45d855113d8baa7d89818441e168ea43c63 Reviewed-on: http://mpd-gerrit.cambridge.arm.com/80573 Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'src/core/CL/kernels/CLSoftmaxLayerKernel.cpp')
-rw-r--r--src/core/CL/kernels/CLSoftmaxLayerKernel.cpp28
1 files changed, 15 insertions, 13 deletions
diff --git a/src/core/CL/kernels/CLSoftmaxLayerKernel.cpp b/src/core/CL/kernels/CLSoftmaxLayerKernel.cpp
index ccaf7453d1..0e81fc7aa4 100644
--- a/src/core/CL/kernels/CLSoftmaxLayerKernel.cpp
+++ b/src/core/CL/kernels/CLSoftmaxLayerKernel.cpp
@@ -79,7 +79,7 @@ void CLLogits1DMaxKernel::configure(const ICLTensor *input, ICLTensor *output)
_kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("softmax_layer_max", build_opts));
// Set fixed arguments
- unsigned int idx = 2 * num_arguments_per_2D_tensor(); //Skip the input and output parameters
+ unsigned int idx = 2 * num_arguments_per_3D_tensor(); //Skip the input and output parameters
_kernel.setArg<cl_uint>(idx++, input->info()->dimension(0));
// Configure kernel window
@@ -141,7 +141,7 @@ void CLLogits1DShiftExpSumKernel::configure(const ICLTensor *input, const ICLTen
_kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("softmax_layer_shift_exp_sum", build_opts));
// Set fixed arguments
- unsigned int idx = 4 * num_arguments_per_2D_tensor(); //Skip the input and output parameters
+ unsigned int idx = 4 * num_arguments_per_3D_tensor(); //Skip the input and output parameters
_kernel.setArg<cl_uint>(idx++, input->info()->dimension(0));
// Configure window
@@ -165,19 +165,20 @@ void CLLogits1DShiftExpSumKernel::run(const Window &window, cl::CommandQueue &qu
ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(IKernel::window(), window);
- Window slice = window.first_slice_window_2D();
+ Window window_collapsed = window.collapse_if_possible(ICLKernel::window(), Window::DimZ);
+ Window slice = window_collapsed.first_slice_window_3D();
do
{
unsigned int idx = 0;
// Set inputs
- add_2D_tensor_argument(idx, _input, slice);
- add_2D_tensor_argument(idx, _max, slice);
- add_2D_tensor_argument(idx, _output, slice);
- add_2D_tensor_argument(idx, _sum, slice);
+ add_3D_tensor_argument(idx, _input, slice);
+ add_3D_tensor_argument(idx, _max, slice);
+ add_3D_tensor_argument(idx, _output, slice);
+ add_3D_tensor_argument(idx, _sum, slice);
enqueue(queue, *this, slice);
}
- while(window.slide_window_slice_2D(slice));
+ while(window_collapsed.slide_window_slice_3D(slice));
}
CLLogits1DNormKernel::CLLogits1DNormKernel()
@@ -233,7 +234,8 @@ void CLLogits1DNormKernel::run(const Window &window, cl::CommandQueue &queue)
ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(IKernel::window(), window);
- Window slice = window.first_slice_window_2D();
+ Window window_collapsed = window.collapse_if_possible(ICLKernel::window(), Window::DimZ);
+ Window slice = window_collapsed.first_slice_window_3D();
do
{
@@ -242,10 +244,10 @@ void CLLogits1DNormKernel::run(const Window &window, cl::CommandQueue &queue)
unsigned int idx = 0;
// Set inputs
- add_2D_tensor_argument(idx, _input, slice);
- add_2D_tensor_argument(idx, _sum, sum_slice);
- add_2D_tensor_argument(idx, _output, slice);
+ add_3D_tensor_argument(idx, _input, slice);
+ add_3D_tensor_argument(idx, _sum, sum_slice);
+ add_3D_tensor_argument(idx, _output, slice);
enqueue(queue, *this, slice);
}
- while(window.slide_window_slice_2D(slice));
+ while(window_collapsed.slide_window_slice_3D(slice));
}