aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/CL/functions/CLQLSTMLayer.cpp
diff options
context:
space:
mode:
authorSheri Zhang <sheri.zhang@arm.com>2020-04-21 13:10:24 +0100
committerSheri Zhang <sheri.zhang@arm.com>2020-04-26 21:31:26 +0000
commit3a35398ed6cc5d9c0f45f33dabb2bfbb017bcf60 (patch)
treeaaa4a8949e288157ab92404d1745214529c0c69b /src/runtime/CL/functions/CLQLSTMLayer.cpp
parent31b49caa2ca9308a5ba62a598afc9d1982b4af18 (diff)
downloadComputeLibrary-3a35398ed6cc5d9c0f45f33dabb2bfbb017bcf60.tar.gz
COMPMID-3240: Add support for layer normalization to CLQLSTMLayer
Signed-off-by: Sheri Zhang <sheri.zhang@arm.com> Change-Id: I45359a4ddb46c059097a2d77c008f802e8f4c143 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3065 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com> Reviewed-by: Sang-Hoon Park <sang-hoon.park@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/runtime/CL/functions/CLQLSTMLayer.cpp')
-rw-r--r--src/runtime/CL/functions/CLQLSTMLayer.cpp120
1 files changed, 109 insertions, 11 deletions
diff --git a/src/runtime/CL/functions/CLQLSTMLayer.cpp b/src/runtime/CL/functions/CLQLSTMLayer.cpp
index 88c5f77b9f..d9b5c7c64d 100644
--- a/src/runtime/CL/functions/CLQLSTMLayer.cpp
+++ b/src/runtime/CL/functions/CLQLSTMLayer.cpp
@@ -92,9 +92,6 @@ void CLQLSTMLayer::configure(const CLCompileContext &compile_context, const ICLT
ICLTensor *cell_state_out, ICLTensor *output_state_out,
const LSTMParams<ICLTensor> &lstm_params)
{
- ARM_COMPUTE_UNUSED(forget_gate_bias);
- ARM_COMPUTE_UNUSED(cell_bias);
- ARM_COMPUTE_UNUSED(output_gate_bias);
ARM_COMPUTE_ERROR_ON_NULLPTR(input, input_to_forget_weights, input_to_cell_weights, input_to_output_weights,
recurrent_to_forget_weights, recurrent_to_cell_weights, recurrent_to_output_weights,
forget_gate_bias, cell_bias, output_gate_bias, cell_state_in, output_state_in, cell_state_out, output_state_out);
@@ -125,6 +122,21 @@ void CLQLSTMLayer::configure(const CLCompileContext &compile_context, const ICLT
_recurrent_to_output_weights = recurrent_to_output_weights;
_projection_weights = lstm_params.projection_weights();
+ // Layer normalization
+ _has_layer_norm = lstm_params.use_layer_norm();
+ if(_has_layer_norm)
+ {
+ set_layer_norm_weight(lstm_params.forget_layer_norm_weights(), LayerNormGate::Forget);
+ set_layer_norm_weight(lstm_params.cell_layer_norm_weights(), LayerNormGate::Cell);
+ set_layer_norm_weight(lstm_params.input_layer_norm_weights(), LayerNormGate::Input);
+ set_layer_norm_weight(lstm_params.output_layer_norm_weights(), LayerNormGate::Output);
+
+ set_layer_norm_bias(forget_gate_bias, LayerNormGate::Forget);
+ set_layer_norm_bias(cell_bias, LayerNormGate::Cell);
+ set_layer_norm_bias(lstm_params.input_gate_bias(), LayerNormGate::Input);
+ set_layer_norm_bias(output_gate_bias, LayerNormGate::Output);
+ }
+
_has_cifg = lstm_params.has_cifg_opt();
_has_projection = lstm_params.has_projection();
_has_peephole = lstm_params.has_peephole_opt();
@@ -218,14 +230,23 @@ void CLQLSTMLayer::configure(const CLCompileContext &compile_context, const ICLT
_cell_to_forget_outstage_res.allocator()->allocate();
}
+ CLTensor *forget_activation_input = &_recurrent_to_forget_outstage_res;
+
+ if(_has_layer_norm)
+ {
+ configure_layer_norm(LayerNormGate::Forget, &_recurrent_to_forget_outstage_res);
+ _recurrent_to_forget_outstage_res.allocator()->allocate();
+ forget_activation_input = &get_layer_norm_output(LayerNormGate::Forget);
+ }
+
// Output quantization info of Sigmoid and Tanh activations
const QuantizationInfo sigmoid_tanh_outqinfo(1.f / 32768.f, 0);
const TensorInfo forget_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
_memory_group.manage(&_forget_gate);
_forget_gate.allocator()->init(forget_gate_info);
- _forget_gate_sigmoid.configure(compile_context, &_recurrent_to_forget_outstage_res, &_forget_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
- _recurrent_to_forget_outstage_res.allocator()->allocate();
+ _forget_gate_sigmoid.configure(compile_context, forget_activation_input, &_forget_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
+ forget_activation_input->allocator()->allocate();
// Modulation gate.
const TensorInfo cell_outstage_info(mm_out_info.tensor_shape(), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.cell_intermediate_scale(), 0));
@@ -245,11 +266,20 @@ void CLQLSTMLayer::configure(const CLCompileContext &compile_context, const ICLT
ConvertPolicy::SATURATE);
_input_to_cell_outstage_res.allocator()->allocate();
+ CLTensor *cell_activation_input = &_recurrent_to_cell_outstage_res;
+
+ if(_has_layer_norm)
+ {
+ configure_layer_norm(LayerNormGate::Cell, &_recurrent_to_cell_outstage_res);
+ _recurrent_to_cell_outstage_res.allocator()->allocate();
+ cell_activation_input = &get_layer_norm_output(LayerNormGate::Cell);
+ }
+
const TensorInfo cell_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
_memory_group.manage(&_cell_gate);
_cell_gate.allocator()->init(cell_gate_info);
- _cell_gate_tanh.configure(compile_context, &_recurrent_to_cell_outstage_res, &_cell_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f));
- _recurrent_to_cell_outstage_res.allocator()->allocate();
+ _cell_gate_tanh.configure(compile_context, cell_activation_input, &_cell_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f));
+ cell_activation_input->allocator()->allocate();
// Input gate.
const TensorInfo input_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
@@ -293,8 +323,17 @@ void CLQLSTMLayer::configure(const CLCompileContext &compile_context, const ICLT
_cell_to_input_outstage_res.allocator()->allocate();
}
- _input_gate_tanh.configure(compile_context, &_recurrent_to_input_outstage_res, &_input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f));
- _recurrent_to_input_outstage_res.allocator()->allocate();
+ CLTensor *input_activation_input = &_recurrent_to_input_outstage_res;
+
+ if(_has_layer_norm)
+ {
+ configure_layer_norm(LayerNormGate::Input, &_recurrent_to_input_outstage_res);
+ _recurrent_to_input_outstage_res.allocator()->allocate();
+ input_activation_input = &get_layer_norm_output(LayerNormGate::Input);
+ }
+
+ _input_gate_tanh.configure(compile_context, input_activation_input, &_input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f));
+ input_activation_input->allocator()->allocate();
}
// Cell.
// TODO(COMPMID-3396): Perform multiplication in the quantized domain in CLPixelWiseMultiplicationKernel
@@ -344,11 +383,20 @@ void CLQLSTMLayer::configure(const CLCompileContext &compile_context, const ICLT
_mul_cell_to_output_res.allocator()->allocate();
}
+ CLTensor *output_activation_input = &_recurrent_to_output_outstage_res;
+
+ if(_has_layer_norm)
+ {
+ configure_layer_norm(LayerNormGate::Output, &_recurrent_to_output_outstage_res);
+ _recurrent_to_output_outstage_res.allocator()->allocate();
+ output_activation_input = &get_layer_norm_output(LayerNormGate::Output);
+ }
+
const TensorInfo output_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
_memory_group.manage(&_output_gate);
_output_gate.allocator()->init(output_gate_info);
- _output_gate_sigmoid.configure(compile_context, &_recurrent_to_output_outstage_res, &_output_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
- _recurrent_to_output_outstage_res.allocator()->allocate();
+ _output_gate_sigmoid.configure(compile_context, output_activation_input, &_output_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
+ output_activation_input->allocator()->allocate();
// Hidden.
_hidden_tanh.configure(compile_context, cell_state_out, &_input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f));
@@ -525,6 +573,8 @@ Status CLQLSTMLayer::validate(const ITensorInfo *input,
gemmlowp_info.gemmlowp_max_bound = std::numeric_limits<int16_t>::max();
gemmlowp_info.output_data_type = DataType::QSYMM16;
+ const bool has_layer_norm = lstm_params.use_layer_norm();
+
// Forget gate.
const TensorInfo forget_outstage_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, QuantizationInfo(lstm_params.forget_intermediate_scale(), 0));
const TensorInfo mm_out_info(TensorShape(num_units, batch_size), 1, DataType::S32);
@@ -547,6 +597,13 @@ Status CLQLSTMLayer::validate(const ITensorInfo *input,
ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, &forget_outstage_info, &forget_outstage_info, &forget_outstage_info, ConvertPolicy::SATURATE));
}
+ if(has_layer_norm)
+ {
+ const ITensorInfo *w_info = lstm_params.forget_layer_norm_weights();
+ const ITensorInfo *b_info = forget_gate_bias;
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(forget_outstage_info, *w_info, *b_info));
+ }
+
// Output quantization info of Sigmoid and Tanh activations
const QuantizationInfo sigmoid_tanh_outqinfo(1.f / 32768.f, 0);
@@ -563,6 +620,13 @@ Status CLQLSTMLayer::validate(const ITensorInfo *input,
ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, &cell_outstage_info, &cell_outstage_info, &cell_outstage_info, ConvertPolicy::SATURATE));
+ if(has_layer_norm)
+ {
+ const ITensorInfo *w_info = lstm_params.cell_layer_norm_weights();
+ const ITensorInfo *b_info = cell_bias;
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(cell_outstage_info, *w_info, *b_info));
+ }
+
const TensorInfo cell_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&cell_outstage_info, &cell_gate_info, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f)));
@@ -602,6 +666,13 @@ Status CLQLSTMLayer::validate(const ITensorInfo *input,
ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, &input_outstage_info, &input_outstage_info, &input_outstage_info, ConvertPolicy::SATURATE));
}
+ if(has_layer_norm)
+ {
+ const ITensorInfo *w_info = lstm_params.input_layer_norm_weights();
+ const ITensorInfo *b_info = lstm_params.input_gate_bias();
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(cell_outstage_info, *w_info, *b_info));
+ }
+
ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&input_outstage_info, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::TANH, 1.f, 1.f)));
}
// Cell.
@@ -634,6 +705,13 @@ Status CLQLSTMLayer::validate(const ITensorInfo *input,
ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, &output_outstage_info, &output_outstage_info, &output_outstage_info, ConvertPolicy::SATURATE));
}
+ if(has_layer_norm)
+ {
+ const ITensorInfo *w_info = lstm_params.output_layer_norm_weights();
+ const ITensorInfo *b_info = output_gate_bias;
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_layer_norm(output_outstage_info, *w_info, *b_info));
+ }
+
const TensorInfo output_gate_info(TensorShape(num_units, batch_size), 1, DataType::QSYMM16, sigmoid_tanh_outqinfo);
ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&output_outstage_info, &output_gate_info, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)));
@@ -715,6 +793,11 @@ void CLQLSTMLayer::run()
CLScheduler::get().enqueue(_accumulate_cell_forget);
}
+ if(_has_layer_norm)
+ {
+ CLScheduler::get().enqueue(get_layer_norm(LayerNormGate::Forget));
+ }
+
_forget_gate_sigmoid.run();
// Modulation gate.
@@ -725,6 +808,11 @@ void CLQLSTMLayer::run()
_recurrent_to_cell_outstage.run();
CLScheduler::get().enqueue(_accumulate_input_recurrent_modulation);
+ if(_has_layer_norm)
+ {
+ CLScheduler::get().enqueue(get_layer_norm(LayerNormGate::Cell));
+ }
+
_cell_gate_tanh.run();
// Input gate
@@ -747,6 +835,11 @@ void CLQLSTMLayer::run()
CLScheduler::get().enqueue(_accumulate_cell_input);
}
+ if(_has_layer_norm)
+ {
+ CLScheduler::get().enqueue(get_layer_norm(LayerNormGate::Input));
+ }
+
_input_gate_tanh.run();
}
@@ -771,6 +864,11 @@ void CLQLSTMLayer::run()
CLScheduler::get().enqueue(_accumulate_cell_to_output);
}
+ if(_has_layer_norm)
+ {
+ CLScheduler::get().enqueue(get_layer_norm(LayerNormGate::Output));
+ }
+
_output_gate_sigmoid.run();
// Hidden.