From 9454cf7654a2aac3f9e624d910237517eb7c8a59 Mon Sep 17 00:00:00 2001 From: Pablo Marquez Tello Date: Wed, 16 Feb 2022 11:15:58 +0000 Subject: Fixed threshould argument order in NE/CL/LSTM * Fixed hardcoded LOGISTIC activation in ACL reference * Partially resolves MLCE-60 * Resolves COMPMID-5139 Change-Id: I50e75339084ea53bf75acf18aa3e5cdafcf34c15 Signed-off-by: Pablo Marquez Tello Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7150 Tested-by: Arm Jenkins Reviewed-by: TeresaARM Reviewed-by: Giorgio Arena Comments-Addressed: Arm Jenkins --- src/runtime/CL/functions/CLLSTMLayer.cpp | 6 +++--- src/runtime/NEON/functions/NELSTMLayer.cpp | 9 +++++---- tests/validation/fixtures/LSTMLayerFixture.h | 9 ++++----- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/runtime/CL/functions/CLLSTMLayer.cpp b/src/runtime/CL/functions/CLLSTMLayer.cpp index 9f17a52812..ea08beca75 100644 --- a/src/runtime/CL/functions/CLLSTMLayer.cpp +++ b/src/runtime/CL/functions/CLLSTMLayer.cpp @@ -286,7 +286,7 @@ void CLLSTMLayer::configure(const CLCompileContext &compile_context, const ICLTe if(cell_threshold != 0.f) { _perform_cell_clipping = true; - _cell_clip.configure(compile_context, &_cell_state_out1, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -cell_threshold, cell_threshold)); + _cell_clip.configure(compile_context, &_cell_state_out1, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, cell_threshold, -cell_threshold)); } // Configure block that calculates the output @@ -569,8 +569,8 @@ Status CLLSTMLayer::validate(const ITensorInfo *input, ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&cell_state_tmp, &cell_state_tmp, &cell_state_tmp, ConvertPolicy::SATURATE)); if(cell_threshold != 0.f) { - ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&cell_state_tmp, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -cell_threshold, - cell_threshold))); + ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&cell_state_tmp, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, cell_threshold, + -cell_threshold))); } std::vector in_out_weights; diff --git a/src/runtime/NEON/functions/NELSTMLayer.cpp b/src/runtime/NEON/functions/NELSTMLayer.cpp index 2d6be06499..428cdf8c04 100644 --- a/src/runtime/NEON/functions/NELSTMLayer.cpp +++ b/src/runtime/NEON/functions/NELSTMLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2021 Arm Limited. + * Copyright (c) 2018-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -263,7 +263,7 @@ void NELSTMLayer::configure(const ITensor *input, if(cell_threshold != 0.f) { _perform_cell_clipping = true; - _cell_clip.configure(&_cell_state_out1, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -cell_threshold, cell_threshold)); + _cell_clip.configure(&_cell_state_out1, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, cell_threshold, -cell_threshold)); } // Configure block that calculates the output @@ -542,8 +542,8 @@ Status NELSTMLayer::validate(const ITensorInfo *input, ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&cell_state_tmp, &cell_state_tmp, &cell_state_tmp, ConvertPolicy::SATURATE)); if(cell_threshold != 0.f) { - ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&cell_state_tmp, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -cell_threshold, - cell_threshold))); + ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&cell_state_tmp, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, cell_threshold, + -cell_threshold))); } // Validate output gate tmp @@ -665,6 +665,7 @@ void NELSTMLayer::run() _pixelwise_mul_cell_gate_coeff.run(); _accum_cell_gate_bias.run(); } + _activation_cell_state.run(); _pixelwise_mul_cell_state1.run(); _pixelwise_mul_cell_state2.run(); diff --git a/tests/validation/fixtures/LSTMLayerFixture.h b/tests/validation/fixtures/LSTMLayerFixture.h index f4bae86d30..c3a54726de 100644 --- a/tests/validation/fixtures/LSTMLayerFixture.h +++ b/tests/validation/fixtures/LSTMLayerFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2021 Arm Limited. + * Copyright (c) 2018-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -458,7 +458,6 @@ protected: } input_gate = reference::activation_layer(input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)); } - // Compute cell_state SimpleTensor fully_connected_cell_state = reference::fully_connected_layer(input, input_to_cell_w, cell_bias, output_cell_shape); transposed_weights = reference::transpose(recurrent_to_cell_w); @@ -474,12 +473,13 @@ protected: fill(cell_bias, 8); cell_state_out = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, cell_state_out, cell_bias, data_type, ConvertPolicy::SATURATE); } - cell_state_out = reference::activation_layer(cell_state_out, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)); + cell_state_out = reference::activation_layer(cell_state_out, info); cell_state_out = reference::pixel_wise_multiplication(cell_state_out, input_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type); cell_state_out = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, cell_state_out, pixelwise_mul, data_type, ConvertPolicy::SATURATE); + if(cell_threshold != 0.f) { - cell_state_out = reference::activation_layer(cell_state_out, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -cell_threshold, cell_threshold)); + cell_state_out = reference::activation_layer(cell_state_out, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, cell_threshold, -cell_threshold)); } // Compute output @@ -515,7 +515,6 @@ protected: output_state_out = reference::activation_layer(fully_connected_projection, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -projection_threshold, projection_threshold)); } } - std::vector> scratch_inputs; if(!cifg_opt) { -- cgit v1.2.1