diff options
Diffstat (limited to 'src/core/CL/kernels/CLActivationLayerKernel.cpp')
-rw-r--r-- | src/core/CL/kernels/CLActivationLayerKernel.cpp | 41 |
1 files changed, 28 insertions, 13 deletions
diff --git a/src/core/CL/kernels/CLActivationLayerKernel.cpp b/src/core/CL/kernels/CLActivationLayerKernel.cpp index fda69b0b94..18202c1c5b 100644 --- a/src/core/CL/kernels/CLActivationLayerKernel.cpp +++ b/src/core/CL/kernels/CLActivationLayerKernel.cpp @@ -26,6 +26,7 @@ #include "arm_compute/core/CL/CLHelpers.h" #include "arm_compute/core/CL/CLKernelLibrary.h" #include "arm_compute/core/CL/ICLTensor.h" +#include "arm_compute/core/FixedPoint.h" #include "arm_compute/core/Helpers.h" #include "arm_compute/core/IAccessWindow.h" #include "arm_compute/core/TensorInfo.h" @@ -33,6 +34,10 @@ #include "arm_compute/core/Validate.h" #include "arm_compute/core/Window.h" +#include "support/ToolchainSupport.h" + +#include <cmath> + using namespace arm_compute; CLActivationLayerKernel::CLActivationLayerKernel() @@ -42,7 +47,7 @@ CLActivationLayerKernel::CLActivationLayerKernel() void CLActivationLayerKernel::configure(ICLTensor *input, ICLTensor *output, ActivationLayerInfo act_info) { - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32); + ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32); if(output != nullptr) { @@ -54,20 +59,33 @@ void CLActivationLayerKernel::configure(ICLTensor *input, ICLTensor *output, Act ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input, output); } + const unsigned int num_elems_processed_per_iteration = 16 / input->info()->element_size(); + const int fixed_point_position = input->info()->fixed_point_position(); + float a_const = act_info.a(); + float b_const = act_info.b(); + if(is_data_type_fixed_point(input->info()->data_type())) + { + a_const = static_cast<int>(lround(a_const * (1 << fixed_point_position))); + b_const = static_cast<int>(lround(b_const * (1 << fixed_point_position))); + } + // Set build options std::set<std::string> build_opts; - build_opts.insert(("-D" + string_from_activation_func(act_info.activation()))); - build_opts.insert(("-D" + ((is_data_type_float(input->info()->data_type())) ? std::string("TYPE_FP") : std::string("TYPE_INT")))); - build_opts.insert(("-DDATA_TYPE=" + get_cl_type_from_data_type(input->info()->data_type()))); - build_opts.insert(("-DA=" + support::cpp11::to_string(act_info.a()))); - build_opts.insert(("-DB=" + support::cpp11::to_string(act_info.b()))); - build_opts.insert(output == nullptr ? "-DIN_PLACE" : ""); + build_opts.emplace(("-DACT=" + lower_string(string_from_activation_func(act_info.activation())))); + build_opts.emplace(("-DDATA_TYPE=" + get_cl_type_from_data_type(input->info()->data_type()))); + build_opts.emplace(("-DVEC_SIZE=" + support::cpp11::to_string(num_elems_processed_per_iteration))); + build_opts.emplace(("-DA_VAL=" + support::cpp11::to_string(a_const))); + build_opts.emplace(("-DB_VAL=" + support::cpp11::to_string(b_const))); + build_opts.emplace(output == nullptr ? "-DIN_PLACE" : ""); + if(is_data_type_fixed_point(input->info()->data_type())) + { + build_opts.emplace(("-DFIXED_POINT_POSITION=" + support::cpp11::to_string(fixed_point_position))); + } // Create kernel _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("activation_layer", build_opts)); // Make sure _kernel is initialized before calling the parent's configure - constexpr unsigned int num_elems_processed_per_iteration = 16; _input = input; _output = output; @@ -77,12 +95,9 @@ void CLActivationLayerKernel::configure(ICLTensor *input, ICLTensor *output, Act if(output != nullptr) { + AccessWindowHorizontal input_access(input->info(), 0, num_elems_processed_per_iteration); AccessWindowHorizontal output_access(output->info(), 0, num_elems_processed_per_iteration); - - update_window_and_padding(win, - AccessWindowHorizontal(input->info(), 0, num_elems_processed_per_iteration), - output_access); - + update_window_and_padding(win, input_access, output_access); output_access.set_valid_region(win, input->info()->valid_region()); } else |