From 42447c120809ee6e767cdeda8a2e52d011519b1d Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Mon, 16 Jul 2018 17:01:20 +0100 Subject: COMPMID-1188: Fixes LSTM IO dimension requirements. Change-Id: Iee92ccce6422368c19173174e6f58e7aada12233 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/140143 Tested-by: Jenkins Reviewed-by: Anthony Barbier --- src/runtime/CL/functions/CLLSTMLayer.cpp | 40 ++++++++++++++++---------------- 1 file changed, 20 insertions(+), 20 deletions(-) (limited to 'src/runtime/CL/functions/CLLSTMLayer.cpp') diff --git a/src/runtime/CL/functions/CLLSTMLayer.cpp b/src/runtime/CL/functions/CLLSTMLayer.cpp index 86e5eb9090..872325175d 100644 --- a/src/runtime/CL/functions/CLLSTMLayer.cpp +++ b/src/runtime/CL/functions/CLLSTMLayer.cpp @@ -295,27 +295,27 @@ Status CLLSTMLayer::validate(const ITensorInfo *input, const ITensorInfo *input_ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights, recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights, forget_gate_bias, cell_bias, output_gate_bias, output_state, cell_state); - ARM_COMPUTE_RETURN_ERROR_ON(input->num_dimensions() != 2); - ARM_COMPUTE_RETURN_ERROR_ON(input_to_forget_weights->num_dimensions() != 2); - ARM_COMPUTE_RETURN_ERROR_ON(input_to_cell_weights->num_dimensions() != 2); - ARM_COMPUTE_RETURN_ERROR_ON(input_to_output_weights->num_dimensions() != 2); - ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_forget_weights->num_dimensions() != 2); - ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_cell_weights->num_dimensions() != 2); - ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_output_weights->num_dimensions() != 2); - ARM_COMPUTE_RETURN_ERROR_ON(forget_gate_bias->num_dimensions() != 1); - ARM_COMPUTE_RETURN_ERROR_ON(cell_bias->num_dimensions() != 1); - ARM_COMPUTE_RETURN_ERROR_ON(output_gate_bias->num_dimensions() != 1); - ARM_COMPUTE_RETURN_ERROR_ON(output_state->num_dimensions() != 2); - ARM_COMPUTE_RETURN_ERROR_ON(cell_state->num_dimensions() != 2); - ARM_COMPUTE_RETURN_ERROR_ON(scratch_buffer->num_dimensions() != 2); - ARM_COMPUTE_RETURN_ERROR_ON(output->num_dimensions() != 2); + ARM_COMPUTE_RETURN_ERROR_ON(input->num_dimensions() > 2); + ARM_COMPUTE_RETURN_ERROR_ON(input_to_forget_weights->num_dimensions() > 2); + ARM_COMPUTE_RETURN_ERROR_ON(input_to_cell_weights->num_dimensions() > 2); + ARM_COMPUTE_RETURN_ERROR_ON(input_to_output_weights->num_dimensions() > 2); + ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_forget_weights->num_dimensions() > 2); + ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_cell_weights->num_dimensions() > 2); + ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_output_weights->num_dimensions() > 2); + ARM_COMPUTE_RETURN_ERROR_ON(forget_gate_bias->num_dimensions() > 1); + ARM_COMPUTE_RETURN_ERROR_ON(cell_bias->num_dimensions() > 1); + ARM_COMPUTE_RETURN_ERROR_ON(output_gate_bias->num_dimensions() > 1); + ARM_COMPUTE_RETURN_ERROR_ON(output_state->num_dimensions() > 2); + ARM_COMPUTE_RETURN_ERROR_ON(cell_state->num_dimensions() > 2); + ARM_COMPUTE_RETURN_ERROR_ON(scratch_buffer->num_dimensions() > 2); + ARM_COMPUTE_RETURN_ERROR_ON(output->num_dimensions() > 2); ARM_COMPUTE_RETURN_ERROR_ON(cell_bias->dimension(0) * 4 != scratch_buffer->dimension(0) && cell_bias->dimension(0) * 3 != scratch_buffer->dimension(0)); if(lstm_params.has_peephole_opt()) { ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.cell_to_output_weights(), lstm_params.cell_to_forget_weights()); - ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_forget_weights()->num_dimensions() != 1); - ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_output_weights()->num_dimensions() != 1); + ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_forget_weights()->num_dimensions() > 1); + ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_output_weights()->num_dimensions() > 1); } TensorShape units_out_transposed_shape = compute_transposed_shape(*recurrent_to_output_weights); @@ -340,10 +340,10 @@ Status CLLSTMLayer::validate(const ITensorInfo *input, const ITensorInfo *input_ if(!lstm_params.has_cifg_opt()) { ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.input_to_input_weights(), lstm_params.recurrent_to_input_weights(), lstm_params.cell_to_input_weights(), lstm_params.input_gate_bias()); - ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_to_input_weights()->num_dimensions() != 2); - ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.recurrent_to_input_weights()->num_dimensions() != 2); - ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_input_weights()->num_dimensions() != 1); - ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_gate_bias()->num_dimensions() != 1); + ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_to_input_weights()->num_dimensions() > 2); + ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.recurrent_to_input_weights()->num_dimensions() > 2); + ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_input_weights()->num_dimensions() > 1); + ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_gate_bias()->num_dimensions() > 1); ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(input, lstm_params.input_to_input_weights(), lstm_params.input_gate_bias(), cell_state, true, false)); ARM_COMPUTE_RETURN_ON_ERROR(CLGEMM::validate(cell_state, &num_units_transposed_info, nullptr, &gemmv_shape_info, 1.f, 0.f, GEMMInfo())); ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(cell_state, &gemmv_shape_info, cell_state, ConvertPolicy::SATURATE)); -- cgit v1.2.1