diff options
Diffstat (limited to 'src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp')
-rw-r--r-- | src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp | 36 |
1 files changed, 18 insertions, 18 deletions
diff --git a/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp b/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp index c2c85f81ef..8e2b88f5a5 100644 --- a/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp +++ b/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp @@ -86,10 +86,10 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *bias, con return Status{}; } -template <typename T, bool has_bias> +template <typename T> typename std::enable_if<arm_compute::utils::traits::is_floating_point<T>::value, void>::type output_stage_nchw(ITensor *input, const ITensor *bias, const Window &window, ITensor *output, - int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift) + int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift, bool has_bias) { /** NEON vector tag type. */ using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>; @@ -147,10 +147,10 @@ output_stage_nchw(ITensor *input, const ITensor *bias, const Window &window, ITe in, out); } -template <typename T, bool has_bias> +template <typename T> typename std::enable_if<arm_compute::utils::traits::is_floating_point<T>::value, void>::type output_stage_nhwc(ITensor *input, const ITensor *bias, const Window &window, ITensor *output, - int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift) + int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift, bool has_bias) { ARM_COMPUTE_UNUSED(result_fixedpoint_multiplier); ARM_COMPUTE_UNUSED(result_shift); @@ -213,9 +213,9 @@ output_stage_nhwc(ITensor *input, const ITensor *bias, const Window &window, ITe } // Quantized case -template < typename TOut, bool has_bias, typename std::enable_if < std::is_same<TOut, uint8_t>::value || std::is_same<TOut, int8_t>::value, int >::type = 0 > +template < typename TOut, typename std::enable_if < std::is_same<TOut, uint8_t>::value || std::is_same<TOut, int8_t>::value, int >::type = 0 > void output_stage_nchw(ITensor *input, const ITensor *bias, const Window &window, ITensor *output, - int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift) + int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift, bool has_bias) { using VectorType = typename wrapper::traits::neon_bitvector_t<TOut, wrapper::traits::BitWidth::W128>; using TagType = typename wrapper::traits::neon_bitvector_tag_t<TOut, wrapper::traits::BitWidth::W128>; @@ -292,9 +292,9 @@ void output_stage_nchw(ITensor *input, const ITensor *bias, const Window &window }, in, out); } -template < typename TOut, bool has_bias, typename std::enable_if < std::is_same<TOut, uint8_t>::value || std::is_same<TOut, int8_t>::value, int >::type = 0 > +template < typename TOut, typename std::enable_if < std::is_same<TOut, uint8_t>::value || std::is_same<TOut, int8_t>::value, int >::type = 0 > void output_stage_nhwc(ITensor *input, const ITensor *bias, const Window &window, ITensor *output, - int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift) + int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift, bool has_bias) { using VectorType = typename wrapper::traits::neon_bitvector_t<TOut, wrapper::traits::BitWidth::W128>; using TagType = typename wrapper::traits::neon_bitvector_tag_t<TOut, wrapper::traits::BitWidth::W128>; @@ -419,7 +419,6 @@ void NEDirectConvolutionLayerOutputStageKernel::configure(ITensor *input, const INEKernel::configure(win); - const bool has_bias = bias != nullptr; const bool is_qasymm8_signed = (output != nullptr) ? is_data_type_quantized_asymmetric_signed(output->info()->data_type()) : false; // Set appropriate function @@ -431,24 +430,24 @@ void NEDirectConvolutionLayerOutputStageKernel::configure(ITensor *input, const { if(is_qasymm8_signed) { - _func = (has_bias) ? &output_stage_nchw<int8_t, true> : &output_stage_nchw<int8_t, false>; + _func = &output_stage_nchw<int8_t>; } else { - _func = (has_bias) ? &output_stage_nchw<uint8_t, true> : &output_stage_nchw<uint8_t, false>; + _func = &output_stage_nchw<uint8_t>; } break; } #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC case DataType::F16: { - _func = (has_bias) ? &output_stage_nchw<float16_t, true> : &output_stage_nchw<float16_t, false>; + _func = &output_stage_nchw<float16_t>; break; } #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ case DataType::F32: { - _func = (has_bias) ? &output_stage_nchw<float, true> : &output_stage_nchw<float, false>; + _func = &output_stage_nchw<float>; break; } default: @@ -465,24 +464,24 @@ void NEDirectConvolutionLayerOutputStageKernel::configure(ITensor *input, const { if(is_qasymm8_signed) { - _func = (has_bias) ? &output_stage_nhwc<int8_t, true> : &output_stage_nhwc<int8_t, false>; + _func = &output_stage_nhwc<int8_t>; } else { - _func = (has_bias) ? &output_stage_nhwc<uint8_t, true> : &output_stage_nhwc<uint8_t, false>; + _func = &output_stage_nhwc<uint8_t>; } break; } #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC case DataType::F16: { - _func = (has_bias) ? &output_stage_nhwc<float16_t, true> : &output_stage_nhwc<float16_t, false>; + _func = &output_stage_nhwc<float16_t>; break; } #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ case DataType::F32: { - _func = (has_bias) ? &output_stage_nhwc<float, true> : &output_stage_nhwc<float, false>; + _func = &output_stage_nhwc<float>; break; } default: @@ -508,6 +507,7 @@ void NEDirectConvolutionLayerOutputStageKernel::run(const Window &window, const ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window); ARM_COMPUTE_ERROR_ON(_func == nullptr); - (*_func)(_input, _bias, window, _output, _result_fixedpoint_multiplier, _result_shift, _result_offset_after_shift); + const bool has_bias = _bias != nullptr; + (*_func)(_input, _bias, window, _output, _result_fixedpoint_multiplier, _result_shift, _result_offset_after_shift, has_bias); } } // namespace arm_compute |