aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp')
-rw-r--r--src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp36
1 files changed, 18 insertions, 18 deletions
diff --git a/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp b/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp
index c2c85f81ef..8e2b88f5a5 100644
--- a/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp
+++ b/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp
@@ -86,10 +86,10 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *bias, con
return Status{};
}
-template <typename T, bool has_bias>
+template <typename T>
typename std::enable_if<arm_compute::utils::traits::is_floating_point<T>::value, void>::type
output_stage_nchw(ITensor *input, const ITensor *bias, const Window &window, ITensor *output,
- int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift)
+ int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift, bool has_bias)
{
/** NEON vector tag type. */
using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>;
@@ -147,10 +147,10 @@ output_stage_nchw(ITensor *input, const ITensor *bias, const Window &window, ITe
in, out);
}
-template <typename T, bool has_bias>
+template <typename T>
typename std::enable_if<arm_compute::utils::traits::is_floating_point<T>::value, void>::type
output_stage_nhwc(ITensor *input, const ITensor *bias, const Window &window, ITensor *output,
- int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift)
+ int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift, bool has_bias)
{
ARM_COMPUTE_UNUSED(result_fixedpoint_multiplier);
ARM_COMPUTE_UNUSED(result_shift);
@@ -213,9 +213,9 @@ output_stage_nhwc(ITensor *input, const ITensor *bias, const Window &window, ITe
}
// Quantized case
-template < typename TOut, bool has_bias, typename std::enable_if < std::is_same<TOut, uint8_t>::value || std::is_same<TOut, int8_t>::value, int >::type = 0 >
+template < typename TOut, typename std::enable_if < std::is_same<TOut, uint8_t>::value || std::is_same<TOut, int8_t>::value, int >::type = 0 >
void output_stage_nchw(ITensor *input, const ITensor *bias, const Window &window, ITensor *output,
- int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift)
+ int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift, bool has_bias)
{
using VectorType = typename wrapper::traits::neon_bitvector_t<TOut, wrapper::traits::BitWidth::W128>;
using TagType = typename wrapper::traits::neon_bitvector_tag_t<TOut, wrapper::traits::BitWidth::W128>;
@@ -292,9 +292,9 @@ void output_stage_nchw(ITensor *input, const ITensor *bias, const Window &window
},
in, out);
}
-template < typename TOut, bool has_bias, typename std::enable_if < std::is_same<TOut, uint8_t>::value || std::is_same<TOut, int8_t>::value, int >::type = 0 >
+template < typename TOut, typename std::enable_if < std::is_same<TOut, uint8_t>::value || std::is_same<TOut, int8_t>::value, int >::type = 0 >
void output_stage_nhwc(ITensor *input, const ITensor *bias, const Window &window, ITensor *output,
- int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift)
+ int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift, bool has_bias)
{
using VectorType = typename wrapper::traits::neon_bitvector_t<TOut, wrapper::traits::BitWidth::W128>;
using TagType = typename wrapper::traits::neon_bitvector_tag_t<TOut, wrapper::traits::BitWidth::W128>;
@@ -419,7 +419,6 @@ void NEDirectConvolutionLayerOutputStageKernel::configure(ITensor *input, const
INEKernel::configure(win);
- const bool has_bias = bias != nullptr;
const bool is_qasymm8_signed = (output != nullptr) ? is_data_type_quantized_asymmetric_signed(output->info()->data_type()) : false;
// Set appropriate function
@@ -431,24 +430,24 @@ void NEDirectConvolutionLayerOutputStageKernel::configure(ITensor *input, const
{
if(is_qasymm8_signed)
{
- _func = (has_bias) ? &output_stage_nchw<int8_t, true> : &output_stage_nchw<int8_t, false>;
+ _func = &output_stage_nchw<int8_t>;
}
else
{
- _func = (has_bias) ? &output_stage_nchw<uint8_t, true> : &output_stage_nchw<uint8_t, false>;
+ _func = &output_stage_nchw<uint8_t>;
}
break;
}
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case DataType::F16:
{
- _func = (has_bias) ? &output_stage_nchw<float16_t, true> : &output_stage_nchw<float16_t, false>;
+ _func = &output_stage_nchw<float16_t>;
break;
}
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
case DataType::F32:
{
- _func = (has_bias) ? &output_stage_nchw<float, true> : &output_stage_nchw<float, false>;
+ _func = &output_stage_nchw<float>;
break;
}
default:
@@ -465,24 +464,24 @@ void NEDirectConvolutionLayerOutputStageKernel::configure(ITensor *input, const
{
if(is_qasymm8_signed)
{
- _func = (has_bias) ? &output_stage_nhwc<int8_t, true> : &output_stage_nhwc<int8_t, false>;
+ _func = &output_stage_nhwc<int8_t>;
}
else
{
- _func = (has_bias) ? &output_stage_nhwc<uint8_t, true> : &output_stage_nhwc<uint8_t, false>;
+ _func = &output_stage_nhwc<uint8_t>;
}
break;
}
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case DataType::F16:
{
- _func = (has_bias) ? &output_stage_nhwc<float16_t, true> : &output_stage_nhwc<float16_t, false>;
+ _func = &output_stage_nhwc<float16_t>;
break;
}
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
case DataType::F32:
{
- _func = (has_bias) ? &output_stage_nhwc<float, true> : &output_stage_nhwc<float, false>;
+ _func = &output_stage_nhwc<float>;
break;
}
default:
@@ -508,6 +507,7 @@ void NEDirectConvolutionLayerOutputStageKernel::run(const Window &window, const
ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
ARM_COMPUTE_ERROR_ON(_func == nullptr);
- (*_func)(_input, _bias, window, _output, _result_fixedpoint_multiplier, _result_shift, _result_offset_after_shift);
+ const bool has_bias = _bias != nullptr;
+ (*_func)(_input, _bias, window, _output, _result_fixedpoint_multiplier, _result_shift, _result_offset_after_shift, has_bias);
}
} // namespace arm_compute