aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/kernels/CLSoftmaxLayerKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/kernels/CLSoftmaxLayerKernel.cpp')
-rw-r--r--src/core/CL/kernels/CLSoftmaxLayerKernel.cpp169
1 files changed, 130 insertions, 39 deletions
diff --git a/src/core/CL/kernels/CLSoftmaxLayerKernel.cpp b/src/core/CL/kernels/CLSoftmaxLayerKernel.cpp
index 3eae9e5749..53a78f7c99 100644
--- a/src/core/CL/kernels/CLSoftmaxLayerKernel.cpp
+++ b/src/core/CL/kernels/CLSoftmaxLayerKernel.cpp
@@ -39,6 +39,7 @@
#include <string>
using namespace arm_compute;
+
namespace
{
/** Calculates softmax parameters from the quantized input scale and scaling factor for the exponent and places them as build options.
@@ -81,8 +82,7 @@ CLBuildOptions prepare_quantized_softmax_build_options(float input_scale, float
void CLLogits1DMaxKernel::configure(const ICLTensor *input, ICLTensor *output)
{
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QASYMM8, DataType::QS16, DataType::F16, DataType::F32);
- ARM_COMPUTE_ERROR_ON_NULLPTR(output);
+ ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
// Softmax across the x dimension
TensorShape output_shape{ input->info()->tensor_shape() };
@@ -96,9 +96,8 @@ void CLLogits1DMaxKernel::configure(const ICLTensor *input, ICLTensor *output)
input->info()->fixed_point_position(),
input->info()->quantization_info());
- ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
- ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, output);
- ARM_COMPUTE_ERROR_ON_MISMATCHING_DIMENSIONS(output->info()->tensor_shape(), output_shape);
+ // Perform validation step
+ ARM_COMPUTE_ERROR_THROW_ON(CLLogits1DMaxKernel::validate(input->info(), output->info()));
_input = input;
_output = output;
@@ -146,6 +145,26 @@ void CLLogits1DMaxKernel::configure(const ICLTensor *input, ICLTensor *output)
_config_id += support::cpp11::to_string(input->info()->dimension(1));
}
+Error CLLogits1DMaxKernel::validate(const ITensorInfo *input, const ITensorInfo *output)
+{
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QASYMM8, DataType::QS16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(output);
+
+ // Checks performed when output is configured
+ if(output->total_size() != 0)
+ {
+ // Softmax across the x dimension
+ TensorShape output_shape{ input->tensor_shape() };
+ output_shape.set(0, 1);
+
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, output);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), output_shape);
+ }
+
+ return Error{};
+}
+
CLLogits1DShiftExpSumKernel::CLLogits1DShiftExpSumKernel()
: _input(nullptr), _max(nullptr), _output(nullptr), _sum(nullptr)
{
@@ -153,8 +172,7 @@ CLLogits1DShiftExpSumKernel::CLLogits1DShiftExpSumKernel()
void CLLogits1DShiftExpSumKernel::configure(const ICLTensor *input, const ICLTensor *max, ICLTensor *output, ICLTensor *sum, float beta)
{
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QASYMM8, DataType::QS16, DataType::F16, DataType::F32);
- ARM_COMPUTE_ERROR_ON_NULLPTR(max, sum, output);
+ ARM_COMPUTE_ERROR_ON_NULLPTR(input, max, sum, output);
const bool is_quantized_asymmetric = is_data_type_quantized_asymmetric(input->info()->data_type());
const DataType tmp_data_type = is_quantized_asymmetric ? DataType::S32 : input->info()->data_type();
@@ -163,18 +181,8 @@ void CLLogits1DShiftExpSumKernel::configure(const ICLTensor *input, const ICLTen
auto_init_if_empty(*sum->info(), max->info()->tensor_shape(), 1, tmp_data_type, input->info()->fixed_point_position());
auto_init_if_empty(*output->info(), input->info()->tensor_shape(), 1, tmp_data_type, input->info()->fixed_point_position());
- ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input, output);
- ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(max, sum);
- if(is_quantized_asymmetric)
- {
- ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, max);
- ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(output, sum);
- }
- else
- {
- ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output, max, sum);
- ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, output, max, sum);
- }
+ // Perform validation step
+ ARM_COMPUTE_ERROR_THROW_ON(CLLogits1DShiftExpSumKernel::validate(input->info(), max->info(), output->info(), sum->info()));
_input = input;
_max = max;
@@ -224,6 +232,46 @@ void CLLogits1DShiftExpSumKernel::configure(const ICLTensor *input, const ICLTen
ICLKernel::configure(win);
}
+Error CLLogits1DShiftExpSumKernel::validate(const ITensorInfo *input, const ITensorInfo *max, const ITensorInfo *output, const ITensorInfo *sum)
+{
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QASYMM8, DataType::QS16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(max, sum, output);
+
+ const bool is_quantized_asymmetric = is_data_type_quantized_asymmetric(input->data_type());
+
+ // Checks performed when output is configured
+ if(output->total_size() != 0)
+ {
+ if(is_quantized_asymmetric)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::S32);
+ }
+ else
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
+ }
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, output);
+ }
+
+ // Checks performed when sum is configured
+ if(sum->total_size() != 0)
+ {
+ if(is_quantized_asymmetric)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(sum, 1, DataType::S32);
+ }
+ else
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(max, sum);
+ }
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(max, sum);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(max, sum);
+ }
+
+ return Error{};
+}
+
void CLLogits1DShiftExpSumKernel::run(const Window &window, cl::CommandQueue &queue)
{
ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
@@ -259,17 +307,14 @@ CLLogits1DMaxShiftExpSumKernel::CLLogits1DMaxShiftExpSumKernel()
void CLLogits1DMaxShiftExpSumKernel::configure(const ICLTensor *input, ICLTensor *max, ICLTensor *output, ICLTensor *sum, float beta)
{
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
- ARM_COMPUTE_ERROR_ON_NULLPTR(max, sum, output);
+ ARM_COMPUTE_ERROR_ON_NULLPTR(input, max, sum, output);
// Output auto initialization if not yet initialized
auto_init_if_empty(*sum->info(), max->info()->tensor_shape(), 1, input->info()->data_type(), input->info()->fixed_point_position());
auto_init_if_empty(*output->info(), input->info()->tensor_shape(), 1, input->info()->data_type(), input->info()->fixed_point_position());
- ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output, max, sum);
- ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, output, max, sum);
- ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input, output);
- ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(max, sum);
+ // Perform validation step
+ ARM_COMPUTE_ERROR_THROW_ON(CLLogits1DMaxShiftExpSumKernel::validate(input->info(), max->info(), output->info(), sum->info()));
_input = input;
_max = max;
@@ -336,6 +381,33 @@ void CLLogits1DMaxShiftExpSumKernel::configure(const ICLTensor *input, ICLTensor
ICLKernel::configure(win);
}
+Error CLLogits1DMaxShiftExpSumKernel::validate(const ITensorInfo *input, const ITensorInfo *max, const ITensorInfo *output, const ITensorInfo *sum)
+{
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(max, sum, output);
+
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, max);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, max);
+
+ // Checks performed when output is configured
+ if(output->total_size() != 0)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, output);
+ }
+
+ // Checks performed when sum is configured
+ if(sum->total_size() != 0)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(max, sum);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(max, sum);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(max, sum);
+ }
+
+ return Error{};
+}
+
CLLogits1DMaxShiftExpSumKernel::ParallelReductionInfo CLLogits1DMaxShiftExpSumKernel::is_parallel_reduction(size_t size)
{
bool is_parallel_reduction = (size >= (_grid_size * _serial_vector_size)) && (_grid_size > 1);
@@ -382,10 +454,7 @@ CLLogits1DNormKernel::CLLogits1DNormKernel()
void CLLogits1DNormKernel::configure(const ICLTensor *input, const ICLTensor *sum, ICLTensor *output, float beta)
{
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::S32, DataType::F16, DataType::F32);
- ARM_COMPUTE_ERROR_ON_NULLPTR(sum, output);
- ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, sum);
- ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, sum);
+ ARM_COMPUTE_ERROR_ON_NULLPTR(input, sum, output);
// Note: output should always have a scale of 1/256 and offset 0
const QuantizationInfo allowed_quantization_info = QuantizationInfo(1.f / 256, 0);
@@ -396,16 +465,8 @@ void CLLogits1DNormKernel::configure(const ICLTensor *input, const ICLTensor *su
auto_init_if_empty(*output->info(),
input->info()->clone()->set_data_type(output_data_type).set_quantization_info(allowed_quantization_info));
- ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input, output);
- if(!is_quantized_asymmetric)
- {
- ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
- ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, output);
- }
- else
- {
- ARM_COMPUTE_ERROR_ON(output->info()->quantization_info() != allowed_quantization_info);
- }
+ // Perform validation step
+ ARM_COMPUTE_ERROR_THROW_ON(CLLogits1DNormKernel::validate(input->info(), sum->info(), output->info()));
_input = input;
_sum = sum;
@@ -439,6 +500,36 @@ void CLLogits1DNormKernel::configure(const ICLTensor *input, const ICLTensor *su
ICLKernel::configure(win);
}
+Error CLLogits1DNormKernel::validate(const ITensorInfo *input, const ITensorInfo *sum, const ITensorInfo *output)
+{
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::S32, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(sum, output);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, sum);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, sum);
+
+ // Note: output should always have a scale of 1/256 and offset 0
+ const QuantizationInfo allowed_quantization_info = QuantizationInfo(1.f / 256, 0);
+ const bool is_quantized_asymmetric = (input->data_type() == DataType::S32);
+
+ // Checks performed when output is configured
+ if(output->total_size() != 0)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, output);
+ if(!is_quantized_asymmetric)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
+ }
+ else
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QASYMM8);
+ ARM_COMPUTE_RETURN_ERROR_ON(output->quantization_info() != allowed_quantization_info);
+ }
+ }
+
+ return Error{};
+}
+
void CLLogits1DNormKernel::run(const Window &window, cl::CommandQueue &queue)
{
ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);