From 21079dd320c00068208acdfd59177895265a53f2 Mon Sep 17 00:00:00 2001 From: Manuel Bottini Date: Tue, 29 Oct 2019 17:20:09 +0000 Subject: COMPMID-2700: Use NEON wrapper on SoftmaxLayer Change-Id: Id8901e865c9f355dcf7b2a1a539493099591377e Signed-off-by: Manuel Bottini Reviewed-on: https://review.mlplatform.org/c/2186 Comments-Addressed: Arm Jenkins Reviewed-by: Michele Di Giorgio Reviewed-by: Giorgio Arena Tested-by: Arm Jenkins --- arm_compute/core/NEON/NEColorConvertHelper.inl | 110 ++--- arm_compute/core/NEON/NEMath.h | 29 +- arm_compute/core/NEON/NEMath.inl | 33 ++ src/core/NEON/kernels/NESoftmaxLayerKernel.cpp | 562 ++++++------------------- tests/validation/NEON/SoftmaxLayer.cpp | 4 +- 5 files changed, 230 insertions(+), 508 deletions(-) diff --git a/arm_compute/core/NEON/NEColorConvertHelper.inl b/arm_compute/core/NEON/NEColorConvertHelper.inl index 68f437116c..62c6eb5aea 100644 --- a/arm_compute/core/NEON/NEColorConvertHelper.inl +++ b/arm_compute/core/NEON/NEColorConvertHelper.inl @@ -24,6 +24,7 @@ #include "arm_compute/core/Error.h" #include "arm_compute/core/Helpers.h" #include "arm_compute/core/IMultiImage.h" +#include "arm_compute/core/NEON/NEMath.h" #include "arm_compute/core/Utils.h" #include @@ -49,37 +50,6 @@ constexpr float rgb2u8_red_coef = 0.2126f; constexpr float rgb2u8_green_coef = 0.7152f; constexpr float rgb2u8_blue_coef = 0.0722f; -inline float32x4x4_t convert_uint8x16_to_float32x4x4(const uint8x16_t &in) -{ - float32x4x4_t out; - const auto tmp1 = vmovl_u8(vget_low_u8(in)); - out.val[0] = vcvtq_f32_u32(vmovl_u16(vget_low_u16(tmp1))); - out.val[1] = vcvtq_f32_u32(vmovl_u16(vget_high_u16(tmp1))); - const auto tmp2 = vmovl_u8(vget_high_u8(in)); - out.val[2] = vcvtq_f32_u32(vmovl_u16(vget_low_u16(tmp2))); - out.val[3] = vcvtq_f32_u32(vmovl_u16(vget_high_u16(tmp2))); - return out; -} - -inline void convert_float32x4x3_to_uint8x8x3(const float32x4x3_t &in1, const float32x4x3_t &in2, uint8x8x3_t &out) -{ - out.val[0] = vqmovn_u16(vcombine_u16(vqmovn_u32(vcvtq_u32_f32(in1.val[0])), - vqmovn_u32(vcvtq_u32_f32(in2.val[0])))); - out.val[1] = vqmovn_u16(vcombine_u16(vqmovn_u32(vcvtq_u32_f32(in1.val[1])), - vqmovn_u32(vcvtq_u32_f32(in2.val[1])))); - out.val[2] = vqmovn_u16(vcombine_u16(vqmovn_u32(vcvtq_u32_f32(in1.val[2])), - vqmovn_u32(vcvtq_u32_f32(in2.val[2])))); -} - -inline void convert_float32x4x4_to_unit8x16(const float32x4x4_t &in, uint8x16_t &out) -{ - const auto low = vcombine_u16(vqmovn_u32(vcvtq_u32_f32(in.val[0])), - vqmovn_u32(vcvtq_u32_f32(in.val[1]))); - const auto high = vcombine_u16(vqmovn_u32(vcvtq_u32_f32(in.val[2])), - vqmovn_u32(vcvtq_u32_f32(in.val[3]))); - out = vcombine_u8(vqmovn_u16(low), vqmovn_u16(high)); -} - inline float32x4_t rgb_to_greyscale_calculation(const float32x4_t &rcolor, const float32x4_t &gcolor, const float32x4_t &bcolor, const float rcoef, const float gcoef, const float bcoef) { @@ -94,9 +64,9 @@ inline void rgb_to_u8_conversion(const uint8x16x3_t &in, uint8x16_t &out) float32x4x4_t out_float32; //Conversion from 3(RGB) 4 uint8s to 3(RGB) 4 floats - const float32x4x4_t r_float32 = convert_uint8x16_to_float32x4x4(in.val[0]); - const float32x4x4_t g_float32 = convert_uint8x16_to_float32x4x4(in.val[1]); - const float32x4x4_t b_float32 = convert_uint8x16_to_float32x4x4(in.val[2]); + const float32x4x4_t r_float32 = arm_compute::convert_uint8x16_to_float32x4x4(in.val[0]); + const float32x4x4_t g_float32 = arm_compute::convert_uint8x16_to_float32x4x4(in.val[1]); + const float32x4x4_t b_float32 = arm_compute::convert_uint8x16_to_float32x4x4(in.val[2]); //New grayscale image = ( (RED_COEFF * R) + (GREEN_COEFF * G) + (BLUE_COEFF * B) ) //Computation of 1(Greyscale) 4 uint8 using 3(RGB) 4 uint8s float @@ -113,7 +83,7 @@ inline void rgb_to_u8_conversion(const uint8x16x3_t &in, uint8x16_t &out) rgb2u8_red_coef, rgb2u8_green_coef, rgb2u8_blue_coef); //Conversion from 1(Greyscale) 4 floats to 1(Greyscale) 4 uint8s - convert_float32x4x4_to_unit8x16(out_float32, out); + arm_compute::convert_float32x4x4_to_unit8x16(out_float32, out); } inline void rgb_to_yuv_calculation(const float32x4_t &rvec, const float32x4_t &gvec, const float32x4_t &bvec, @@ -172,7 +142,7 @@ inline void yuyv_to_rgb_calculation(const float32x4_t &yvec_val, float32x4_t uve rgb2.val[2] = vaddq_f32(yyvec_val, blue); uint8x8x3_t u8_rgb; - convert_float32x4x3_to_uint8x8x3(rgb1, rgb2, u8_rgb); + arm_compute::convert_float32x4x3_to_uint8x8x3(rgb1, rgb2, u8_rgb); if(!alpha) { @@ -225,13 +195,13 @@ inline uint8x16x3_t load_rgb(const unsigned char *const ptr, const bool alpha) inline void rgb_to_yuv_conversion(uint8x16x3_t &vec_top, uint8x16x3_t &vec_bottom) { // Convert the uint8x16_t to float32x4x4_t - const float32x4x4_t frvec_top = convert_uint8x16_to_float32x4x4(vec_top.val[0]); - const float32x4x4_t fgvec_top = convert_uint8x16_to_float32x4x4(vec_top.val[1]); - const float32x4x4_t fbvec_top = convert_uint8x16_to_float32x4x4(vec_top.val[2]); + const float32x4x4_t frvec_top = arm_compute::convert_uint8x16_to_float32x4x4(vec_top.val[0]); + const float32x4x4_t fgvec_top = arm_compute::convert_uint8x16_to_float32x4x4(vec_top.val[1]); + const float32x4x4_t fbvec_top = arm_compute::convert_uint8x16_to_float32x4x4(vec_top.val[2]); - const float32x4x4_t frvec_bottom = convert_uint8x16_to_float32x4x4(vec_bottom.val[0]); - const float32x4x4_t fgvec_bottom = convert_uint8x16_to_float32x4x4(vec_bottom.val[1]); - const float32x4x4_t fbvec_bottom = convert_uint8x16_to_float32x4x4(vec_bottom.val[2]); + const float32x4x4_t frvec_bottom = arm_compute::convert_uint8x16_to_float32x4x4(vec_bottom.val[0]); + const float32x4x4_t fgvec_bottom = arm_compute::convert_uint8x16_to_float32x4x4(vec_bottom.val[1]); + const float32x4x4_t fbvec_bottom = arm_compute::convert_uint8x16_to_float32x4x4(vec_bottom.val[2]); float32x4x4_t fyvec_top, fuvec_top, fvvec_top; float32x4x4_t fyvec_bottom, fuvec_bottom, fvvec_bottom; @@ -244,12 +214,12 @@ inline void rgb_to_yuv_conversion(uint8x16x3_t &vec_top, uint8x16x3_t &vec_botto fyvec_bottom.val[i], fuvec_bottom.val[i], fvvec_bottom.val[i]); } - convert_float32x4x4_to_unit8x16(fyvec_top, vec_top.val[0]); - convert_float32x4x4_to_unit8x16(fuvec_top, vec_top.val[1]); - convert_float32x4x4_to_unit8x16(fvvec_top, vec_top.val[2]); - convert_float32x4x4_to_unit8x16(fyvec_bottom, vec_bottom.val[0]); - convert_float32x4x4_to_unit8x16(fuvec_bottom, vec_bottom.val[1]); - convert_float32x4x4_to_unit8x16(fvvec_bottom, vec_bottom.val[2]); + arm_compute::convert_float32x4x4_to_unit8x16(fyvec_top, vec_top.val[0]); + arm_compute::convert_float32x4x4_to_unit8x16(fuvec_top, vec_top.val[1]); + arm_compute::convert_float32x4x4_to_unit8x16(fvvec_top, vec_top.val[2]); + arm_compute::convert_float32x4x4_to_unit8x16(fyvec_bottom, vec_bottom.val[0]); + arm_compute::convert_float32x4x4_to_unit8x16(fuvec_bottom, vec_bottom.val[1]); + arm_compute::convert_float32x4x4_to_unit8x16(fvvec_bottom, vec_bottom.val[2]); } inline void store_rgb_to_nv12(const uint8x16_t &rvec_top, const uint8x16_t &gvec_top, const uint8x16_t &bvec_top, @@ -316,9 +286,9 @@ inline void store_rgb_to_yuv4(const uint8x16_t &rvec, const uint8x16_t &gvec, co unsigned char *const __restrict out_v) { // Convert the uint8x16_t to float32x4x4_t - const float32x4x4_t frvec = convert_uint8x16_to_float32x4x4(rvec); - const float32x4x4_t fgvec = convert_uint8x16_to_float32x4x4(gvec); - const float32x4x4_t fbvec = convert_uint8x16_to_float32x4x4(bvec); + const float32x4x4_t frvec = arm_compute::convert_uint8x16_to_float32x4x4(rvec); + const float32x4x4_t fgvec = arm_compute::convert_uint8x16_to_float32x4x4(gvec); + const float32x4x4_t fbvec = arm_compute::convert_uint8x16_to_float32x4x4(bvec); float32x4x4_t fyvec, fuvec, fvvec; for(auto i = 0; i < 4; ++i) @@ -328,9 +298,9 @@ inline void store_rgb_to_yuv4(const uint8x16_t &rvec, const uint8x16_t &gvec, co } uint8x16_t yvec, uvec, vvec; - convert_float32x4x4_to_unit8x16(fyvec, yvec); - convert_float32x4x4_to_unit8x16(fuvec, uvec); - convert_float32x4x4_to_unit8x16(fvvec, vvec); + arm_compute::convert_float32x4x4_to_unit8x16(fyvec, yvec); + arm_compute::convert_float32x4x4_to_unit8x16(fuvec, uvec); + arm_compute::convert_float32x4x4_to_unit8x16(fvvec, vvec); vst1q_u8(out_y, yvec); vst1q_u8(out_u, uvec); @@ -461,10 +431,10 @@ void colorconvert_yuyv_to_rgb(const void *__restrict input, void *__restrict out //ta.val[3] = V0 V2 V4 V7 ... // Convert the uint8x16x4_t to float32x4x4_t - const float32x4x4_t yvec = convert_uint8x16_to_float32x4x4(ta.val[0 + shift]); - const float32x4x4_t uvec = convert_uint8x16_to_float32x4x4(ta.val[1 - shift]); - const float32x4x4_t yyvec = convert_uint8x16_to_float32x4x4(ta.val[2 + shift]); - const float32x4x4_t vvec = convert_uint8x16_to_float32x4x4(ta.val[3 - shift]); + const float32x4x4_t yvec = arm_compute::convert_uint8x16_to_float32x4x4(ta.val[0 + shift]); + const float32x4x4_t uvec = arm_compute::convert_uint8x16_to_float32x4x4(ta.val[1 - shift]); + const float32x4x4_t yyvec = arm_compute::convert_uint8x16_to_float32x4x4(ta.val[2 + shift]); + const float32x4x4_t vvec = arm_compute::convert_uint8x16_to_float32x4x4(ta.val[3 - shift]); yuyv_to_rgb_calculation(yvec.val[0], uvec.val[0], yyvec.val[0], vvec.val[0], out.ptr() + 0 * element_size, alpha); yuyv_to_rgb_calculation(yvec.val[1], uvec.val[1], yyvec.val[1], vvec.val[1], out.ptr() + 1 * element_size, alpha); @@ -516,12 +486,12 @@ void colorconvert_nv12_to_rgb(const void *__restrict input, void *__restrict out //ta_uv.val[1] = V0 V2 V4 V6 ... // Convert the uint8x16x4_t to float32x4x4_t - float32x4x4_t yvec_top = convert_uint8x16_to_float32x4x4(ta_y_top.val[0]); - float32x4x4_t yyvec_top = convert_uint8x16_to_float32x4x4(ta_y_top.val[1]); - float32x4x4_t yvec_bottom = convert_uint8x16_to_float32x4x4(ta_y_bottom.val[0]); - float32x4x4_t yyvec_bottom = convert_uint8x16_to_float32x4x4(ta_y_bottom.val[1]); - float32x4x4_t uvec = convert_uint8x16_to_float32x4x4(ta_uv.val[0 + shift]); - float32x4x4_t vvec = convert_uint8x16_to_float32x4x4(ta_uv.val[1 - shift]); + float32x4x4_t yvec_top = arm_compute::convert_uint8x16_to_float32x4x4(ta_y_top.val[0]); + float32x4x4_t yyvec_top = arm_compute::convert_uint8x16_to_float32x4x4(ta_y_top.val[1]); + float32x4x4_t yvec_bottom = arm_compute::convert_uint8x16_to_float32x4x4(ta_y_bottom.val[0]); + float32x4x4_t yyvec_bottom = arm_compute::convert_uint8x16_to_float32x4x4(ta_y_bottom.val[1]); + float32x4x4_t uvec = arm_compute::convert_uint8x16_to_float32x4x4(ta_uv.val[0 + shift]); + float32x4x4_t vvec = arm_compute::convert_uint8x16_to_float32x4x4(ta_uv.val[1 - shift]); yuyv_to_rgb_calculation(yvec_top.val[0], uvec.val[0], yyvec_top.val[0], vvec.val[0], out.ptr() + 0 * element_size, alpha); yuyv_to_rgb_calculation(yvec_top.val[1], uvec.val[1], yyvec_top.val[1], vvec.val[1], out.ptr() + 1 * element_size, alpha); @@ -579,12 +549,12 @@ void colorconvert_iyuv_to_rgb(const void *__restrict input, void *__restrict out //ta_v.val[0] = V0 V2 V4 V6 ... // Convert the uint8x16x4_t to float32x4x4_t - float32x4x4_t yvec_top = convert_uint8x16_to_float32x4x4(ta_y_top.val[0]); - float32x4x4_t yyvec_top = convert_uint8x16_to_float32x4x4(ta_y_top.val[1]); - float32x4x4_t yvec_bottom = convert_uint8x16_to_float32x4x4(ta_y_bottom.val[0]); - float32x4x4_t yyvec_bottom = convert_uint8x16_to_float32x4x4(ta_y_bottom.val[1]); - float32x4x4_t uvec = convert_uint8x16_to_float32x4x4(ta_u); - float32x4x4_t vvec = convert_uint8x16_to_float32x4x4(ta_v); + float32x4x4_t yvec_top = arm_compute::convert_uint8x16_to_float32x4x4(ta_y_top.val[0]); + float32x4x4_t yyvec_top = arm_compute::convert_uint8x16_to_float32x4x4(ta_y_top.val[1]); + float32x4x4_t yvec_bottom = arm_compute::convert_uint8x16_to_float32x4x4(ta_y_bottom.val[0]); + float32x4x4_t yyvec_bottom = arm_compute::convert_uint8x16_to_float32x4x4(ta_y_bottom.val[1]); + float32x4x4_t uvec = arm_compute::convert_uint8x16_to_float32x4x4(ta_u); + float32x4x4_t vvec = arm_compute::convert_uint8x16_to_float32x4x4(ta_v); yuyv_to_rgb_calculation(yvec_top.val[0], uvec.val[0], yyvec_top.val[0], vvec.val[0], out.ptr() + 0 * element_size, alpha); yuyv_to_rgb_calculation(yvec_top.val[1], uvec.val[1], yyvec_top.val[1], vvec.val[1], out.ptr() + 1 * element_size, alpha); diff --git a/arm_compute/core/NEON/NEMath.h b/arm_compute/core/NEON/NEMath.h index 8593059b1a..aa3054306c 100644 --- a/arm_compute/core/NEON/NEMath.h +++ b/arm_compute/core/NEON/NEMath.h @@ -21,8 +21,8 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ -#ifndef __ARM_COMPUTE_NEMATH_H__ -#define __ARM_COMPUTE_NEMATH_H__ +#ifndef ARM_COMPUTE_NEMATH_H +#define ARM_COMPUTE_NEMATH_H #include @@ -157,6 +157,29 @@ int32x4_t rounding_divide_by_pow2(int32x4_t x, int exponent); */ int32_t rounding_divide_by_pow2(int32_t x, int exponent); +/** Converts from uint8x16 to float32x4x4_t + * + * @param[in] in Vector of uint8 to be converted + * + * @return Converted vector of float + */ +float32x4x4_t convert_uint8x16_to_float32x4x4(const uint8x16_t &in); + +/** Converts from two float32x4x3_t to just one uint8x8x3_t + * + * @param[in] in1 First input vector of float to be converted + * @param[in] in2 Second input vector of float to be converted + * @param[out] out Converted output vector uint8 to store the result + */ +void convert_float32x4x3_to_uint8x8x3(const float32x4x3_t &in1, const float32x4x3_t &in2, uint8x8x3_t &out); + +/** Converts from two float32x4x4_t to just one uint8x16_t + * + * @param[in] in Vector of float to be converted + * @param[out] out Converted vector of uint8 to store the result + */ +void convert_float32x4x4_to_unit8x16(const float32x4x4_t &in, uint8x16_t &out); + /** Calculate sine. * * @param[in] val Input vector value in radians, F32 format. @@ -256,4 +279,4 @@ float16x8_t vsinq_f16(float16x8_t val); #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ } // namespace arm_compute #include "arm_compute/core/NEON/NEMath.inl" -#endif /* __ARM_COMPUTE_NEMATH_H__ */ +#endif /* ARM_COMPUTE_NEMATH_H */ diff --git a/arm_compute/core/NEON/NEMath.inl b/arm_compute/core/NEON/NEMath.inl index f1c9c2024b..a3601f6a25 100644 --- a/arm_compute/core/NEON/NEMath.inl +++ b/arm_compute/core/NEON/NEMath.inl @@ -317,6 +317,39 @@ inline int32_t rounding_divide_by_pow2(int32_t x, int exponent) return (x >> exponent) + ((x & mask) > threshold ? 1 : 0); } +inline float32x4x4_t convert_uint8x16_to_float32x4x4(const uint8x16_t &in) +{ + float32x4x4_t out; + + const auto tmp1 = vmovl_u8(vget_low_u8(in)); + out.val[0] = vcvtq_f32_u32(vmovl_u16(vget_low_u16(tmp1))); + out.val[1] = vcvtq_f32_u32(vmovl_u16(vget_high_u16(tmp1))); + + const auto tmp2 = vmovl_u8(vget_high_u8(in)); + out.val[2] = vcvtq_f32_u32(vmovl_u16(vget_low_u16(tmp2))); + out.val[3] = vcvtq_f32_u32(vmovl_u16(vget_high_u16(tmp2))); + return out; +} + +inline void convert_float32x4x3_to_uint8x8x3(const float32x4x3_t &in1, const float32x4x3_t &in2, uint8x8x3_t &out) +{ + out.val[0] = vqmovn_u16(vcombine_u16(vqmovn_u32(vcvtq_u32_f32(in1.val[0])), + vqmovn_u32(vcvtq_u32_f32(in2.val[0])))); + out.val[1] = vqmovn_u16(vcombine_u16(vqmovn_u32(vcvtq_u32_f32(in1.val[1])), + vqmovn_u32(vcvtq_u32_f32(in2.val[1])))); + out.val[2] = vqmovn_u16(vcombine_u16(vqmovn_u32(vcvtq_u32_f32(in1.val[2])), + vqmovn_u32(vcvtq_u32_f32(in2.val[2])))); +} + +inline void convert_float32x4x4_to_unit8x16(const float32x4x4_t &in, uint8x16_t &out) +{ + const auto low = vcombine_u16(vqmovn_u32(vcvtq_u32_f32(in.val[0])), + vqmovn_u32(vcvtq_u32_f32(in.val[1]))); + const auto high = vcombine_u16(vqmovn_u32(vcvtq_u32_f32(in.val[2])), + vqmovn_u32(vcvtq_u32_f32(in.val[3]))); + out = vcombine_u8(vqmovn_u16(low), vqmovn_u16(high)); +} + #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC /** Exponent polynomial coefficients */ /** Logarithm polynomial coefficients */ diff --git a/src/core/NEON/kernels/NESoftmaxLayerKernel.cpp b/src/core/NEON/kernels/NESoftmaxLayerKernel.cpp index 1003ebd2e3..a3ecce3a1e 100644 --- a/src/core/NEON/kernels/NESoftmaxLayerKernel.cpp +++ b/src/core/NEON/kernels/NESoftmaxLayerKernel.cpp @@ -30,6 +30,7 @@ #include "arm_compute/core/ITensor.h" #include "arm_compute/core/NEON/NEFixedPoint.h" #include "arm_compute/core/NEON/NEMath.h" +#include "arm_compute/core/NEON/wrapper/wrapper.h" #include "arm_compute/core/TensorInfo.h" #include "arm_compute/core/Utils.h" #include "arm_compute/core/Validate.h" @@ -43,309 +44,6 @@ namespace arm_compute { -template -struct vec_n_type; - -#define DECLARE_NEON_VEC_TYPE(T, N, V) \ - template <> \ - struct vec_n_type \ - { \ - using type = V; \ - }; - -DECLARE_NEON_VEC_TYPE(uint8_t, 16, uint8x16_t) -DECLARE_NEON_VEC_TYPE(uint8_t, 8, uint8x8_t) - -DECLARE_NEON_VEC_TYPE(int8_t, 16, int8x16_t) -DECLARE_NEON_VEC_TYPE(int8_t, 8, int8x8_t) - -DECLARE_NEON_VEC_TYPE(uint16_t, 8, uint16x8_t) -DECLARE_NEON_VEC_TYPE(uint16_t, 4, uint16x4_t) - -DECLARE_NEON_VEC_TYPE(int16_t, 8, int16x8_t) -DECLARE_NEON_VEC_TYPE(int16_t, 4, int16x4_t) - -DECLARE_NEON_VEC_TYPE(int32_t, 4, int32x4_t) -DECLARE_NEON_VEC_TYPE(int32_t, 2, int32x2_t) - -DECLARE_NEON_VEC_TYPE(uint32_t, 4, uint32x4_t) -DECLARE_NEON_VEC_TYPE(uint32_t, 2, uint32x2_t) - -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -DECLARE_NEON_VEC_TYPE(float16_t, 8, float16x8_t) -DECLARE_NEON_VEC_TYPE(float16_t, 4, float16x4_t) -#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ - -DECLARE_NEON_VEC_TYPE(float, 4, float32x4_t) -DECLARE_NEON_VEC_TYPE(float, 2, float32x2_t) - -template -using vec_n_t = typename vec_n_type::type; - -template -using vec_n_byte_t = vec_n_t < T, N / sizeof(T) >; - -template -using vec_16_byte_t = vec_n_byte_t; - -template -using vec_8_byte_t = vec_n_byte_t; - -template -using const_ptr_t = const T *; - -template -using ptr_t = T *; - -#define FORWARD_DECLARE_VGET_LANE_FOR_TYPE(TYPE) \ - template \ - TYPE vget_lane(vec_8_byte_t vec); \ - template \ - TYPE vget_lane(vec_16_byte_t vec); - -FORWARD_DECLARE_VGET_LANE_FOR_TYPE(uint8_t) -FORWARD_DECLARE_VGET_LANE_FOR_TYPE(int8_t) -FORWARD_DECLARE_VGET_LANE_FOR_TYPE(uint16_t) -FORWARD_DECLARE_VGET_LANE_FOR_TYPE(int16_t) -FORWARD_DECLARE_VGET_LANE_FOR_TYPE(uint32_t) -FORWARD_DECLARE_VGET_LANE_FOR_TYPE(int32_t) -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -FORWARD_DECLARE_VGET_LANE_FOR_TYPE(float16_t) -#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ -FORWARD_DECLARE_VGET_LANE_FOR_TYPE(float) -template -float vget_lane(float32x4x4_t vec); - -template -using elem_type_t = decltype(vget_lane<0>(std::declval())); - -template -constexpr size_t vec_size_of(const V &vec) -{ - return sizeof(vec) / sizeof(elem_type_t); -} - -template -V vdup_n(elem_type_t val); -template -V vld(const_ptr_t> ptr); - -#define DECLARE_NEON_FUNCTIONS_FOR_TYPE(TYPE, TAG) \ - template <> \ - inline vec_8_byte_t vdup_n>(TYPE val) \ - { \ - return vdup_n_##TAG(val); \ - } \ - template <> \ - inline vec_16_byte_t vdup_n>(TYPE val) \ - { \ - return vdupq_n_##TAG(val); \ - } \ - template <> \ - inline vec_8_byte_t vld>(const_ptr_t ptr) \ - { \ - return vld1_##TAG(ptr); \ - } \ - template <> \ - inline vec_16_byte_t vld>(const_ptr_t ptr) \ - { \ - return vld1q_##TAG(ptr); \ - } \ - inline void vst(ptr_t ptr, vec_8_byte_t vec) \ - { \ - vst1_##TAG(ptr, vec); \ - } \ - inline void vst(ptr_t ptr, vec_16_byte_t vec) \ - { \ - vst1q_##TAG(ptr, vec); \ - } \ - inline vec_16_byte_t vmax(vec_16_byte_t a, vec_16_byte_t b) \ - { \ - return vmaxq_##TAG(a, b); \ - } \ - inline vec_8_byte_t vpmax(vec_8_byte_t a, vec_8_byte_t b) \ - { \ - return vpmax_##TAG(a, b); \ - } \ - inline vec_8_byte_t vget_low(vec_16_byte_t vec) \ - { \ - return vget_low_##TAG(vec); \ - } \ - inline vec_8_byte_t vget_high(vec_16_byte_t vec) \ - { \ - return vget_high_##TAG(vec); \ - } \ - template \ - inline TYPE vget_lane(vec_8_byte_t vec) \ - { \ - static_assert(lane >= 0, "lane is out of bounds"); \ - static_assert(lane < vec_size_of(vec), "lane is out of bounds"); \ - return vget_lane_##TAG(vec, lane); \ - } \ - template \ - inline TYPE vget_lane(vec_16_byte_t vec) \ - { \ - static_assert(lane >= 0, "lane is out of bounds"); \ - static_assert(lane < vec_size_of(vec), "lane is out of bounds"); \ - return vgetq_lane_##TAG(vec, lane); \ - } - -template -T sqadd(T a, T b); -template -T sqsub(T a, T b); -template -T sqmul(T a, T b); - -#define DECLARE_NEON_FUNCTIONS_FOR_FLOAT(TYPE, TAG) \ - inline vec_8_byte_t vadd(vec_8_byte_t a, vec_8_byte_t b) \ - { \ - return vadd_##TAG(a, b); \ - } \ - inline vec_16_byte_t vadd(vec_16_byte_t a, vec_16_byte_t b) \ - { \ - return vaddq_##TAG(a, b); \ - } \ - inline vec_16_byte_t vsub(vec_16_byte_t a, vec_16_byte_t b) \ - { \ - return vsubq_##TAG(a, b); \ - } \ - inline vec_16_byte_t vmul_n(vec_16_byte_t vec, TYPE val) \ - { \ - return vmulq_n_##TAG(vec, val); \ - } - -DECLARE_NEON_FUNCTIONS_FOR_TYPE(uint8_t, u8) -DECLARE_NEON_FUNCTIONS_FOR_TYPE(int8_t, s8) -DECLARE_NEON_FUNCTIONS_FOR_TYPE(uint16_t, u16) -DECLARE_NEON_FUNCTIONS_FOR_TYPE(int16_t, s16) -DECLARE_NEON_FUNCTIONS_FOR_TYPE(uint32_t, u32) -DECLARE_NEON_FUNCTIONS_FOR_TYPE(int32_t, s32) -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -DECLARE_NEON_FUNCTIONS_FOR_TYPE(float16_t, f16) -#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ -DECLARE_NEON_FUNCTIONS_FOR_TYPE(float, f32) - -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -DECLARE_NEON_FUNCTIONS_FOR_FLOAT(float16_t, f16) -#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ -DECLARE_NEON_FUNCTIONS_FOR_FLOAT(float, f32) - -template -VO vcvt(VI vec); - -template <> -float32x4x4_t vcvt(uint8x16_t vec) -{ - const auto low = vmovl_u8(vget_low(vec)); - const auto high = vmovl_u8(vget_high(vec)); - float32x4x4_t res = { { - vcvtq_f32_u32(vmovl_u16(vget_low(low))), - vcvtq_f32_u32(vmovl_u16(vget_high(low))), - vcvtq_f32_u32(vmovl_u16(vget_low(high))), - vcvtq_f32_u32(vmovl_u16(vget_high(high))) - } - }; - return res; -} - -template <> -uint8x16_t vcvt(float32x4x4_t vec) -{ - uint16x8x2_t resU16 = { { - vcombine_u16(vqmovn_u32(vcvtq_u32_f32(vec.val[0])), - vqmovn_u32(vcvtq_u32_f32(vec.val[1]))), - vcombine_u16(vqmovn_u32(vcvtq_u32_f32(vec.val[2])), - vqmovn_u32(vcvtq_u32_f32(vec.val[3]))) - } - }; - - uint8x16_t res = vcombine_u8(vqmovn_u16(resU16.val[0]), vqmovn_u16(resU16.val[1])); - return res; -} - -float32x4x4_t vexp(float32x4x4_t vec) -{ - float32x4x4_t res = { { - vexpq_f32(vec.val[0]), - vexpq_f32(vec.val[1]), - vexpq_f32(vec.val[2]), - vexpq_f32(vec.val[3]) - } - }; - return res; -} - -float32x4_t vexp(const float32x4_t &vec) -{ - return vexpq_f32(vec); -} - -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -// TODO (COMPMID-1535) : Revisit FP16 approximations -float16x8_t vexp(const float16x8_t &vec) -{ - float16x4x2_t res = - { - { - vcvt_f16_f32(vexpq_f32(vcvt_f32_f16(vget_low_f16(vec)))), - vcvt_f16_f32(vexpq_f32(vcvt_f32_f16(vget_high_f16(vec)))) - } - }; - return vcombine_f16(res.val[0], res.val[1]); -} -#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ - -template <> -float32x4x4_t vdup_n(float val) -{ - float32x4x4_t res = { { - vdupq_n_f32(val), - vdupq_n_f32(val), - vdupq_n_f32(val), - vdupq_n_f32(val) - } - }; - return res; -} - -float32x4x4_t vmul_n(float32x4x4_t vec, float val) -{ - float32x4x4_t res = { { - vmulq_n_f32(vec.val[0], val), - vmulq_n_f32(vec.val[1], val), - vmulq_n_f32(vec.val[2], val), - vmulq_n_f32(vec.val[3], val) - } - }; - return res; -} - -float32x4x4_t vadd(float32x4x4_t a, float32x4x4_t b) -{ - float32x4x4_t res = { { - vaddq_f32(a.val[0], b.val[0]), - vaddq_f32(a.val[1], b.val[1]), - vaddq_f32(a.val[2], b.val[2]), - vaddq_f32(a.val[3], b.val[3]) - } - }; - return res; -} - -float32x4x4_t vsub_n(float32x4x4_t a, float val) -{ - auto scalar_vector = vdup_n(val); - float32x4x4_t res = { { - vsubq_f32(a.val[0], scalar_vector.val[0]), - vsubq_f32(a.val[1], scalar_vector.val[1]), - vsubq_f32(a.val[2], scalar_vector.val[2]), - vsubq_f32(a.val[3], scalar_vector.val[3]) - } - }; - return res; -} - namespace { Status validate_arguments_logits_1d_max(const ITensorInfo &input, const ITensorInfo &output) @@ -390,30 +88,20 @@ std::pair validate_and_configure_window_logits_1d_max(ITensorInf return std::make_pair(err, win); } -template -auto reduce_max(V vec) -> elem_type_t -{ - constexpr int N = vec_size_of(vec); - - auto carry_max = vpmax(vget_high(vec), vget_low(vec)); - - for(int k = N / 2; k > 1; k /= 2) - { - carry_max = vpmax(carry_max, carry_max); - } - - return vget_lane<0>(carry_max); -} - template void logits_1d_max(const ITensor &in, ITensor &out, const Window &window) { const auto start_x = in.info()->valid_region().anchor.x(); const size_t input_width = in.info()->valid_region().shape.x(); + /** NEON vector tag type. */ + using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t; + Iterator input(&in, window); Iterator output(&out, window); + constexpr int window_step_x = 16 / sizeof(T); + const int sum_stages = log2(window_step_x / 2); execute_window_loop(window, [&](const Coordinates &) { // Get pointers @@ -421,16 +109,22 @@ void logits_1d_max(const ITensor &in, ITensor &out, const Window &window) const auto out_ptr = reinterpret_cast(output.ptr()); // Init max value - auto vec_max = vdup_n>(support::cpp11::lowest()); + auto vec_max = wrapper::vdup_n(support::cpp11::lowest(), ExactTagType{}); // Loop over input row - for(const T *it = in_ptr; it < (in_ptr + input_width); it += vec_size_of(vec_max)) + for(const T *it = in_ptr; it < (in_ptr + input_width); it += window_step_x) { - const auto current_value = vld>(it); - vec_max = vmax(vec_max, current_value); + const auto current_value = wrapper::vloadq(it); + vec_max = wrapper::vmax(vec_max, current_value); } - const T max_val = reduce_max(vec_max); + auto carry_max = wrapper::vpmax(wrapper::vgethigh(vec_max), wrapper::vgetlow(vec_max)); + + for(int i = 0; i < sum_stages; ++i) + { + carry_max = wrapper::vpmax(carry_max, carry_max); + } + const T max_val = wrapper::vgetlane(carry_max, 0); *out_ptr = max_val; }, input, output); @@ -575,45 +269,19 @@ std::pair validate_and_configure_window_logits_softmax(ITensorIn return std::make_pair(err, win); } -template -struct reduce_add_impl -{ - template - static T reduce(F add_fn, vec_n_t vec) - { - constexpr int H = (S + E + 1) / 2; - const auto reduced_high = reduce_add_impl < T, N, S, H - 1 >::reduce(add_fn, vec); - const auto reduced_low = reduce_add_impl::reduce(add_fn, vec); - return add_fn(reduced_high, reduced_low); - } -}; -template -struct reduce_add_impl -{ - template - static T reduce(F /*add_fn*/, vec_n_t vec) - { - return vget_lane(vec); - } -}; -template -elem_type_t reduce_add(F add_fn, V vec) -{ - constexpr int N = vec_size_of(vec); - return reduce_add_impl < elem_type_t, N, 0, N - 1 >::reduce(add_fn, vec); -} - template void logits_1d_softmax_qasymm8(const ITensor &in, const ITensor &max, void *const tmp, ITensor &out, const float beta, const Window &window) { const int start_x = in.info()->valid_region().anchor.x(); const int input_width = in.info()->valid_region().shape.x(); - const float scale_beta = -beta * in.info()->quantization_info().uniform().scale; + const float scale_beta = -beta * in.info()->quantization_info().uniform().scale; + const auto scale_beta_vec = vdupq_n_f32(scale_beta); - Iterator in_it(&in, window); - Iterator max_it(&max, window); - Iterator out_it(&out, window); + Iterator in_it(&in, window); + Iterator max_it(&max, window); + Iterator out_it(&out, window); + constexpr int vec_size = 16; execute_window_loop(window, [&](const Coordinates &) { @@ -629,57 +297,73 @@ void logits_1d_softmax_qasymm8(const ITensor &in, const ITensor &max, void *cons { /* Get max value */ const auto max_val = *reinterpret_cast(max_it.ptr()); - const auto vec_max = vdup_n>(max_val); + const auto vec_max = vdupq_n_u8(max_val); /* Init sum to zero */ - auto vec_sum = vdup_n(0.f); + float32x4x4_t vec_sum = + { + vdupq_n_f32(0.f), + vdupq_n_f32(0.f), + vdupq_n_f32(0.f), + vdupq_n_f32(0.f), + }; /* Loop over row and compute exponentials and sum */ - int i = 0; - constexpr int vec_size = vec_size_of(vec_max); - - for(; i <= (input_width - vec_size); i += vec_size) + int x = 0; + for(; x <= (input_width - vec_size); x += vec_size) { - auto vec_elements = vld>(in_ptr + i); - vec_elements = vsubq_u8(vec_max, vec_elements); - - auto vec_elements_flt = vcvt(vec_elements); + auto vec_elements = wrapper::vloadq(in_ptr + x); + vec_elements = vsubq_u8(vec_max, vec_elements); + auto vec_elements_flt = convert_uint8x16_to_float32x4x4(vec_elements); if(is_log) { - vec_elements_flt = vmul_n(vec_elements_flt, scale_beta); - vec_sum = vadd(vec_sum, vexp(vec_elements_flt)); + vec_elements_flt.val[0] = vmulq_f32(vec_elements_flt.val[0], scale_beta_vec); + vec_elements_flt.val[1] = vmulq_f32(vec_elements_flt.val[1], scale_beta_vec); + vec_elements_flt.val[2] = vmulq_f32(vec_elements_flt.val[2], scale_beta_vec); + vec_elements_flt.val[3] = vmulq_f32(vec_elements_flt.val[3], scale_beta_vec); + vec_sum.val[0] = vaddq_f32(vec_sum.val[0], vexpq_f32(vec_elements_flt.val[0])); + vec_sum.val[1] = vaddq_f32(vec_sum.val[1], vexpq_f32(vec_elements_flt.val[1])); + vec_sum.val[2] = vaddq_f32(vec_sum.val[2], vexpq_f32(vec_elements_flt.val[2])); + vec_sum.val[3] = vaddq_f32(vec_sum.val[3], vexpq_f32(vec_elements_flt.val[3])); } else { - vec_elements_flt = vexp(vmul_n(vec_elements_flt, scale_beta)); - vec_sum = vadd(vec_sum, vec_elements_flt); + vec_elements_flt.val[0] = vexpq_f32(vmulq_f32(vec_elements_flt.val[0], scale_beta_vec)); + vec_elements_flt.val[1] = vexpq_f32(vmulq_f32(vec_elements_flt.val[1], scale_beta_vec)); + vec_elements_flt.val[2] = vexpq_f32(vmulq_f32(vec_elements_flt.val[2], scale_beta_vec)); + vec_elements_flt.val[3] = vexpq_f32(vmulq_f32(vec_elements_flt.val[3], scale_beta_vec)); + vec_sum.val[0] = vaddq_f32(vec_sum.val[0], vec_elements_flt.val[0]); + vec_sum.val[1] = vaddq_f32(vec_sum.val[1], vec_elements_flt.val[1]); + vec_sum.val[2] = vaddq_f32(vec_sum.val[2], vec_elements_flt.val[2]); + vec_sum.val[3] = vaddq_f32(vec_sum.val[3], vec_elements_flt.val[3]); } - vst4q_f32(tmp_ptr + i, vec_elements_flt); + + vst4q_f32(tmp_ptr + x, vec_elements_flt); } /* Reduce sum */ - const auto sum_16_byte = vaddq_f32(vaddq_f32(vec_sum.val[0], vec_sum.val[1]), - vaddq_f32(vec_sum.val[2], vec_sum.val[3])); - const auto sum_8_byte = vadd_f32(vget_low(sum_16_byte), vget_high(sum_16_byte)); - sum = reduce_add(std::plus(), sum_8_byte); + const auto sum_16_byte = vaddq_f32(vaddq_f32(vec_sum.val[0], vec_sum.val[1]), vaddq_f32(vec_sum.val[2], vec_sum.val[3])); + auto sum_res = vpadd_f32(vget_high_f32(sum_16_byte), vget_low_f32(sum_16_byte)); + sum_res = vpadd_f32(sum_res, sum_res); + sum = wrapper::vgetlane(sum_res, 0); /* Run remaining elements */ - for(; i < input_width; ++i) + for(; x < input_width; ++x) { float element{}; if(is_log) { - element = (max_val - in_ptr[i]) * scale_beta; + element = (max_val - in_ptr[x]) * scale_beta; sum += std::exp(element); } else { - element = std::exp((max_val - in_ptr[i]) * scale_beta); + element = std::exp((max_val - in_ptr[x]) * scale_beta); sum += element; } - tmp_ptr[i] = element; + tmp_ptr[x] = element; } if(!is_log) @@ -691,35 +375,45 @@ void logits_1d_softmax_qasymm8(const ITensor &in, const ITensor &max, void *cons /* Normalize exponentials */ { /* Loop over row and compute softmax */ - int i = 0; + int x = 0; + for(; x <= (input_width - vec_size); x += vec_size) { - constexpr int vec_size = 16; - - for(; i <= (input_width - vec_size); i += vec_size) + float32x4x4_t vec_in = vld4q_f32(tmp_ptr + x); + uint8x16_t normalized_value{}; + if(is_log) { - float32x4x4_t vec_in = vld4q_f32(tmp_ptr + i); - vec_16_byte_t normalized_value{}; - if(is_log) + const float32x4x4_t sub = { - normalized_value = vcvt>(vsub_n(vec_in, sum)); - } - else + vsubq_f32(vec_in.val[0], vdupq_n_f32(sum)), + vsubq_f32(vec_in.val[1], vdupq_n_f32(sum)), + vsubq_f32(vec_in.val[2], vdupq_n_f32(sum)), + vsubq_f32(vec_in.val[3], vdupq_n_f32(sum)), + }; + convert_float32x4x4_to_unit8x16(sub, normalized_value); + } + else + { + const float32x4x4_t mul = { - normalized_value = vcvt>(vmul_n(vec_in, sum_inversed)); - } - vst(out_ptr + i, normalized_value); + vmulq_f32(vec_in.val[0], vdupq_n_f32(sum_inversed)), + vmulq_f32(vec_in.val[1], vdupq_n_f32(sum_inversed)), + vmulq_f32(vec_in.val[2], vdupq_n_f32(sum_inversed)), + vmulq_f32(vec_in.val[3], vdupq_n_f32(sum_inversed)), + }; + convert_float32x4x4_to_unit8x16(mul, normalized_value); } + vst1q_u8(out_ptr + x, normalized_value); } /* Run remaining elements */ - for(; i < input_width; ++i) + for(; x < input_width; ++x) { if(is_log) { - out_ptr[i] = utils::cast::saturate_cast(tmp_ptr[i] - sum); + out_ptr[x] = utils::cast::saturate_cast(tmp_ptr[x] - sum); } else { - out_ptr[i] = utils::cast::saturate_cast(tmp_ptr[i] * sum_inversed); + out_ptr[x] = utils::cast::saturate_cast(tmp_ptr[x] * sum_inversed); } } } @@ -738,6 +432,12 @@ void logits_1d_softmax_float(const ITensor &in, const ITensor &max, void *const Iterator max_it(&max, window); Iterator out_it(&out, window); + /** NEON vector tag type. */ + using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t; + + constexpr int vec_size = 16 / sizeof(T); + const int sum_stages = log2(vec_size / 2); + execute_window_loop(window, [&](const Coordinates &) { /* Get pointers */ @@ -752,53 +452,54 @@ void logits_1d_softmax_float(const ITensor &in, const ITensor &max, void *const { /* Get max value */ const auto max_val = *reinterpret_cast(max_it.ptr()); - const auto vec_max = vdup_n>(max_val); + const auto vec_max = wrapper::vdup_n(max_val, ExactTagType{}); /* Init sum to zero */ - auto vec_sum = vdup_n>(0); + auto vec_sum = wrapper::vdup_n(static_cast(0), ExactTagType{}); /* Loop over row and compute exponentials and sum */ - int i = 0; - constexpr int vec_size = vec_size_of(vec_sum); - - for(; i <= (input_width - vec_size); i += vec_size) + int x = 0; + for(; x <= (input_width - vec_size); x += vec_size) { - auto vec_elements = vld>(in_ptr + i); - vec_elements = vsub(vec_elements, vec_max); + auto vec_elements = wrapper::vloadq(in_ptr + x); + vec_elements = wrapper::vsub(vec_elements, vec_max); if(is_log) { - vec_elements = vmul_n(vec_elements, static_cast(beta)); - vec_sum = vadd(vec_sum, vexp(vec_elements)); + vec_elements = wrapper::vmul(vec_elements, wrapper::vdup_n(static_cast(beta), ExactTagType{})); + vec_sum = wrapper::vadd(vec_sum, wrapper::vexpq(vec_elements)); } else { - vec_elements = vexp(vmul_n(vec_elements, static_cast(beta))); - vec_sum = vadd(vec_sum, vec_elements); + vec_elements = wrapper::vexpq(wrapper::vmul(vec_elements, wrapper::vdup_n(static_cast(beta), ExactTagType{}))); + vec_sum = wrapper::vadd(vec_sum, vec_elements); } - vst(tmp_ptr + i, vec_elements); + wrapper::vstore(tmp_ptr + x, vec_elements); } /* Reduce sum */ - const auto sum_8_byte = vadd(vget_high(vec_sum), vget_low(vec_sum)); - sum = reduce_add([](T a, T b) -> T { return a + b; }, sum_8_byte); + auto sum_res = wrapper::vpadd(wrapper::vgethigh(vec_sum), wrapper::vgetlow(vec_sum)); + for(int i = 0; i < sum_stages; ++i) + { + sum_res = wrapper::vpadd(sum_res, sum_res); + } + sum = wrapper::vgetlane(sum_res, 0); /* Run remaining elements */ - - for(; i < input_width; ++i) + for(; x < input_width; ++x) { T element{}; if(is_log) { - element = (in_ptr[i] - max_val) * beta; + element = (in_ptr[x] - max_val) * beta; sum += std::exp(element); } else { - element = std::exp((in_ptr[i] - max_val) * beta); + element = std::exp((in_ptr[x] - max_val) * beta); sum += element; } - tmp_ptr[i] = element; + tmp_ptr[x] = element; } if(!is_log) @@ -810,36 +511,31 @@ void logits_1d_softmax_float(const ITensor &in, const ITensor &max, void *const /* Normalize exponentials */ { /* Loop over row and compute softmax */ - int i = 0; - + int x = 0; + for(; x <= (input_width - vec_size); x += vec_size) { - constexpr int vec_size = vec_size_of(vec_16_byte_t {}); - - for(; i <= (input_width - vec_size); i += vec_size) + auto vec_in = wrapper::vloadq(tmp_ptr + x); + auto normalized_value = wrapper::vdup_n(static_cast(0), ExactTagType{}); + if(is_log) { - auto vec_in = vld>(tmp_ptr + i); - vec_16_byte_t normalized_value{}; - if(is_log) - { - normalized_value = vsub(vec_in, vdup_n>(sum)); - } - else - { - normalized_value = vmul_n(vec_in, sum_inversed); - } - vst(out_ptr + i, normalized_value); + normalized_value = wrapper::vsub(vec_in, wrapper::vdup_n(static_cast(sum), ExactTagType{})); + } + else + { + normalized_value = wrapper::vmul(vec_in, wrapper::vdup_n(static_cast(sum_inversed), ExactTagType{})); } + wrapper::vstore(out_ptr + x, normalized_value); } /* Run remaining elements */ - for(; i < input_width; ++i) + for(; x < input_width; ++x) { if(is_log) { - out_ptr[i] = tmp_ptr[i] - sum; + out_ptr[x] = tmp_ptr[x] - sum; } else { - out_ptr[i] = tmp_ptr[i] * sum_inversed; + out_ptr[x] = tmp_ptr[x] * sum_inversed; } } } diff --git a/tests/validation/NEON/SoftmaxLayer.cpp b/tests/validation/NEON/SoftmaxLayer.cpp index 8f91b51d9a..7f8c622ef9 100644 --- a/tests/validation/NEON/SoftmaxLayer.cpp +++ b/tests/validation/NEON/SoftmaxLayer.cpp @@ -162,12 +162,12 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NESoftmaxLayerFixture, framework::Dataset validate(Accessor(_target), _reference, tolerance_f16); } FIXTURE_DATA_TEST_CASE(RunSmall4D, NESoftmaxLayerFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::Small4DShapes(), - framework::dataset::make("DataType", DataType::F32)), + framework::dataset::make("DataType", DataType::F16)), framework::dataset::make("Beta", { 1.0f, 2.0f })), framework::dataset::make("Axis", { 1, 2, 3 }))) { // Validate output - validate(Accessor(_target), _reference, tolerance_f32); + validate(Accessor(_target), _reference, tolerance_f16); } FIXTURE_DATA_TEST_CASE(RunLarge, NESoftmaxLayerFixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::SoftmaxLayerLargeShapes(), framework::dataset::make("DataType", DataType::F16)), -- cgit v1.2.1