aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/runtime/CL/functions/CLQLSTMLayer.h
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/runtime/CL/functions/CLQLSTMLayer.h')
-rw-r--r--arm_compute/runtime/CL/functions/CLQLSTMLayer.h50
1 files changed, 44 insertions, 6 deletions
diff --git a/arm_compute/runtime/CL/functions/CLQLSTMLayer.h b/arm_compute/runtime/CL/functions/CLQLSTMLayer.h
index 219f46ee48..67e8bc7751 100644
--- a/arm_compute/runtime/CL/functions/CLQLSTMLayer.h
+++ b/arm_compute/runtime/CL/functions/CLQLSTMLayer.h
@@ -230,7 +230,8 @@ private:
Output,
Count
};
- static constexpr uint8_t _layer_norm_count = static_cast<uint8_t>(LayerNormGate::Count);
+ static constexpr uint8_t _layer_norm_count = static_cast<uint8_t>(LayerNormGate::Count);
+ static constexpr uint32_t _out_state_output_size_dimension_idx = 0;
/** Internal method to configure matrix multiplication plus output stage of each gate.
*
@@ -254,6 +255,35 @@ private:
MemoryGroup _memory_group{};
+ /** A small internel kernel do the copy between two tensors */
+ class TensorCopyKernel
+ {
+ static constexpr uint32_t max_dimension_supported = 2;
+
+ ICLTensor *_src{ nullptr };
+ ICLTensor *_dst{ nullptr };
+ size_t _row_size{};
+ Window _window{};
+
+ public:
+ /** Static function to check if given info will lead to a valid configuration of @ref CLQLSTMLayer::TensorCopyKernel
+ *
+ * @param[in] src Source tensor info.
+ * @param[in] dst Destination tensor info
+ *
+ * @return a status
+ */
+ static Status validate(const ITensorInfo &src, const ITensorInfo &dst);
+ /** Set the input and output tensors.
+ *
+ * @param[in] src Source tensor
+ * @param[out] dst Destination tensor
+ */
+ void configure(ICLTensor &src, ICLTensor &dst);
+ /** run the kernel */
+ void run();
+ };
+
// Functions used
CLTranspose _transpose_input_to_forget_weights{};
CLTranspose _transpose_input_to_cell_weights{};
@@ -298,7 +328,7 @@ private:
CLPixelWiseMultiplicationKernel _pixelwise_mul_cell_to_input{};
CLGEMMLowpOutputStage _cell_to_input_outstage{};
CLSaturatedArithmeticOperationKernel _accumulate_cell_input{};
- CLActivationLayer _input_gate_tanh{};
+ CLActivationLayer _input_gate_sigmoid{};
CLPixelWiseMultiplicationKernel _pixelwise_mul_forget_cell{};
CLPixelWiseMultiplicationKernel _pixelwise_mul_input_cell{};
CLSaturatedArithmeticOperationKernel _add_forget_cell{};
@@ -309,6 +339,7 @@ private:
CLGEMMLowpOutputStage _recurrent_to_output_outstage{};
CLSaturatedArithmeticOperationKernel _accumulate_input_recurrent_output{};
CLPixelWiseMultiplicationKernel _pixelwise_mul_cell_to_output{};
+ CLGEMMLowpOutputStage _cell_to_output_outstage{};
CLSaturatedArithmeticOperationKernel _accumulate_cell_to_output{};
CLActivationLayer _output_gate_sigmoid{};
CLActivationLayer _hidden_tanh{};
@@ -321,11 +352,13 @@ private:
std::array<CLQLSTMLayerNormalizationKernel, _layer_norm_count> _layer_norms{ {} };
CLCopyKernel _copy_output{};
+ TensorCopyKernel _projection_bias_copy{};
+ TensorCopyKernel _projection_output_to_accumulate_copy{};
+ TensorCopyKernel _projection_accumulate_to_output_copy{};
+ TensorCopyKernel _hidden_to_output_copy{};
+
// Tensor pointers
- const ICLTensor *_input_to_input_weights
- {
- nullptr
- };
+ const ICLTensor *_input_to_input_weights{ nullptr };
const ICLTensor *_recurrent_to_input_weights{ nullptr };
const ICLTensor *_projection_bias{ nullptr };
const ICLTensor *_input_to_forget_weights{ nullptr };
@@ -435,11 +468,15 @@ private:
CLTensor _input_to_output_outstage_res{ nullptr };
CLTensor _mm_recurrent_to_output_res{ nullptr };
CLTensor _mul_cell_to_output_res{ nullptr };
+ CLTensor _cell_to_output_outstage_res{ nullptr };
CLTensor _recurrent_to_output_outstage_res{ nullptr };
CLTensor _output_gate{ nullptr };
CLTensor _hidden_mul_res{ nullptr };
+ CLTensor _hidden_gate{ nullptr };
CLTensor _mm_projection_res{ nullptr };
CLTensor _projection_outstage_res{ nullptr };
+ CLTensor _projection_out_res{ nullptr };
+ CLTensor _projection_accumulate_res{ nullptr };
CLTensor _ones{ nullptr };
std::array<CLTensor, _layer_norm_count> _layer_norm_output{ {} };
@@ -455,6 +492,7 @@ private:
bool _has_projection_clipping{ false };
bool _has_peephole{ false };
bool _has_layer_norm{ false };
+ bool _projection_tensor_copy_required{ false };
};
} // namespace arm_compute
#endif /* ARM_COMPUTE_CLQLSTMLAYER_H */