aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--arm_compute/core/CL/kernels/CLSoftmaxLayerKernel.h3
-rw-r--r--arm_compute/core/KernelDescriptors.h5
-rw-r--r--arm_compute/core/Utils.h9
-rw-r--r--src/core/CL/cl_kernels/softmax_layer_quantized.cl103
-rw-r--r--src/core/CL/kernels/CLReshapeLayerKernel.cpp8
-rw-r--r--src/core/CL/kernels/CLSoftmaxLayerKernel.cpp44
-rw-r--r--src/core/Utils.cpp21
-rw-r--r--src/runtime/CL/functions/CLSoftmaxLayer.cpp12
-rw-r--r--tests/validation/CL/SoftmaxLayer.cpp36
-rw-r--r--tests/validation/Helpers.h2
-rw-r--r--tests/validation/fixtures/SoftmaxLayerFixture.h4
-rw-r--r--tests/validation/reference/LogSoftmaxLayer.cpp8
-rw-r--r--tests/validation/reference/LogSoftmaxLayer.h2
-rw-r--r--tests/validation/reference/SoftmaxLayer.cpp8
-rw-r--r--tests/validation/reference/SoftmaxLayer.h2
15 files changed, 175 insertions, 92 deletions
diff --git a/arm_compute/core/CL/kernels/CLSoftmaxLayerKernel.h b/arm_compute/core/CL/kernels/CLSoftmaxLayerKernel.h
index 93e403e257..f64739ae32 100644
--- a/arm_compute/core/CL/kernels/CLSoftmaxLayerKernel.h
+++ b/arm_compute/core/CL/kernels/CLSoftmaxLayerKernel.h
@@ -187,10 +187,11 @@ public:
* @param[in] input Source tensor. Data types supported: S32/F16/F32
* @param[in] sum Sum tensor. Dimensions should be dim(input)-1. Data types supported: same as @p input
* @param[in] output Destination tensor. Data types supported: QASYMM8 for S32 @p input, or same as @p input
+ * @param[in] info Contains information consumed by kernels for softmax described in @ref SoftmaxKernelInfo.
*
* @return a status
*/
- static Status validate(const ITensorInfo *input, const ITensorInfo *sum, const ITensorInfo *output);
+ static Status validate(const ITensorInfo *input, const ITensorInfo *sum, const ITensorInfo *output, const SoftmaxKernelInfo &info);
// Inherited methods overridden:
void run(const Window &window, cl::CommandQueue &queue) override;
diff --git a/arm_compute/core/KernelDescriptors.h b/arm_compute/core/KernelDescriptors.h
index 2aa076246e..f358153b0d 100644
--- a/arm_compute/core/KernelDescriptors.h
+++ b/arm_compute/core/KernelDescriptors.h
@@ -79,8 +79,9 @@ struct DWCWeightsKernelInfo
/** Descriptor used by the softmax kernels */
struct SoftmaxKernelInfo
{
- float beta{ 1.f }; /**< A scaling factor for the exponent with default value 1.0 */
- bool is_log{ false }; /**< Flag used to perform Log Softmax operation */
+ float beta{ 1.f }; /**< A scaling factor for the exponent with default value 1.0 */
+ bool is_log{ false }; /**< Flag used to perform Log Softmax operation */
+ DataType input_data_type{ DataType::UNKNOWN }; /**< Input tensor data type */
};
} // namespace arm_compute
#endif /* ARM_COMPUTE_CORE_KERNEL_DESCRIPTORS_H */
diff --git a/arm_compute/core/Utils.h b/arm_compute/core/Utils.h
index 18c5471f8f..0a7eeefded 100644
--- a/arm_compute/core/Utils.h
+++ b/arm_compute/core/Utils.h
@@ -958,6 +958,15 @@ std::pair<unsigned int, unsigned int> scaled_dimensions(unsigned int width, unsi
*/
bool needs_serialized_reduction(ReductionOperation op, DataType dt, unsigned int axis);
+/** Returns output quantization information for softmax layer
+ *
+ * @param[in] input_type The data type of the input tensor
+ * @param[in] is_log True for log softmax
+ *
+ * @return Quantization information for the output tensor
+ */
+QuantizationInfo get_softmax_output_quantization_info(DataType input_type, bool is_log);
+
/** Convert a tensor format into a string.
*
* @param[in] format @ref Format to be translated to string.
diff --git a/src/core/CL/cl_kernels/softmax_layer_quantized.cl b/src/core/CL/cl_kernels/softmax_layer_quantized.cl
index ce3bd7bc43..5d35e50b1f 100644
--- a/src/core/CL/cl_kernels/softmax_layer_quantized.cl
+++ b/src/core/CL/cl_kernels/softmax_layer_quantized.cl
@@ -63,6 +63,7 @@ __constant uint16 idx__ = (uint16)(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
#define VEC_UCHAR VEC_DATA_TYPE(uchar, VECTOR_SIZE)
#define VEC_UINT VEC_DATA_TYPE(uint, VECTOR_SIZE)
#define VEC_INT VEC_DATA_TYPE(int, VECTOR_SIZE)
+#define VEC_BASE VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
#if defined(DIFF_MIN)
@@ -141,43 +142,43 @@ __kernel void softmax_layer_max_shift_exp_sum_quantized_serial(
Image maxo = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(maxo);
Image sum = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(sum);
- VEC_UCHAR max_val_vec = 0;
+ VEC_BASE max_val_vec = (VEC_BASE)(MIN_VALUE);
// Calculate max of row
const uint width4 = width >> LOG_VECTOR_SIZE;
for(uint i = 0; i < width4; i++)
{
- VEC_UCHAR data = VLOAD(VECTOR_SIZE)(0, (__global uchar *)offset(&src, i << LOG_VECTOR_SIZE, 0));
- max_val_vec = MAX_OP(data, max_val_vec, uchar, 16);
+ VEC_BASE data = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)offset(&src, i << LOG_VECTOR_SIZE, 0));
+ max_val_vec = MAX_OP(data, max_val_vec, DATA_TYPE, 16);
}
#ifdef NON_MULTIPLE_OF_VECTOR_SIZE
// Handle non multiple of 16
- VEC_UCHAR uchar_min = (VEC_UCHAR)0;
- VEC_UCHAR data = VLOAD(VECTOR_SIZE)(0, (__global uchar *)offset(&src, width4 << LOG_VECTOR_SIZE, 0));
- VEC_UCHAR widx = CONVERT(((VEC_UINT)(width4 << LOG_VECTOR_SIZE) + idx__) < width, VEC_UCHAR);
- max_val_vec = MAX_OP(max_val_vec, select(uchar_min, data, widx), uchar, 16);
+ VEC_BASE vec_min_val = (VEC_BASE)(MIN_VALUE);
+ VEC_BASE data = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)offset(&src, width4 << LOG_VECTOR_SIZE, 0));
+ VEC_UCHAR widx = CONVERT(((VEC_UINT)(width4 << LOG_VECTOR_SIZE) + idx__) < width, VEC_UCHAR);
+ max_val_vec = MAX_OP(max_val_vec, select(vec_min_val, data, widx), DATA_TYPE, 16);
#endif /* NON_MULTIPLE_OF_VECTOR_SIZE */
// Perform max reduction
#if VECTOR_SIZE == 16
- max_val_vec.s01234567 = MAX_OP(max_val_vec.s01234567, max_val_vec.s89ABCDEF, uchar, 8);
+ max_val_vec.s01234567 = MAX_OP(max_val_vec.s01234567, max_val_vec.s89ABCDEF, DATA_TYPE, 8);
#endif /* VECTOR SIZE 16 END */
#if VECTOR_SIZE >= 8
- max_val_vec.s0123 = MAX_OP(max_val_vec.s0123, max_val_vec.s4567, uchar, 4);
+ max_val_vec.s0123 = MAX_OP(max_val_vec.s0123, max_val_vec.s4567, DATA_TYPE, 4);
#endif /* VECTOR SIZE 8 END */
#if VECTOR_SIZE >= 4
- max_val_vec.s01 = MAX_OP(max_val_vec.s01, max_val_vec.s23, uchar, 2);
+ max_val_vec.s01 = MAX_OP(max_val_vec.s01, max_val_vec.s23, DATA_TYPE, 2);
#endif /* VECTOR SIZE 4 END */
- max_val_vec.s0 = MAX_OP(max_val_vec.s0, max_val_vec.s1, uchar, 1);
+ max_val_vec.s0 = MAX_OP(max_val_vec.s0, max_val_vec.s1, DATA_TYPE, 1);
// Store result
- *((__global uchar *)maxo.ptr) = max_val_vec.s0;
+ *((__global DATA_TYPE *)maxo.ptr) = max_val_vec.s0;
// Second part
// Load max value of 1D logits vector (row)
- int max_val = convert_int(*((__global uchar *)offset(&maxo, 0, 0)));
+ int max_val = convert_int(*((__global DATA_TYPE *)offset(&maxo, 0, 0)));
// Set sum vector, Q(EXP_ACCUMULATION_INT_BITS)
VEC_INT sum1D = 0;
@@ -185,7 +186,7 @@ __kernel void softmax_layer_max_shift_exp_sum_quantized_serial(
// Shift values, exp and sum
for(uint i = 0; i < width4; i++)
{
- VEC_UCHAR data = VLOAD(VECTOR_SIZE)(0, (__global uchar *)offset(&src, i << LOG_VECTOR_SIZE, 0));
+ VEC_BASE data = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)offset(&src, i << LOG_VECTOR_SIZE, 0));
VEC_INT data_fp = CONVERT(data, VEC_INT);
VEC_INT data_diff = data_fp - max_val;
VEC_INT data_diff_mult = mult_by_quantized_multiplier_serial(data_diff);
@@ -193,12 +194,12 @@ __kernel void softmax_layer_max_shift_exp_sum_quantized_serial(
data_fp = asymm_rescale(data_fp, 0, EXP_ACCUMULATION_INT_BITS);
VSTORE(VECTOR_SIZE)
(data_diff, 0, (__global int *)offset(&dst, i << LOG_VECTOR_SIZE, 0));
- sum1D = sum1D + select(0, data_fp, data_diff >= (VEC_INT)(DIFF_MIN));
+ sum1D = sum1D + select(MIN_VALUE, data_fp, data_diff >= (VEC_INT)(DIFF_MIN));
}
#ifdef NON_MULTIPLE_OF_VECTOR_SIZE
// Handle non multiple of 16
- data = VLOAD(VECTOR_SIZE)(0, (__global uchar *)offset(&src, width4 << LOG_VECTOR_SIZE, 0));
+ data = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)offset(&src, width4 << LOG_VECTOR_SIZE, 0));
VEC_INT data_fp = CONVERT(data, VEC_INT);
VEC_INT data_diff = data_fp - max_val;
VEC_INT data_diff_mult = mult_by_quantized_multiplier_serial(data_diff);
@@ -207,21 +208,21 @@ __kernel void softmax_layer_max_shift_exp_sum_quantized_serial(
VEC_INT widx_ = CONVERT(((VEC_UINT)(width4 << LOG_VECTOR_SIZE) + idx__) < width, VEC_INT);
VSTORE(VECTOR_SIZE)
(data_diff, 0, (__global int *)offset(&dst, width4 << LOG_VECTOR_SIZE, 0));
- data_fp = select(0, data_fp, data_diff >= (VEC_INT)(DIFF_MIN));
- sum1D = sum1D + select(0, data_fp, widx_);
+ data_fp = select(MIN_VALUE, data_fp, data_diff >= (VEC_INT)(DIFF_MIN));
+ sum1D = sum1D + select(MIN_VALUE, data_fp, widx_);
#endif /* NON_MULTIPLE_OF_VECTOR_SIZE */
// Perform sum reduction
#if VECTOR_SIZE == 16
- sum1D.s01234567 = ADD_OP(sum1D.s01234567, sum1D.s89ABCDEF, uchar, 8);
+ sum1D.s01234567 = ADD_OP(sum1D.s01234567, sum1D.s89ABCDEF, DATA_TYPE, 8);
#endif /* VECTOR SIZE 16 END */
#if VECTOR_SIZE >= 8
- sum1D.s0123 = ADD_OP(sum1D.s0123, sum1D.s4567, uchar, 4);
+ sum1D.s0123 = ADD_OP(sum1D.s0123, sum1D.s4567, DATA_TYPE, 4);
#endif /* VECTOR SIZE 8 END */
#if VECTOR_SIZE >= 4
- sum1D.s01 = ADD_OP(sum1D.s01, sum1D.s23, uchar, 2);
+ sum1D.s01 = ADD_OP(sum1D.s01, sum1D.s23, DATA_TYPE, 2);
#endif /* VECTOR SIZE 4 END */
- sum1D.s0 = ADD_OP(sum1D.s0, sum1D.s1, uchar, 1);
+ sum1D.s0 = ADD_OP(sum1D.s0, sum1D.s1, DATA_TYPE, 1);
// Calculate and store result
*((__global int *)sum.ptr) = sum1D.s0;
@@ -284,10 +285,12 @@ __kernel void softmax_layer_max_shift_exp_sum_quantized_parallel(
// Define one temporary vector per work-item.
__local int4 tmp_local[GRID_SIZE];
- __local uchar max_local;
+ __local DATA_TYPE max_local;
- uchar4 uchar_min = (uchar4)0;
- uchar4 max_val_vec = uchar_min;
+ VEC_DATA_TYPE(DATA_TYPE, 4)
+ vec_min_val = (VEC_DATA_TYPE(DATA_TYPE, 4))(MIN_VALUE);
+ VEC_DATA_TYPE(DATA_TYPE, 4)
+ max_val_vec = vec_min_val;
// Number of elements per work-item.
const uint row = width / GRID_SIZE;
@@ -297,8 +300,9 @@ __kernel void softmax_layer_max_shift_exp_sum_quantized_parallel(
uint i = 0;
for(; i < width_; i++)
{
- uchar4 data_max = vload4(0, (__global uchar *)offset(&src, i * GRID_SIZE * 4, 0));
- max_val_vec = MAX_OP(data_max, max_val_vec, uchar, 4);
+ VEC_DATA_TYPE(DATA_TYPE, 4)
+ data_max = vload4(0, (__global DATA_TYPE *)offset(&src, i * GRID_SIZE * 4, 0));
+ max_val_vec = MAX_OP(data_max, max_val_vec, DATA_TYPE, 4);
}
#ifdef NON_MULTIPLE_OF_GRID_SIZE
// How many work-items needed to complete the computation.
@@ -306,8 +310,9 @@ __kernel void softmax_layer_max_shift_exp_sum_quantized_parallel(
int boundary_workitems = (width % (GRID_SIZE * 4)) / 4;
if(lid < boundary_workitems)
{
- uchar4 data_max = vload4(0, (__global uchar *)offset(&src, i * GRID_SIZE * 4, 0));
- max_val_vec = MAX_OP(data_max, max_val_vec, uchar, 4);
+ VEC_DATA_TYPE(DATA_TYPE, 4)
+ data_max = vload4(0, (__global DATA_TYPE *)offset(&src, i * GRID_SIZE * 4, 0));
+ max_val_vec = MAX_OP(data_max, max_val_vec, DATA_TYPE, 4);
}
#ifdef NON_MULTIPLE_OF_VECTOR_SIZE
if(boundary_workitems == 0)
@@ -318,9 +323,11 @@ __kernel void softmax_layer_max_shift_exp_sum_quantized_parallel(
if(lid == (boundary_workitems - 1))
{
// Handle non multiple of 4
- uchar4 data_max = vload4(0, (__global uchar *)offset(&src, (GRID_SIZE * i * 4) + 4, 0));
- uchar4 widx = convert_uchar4(((uint4)(GRID_SIZE * i * 4) + boundary_workitems * 4 + idx4) < width);
- max_val_vec = MAX_OP(max_val_vec, select(uchar_min, data_max, widx), uchar, 4);
+ VEC_DATA_TYPE(DATA_TYPE, 4)
+ data_max = vload4(0, (__global DATA_TYPE *)offset(&src, (GRID_SIZE * i * 4) + 4, 0));
+ VEC_DATA_TYPE(DATA_TYPE, 4)
+ widx = CONVERT((((uint4)(GRID_SIZE * i * 4) + boundary_workitems * 4 + idx4) < width), VEC_DATA_TYPE(DATA_TYPE, 4));
+ max_val_vec = MAX_OP(max_val_vec, select(vec_min_val, data_max, widx), DATA_TYPE, 4);
}
#endif /* NON_MULTIPLE_OF_VECTOR_SIZE */
#endif /* NON_MULTIPLE_OF_GRID_SIZE */
@@ -386,9 +393,9 @@ __kernel void softmax_layer_max_shift_exp_sum_quantized_parallel(
}
if(lid == 0)
{
- max_val_vec = MAX_OP(convert_uchar4(tmp_local[lid + 1]), convert_uchar4(tmp_local[lid]), uchar, 4);
- max_val_vec.s01 = MAX_OP(max_val_vec.s01, max_val_vec.s23, uchar, 2);
- max_val_vec.s0 = MAX_OP(max_val_vec.s0, max_val_vec.s1, uchar, 1);
+ max_val_vec = MAX_OP(CONVERT((tmp_local[lid + 1]), VEC_DATA_TYPE(DATA_TYPE, 4)), CONVERT((tmp_local[lid]), VEC_DATA_TYPE(DATA_TYPE, 4)), DATA_TYPE, 4);
+ max_val_vec.s01 = MAX_OP(max_val_vec.s01, max_val_vec.s23, DATA_TYPE, 2);
+ max_val_vec.s0 = MAX_OP(max_val_vec.s0, max_val_vec.s1, DATA_TYPE, 1);
max_local = max_val_vec.s0;
}
barrier(CLK_LOCAL_MEM_FENCE);
@@ -402,28 +409,30 @@ __kernel void softmax_layer_max_shift_exp_sum_quantized_parallel(
// Shift values, exp and sum
for(i = 0; i < width_; i++)
{
- uchar4 data = vload4(0, (__global uchar *)offset(&src, i * GRID_SIZE * 4, 0));
+ VEC_DATA_TYPE(DATA_TYPE, 4)
+ data = vload4(0, (__global DATA_TYPE *)offset(&src, i * GRID_SIZE * 4, 0));
int4 data_fp = convert_int4(data);
int4 data_diff = data_fp - max_val;
int4 data_diff_mult = mult_by_quantized_multiplier_parallel(data_diff);
data_fp = ASYMM_EXP_ON_NEGATIVE_VALUES(data_diff_mult, SCALED_DIFF_INT_BITS, 4);
data_fp = ASYMM_RESCALE(data_fp, 0, EXP_ACCUMULATION_INT_BITS, 4);
vstore4(data_diff, 0, (__global int *)offset(&dst, i * GRID_SIZE * 4, 0));
- sum1D = sum1D + select(0, data_fp, data_diff >= (int4)(DIFF_MIN));
+ sum1D = sum1D + select(MIN_VALUE, data_fp, data_diff >= (int4)(DIFF_MIN));
}
#ifdef NON_MULTIPLE_OF_GRID_SIZE
//TODO: Optimize the calculation (avoid %).
boundary_workitems = (width % (GRID_SIZE * 4)) / 4;
if(lid < boundary_workitems)
{
- uchar4 data = vload4(0, (__global uchar *)offset(&src, i * GRID_SIZE * 4, 0));
+ VEC_DATA_TYPE(DATA_TYPE, 4)
+ data = vload4(0, (__global DATA_TYPE *)offset(&src, i * GRID_SIZE * 4, 0));
int4 data_fp = convert_int4(data);
int4 data_diff = data_fp - max_val;
int4 data_diff_mult = mult_by_quantized_multiplier_parallel(data_diff);
data_fp = ASYMM_EXP_ON_NEGATIVE_VALUES(data_diff_mult, SCALED_DIFF_INT_BITS, 4);
data_fp = ASYMM_RESCALE(data_fp, 0, EXP_ACCUMULATION_INT_BITS, 4);
vstore4(data_diff, 0, (__global int *)offset(&dst, i * GRID_SIZE * 4, 0));
- sum1D = sum1D + select(0, data_fp, data_diff >= (int4)(DIFF_MIN));
+ sum1D = sum1D + select(MIN_VALUE, data_fp, data_diff >= (int4)(DIFF_MIN));
}
#ifdef NON_MULTIPLE_OF_VECTOR_SIZE
if(boundary_workitems == 0)
@@ -434,16 +443,17 @@ __kernel void softmax_layer_max_shift_exp_sum_quantized_parallel(
if(lid == (boundary_workitems - 1))
{
// Handle non multiple of vector size ((GRID_SIZE * i * 4) + 4, 0); move 4 float positions ahead, *4 is due to the stride
- uchar4 data = vload4(0, (__global uchar *)offset(&src, i * GRID_SIZE * 4 + 4, 0));
+ VEC_DATA_TYPE(DATA_TYPE, 4)
+ data = vload4(0, (__global DATA_TYPE *)offset(&src, i * GRID_SIZE * 4 + 4, 0));
int4 data_fp = convert_int4(data);
int4 data_diff = data_fp - max_val;
int4 data_diff_mult = mult_by_quantized_multiplier_parallel(data_diff);
data_fp = ASYMM_EXP_ON_NEGATIVE_VALUES(data_diff_mult, SCALED_DIFF_INT_BITS, 4);
data_fp = ASYMM_RESCALE(data_fp, 0, EXP_ACCUMULATION_INT_BITS, 4);
int4 widx = convert_int4(((uint4)(GRID_SIZE * i * 4) + boundary_workitems * 4 + idx4) < width);
- data_fp = select(0, data_fp, widx);
+ data_fp = select(MIN_VALUE, data_fp, widx);
vstore4(data_diff, 0, (__global int *)offset(&dst, i * GRID_SIZE * 4 + 4, 0));
- sum1D = sum1D + select(0, data_fp, data_diff >= (int4)(DIFF_MIN));
+ sum1D = sum1D + select(MIN_VALUE, data_fp, data_diff >= (int4)(DIFF_MIN));
}
#endif /* NON_MULTIPLE_OF_VECTOR_SIZE */
#endif /* NON_MULTIPLE_OF_GRID_SIZE */
@@ -582,13 +592,16 @@ __kernel void softmax_layer_norm_quantized(
#ifdef LOG_SOFTMAX
long16 data = SUB_OP(convert_long16(data_diff_mult), (long16)(sum_val), long, 16);
data = select(0L, data, convert_long16(data_diff) >= (long16)(DIFF_MIN));
-#else /* LOG_SOFTMAX */
+#else /* LOG_SOFTMAX */
int16 data = ASYMM_EXP_ON_NEGATIVE_VALUES(data_diff_mult, SCALED_DIFF_INT_BITS, 16);
data = ASYMM_MULT(shifted_scale, data, 16);
data = ASYMM_ROUNDING_DIVIDE_BY_POW2(data, num_bits_over_unit + 31 - 8, 16);
- data = select(0, data, data_diff >= (int16)(DIFF_MIN));
+#ifdef QASYMM8_SIGNED
+ data = ADD_OP(data, (int16)(MIN_VALUE), int, 16);
+#endif /* QASYMM8_SIGNED */
+ data = select(MIN_VALUE, data, data_diff >= (int16)(DIFF_MIN));
#endif /* LOG_SOFTMAX */
- vstore16(convert_uchar16_sat(data), 0, (__global uchar *)offset(&dst, 0, 0));
+ vstore16(CONVERT_SAT(data, VEC_DATA_TYPE(DATA_TYPE, 16)), 0, (__global DATA_TYPE *)offset(&dst, 0, 0));
}
#endif /* defined(DIFF_MIN) */
diff --git a/src/core/CL/kernels/CLReshapeLayerKernel.cpp b/src/core/CL/kernels/CLReshapeLayerKernel.cpp
index 040e442845..a6053d97e3 100644
--- a/src/core/CL/kernels/CLReshapeLayerKernel.cpp
+++ b/src/core/CL/kernels/CLReshapeLayerKernel.cpp
@@ -44,11 +44,9 @@ namespace
{
Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output)
{
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8, DataType::S8, DataType::QASYMM8,
- DataType::U16, DataType::S16,
- DataType::U32, DataType::S32, DataType::F16, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(output);
+ ARM_COMPUTE_RETURN_ERROR_ON(input->data_type() == DataType::UNKNOWN);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(input, output);
@@ -73,7 +71,7 @@ void CLReshapeLayerKernel::configure(const ICLTensor *input, ICLTensor *output)
_output = output;
// Create kernel
- std::set<std::string> build_opts = { "-DDATA_TYPE=" + get_cl_type_from_data_type(input->info()->data_type()) };
+ std::set<std::string> build_opts = { "-DDATA_TYPE=" + get_cl_unsigned_type_from_element_size(input->info()->element_size()) };
_kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("reshape_layer", build_opts));
// Add static arguments
diff --git a/src/core/CL/kernels/CLSoftmaxLayerKernel.cpp b/src/core/CL/kernels/CLSoftmaxLayerKernel.cpp
index f24c25f507..215fa838c4 100644
--- a/src/core/CL/kernels/CLSoftmaxLayerKernel.cpp
+++ b/src/core/CL/kernels/CLSoftmaxLayerKernel.cpp
@@ -84,7 +84,7 @@ CLBuildOptions prepare_quantized_softmax_build_options(float input_scale, float
Status validate_arguments_1DMaxShiftExpSum(const ITensorInfo *input, const ITensorInfo *max, const ITensorInfo *output, const ITensorInfo *sum)
{
ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(max, sum, output);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, max);
@@ -122,7 +122,7 @@ Status validate_arguments_1DMaxShiftExpSum(const ITensorInfo *input, const ITens
return Status{};
}
-Status validate_arguments_1DNorm(const ITensorInfo *input, const ITensorInfo *sum, const ITensorInfo *output)
+Status validate_arguments_1DNorm(const ITensorInfo *input, const ITensorInfo *sum, const ITensorInfo *output, const SoftmaxKernelInfo &info)
{
ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::S32, DataType::F16, DataType::F32);
@@ -130,8 +130,8 @@ Status validate_arguments_1DNorm(const ITensorInfo *input, const ITensorInfo *su
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, sum);
// Note: output should always have a scale of 1/256 and offset 0
- const QuantizationInfo allowed_quantization_info = QuantizationInfo(1.f / 256, 0);
- const bool is_quantized_asymmetric = (input->data_type() == DataType::S32);
+ const QuantizationInfo allowed_quantization_info = get_softmax_output_quantization_info(info.input_data_type, info.is_log);
+ const bool is_quantized_asymmetric = is_data_type_quantized_asymmetric(info.input_data_type);
// Checks performed when output is configured
if(output->total_size() != 0)
@@ -143,7 +143,7 @@ Status validate_arguments_1DNorm(const ITensorInfo *input, const ITensorInfo *su
}
else
{
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QASYMM8);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED);
ARM_COMPUTE_RETURN_ERROR_ON(output->quantization_info() != allowed_quantization_info);
}
}
@@ -178,11 +178,10 @@ std::pair<Status, Window> validate_and_configure_window_1DMaxShiftExpSum(ITensor
return std::make_pair(err, win);
}
-std::pair<Status, Window> validate_and_configure_window_1DNorm(ITensorInfo *input, ITensorInfo *output, ITensorInfo *sum)
+std::pair<Status, Window> validate_and_configure_window_1DNorm(ITensorInfo *input, ITensorInfo *output, ITensorInfo *sum, const SoftmaxKernelInfo &info)
{
- const QuantizationInfo allowed_quantization_info = QuantizationInfo(1.f / 256, 0);
- const bool is_quantized_asymmetric = (input->data_type() == DataType::S32);
- const DataType output_data_type = is_quantized_asymmetric ? DataType::QASYMM8 : input->data_type();
+ const DataType output_data_type = info.input_data_type;
+ const QuantizationInfo allowed_quantization_info = get_softmax_output_quantization_info(info.input_data_type, info.is_log);
// Output auto initialization if not yet initialized
auto_init_if_empty(*output,
@@ -238,10 +237,14 @@ void CLLogits1DMaxShiftExpSumKernel::configure(const ICLTensor *input, ICLTensor
const UniformQuantizationInfo qinfo = input->info()->quantization_info().uniform();
const size_t reduction_dim_size = input->info()->dimension(0);
const float beta = info.beta;
+ const auto is_signed_qasymm8 = is_data_type_quantized_asymmetric_signed(info.input_data_type);
+ const int min_value = is_signed_qasymm8 ? CL_SCHAR_MIN : 0;
// Set build options
CLBuildOptions build_opts;
build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(dt));
+ build_opts.add_option("-DMIN_VALUE=" + support::cpp11::to_string(min_value));
+ build_opts.add_option_if(is_signed_qasymm8, "-DQASYMM8_SIGNED");
build_opts.add_option_if(dt == DataType::F16, "-DUSE_F16");
build_opts.add_option_if(is_data_type_float(dt) && (beta != 1.0f), "-DBETA=" + float_to_string_with_full_precision(beta));
build_opts.add_options_if(is_data_type_quantized_asymmetric(dt), prepare_quantized_softmax_build_options(qinfo.scale, beta).options());
@@ -342,9 +345,9 @@ void CLLogits1DNormKernel::configure(const ICLTensor *input, const ICLTensor *su
ARM_COMPUTE_ERROR_ON_NULLPTR(input, sum, output);
// Note: output should always have a scale of 1/256 and offset 0
- const QuantizationInfo allowed_quantization_info = QuantizationInfo(1.F / 256, 0);
- const bool is_quantized_asymmetric = (input->info()->data_type() == DataType::S32);
- const DataType output_data_type = is_quantized_asymmetric ? DataType::QASYMM8 : input->info()->data_type();
+ const bool is_quantized_asymmetric = is_data_type_quantized_asymmetric(info.input_data_type);
+ const DataType output_data_type = info.input_data_type;
+ const QuantizationInfo allowed_quantization_info = get_softmax_output_quantization_info(info.input_data_type, info.is_log);
const UniformQuantizationInfo qinfo = input->info()->quantization_info().uniform();
// Output auto initialization if not yet initialized
@@ -352,15 +355,20 @@ void CLLogits1DNormKernel::configure(const ICLTensor *input, const ICLTensor *su
input->info()->clone()->set_data_type(output_data_type).set_quantization_info(allowed_quantization_info));
// Perform validation step
- ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_1DNorm(input->info(), sum->info(), output->info()));
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_1DNorm(input->info(), sum->info(), output->info(), info));
_input = input;
_sum = sum;
_output = output;
+ const auto is_signed_qasymm8 = is_data_type_quantized_asymmetric_signed(info.input_data_type);
+ const int min_value = is_signed_qasymm8 ? CL_SCHAR_MIN : 0;
+
// Set build options
CLBuildOptions build_opts;
- build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(input->info()->data_type()));
+ build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(info.input_data_type));
+ build_opts.add_option("-DMIN_VALUE=" + support::cpp11::to_string(min_value));
+ build_opts.add_option_if(is_data_type_quantized_asymmetric_signed(info.input_data_type), "-DQASYMM8_SIGNED");
build_opts.add_options_if(is_quantized_asymmetric,
prepare_quantized_softmax_build_options(qinfo.scale, info.beta).options());
build_opts.add_option_if(info.is_log, "-DLOG_SOFTMAX");
@@ -370,15 +378,15 @@ void CLLogits1DNormKernel::configure(const ICLTensor *input, const ICLTensor *su
_kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel(kernel_name, build_opts.options()));
// Configure window
- auto win_config = validate_and_configure_window_1DNorm(input->info(), output->info(), sum->info());
+ auto win_config = validate_and_configure_window_1DNorm(input->info(), output->info(), sum->info(), info);
ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
ICLKernel::configure_internal(win_config.second);
}
-Status CLLogits1DNormKernel::validate(const ITensorInfo *input, const ITensorInfo *sum, const ITensorInfo *output)
+Status CLLogits1DNormKernel::validate(const ITensorInfo *input, const ITensorInfo *sum, const ITensorInfo *output, const SoftmaxKernelInfo &info)
{
- ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_1DNorm(input, sum, output));
- ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window_1DNorm(input->clone().get(), output->clone().get(), sum->clone().get()).first);
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_1DNorm(input, sum, output, info));
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window_1DNorm(input->clone().get(), output->clone().get(), sum->clone().get(), info).first);
return Status{};
}
diff --git a/src/core/Utils.cpp b/src/core/Utils.cpp
index fa56118587..7e7dea5e34 100644
--- a/src/core/Utils.cpp
+++ b/src/core/Utils.cpp
@@ -438,6 +438,27 @@ bool arm_compute::needs_serialized_reduction(ReductionOperation op, DataType dt,
return !is_first_dim || is_min_max || is_quantized_type;
}
+QuantizationInfo arm_compute::get_softmax_output_quantization_info(DataType input_type, bool is_log)
+{
+ // Note: Output quantization info for softmax should always have
+ // * Softmax with QASYMM8: scale = 1/256, offset = 0
+ // * Softmax with QASYMM8_SIGNED: scale = 1/256, offset = -128
+ // * LogSoftmax with QASYMM8: scale = 1/256, offset = 0
+ // * LogSoftmax with QASYMM8_SIGNED: scale = 16/256, offset = 127
+ if(is_data_type_quantized_asymmetric_signed(input_type))
+ {
+ if(is_log)
+ {
+ return QuantizationInfo(16.f / 256, 127);
+ }
+ else
+ {
+ return QuantizationInfo(1.f / 256, -128);
+ }
+ }
+ return QuantizationInfo(1.f / 256, 0);
+}
+
#ifdef ARM_COMPUTE_ASSERTS_ENABLED
void arm_compute::print_consecutive_elements(std::ostream &s, DataType dt, const uint8_t *ptr, unsigned int n, int stream_width, const std::string &element_delim)
{
diff --git a/src/runtime/CL/functions/CLSoftmaxLayer.cpp b/src/runtime/CL/functions/CLSoftmaxLayer.cpp
index 32d7f4423d..e01d2c75ca 100644
--- a/src/runtime/CL/functions/CLSoftmaxLayer.cpp
+++ b/src/runtime/CL/functions/CLSoftmaxLayer.cpp
@@ -118,8 +118,9 @@ void CLSoftmaxLayerGeneric<IS_LOG>::configure(const ICLTensor *input, ICLTensor
_memory_group.manage(&_sum);
SoftmaxKernelInfo softmax_info;
- softmax_info.beta = beta;
- softmax_info.is_log = IS_LOG;
+ softmax_info.beta = beta;
+ softmax_info.is_log = IS_LOG;
+ softmax_info.input_data_type = input_2D->info()->data_type();
// Configure kernels
_max_shift_exp_sum_kernel.configure(input_2D, &_max, &_tmp, &_sum, softmax_info);
@@ -184,8 +185,13 @@ Status CLSoftmaxLayerGeneric<IS_LOG>::validate(const ITensorInfo *input, const I
}
}
+ SoftmaxKernelInfo softmax_info;
+ softmax_info.beta = beta;
+ softmax_info.is_log = IS_LOG;
+ softmax_info.input_data_type = input->data_type();
+
ARM_COMPUTE_RETURN_ON_ERROR(CLLogits1DMaxShiftExpSumKernel::validate(input, &tensor_info_max, &tensor_info_tmp, &tensor_info_sum));
- ARM_COMPUTE_RETURN_ON_ERROR(CLLogits1DNormKernel::validate(&tensor_info_tmp, &tensor_info_sum, output));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLLogits1DNormKernel::validate(&tensor_info_tmp, &tensor_info_sum, output, softmax_info));
if(needs_flattening)
{
diff --git a/tests/validation/CL/SoftmaxLayer.cpp b/tests/validation/CL/SoftmaxLayer.cpp
index ae7adec9f2..5ee929f6b9 100644
--- a/tests/validation/CL/SoftmaxLayer.cpp
+++ b/tests/validation/CL/SoftmaxLayer.cpp
@@ -49,6 +49,15 @@ RelativeTolerance<float> tolerance_f32(0.001f);
/** Tolerance for quantized operations */
constexpr AbsoluteTolerance<uint8_t> tolerance_qasymm8(1);
+constexpr AbsoluteTolerance<int8_t> tolerance_qasymm8_signed(1);
+
+/*
+ The following tolerance number is used as a workaround for the mismatches
+ caused by float computation in reference (and NEON) kernel
+ and integer computations in OpenCL kernel.
+ COMPMID-2958 is created to investigate this.
+*/
+constexpr float tolerance_number_qasymm8_signed = 0.05f;
/** CNN data types */
const auto CNNDataTypes = framework::dataset::make("DataType",
@@ -110,6 +119,8 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(
TensorInfo(TensorShape(32U, 13U), 1, DataType::F32),
TensorInfo(TensorShape(32U, 13U), 1, DataType::QASYMM8,
QuantizationInfo(1.f/256, 12)),
+ TensorInfo(TensorShape(32U, 13U), 1, DataType::QASYMM8_SIGNED,
+ QuantizationInfo(1.f/256, 12))
}),
framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(27U, 13U), 1, DataType::F16),
TensorInfo(TensorShape(27U, 11U), 1, DataType::F32),
@@ -120,8 +131,10 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(
TensorInfo(TensorShape(32U, 13U), 1, DataType::F32),
TensorInfo(TensorShape(32U, 13U), 1, DataType::QASYMM8,
QuantizationInfo(1.f/256, 0)),
+ TensorInfo(TensorShape(32U, 13U), 1, DataType::QASYMM8_SIGNED,
+ QuantizationInfo(1.f/256, -128)),
})),
- framework::dataset::make("Expected", { false, false, false, false, false, true, true })),
+ framework::dataset::make("Expected", { false, false, false, false, false, true, true, true })),
input_info, output_info, expected)
{
ARM_COMPUTE_EXPECT(bool(CLSoftmaxLayer::validate(&input_info.clone()->set_is_resizable(false), &output_info.clone()->set_is_resizable(false))) == expected, framework::LogLevel::ERRORS);
@@ -221,11 +234,24 @@ FIXTURE_DATA_TEST_CASE(Run4D, CLSoftmaxLayerQuantizedFixture<uint8_t>, framework
validate(CLAccessor(_target), _reference, tolerance_qasymm8);
}
-TEST_SUITE_END()
-TEST_SUITE_END()
+TEST_SUITE_END() // QASYMM8
-TEST_SUITE_END()
-TEST_SUITE_END()
+TEST_SUITE(QASYMM8_SIGNED)
+
+FIXTURE_DATA_TEST_CASE(RunSmall, CLSoftmaxLayerQuantizedFixture<int8_t>, framework::DatasetMode::ALL, combine(combine(combine(datasets::SoftmaxLayerSmallShapes(),
+ framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)),
+ combine(framework::dataset::make("QuantizationInfo", { QuantizationInfo(0.5f, -10) }),
+ framework::dataset::make("Beta", { 1.0f, 2.f }))),
+ framework::dataset::make("Axis", { 1, 2 })))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_qasymm8_signed, tolerance_number_qasymm8_signed);
+}
+
+TEST_SUITE_END() // QASYMM8_SIGNED
+TEST_SUITE_END() // Quantized
+TEST_SUITE_END() // SoftmaxLayer
+TEST_SUITE_END() // CL
} // namespace validation
} // namespace test
} // namespace arm_compute
diff --git a/tests/validation/Helpers.h b/tests/validation/Helpers.h
index 100f4f05c1..942b2396bf 100644
--- a/tests/validation/Helpers.h
+++ b/tests/validation/Helpers.h
@@ -177,7 +177,7 @@ void fill_lookuptable(T &&table)
}
}
-/** Convert a quantized simple tensor into float using tensor quantization information.
+/** Convert an asymmetric quantized simple tensor into float using tensor quantization information.
*
* @param[in] src Quantized tensor.
*
diff --git a/tests/validation/fixtures/SoftmaxLayerFixture.h b/tests/validation/fixtures/SoftmaxLayerFixture.h
index f747ab3574..82daf34f13 100644
--- a/tests/validation/fixtures/SoftmaxLayerFixture.h
+++ b/tests/validation/fixtures/SoftmaxLayerFixture.h
@@ -65,7 +65,7 @@ protected:
std::uniform_real_distribution<> distribution(-1000.f, 1000.f);
library->fill(tensor, distribution, 0);
}
- else // data type is quantized_asymmetric
+ else // data type is quantized_asymmetric (signed or unsigned)
{
std::uniform_int_distribution<> distribution(0, 100);
library->fill(tensor, distribution, 0);
@@ -77,7 +77,7 @@ protected:
{
// Create tensors
TensorType src = create_tensor<TensorType>(shape, data_type, 1, quantization_info);
- TensorType dst = create_tensor<TensorType>(shape, data_type, 1, QuantizationInfo(1.f / 256, 0));
+ TensorType dst = create_tensor<TensorType>(shape, data_type, 1, get_softmax_output_quantization_info(data_type, IS_LOG));
// Create and configure function
FunctionType smx_layer;
diff --git a/tests/validation/reference/LogSoftmaxLayer.cpp b/tests/validation/reference/LogSoftmaxLayer.cpp
index 3f21d85dd0..e4403956ab 100644
--- a/tests/validation/reference/LogSoftmaxLayer.cpp
+++ b/tests/validation/reference/LogSoftmaxLayer.cpp
@@ -40,21 +40,21 @@ SimpleTensor<T> log_softmax_layer(const SimpleTensor<T> &src, float beta, size_t
return softmax_layer_generic<T>(src, beta, axis, true);
}
-template <typename T, typename std::enable_if<std::is_same<T, uint8_t>::value, int>::type>
+template <typename T, typename std::enable_if<std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value, int>::type>
SimpleTensor<T> log_softmax_layer(const SimpleTensor<T> &src, float beta, size_t axis)
{
- // Note: Output quantization info should always have scale = 1/256 and offset = 0
- const QuantizationInfo output_quantization_info = QuantizationInfo(1.f / 256, 0);
+ const QuantizationInfo output_quantization_info = arm_compute::get_softmax_output_quantization_info(src.data_type(), true);
SimpleTensor<float> src_tmp = convert_from_asymmetric(src);
SimpleTensor<float> dst_tmp = log_softmax_layer<float>(src_tmp, beta, axis);
- SimpleTensor<T> dst = convert_to_asymmetric<uint8_t>(dst_tmp, output_quantization_info);
+ SimpleTensor<T> dst = convert_to_asymmetric<T>(dst_tmp, output_quantization_info);
return dst;
}
template SimpleTensor<float> log_softmax_layer(const SimpleTensor<float> &src, float beta, size_t axis);
template SimpleTensor<half> log_softmax_layer(const SimpleTensor<half> &src, float beta, size_t axis);
template SimpleTensor<uint8_t> log_softmax_layer(const SimpleTensor<uint8_t> &src, float beta, size_t axis);
+template SimpleTensor<int8_t> log_softmax_layer(const SimpleTensor<int8_t> &src, float beta, size_t axis);
} // namespace reference
} // namespace validation
} // namespace test
diff --git a/tests/validation/reference/LogSoftmaxLayer.h b/tests/validation/reference/LogSoftmaxLayer.h
index 065315ff2c..c2e3f5974e 100644
--- a/tests/validation/reference/LogSoftmaxLayer.h
+++ b/tests/validation/reference/LogSoftmaxLayer.h
@@ -38,7 +38,7 @@ namespace reference
template <typename T, typename std::enable_if<is_floating_point<T>::value, int>::type = 0>
SimpleTensor<T> log_softmax_layer(const SimpleTensor<T> &src, float beta, size_t axis = 1);
-template <typename T, typename std::enable_if<std::is_same<T, uint8_t>::value, int>::type = 0>
+template <typename T, typename std::enable_if<std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value, int>::type = 0>
SimpleTensor<T> log_softmax_layer(const SimpleTensor<T> &src, float beta, size_t axis = 1);
} // namespace reference
} // namespace validation
diff --git a/tests/validation/reference/SoftmaxLayer.cpp b/tests/validation/reference/SoftmaxLayer.cpp
index ef2468df59..0e470260a9 100644
--- a/tests/validation/reference/SoftmaxLayer.cpp
+++ b/tests/validation/reference/SoftmaxLayer.cpp
@@ -107,21 +107,21 @@ SimpleTensor<T> softmax_layer(const SimpleTensor<T> &src, float beta, size_t axi
return softmax_layer_generic<T>(src, beta, axis, false);
}
-template <typename T, typename std::enable_if<std::is_same<T, uint8_t>::value, int>::type>
+template <typename T, typename std::enable_if<std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value, int>::type>
SimpleTensor<T> softmax_layer(const SimpleTensor<T> &src, float beta, size_t axis)
{
- // Note: Output quantization info should always have scale = 1/256 and offset = 0
- const QuantizationInfo output_quantization_info = QuantizationInfo(1.f / 256, 0);
+ const QuantizationInfo output_quantization_info = arm_compute::get_softmax_output_quantization_info(src.data_type(), false);
SimpleTensor<float> src_tmp = convert_from_asymmetric(src);
SimpleTensor<float> dst_tmp = softmax_layer<float>(src_tmp, beta, axis);
- SimpleTensor<T> dst = convert_to_asymmetric<uint8_t>(dst_tmp, output_quantization_info);
+ SimpleTensor<T> dst = convert_to_asymmetric<T>(dst_tmp, output_quantization_info);
return dst;
}
template SimpleTensor<float> softmax_layer(const SimpleTensor<float> &src, float beta, size_t axis);
template SimpleTensor<half> softmax_layer(const SimpleTensor<half> &src, float beta, size_t axis);
template SimpleTensor<uint8_t> softmax_layer(const SimpleTensor<uint8_t> &src, float beta, size_t axis);
+template SimpleTensor<int8_t> softmax_layer(const SimpleTensor<int8_t> &src, float beta, size_t axis);
} // namespace reference
} // namespace validation
} // namespace test
diff --git a/tests/validation/reference/SoftmaxLayer.h b/tests/validation/reference/SoftmaxLayer.h
index 2708c807f2..2be575c2af 100644
--- a/tests/validation/reference/SoftmaxLayer.h
+++ b/tests/validation/reference/SoftmaxLayer.h
@@ -41,7 +41,7 @@ SimpleTensor<T> softmax_layer_generic(const SimpleTensor<T> &src, float beta, si
template <typename T, typename std::enable_if<is_floating_point<T>::value, int>::type = 0>
SimpleTensor<T> softmax_layer(const SimpleTensor<T> &src, float beta, size_t axis = 1);
-template <typename T, typename std::enable_if<std::is_same<T, uint8_t>::value, int>::type = 0>
+template <typename T, typename std::enable_if<std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value, int>::type = 0>
SimpleTensor<T> softmax_layer(const SimpleTensor<T> &src, float beta, size_t axis = 1);
} // namespace reference
} // namespace validation