aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/NEON/functions/NEQLSTMLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/NEON/functions/NEQLSTMLayer.cpp')
-rw-r--r--src/runtime/NEON/functions/NEQLSTMLayer.cpp17
1 files changed, 13 insertions, 4 deletions
diff --git a/src/runtime/NEON/functions/NEQLSTMLayer.cpp b/src/runtime/NEON/functions/NEQLSTMLayer.cpp
index 9c78ea8b75..466c41307b 100644
--- a/src/runtime/NEON/functions/NEQLSTMLayer.cpp
+++ b/src/runtime/NEON/functions/NEQLSTMLayer.cpp
@@ -106,7 +106,7 @@ void NEQLSTMLayer::configure(const ITensor *input,
const ITensor *recurrent_to_forget_weights, const ITensor *recurrent_to_cell_weights, const ITensor *recurrent_to_output_weights,
const ITensor *forget_gate_bias, const ITensor *cell_bias, const ITensor *output_gate_bias,
const ITensor *cell_state_in, const ITensor *output_state_in,
- ITensor *cell_state_out, ITensor *output_state_out,
+ ITensor *cell_state_out, ITensor *output_state_out, ITensor *output,
const LSTMParams<ITensor> &lstm_params)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
@@ -121,7 +121,8 @@ void NEQLSTMLayer::configure(const ITensor *input,
ARM_COMPUTE_ERROR_THROW_ON(NEQLSTMLayer::validate(input->info(), input_to_forget_weights->info(), input_to_cell_weights->info(), input_to_output_weights->info(),
recurrent_to_forget_weights->info(), recurrent_to_cell_weights->info(), recurrent_to_output_weights->info(),
forget_gate_bias->info(), cell_bias->info(), output_gate_bias->info(),
- cell_state_in->info(), output_state_in->info(), cell_state_out->info(), output_state_out->info(), lstm_params_info));
+ cell_state_in->info(), output_state_in->info(), cell_state_out->info(), output_state_out->info(), output->info(),
+ lstm_params_info));
const int batch_size = input->info()->dimension(1);
const int num_units = input_to_output_weights->info()->dimension(1);
@@ -515,6 +516,9 @@ void NEQLSTMLayer::configure(const ITensor *input,
_hidden_gate.allocator()->allocate();
}
}
+
+ // Copy output_state_out to output
+ _copy_output.configure(output_state_out, output);
}
Status NEQLSTMLayer::validate(const ITensorInfo *input,
@@ -522,11 +526,12 @@ Status NEQLSTMLayer::validate(const ITensorInfo *input,
const ITensorInfo *recurrent_to_forget_weights, const ITensorInfo *recurrent_to_cell_weights, const ITensorInfo *recurrent_to_output_weights,
const ITensorInfo *forget_gate_bias, const ITensorInfo *cell_bias, const ITensorInfo *output_gate_bias,
const ITensorInfo *cell_state_in, const ITensorInfo *output_state_in,
- const ITensorInfo *cell_state_out, const ITensorInfo *output_state_out,
+ const ITensorInfo *cell_state_out, const ITensorInfo *output_state_out, const ITensorInfo *output,
const LSTMParams<ITensorInfo> &lstm_params)
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights, recurrent_to_forget_weights, recurrent_to_cell_weights,
- recurrent_to_output_weights, forget_gate_bias, cell_bias, output_gate_bias, cell_state_in, output_state_in, cell_state_out, output_state_out);
+ recurrent_to_output_weights, forget_gate_bias, cell_bias, output_gate_bias, cell_state_in, output_state_in,
+ cell_state_out, output_state_out, output);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8_SIGNED);
ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->num_dimensions() != 2, "Input must have exactly 2 dimensions");
@@ -867,6 +872,7 @@ Status NEQLSTMLayer::validate(const ITensorInfo *input,
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(output_state_in, output_state_out);
}
+ ARM_COMPUTE_RETURN_ON_ERROR(NECopyKernel::validate(output_state_out, output));
return Status{};
}
@@ -1011,6 +1017,9 @@ void NEQLSTMLayer::run()
_hidden_to_output_copy.run();
}
}
+
+ // Copy output_state_out to output
+ NEScheduler::get().schedule(&_copy_output, Window::DimY);
}
void NEQLSTMLayer::prepare()