aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/runtime/CL/functions/CLLSTMLayer.h
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/runtime/CL/functions/CLLSTMLayer.h')
-rw-r--r--arm_compute/runtime/CL/functions/CLLSTMLayer.h91
1 files changed, 65 insertions, 26 deletions
diff --git a/arm_compute/runtime/CL/functions/CLLSTMLayer.h b/arm_compute/runtime/CL/functions/CLLSTMLayer.h
index d26b4c5595..fe494991af 100644
--- a/arm_compute/runtime/CL/functions/CLLSTMLayer.h
+++ b/arm_compute/runtime/CL/functions/CLLSTMLayer.h
@@ -24,8 +24,6 @@
#ifndef ARM_COMPUTE_CLLSTMLAYER_H
#define ARM_COMPUTE_CLLSTMLAYER_H
-#include "arm_compute/runtime/IFunction.h"
-
#include "arm_compute/core/Types.h"
#include "arm_compute/runtime/CL/CLTensor.h"
#include "arm_compute/runtime/CL/functions/CLActivationLayer.h"
@@ -37,9 +35,10 @@
#include "arm_compute/runtime/CL/functions/CLGEMM.h"
#include "arm_compute/runtime/CL/functions/CLMeanStdDevNormalizationLayer.h"
#include "arm_compute/runtime/CL/functions/CLPixelWiseMultiplication.h"
+#include "arm_compute/runtime/common/LSTMParams.h"
+#include "arm_compute/runtime/IFunction.h"
#include "arm_compute/runtime/IMemoryManager.h"
#include "arm_compute/runtime/MemoryGroup.h"
-#include "arm_compute/runtime/common/LSTMParams.h"
#include <memory>
@@ -53,7 +52,7 @@ namespace kernels
{
class ClTransposeKernel;
}
-}
+} // namespace opencl
/** This function performs a single time step in a Long Short-Term Memory (LSTM) layer.
*
@@ -120,13 +119,26 @@ public:
* @param[in] projection_threshold (Optional) The clipping threshold for the output from the projection layer, such that values are bound within [-proj_clip, proj_clip].
* If set to 0.0f then clipping is disabled.
*/
- void configure(const ICLTensor *input,
- const ICLTensor *input_to_forget_weights, const ICLTensor *input_to_cell_weights, const ICLTensor *input_to_output_weights,
- const ICLTensor *recurrent_to_forget_weights, const ICLTensor *recurrent_to_cell_weights, const ICLTensor *recurrent_to_output_weights,
- const ICLTensor *forget_gate_bias, const ICLTensor *cell_bias, const ICLTensor *output_gate_bias,
- const ICLTensor *output_state_in, ICLTensor *cell_state_in,
- ICLTensor *scratch_buffer, ICLTensor *output_state_out, ICLTensor *cell_state_out, ICLTensor *output,
- const LSTMParams<ICLTensor> &lstm_params, const ActivationLayerInfo &activation_info, float cell_threshold = 0.f, float projection_threshold = 0.f);
+ void configure(const ICLTensor *input,
+ const ICLTensor *input_to_forget_weights,
+ const ICLTensor *input_to_cell_weights,
+ const ICLTensor *input_to_output_weights,
+ const ICLTensor *recurrent_to_forget_weights,
+ const ICLTensor *recurrent_to_cell_weights,
+ const ICLTensor *recurrent_to_output_weights,
+ const ICLTensor *forget_gate_bias,
+ const ICLTensor *cell_bias,
+ const ICLTensor *output_gate_bias,
+ const ICLTensor *output_state_in,
+ ICLTensor *cell_state_in,
+ ICLTensor *scratch_buffer,
+ ICLTensor *output_state_out,
+ ICLTensor *cell_state_out,
+ ICLTensor *output,
+ const LSTMParams<ICLTensor> &lstm_params,
+ const ActivationLayerInfo &activation_info,
+ float cell_threshold = 0.f,
+ float projection_threshold = 0.f);
/** Initialize function's tensors.
*
* @param[in] compile_context The compile context to be used.
@@ -166,13 +178,27 @@ public:
* @param[in] projection_threshold (Optional) The clipping threshold for the output from the projection layer, such that values are bound within [-proj_clip, proj_clip].
* If set to 0.0f then clipping is disabled.
*/
- void configure(const CLCompileContext &compile_context, const ICLTensor *input,
- const ICLTensor *input_to_forget_weights, const ICLTensor *input_to_cell_weights, const ICLTensor *input_to_output_weights,
- const ICLTensor *recurrent_to_forget_weights, const ICLTensor *recurrent_to_cell_weights, const ICLTensor *recurrent_to_output_weights,
- const ICLTensor *forget_gate_bias, const ICLTensor *cell_bias, const ICLTensor *output_gate_bias,
- const ICLTensor *output_state_in, ICLTensor *cell_state_in,
- ICLTensor *scratch_buffer, ICLTensor *output_state_out, ICLTensor *cell_state_out, ICLTensor *output,
- const LSTMParams<ICLTensor> &lstm_params, const ActivationLayerInfo &activation_info, float cell_threshold = 0.f, float projection_threshold = 0.f);
+ void configure(const CLCompileContext &compile_context,
+ const ICLTensor *input,
+ const ICLTensor *input_to_forget_weights,
+ const ICLTensor *input_to_cell_weights,
+ const ICLTensor *input_to_output_weights,
+ const ICLTensor *recurrent_to_forget_weights,
+ const ICLTensor *recurrent_to_cell_weights,
+ const ICLTensor *recurrent_to_output_weights,
+ const ICLTensor *forget_gate_bias,
+ const ICLTensor *cell_bias,
+ const ICLTensor *output_gate_bias,
+ const ICLTensor *output_state_in,
+ ICLTensor *cell_state_in,
+ ICLTensor *scratch_buffer,
+ ICLTensor *output_state_out,
+ ICLTensor *cell_state_out,
+ ICLTensor *output,
+ const LSTMParams<ICLTensor> &lstm_params,
+ const ActivationLayerInfo &activation_info,
+ float cell_threshold = 0.f,
+ float projection_threshold = 0.f);
/** Static function to check if given info will lead to a valid configuration of @ref CLLSTMLayer
*
@@ -214,13 +240,26 @@ public:
*
* @return a status
*/
- static Status validate(const ITensorInfo *input,
- const ITensorInfo *input_to_forget_weights, const ITensorInfo *input_to_cell_weights, const ITensorInfo *input_to_output_weights,
- const ITensorInfo *recurrent_to_forget_weights, const ITensorInfo *recurrent_to_cell_weights, const ITensorInfo *recurrent_to_output_weights,
- const ITensorInfo *forget_gate_bias, const ITensorInfo *cell_bias, const ITensorInfo *output_gate_bias,
- const ITensorInfo *output_state_in, const ITensorInfo *cell_state_in,
- const ITensorInfo *scratch_buffer, const ITensorInfo *output_state_out, const ITensorInfo *cell_state_out, const ITensorInfo *output,
- const LSTMParams<ITensorInfo> &lstm_params, const ActivationLayerInfo &activation_info, float cell_threshold = 0.f, float projection_threshold = 0.f);
+ static Status validate(const ITensorInfo *input,
+ const ITensorInfo *input_to_forget_weights,
+ const ITensorInfo *input_to_cell_weights,
+ const ITensorInfo *input_to_output_weights,
+ const ITensorInfo *recurrent_to_forget_weights,
+ const ITensorInfo *recurrent_to_cell_weights,
+ const ITensorInfo *recurrent_to_output_weights,
+ const ITensorInfo *forget_gate_bias,
+ const ITensorInfo *cell_bias,
+ const ITensorInfo *output_gate_bias,
+ const ITensorInfo *output_state_in,
+ const ITensorInfo *cell_state_in,
+ const ITensorInfo *scratch_buffer,
+ const ITensorInfo *output_state_out,
+ const ITensorInfo *cell_state_out,
+ const ITensorInfo *output,
+ const LSTMParams<ITensorInfo> &lstm_params,
+ const ActivationLayerInfo &activation_info,
+ float cell_threshold = 0.f,
+ float projection_threshold = 0.f);
// Inherited methods overridden:
void run() override;
@@ -311,7 +350,7 @@ private:
bool _perform_projection_clipping;
bool _is_prepared;
bool _is_layer_norm_lstm;
- const ICLTensor *_recurrent_to_cell_weights{ nullptr };
+ const ICLTensor *_recurrent_to_cell_weights{nullptr};
};
} // namespace arm_compute
#endif /* ARM_COMPUTE_CLLSTMLAYER_H */