diff options
Diffstat (limited to 'src/core/NEON/kernels/NEIm2ColKernel.cpp')
-rw-r--r-- | src/core/NEON/kernels/NEIm2ColKernel.cpp | 21 |
1 files changed, 18 insertions, 3 deletions
diff --git a/src/core/NEON/kernels/NEIm2ColKernel.cpp b/src/core/NEON/kernels/NEIm2ColKernel.cpp index 8c9d12c57c..5bb8b1c22a 100644 --- a/src/core/NEON/kernels/NEIm2ColKernel.cpp +++ b/src/core/NEON/kernels/NEIm2ColKernel.cpp @@ -134,10 +134,14 @@ inline void linearize_volume(const uint8_t *const in_ptr, // Append 1 if the convolution layer has biases if(has_bias) { - if(std::is_same<T, arm_compute::qint8_t>::value) + if(std::is_same<T, qint8_t>::value) { *out_ptr = scvt_qs8_f32(1.0f, fixed_point_position); } + else if(std::is_same<T, qint16_t>::value) + { + *out_ptr = scvt_qs16_f32(1.0f, fixed_point_position); + } else { *out_ptr = static_cast<T>(1); @@ -249,10 +253,14 @@ void NEIm2ColKernel::run_reduced(const Window &window) // Add bias if(_has_bias) { - if(std::is_same<T, arm_compute::qint8_t>::value) + if(std::is_same<T, qint8_t>::value) { *(reinterpret_cast<T *>(out_ptr) + out_width - 1) = scvt_qs8_f32(1.0f, _input->info()->fixed_point_position()); } + else if(std::is_same<T, qint16_t>::value) + { + *(reinterpret_cast<T *>(out_ptr) + out_width - 1) = scvt_qs16_f32(1.0f, _input->info()->fixed_point_position()); + } else { *(reinterpret_cast<T *>(out_ptr) + out_width - 1) = static_cast<T>(1); @@ -269,8 +277,9 @@ NEIm2ColKernel::NEIm2ColKernel() void NEIm2ColKernel::configure(const ITensor *input, ITensor *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias) { - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32, DataType::QS8); + ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32, DataType::QS8, DataType::QS16); ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output); + ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, output); _input = input; _output = output; @@ -309,6 +318,9 @@ void NEIm2ColKernel::configure(const ITensor *input, ITensor *output, const Size case DataType::QS8: _func = &NEIm2ColKernel::run_reduced<qint8_t>; break; + case DataType::QS16: + _func = &NEIm2ColKernel::run_reduced<qint16_t>; + break; default: ARM_COMPUTE_ERROR("Data type not supported"); break; @@ -329,6 +341,9 @@ void NEIm2ColKernel::configure(const ITensor *input, ITensor *output, const Size case DataType::QS8: _func = ((pad_x == 0) && (pad_y == 0)) ? &NEIm2ColKernel::run_generic<qint8_t, false> : &NEIm2ColKernel::run_generic<qint8_t, true>; break; + case DataType::QS16: + _func = ((pad_x == 0) && (pad_y == 0)) ? &NEIm2ColKernel::run_generic<qint16_t, false> : &NEIm2ColKernel::run_generic<qint16_t, true>; + break; default: ARM_COMPUTE_ERROR("Data type not supported"); break; |