diff options
author | Georgios Pinitas <georgios.pinitas@arm.com> | 2017-11-14 15:32:57 +0000 |
---|---|---|
committer | Anthony Barbier <anthony.barbier@arm.com> | 2018-11-02 16:35:24 +0000 |
commit | 30902ed3befd225cb3a6915223d0941949b8d265 (patch) | |
tree | 6f701094ae52e8e4dc41b993febbff404e660c82 /src/core/CL/kernels/CLSoftmaxLayerKernel.cpp | |
parent | 77f0f879f8a9371e50fb1c5b2b5f7252b839883c (diff) | |
download | ComputeLibrary-30902ed3befd225cb3a6915223d0941949b8d265.tar.gz |
COMPMID-617: Add validation methods to ML CL functions.
Adds validation support to:
- CLDirectConvolution
- CLNormalizationLayer
- CLSoftmaxLayer
Change-Id: I9bd1e925e6db057c799169405f82ed21d20b87ee
Reviewed-on: http://mpd-gerrit.cambridge.arm.com/95939
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com>
Diffstat (limited to 'src/core/CL/kernels/CLSoftmaxLayerKernel.cpp')
-rw-r--r-- | src/core/CL/kernels/CLSoftmaxLayerKernel.cpp | 169 |
1 files changed, 130 insertions, 39 deletions
diff --git a/src/core/CL/kernels/CLSoftmaxLayerKernel.cpp b/src/core/CL/kernels/CLSoftmaxLayerKernel.cpp index 3eae9e5749..53a78f7c99 100644 --- a/src/core/CL/kernels/CLSoftmaxLayerKernel.cpp +++ b/src/core/CL/kernels/CLSoftmaxLayerKernel.cpp @@ -39,6 +39,7 @@ #include <string> using namespace arm_compute; + namespace { /** Calculates softmax parameters from the quantized input scale and scaling factor for the exponent and places them as build options. @@ -81,8 +82,7 @@ CLBuildOptions prepare_quantized_softmax_build_options(float input_scale, float void CLLogits1DMaxKernel::configure(const ICLTensor *input, ICLTensor *output) { - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QASYMM8, DataType::QS16, DataType::F16, DataType::F32); - ARM_COMPUTE_ERROR_ON_NULLPTR(output); + ARM_COMPUTE_ERROR_ON_NULLPTR(input, output); // Softmax across the x dimension TensorShape output_shape{ input->info()->tensor_shape() }; @@ -96,9 +96,8 @@ void CLLogits1DMaxKernel::configure(const ICLTensor *input, ICLTensor *output) input->info()->fixed_point_position(), input->info()->quantization_info()); - ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output); - ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, output); - ARM_COMPUTE_ERROR_ON_MISMATCHING_DIMENSIONS(output->info()->tensor_shape(), output_shape); + // Perform validation step + ARM_COMPUTE_ERROR_THROW_ON(CLLogits1DMaxKernel::validate(input->info(), output->info())); _input = input; _output = output; @@ -146,6 +145,26 @@ void CLLogits1DMaxKernel::configure(const ICLTensor *input, ICLTensor *output) _config_id += support::cpp11::to_string(input->info()->dimension(1)); } +Error CLLogits1DMaxKernel::validate(const ITensorInfo *input, const ITensorInfo *output) +{ + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QASYMM8, DataType::QS16, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(output); + + // Checks performed when output is configured + if(output->total_size() != 0) + { + // Softmax across the x dimension + TensorShape output_shape{ input->tensor_shape() }; + output_shape.set(0, 1); + + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, output); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), output_shape); + } + + return Error{}; +} + CLLogits1DShiftExpSumKernel::CLLogits1DShiftExpSumKernel() : _input(nullptr), _max(nullptr), _output(nullptr), _sum(nullptr) { @@ -153,8 +172,7 @@ CLLogits1DShiftExpSumKernel::CLLogits1DShiftExpSumKernel() void CLLogits1DShiftExpSumKernel::configure(const ICLTensor *input, const ICLTensor *max, ICLTensor *output, ICLTensor *sum, float beta) { - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QASYMM8, DataType::QS16, DataType::F16, DataType::F32); - ARM_COMPUTE_ERROR_ON_NULLPTR(max, sum, output); + ARM_COMPUTE_ERROR_ON_NULLPTR(input, max, sum, output); const bool is_quantized_asymmetric = is_data_type_quantized_asymmetric(input->info()->data_type()); const DataType tmp_data_type = is_quantized_asymmetric ? DataType::S32 : input->info()->data_type(); @@ -163,18 +181,8 @@ void CLLogits1DShiftExpSumKernel::configure(const ICLTensor *input, const ICLTen auto_init_if_empty(*sum->info(), max->info()->tensor_shape(), 1, tmp_data_type, input->info()->fixed_point_position()); auto_init_if_empty(*output->info(), input->info()->tensor_shape(), 1, tmp_data_type, input->info()->fixed_point_position()); - ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input, output); - ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(max, sum); - if(is_quantized_asymmetric) - { - ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, max); - ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(output, sum); - } - else - { - ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output, max, sum); - ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, output, max, sum); - } + // Perform validation step + ARM_COMPUTE_ERROR_THROW_ON(CLLogits1DShiftExpSumKernel::validate(input->info(), max->info(), output->info(), sum->info())); _input = input; _max = max; @@ -224,6 +232,46 @@ void CLLogits1DShiftExpSumKernel::configure(const ICLTensor *input, const ICLTen ICLKernel::configure(win); } +Error CLLogits1DShiftExpSumKernel::validate(const ITensorInfo *input, const ITensorInfo *max, const ITensorInfo *output, const ITensorInfo *sum) +{ + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QASYMM8, DataType::QS16, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(max, sum, output); + + const bool is_quantized_asymmetric = is_data_type_quantized_asymmetric(input->data_type()); + + // Checks performed when output is configured + if(output->total_size() != 0) + { + if(is_quantized_asymmetric) + { + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::S32); + } + else + { + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output); + } + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, output); + } + + // Checks performed when sum is configured + if(sum->total_size() != 0) + { + if(is_quantized_asymmetric) + { + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(sum, 1, DataType::S32); + } + else + { + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(max, sum); + } + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(max, sum); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(max, sum); + } + + return Error{}; +} + void CLLogits1DShiftExpSumKernel::run(const Window &window, cl::CommandQueue &queue) { ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); @@ -259,17 +307,14 @@ CLLogits1DMaxShiftExpSumKernel::CLLogits1DMaxShiftExpSumKernel() void CLLogits1DMaxShiftExpSumKernel::configure(const ICLTensor *input, ICLTensor *max, ICLTensor *output, ICLTensor *sum, float beta) { - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32); - ARM_COMPUTE_ERROR_ON_NULLPTR(max, sum, output); + ARM_COMPUTE_ERROR_ON_NULLPTR(input, max, sum, output); // Output auto initialization if not yet initialized auto_init_if_empty(*sum->info(), max->info()->tensor_shape(), 1, input->info()->data_type(), input->info()->fixed_point_position()); auto_init_if_empty(*output->info(), input->info()->tensor_shape(), 1, input->info()->data_type(), input->info()->fixed_point_position()); - ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output, max, sum); - ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, output, max, sum); - ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input, output); - ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(max, sum); + // Perform validation step + ARM_COMPUTE_ERROR_THROW_ON(CLLogits1DMaxShiftExpSumKernel::validate(input->info(), max->info(), output->info(), sum->info())); _input = input; _max = max; @@ -336,6 +381,33 @@ void CLLogits1DMaxShiftExpSumKernel::configure(const ICLTensor *input, ICLTensor ICLKernel::configure(win); } +Error CLLogits1DMaxShiftExpSumKernel::validate(const ITensorInfo *input, const ITensorInfo *max, const ITensorInfo *output, const ITensorInfo *sum) +{ + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(max, sum, output); + + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, max); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, max); + + // Checks performed when output is configured + if(output->total_size() != 0) + { + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, output); + } + + // Checks performed when sum is configured + if(sum->total_size() != 0) + { + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(max, sum); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(max, sum); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(max, sum); + } + + return Error{}; +} + CLLogits1DMaxShiftExpSumKernel::ParallelReductionInfo CLLogits1DMaxShiftExpSumKernel::is_parallel_reduction(size_t size) { bool is_parallel_reduction = (size >= (_grid_size * _serial_vector_size)) && (_grid_size > 1); @@ -382,10 +454,7 @@ CLLogits1DNormKernel::CLLogits1DNormKernel() void CLLogits1DNormKernel::configure(const ICLTensor *input, const ICLTensor *sum, ICLTensor *output, float beta) { - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::S32, DataType::F16, DataType::F32); - ARM_COMPUTE_ERROR_ON_NULLPTR(sum, output); - ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, sum); - ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, sum); + ARM_COMPUTE_ERROR_ON_NULLPTR(input, sum, output); // Note: output should always have a scale of 1/256 and offset 0 const QuantizationInfo allowed_quantization_info = QuantizationInfo(1.f / 256, 0); @@ -396,16 +465,8 @@ void CLLogits1DNormKernel::configure(const ICLTensor *input, const ICLTensor *su auto_init_if_empty(*output->info(), input->info()->clone()->set_data_type(output_data_type).set_quantization_info(allowed_quantization_info)); - ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input, output); - if(!is_quantized_asymmetric) - { - ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output); - ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, output); - } - else - { - ARM_COMPUTE_ERROR_ON(output->info()->quantization_info() != allowed_quantization_info); - } + // Perform validation step + ARM_COMPUTE_ERROR_THROW_ON(CLLogits1DNormKernel::validate(input->info(), sum->info(), output->info())); _input = input; _sum = sum; @@ -439,6 +500,36 @@ void CLLogits1DNormKernel::configure(const ICLTensor *input, const ICLTensor *su ICLKernel::configure(win); } +Error CLLogits1DNormKernel::validate(const ITensorInfo *input, const ITensorInfo *sum, const ITensorInfo *output) +{ + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::S32, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(sum, output); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, sum); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, sum); + + // Note: output should always have a scale of 1/256 and offset 0 + const QuantizationInfo allowed_quantization_info = QuantizationInfo(1.f / 256, 0); + const bool is_quantized_asymmetric = (input->data_type() == DataType::S32); + + // Checks performed when output is configured + if(output->total_size() != 0) + { + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, output); + if(!is_quantized_asymmetric) + { + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output); + } + else + { + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QASYMM8); + ARM_COMPUTE_RETURN_ERROR_ON(output->quantization_info() != allowed_quantization_info); + } + } + + return Error{}; +} + void CLLogits1DNormKernel::run(const Window &window, cl::CommandQueue &queue) { ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); |