aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/NEQLSTMLayerNormalizationKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/NEQLSTMLayerNormalizationKernel.cpp')
-rw-r--r--src/core/NEON/kernels/NEQLSTMLayerNormalizationKernel.cpp17
1 files changed, 12 insertions, 5 deletions
diff --git a/src/core/NEON/kernels/NEQLSTMLayerNormalizationKernel.cpp b/src/core/NEON/kernels/NEQLSTMLayerNormalizationKernel.cpp
index db2ff85db9..e966c6bdba 100644
--- a/src/core/NEON/kernels/NEQLSTMLayerNormalizationKernel.cpp
+++ b/src/core/NEON/kernels/NEQLSTMLayerNormalizationKernel.cpp
@@ -80,11 +80,9 @@ inline int64x2x2_t mul_add(const int32x4_t &a, const int32x4_t &b, const int32x4
void NEQLSTMLayerNormalizationKernel::configure(const ITensor *input, ITensor *output, const ITensor *weight, const ITensor *bias)
{
- ARM_COMPUTE_ERROR_ON_NULLPTR(input, weight);
- ARM_COMPUTE_ERROR_THROW_ON(validate(input->info(),
- output ? output->info() : nullptr,
- weight->info(),
- bias ? bias->info() : nullptr));
+ ARM_COMPUTE_ERROR_ON_NULLPTR(input, weight, bias, output);
+ ARM_COMPUTE_ERROR_ON(input == output);
+ ARM_COMPUTE_ERROR_THROW_ON(validate(input->info(), output->info(), weight->info(), bias->info()));
static const std::map<DataType, ComputeFuncType> fn_map =
{
@@ -98,6 +96,7 @@ void NEQLSTMLayerNormalizationKernel::configure(const ITensor *input, ITensor *o
_fn = fn_map.at(_input->info()->data_type());
auto_init_if_empty(*_output->info(), *_input->info());
+ _output->info()->set_quantization_info(compute_output_qinfo());
const UniformQuantizationInfo wq_info = _weight->info()->quantization_info().uniform();
const Status s = quantization::calculate_quantized_multiplier(wq_info.scale, &_output_multiplier, &_output_shift);
@@ -171,6 +170,14 @@ void NEQLSTMLayerNormalizationKernel::run(const Window &window, const ThreadInfo
_fn(*this);
}
+inline QuantizationInfo NEQLSTMLayerNormalizationKernel::compute_output_qinfo()
+{
+ const UniformQuantizationInfo iq_info = _input->info()->quantization_info().uniform();
+ const UniformQuantizationInfo wq_info = _weight->info()->quantization_info().uniform();
+ const float output_scale = (wq_info.scale * iq_info.scale) * 1024;
+ return QuantizationInfo(output_scale);
+}
+
inline std::pair<int64_t, int64_t> NEQLSTMLayerNormalizationKernel::sum_qsymm16(const int16_t *input_ptr)
{
ARM_COMPUTE_ERROR_ON(!input_ptr);