diff options
Diffstat (limited to 'arm_compute/core/utils/misc/InfoHelpers.h')
-rw-r--r-- | arm_compute/core/utils/misc/InfoHelpers.h | 17 |
1 files changed, 17 insertions, 0 deletions
diff --git a/arm_compute/core/utils/misc/InfoHelpers.h b/arm_compute/core/utils/misc/InfoHelpers.h index 8cf701c124..6ecda7a0dd 100644 --- a/arm_compute/core/utils/misc/InfoHelpers.h +++ b/arm_compute/core/utils/misc/InfoHelpers.h @@ -90,6 +90,23 @@ inline void build_lstm_params_tensor_info(const LSTMParams<T> &lstm_params, lstm_params_info->set_cifg_params(lstm_params.input_to_input_weights()->info(), lstm_params.recurrent_to_input_weights()->info(), cell_to_input_weights_info, lstm_params.input_gate_bias()->info()); } + if(lstm_params.use_layer_norm()) + { + ARM_COMPUTE_ERROR_ON_NULLPTR(lstm_params.forget_layer_norm_weights(), + lstm_params.output_layer_norm_weights(), + lstm_params.cell_layer_norm_weights()); + if(!lstm_params.has_cifg_opt()) + { + ARM_COMPUTE_ERROR_ON_NULLPTR(lstm_params.input_layer_norm_weights()); + } + + const ITensorInfo *forget_info = lstm_params.forget_layer_norm_weights()->info(); + const ITensorInfo *cell_info = lstm_params.cell_layer_norm_weights()->info(); + const ITensorInfo *output_info = lstm_params.output_layer_norm_weights()->info(); + const ITensorInfo *input_info = lstm_params.has_cifg_opt() ? nullptr : lstm_params.input_layer_norm_weights()->info(); + + lstm_params_info->set_layer_normalization_params(input_info, forget_info, cell_info, output_info); + } } } // namespace info_helpers } // namespace utils |