From 9230e2789e421021804bc3a65cf47df4749b0765 Mon Sep 17 00:00:00 2001 From: Sang-Hoon Park Date: Sat, 18 Apr 2020 00:46:34 +0100 Subject: COMPMID-3241: Add Layer Normalization to NEQLSTMLayer - Add output quantization calculation to Layer Normalization - Add members for Layer Normalization to NEQLSTMLayer - Add configure/validate/run of Layer Normalization to NEQLSTMLayer Change-Id: I278c8e0edbb21212f3afa4d4a336df0f1a4c1bfb Signed-off-by: Sang-Hoon Park Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3059 Tested-by: Arm Jenkins Reviewed-by: Michele Di Giorgio Comments-Addressed: Arm Jenkins --- .../NEON/kernels/NEQLSTMLayerNormalizationKernel.h | 2 + arm_compute/runtime/NEON/functions/NEQLSTMLayer.h | 76 ++++++++++++- .../kernels/NEQLSTMLayerNormalizationKernel.cpp | 17 ++- src/runtime/NEON/functions/NEQLSTMLayer.cpp | 125 ++++++++++++++++++--- 4 files changed, 198 insertions(+), 22 deletions(-) diff --git a/arm_compute/core/NEON/kernels/NEQLSTMLayerNormalizationKernel.h b/arm_compute/core/NEON/kernels/NEQLSTMLayerNormalizationKernel.h index 631de66cc2..f5e8da7feb 100644 --- a/arm_compute/core/NEON/kernels/NEQLSTMLayerNormalizationKernel.h +++ b/arm_compute/core/NEON/kernels/NEQLSTMLayerNormalizationKernel.h @@ -130,6 +130,8 @@ private: const int16_t *weight_ptr, const int32_t *bias_ptr, int32_t mean, int32_t inv_std_mul, int32_t inv_std_shift); + /** Function to compute output quantization information */ + QuantizationInfo compute_output_qinfo(); }; } // namespace arm_compute #endif /* ARM_COMPUTE_NEQLSTMLAYERNORMALIZATIONKERNEL_H */ diff --git a/arm_compute/runtime/NEON/functions/NEQLSTMLayer.h b/arm_compute/runtime/NEON/functions/NEQLSTMLayer.h index a37909b775..312a8984b5 100644 --- a/arm_compute/runtime/NEON/functions/NEQLSTMLayer.h +++ b/arm_compute/runtime/NEON/functions/NEQLSTMLayer.h @@ -28,6 +28,7 @@ #include "arm_compute/core/NEON/kernels/NEArithmeticSubtractionKernel.h" #include "arm_compute/core/NEON/kernels/NEGEMMLowpReductionKernel.h" #include "arm_compute/core/NEON/kernels/NEPixelWiseMultiplicationKernel.h" +#include "arm_compute/core/NEON/kernels/NEQLSTMLayerNormalizationKernel.h" #include "arm_compute/core/Types.h" #include "arm_compute/runtime/NEON/functions/NEActivationLayer.h" #include "arm_compute/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.h" @@ -169,6 +170,16 @@ public: void prepare() override; private: + enum class LayerNormGate : uint8_t + { + Forget, + Cell, + Input, + Output, + Count + }; + static constexpr uint8_t _layer_norm_count = static_cast(LayerNormGate::Count); + /** Internal method to configure matrix multiplication plus output stage of each gate. * * @param[in] mm Matrix multiplication function to use. @@ -254,12 +265,10 @@ private: NEGEMMLowpOutputStage _projection_outstage{}; NEArithmeticAdditionKernel _accumulate_projection{}; NEActivationLayer _projection_clip{}; + std::array _layer_norms{}; // Tensor pointers - const ITensor *_input_to_input_weights - { - nullptr - }; + const ITensor *_input_to_input_weights{ nullptr }; const ITensor *_recurrent_to_input_weights{ nullptr }; const ITensor *_projection_bias{ nullptr }; const ITensor *_input_to_forget_weights{ nullptr }; @@ -269,6 +278,58 @@ private: const ITensor *_recurrent_to_cell_weights{ nullptr }; const ITensor *_recurrent_to_output_weights{ nullptr }; const ITensor *_projection_weights{ nullptr }; + std::array _layer_norm_weights{}; + std::array _layer_norm_bias{}; + + using LayerNormIndexType = typename std::underlying_type::type; + inline LayerNormIndexType getGateIndex(LayerNormGate g) + { + return static_cast(g); + } + + inline void set_layer_norm_weight(const ITensor *t, LayerNormGate g) + { + _layer_norm_weights[getGateIndex(g)] = t; + } + + inline void set_layer_norm_bias(const ITensor *t, LayerNormGate g) + { + _layer_norm_bias[getGateIndex(g)] = t; + } + + inline const ITensor *get_layer_norm_weight(LayerNormGate g) + { + return _layer_norm_weights[getGateIndex(g)]; + } + + inline const ITensor *get_layer_norm_bias(LayerNormGate g) + { + return _layer_norm_bias[getGateIndex(g)]; + } + + inline NEQLSTMLayerNormalizationKernel &get_layer_norm(LayerNormGate g) + { + return _layer_norms[getGateIndex(g)]; + } + + inline void configure_layer_norm(LayerNormGate g, const ITensor *in) + { + ARM_COMPUTE_ERROR_ON(!_has_layer_norm); + + Tensor &out = get_layer_norm_output(g); + _memory_group.manage(&out); + out.allocator()->init(*(in->info())); + + get_layer_norm(g).configure(in, &out, get_layer_norm_weight(g), get_layer_norm_bias(g)); + } + + inline static Status validate_layer_norm(const ITensorInfo &in, const ITensorInfo &weight, const ITensorInfo &bias) + { + // Output quantization scale will be different, but ignored here + // since it will be configured at configure() stage. + const TensorInfo out{ in }; + return NEQLSTMLayerNormalizationKernel::validate(&in, &out, &weight, &bias); + } // Temporary tensors Tensor _input_to_forget_weights_transposed{ nullptr }; @@ -320,6 +381,12 @@ private: Tensor _mm_projection_res{ nullptr }; Tensor _projection_outstage_res{ nullptr }; Tensor _ones{ nullptr }; + std::array _layer_norm_output{}; + + inline Tensor &get_layer_norm_output(LayerNormGate g) + { + return _layer_norm_output[getGateIndex(g)]; + } bool _is_prepared{ false }; bool _has_cifg{ false }; @@ -327,6 +394,7 @@ private: bool _has_projection{ false }; bool _has_projection_clipping{ false }; bool _has_peephole{ false }; + bool _has_layer_norm{ false }; }; } // namespace arm_compute #endif /* ARM_COMPUTE_NEQLSTMLAYER_H */ diff --git a/src/core/NEON/kernels/NEQLSTMLayerNormalizationKernel.cpp b/src/core/NEON/kernels/NEQLSTMLayerNormalizationKernel.cpp index db2ff85db9..e966c6bdba 100644 --- a/src/core/NEON/kernels/NEQLSTMLayerNormalizationKernel.cpp +++ b/src/core/NEON/kernels/NEQLSTMLayerNormalizationKernel.cpp @@ -80,11 +80,9 @@ inline int64x2x2_t mul_add(const int32x4_t &a, const int32x4_t &b, const int32x4 void NEQLSTMLayerNormalizationKernel::configure(const ITensor *input, ITensor *output, const ITensor *weight, const ITensor *bias) { - ARM_COMPUTE_ERROR_ON_NULLPTR(input, weight); - ARM_COMPUTE_ERROR_THROW_ON(validate(input->info(), - output ? output->info() : nullptr, - weight->info(), - bias ? bias->info() : nullptr)); + ARM_COMPUTE_ERROR_ON_NULLPTR(input, weight, bias, output); + ARM_COMPUTE_ERROR_ON(input == output); + ARM_COMPUTE_ERROR_THROW_ON(validate(input->info(), output->info(), weight->info(), bias->info())); static const std::map fn_map = { @@ -98,6 +96,7 @@ void NEQLSTMLayerNormalizationKernel::configure(const ITensor *input, ITensor *o _fn = fn_map.at(_input->info()->data_type()); auto_init_if_empty(*_output->info(), *_input->info()); + _output->info()->set_quantization_info(compute_output_qinfo()); const UniformQuantizationInfo wq_info = _weight->info()->quantization_info().uniform(); const Status s = quantization::calculate_quantized_multiplier(wq_info.scale, &_output_multiplier, &_output_shift); @@ -171,6 +170,14 @@ void NEQLSTMLayerNormalizationKernel::run(const Window &window, const ThreadInfo _fn(*this); } +inline QuantizationInfo NEQLSTMLayerNormalizationKernel::compute_output_qinfo() +{ + const UniformQuantizationInfo iq_info = _input->info()->quantization_info().uniform(); + const UniformQuantizationInfo wq_info = _weight->info()->quantization_info().uniform(); + const float output_scale = (wq_info.scale * iq_info.scale) * 1024; + return QuantizationInfo(output_scale); +} + inline std::pair NEQLSTMLayerNormalizationKernel::sum_qsymm16(const int16_t *input_ptr) { ARM_COMPUTE_ERROR_ON(!input_ptr); diff --git a/src/runtime/NEON/functions/NEQLSTMLayer.cpp b/src/runtime/NEON/functions/NEQLSTMLayer.cpp index b02fab227b..a279bba2ab 100644 --- a/src/runtime/NEON/functions/NEQLSTMLayer.cpp +++ b/src/runtime/NEON/functions/NEQLSTMLayer.cpp @@ -79,9 +79,6 @@ void NEQLSTMLayer::configure(const ITensor *input, ITensor *cell_state_out, ITensor *output_state_out, const LSTMParams &lstm_params) { - ARM_COMPUTE_UNUSED(forget_gate_bias); - ARM_COMPUTE_UNUSED(cell_bias); - ARM_COMPUTE_UNUSED(output_gate_bias); ARM_COMPUTE_ERROR_ON_NULLPTR(input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights, recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights, forget_gate_bias, cell_bias, output_gate_bias, cell_state_in, output_state_in, cell_state_out, output_state_out); @@ -112,6 +109,21 @@ void NEQLSTMLayer::configure(const ITensor *input, _recurrent_to_output_weights = recurrent_to_output_weights; _projection_weights = lstm_params.projection_weights(); + // Layer normalization + _has_layer_norm = lstm_params.use_layer_norm(); + if(_has_layer_norm) + { + set_layer_norm_weight(lstm_params.forget_layer_norm_weights(), LayerNormGate::Forget); + set_layer_norm_weight(lstm_params.cell_layer_norm_weights(), LayerNormGate::Cell); + set_layer_norm_weight(lstm_params.input_layer_norm_weights(), LayerNormGate::Input); + set_layer_norm_weight(lstm_params.output_layer_norm_weights(), LayerNormGate::Output); + + set_layer_norm_bias(forget_gate_bias, LayerNormGate::Forget); + set_layer_norm_bias(cell_bias, LayerNormGate::Cell); + set_layer_norm_bias(lstm_params.input_gate_bias(), LayerNormGate::Input); + set_layer_norm_bias(output_gate_bias, LayerNormGate::Output); + } + _has_cifg = lstm_params.has_cifg_opt(); _has_projection = lstm_params.has_projection(); _has_peephole = lstm_params.has_peephole_opt(); @@ -203,14 +215,23 @@ void NEQLSTMLayer::configure(const ITensor *input, _cell_to_forget_outstage_res.allocator()->allocate(); } + Tensor *forget_activation_input = &_recurrent_to_forget_outstage_res; + + if(_has_layer_norm) + { + configure_layer_norm(LayerNormGate::Forget, forget_activation_input); + forget_activation_input->allocator()->allocate(); + forget_activation_input = &get_layer_norm_output(LayerNormGate::Forget); + } + // Output quantization info of Sigmoid and Tanh activations const QuantizationInfo sigmoid_tanh_outqinfo(1.f / 32768.f, 0); + const TensorInfo forget_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo); - const TensorInfo forget_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo); _memory_group.manage(&_forget_gate); _forget_gate.allocator()->init(forget_gate_info); - _forget_gate_sigmoid.configure(&_recurrent_to_forget_outstage_res, &_forget_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)); - _recurrent_to_forget_outstage_res.allocator()->allocate(); + _forget_gate_sigmoid.configure(forget_activation_input, &_forget_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)); + forget_activation_input->allocator()->allocate(); // Modulation gate. const TensorInfo cell_outstage_info(mm_out_info.tensor_shape(), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.cell_intermediate_scale(), 0)); @@ -229,11 +250,21 @@ void NEQLSTMLayer::configure(const ITensor *input, _accumulate_input_recurrent_modulation.configure(&_input_to_cell_outstage_res, &_recurrent_to_cell_outstage_res, &_recurrent_to_cell_outstage_res, ConvertPolicy::SATURATE); _input_to_cell_outstage_res.allocator()->allocate(); + Tensor *cell_activation_input = &_recurrent_to_cell_outstage_res; + + if(_has_layer_norm) + { + configure_layer_norm(LayerNormGate::Cell, cell_activation_input); + cell_activation_input->allocator()->allocate(); + cell_activation_input = &get_layer_norm_output(LayerNormGate::Cell); + } + const TensorInfo cell_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo); + _memory_group.manage(&_cell_gate); _cell_gate.allocator()->init(cell_gate_info); - _cell_gate_tanh.configure(&_recurrent_to_cell_outstage_res, &_cell_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f)); - _recurrent_to_cell_outstage_res.allocator()->allocate(); + _cell_gate_tanh.configure(cell_activation_input, &_cell_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f)); + cell_activation_input->allocator()->allocate(); // Input gate. const TensorInfo input_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo); @@ -276,8 +307,17 @@ void NEQLSTMLayer::configure(const ITensor *input, _cell_to_input_outstage_res.allocator()->allocate(); } - _input_gate_tanh.configure(&_recurrent_to_input_outstage_res, &_input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f)); - _recurrent_to_input_outstage_res.allocator()->allocate(); + Tensor *input_activation_input = &_recurrent_to_input_outstage_res; + + if(_has_layer_norm) + { + configure_layer_norm(LayerNormGate::Input, input_activation_input); + input_activation_input->allocator()->allocate(); + input_activation_input = &get_layer_norm_output(LayerNormGate::Input); + } + + _input_gate_tanh.configure(input_activation_input, &_input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f)); + input_activation_input->allocator()->allocate(); } // Cell. // TODO(COMPMID-3395): Perform multiplication in the quantized domain in NEPixelWiseMultiplicationKernel @@ -325,11 +365,20 @@ void NEQLSTMLayer::configure(const ITensor *input, _mul_cell_to_output_res.allocator()->allocate(); } + Tensor *output_activation_input = &_recurrent_to_output_outstage_res; + + if(_has_layer_norm) + { + configure_layer_norm(LayerNormGate::Output, output_activation_input); + output_activation_input->allocator()->allocate(); + output_activation_input = &get_layer_norm_output(LayerNormGate::Output); + } const TensorInfo output_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo); + _memory_group.manage(&_output_gate); _output_gate.allocator()->init(output_gate_info); - _output_gate_sigmoid.configure(&_recurrent_to_output_outstage_res, &_output_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)); - _recurrent_to_output_outstage_res.allocator()->allocate(); + _output_gate_sigmoid.configure(output_activation_input, &_output_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)); + output_activation_input->allocator()->allocate(); // Hidden. _hidden_tanh.configure(cell_state_out, &_input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f)); @@ -505,6 +554,8 @@ Status NEQLSTMLayer::validate(const ITensorInfo *input, gemmlowp_info.gemmlowp_max_bound = std::numeric_limits::max(); gemmlowp_info.output_data_type = DataType::QSYMM16; + const bool has_layer_norm = lstm_params.use_layer_norm(); + // Forget gate. const TensorInfo forget_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.forget_intermediate_scale(), 0)); const TensorInfo mm_out_info(TensorShape(num_units, batch_size), 1, DataType::S32); @@ -527,10 +578,17 @@ Status NEQLSTMLayer::validate(const ITensorInfo *input, ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAdditionKernel::validate(&forget_outstage_info, &forget_outstage_info, &forget_outstage_info, ConvertPolicy::SATURATE)); } + if(has_layer_norm) + { + const ITensorInfo *w_info = lstm_params.forget_layer_norm_weights(); + const ITensorInfo *b_info = forget_gate_bias; + ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(forget_outstage_info, *w_info, *b_info)); + } + // Output quantization info of Sigmoid and Tanh activations const QuantizationInfo sigmoid_tanh_outqinfo(1.f / 32768.f, 0); + const TensorInfo forget_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo); - const TensorInfo forget_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo); ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&forget_outstage_info, &forget_gate_info, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC))); // Modulation gate. @@ -543,7 +601,14 @@ Status NEQLSTMLayer::validate(const ITensorInfo *input, ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAdditionKernel::validate(&cell_outstage_info, &cell_outstage_info, &cell_outstage_info, ConvertPolicy::SATURATE)); + if(has_layer_norm) + { + const ITensorInfo *w_info = lstm_params.cell_layer_norm_weights(); + const ITensorInfo *b_info = cell_bias; + ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(cell_outstage_info, *w_info, *b_info)); + } const TensorInfo cell_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo); + ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&cell_outstage_info, &cell_gate_info, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f))); // Input gate. @@ -582,6 +647,13 @@ Status NEQLSTMLayer::validate(const ITensorInfo *input, ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAdditionKernel::validate(&input_outstage_info, &input_outstage_info, &input_outstage_info, ConvertPolicy::SATURATE)); } + if(has_layer_norm) + { + const ITensorInfo *w_info = lstm_params.input_layer_norm_weights(); + const ITensorInfo *b_info = lstm_params.input_gate_bias(); + ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(input_outstage_info, *w_info, *b_info)); + } + ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&input_outstage_info, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f))); } // Cell. @@ -614,6 +686,13 @@ Status NEQLSTMLayer::validate(const ITensorInfo *input, ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAdditionKernel::validate(&output_outstage_info, &output_outstage_info, &output_outstage_info, ConvertPolicy::SATURATE)); } + if(has_layer_norm) + { + const ITensorInfo *w_info = lstm_params.output_layer_norm_weights(); + const ITensorInfo *b_info = output_gate_bias; + ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(output_outstage_info, *w_info, *b_info)); + } + const TensorInfo output_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo); ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&output_outstage_info, &output_gate_info, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC))); @@ -695,6 +774,11 @@ void NEQLSTMLayer::run() NEScheduler::get().schedule(&_accumulate_cell_forget, Window::DimY); } + if(_has_layer_norm) + { + NEScheduler::get().schedule(&get_layer_norm(LayerNormGate::Forget), Window::DimY); + } + _forget_gate_sigmoid.run(); // Modulation gate. @@ -705,6 +789,11 @@ void NEQLSTMLayer::run() _recurrent_to_cell_outstage.run(); NEScheduler::get().schedule(&_accumulate_input_recurrent_modulation, Window::DimY); + if(_has_layer_norm) + { + NEScheduler::get().schedule(&get_layer_norm(LayerNormGate::Cell), Window::DimY); + } + _cell_gate_tanh.run(); // Input gate @@ -727,6 +816,11 @@ void NEQLSTMLayer::run() NEScheduler::get().schedule(&_accumulate_cell_input, Window::DimY); } + if(_has_layer_norm) + { + NEScheduler::get().schedule(&get_layer_norm(LayerNormGate::Input), Window::DimY); + } + _input_gate_tanh.run(); } @@ -751,6 +845,11 @@ void NEQLSTMLayer::run() NEScheduler::get().schedule(&_accumulate_cell_to_output, Window::DimY); } + if(_has_layer_norm) + { + NEScheduler::get().schedule(&get_layer_norm(LayerNormGate::Output), Window::DimY); + } + _output_gate_sigmoid.run(); // Hidden. -- cgit v1.2.1