diff options
-rw-r--r-- | src/runtime/CL/functions/CLLSTMLayer.cpp | 40 | ||||
-rw-r--r-- | tests/datasets/LSTMLayerDataset.h | 1 |
2 files changed, 21 insertions, 20 deletions
diff --git a/src/runtime/CL/functions/CLLSTMLayer.cpp b/src/runtime/CL/functions/CLLSTMLayer.cpp index 86e5eb9090..872325175d 100644 --- a/src/runtime/CL/functions/CLLSTMLayer.cpp +++ b/src/runtime/CL/functions/CLLSTMLayer.cpp @@ -295,27 +295,27 @@ Status CLLSTMLayer::validate(const ITensorInfo *input, const ITensorInfo *input_ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights, recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights, forget_gate_bias, cell_bias, output_gate_bias, output_state, cell_state); - ARM_COMPUTE_RETURN_ERROR_ON(input->num_dimensions() != 2); - ARM_COMPUTE_RETURN_ERROR_ON(input_to_forget_weights->num_dimensions() != 2); - ARM_COMPUTE_RETURN_ERROR_ON(input_to_cell_weights->num_dimensions() != 2); - ARM_COMPUTE_RETURN_ERROR_ON(input_to_output_weights->num_dimensions() != 2); - ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_forget_weights->num_dimensions() != 2); - ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_cell_weights->num_dimensions() != 2); - ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_output_weights->num_dimensions() != 2); - ARM_COMPUTE_RETURN_ERROR_ON(forget_gate_bias->num_dimensions() != 1); - ARM_COMPUTE_RETURN_ERROR_ON(cell_bias->num_dimensions() != 1); - ARM_COMPUTE_RETURN_ERROR_ON(output_gate_bias->num_dimensions() != 1); - ARM_COMPUTE_RETURN_ERROR_ON(output_state->num_dimensions() != 2); - ARM_COMPUTE_RETURN_ERROR_ON(cell_state->num_dimensions() != 2); - ARM_COMPUTE_RETURN_ERROR_ON(scratch_buffer->num_dimensions() != 2); - ARM_COMPUTE_RETURN_ERROR_ON(output->num_dimensions() != 2); + ARM_COMPUTE_RETURN_ERROR_ON(input->num_dimensions() > 2); + ARM_COMPUTE_RETURN_ERROR_ON(input_to_forget_weights->num_dimensions() > 2); + ARM_COMPUTE_RETURN_ERROR_ON(input_to_cell_weights->num_dimensions() > 2); + ARM_COMPUTE_RETURN_ERROR_ON(input_to_output_weights->num_dimensions() > 2); + ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_forget_weights->num_dimensions() > 2); + ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_cell_weights->num_dimensions() > 2); + ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_output_weights->num_dimensions() > 2); + ARM_COMPUTE_RETURN_ERROR_ON(forget_gate_bias->num_dimensions() > 1); + ARM_COMPUTE_RETURN_ERROR_ON(cell_bias->num_dimensions() > 1); + ARM_COMPUTE_RETURN_ERROR_ON(output_gate_bias->num_dimensions() > 1); + ARM_COMPUTE_RETURN_ERROR_ON(output_state->num_dimensions() > 2); + ARM_COMPUTE_RETURN_ERROR_ON(cell_state->num_dimensions() > 2); + ARM_COMPUTE_RETURN_ERROR_ON(scratch_buffer->num_dimensions() > 2); + ARM_COMPUTE_RETURN_ERROR_ON(output->num_dimensions() > 2); ARM_COMPUTE_RETURN_ERROR_ON(cell_bias->dimension(0) * 4 != scratch_buffer->dimension(0) && cell_bias->dimension(0) * 3 != scratch_buffer->dimension(0)); if(lstm_params.has_peephole_opt()) { ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.cell_to_output_weights(), lstm_params.cell_to_forget_weights()); - ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_forget_weights()->num_dimensions() != 1); - ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_output_weights()->num_dimensions() != 1); + ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_forget_weights()->num_dimensions() > 1); + ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_output_weights()->num_dimensions() > 1); } TensorShape units_out_transposed_shape = compute_transposed_shape(*recurrent_to_output_weights); @@ -340,10 +340,10 @@ Status CLLSTMLayer::validate(const ITensorInfo *input, const ITensorInfo *input_ if(!lstm_params.has_cifg_opt()) { ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.input_to_input_weights(), lstm_params.recurrent_to_input_weights(), lstm_params.cell_to_input_weights(), lstm_params.input_gate_bias()); - ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_to_input_weights()->num_dimensions() != 2); - ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.recurrent_to_input_weights()->num_dimensions() != 2); - ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_input_weights()->num_dimensions() != 1); - ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_gate_bias()->num_dimensions() != 1); + ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_to_input_weights()->num_dimensions() > 2); + ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.recurrent_to_input_weights()->num_dimensions() > 2); + ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_input_weights()->num_dimensions() > 1); + ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_gate_bias()->num_dimensions() > 1); ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(input, lstm_params.input_to_input_weights(), lstm_params.input_gate_bias(), cell_state, true, false)); ARM_COMPUTE_RETURN_ON_ERROR(CLGEMM::validate(cell_state, &num_units_transposed_info, nullptr, &gemmv_shape_info, 1.f, 0.f, GEMMInfo())); ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(cell_state, &gemmv_shape_info, cell_state, ConvertPolicy::SATURATE)); diff --git a/tests/datasets/LSTMLayerDataset.h b/tests/datasets/LSTMLayerDataset.h index 51802fd75b..a976caa0ba 100644 --- a/tests/datasets/LSTMLayerDataset.h +++ b/tests/datasets/LSTMLayerDataset.h @@ -160,6 +160,7 @@ class SmallLSTMLayerDataset final : public LSTMLayerDataset public: SmallLSTMLayerDataset() { + add_config(TensorShape(8U), TensorShape(8U, 16U), TensorShape(16U, 16U), TensorShape(16U), TensorShape(16U), TensorShape(16U), TensorShape(64U), ActivationLayerInfo(), 0.05f, 0.93f); add_config(TensorShape(8U, 2U), TensorShape(8U, 16U), TensorShape(16U, 16U), TensorShape(16U), TensorShape(16U, 2U), TensorShape(16U, 2U), TensorShape(64U, 2U), ActivationLayerInfo(), 0.05f, 0.93f); add_config(TensorShape(8U, 2U), TensorShape(8U, 16U), TensorShape(16U, 16U), TensorShape(16U), TensorShape(16U, 2U), TensorShape(16U, 2U), TensorShape(48U, 2U), ActivationLayerInfo(), 0.05f, 0.93f); } |