aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/kernels/CLSoftmaxLayerKernel.cpp
diff options
context:
space:
mode:
authorMoritz Pflanzer <moritz.pflanzer@arm.com>2017-07-27 12:45:30 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-09-17 14:16:42 +0100
commitba9b3f58583331100346cf4fad16d51aa11a3677 (patch)
tree4f93ec2cd9e4ed90c46a3ea0b37cb07ad014e3de /src/core/CL/kernels/CLSoftmaxLayerKernel.cpp
parent637540ac6d25193f9ae1ecb57abfad40f47edd75 (diff)
downloadComputeLibrary-ba9b3f58583331100346cf4fad16d51aa11a3677.tar.gz
COMPMID-417: Fix F16 CLSoftmaxLayer
Change-Id: I231b1fcaea8bfb11f8306bc71fdde78fadeed60d Reviewed-on: http://mpd-gerrit.cambridge.arm.com/81832 Reviewed-by: Steven Niu <steven.niu@arm.com> Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com>
Diffstat (limited to 'src/core/CL/kernels/CLSoftmaxLayerKernel.cpp')
-rw-r--r--src/core/CL/kernels/CLSoftmaxLayerKernel.cpp8
1 files changed, 8 insertions, 0 deletions
diff --git a/src/core/CL/kernels/CLSoftmaxLayerKernel.cpp b/src/core/CL/kernels/CLSoftmaxLayerKernel.cpp
index 0e81fc7aa4..da3b9423d5 100644
--- a/src/core/CL/kernels/CLSoftmaxLayerKernel.cpp
+++ b/src/core/CL/kernels/CLSoftmaxLayerKernel.cpp
@@ -68,6 +68,10 @@ void CLLogits1DMaxKernel::configure(const ICLTensor *input, ICLTensor *output)
{
build_opts.emplace(("-DFIXED_POINT_POSITION=" + support::cpp11::to_string(input->info()->fixed_point_position())));
}
+ else if(input->info()->data_type() == DataType::F16)
+ {
+ build_opts.emplace("-DUSE_F16");
+ }
// Tell the kernel that the width is not a multiple of 16
if((input->info()->dimension(0) % max_cl_vector_width) != 0)
@@ -130,6 +134,10 @@ void CLLogits1DShiftExpSumKernel::configure(const ICLTensor *input, const ICLTen
{
build_opts.emplace(("-DFIXED_POINT_POSITION=" + support::cpp11::to_string(input->info()->fixed_point_position())));
}
+ else if(input->info()->data_type() == DataType::F16)
+ {
+ build_opts.emplace("-DUSE_F16");
+ }
// Tell the kernel that the width is not a multiple of 16
if((input->info()->dimension(0) % max_cl_vector_width) != 0)