diff options
Diffstat (limited to 'src/runtime/NEON/functions/NEQLSTMLayer.cpp')
-rw-r--r-- | src/runtime/NEON/functions/NEQLSTMLayer.cpp | 33 |
1 files changed, 23 insertions, 10 deletions
diff --git a/src/runtime/NEON/functions/NEQLSTMLayer.cpp b/src/runtime/NEON/functions/NEQLSTMLayer.cpp index fdfe95fb64..76bb8c01d2 100644 --- a/src/runtime/NEON/functions/NEQLSTMLayer.cpp +++ b/src/runtime/NEON/functions/NEQLSTMLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021 Arm Limited. + * Copyright (c) 2020-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -167,6 +167,13 @@ void NEQLSTMLayer::configure(const ITensor *input, LSTMParams<ITensorInfo> lstm_params_info{}; build_lstm_params_tensor_info(lstm_params, &lstm_params_info); + _input_to_forget_weights_transposed.info()->set_quantization_info(input_to_forget_weights->info()->quantization_info()); + _input_to_cell_weights_transposed.info()->set_quantization_info(input_to_cell_weights->info()->quantization_info()); + _input_to_output_weights_transposed.info()->set_quantization_info(input_to_output_weights->info()->quantization_info()); + _recurrent_to_forget_weights_transposed.info()->set_quantization_info(recurrent_to_forget_weights->info()->quantization_info()); + _recurrent_to_cell_weights_transposed.info()->set_quantization_info(recurrent_to_cell_weights->info()->quantization_info()); + _recurrent_to_output_weights_transposed.info()->set_quantization_info(recurrent_to_output_weights->info()->quantization_info()); + // Validate ARM_COMPUTE_ERROR_THROW_ON(NEQLSTMLayer::validate(input->info(), input_to_forget_weights->info(), input_to_cell_weights->info(), input_to_output_weights->info(), recurrent_to_forget_weights->info(), recurrent_to_cell_weights->info(), recurrent_to_output_weights->info(), @@ -689,20 +696,26 @@ Status NEQLSTMLayer::validate(const ITensorInfo *input, } } - const TensorInfo input_weights_transposed(TensorShape(num_units, input_size), 1, input_to_forget_weights->data_type(), input_to_forget_weights->quantization_info()); + const TensorInfo input_weights_transposed(TensorShape(num_units, input_size), 1, input_to_forget_weights->data_type(), input_to_cell_weights->quantization_info()); + const TensorInfo input_to_output_weights_transposed(TensorShape(num_units, input_size), 1, input_to_output_weights->data_type(), input_to_output_weights->quantization_info()); + const TensorInfo recurrent_to_forget_weights_transposed(TensorShape(num_units, output_size), 1, recurrent_to_forget_weights->data_type(), recurrent_to_forget_weights->quantization_info()); + const TensorInfo recurrent_to_cell_weights_transposed(TensorShape(num_units, output_size), 1, recurrent_to_forget_weights->data_type(), recurrent_to_cell_weights->quantization_info()); + const TensorInfo recurrent_to_output_weights_transposed(TensorShape(num_units, output_size), 1, recurrent_to_forget_weights->data_type(), recurrent_to_output_weights->quantization_info()); const TensorInfo recurrent_weights_transposed(TensorShape(num_units, output_size), 1, recurrent_to_forget_weights->data_type(), recurrent_to_forget_weights->quantization_info()); - // Validate weights transpose - ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(input_to_forget_weights, &input_weights_transposed)); ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(input_to_cell_weights, &input_weights_transposed)); - ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(input_to_output_weights, &input_weights_transposed)); - ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(recurrent_to_forget_weights, &recurrent_weights_transposed)); - ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(recurrent_to_cell_weights, &recurrent_weights_transposed)); - ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(recurrent_to_output_weights, &recurrent_weights_transposed)); + ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(input_to_output_weights, &input_to_output_weights_transposed)); + ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(recurrent_to_forget_weights, &recurrent_to_forget_weights_transposed)); + ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(recurrent_to_cell_weights, &recurrent_to_cell_weights_transposed)); + ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(recurrent_to_output_weights, &recurrent_to_output_weights_transposed)); if(!lstm_params.has_cifg_opt()) { - ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(lstm_params.input_to_input_weights(), &input_weights_transposed)); - ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(lstm_params.recurrent_to_input_weights(), &recurrent_weights_transposed)); + const TensorInfo recurrent_to_input_weights_transposed(TensorShape(num_units, output_size), 1, + recurrent_to_forget_weights->data_type(), lstm_params.recurrent_to_input_weights()->quantization_info()); + const TensorInfo input_to_input_weights_transposed(TensorShape(num_units, input_size), 1, + input_to_forget_weights->data_type(), lstm_params.input_to_input_weights()->quantization_info()); + ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(lstm_params.input_to_input_weights(), &input_to_input_weights_transposed)); + ARM_COMPUTE_RETURN_ON_ERROR(NETranspose::validate(lstm_params.recurrent_to_input_weights(), &recurrent_to_input_weights_transposed)); } if(lstm_params.has_projection()) { |