aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/CL/functions/CLLSTMLayer.cpp
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2018-07-16 17:01:20 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:54:54 +0000
commit42447c120809ee6e767cdeda8a2e52d011519b1d (patch)
treeba54497796628cd0c12cc0c461cbf226ec975893 /src/runtime/CL/functions/CLLSTMLayer.cpp
parent6f109bdaed41bf9f37f201189f11ba30c60170fb (diff)
downloadComputeLibrary-42447c120809ee6e767cdeda8a2e52d011519b1d.tar.gz
COMPMID-1188: Fixes LSTM IO dimension requirements.
Change-Id: Iee92ccce6422368c19173174e6f58e7aada12233 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/140143 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'src/runtime/CL/functions/CLLSTMLayer.cpp')
-rw-r--r--src/runtime/CL/functions/CLLSTMLayer.cpp40
1 files changed, 20 insertions, 20 deletions
diff --git a/src/runtime/CL/functions/CLLSTMLayer.cpp b/src/runtime/CL/functions/CLLSTMLayer.cpp
index 86e5eb9090..872325175d 100644
--- a/src/runtime/CL/functions/CLLSTMLayer.cpp
+++ b/src/runtime/CL/functions/CLLSTMLayer.cpp
@@ -295,27 +295,27 @@ Status CLLSTMLayer::validate(const ITensorInfo *input, const ITensorInfo *input_
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights, recurrent_to_forget_weights, recurrent_to_cell_weights,
recurrent_to_output_weights, forget_gate_bias, cell_bias, output_gate_bias, output_state, cell_state);
- ARM_COMPUTE_RETURN_ERROR_ON(input->num_dimensions() != 2);
- ARM_COMPUTE_RETURN_ERROR_ON(input_to_forget_weights->num_dimensions() != 2);
- ARM_COMPUTE_RETURN_ERROR_ON(input_to_cell_weights->num_dimensions() != 2);
- ARM_COMPUTE_RETURN_ERROR_ON(input_to_output_weights->num_dimensions() != 2);
- ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_forget_weights->num_dimensions() != 2);
- ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_cell_weights->num_dimensions() != 2);
- ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_output_weights->num_dimensions() != 2);
- ARM_COMPUTE_RETURN_ERROR_ON(forget_gate_bias->num_dimensions() != 1);
- ARM_COMPUTE_RETURN_ERROR_ON(cell_bias->num_dimensions() != 1);
- ARM_COMPUTE_RETURN_ERROR_ON(output_gate_bias->num_dimensions() != 1);
- ARM_COMPUTE_RETURN_ERROR_ON(output_state->num_dimensions() != 2);
- ARM_COMPUTE_RETURN_ERROR_ON(cell_state->num_dimensions() != 2);
- ARM_COMPUTE_RETURN_ERROR_ON(scratch_buffer->num_dimensions() != 2);
- ARM_COMPUTE_RETURN_ERROR_ON(output->num_dimensions() != 2);
+ ARM_COMPUTE_RETURN_ERROR_ON(input->num_dimensions() > 2);
+ ARM_COMPUTE_RETURN_ERROR_ON(input_to_forget_weights->num_dimensions() > 2);
+ ARM_COMPUTE_RETURN_ERROR_ON(input_to_cell_weights->num_dimensions() > 2);
+ ARM_COMPUTE_RETURN_ERROR_ON(input_to_output_weights->num_dimensions() > 2);
+ ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_forget_weights->num_dimensions() > 2);
+ ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_cell_weights->num_dimensions() > 2);
+ ARM_COMPUTE_RETURN_ERROR_ON(recurrent_to_output_weights->num_dimensions() > 2);
+ ARM_COMPUTE_RETURN_ERROR_ON(forget_gate_bias->num_dimensions() > 1);
+ ARM_COMPUTE_RETURN_ERROR_ON(cell_bias->num_dimensions() > 1);
+ ARM_COMPUTE_RETURN_ERROR_ON(output_gate_bias->num_dimensions() > 1);
+ ARM_COMPUTE_RETURN_ERROR_ON(output_state->num_dimensions() > 2);
+ ARM_COMPUTE_RETURN_ERROR_ON(cell_state->num_dimensions() > 2);
+ ARM_COMPUTE_RETURN_ERROR_ON(scratch_buffer->num_dimensions() > 2);
+ ARM_COMPUTE_RETURN_ERROR_ON(output->num_dimensions() > 2);
ARM_COMPUTE_RETURN_ERROR_ON(cell_bias->dimension(0) * 4 != scratch_buffer->dimension(0) && cell_bias->dimension(0) * 3 != scratch_buffer->dimension(0));
if(lstm_params.has_peephole_opt())
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.cell_to_output_weights(), lstm_params.cell_to_forget_weights());
- ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_forget_weights()->num_dimensions() != 1);
- ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_output_weights()->num_dimensions() != 1);
+ ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_forget_weights()->num_dimensions() > 1);
+ ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_output_weights()->num_dimensions() > 1);
}
TensorShape units_out_transposed_shape = compute_transposed_shape(*recurrent_to_output_weights);
@@ -340,10 +340,10 @@ Status CLLSTMLayer::validate(const ITensorInfo *input, const ITensorInfo *input_
if(!lstm_params.has_cifg_opt())
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lstm_params.input_to_input_weights(), lstm_params.recurrent_to_input_weights(), lstm_params.cell_to_input_weights(), lstm_params.input_gate_bias());
- ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_to_input_weights()->num_dimensions() != 2);
- ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.recurrent_to_input_weights()->num_dimensions() != 2);
- ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_input_weights()->num_dimensions() != 1);
- ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_gate_bias()->num_dimensions() != 1);
+ ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_to_input_weights()->num_dimensions() > 2);
+ ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.recurrent_to_input_weights()->num_dimensions() > 2);
+ ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.cell_to_input_weights()->num_dimensions() > 1);
+ ARM_COMPUTE_RETURN_ERROR_ON(lstm_params.input_gate_bias()->num_dimensions() > 1);
ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(input, lstm_params.input_to_input_weights(), lstm_params.input_gate_bias(), cell_state, true, false));
ARM_COMPUTE_RETURN_ON_ERROR(CLGEMM::validate(cell_state, &num_units_transposed_info, nullptr, &gemmv_shape_info, 1.f, 0.f, GEMMInfo()));
ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(cell_state, &gemmv_shape_info, cell_state, ConvertPolicy::SATURATE));