diff options
Diffstat (limited to 'arm_compute/core/utils/misc/InfoHelpers.h')
-rw-r--r-- | arm_compute/core/utils/misc/InfoHelpers.h | 54 |
1 files changed, 29 insertions, 25 deletions
diff --git a/arm_compute/core/utils/misc/InfoHelpers.h b/arm_compute/core/utils/misc/InfoHelpers.h index ced0d24b56..1d1b4ea8d7 100644 --- a/arm_compute/core/utils/misc/InfoHelpers.h +++ b/arm_compute/core/utils/misc/InfoHelpers.h @@ -53,10 +53,12 @@ inline bool is_relu(ActivationLayerInfo activation_info) */ inline bool is_relu6(ActivationLayerInfo activation_info) { - const bool is_lu_bounded_relu = activation_info.activation() == ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU - && activation_info.a() == 6.f && activation_info.b() == 0.f; - const bool is_bounded_relu = activation_info.activation() == ActivationLayerInfo::ActivationFunction::BOUNDED_RELU - && activation_info.a() == 6.f; + const bool is_lu_bounded_relu = + activation_info.activation() == ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU && + activation_info.a() == 6.f && activation_info.b() == 0.f; + const bool is_bounded_relu = + activation_info.activation() == ActivationLayerInfo::ActivationFunction::BOUNDED_RELU && + activation_info.a() == 6.f; return activation_info.enabled() && (is_lu_bounded_relu || is_bounded_relu); } @@ -68,34 +70,37 @@ inline bool is_relu6(ActivationLayerInfo activation_info) * */ template <typename T> -inline void build_lstm_params_tensor_info(const LSTMParams<T> &lstm_params, - LSTMParams<ITensorInfo> *lstm_params_info) +inline void build_lstm_params_tensor_info(const LSTMParams<T> &lstm_params, LSTMParams<ITensorInfo> *lstm_params_info) { - if(lstm_params.has_peephole_opt()) + if (lstm_params.has_peephole_opt()) { ARM_COMPUTE_ERROR_ON_NULLPTR(lstm_params.cell_to_forget_weights(), lstm_params.cell_to_output_weights()); - lstm_params_info->set_peephole_params(lstm_params.cell_to_forget_weights()->info(), lstm_params.cell_to_output_weights()->info()); + lstm_params_info->set_peephole_params(lstm_params.cell_to_forget_weights()->info(), + lstm_params.cell_to_output_weights()->info()); } - if(lstm_params.has_projection()) + if (lstm_params.has_projection()) { ARM_COMPUTE_ERROR_ON_NULLPTR(lstm_params.projection_weights()); - lstm_params_info->set_projection_params(lstm_params.projection_weights()->info(), - lstm_params.projection_bias() != nullptr ? lstm_params.projection_bias()->info() : nullptr); + lstm_params_info->set_projection_params( + lstm_params.projection_weights()->info(), + lstm_params.projection_bias() != nullptr ? lstm_params.projection_bias()->info() : nullptr); } - if(!lstm_params.has_cifg_opt()) + if (!lstm_params.has_cifg_opt()) { - ARM_COMPUTE_ERROR_ON_NULLPTR(lstm_params.input_to_input_weights(), lstm_params.recurrent_to_input_weights(), lstm_params.input_gate_bias()); + ARM_COMPUTE_ERROR_ON_NULLPTR(lstm_params.input_to_input_weights(), lstm_params.recurrent_to_input_weights(), + lstm_params.input_gate_bias()); - ITensorInfo *cell_to_input_weights_info = (lstm_params.has_peephole_opt()) ? lstm_params.cell_to_input_weights()->info() : nullptr; - lstm_params_info->set_cifg_params(lstm_params.input_to_input_weights()->info(), lstm_params.recurrent_to_input_weights()->info(), - cell_to_input_weights_info, lstm_params.input_gate_bias()->info()); + ITensorInfo *cell_to_input_weights_info = + (lstm_params.has_peephole_opt()) ? lstm_params.cell_to_input_weights()->info() : nullptr; + lstm_params_info->set_cifg_params(lstm_params.input_to_input_weights()->info(), + lstm_params.recurrent_to_input_weights()->info(), cell_to_input_weights_info, + lstm_params.input_gate_bias()->info()); } - if(lstm_params.use_layer_norm()) + if (lstm_params.use_layer_norm()) { - ARM_COMPUTE_ERROR_ON_NULLPTR(lstm_params.forget_layer_norm_weights(), - lstm_params.output_layer_norm_weights(), + ARM_COMPUTE_ERROR_ON_NULLPTR(lstm_params.forget_layer_norm_weights(), lstm_params.output_layer_norm_weights(), lstm_params.cell_layer_norm_weights()); - if(!lstm_params.has_cifg_opt()) + if (!lstm_params.has_cifg_opt()) { ARM_COMPUTE_ERROR_ON_NULLPTR(lstm_params.input_layer_norm_weights()); } @@ -103,15 +108,14 @@ inline void build_lstm_params_tensor_info(const LSTMParams<T> &lstm_params, ITensorInfo *forget_info = lstm_params.forget_layer_norm_weights()->info(); ITensorInfo *cell_info = lstm_params.cell_layer_norm_weights()->info(); ITensorInfo *output_info = lstm_params.output_layer_norm_weights()->info(); - ITensorInfo *input_info = lstm_params.has_cifg_opt() ? nullptr : lstm_params.input_layer_norm_weights()->info(); + ITensorInfo *input_info = lstm_params.has_cifg_opt() ? nullptr : lstm_params.input_layer_norm_weights()->info(); lstm_params_info->set_layer_normalization_params(input_info, forget_info, cell_info, output_info); } - lstm_params_info->set_matmul_scale_params(lstm_params.input_intermediate_scale(), - lstm_params.forget_intermediate_scale(), - lstm_params.cell_intermediate_scale(), - lstm_params.output_intermediate_scale()); + lstm_params_info->set_matmul_scale_params( + lstm_params.input_intermediate_scale(), lstm_params.forget_intermediate_scale(), + lstm_params.cell_intermediate_scale(), lstm_params.output_intermediate_scale()); lstm_params_info->set_hidden_state_params(lstm_params.hidden_state_zero(), lstm_params.hidden_state_scale()); } |