aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/LSTMLayerFixture.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/fixtures/LSTMLayerFixture.h')
-rw-r--r--tests/validation/fixtures/LSTMLayerFixture.h132
1 files changed, 68 insertions, 64 deletions
diff --git a/tests/validation/fixtures/LSTMLayerFixture.h b/tests/validation/fixtures/LSTMLayerFixture.h
index 858ee07d3e..a32e9adfe5 100644
--- a/tests/validation/fixtures/LSTMLayerFixture.h
+++ b/tests/validation/fixtures/LSTMLayerFixture.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2020 ARM Limited.
+ * Copyright (c) 2018-2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -46,7 +46,6 @@ template <typename TensorType, typename AccessorType, typename FunctionType, typ
class LSTMLayerValidationFixture : public framework::Fixture
{
public:
- template <typename...>
void setup(TensorShape input_shape, TensorShape input_weights_shape, TensorShape recurrent_weights_shape, TensorShape cell_bias_shape, TensorShape output_cell_shape, TensorShape output_shape,
TensorShape scratch_shape, ActivationLayerInfo info, float cell_threshold, float projection_threshold, DataType data_type, bool projection_opt, bool peephole_opt,
bool use_layer_norm)
@@ -61,13 +60,19 @@ protected:
template <typename U>
void fill(U &&tensor, int i)
{
- std::uniform_real_distribution<> distribution(-1.0f, 1.0f);
+ static_assert(std::is_floating_point<T>::value || std::is_same<T, half>::value, "Only floating point data types supported.");
+ using DistributionType = typename std::conditional<std::is_same<T, half>::value, arm_compute::utils::uniform_real_distribution_16bit<T>, std::uniform_real_distribution<T>>::type;
+
+ DistributionType distribution{ T(-1.0f), T(1.0f) };
library->fill(tensor, distribution, i);
}
template <typename U>
void fill_custom_val(U &&tensor, float num, int i)
{
- std::uniform_real_distribution<> distribution(num, num);
+ static_assert(std::is_floating_point<T>::value || std::is_same<T, half>::value, "Only floating point data types supported.");
+ using DistributionType = typename std::conditional<std::is_same<T, half>::value, arm_compute::utils::uniform_real_distribution_16bit<T>, std::uniform_real_distribution<T>>::type;
+
+ DistributionType distribution{ T(num), T(num) };
library->fill(tensor, distribution, i);
}
TensorType compute_target(const TensorShape &input_shape, const TensorShape &input_weights_shape, const TensorShape &recurrent_weights_shape, const TensorShape &cell_bias_shape,
@@ -161,22 +166,22 @@ protected:
&scratch, &output_state_out, &cell_state_out, &output,
lstm_params, info, cell_threshold, projection_threshold);
- ARM_COMPUTE_EXPECT(input.info()->is_resizable(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(input_to_forget_w.info()->is_resizable(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(input_to_cell_w.info()->is_resizable(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(input_to_output_w.info()->is_resizable(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(recurrent_to_forget_w.info()->is_resizable(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(recurrent_to_cell_w.info()->is_resizable(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(recurrent_to_output_w.info()->is_resizable(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(forget_gate_bias.info()->is_resizable(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(cell_bias.info()->is_resizable(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(output_gate_bias.info()->is_resizable(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(output_state_in.info()->is_resizable(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(cell_state_in.info()->is_resizable(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(scratch.info()->is_resizable(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(output_state_out.info()->is_resizable(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(cell_state_out.info()->is_resizable(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(output.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_ASSERT(input.info()->is_resizable());
+ ARM_COMPUTE_ASSERT(input_to_forget_w.info()->is_resizable());
+ ARM_COMPUTE_ASSERT(input_to_cell_w.info()->is_resizable());
+ ARM_COMPUTE_ASSERT(input_to_output_w.info()->is_resizable());
+ ARM_COMPUTE_ASSERT(recurrent_to_forget_w.info()->is_resizable());
+ ARM_COMPUTE_ASSERT(recurrent_to_cell_w.info()->is_resizable());
+ ARM_COMPUTE_ASSERT(recurrent_to_output_w.info()->is_resizable());
+ ARM_COMPUTE_ASSERT(forget_gate_bias.info()->is_resizable());
+ ARM_COMPUTE_ASSERT(cell_bias.info()->is_resizable());
+ ARM_COMPUTE_ASSERT(output_gate_bias.info()->is_resizable());
+ ARM_COMPUTE_ASSERT(output_state_in.info()->is_resizable());
+ ARM_COMPUTE_ASSERT(cell_state_in.info()->is_resizable());
+ ARM_COMPUTE_ASSERT(scratch.info()->is_resizable());
+ ARM_COMPUTE_ASSERT(output_state_out.info()->is_resizable());
+ ARM_COMPUTE_ASSERT(cell_state_out.info()->is_resizable());
+ ARM_COMPUTE_ASSERT(output.info()->is_resizable());
// Allocate tensors
input.allocator()->allocate();
@@ -196,22 +201,22 @@ protected:
cell_state_out.allocator()->allocate();
output.allocator()->allocate();
- ARM_COMPUTE_EXPECT(!input.info()->is_resizable(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(!input_to_forget_w.info()->is_resizable(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(!input_to_cell_w.info()->is_resizable(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(!input_to_output_w.info()->is_resizable(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(!recurrent_to_forget_w.info()->is_resizable(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(!recurrent_to_cell_w.info()->is_resizable(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(!recurrent_to_output_w.info()->is_resizable(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(!forget_gate_bias.info()->is_resizable(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(!cell_bias.info()->is_resizable(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(!output_gate_bias.info()->is_resizable(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(!output_state_in.info()->is_resizable(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(!cell_state_in.info()->is_resizable(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(!scratch.info()->is_resizable(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(!output_state_out.info()->is_resizable(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(!cell_state_out.info()->is_resizable(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(!output.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_ASSERT(!input.info()->is_resizable());
+ ARM_COMPUTE_ASSERT(!input_to_forget_w.info()->is_resizable());
+ ARM_COMPUTE_ASSERT(!input_to_cell_w.info()->is_resizable());
+ ARM_COMPUTE_ASSERT(!input_to_output_w.info()->is_resizable());
+ ARM_COMPUTE_ASSERT(!recurrent_to_forget_w.info()->is_resizable());
+ ARM_COMPUTE_ASSERT(!recurrent_to_cell_w.info()->is_resizable());
+ ARM_COMPUTE_ASSERT(!recurrent_to_output_w.info()->is_resizable());
+ ARM_COMPUTE_ASSERT(!forget_gate_bias.info()->is_resizable());
+ ARM_COMPUTE_ASSERT(!cell_bias.info()->is_resizable());
+ ARM_COMPUTE_ASSERT(!output_gate_bias.info()->is_resizable());
+ ARM_COMPUTE_ASSERT(!output_state_in.info()->is_resizable());
+ ARM_COMPUTE_ASSERT(!cell_state_in.info()->is_resizable());
+ ARM_COMPUTE_ASSERT(!scratch.info()->is_resizable());
+ ARM_COMPUTE_ASSERT(!output_state_out.info()->is_resizable());
+ ARM_COMPUTE_ASSERT(!cell_state_out.info()->is_resizable());
+ ARM_COMPUTE_ASSERT(!output.info()->is_resizable());
// Fill tensors
fill(AccessorType(input), 0);
@@ -230,18 +235,18 @@ protected:
if(!cifg_opt)
{
- ARM_COMPUTE_EXPECT(input_to_input_w.info()->is_resizable(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(recurrent_to_input_w.info()->is_resizable(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(cell_to_input_w.info()->is_resizable(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(input_gate_bias.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_ASSERT(input_to_input_w.info()->is_resizable());
+ ARM_COMPUTE_ASSERT(recurrent_to_input_w.info()->is_resizable());
+ ARM_COMPUTE_ASSERT(cell_to_input_w.info()->is_resizable());
+ ARM_COMPUTE_ASSERT(input_gate_bias.info()->is_resizable());
input_to_input_w.allocator()->allocate();
recurrent_to_input_w.allocator()->allocate();
cell_to_input_w.allocator()->allocate();
input_gate_bias.allocator()->allocate();
- ARM_COMPUTE_EXPECT(!input_to_input_w.info()->is_resizable(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(!recurrent_to_input_w.info()->is_resizable(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(!cell_to_input_w.info()->is_resizable(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(!input_gate_bias.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_ASSERT(!input_to_input_w.info()->is_resizable());
+ ARM_COMPUTE_ASSERT(!recurrent_to_input_w.info()->is_resizable());
+ ARM_COMPUTE_ASSERT(!cell_to_input_w.info()->is_resizable());
+ ARM_COMPUTE_ASSERT(!input_gate_bias.info()->is_resizable());
fill(AccessorType(input_to_input_w), 13);
fill(AccessorType(recurrent_to_input_w), 14);
if(peephole_opt)
@@ -254,26 +259,26 @@ protected:
if(peephole_opt)
{
- ARM_COMPUTE_EXPECT(cell_to_forget_w.info()->is_resizable(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(cell_to_output_w.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_ASSERT(cell_to_forget_w.info()->is_resizable());
+ ARM_COMPUTE_ASSERT(cell_to_output_w.info()->is_resizable());
cell_to_forget_w.allocator()->allocate();
cell_to_output_w.allocator()->allocate();
- ARM_COMPUTE_EXPECT(!cell_to_forget_w.info()->is_resizable(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(!cell_to_output_w.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_ASSERT(!cell_to_forget_w.info()->is_resizable());
+ ARM_COMPUTE_ASSERT(!cell_to_output_w.info()->is_resizable());
fill(AccessorType(cell_to_forget_w), 18);
fill(AccessorType(cell_to_output_w), 19);
}
if(projection_opt)
{
- ARM_COMPUTE_EXPECT(projection_w.info()->is_resizable(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(projection_bias.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_ASSERT(projection_w.info()->is_resizable());
+ ARM_COMPUTE_ASSERT(projection_bias.info()->is_resizable());
projection_w.allocator()->allocate();
projection_bias.allocator()->allocate();
- ARM_COMPUTE_EXPECT(!projection_w.info()->is_resizable(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(!projection_bias.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_ASSERT(!projection_w.info()->is_resizable());
+ ARM_COMPUTE_ASSERT(!projection_bias.info()->is_resizable());
fill(AccessorType(projection_w), 20);
fill(AccessorType(projection_bias), 21);
@@ -283,25 +288,25 @@ protected:
{
if(!cifg_opt)
{
- ARM_COMPUTE_EXPECT(input_layer_norm_w.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_ASSERT(input_layer_norm_w.info()->is_resizable());
input_layer_norm_w.allocator()->allocate();
- ARM_COMPUTE_EXPECT(!input_layer_norm_w.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_ASSERT(!input_layer_norm_w.info()->is_resizable());
fill(AccessorType(input_layer_norm_w), 22);
}
- ARM_COMPUTE_EXPECT(forget_layer_norm_w.info()->is_resizable(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(cell_layer_norm_w.info()->is_resizable(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(output_layer_norm_w.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_ASSERT(forget_layer_norm_w.info()->is_resizable());
+ ARM_COMPUTE_ASSERT(cell_layer_norm_w.info()->is_resizable());
+ ARM_COMPUTE_ASSERT(output_layer_norm_w.info()->is_resizable());
forget_layer_norm_w.allocator()->allocate();
cell_layer_norm_w.allocator()->allocate();
output_layer_norm_w.allocator()->allocate();
- ARM_COMPUTE_EXPECT(!forget_layer_norm_w.info()->is_resizable(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(!cell_layer_norm_w.info()->is_resizable(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(!output_layer_norm_w.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_ASSERT(!forget_layer_norm_w.info()->is_resizable());
+ ARM_COMPUTE_ASSERT(!cell_layer_norm_w.info()->is_resizable());
+ ARM_COMPUTE_ASSERT(!output_layer_norm_w.info()->is_resizable());
fill(AccessorType(forget_layer_norm_w), 23);
fill(AccessorType(cell_layer_norm_w), 24);
@@ -452,7 +457,6 @@ protected:
}
input_gate = reference::activation_layer(input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
}
-
// Compute cell_state
SimpleTensor<T> fully_connected_cell_state = reference::fully_connected_layer(input, input_to_cell_w, cell_bias, output_cell_shape);
transposed_weights = reference::transpose(recurrent_to_cell_w);
@@ -468,12 +472,13 @@ protected:
fill(cell_bias, 8);
cell_state_out = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, cell_state_out, cell_bias, data_type, ConvertPolicy::SATURATE);
}
- cell_state_out = reference::activation_layer(cell_state_out, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
+ cell_state_out = reference::activation_layer(cell_state_out, info);
cell_state_out = reference::pixel_wise_multiplication<T, T, T>(cell_state_out, input_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type);
cell_state_out = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, cell_state_out, pixelwise_mul, data_type, ConvertPolicy::SATURATE);
+
if(cell_threshold != 0.f)
{
- cell_state_out = reference::activation_layer(cell_state_out, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -cell_threshold, cell_threshold));
+ cell_state_out = reference::activation_layer(cell_state_out, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, cell_threshold, -cell_threshold));
}
// Compute output
@@ -509,7 +514,6 @@ protected:
output_state_out = reference::activation_layer(fully_connected_projection, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -projection_threshold, projection_threshold));
}
}
-
std::vector<SimpleTensor<T>> scratch_inputs;
if(!cifg_opt)
{