diff options
Diffstat (limited to 'tests/validation/fixtures/LSTMLayerFixture.h')
-rw-r--r-- | tests/validation/fixtures/LSTMLayerFixture.h | 132 |
1 files changed, 68 insertions, 64 deletions
diff --git a/tests/validation/fixtures/LSTMLayerFixture.h b/tests/validation/fixtures/LSTMLayerFixture.h index 858ee07d3e..a32e9adfe5 100644 --- a/tests/validation/fixtures/LSTMLayerFixture.h +++ b/tests/validation/fixtures/LSTMLayerFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2020 ARM Limited. + * Copyright (c) 2018-2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -46,7 +46,6 @@ template <typename TensorType, typename AccessorType, typename FunctionType, typ class LSTMLayerValidationFixture : public framework::Fixture { public: - template <typename...> void setup(TensorShape input_shape, TensorShape input_weights_shape, TensorShape recurrent_weights_shape, TensorShape cell_bias_shape, TensorShape output_cell_shape, TensorShape output_shape, TensorShape scratch_shape, ActivationLayerInfo info, float cell_threshold, float projection_threshold, DataType data_type, bool projection_opt, bool peephole_opt, bool use_layer_norm) @@ -61,13 +60,19 @@ protected: template <typename U> void fill(U &&tensor, int i) { - std::uniform_real_distribution<> distribution(-1.0f, 1.0f); + static_assert(std::is_floating_point<T>::value || std::is_same<T, half>::value, "Only floating point data types supported."); + using DistributionType = typename std::conditional<std::is_same<T, half>::value, arm_compute::utils::uniform_real_distribution_16bit<T>, std::uniform_real_distribution<T>>::type; + + DistributionType distribution{ T(-1.0f), T(1.0f) }; library->fill(tensor, distribution, i); } template <typename U> void fill_custom_val(U &&tensor, float num, int i) { - std::uniform_real_distribution<> distribution(num, num); + static_assert(std::is_floating_point<T>::value || std::is_same<T, half>::value, "Only floating point data types supported."); + using DistributionType = typename std::conditional<std::is_same<T, half>::value, arm_compute::utils::uniform_real_distribution_16bit<T>, std::uniform_real_distribution<T>>::type; + + DistributionType distribution{ T(num), T(num) }; library->fill(tensor, distribution, i); } TensorType compute_target(const TensorShape &input_shape, const TensorShape &input_weights_shape, const TensorShape &recurrent_weights_shape, const TensorShape &cell_bias_shape, @@ -161,22 +166,22 @@ protected: &scratch, &output_state_out, &cell_state_out, &output, lstm_params, info, cell_threshold, projection_threshold); - ARM_COMPUTE_EXPECT(input.info()->is_resizable(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(input_to_forget_w.info()->is_resizable(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(input_to_cell_w.info()->is_resizable(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(input_to_output_w.info()->is_resizable(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(recurrent_to_forget_w.info()->is_resizable(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(recurrent_to_cell_w.info()->is_resizable(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(recurrent_to_output_w.info()->is_resizable(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(forget_gate_bias.info()->is_resizable(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(cell_bias.info()->is_resizable(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(output_gate_bias.info()->is_resizable(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(output_state_in.info()->is_resizable(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(cell_state_in.info()->is_resizable(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(scratch.info()->is_resizable(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(output_state_out.info()->is_resizable(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(cell_state_out.info()->is_resizable(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(output.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_ASSERT(input.info()->is_resizable()); + ARM_COMPUTE_ASSERT(input_to_forget_w.info()->is_resizable()); + ARM_COMPUTE_ASSERT(input_to_cell_w.info()->is_resizable()); + ARM_COMPUTE_ASSERT(input_to_output_w.info()->is_resizable()); + ARM_COMPUTE_ASSERT(recurrent_to_forget_w.info()->is_resizable()); + ARM_COMPUTE_ASSERT(recurrent_to_cell_w.info()->is_resizable()); + ARM_COMPUTE_ASSERT(recurrent_to_output_w.info()->is_resizable()); + ARM_COMPUTE_ASSERT(forget_gate_bias.info()->is_resizable()); + ARM_COMPUTE_ASSERT(cell_bias.info()->is_resizable()); + ARM_COMPUTE_ASSERT(output_gate_bias.info()->is_resizable()); + ARM_COMPUTE_ASSERT(output_state_in.info()->is_resizable()); + ARM_COMPUTE_ASSERT(cell_state_in.info()->is_resizable()); + ARM_COMPUTE_ASSERT(scratch.info()->is_resizable()); + ARM_COMPUTE_ASSERT(output_state_out.info()->is_resizable()); + ARM_COMPUTE_ASSERT(cell_state_out.info()->is_resizable()); + ARM_COMPUTE_ASSERT(output.info()->is_resizable()); // Allocate tensors input.allocator()->allocate(); @@ -196,22 +201,22 @@ protected: cell_state_out.allocator()->allocate(); output.allocator()->allocate(); - ARM_COMPUTE_EXPECT(!input.info()->is_resizable(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(!input_to_forget_w.info()->is_resizable(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(!input_to_cell_w.info()->is_resizable(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(!input_to_output_w.info()->is_resizable(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(!recurrent_to_forget_w.info()->is_resizable(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(!recurrent_to_cell_w.info()->is_resizable(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(!recurrent_to_output_w.info()->is_resizable(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(!forget_gate_bias.info()->is_resizable(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(!cell_bias.info()->is_resizable(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(!output_gate_bias.info()->is_resizable(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(!output_state_in.info()->is_resizable(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(!cell_state_in.info()->is_resizable(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(!scratch.info()->is_resizable(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(!output_state_out.info()->is_resizable(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(!cell_state_out.info()->is_resizable(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(!output.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_ASSERT(!input.info()->is_resizable()); + ARM_COMPUTE_ASSERT(!input_to_forget_w.info()->is_resizable()); + ARM_COMPUTE_ASSERT(!input_to_cell_w.info()->is_resizable()); + ARM_COMPUTE_ASSERT(!input_to_output_w.info()->is_resizable()); + ARM_COMPUTE_ASSERT(!recurrent_to_forget_w.info()->is_resizable()); + ARM_COMPUTE_ASSERT(!recurrent_to_cell_w.info()->is_resizable()); + ARM_COMPUTE_ASSERT(!recurrent_to_output_w.info()->is_resizable()); + ARM_COMPUTE_ASSERT(!forget_gate_bias.info()->is_resizable()); + ARM_COMPUTE_ASSERT(!cell_bias.info()->is_resizable()); + ARM_COMPUTE_ASSERT(!output_gate_bias.info()->is_resizable()); + ARM_COMPUTE_ASSERT(!output_state_in.info()->is_resizable()); + ARM_COMPUTE_ASSERT(!cell_state_in.info()->is_resizable()); + ARM_COMPUTE_ASSERT(!scratch.info()->is_resizable()); + ARM_COMPUTE_ASSERT(!output_state_out.info()->is_resizable()); + ARM_COMPUTE_ASSERT(!cell_state_out.info()->is_resizable()); + ARM_COMPUTE_ASSERT(!output.info()->is_resizable()); // Fill tensors fill(AccessorType(input), 0); @@ -230,18 +235,18 @@ protected: if(!cifg_opt) { - ARM_COMPUTE_EXPECT(input_to_input_w.info()->is_resizable(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(recurrent_to_input_w.info()->is_resizable(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(cell_to_input_w.info()->is_resizable(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(input_gate_bias.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_ASSERT(input_to_input_w.info()->is_resizable()); + ARM_COMPUTE_ASSERT(recurrent_to_input_w.info()->is_resizable()); + ARM_COMPUTE_ASSERT(cell_to_input_w.info()->is_resizable()); + ARM_COMPUTE_ASSERT(input_gate_bias.info()->is_resizable()); input_to_input_w.allocator()->allocate(); recurrent_to_input_w.allocator()->allocate(); cell_to_input_w.allocator()->allocate(); input_gate_bias.allocator()->allocate(); - ARM_COMPUTE_EXPECT(!input_to_input_w.info()->is_resizable(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(!recurrent_to_input_w.info()->is_resizable(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(!cell_to_input_w.info()->is_resizable(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(!input_gate_bias.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_ASSERT(!input_to_input_w.info()->is_resizable()); + ARM_COMPUTE_ASSERT(!recurrent_to_input_w.info()->is_resizable()); + ARM_COMPUTE_ASSERT(!cell_to_input_w.info()->is_resizable()); + ARM_COMPUTE_ASSERT(!input_gate_bias.info()->is_resizable()); fill(AccessorType(input_to_input_w), 13); fill(AccessorType(recurrent_to_input_w), 14); if(peephole_opt) @@ -254,26 +259,26 @@ protected: if(peephole_opt) { - ARM_COMPUTE_EXPECT(cell_to_forget_w.info()->is_resizable(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(cell_to_output_w.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_ASSERT(cell_to_forget_w.info()->is_resizable()); + ARM_COMPUTE_ASSERT(cell_to_output_w.info()->is_resizable()); cell_to_forget_w.allocator()->allocate(); cell_to_output_w.allocator()->allocate(); - ARM_COMPUTE_EXPECT(!cell_to_forget_w.info()->is_resizable(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(!cell_to_output_w.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_ASSERT(!cell_to_forget_w.info()->is_resizable()); + ARM_COMPUTE_ASSERT(!cell_to_output_w.info()->is_resizable()); fill(AccessorType(cell_to_forget_w), 18); fill(AccessorType(cell_to_output_w), 19); } if(projection_opt) { - ARM_COMPUTE_EXPECT(projection_w.info()->is_resizable(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(projection_bias.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_ASSERT(projection_w.info()->is_resizable()); + ARM_COMPUTE_ASSERT(projection_bias.info()->is_resizable()); projection_w.allocator()->allocate(); projection_bias.allocator()->allocate(); - ARM_COMPUTE_EXPECT(!projection_w.info()->is_resizable(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(!projection_bias.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_ASSERT(!projection_w.info()->is_resizable()); + ARM_COMPUTE_ASSERT(!projection_bias.info()->is_resizable()); fill(AccessorType(projection_w), 20); fill(AccessorType(projection_bias), 21); @@ -283,25 +288,25 @@ protected: { if(!cifg_opt) { - ARM_COMPUTE_EXPECT(input_layer_norm_w.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_ASSERT(input_layer_norm_w.info()->is_resizable()); input_layer_norm_w.allocator()->allocate(); - ARM_COMPUTE_EXPECT(!input_layer_norm_w.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_ASSERT(!input_layer_norm_w.info()->is_resizable()); fill(AccessorType(input_layer_norm_w), 22); } - ARM_COMPUTE_EXPECT(forget_layer_norm_w.info()->is_resizable(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(cell_layer_norm_w.info()->is_resizable(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(output_layer_norm_w.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_ASSERT(forget_layer_norm_w.info()->is_resizable()); + ARM_COMPUTE_ASSERT(cell_layer_norm_w.info()->is_resizable()); + ARM_COMPUTE_ASSERT(output_layer_norm_w.info()->is_resizable()); forget_layer_norm_w.allocator()->allocate(); cell_layer_norm_w.allocator()->allocate(); output_layer_norm_w.allocator()->allocate(); - ARM_COMPUTE_EXPECT(!forget_layer_norm_w.info()->is_resizable(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(!cell_layer_norm_w.info()->is_resizable(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(!output_layer_norm_w.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_ASSERT(!forget_layer_norm_w.info()->is_resizable()); + ARM_COMPUTE_ASSERT(!cell_layer_norm_w.info()->is_resizable()); + ARM_COMPUTE_ASSERT(!output_layer_norm_w.info()->is_resizable()); fill(AccessorType(forget_layer_norm_w), 23); fill(AccessorType(cell_layer_norm_w), 24); @@ -452,7 +457,6 @@ protected: } input_gate = reference::activation_layer(input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)); } - // Compute cell_state SimpleTensor<T> fully_connected_cell_state = reference::fully_connected_layer(input, input_to_cell_w, cell_bias, output_cell_shape); transposed_weights = reference::transpose(recurrent_to_cell_w); @@ -468,12 +472,13 @@ protected: fill(cell_bias, 8); cell_state_out = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, cell_state_out, cell_bias, data_type, ConvertPolicy::SATURATE); } - cell_state_out = reference::activation_layer(cell_state_out, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)); + cell_state_out = reference::activation_layer(cell_state_out, info); cell_state_out = reference::pixel_wise_multiplication<T, T, T>(cell_state_out, input_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN, data_type); cell_state_out = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, cell_state_out, pixelwise_mul, data_type, ConvertPolicy::SATURATE); + if(cell_threshold != 0.f) { - cell_state_out = reference::activation_layer(cell_state_out, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -cell_threshold, cell_threshold)); + cell_state_out = reference::activation_layer(cell_state_out, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, cell_threshold, -cell_threshold)); } // Compute output @@ -509,7 +514,6 @@ protected: output_state_out = reference::activation_layer(fully_connected_projection, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -projection_threshold, projection_threshold)); } } - std::vector<SimpleTensor<T>> scratch_inputs; if(!cifg_opt) { |