aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/runtime/CL/functions/CLLSTMLayerQuantized.h
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/runtime/CL/functions/CLLSTMLayerQuantized.h')
-rw-r--r--arm_compute/runtime/CL/functions/CLLSTMLayerQuantized.h67
1 files changed, 50 insertions, 17 deletions
diff --git a/arm_compute/runtime/CL/functions/CLLSTMLayerQuantized.h b/arm_compute/runtime/CL/functions/CLLSTMLayerQuantized.h
index 9c004b85d0..8c116b1482 100644
--- a/arm_compute/runtime/CL/functions/CLLSTMLayerQuantized.h
+++ b/arm_compute/runtime/CL/functions/CLLSTMLayerQuantized.h
@@ -35,7 +35,6 @@
#include "arm_compute/runtime/CL/functions/CLQuantizationLayer.h"
#include "arm_compute/runtime/CL/functions/CLSlice.h"
#include "arm_compute/runtime/CL/functions/CLTranspose.h"
-
#include "arm_compute/runtime/common/LSTMParams.h"
namespace arm_compute
@@ -100,11 +99,22 @@ public:
* @param[out] output_state_out Destination tensor. Output is a 2D tensor with dimensions [output_size, batch_size].Data types supported: Same as @p input.
*/
void configure(const ICLTensor *input,
- const ICLTensor *input_to_input_weights, const ICLTensor *input_to_forget_weights, const ICLTensor *input_to_cell_weights, const ICLTensor *input_to_output_weights,
- const ICLTensor *recurrent_to_input_weights, const ICLTensor *recurrent_to_forget_weights, const ICLTensor *recurrent_to_cell_weights, const ICLTensor *recurrent_to_output_weights,
- const ICLTensor *input_gate_bias, const ICLTensor *forget_gate_bias, const ICLTensor *cell_bias, const ICLTensor *output_gate_bias,
- ICLTensor *cell_state_in, const ICLTensor *output_state_in,
- ICLTensor *cell_state_out, ICLTensor *output_state_out);
+ const ICLTensor *input_to_input_weights,
+ const ICLTensor *input_to_forget_weights,
+ const ICLTensor *input_to_cell_weights,
+ const ICLTensor *input_to_output_weights,
+ const ICLTensor *recurrent_to_input_weights,
+ const ICLTensor *recurrent_to_forget_weights,
+ const ICLTensor *recurrent_to_cell_weights,
+ const ICLTensor *recurrent_to_output_weights,
+ const ICLTensor *input_gate_bias,
+ const ICLTensor *forget_gate_bias,
+ const ICLTensor *cell_bias,
+ const ICLTensor *output_gate_bias,
+ ICLTensor *cell_state_in,
+ const ICLTensor *output_state_in,
+ ICLTensor *cell_state_out,
+ ICLTensor *output_state_out);
/** Initialize function's tensors.
*
* @param[in] compile_context The compile context to be used.
@@ -126,12 +136,24 @@ public:
* @param[out] cell_state_out Destination tensor. Output is a 2D tensor with dimensions [output_size, batch_size]. Data type supported: QSYMM16.
* @param[out] output_state_out Destination tensor. Output is a 2D tensor with dimensions [output_size, batch_size].Data types supported: Same as @p input.
*/
- void configure(const CLCompileContext &compile_context, const ICLTensor *input,
- const ICLTensor *input_to_input_weights, const ICLTensor *input_to_forget_weights, const ICLTensor *input_to_cell_weights, const ICLTensor *input_to_output_weights,
- const ICLTensor *recurrent_to_input_weights, const ICLTensor *recurrent_to_forget_weights, const ICLTensor *recurrent_to_cell_weights, const ICLTensor *recurrent_to_output_weights,
- const ICLTensor *input_gate_bias, const ICLTensor *forget_gate_bias, const ICLTensor *cell_bias, const ICLTensor *output_gate_bias,
- ICLTensor *cell_state_in, const ICLTensor *output_state_in,
- ICLTensor *cell_state_out, ICLTensor *output_state_out);
+ void configure(const CLCompileContext &compile_context,
+ const ICLTensor *input,
+ const ICLTensor *input_to_input_weights,
+ const ICLTensor *input_to_forget_weights,
+ const ICLTensor *input_to_cell_weights,
+ const ICLTensor *input_to_output_weights,
+ const ICLTensor *recurrent_to_input_weights,
+ const ICLTensor *recurrent_to_forget_weights,
+ const ICLTensor *recurrent_to_cell_weights,
+ const ICLTensor *recurrent_to_output_weights,
+ const ICLTensor *input_gate_bias,
+ const ICLTensor *forget_gate_bias,
+ const ICLTensor *cell_bias,
+ const ICLTensor *output_gate_bias,
+ ICLTensor *cell_state_in,
+ const ICLTensor *output_state_in,
+ ICLTensor *cell_state_out,
+ ICLTensor *output_state_out);
/** Static function to check if given info will lead to a valid configuration of @ref CLLSTMLayerQuantized
*
@@ -156,11 +178,22 @@ public:
* @return a status
*/
static Status validate(const ITensorInfo *input,
- const ITensorInfo *input_to_input_weights, const ITensorInfo *input_to_forget_weights, const ITensorInfo *input_to_cell_weights, const ITensorInfo *input_to_output_weights,
- const ITensorInfo *recurrent_to_input_weights, const ITensorInfo *recurrent_to_forget_weights, const ITensorInfo *recurrent_to_cell_weights, const ITensorInfo *recurrent_to_output_weights,
- const ITensorInfo *input_gate_bias, const ITensorInfo *forget_gate_bias, const ITensorInfo *cell_bias, const ITensorInfo *output_gate_bias,
- const ITensorInfo *cell_state_in, const ITensorInfo *output_state_in,
- const ITensorInfo *cell_state_out, const ITensorInfo *output_state_out);
+ const ITensorInfo *input_to_input_weights,
+ const ITensorInfo *input_to_forget_weights,
+ const ITensorInfo *input_to_cell_weights,
+ const ITensorInfo *input_to_output_weights,
+ const ITensorInfo *recurrent_to_input_weights,
+ const ITensorInfo *recurrent_to_forget_weights,
+ const ITensorInfo *recurrent_to_cell_weights,
+ const ITensorInfo *recurrent_to_output_weights,
+ const ITensorInfo *input_gate_bias,
+ const ITensorInfo *forget_gate_bias,
+ const ITensorInfo *cell_bias,
+ const ITensorInfo *output_gate_bias,
+ const ITensorInfo *cell_state_in,
+ const ITensorInfo *output_state_in,
+ const ITensorInfo *cell_state_out,
+ const ITensorInfo *output_state_out);
// Inherited methods overridden:
void run() override;