From ab23dd0fbc632063235a6ad408241dc79a35d3e4 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Mon, 6 Jul 2020 14:57:36 +0100 Subject: COMPMID-3387: Support memory injection in CLActivationLayer Signed-off-by: Georgios Pinitas Change-Id: I31f9620607b372fc3340c71e748a5ea177d9da62 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3520 Reviewed-by: Michele Di Giorgio Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins --- src/runtime/CL/functions/CLLSTMLayer.cpp | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) (limited to 'src/runtime/CL/functions/CLLSTMLayer.cpp') diff --git a/src/runtime/CL/functions/CLLSTMLayer.cpp b/src/runtime/CL/functions/CLLSTMLayer.cpp index 56f22e2fe0..e63a9cceb0 100644 --- a/src/runtime/CL/functions/CLLSTMLayer.cpp +++ b/src/runtime/CL/functions/CLLSTMLayer.cpp @@ -499,7 +499,7 @@ Status CLLSTMLayer::validate(const ITensorInfo *input, RoundingPolicy::TO_NEAREST_EVEN)); ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&forget_gate, forget_gate_bias, &forget_gate, ConvertPolicy::SATURATE)); } - ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayerKernel::validate(&forget_gate, &forget_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC))); + ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&forget_gate, &forget_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC))); // Validate input gate if(!lstm_params.has_cifg_opt()) @@ -534,7 +534,7 @@ Status CLLSTMLayer::validate(const ITensorInfo *input, ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(&input_gate, lstm_params.input_layer_norm_weights(), &input_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN)); ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&input_gate, lstm_params.input_gate_bias(), &input_gate, ConvertPolicy::SATURATE)); } - ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayerKernel::validate(&input_gate, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC))); + ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&input_gate, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC))); } else { @@ -552,14 +552,14 @@ Status CLLSTMLayer::validate(const ITensorInfo *input, RoundingPolicy::TO_NEAREST_EVEN)); ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&cell_state_tmp, cell_bias, &cell_state_tmp, ConvertPolicy::SATURATE)); } - ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayerKernel::validate(&cell_state_tmp, nullptr, activation_info)); + ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&cell_state_tmp, nullptr, activation_info)); ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(&cell_state_tmp, &input_gate, &cell_state_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN)); ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(&cell_state_tmp, &forget_gate, &cell_state_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN)); ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&cell_state_tmp, &cell_state_tmp, &cell_state_tmp, ConvertPolicy::SATURATE)); if(cell_threshold != 0.f) { - ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayerKernel::validate(&cell_state_tmp, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -cell_threshold, - cell_threshold))); + ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&cell_state_tmp, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -cell_threshold, + cell_threshold))); } std::vector in_out_weights; @@ -584,18 +584,18 @@ Status CLLSTMLayer::validate(const ITensorInfo *input, RoundingPolicy::TO_NEAREST_EVEN)); ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&output_gate_tmp, output_gate_bias, &output_gate_tmp, ConvertPolicy::SATURATE)); } - ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayerKernel::validate(&output_gate_tmp, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC))); + ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&output_gate_tmp, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC))); // Validate output state - ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayerKernel::validate(&cell_state_tmp, &cell_state_tmp, activation_info)); + ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&cell_state_tmp, &cell_state_tmp, activation_info)); ARM_COMPUTE_RETURN_ON_ERROR(CLPixelWiseMultiplicationKernel::validate(&cell_state_tmp, &output_gate_tmp, &output_gate_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN)); if(lstm_params.has_projection()) { ARM_COMPUTE_RETURN_ON_ERROR(CLFullyConnectedLayer::validate(&output_gate_tmp, lstm_params.projection_weights(), lstm_params.projection_bias(), output_state_out)); if(projection_threshold != 0.f) { - ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayerKernel::validate(output_state_out, output_state_out, - ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -projection_threshold, projection_threshold))); + ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(output_state_out, output_state_out, + ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -projection_threshold, projection_threshold))); } } @@ -638,7 +638,7 @@ void CLLSTMLayer::run() CLScheduler::get().enqueue(_pixelwise_mul_forget_gate_coeff); CLScheduler::get().enqueue(_accum_forget_gate_bias); } - CLScheduler::get().enqueue(_activation_forget_gate); + _activation_forget_gate.run(); if(_run_cifg_opt) { @@ -661,7 +661,7 @@ void CLLSTMLayer::run() CLScheduler::get().enqueue(_pixelwise_mul_input_gate_coeff); CLScheduler::get().enqueue(_accum_input_gate_bias); } - CLScheduler::get().enqueue(_activation_input_gate); + _activation_input_gate.run(); } _fully_connected_cell_state.run(); @@ -674,14 +674,14 @@ void CLLSTMLayer::run() CLScheduler::get().enqueue(_pixelwise_mul_cell_gate_coeff); CLScheduler::get().enqueue(_accum_cell_gate_bias); } - CLScheduler::get().enqueue(_activation_cell_state); + _activation_cell_state.run(); CLScheduler::get().enqueue(_pixelwise_mul_cell_state1); CLScheduler::get().enqueue(_pixelwise_mul_cell_state2); CLScheduler::get().enqueue(_accum_cell_state2); if(_perform_cell_clipping) { - CLScheduler::get().enqueue(_cell_clip); + _cell_clip.run(); } _fully_connected_output.run(); @@ -697,9 +697,9 @@ void CLLSTMLayer::run() CLScheduler::get().enqueue(_pixelwise_mul_output_gate_coeff); CLScheduler::get().enqueue(_accum_output_gate_bias); } - CLScheduler::get().enqueue(_activation_output); + _activation_output.run(); - CLScheduler::get().enqueue(_activation_output_state); + _activation_output_state.run(); CLScheduler::get().enqueue(_pixelwise_mul_output_state2); if(_has_projection_weights) @@ -707,7 +707,7 @@ void CLLSTMLayer::run() _fully_connected_output_state.run(); if(_perform_projection_clipping) { - CLScheduler::get().enqueue(_projection_clip); + _projection_clip.run(); } } -- cgit v1.2.1