aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/CL/functions/CLQLSTMLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/CL/functions/CLQLSTMLayer.cpp')
-rw-r--r--src/runtime/CL/functions/CLQLSTMLayer.cpp24
1 files changed, 17 insertions, 7 deletions
diff --git a/src/runtime/CL/functions/CLQLSTMLayer.cpp b/src/runtime/CL/functions/CLQLSTMLayer.cpp
index d9b5c7c64d..a20ffc6f37 100644
--- a/src/runtime/CL/functions/CLQLSTMLayer.cpp
+++ b/src/runtime/CL/functions/CLQLSTMLayer.cpp
@@ -76,12 +76,12 @@ void CLQLSTMLayer::configure(const ICLTensor *input,
const ICLTensor *recurrent_to_forget_weights, const ICLTensor *recurrent_to_cell_weights, const ICLTensor *recurrent_to_output_weights,
const ICLTensor *forget_gate_bias, const ICLTensor *cell_bias, const ICLTensor *output_gate_bias,
const ICLTensor *cell_state_in, const ICLTensor *output_state_in,
- ICLTensor *cell_state_out, ICLTensor *output_state_out,
+ ICLTensor *cell_state_out, ICLTensor *output_state_out, ICLTensor *output,
const LSTMParams<ICLTensor> &lstm_params)
{
configure(CLKernelLibrary::get().get_compile_context(), input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights, forget_gate_bias, cell_bias, output_gate_bias,
- cell_state_in, output_state_in, cell_state_out, output_state_out, lstm_params);
+ cell_state_in, output_state_in, cell_state_out, output, output_state_out, lstm_params);
}
void CLQLSTMLayer::configure(const CLCompileContext &compile_context, const ICLTensor *input,
@@ -89,12 +89,13 @@ void CLQLSTMLayer::configure(const CLCompileContext &compile_context, const ICLT
const ICLTensor *recurrent_to_forget_weights, const ICLTensor *recurrent_to_cell_weights, const ICLTensor *recurrent_to_output_weights,
const ICLTensor *forget_gate_bias, const ICLTensor *cell_bias, const ICLTensor *output_gate_bias,
const ICLTensor *cell_state_in, const ICLTensor *output_state_in,
- ICLTensor *cell_state_out, ICLTensor *output_state_out,
+ ICLTensor *cell_state_out, ICLTensor *output_state_out, ICLTensor *output,
const LSTMParams<ICLTensor> &lstm_params)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights,
- forget_gate_bias, cell_bias, output_gate_bias, cell_state_in, output_state_in, cell_state_out, output_state_out);
+ forget_gate_bias, cell_bias, output_gate_bias, cell_state_in, output_state_in,
+ cell_state_out, output_state_out, output);
// Set lstm parameters
LSTMParams<ITensorInfo> lstm_params_info{};
@@ -104,7 +105,8 @@ void CLQLSTMLayer::configure(const CLCompileContext &compile_context, const ICLT
ARM_COMPUTE_ERROR_THROW_ON(CLQLSTMLayer::validate(input->info(), input_to_forget_weights->info(), input_to_cell_weights->info(), input_to_output_weights->info(),
recurrent_to_forget_weights->info(), recurrent_to_cell_weights->info(), recurrent_to_output_weights->info(),
forget_gate_bias->info(), cell_bias->info(), output_gate_bias->info(),
- cell_state_in->info(), output_state_in->info(), cell_state_out->info(), output_state_out->info(), lstm_params_info));
+ cell_state_in->info(), output_state_in->info(), cell_state_out->info(), output_state_out->info(), output->info(),
+ lstm_params_info));
const int batch_size = input->info()->dimension(1);
const int num_units = input_to_output_weights->info()->dimension(1);
@@ -446,6 +448,9 @@ void CLQLSTMLayer::configure(const CLCompileContext &compile_context, const ICLT
_has_projection_clipping = true;
}
}
+
+ // Copy output_state_out to output
+ _copy_output.configure(compile_context, output_state_out, output);
}
Status CLQLSTMLayer::validate(const ITensorInfo *input,
@@ -453,11 +458,12 @@ Status CLQLSTMLayer::validate(const ITensorInfo *input,
const ITensorInfo *recurrent_to_forget_weights, const ITensorInfo *recurrent_to_cell_weights, const ITensorInfo *recurrent_to_output_weights,
const ITensorInfo *forget_gate_bias, const ITensorInfo *cell_bias, const ITensorInfo *output_gate_bias,
const ITensorInfo *cell_state_in, const ITensorInfo *output_state_in,
- const ITensorInfo *cell_state_out, const ITensorInfo *output_state_out,
+ const ITensorInfo *cell_state_out, const ITensorInfo *output_state_out, const ITensorInfo *output,
const LSTMParams<ITensorInfo> &lstm_params)
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights, recurrent_to_forget_weights, recurrent_to_cell_weights,
- recurrent_to_output_weights, forget_gate_bias, cell_bias, output_gate_bias, cell_state_in, output_state_in, cell_state_out, output_state_out);
+ recurrent_to_output_weights, forget_gate_bias, cell_bias, output_gate_bias, cell_state_in, output_state_in,
+ cell_state_out, output_state_out, output);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8_SIGNED);
ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->num_dimensions() != 2, "Input must have exactly 2 dimensions");
@@ -768,6 +774,7 @@ Status CLQLSTMLayer::validate(const ITensorInfo *input,
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(output_state_in, output_state_out);
}
+ ARM_COMPUTE_RETURN_ON_ERROR(CLCopyKernel::validate(output_state_out, output));
return Status{};
}
@@ -887,6 +894,9 @@ void CLQLSTMLayer::run()
_projection_clip.run();
}
}
+
+ // Copy output_state_out to output
+ CLScheduler::get().enqueue(_copy_output);
}
void CLQLSTMLayer::prepare()