From bfd10a1a766831fedc3a2991e70dfe388ebeb422 Mon Sep 17 00:00:00 2001 From: Pablo Marquez Tello Date: Mon, 21 Mar 2022 14:56:38 +0000 Subject: QLSTM add support for different qinfo in weights Resolves COMPMID-5185 Change-Id: I61e1453e8851ab84c1cadc10587ebd23fd94799e Signed-off-by: Pablo Marquez Tello Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7330 Tested-by: Arm Jenkins Reviewed-by: Giorgio Arena Comments-Addressed: Arm Jenkins --- src/runtime/NEON/functions/NEQLSTMLayer.cpp | 33 ++++++++++++++++++++--------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/src/runtime/NEON/functions/NEQLSTMLayer.cpp b/src/runtime/NEON/functions/NEQLSTMLayer.cpp index fdfe95fb64..76bb8c01d2 100644 --- a/src/runtime/NEON/functions/NEQLSTMLayer.cpp +++ b/src/runtime/NEON/functions/NEQLSTMLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021 Arm Limited. + * Copyright (c) 2020-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -167,6 +167,13 @@ void NEQLSTMLayer::configure(const ITensor *input, LSTMParams lstm_params_info{}; build_lstm_params_tensor_info(lstm_params, &lstm_params_info); + _input_to_forget_weights_transposed.info()->set_quantization_info(input_to_forget_weights->info()->quantization_info()); + _input_to_cell_weights_transposed.info()->set_quantization_info(input_to_cell_weights->info()->quantization_info()); + _input_to_output_weights_transposed.info()->set_quantization_info(input_to_output_weights->info()->quantization_info()); + _recurrent_to_forget_weights_transposed.info()->set_quantization_info(recurrent_to_forget_weights->info()->quantization_info()); + _recurrent_to_cell_weights_transposed.info()->set_quantization_info(recurrent_to_cell_weights->info()->quantization_info()); + _recurrent_to_output_weights_transposed.info()->set_quantization_info(recurrent_to_output_weights->info()->quantization_info()); + // Validate ARM_COMPUTE_ERROR_THROW_ON(NEQLSTMLayer::validate(input->info(), input_to_forget_weights->info(), input_to_cell_weights->info(), input_to_output_weights->info(), recurrent_to_forget_weights->info(), recurrent_to_cell_weights->info(), recurrent_to_output_weights->info(), @@ -689,20 +696,26 @@ Status NEQLSTMLayer::validate(const ITensorInfo *input, } } - const TensorInfo input_weights_transposed(TensorShape(num_units, input_size), 1, input_to_forget_weights->data_type(), input_to_forget_weights->quantization_info()); + const TensorInfo input_weights_transposed(TensorShape(num_units, input_size), 1, input_to_forget_weights->data_type(), input_to_cell_weights->quantization_info()); + const TensorInfo input_to_output_weights_transposed(TensorShape(num_units, input_size), 1, input_to_output_weights->data_type(), input_to_output_weights->quantization_info()); + const TensorInfo recurrent_to_forget_weights_transposed(TensorShape(num_units, output_size), 1, recurrent_to_forget_weights->data_type(), recurrent_to_forget_weights->quantization_info()); + const TensorInfo recurrent_to_cell_weights_transposed(TensorShape(num_units, output_size), 1, recurrent_to_forget_weights->data_type(), recurrent_to_cell_weights->quantization_info()); + const TensorInfo recurrent_to_output_weights_transposed(TensorShape(num_units, output_size), 1, recurrent_to_forget_weights->data_type(), recurrent_to_output_weights->quantization_info()); const TensorInfo recurrent_weights_transposed(TensorShape(num_units, output_size), 1, recurrent_to_forget_weights->data_type(), recurrent_to_forget_weights->quantization_info()); - // Validate weights transpose - ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(input_to_forget_weights, &input_weights_transposed)); ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(input_to_cell_weights, &input_weights_transposed)); - ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(input_to_output_weights, &input_weights_transposed)); - ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(recurrent_to_forget_weights, &recurrent_weights_transposed)); - ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(recurrent_to_cell_weights, &recurrent_weights_transposed)); - ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(recurrent_to_output_weights, &recurrent_weights_transposed)); + ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(input_to_output_weights, &input_to_output_weights_transposed)); + ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(recurrent_to_forget_weights, &recurrent_to_forget_weights_transposed)); + ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(recurrent_to_cell_weights, &recurrent_to_cell_weights_transposed)); + ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(recurrent_to_output_weights, &recurrent_to_output_weights_transposed)); if(!lstm_params.has_cifg_opt()) { - ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(lstm_params.input_to_input_weights(), &input_weights_transposed)); - ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(lstm_params.recurrent_to_input_weights(), &recurrent_weights_transposed)); + const TensorInfo recurrent_to_input_weights_transposed(TensorShape(num_units, output_size), 1, + recurrent_to_forget_weights->data_type(), lstm_params.recurrent_to_input_weights()->quantization_info()); + const TensorInfo input_to_input_weights_transposed(TensorShape(num_units, input_size), 1, + input_to_forget_weights->data_type(), lstm_params.input_to_input_weights()->quantization_info()); + ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(lstm_params.input_to_input_weights(), &input_to_input_weights_transposed)); + ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(lstm_params.recurrent_to_input_weights(), &recurrent_to_input_weights_transposed)); } if(lstm_params.has_projection()) { -- cgit v1.2.1