From 3a35398ed6cc5d9c0f45f33dabb2bfbb017bcf60 Mon Sep 17 00:00:00 2001 From: Sheri Zhang Date: Tue, 21 Apr 2020 13:10:24 +0100 Subject: COMPMID-3240: Add support for layer normalization to CLQLSTMLayer Signed-off-by: Sheri Zhang Change-Id: I45359a4ddb46c059097a2d77c008f802e8f4c143 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3065 Tested-by: Arm Jenkins Reviewed-by: Michele Di Giorgio Reviewed-by: Sang-Hoon Park Comments-Addressed: Arm Jenkins --- .../CL/kernels/CLQLSTMLayerNormalizationKernel.h | 2 +- arm_compute/runtime/CL/functions/CLQLSTMLayer.h | 74 +++++++++++++ .../CL/kernels/CLQLSTMLayerNormalizationKernel.cpp | 10 +- .../kernels/NEQLSTMLayerNormalizationKernel.cpp | 5 +- src/runtime/CL/functions/CLQLSTMLayer.cpp | 120 +++++++++++++++++++-- 5 files changed, 193 insertions(+), 18 deletions(-) diff --git a/arm_compute/core/CL/kernels/CLQLSTMLayerNormalizationKernel.h b/arm_compute/core/CL/kernels/CLQLSTMLayerNormalizationKernel.h index 1a2f3111f5..2d4707245f 100644 --- a/arm_compute/core/CL/kernels/CLQLSTMLayerNormalizationKernel.h +++ b/arm_compute/core/CL/kernels/CLQLSTMLayerNormalizationKernel.h @@ -73,7 +73,7 @@ public: * * @return a status */ - static Status validate(const ITensorInfo *input, ITensorInfo *output, const ITensorInfo *weight, const ITensorInfo *bias); + static Status validate(const ITensorInfo *input, const ITensorInfo *output, const ITensorInfo *weight, const ITensorInfo *bias); // Inherited methods overridden: void run(const Window &window, cl::CommandQueue &queue) override; diff --git a/arm_compute/runtime/CL/functions/CLQLSTMLayer.h b/arm_compute/runtime/CL/functions/CLQLSTMLayer.h index 72a61f8505..722275e269 100644 --- a/arm_compute/runtime/CL/functions/CLQLSTMLayer.h +++ b/arm_compute/runtime/CL/functions/CLQLSTMLayer.h @@ -27,6 +27,7 @@ #include "arm_compute/core/CL/kernels/CLElementwiseOperationKernel.h" #include "arm_compute/core/CL/kernels/CLGEMMLowpReductionKernel.h" #include "arm_compute/core/CL/kernels/CLPixelWiseMultiplicationKernel.h" +#include "arm_compute/core/CL/kernels/CLQLSTMLayerNormalizationKernel.h" #include "arm_compute/core/Types.h" #include "arm_compute/runtime/CL/functions/CLActivationLayer.h" #include "arm_compute/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.h" @@ -216,6 +217,16 @@ public: void prepare() override; private: + enum class LayerNormGate : uint8_t + { + Forget, + Cell, + Input, + Output, + Count + }; + static constexpr uint8_t _layer_norm_count = static_cast(LayerNormGate::Count); + /** Internal method to configure matrix multiplication plus output stage of each gate. * * @param[in] compile_context The compile context to be used. @@ -302,6 +313,7 @@ private: CLGEMMLowpOutputStage _projection_outstage{}; CLSaturatedArithmeticOperationKernel _accumulate_projection{}; CLActivationLayer _projection_clip{}; + std::array _layer_norms{ {} }; // Tensor pointers const ICLTensor *_input_to_input_weights @@ -317,6 +329,61 @@ private: const ICLTensor *_recurrent_to_cell_weights{ nullptr }; const ICLTensor *_recurrent_to_output_weights{ nullptr }; const ICLTensor *_projection_weights{ nullptr }; + std::array _layer_norm_weights{ {} }; + std::array _layer_norm_bias{ {} }; + + using LayerNormIndexType = typename std::underlying_type::type; + inline LayerNormIndexType getGateIndex(LayerNormGate g) + { + return static_cast(g); + } + + inline void set_layer_norm_weight(const ICLTensor *t, LayerNormGate g) + { + _layer_norm_weights[getGateIndex(g)] = t; + } + + inline void set_layer_norm_bias(const ICLTensor *t, LayerNormGate g) + { + _layer_norm_bias[getGateIndex(g)] = t; + } + + inline const ICLTensor *get_layer_norm_weight(LayerNormGate g) + { + return _layer_norm_weights[getGateIndex(g)]; + } + + inline const ICLTensor *get_layer_norm_bias(LayerNormGate g) + { + return _layer_norm_bias[getGateIndex(g)]; + } + + inline CLQLSTMLayerNormalizationKernel &get_layer_norm(LayerNormGate g) + { + return _layer_norms[getGateIndex(g)]; + } + + inline void configure_layer_norm(LayerNormGate g, const ICLTensor *in) + { + ARM_COMPUTE_ERROR_ON(!_has_layer_norm); + + CLTensor *out = &get_layer_norm_output(g); + _memory_group.manage(out); + out->allocator()->init(*(in->info())); + + get_layer_norm(g).configure(in, out, get_layer_norm_weight(g), get_layer_norm_bias(g)); + } + + inline static Status validate_layer_norm(const ITensorInfo &in, const ITensorInfo &weight, const ITensorInfo &bias) + { + // Output quantization scale will be different, but ignored here + // since it will be configured at configure() stage. + const TensorInfo out + { + in + }; + return CLQLSTMLayerNormalizationKernel::validate(&in, &out, &weight, &bias); + } // Temporary tensors CLTensor _input_to_forget_weights_transposed{ nullptr }; @@ -368,6 +435,12 @@ private: CLTensor _mm_projection_res{ nullptr }; CLTensor _projection_outstage_res{ nullptr }; CLTensor _ones{ nullptr }; + std::array _layer_norm_output{ {} }; + + inline CLTensor &get_layer_norm_output(LayerNormGate g) + { + return _layer_norm_output[getGateIndex(g)]; + } bool _is_prepared{ false }; bool _has_cifg{ false }; @@ -375,6 +448,7 @@ private: bool _has_projection{ false }; bool _has_projection_clipping{ false }; bool _has_peephole{ false }; + bool _has_layer_norm{ false }; }; } // namespace arm_compute #endif /* ARM_COMPUTE_CLQLSTMLAYER_H */ diff --git a/src/core/CL/kernels/CLQLSTMLayerNormalizationKernel.cpp b/src/core/CL/kernels/CLQLSTMLayerNormalizationKernel.cpp index b9767e8ec2..d9da3cb36e 100644 --- a/src/core/CL/kernels/CLQLSTMLayerNormalizationKernel.cpp +++ b/src/core/CL/kernels/CLQLSTMLayerNormalizationKernel.cpp @@ -31,11 +31,17 @@ namespace arm_compute { namespace { +QuantizationInfo compute_output_qinfo() +{ + return QuantizationInfo(1.f / 4096); +} + std::pair validate_and_configure_window(ITensorInfo *input, ITensorInfo *output) { ARM_COMPUTE_ERROR_ON_NULLPTR(input); // Output auto inizialitation if not yet initialized auto_init_if_empty(*output, *input); + output->set_quantization_info(compute_output_qinfo()); const uint32_t temp_num_elems_processed_per_iteration = max_cl_vector_width / input->element_size(); /* If width is less then step, then make step same as width to avoid global size being step instead of actual width. */ @@ -48,7 +54,7 @@ std::pair validate_and_configure_window(ITensorInfo *input, ITen return std::make_pair(Status{}, win); } -Status validate_arguments(const ITensorInfo *input, ITensorInfo *output, const ITensorInfo *weight, const ITensorInfo *bias) +Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const ITensorInfo *weight, const ITensorInfo *bias) { ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, weight, bias, output); @@ -129,7 +135,7 @@ void CLQLSTMLayerNormalizationKernel::configure(const ICLTensor *input, ICLTenso configure(CLKernelLibrary::get().get_compile_context(), input, output, weight, bias); } -Status CLQLSTMLayerNormalizationKernel::validate(const ITensorInfo *input, ITensorInfo *output, const ITensorInfo *weight, const ITensorInfo *bias) +Status CLQLSTMLayerNormalizationKernel::validate(const ITensorInfo *input, const ITensorInfo *output, const ITensorInfo *weight, const ITensorInfo *bias) { ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, weight, bias)); ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(), output->clone().get()).first); diff --git a/src/core/NEON/kernels/NEQLSTMLayerNormalizationKernel.cpp b/src/core/NEON/kernels/NEQLSTMLayerNormalizationKernel.cpp index e966c6bdba..29ffee867b 100644 --- a/src/core/NEON/kernels/NEQLSTMLayerNormalizationKernel.cpp +++ b/src/core/NEON/kernels/NEQLSTMLayerNormalizationKernel.cpp @@ -172,10 +172,7 @@ void NEQLSTMLayerNormalizationKernel::run(const Window &window, const ThreadInfo inline QuantizationInfo NEQLSTMLayerNormalizationKernel::compute_output_qinfo() { - const UniformQuantizationInfo iq_info = _input->info()->quantization_info().uniform(); - const UniformQuantizationInfo wq_info = _weight->info()->quantization_info().uniform(); - const float output_scale = (wq_info.scale * iq_info.scale) * 1024; - return QuantizationInfo(output_scale); + return QuantizationInfo(1.f / 4096); } inline std::pair NEQLSTMLayerNormalizationKernel::sum_qsymm16(const int16_t *input_ptr) diff --git a/src/runtime/CL/functions/CLQLSTMLayer.cpp b/src/runtime/CL/functions/CLQLSTMLayer.cpp index 88c5f77b9f..d9b5c7c64d 100644 --- a/src/runtime/CL/functions/CLQLSTMLayer.cpp +++ b/src/runtime/CL/functions/CLQLSTMLayer.cpp @@ -92,9 +92,6 @@ void CLQLSTMLayer::configure(const CLCompileContext &compile_context, const ICLT ICLTensor *cell_state_out, ICLTensor *output_state_out, const LSTMParams &lstm_params) { - ARM_COMPUTE_UNUSED(forget_gate_bias); - ARM_COMPUTE_UNUSED(cell_bias); - ARM_COMPUTE_UNUSED(output_gate_bias); ARM_COMPUTE_ERROR_ON_NULLPTR(input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights, recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights, forget_gate_bias, cell_bias, output_gate_bias, cell_state_in, output_state_in, cell_state_out, output_state_out); @@ -125,6 +122,21 @@ void CLQLSTMLayer::configure(const CLCompileContext &compile_context, const ICLT _recurrent_to_output_weights = recurrent_to_output_weights; _projection_weights = lstm_params.projection_weights(); + // Layer normalization + _has_layer_norm = lstm_params.use_layer_norm(); + if(_has_layer_norm) + { + set_layer_norm_weight(lstm_params.forget_layer_norm_weights(), LayerNormGate::Forget); + set_layer_norm_weight(lstm_params.cell_layer_norm_weights(), LayerNormGate::Cell); + set_layer_norm_weight(lstm_params.input_layer_norm_weights(), LayerNormGate::Input); + set_layer_norm_weight(lstm_params.output_layer_norm_weights(), LayerNormGate::Output); + + set_layer_norm_bias(forget_gate_bias, LayerNormGate::Forget); + set_layer_norm_bias(cell_bias, LayerNormGate::Cell); + set_layer_norm_bias(lstm_params.input_gate_bias(), LayerNormGate::Input); + set_layer_norm_bias(output_gate_bias, LayerNormGate::Output); + } + _has_cifg = lstm_params.has_cifg_opt(); _has_projection = lstm_params.has_projection(); _has_peephole = lstm_params.has_peephole_opt(); @@ -218,14 +230,23 @@ void CLQLSTMLayer::configure(const CLCompileContext &compile_context, const ICLT _cell_to_forget_outstage_res.allocator()->allocate(); } + CLTensor *forget_activation_input = &_recurrent_to_forget_outstage_res; + + if(_has_layer_norm) + { + configure_layer_norm(LayerNormGate::Forget, &_recurrent_to_forget_outstage_res); + _recurrent_to_forget_outstage_res.allocator()->allocate(); + forget_activation_input = &get_layer_norm_output(LayerNormGate::Forget); + } + // Output quantization info of Sigmoid and Tanh activations const QuantizationInfo sigmoid_tanh_outqinfo(1.f / 32768.f, 0); const TensorInfo forget_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo); _memory_group.manage(&_forget_gate); _forget_gate.allocator()->init(forget_gate_info); - _forget_gate_sigmoid.configure(compile_context, &_recurrent_to_forget_outstage_res, &_forget_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)); - _recurrent_to_forget_outstage_res.allocator()->allocate(); + _forget_gate_sigmoid.configure(compile_context, forget_activation_input, &_forget_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)); + forget_activation_input->allocator()->allocate(); // Modulation gate. const TensorInfo cell_outstage_info(mm_out_info.tensor_shape(), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.cell_intermediate_scale(), 0)); @@ -245,11 +266,20 @@ void CLQLSTMLayer::configure(const CLCompileContext &compile_context, const ICLT ConvertPolicy::SATURATE); _input_to_cell_outstage_res.allocator()->allocate(); + CLTensor *cell_activation_input = &_recurrent_to_cell_outstage_res; + + if(_has_layer_norm) + { + configure_layer_norm(LayerNormGate::Cell, &_recurrent_to_cell_outstage_res); + _recurrent_to_cell_outstage_res.allocator()->allocate(); + cell_activation_input = &get_layer_norm_output(LayerNormGate::Cell); + } + const TensorInfo cell_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo); _memory_group.manage(&_cell_gate); _cell_gate.allocator()->init(cell_gate_info); - _cell_gate_tanh.configure(compile_context, &_recurrent_to_cell_outstage_res, &_cell_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f)); - _recurrent_to_cell_outstage_res.allocator()->allocate(); + _cell_gate_tanh.configure(compile_context, cell_activation_input, &_cell_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f)); + cell_activation_input->allocator()->allocate(); // Input gate. const TensorInfo input_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo); @@ -293,8 +323,17 @@ void CLQLSTMLayer::configure(const CLCompileContext &compile_context, const ICLT _cell_to_input_outstage_res.allocator()->allocate(); } - _input_gate_tanh.configure(compile_context, &_recurrent_to_input_outstage_res, &_input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f)); - _recurrent_to_input_outstage_res.allocator()->allocate(); + CLTensor *input_activation_input = &_recurrent_to_input_outstage_res; + + if(_has_layer_norm) + { + configure_layer_norm(LayerNormGate::Input, &_recurrent_to_input_outstage_res); + _recurrent_to_input_outstage_res.allocator()->allocate(); + input_activation_input = &get_layer_norm_output(LayerNormGate::Input); + } + + _input_gate_tanh.configure(compile_context, input_activation_input, &_input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f)); + input_activation_input->allocator()->allocate(); } // Cell. // TODO(COMPMID-3396): Perform multiplication in the quantized domain in CLPixelWiseMultiplicationKernel @@ -344,11 +383,20 @@ void CLQLSTMLayer::configure(const CLCompileContext &compile_context, const ICLT _mul_cell_to_output_res.allocator()->allocate(); } + CLTensor *output_activation_input = &_recurrent_to_output_outstage_res; + + if(_has_layer_norm) + { + configure_layer_norm(LayerNormGate::Output, &_recurrent_to_output_outstage_res); + _recurrent_to_output_outstage_res.allocator()->allocate(); + output_activation_input = &get_layer_norm_output(LayerNormGate::Output); + } + const TensorInfo output_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo); _memory_group.manage(&_output_gate); _output_gate.allocator()->init(output_gate_info); - _output_gate_sigmoid.configure(compile_context, &_recurrent_to_output_outstage_res, &_output_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)); - _recurrent_to_output_outstage_res.allocator()->allocate(); + _output_gate_sigmoid.configure(compile_context, output_activation_input, &_output_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)); + output_activation_input->allocator()->allocate(); // Hidden. _hidden_tanh.configure(compile_context, cell_state_out, &_input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f)); @@ -525,6 +573,8 @@ Status CLQLSTMLayer::validate(const ITensorInfo *input, gemmlowp_info.gemmlowp_max_bound = std::numeric_limits::max(); gemmlowp_info.output_data_type = DataType::QSYMM16; + const bool has_layer_norm = lstm_params.use_layer_norm(); + // Forget gate. const TensorInfo forget_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.forget_intermediate_scale(), 0)); const TensorInfo mm_out_info(TensorShape(num_units, batch_size), 1, DataType::S32); @@ -547,6 +597,13 @@ Status CLQLSTMLayer::validate(const ITensorInfo *input, ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, &forget_outstage_info, &forget_outstage_info, &forget_outstage_info, ConvertPolicy::SATURATE)); } + if(has_layer_norm) + { + const ITensorInfo *w_info = lstm_params.forget_layer_norm_weights(); + const ITensorInfo *b_info = forget_gate_bias; + ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(forget_outstage_info, *w_info, *b_info)); + } + // Output quantization info of Sigmoid and Tanh activations const QuantizationInfo sigmoid_tanh_outqinfo(1.f / 32768.f, 0); @@ -563,6 +620,13 @@ Status CLQLSTMLayer::validate(const ITensorInfo *input, ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, &cell_outstage_info, &cell_outstage_info, &cell_outstage_info, ConvertPolicy::SATURATE)); + if(has_layer_norm) + { + const ITensorInfo *w_info = lstm_params.cell_layer_norm_weights(); + const ITensorInfo *b_info = cell_bias; + ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(cell_outstage_info, *w_info, *b_info)); + } + const TensorInfo cell_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo); ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&cell_outstage_info, &cell_gate_info, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f))); @@ -602,6 +666,13 @@ Status CLQLSTMLayer::validate(const ITensorInfo *input, ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, &input_outstage_info, &input_outstage_info, &input_outstage_info, ConvertPolicy::SATURATE)); } + if(has_layer_norm) + { + const ITensorInfo *w_info = lstm_params.input_layer_norm_weights(); + const ITensorInfo *b_info = lstm_params.input_gate_bias(); + ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(cell_outstage_info, *w_info, *b_info)); + } + ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&input_outstage_info, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f))); } // Cell. @@ -634,6 +705,13 @@ Status CLQLSTMLayer::validate(const ITensorInfo *input, ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, &output_outstage_info, &output_outstage_info, &output_outstage_info, ConvertPolicy::SATURATE)); } + if(has_layer_norm) + { + const ITensorInfo *w_info = lstm_params.output_layer_norm_weights(); + const ITensorInfo *b_info = output_gate_bias; + ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(output_outstage_info, *w_info, *b_info)); + } + const TensorInfo output_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo); ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&output_outstage_info, &output_gate_info, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC))); @@ -715,6 +793,11 @@ void CLQLSTMLayer::run() CLScheduler::get().enqueue(_accumulate_cell_forget); } + if(_has_layer_norm) + { + CLScheduler::get().enqueue(get_layer_norm(LayerNormGate::Forget)); + } + _forget_gate_sigmoid.run(); // Modulation gate. @@ -725,6 +808,11 @@ void CLQLSTMLayer::run() _recurrent_to_cell_outstage.run(); CLScheduler::get().enqueue(_accumulate_input_recurrent_modulation); + if(_has_layer_norm) + { + CLScheduler::get().enqueue(get_layer_norm(LayerNormGate::Cell)); + } + _cell_gate_tanh.run(); // Input gate @@ -747,6 +835,11 @@ void CLQLSTMLayer::run() CLScheduler::get().enqueue(_accumulate_cell_input); } + if(_has_layer_norm) + { + CLScheduler::get().enqueue(get_layer_norm(LayerNormGate::Input)); + } + _input_gate_tanh.run(); } @@ -771,6 +864,11 @@ void CLQLSTMLayer::run() CLScheduler::get().enqueue(_accumulate_cell_to_output); } + if(_has_layer_norm) + { + CLScheduler::get().enqueue(get_layer_norm(LayerNormGate::Output)); + } + _output_gate_sigmoid.run(); // Hidden. -- cgit v1.2.1