diff options
Diffstat (limited to 'src/core/NEON/kernels/NEDirectConvolutionLayerBiasAccumulateKernel.cpp')
-rw-r--r-- | src/core/NEON/kernels/NEDirectConvolutionLayerBiasAccumulateKernel.cpp | 84 |
1 files changed, 64 insertions, 20 deletions
diff --git a/src/core/NEON/kernels/NEDirectConvolutionLayerBiasAccumulateKernel.cpp b/src/core/NEON/kernels/NEDirectConvolutionLayerBiasAccumulateKernel.cpp index fb16c8dcc1..12ef064803 100644 --- a/src/core/NEON/kernels/NEDirectConvolutionLayerBiasAccumulateKernel.cpp +++ b/src/core/NEON/kernels/NEDirectConvolutionLayerBiasAccumulateKernel.cpp @@ -54,6 +54,11 @@ inline qint16x8_t internal_vld1q(const qint16_t *in) return vld1q_qs16(in); } +inline qint32x4_t internal_vld1q(const qint32_t *in) +{ + return vld1q_s32(in); +} + // Internal store inline void internal_vst1q(float *p, const float32x4_t &v) { @@ -72,6 +77,16 @@ inline void internal_vst1q(qint16_t *p, const qint16x8_t &v) vst1q_qs16(p, v); } +inline void internal_vst1q(qint32_t *p, const qint32x4_t &v) +{ + vst1q_s32(p, v); +} + +inline void internal_vst1q(qint16_t *p, const qint32x4_t &v) +{ + vst1_qs16(p, vqmovn_qs32(v)); +} + // Internal vdup inline float32x4_t internal_vdupq_n(float v) { @@ -86,6 +101,11 @@ inline qint16x8_t internal_vdupq_n(qint16_t v) return vdupq_n_qs16(v); } +inline qint32x4_t internal_vdupq_n(qint32_t v) +{ + return vdupq_n_qs32(v); +} + // Internal vadd inline float32x4_t internal_vqaddq(const float32x4_t &x, const float32x4_t &y) { @@ -99,6 +119,10 @@ inline qint16x8_t internal_vqaddq(const qint16x8_t &x, const qint16x8_t &y) { return vqaddq_qs16(x, y); } +inline qint32x4_t internal_vqaddq(const qint32x4_t &x, const qint32x4_t &y) +{ + return vqaddq_qs32(x, y); +} #ifdef ARM_COMPUTE_ENABLE_FP16 inline float16x8_t internal_vld1q(const float16_t *in) @@ -162,8 +186,8 @@ NEDirectConvolutionLayerBiasAccumulateKernel::NEDirectConvolutionLayerBiasAccumu void NEDirectConvolutionLayerBiasAccumulateKernel::configure(ITensor *input, const ITensor *bias, ITensor *output) { - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32); - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(bias, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32); + ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::QS32, DataType::F32); + ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(bias, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::QS32, DataType::F32); ARM_COMPUTE_ERROR_ON(input->info()->fixed_point_position() != bias->info()->fixed_point_position()); if(output != nullptr) { @@ -198,27 +222,47 @@ void NEDirectConvolutionLayerBiasAccumulateKernel::configure(ITensor *input, con INEKernel::configure(win); // Set appropriate function - if(input->info()->data_type() == DataType::F32) + switch(input->info()->data_type()) { - _func = (output == nullptr) ? &accumulate_bias<float, float, true> : &accumulate_bias<float, float, false>; - } + case DataType::QS8: + { + _func = (output == nullptr) ? &accumulate_bias<qint8_t, qint8_t, true> : &accumulate_bias<qint8_t, qint8_t, false>; + break; + } + case DataType::QS16: + { + if(bias->info()->data_type() == DataType::QS8) + { + _func = (output == nullptr) ? &accumulate_bias<qint16_t, qint8_t, true> : &accumulate_bias<qint16_t, qint8_t, false>; + } + else + { + ARM_COMPUTE_ERROR("Not implemented"); + } + break; + } + case DataType::QS32: + { + _func = (output == nullptr) ? &accumulate_bias<qint32_t, qint16_t, true> : &accumulate_bias<qint32_t, qint16_t, false>; + break; + } #ifdef ARM_COMPUTE_ENABLE_FP16 - else if(input->info()->data_type() == DataType::F16) - { - _func = (output == nullptr) ? &accumulate_bias<float16_t, float16_t, true> : &accumulate_bias<float16_t, float16_t, false>; - } + case DataType::F16: + { + _func = (output == nullptr) ? &accumulate_bias<float16_t, float16_t, true> : &accumulate_bias<float16_t, float16_t, false>; + break; + } #endif /* ARM_COMPUTE_ENABLE_FP16 */ - else if(input->info()->data_type() == DataType::QS8) - { - _func = (output == nullptr) ? &accumulate_bias<qint8_t, qint8_t, true> : &accumulate_bias<qint8_t, qint8_t, false>; - } - else if(input->info()->data_type() == DataType::QS16 && bias->info()->data_type() == DataType::QS8) - { - _func = (output == nullptr) ? &accumulate_bias<qint16_t, qint8_t, true> : &accumulate_bias<qint16_t, qint8_t, false>; - } - else - { - ARM_COMPUTE_ERROR("Unsupported combination of types among the inputs."); + case DataType::F32: + { + _func = (output == nullptr) ? &accumulate_bias<float, float, true> : &accumulate_bias<float, float, false>; + break; + } + default: + { + ARM_COMPUTE_ERROR("Unsupported combination of types among the inputs."); + break; + } } } |