aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/cl_kernels/softmax_layer_quantized.cl
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/cl_kernels/softmax_layer_quantized.cl')
-rw-r--r--src/core/CL/cl_kernels/softmax_layer_quantized.cl13
1 files changed, 7 insertions, 6 deletions
diff --git a/src/core/CL/cl_kernels/softmax_layer_quantized.cl b/src/core/CL/cl_kernels/softmax_layer_quantized.cl
index 8ccc5d3dd5..ce3bd7bc43 100644
--- a/src/core/CL/cl_kernels/softmax_layer_quantized.cl
+++ b/src/core/CL/cl_kernels/softmax_layer_quantized.cl
@@ -560,8 +560,8 @@ __kernel void softmax_layer_norm_quantized(
int sum_val = *((__global int *)offset(&sum, 0, get_global_id(1)));
// It will be better to calculate this in prev layer and pass here as parameter
- uint sum_val_u = convert_uint(sum_val);
#ifndef LOG_SOFTMAX
+ uint sum_val_u = convert_uint(sum_val);
int headroom_plus_one = clz(sum_val_u);
int num_bits_over_unit = EXP_ACCUMULATION_INT_BITS - headroom_plus_one;
int shifted_sum_minus_one_1 = convert_int((sum_val_u << headroom_plus_one) - (1u << 31));
@@ -578,15 +578,16 @@ __kernel void softmax_layer_norm_quantized(
data_diff_mult = ASYMM_MULT(data_diff * (1 << INPUT_BETA_LEFT_SHIFT), INPUT_BETA_MULTIPLIER, 16);
}
#endif /* defined(INPUT_BETA_MULTIPLIER) && defined(INPUT_BETA_LEFT_SHIFT) */
- int16 data = ASYMM_EXP_ON_NEGATIVE_VALUES(data_diff_mult, SCALED_DIFF_INT_BITS, 16);
#ifdef LOG_SOFTMAX
- data = SUB_OP(data_diff_mult, (int16)sum_val_u, int, 16);
+ long16 data = SUB_OP(convert_long16(data_diff_mult), (long16)(sum_val), long, 16);
+ data = select(0L, data, convert_long16(data_diff) >= (long16)(DIFF_MIN));
#else /* LOG_SOFTMAX */
- data = ASYMM_MULT(shifted_scale, data, 16);
- data = ASYMM_ROUNDING_DIVIDE_BY_POW2(data, num_bits_over_unit + 31 - 8, 16);
+ int16 data = ASYMM_EXP_ON_NEGATIVE_VALUES(data_diff_mult, SCALED_DIFF_INT_BITS, 16);
+ data = ASYMM_MULT(shifted_scale, data, 16);
+ data = ASYMM_ROUNDING_DIVIDE_BY_POW2(data, num_bits_over_unit + 31 - 8, 16);
+ data = select(0, data, data_diff >= (int16)(DIFF_MIN));
#endif /* LOG_SOFTMAX */
- data = select(0, data, data_diff >= (int16)(DIFF_MIN));
vstore16(convert_uchar16_sat(data), 0, (__global uchar *)offset(&dst, 0, 0));
}