diff options
Diffstat (limited to 'src/core')
-rw-r--r-- | src/core/NEON/kernels/NEQLSTMLayerNormalizationKernel.cpp | 17 |
1 files changed, 12 insertions, 5 deletions
diff --git a/src/core/NEON/kernels/NEQLSTMLayerNormalizationKernel.cpp b/src/core/NEON/kernels/NEQLSTMLayerNormalizationKernel.cpp index db2ff85db9..e966c6bdba 100644 --- a/src/core/NEON/kernels/NEQLSTMLayerNormalizationKernel.cpp +++ b/src/core/NEON/kernels/NEQLSTMLayerNormalizationKernel.cpp @@ -80,11 +80,9 @@ inline int64x2x2_t mul_add(const int32x4_t &a, const int32x4_t &b, const int32x4 void NEQLSTMLayerNormalizationKernel::configure(const ITensor *input, ITensor *output, const ITensor *weight, const ITensor *bias) { - ARM_COMPUTE_ERROR_ON_NULLPTR(input, weight); - ARM_COMPUTE_ERROR_THROW_ON(validate(input->info(), - output ? output->info() : nullptr, - weight->info(), - bias ? bias->info() : nullptr)); + ARM_COMPUTE_ERROR_ON_NULLPTR(input, weight, bias, output); + ARM_COMPUTE_ERROR_ON(input == output); + ARM_COMPUTE_ERROR_THROW_ON(validate(input->info(), output->info(), weight->info(), bias->info())); static const std::map<DataType, ComputeFuncType> fn_map = { @@ -98,6 +96,7 @@ void NEQLSTMLayerNormalizationKernel::configure(const ITensor *input, ITensor *o _fn = fn_map.at(_input->info()->data_type()); auto_init_if_empty(*_output->info(), *_input->info()); + _output->info()->set_quantization_info(compute_output_qinfo()); const UniformQuantizationInfo wq_info = _weight->info()->quantization_info().uniform(); const Status s = quantization::calculate_quantized_multiplier(wq_info.scale, &_output_multiplier, &_output_shift); @@ -171,6 +170,14 @@ void NEQLSTMLayerNormalizationKernel::run(const Window &window, const ThreadInfo _fn(*this); } +inline QuantizationInfo NEQLSTMLayerNormalizationKernel::compute_output_qinfo() +{ + const UniformQuantizationInfo iq_info = _input->info()->quantization_info().uniform(); + const UniformQuantizationInfo wq_info = _weight->info()->quantization_info().uniform(); + const float output_scale = (wq_info.scale * iq_info.scale) * 1024; + return QuantizationInfo(output_scale); +} + inline std::pair<int64_t, int64_t> NEQLSTMLayerNormalizationKernel::sum_qsymm16(const int16_t *input_ptr) { ARM_COMPUTE_ERROR_ON(!input_ptr); |