From beb2d45ed515a2d0f0727c038ff837f21c61d2dd Mon Sep 17 00:00:00 2001 From: Michele Di Giorgio Date: Mon, 11 May 2020 16:17:51 +0100 Subject: COMPMID-3470: Modify NE/CLQLSTMLayer interface to provide 3 outputs Change-Id: I895b697c89c9a7509d48a54ac1effb7fbd8cca19 Signed-off-by: Michele Di Giorgio Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3174 Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins Reviewed-by: Sang-Hoon Park --- src/runtime/CL/functions/CLQLSTMLayer.cpp | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) (limited to 'src/runtime/CL/functions/CLQLSTMLayer.cpp') diff --git a/src/runtime/CL/functions/CLQLSTMLayer.cpp b/src/runtime/CL/functions/CLQLSTMLayer.cpp index d9b5c7c64d..a20ffc6f37 100644 --- a/src/runtime/CL/functions/CLQLSTMLayer.cpp +++ b/src/runtime/CL/functions/CLQLSTMLayer.cpp @@ -76,12 +76,12 @@ void CLQLSTMLayer::configure(const ICLTensor *input, const ICLTensor *recurrent_to_forget_weights, const ICLTensor *recurrent_to_cell_weights, const ICLTensor *recurrent_to_output_weights, const ICLTensor *forget_gate_bias, const ICLTensor *cell_bias, const ICLTensor *output_gate_bias, const ICLTensor *cell_state_in, const ICLTensor *output_state_in, - ICLTensor *cell_state_out, ICLTensor *output_state_out, + ICLTensor *cell_state_out, ICLTensor *output_state_out, ICLTensor *output, const LSTMParams &lstm_params) { configure(CLKernelLibrary::get().get_compile_context(), input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights, recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights, forget_gate_bias, cell_bias, output_gate_bias, - cell_state_in, output_state_in, cell_state_out, output_state_out, lstm_params); + cell_state_in, output_state_in, cell_state_out, output, output_state_out, lstm_params); } void CLQLSTMLayer::configure(const CLCompileContext &compile_context, const ICLTensor *input, @@ -89,12 +89,13 @@ void CLQLSTMLayer::configure(const CLCompileContext &compile_context, const ICLT const ICLTensor *recurrent_to_forget_weights, const ICLTensor *recurrent_to_cell_weights, const ICLTensor *recurrent_to_output_weights, const ICLTensor *forget_gate_bias, const ICLTensor *cell_bias, const ICLTensor *output_gate_bias, const ICLTensor *cell_state_in, const ICLTensor *output_state_in, - ICLTensor *cell_state_out, ICLTensor *output_state_out, + ICLTensor *cell_state_out, ICLTensor *output_state_out, ICLTensor *output, const LSTMParams &lstm_params) { ARM_COMPUTE_ERROR_ON_NULLPTR(input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights, recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights, - forget_gate_bias, cell_bias, output_gate_bias, cell_state_in, output_state_in, cell_state_out, output_state_out); + forget_gate_bias, cell_bias, output_gate_bias, cell_state_in, output_state_in, + cell_state_out, output_state_out, output); // Set lstm parameters LSTMParams lstm_params_info{}; @@ -104,7 +105,8 @@ void CLQLSTMLayer::configure(const CLCompileContext &compile_context, const ICLT ARM_COMPUTE_ERROR_THROW_ON(CLQLSTMLayer::validate(input->info(), input_to_forget_weights->info(), input_to_cell_weights->info(), input_to_output_weights->info(), recurrent_to_forget_weights->info(), recurrent_to_cell_weights->info(), recurrent_to_output_weights->info(), forget_gate_bias->info(), cell_bias->info(), output_gate_bias->info(), - cell_state_in->info(), output_state_in->info(), cell_state_out->info(), output_state_out->info(), lstm_params_info)); + cell_state_in->info(), output_state_in->info(), cell_state_out->info(), output_state_out->info(), output->info(), + lstm_params_info)); const int batch_size = input->info()->dimension(1); const int num_units = input_to_output_weights->info()->dimension(1); @@ -446,6 +448,9 @@ void CLQLSTMLayer::configure(const CLCompileContext &compile_context, const ICLT _has_projection_clipping = true; } } + + // Copy output_state_out to output + _copy_output.configure(compile_context, output_state_out, output); } Status CLQLSTMLayer::validate(const ITensorInfo *input, @@ -453,11 +458,12 @@ Status CLQLSTMLayer::validate(const ITensorInfo *input, const ITensorInfo *recurrent_to_forget_weights, const ITensorInfo *recurrent_to_cell_weights, const ITensorInfo *recurrent_to_output_weights, const ITensorInfo *forget_gate_bias, const ITensorInfo *cell_bias, const ITensorInfo *output_gate_bias, const ITensorInfo *cell_state_in, const ITensorInfo *output_state_in, - const ITensorInfo *cell_state_out, const ITensorInfo *output_state_out, + const ITensorInfo *cell_state_out, const ITensorInfo *output_state_out, const ITensorInfo *output, const LSTMParams &lstm_params) { ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights, recurrent_to_forget_weights, recurrent_to_cell_weights, - recurrent_to_output_weights, forget_gate_bias, cell_bias, output_gate_bias, cell_state_in, output_state_in, cell_state_out, output_state_out); + recurrent_to_output_weights, forget_gate_bias, cell_bias, output_gate_bias, cell_state_in, output_state_in, + cell_state_out, output_state_out, output); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8_SIGNED); ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->num_dimensions() != 2, "Input must have exactly 2 dimensions"); @@ -768,6 +774,7 @@ Status CLQLSTMLayer::validate(const ITensorInfo *input, ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(output_state_in, output_state_out); } + ARM_COMPUTE_RETURN_ON_ERROR(CLCopyKernel::validate(output_state_out, output)); return Status{}; } @@ -887,6 +894,9 @@ void CLQLSTMLayer::run() _projection_clip.run(); } } + + // Copy output_state_out to output + CLScheduler::get().enqueue(_copy_output); } void CLQLSTMLayer::prepare() -- cgit v1.2.1