aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/runtime/CL/functions/CLQLSTMLayer.h
diff options
context:
space:
mode:
authorSheri Zhang <sheri.zhang@arm.com>2020-04-21 13:10:24 +0100
committerSheri Zhang <sheri.zhang@arm.com>2020-04-26 21:31:26 +0000
commit3a35398ed6cc5d9c0f45f33dabb2bfbb017bcf60 (patch)
treeaaa4a8949e288157ab92404d1745214529c0c69b /arm_compute/runtime/CL/functions/CLQLSTMLayer.h
parent31b49caa2ca9308a5ba62a598afc9d1982b4af18 (diff)
downloadComputeLibrary-3a35398ed6cc5d9c0f45f33dabb2bfbb017bcf60.tar.gz
COMPMID-3240: Add support for layer normalization to CLQLSTMLayer
Signed-off-by: Sheri Zhang <sheri.zhang@arm.com> Change-Id: I45359a4ddb46c059097a2d77c008f802e8f4c143 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3065 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com> Reviewed-by: Sang-Hoon Park <sang-hoon.park@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'arm_compute/runtime/CL/functions/CLQLSTMLayer.h')
-rw-r--r--arm_compute/runtime/CL/functions/CLQLSTMLayer.h74
1 files changed, 74 insertions, 0 deletions
diff --git a/arm_compute/runtime/CL/functions/CLQLSTMLayer.h b/arm_compute/runtime/CL/functions/CLQLSTMLayer.h
index 72a61f8505..722275e269 100644
--- a/arm_compute/runtime/CL/functions/CLQLSTMLayer.h
+++ b/arm_compute/runtime/CL/functions/CLQLSTMLayer.h
@@ -27,6 +27,7 @@
#include "arm_compute/core/CL/kernels/CLElementwiseOperationKernel.h"
#include "arm_compute/core/CL/kernels/CLGEMMLowpReductionKernel.h"
#include "arm_compute/core/CL/kernels/CLPixelWiseMultiplicationKernel.h"
+#include "arm_compute/core/CL/kernels/CLQLSTMLayerNormalizationKernel.h"
#include "arm_compute/core/Types.h"
#include "arm_compute/runtime/CL/functions/CLActivationLayer.h"
#include "arm_compute/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.h"
@@ -216,6 +217,16 @@ public:
void prepare() override;
private:
+ enum class LayerNormGate : uint8_t
+ {
+ Forget,
+ Cell,
+ Input,
+ Output,
+ Count
+ };
+ static constexpr uint8_t _layer_norm_count = static_cast<uint8_t>(LayerNormGate::Count);
+
/** Internal method to configure matrix multiplication plus output stage of each gate.
*
* @param[in] compile_context The compile context to be used.
@@ -302,6 +313,7 @@ private:
CLGEMMLowpOutputStage _projection_outstage{};
CLSaturatedArithmeticOperationKernel _accumulate_projection{};
CLActivationLayer _projection_clip{};
+ std::array<CLQLSTMLayerNormalizationKernel, _layer_norm_count> _layer_norms{ {} };
// Tensor pointers
const ICLTensor *_input_to_input_weights
@@ -317,6 +329,61 @@ private:
const ICLTensor *_recurrent_to_cell_weights{ nullptr };
const ICLTensor *_recurrent_to_output_weights{ nullptr };
const ICLTensor *_projection_weights{ nullptr };
+ std::array<const ICLTensor *, _layer_norm_count> _layer_norm_weights{ {} };
+ std::array<const ICLTensor *, _layer_norm_count> _layer_norm_bias{ {} };
+
+ using LayerNormIndexType = typename std::underlying_type<LayerNormGate>::type;
+ inline LayerNormIndexType getGateIndex(LayerNormGate g)
+ {
+ return static_cast<LayerNormIndexType>(g);
+ }
+
+ inline void set_layer_norm_weight(const ICLTensor *t, LayerNormGate g)
+ {
+ _layer_norm_weights[getGateIndex(g)] = t;
+ }
+
+ inline void set_layer_norm_bias(const ICLTensor *t, LayerNormGate g)
+ {
+ _layer_norm_bias[getGateIndex(g)] = t;
+ }
+
+ inline const ICLTensor *get_layer_norm_weight(LayerNormGate g)
+ {
+ return _layer_norm_weights[getGateIndex(g)];
+ }
+
+ inline const ICLTensor *get_layer_norm_bias(LayerNormGate g)
+ {
+ return _layer_norm_bias[getGateIndex(g)];
+ }
+
+ inline CLQLSTMLayerNormalizationKernel &get_layer_norm(LayerNormGate g)
+ {
+ return _layer_norms[getGateIndex(g)];
+ }
+
+ inline void configure_layer_norm(LayerNormGate g, const ICLTensor *in)
+ {
+ ARM_COMPUTE_ERROR_ON(!_has_layer_norm);
+
+ CLTensor *out = &get_layer_norm_output(g);
+ _memory_group.manage(out);
+ out->allocator()->init(*(in->info()));
+
+ get_layer_norm(g).configure(in, out, get_layer_norm_weight(g), get_layer_norm_bias(g));
+ }
+
+ inline static Status validate_layer_norm(const ITensorInfo &in, const ITensorInfo &weight, const ITensorInfo &bias)
+ {
+ // Output quantization scale will be different, but ignored here
+ // since it will be configured at configure() stage.
+ const TensorInfo out
+ {
+ in
+ };
+ return CLQLSTMLayerNormalizationKernel::validate(&in, &out, &weight, &bias);
+ }
// Temporary tensors
CLTensor _input_to_forget_weights_transposed{ nullptr };
@@ -368,6 +435,12 @@ private:
CLTensor _mm_projection_res{ nullptr };
CLTensor _projection_outstage_res{ nullptr };
CLTensor _ones{ nullptr };
+ std::array<CLTensor, _layer_norm_count> _layer_norm_output{ {} };
+
+ inline CLTensor &get_layer_norm_output(LayerNormGate g)
+ {
+ return _layer_norm_output[getGateIndex(g)];
+ }
bool _is_prepared{ false };
bool _has_cifg{ false };
@@ -375,6 +448,7 @@ private:
bool _has_projection{ false };
bool _has_projection_clipping{ false };
bool _has_peephole{ false };
+ bool _has_layer_norm{ false };
};
} // namespace arm_compute
#endif /* ARM_COMPUTE_CLQLSTMLAYER_H */