aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/runtime/common/LSTMParams.h
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/runtime/common/LSTMParams.h')
-rw-r--r--arm_compute/runtime/common/LSTMParams.h36
1 files changed, 18 insertions, 18 deletions
diff --git a/arm_compute/runtime/common/LSTMParams.h b/arm_compute/runtime/common/LSTMParams.h
index 82fca7e7a2..ffb4ddd9d3 100644
--- a/arm_compute/runtime/common/LSTMParams.h
+++ b/arm_compute/runtime/common/LSTMParams.h
@@ -81,7 +81,7 @@ public:
*
* @return Reference to this LSTMParams object
*/
- LSTMParams &set_cifg_params(const T *input_to_input_weights, const T *recurrent_to_input_weights, const T *cell_to_input_weights, const T *input_gate_bias)
+ LSTMParams &set_cifg_params(const T *input_to_input_weights, const T *recurrent_to_input_weights, T *cell_to_input_weights, const T *input_gate_bias)
{
_input_to_input_weights = input_to_input_weights;
_recurrent_to_input_weights = recurrent_to_input_weights;
@@ -111,7 +111,7 @@ public:
*
* @return Reference to this LSTMParams object
*/
- LSTMParams &set_peephole_params(const T *cell_to_forget_weights, const T *cell_to_output_weights)
+ LSTMParams &set_peephole_params(T *cell_to_forget_weights, T *cell_to_output_weights)
{
_cell_to_forget_weights = cell_to_forget_weights;
_cell_to_output_weights = cell_to_output_weights;
@@ -127,8 +127,8 @@ public:
*
* @return Reference to this LSTMParams object
*/
- LSTMParams &set_layer_normalization_params(const T *input_layer_norm_weights, const T *forget_layer_norm_weights,
- const T *cell_layer_norm_weights, const T *output_layer_norm_weights)
+ LSTMParams &set_layer_normalization_params(T *input_layer_norm_weights, T *forget_layer_norm_weights,
+ T *cell_layer_norm_weights, T *output_layer_norm_weights)
{
_input_layer_norm_weights = input_layer_norm_weights;
_forget_layer_norm_weights = forget_layer_norm_weights;
@@ -204,7 +204,7 @@ public:
return _recurrent_to_input_weights;
}
- const T *cell_to_input_weights() const
+ T *cell_to_input_weights() const
{
return _cell_to_input_weights;
}
@@ -214,12 +214,12 @@ public:
return _input_gate_bias;
}
- const T *cell_to_forget_weights() const
+ T *cell_to_forget_weights() const
{
return _cell_to_forget_weights;
}
- const T *cell_to_output_weights() const
+ T *cell_to_output_weights() const
{
return _cell_to_output_weights;
}
@@ -234,22 +234,22 @@ public:
return _projection_bias;
}
- const T *input_layer_norm_weights() const
+ T *input_layer_norm_weights() const
{
return _input_layer_norm_weights;
}
- const T *forget_layer_norm_weights() const
+ T *forget_layer_norm_weights() const
{
return _forget_layer_norm_weights;
}
- const T *cell_layer_norm_weights() const
+ T *cell_layer_norm_weights() const
{
return _cell_layer_norm_weights;
}
- const T *output_layer_norm_weights() const
+ T *output_layer_norm_weights() const
{
return _output_layer_norm_weights;
}
@@ -317,16 +317,16 @@ public:
private:
const T *_input_to_input_weights;
const T *_recurrent_to_input_weights;
- const T *_cell_to_input_weights;
+ T *_cell_to_input_weights;
const T *_input_gate_bias;
- const T *_cell_to_forget_weights;
- const T *_cell_to_output_weights;
+ T *_cell_to_forget_weights;
+ T *_cell_to_output_weights;
const T *_projection_weights;
const T *_projection_bias;
- const T *_input_layer_norm_weights;
- const T *_forget_layer_norm_weights;
- const T *_cell_layer_norm_weights;
- const T *_output_layer_norm_weights;
+ T *_input_layer_norm_weights;
+ T *_forget_layer_norm_weights;
+ T *_cell_layer_norm_weights;
+ T *_output_layer_norm_weights;
float _cell_clip;
float _projection_clip;
float _input_intermediate_scale;