aboutsummaryrefslogtreecommitdiff
path: root/arm_compute
diff options
context:
space:
mode:
authorJohn Kesapides <john.kesapides@arm.com>2019-02-19 15:53:59 +0000
committerJohn Kesapides <john.kesapides@arm.com>2019-04-02 11:35:52 +0000
commitcafec8f19ee126b72ba2b0194bd25a5a93727980 (patch)
treefcde0a28d1554e024d8029c81a1c028d3f0444ce /arm_compute
parent108a95e046dde880075b6c278b44033d13f55be3 (diff)
downloadComputeLibrary-cafec8f19ee126b72ba2b0194bd25a5a93727980.tar.gz
COMPMID-1024 Investigate concatenation for RNN/LSTM OpenCL
Change-Id: I0cee0853b82a7e4c487989d4a0890d58ec086045 Signed-off-by: John Kesapides <john.kesapides@arm.com> Reviewed-on: https://review.mlplatform.org/c/763 Comments-Addressed: Michele Di Giorgio <michele.digiorgio@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'arm_compute')
-rw-r--r--arm_compute/core/utils/misc/ShapeCalculator.h5
-rw-r--r--arm_compute/runtime/CL/functions/CLLSTMLayer.h14
2 files changed, 12 insertions, 7 deletions
diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h
index 7a34b43028..0d07266403 100644
--- a/arm_compute/core/utils/misc/ShapeCalculator.h
+++ b/arm_compute/core/utils/misc/ShapeCalculator.h
@@ -1162,12 +1162,11 @@ inline TensorShape extract_shape(T *data)
return data->info()->tensor_shape();
}
-inline TensorShape extract_shape(const ITensorInfo *data)
+inline TensorShape extract_shape(ITensorInfo *data)
{
return data->tensor_shape();
}
-
-inline TensorShape extract_shape(ITensorInfo *data)
+inline TensorShape extract_shape(const ITensorInfo *data)
{
return data->tensor_shape();
}
diff --git a/arm_compute/runtime/CL/functions/CLLSTMLayer.h b/arm_compute/runtime/CL/functions/CLLSTMLayer.h
index 87fb1190b7..a804a4af5b 100644
--- a/arm_compute/runtime/CL/functions/CLLSTMLayer.h
+++ b/arm_compute/runtime/CL/functions/CLLSTMLayer.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -30,6 +30,7 @@
#include "arm_compute/core/CL/kernels/CLCopyKernel.h"
#include "arm_compute/core/CL/kernels/CLElementwiseOperationKernel.h"
#include "arm_compute/core/CL/kernels/CLPixelWiseMultiplicationKernel.h"
+#include "arm_compute/core/CL/kernels/CLWidthConcatenate2TensorsKernel.h"
#include "arm_compute/core/Types.h"
#include "arm_compute/runtime/CL/CLMemoryGroup.h"
#include "arm_compute/runtime/CL/CLTensor.h"
@@ -138,6 +139,7 @@ public:
// Inherited methods overridden:
void run() override;
+ void prepare() override;
private:
CLMemoryGroup _memory_group;
@@ -182,16 +184,20 @@ private:
CLCopyKernel _copy_cell_state;
CLCopyKernel _copy_output;
CLWidthConcatenateLayer _concat_scratch_buffer;
+ CLWidthConcatenate2TensorsKernel _concat_inputs_forget_gate;
+ CLWidthConcatenate2TensorsKernel _concat_weights_forget_gate;
+ CLWidthConcatenate2TensorsKernel _concat_weights_input_gate;
+ CLWidthConcatenate2TensorsKernel _concat_weights_output;
CLTensor _input_gate_out1;
CLTensor _input_gate_out2;
CLTensor _input_gate_out3;
CLTensor _input_gate_out4;
- CLTensor _input_gate_out5;
CLTensor _forget_gate_out1;
CLTensor _forget_gate_out2;
CLTensor _forget_gate_out3;
CLTensor _forget_gate_out4;
CLTensor _forget_gate_out5;
+ CLTensor _forget_gate_out6;
CLTensor _cell_state_out1;
CLTensor _cell_state_out2;
CLTensor _cell_state_out3;
@@ -201,7 +207,6 @@ private:
CLTensor _output2;
CLTensor _output3;
CLTensor _output4;
- CLTensor _output5;
CLTensor _cell_state_activation;
CLTensor _output_state1;
CLTensor _ones;
@@ -210,6 +215,7 @@ private:
bool _perform_cell_clipping;
bool _has_projection_weights;
bool _perform_projection_clipping;
+ bool _is_prepared;
};
-}
+} // namespace arm_compute
#endif /* __ARM_COMPUTE_CLLSTMLAYER_H__ */