aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/runtime/CL/functions/CLLSTMLayer.h
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/runtime/CL/functions/CLLSTMLayer.h')
-rw-r--r--arm_compute/runtime/CL/functions/CLLSTMLayer.h14
1 files changed, 10 insertions, 4 deletions
diff --git a/arm_compute/runtime/CL/functions/CLLSTMLayer.h b/arm_compute/runtime/CL/functions/CLLSTMLayer.h
index 87fb1190b7..a804a4af5b 100644
--- a/arm_compute/runtime/CL/functions/CLLSTMLayer.h
+++ b/arm_compute/runtime/CL/functions/CLLSTMLayer.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -30,6 +30,7 @@
#include "arm_compute/core/CL/kernels/CLCopyKernel.h"
#include "arm_compute/core/CL/kernels/CLElementwiseOperationKernel.h"
#include "arm_compute/core/CL/kernels/CLPixelWiseMultiplicationKernel.h"
+#include "arm_compute/core/CL/kernels/CLWidthConcatenate2TensorsKernel.h"
#include "arm_compute/core/Types.h"
#include "arm_compute/runtime/CL/CLMemoryGroup.h"
#include "arm_compute/runtime/CL/CLTensor.h"
@@ -138,6 +139,7 @@ public:
// Inherited methods overridden:
void run() override;
+ void prepare() override;
private:
CLMemoryGroup _memory_group;
@@ -182,16 +184,20 @@ private:
CLCopyKernel _copy_cell_state;
CLCopyKernel _copy_output;
CLWidthConcatenateLayer _concat_scratch_buffer;
+ CLWidthConcatenate2TensorsKernel _concat_inputs_forget_gate;
+ CLWidthConcatenate2TensorsKernel _concat_weights_forget_gate;
+ CLWidthConcatenate2TensorsKernel _concat_weights_input_gate;
+ CLWidthConcatenate2TensorsKernel _concat_weights_output;
CLTensor _input_gate_out1;
CLTensor _input_gate_out2;
CLTensor _input_gate_out3;
CLTensor _input_gate_out4;
- CLTensor _input_gate_out5;
CLTensor _forget_gate_out1;
CLTensor _forget_gate_out2;
CLTensor _forget_gate_out3;
CLTensor _forget_gate_out4;
CLTensor _forget_gate_out5;
+ CLTensor _forget_gate_out6;
CLTensor _cell_state_out1;
CLTensor _cell_state_out2;
CLTensor _cell_state_out3;
@@ -201,7 +207,6 @@ private:
CLTensor _output2;
CLTensor _output3;
CLTensor _output4;
- CLTensor _output5;
CLTensor _cell_state_activation;
CLTensor _output_state1;
CLTensor _ones;
@@ -210,6 +215,7 @@ private:
bool _perform_cell_clipping;
bool _has_projection_weights;
bool _perform_projection_clipping;
+ bool _is_prepared;
};
-}
+} // namespace arm_compute
#endif /* __ARM_COMPUTE_CLLSTMLAYER_H__ */