From 3a35398ed6cc5d9c0f45f33dabb2bfbb017bcf60 Mon Sep 17 00:00:00 2001 From: Sheri Zhang Date: Tue, 21 Apr 2020 13:10:24 +0100 Subject: COMPMID-3240: Add support for layer normalization to CLQLSTMLayer Signed-off-by: Sheri Zhang Change-Id: I45359a4ddb46c059097a2d77c008f802e8f4c143 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3065 Tested-by: Arm Jenkins Reviewed-by: Michele Di Giorgio Reviewed-by: Sang-Hoon Park Comments-Addressed: Arm Jenkins --- arm_compute/runtime/CL/functions/CLQLSTMLayer.h | 74 +++++++++++++++++++++++++ 1 file changed, 74 insertions(+) (limited to 'arm_compute/runtime/CL/functions/CLQLSTMLayer.h') diff --git a/arm_compute/runtime/CL/functions/CLQLSTMLayer.h b/arm_compute/runtime/CL/functions/CLQLSTMLayer.h index 72a61f8505..722275e269 100644 --- a/arm_compute/runtime/CL/functions/CLQLSTMLayer.h +++ b/arm_compute/runtime/CL/functions/CLQLSTMLayer.h @@ -27,6 +27,7 @@ #include "arm_compute/core/CL/kernels/CLElementwiseOperationKernel.h" #include "arm_compute/core/CL/kernels/CLGEMMLowpReductionKernel.h" #include "arm_compute/core/CL/kernels/CLPixelWiseMultiplicationKernel.h" +#include "arm_compute/core/CL/kernels/CLQLSTMLayerNormalizationKernel.h" #include "arm_compute/core/Types.h" #include "arm_compute/runtime/CL/functions/CLActivationLayer.h" #include "arm_compute/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.h" @@ -216,6 +217,16 @@ public: void prepare() override; private: + enum class LayerNormGate : uint8_t + { + Forget, + Cell, + Input, + Output, + Count + }; + static constexpr uint8_t _layer_norm_count = static_cast(LayerNormGate::Count); + /** Internal method to configure matrix multiplication plus output stage of each gate. * * @param[in] compile_context The compile context to be used. @@ -302,6 +313,7 @@ private: CLGEMMLowpOutputStage _projection_outstage{}; CLSaturatedArithmeticOperationKernel _accumulate_projection{}; CLActivationLayer _projection_clip{}; + std::array _layer_norms{ {} }; // Tensor pointers const ICLTensor *_input_to_input_weights @@ -317,6 +329,61 @@ private: const ICLTensor *_recurrent_to_cell_weights{ nullptr }; const ICLTensor *_recurrent_to_output_weights{ nullptr }; const ICLTensor *_projection_weights{ nullptr }; + std::array _layer_norm_weights{ {} }; + std::array _layer_norm_bias{ {} }; + + using LayerNormIndexType = typename std::underlying_type::type; + inline LayerNormIndexType getGateIndex(LayerNormGate g) + { + return static_cast(g); + } + + inline void set_layer_norm_weight(const ICLTensor *t, LayerNormGate g) + { + _layer_norm_weights[getGateIndex(g)] = t; + } + + inline void set_layer_norm_bias(const ICLTensor *t, LayerNormGate g) + { + _layer_norm_bias[getGateIndex(g)] = t; + } + + inline const ICLTensor *get_layer_norm_weight(LayerNormGate g) + { + return _layer_norm_weights[getGateIndex(g)]; + } + + inline const ICLTensor *get_layer_norm_bias(LayerNormGate g) + { + return _layer_norm_bias[getGateIndex(g)]; + } + + inline CLQLSTMLayerNormalizationKernel &get_layer_norm(LayerNormGate g) + { + return _layer_norms[getGateIndex(g)]; + } + + inline void configure_layer_norm(LayerNormGate g, const ICLTensor *in) + { + ARM_COMPUTE_ERROR_ON(!_has_layer_norm); + + CLTensor *out = &get_layer_norm_output(g); + _memory_group.manage(out); + out->allocator()->init(*(in->info())); + + get_layer_norm(g).configure(in, out, get_layer_norm_weight(g), get_layer_norm_bias(g)); + } + + inline static Status validate_layer_norm(const ITensorInfo &in, const ITensorInfo &weight, const ITensorInfo &bias) + { + // Output quantization scale will be different, but ignored here + // since it will be configured at configure() stage. + const TensorInfo out + { + in + }; + return CLQLSTMLayerNormalizationKernel::validate(&in, &out, &weight, &bias); + } // Temporary tensors CLTensor _input_to_forget_weights_transposed{ nullptr }; @@ -368,6 +435,12 @@ private: CLTensor _mm_projection_res{ nullptr }; CLTensor _projection_outstage_res{ nullptr }; CLTensor _ones{ nullptr }; + std::array _layer_norm_output{ {} }; + + inline CLTensor &get_layer_norm_output(LayerNormGate g) + { + return _layer_norm_output[getGateIndex(g)]; + } bool _is_prepared{ false }; bool _has_cifg{ false }; @@ -375,6 +448,7 @@ private: bool _has_projection{ false }; bool _has_projection_clipping{ false }; bool _has_peephole{ false }; + bool _has_layer_norm{ false }; }; } // namespace arm_compute #endif /* ARM_COMPUTE_CLQLSTMLAYER_H */ -- cgit v1.2.1