aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/runtime/CL/functions/CLLSTMLayer.cpp6
-rw-r--r--src/runtime/NEON/functions/NELSTMLayer.cpp9
-rw-r--r--tests/validation/fixtures/LSTMLayerFixture.h9
3 files changed, 12 insertions, 12 deletions
diff --git a/src/runtime/CL/functions/CLLSTMLayer.cpp b/src/runtime/CL/functions/CLLSTMLayer.cpp
index 9f17a52812..ea08beca75 100644
--- a/src/runtime/CL/functions/CLLSTMLayer.cpp
+++ b/src/runtime/CL/functions/CLLSTMLayer.cpp
@@ -286,7 +286,7 @@ void CLLSTMLayer::configure(const CLCompileContext &compile_context, const ICLTe
if(cell_threshold != 0.f)
{
_perform_cell_clipping = true;
- _cell_clip.configure(compile_context, &_cell_state_out1, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -cell_threshold, cell_threshold));
+ _cell_clip.configure(compile_context, &_cell_state_out1, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, cell_threshold, -cell_threshold));
}
// Configure block that calculates the output
@@ -569,8 +569,8 @@ Status CLLSTMLayer::validate(const ITensorInfo *input,
ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAddition::validate(&cell_state_tmp, &cell_state_tmp, &cell_state_tmp, ConvertPolicy::SATURATE));
if(cell_threshold != 0.f)
{
- ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&cell_state_tmp, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -cell_threshold,
- cell_threshold)));
+ ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(&cell_state_tmp, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, cell_threshold,
+ -cell_threshold)));
}
std::vector<const ITensorInfo *> in_out_weights;
diff --git a/src/runtime/NEON/functions/NELSTMLayer.cpp b/src/runtime/NEON/functions/NELSTMLayer.cpp
index 2d6be06499..428cdf8c04 100644
--- a/src/runtime/NEON/functions/NELSTMLayer.cpp
+++ b/src/runtime/NEON/functions/NELSTMLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2021 Arm Limited.
+ * Copyright (c) 2018-2022 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -263,7 +263,7 @@ void NELSTMLayer::configure(const ITensor *input,
if(cell_threshold != 0.f)
{
_perform_cell_clipping = true;
- _cell_clip.configure(&_cell_state_out1, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -cell_threshold, cell_threshold));
+ _cell_clip.configure(&_cell_state_out1, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, cell_threshold, -cell_threshold));
}
// Configure block that calculates the output
@@ -542,8 +542,8 @@ Status NELSTMLayer::validate(const ITensorInfo *input,
ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&cell_state_tmp, &cell_state_tmp, &cell_state_tmp, ConvertPolicy::SATURATE));
if(cell_threshold != 0.f)
{
- ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&cell_state_tmp, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -cell_threshold,
- cell_threshold)));
+ ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&cell_state_tmp, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, cell_threshold,
+ -cell_threshold)));
}
// Validate output gate tmp
@@ -665,6 +665,7 @@ void NELSTMLayer::run()
_pixelwise_mul_cell_gate_coeff.run();
_accum_cell_gate_bias.run();
}
+
_activation_cell_state.run();
_pixelwise_mul_cell_state1.run();
_pixelwise_mul_cell_state2.run();
diff --git a/tests/validation/fixtures/LSTMLayerFixture.h b/tests/validation/fixtures/LSTMLayerFixture.h
index f4bae86d30..c3a54726de 100644
--- a/tests/validation/fixtures/LSTMLayerFixture.h
+++ b/tests/validation/fixtures/LSTMLayerFixture.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2021 Arm Limited.
+ * Copyright (c) 2018-2022 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -458,7 +458,6 @@ protected:
}
input_gate = reference::activation_layer(input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
}
-
// Compute cell_state
SimpleTensor<T> fully_connected_cell_state = reference::fully_connected_layer(input, input_to_cell_w, cell_bias, output_cell_shape);
transposed_weights = reference::transpose(recurrent_to_cell_w);
@@ -474,12 +473,13 @@ protected:
fill(cell_bias, 8);
cell_state_out = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, cell_state_out, cell_bias, data_type, ConvertPolicy::SATURATE);
}
- cell_state_out = reference::activation_layer(cell_state_out, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
+ cell_state_out = reference::activation_layer(cell_state_out, info);
cell_state_out = reference::pixel_wise_multiplication<T, T, T>(cell_state_out, input_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type);
cell_state_out = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, cell_state_out, pixelwise_mul, data_type, ConvertPolicy::SATURATE);
+
if(cell_threshold != 0.f)
{
- cell_state_out = reference::activation_layer(cell_state_out, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -cell_threshold, cell_threshold));
+ cell_state_out = reference::activation_layer(cell_state_out, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, cell_threshold, -cell_threshold));
}
// Compute output
@@ -515,7 +515,6 @@ protected:
output_state_out = reference::activation_layer(fully_connected_projection, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -projection_threshold, projection_threshold));
}
}
-
std::vector<SimpleTensor<T>> scratch_inputs;
if(!cifg_opt)
{