aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/runtime/CL/functions/CLLSTMLayer.cpp40
-rw-r--r--tests/datasets/LSTMLayerDataset.h1
2 files changed, 21 insertions, 20 deletions
diff --git a/src/runtime/CL/functions/CLLSTMLayer.cpp b/src/runtime/CL/functions/CLLSTMLayer.cpp
index 86e5eb9090..872325175d 100644
--- a/src/runtime/CL/functions/CLLSTMLayer.cpp
+++ b/src/runtime/CL/functions/CLLSTMLayer.cpp
@@ -295,27 +295,27 @@ Status CLLSTMLayer::validate(const ITensorInfo *input, const ITensorInfo *input_
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights, recurrent_to_forget_weights, recurrent_to_cell_weights,
recurrent_to_output_weights, forget_gate_bias, cell_bias, output_gate_bias, output_state, cell_state);
- ARM_COMPUTE_RETURN_ERROR_ON(input->num_dimensions() != 2);
- ARM_COMPUTE_RETURN_ERROR_ON(input_to_forget_weights->num_dimensions() != 2);
- ARM_COMPUTE_RETURN_ERROR_ON(input_to_cell_weights->num_dimensions() != 2);
- ARM_COMPUTE_RETURN_ERROR_ON(input_to_output_weights->num_dimensions() != 2);
- ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_forget_weights->num_dimensions() != 2);
- ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_cell_weights->num_dimensions() != 2);
- ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_output_weights->num_dimensions() != 2);
- ARM_COMPUTE_RETURN_ERROR_ON(forget_gate_bias->num_dimensions() != 1);
- ARM_COMPUTE_RETURN_ERROR_ON(cell_bias->num_dimensions() != 1);
- ARM_COMPUTE_RETURN_ERROR_ON(output_gate_bias->num_dimensions() != 1);
- ARM_COMPUTE_RETURN_ERROR_ON(output_state->num_dimensions() != 2);
- ARM_COMPUTE_RETURN_ERROR_ON(cell_state->num_dimensions() != 2);
- ARM_COMPUTE_RETURN_ERROR_ON(scratch_buffer->num_dimensions() != 2);
- ARM_COMPUTE_RETURN_ERROR_ON(output->num_dimensions() != 2);
+ ARM_COMPUTE_RETURN_ERROR_ON(input->num_dimensions() > 2);
+ ARM_COMPUTE_RETURN_ERROR_ON(input_to_forget_weights->num_dimensions() > 2);
+ ARM_COMPUTE_RETURN_ERROR_ON(input_to_cell_weights->num_dimensions() > 2);
+ ARM_COMPUTE_RETURN_ERROR_ON(input_to_output_weights->num_dimensions() > 2);
+ ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_forget_weights->num_dimensions() > 2);
+ ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_cell_weights->num_dimensions() > 2);
+ ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_output_weights->num_dimensions() > 2);
+ ARM_COMPUTE_RETURN_ERROR_ON(forget_gate_bias->num_dimensions() > 1);
+ ARM_COMPUTE_RETURN_ERROR_ON(cell_bias->num_dimensions() > 1);
+ ARM_COMPUTE_RETURN_ERROR_ON(output_gate_bias->num_dimensions() > 1);
+ ARM_COMPUTE_RETURN_ERROR_ON(output_state->num_dimensions() > 2);
+ ARM_COMPUTE_RETURN_ERROR_ON(cell_state->num_dimensions() > 2);
+ ARM_COMPUTE_RETURN_ERROR_ON(scratch_buffer->num_dimensions() > 2);
+ ARM_COMPUTE_RETURN_ERROR_ON(output->num_dimensions() > 2);
ARM_COMPUTE_RETURN_ERROR_ON(cell_bias->dimension(0) * 4 != scratch_buffer->dimension(0) && cell_bias->dimension(0) * 3 != scratch_buffer->dimension(0));
if(lstm_params.has_peephole_opt())
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.cell_to_output_weights(), lstm_params.cell_to_forget_weights());
- ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_forget_weights()->num_dimensions() != 1);
- ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_output_weights()->num_dimensions() != 1);
+ ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_forget_weights()->num_dimensions() > 1);
+ ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_output_weights()->num_dimensions() > 1);
}
TensorShape units_out_transposed_shape = compute_transposed_shape(*recurrent_to_output_weights);
@@ -340,10 +340,10 @@ Status CLLSTMLayer::validate(const ITensorInfo *input, const ITensorInfo *input_
if(!lstm_params.has_cifg_opt())
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.input_to_input_weights(), lstm_params.recurrent_to_input_weights(), lstm_params.cell_to_input_weights(), lstm_params.input_gate_bias());
- ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_to_input_weights()->num_dimensions() != 2);
- ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.recurrent_to_input_weights()->num_dimensions() != 2);
- ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_input_weights()->num_dimensions() != 1);
- ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_gate_bias()->num_dimensions() != 1);
+ ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_to_input_weights()->num_dimensions() > 2);
+ ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.recurrent_to_input_weights()->num_dimensions() > 2);
+ ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_input_weights()->num_dimensions() > 1);
+ ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_gate_bias()->num_dimensions() > 1);
ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(input, lstm_params.input_to_input_weights(), lstm_params.input_gate_bias(), cell_state, true, false));
ARM_COMPUTE_RETURN_ON_ERROR(CLGEMM::validate(cell_state, &num_units_transposed_info, nullptr, &gemmv_shape_info, 1.f, 0.f, GEMMInfo()));
ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(cell_state, &gemmv_shape_info, cell_state, ConvertPolicy::SATURATE));
diff --git a/tests/datasets/LSTMLayerDataset.h b/tests/datasets/LSTMLayerDataset.h
index 51802fd75b..a976caa0ba 100644
--- a/tests/datasets/LSTMLayerDataset.h
+++ b/tests/datasets/LSTMLayerDataset.h
@@ -160,6 +160,7 @@ class SmallLSTMLayerDataset final : public LSTMLayerDataset
public:
SmallLSTMLayerDataset()
{
+ add_config(TensorShape(8U), TensorShape(8U, 16U), TensorShape(16U, 16U), TensorShape(16U), TensorShape(16U), TensorShape(16U), TensorShape(64U), ActivationLayerInfo(), 0.05f, 0.93f);
add_config(TensorShape(8U, 2U), TensorShape(8U, 16U), TensorShape(16U, 16U), TensorShape(16U), TensorShape(16U, 2U), TensorShape(16U, 2U), TensorShape(64U, 2U), ActivationLayerInfo(), 0.05f, 0.93f);
add_config(TensorShape(8U, 2U), TensorShape(8U, 16U), TensorShape(16U, 16U), TensorShape(16U), TensorShape(16U, 2U), TensorShape(16U, 2U), TensorShape(48U, 2U), ActivationLayerInfo(), 0.05f, 0.93f);
}