From cafec8f19ee126b72ba2b0194bd25a5a93727980 Mon Sep 17 00:00:00 2001 From: John Kesapides Date: Tue, 19 Feb 2019 15:53:59 +0000 Subject: COMPMID-1024 Investigate concatenation for RNN/LSTM OpenCL Change-Id: I0cee0853b82a7e4c487989d4a0890d58ec086045 Signed-off-by: John Kesapides Reviewed-on: https://review.mlplatform.org/c/763 Comments-Addressed: Michele Di Giorgio Tested-by: Arm Jenkins Reviewed-by: Georgios Pinitas Comments-Addressed: Arm Jenkins --- arm_compute/core/utils/misc/ShapeCalculator.h | 5 ++--- arm_compute/runtime/CL/functions/CLLSTMLayer.h | 14 ++++++++++---- 2 files changed, 12 insertions(+), 7 deletions(-) (limited to 'arm_compute') diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h index 7a34b43028..0d07266403 100644 --- a/arm_compute/core/utils/misc/ShapeCalculator.h +++ b/arm_compute/core/utils/misc/ShapeCalculator.h @@ -1162,12 +1162,11 @@ inline TensorShape extract_shape(T *data) return data->info()->tensor_shape(); } -inline TensorShape extract_shape(const ITensorInfo *data) +inline TensorShape extract_shape(ITensorInfo *data) { return data->tensor_shape(); } - -inline TensorShape extract_shape(ITensorInfo *data) +inline TensorShape extract_shape(const ITensorInfo *data) { return data->tensor_shape(); } diff --git a/arm_compute/runtime/CL/functions/CLLSTMLayer.h b/arm_compute/runtime/CL/functions/CLLSTMLayer.h index 87fb1190b7..a804a4af5b 100644 --- a/arm_compute/runtime/CL/functions/CLLSTMLayer.h +++ b/arm_compute/runtime/CL/functions/CLLSTMLayer.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 ARM Limited. + * Copyright (c) 2018-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -30,6 +30,7 @@ #include "arm_compute/core/CL/kernels/CLCopyKernel.h" #include "arm_compute/core/CL/kernels/CLElementwiseOperationKernel.h" #include "arm_compute/core/CL/kernels/CLPixelWiseMultiplicationKernel.h" +#include "arm_compute/core/CL/kernels/CLWidthConcatenate2TensorsKernel.h" #include "arm_compute/core/Types.h" #include "arm_compute/runtime/CL/CLMemoryGroup.h" #include "arm_compute/runtime/CL/CLTensor.h" @@ -138,6 +139,7 @@ public: // Inherited methods overridden: void run() override; + void prepare() override; private: CLMemoryGroup _memory_group; @@ -182,16 +184,20 @@ private: CLCopyKernel _copy_cell_state; CLCopyKernel _copy_output; CLWidthConcatenateLayer _concat_scratch_buffer; + CLWidthConcatenate2TensorsKernel _concat_inputs_forget_gate; + CLWidthConcatenate2TensorsKernel _concat_weights_forget_gate; + CLWidthConcatenate2TensorsKernel _concat_weights_input_gate; + CLWidthConcatenate2TensorsKernel _concat_weights_output; CLTensor _input_gate_out1; CLTensor _input_gate_out2; CLTensor _input_gate_out3; CLTensor _input_gate_out4; - CLTensor _input_gate_out5; CLTensor _forget_gate_out1; CLTensor _forget_gate_out2; CLTensor _forget_gate_out3; CLTensor _forget_gate_out4; CLTensor _forget_gate_out5; + CLTensor _forget_gate_out6; CLTensor _cell_state_out1; CLTensor _cell_state_out2; CLTensor _cell_state_out3; @@ -201,7 +207,6 @@ private: CLTensor _output2; CLTensor _output3; CLTensor _output4; - CLTensor _output5; CLTensor _cell_state_activation; CLTensor _output_state1; CLTensor _ones; @@ -210,6 +215,7 @@ private: bool _perform_cell_clipping; bool _has_projection_weights; bool _perform_projection_clipping; + bool _is_prepared; }; -} +} // namespace arm_compute #endif /* __ARM_COMPUTE_CLLSTMLAYER_H__ */ -- cgit v1.2.1