diff options
Diffstat (limited to 'src/runtime/NEON')
-rw-r--r-- | src/runtime/NEON/functions/NEQLSTMLayer.cpp | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/src/runtime/NEON/functions/NEQLSTMLayer.cpp b/src/runtime/NEON/functions/NEQLSTMLayer.cpp index cb45b647c7..5a6b51337a 100644 --- a/src/runtime/NEON/functions/NEQLSTMLayer.cpp +++ b/src/runtime/NEON/functions/NEQLSTMLayer.cpp @@ -105,7 +105,7 @@ void NEQLSTMLayer::configure(const ITensor *input, const ITensor *input_to_forget_weights, const ITensor *input_to_cell_weights, const ITensor *input_to_output_weights, const ITensor *recurrent_to_forget_weights, const ITensor *recurrent_to_cell_weights, const ITensor *recurrent_to_output_weights, const ITensor *forget_gate_bias, const ITensor *cell_bias, const ITensor *output_gate_bias, - const ITensor *cell_state_in, const ITensor *output_state_in, + const ITensor *cell_state_in, ITensor *output_state_in, ITensor *cell_state_out, ITensor *output_state_out, ITensor *output, const LSTMParams<ITensor> &lstm_params) { @@ -477,9 +477,9 @@ void NEQLSTMLayer::configure(const ITensor *input, if(_projection_tensor_copy_required) { _hidden_gate.allocator()->allocate(); - _projection_accumulate_res.allocator()->init(*output_state_out->info()); + _projection_accumulate_res.allocator()->init(*output_state_in->info()); _projection_accumulate_res.info()->set_tensor_shape(_projection_outstage_res.info()->tensor_shape()); - _projection_output_to_accumulate_copy.configure(*output_state_out, _projection_accumulate_res); + _projection_output_to_accumulate_copy.configure(*output_state_in, _projection_accumulate_res); accumulate_destination = &_projection_accumulate_res; } @@ -834,7 +834,7 @@ Status NEQLSTMLayer::validate(const ITensorInfo *input, if(projection_tensor_copy_required) { - ARM_COMPUTE_RETURN_ON_ERROR(NEQLSTMLayer::TensorCopyKernel::validate(*output_state_out, projection_outstage_info)); + ARM_COMPUTE_RETURN_ON_ERROR(NEQLSTMLayer::TensorCopyKernel::validate(*output_state_in, projection_outstage_info)); } ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(output_state_out, output_state_out, output_state_out, ConvertPolicy::SATURATE)); |