aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/runtime/common/LSTMParams.h
diff options
context:
space:
mode:
authorMichele Di Giorgio <michele.digiorgio@arm.com>2020-03-04 18:08:47 +0000
committerMichele Di Giorgio <michele.digiorgio@arm.com>2020-03-09 09:41:27 +0000
commit25d9775b08c59be8d17700f75026a8457aab9838 (patch)
tree449f2b1a450bd688dc2b925bd6e2cb9f28f9dc5b /arm_compute/runtime/common/LSTMParams.h
parentc28d42837b2aea09738a7df00653d623c3c53420 (diff)
downloadComputeLibrary-25d9775b08c59be8d17700f75026a8457aab9838.tar.gz
COMPMID-3225: Extend LSTMParams with parameters for Enhanced Quantized LSTM
Change-Id: I9732c7e7a7a89537cb046b973fd3a14f10caa06c Signed-off-by: Michele Di Giorgio <michele.digiorgio@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/2836 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Sang-Hoon Park <sang-hoon.park@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'arm_compute/runtime/common/LSTMParams.h')
-rw-r--r--arm_compute/runtime/common/LSTMParams.h133
1 files changed, 129 insertions, 4 deletions
diff --git a/arm_compute/runtime/common/LSTMParams.h b/arm_compute/runtime/common/LSTMParams.h
index b9e4014ff8..f16945730e 100644
--- a/arm_compute/runtime/common/LSTMParams.h
+++ b/arm_compute/runtime/common/LSTMParams.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2019 ARM Limited.
+ * Copyright (c) 2018-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -40,9 +40,30 @@ class LSTMParams
public:
/** Constructor */
LSTMParams()
- : _input_to_input_weights(nullptr), _recurrent_to_input_weights(nullptr), _cell_to_input_weights(nullptr), _input_gate_bias(nullptr), _cell_to_forget_weights(nullptr),
- _cell_to_output_weights(nullptr), _projection_weights(nullptr), _projection_bias(nullptr), _input_layer_norm_weights(nullptr), _forget_layer_norm_weights(nullptr), _cell_layer_norm_weights(nullptr),
- _output_layer_norm_weights(nullptr), _has_peephole_opt(false), _has_projection(false), _has_cifg_opt(true), _use_layer_norm(false)
+ : _input_to_input_weights(nullptr),
+ _recurrent_to_input_weights(nullptr),
+ _cell_to_input_weights(nullptr),
+ _input_gate_bias(nullptr),
+ _cell_to_forget_weights(nullptr),
+ _cell_to_output_weights(nullptr),
+ _projection_weights(nullptr),
+ _projection_bias(nullptr),
+ _input_layer_norm_weights(nullptr),
+ _forget_layer_norm_weights(nullptr),
+ _cell_layer_norm_weights(nullptr),
+ _output_layer_norm_weights(nullptr),
+ _cell_clip(0.f),
+ _projection_clip(0.0f),
+ _input_gate_matmul_scale(0.0f),
+ _forget_gate_matmul_scale(0.0f),
+ _cell_gate_matmul_scale(0.0f),
+ _output_gate_matmul_scale(0.0f),
+ _hidden_state_zero(0.0f),
+ _hidden_state_scale(0),
+ _has_peephole_opt(false),
+ _has_projection(false),
+ _has_cifg_opt(true),
+ _use_layer_norm(false)
{
}
/** Prevent instances of this class from being copied (As this class contains pointers) */
@@ -117,6 +138,62 @@ public:
return *this;
}
+ /** Set cell clip value.
+ *
+ * @param[in] cell_clip Value to be used to clip the cell state prior to the cell output activation.
+ *
+ * @return Reference to this LSTMParams object
+ */
+ LSTMParams &set_cell_clip_params(float cell_clip)
+ {
+ _cell_clip = cell_clip;
+ return *this;
+ }
+
+ /** Set projection clip value.
+ *
+ * @param[in] projection_clip Value to be used to clip the projection, in case projection is enabled.
+ *
+ * @return Reference to this LSTMParams object
+ */
+ LSTMParams &set_projection_clip_params(float projection_clip)
+ {
+ _projection_clip = projection_clip;
+ return *this;
+ }
+
+ /** Set scale of the intermediate results of matmul of each layer parameters.
+ *
+ * @param[in] input_gate_matmul_scale Scale of the intermediate result of matmul, i.e. input to layer normalization, at input gate.
+ * @param[in] forget_gate_matmul_scale Scale of the intermediate result of matmul, i.e. input to layer normalization, at forget gate.
+ * @param[in] cell_gate_matmul_scale Scale of the intermediate result of matmul, i.e. input to layer normalization, at cell gate.
+ * @param[in] output_gate_matmul_scale Scale of the intermediate result of matmul, i.e. input to layer normalization, at output gate.
+ *
+ * @return Reference to this LSTMParams object
+ */
+ LSTMParams &set_matmul_scale_params(float input_gate_matmul_scale, float forget_gate_matmul_scale, float cell_gate_matmul_scale, float output_gate_matmul_scale)
+ {
+ _input_gate_matmul_scale = input_gate_matmul_scale;
+ _forget_gate_matmul_scale = forget_gate_matmul_scale;
+ _cell_gate_matmul_scale = cell_gate_matmul_scale;
+ _output_gate_matmul_scale = output_gate_matmul_scale;
+ return *this;
+ }
+
+ /** Set hidden state zero and scale parameters.
+ *
+ * @param[in] hidden_state_zero The zero point of the hidden state.
+ * @param[in] hidden_state_scale The scale of the hidden state.
+ *
+ * @return Reference to this LSTMParams object
+ */
+ LSTMParams &set_matmul_scale_params(int32_t hidden_state_zero, float hidden_state_scale)
+ {
+ _hidden_state_zero = hidden_state_zero;
+ _hidden_state_scale = hidden_state_scale;
+ return *this;
+ }
+
const T *input_to_input_weights() const
{
return _input_to_input_weights;
@@ -177,6 +254,46 @@ public:
return _output_layer_norm_weights;
}
+ float cell_clip() const
+ {
+ return _cell_clip;
+ }
+
+ float projection_clip() const
+ {
+ return _projection_clip;
+ }
+
+ float input_gate_matmul_scale() const
+ {
+ return _input_gate_matmul_scale;
+ }
+
+ float forget_gate_matmul_scale() const
+ {
+ return _forget_gate_matmul_scale;
+ }
+
+ float cell_gate_matmul_scale() const
+ {
+ return _cell_gate_matmul_scale;
+ }
+
+ float output_gate_matmul_scale() const
+ {
+ return _output_gate_matmul_scale;
+ }
+
+ int32_t hidden_state_zero() const
+ {
+ return _hidden_state_zero;
+ }
+
+ float hidden_state_scale() const
+ {
+ return _hidden_state_scale;
+ }
+
bool has_peephole_opt() const
{
return _has_peephole_opt;
@@ -210,6 +327,14 @@ private:
const T *_forget_layer_norm_weights;
const T *_cell_layer_norm_weights;
const T *_output_layer_norm_weights;
+ float _cell_clip;
+ float _projection_clip;
+ float _input_gate_matmul_scale;
+ float _forget_gate_matmul_scale;
+ float _cell_gate_matmul_scale;
+ float _output_gate_matmul_scale;
+ float _hidden_state_zero;
+ int32_t _hidden_state_scale;
bool _has_peephole_opt;
bool _has_projection;
bool _has_cifg_opt;