aboutsummaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorMichele Di Giorgio <michele.digiorgio@arm.com>2019-06-04 12:41:45 +0100
committerMichele Di Giorgio <michele.digiorgio@arm.com>2019-06-13 13:06:49 +0000
commit39438b427b293c6d2e7066c68d3c3d3cb6d98a15 (patch)
treed5de918ca90dfe5641c7e0c3c854724f7de746d4 /tests
parentc86633eb8865d8d2292cc44a8c30d09aee091ece (diff)
downloadComputeLibrary-39438b427b293c6d2e7066c68d3c3d3cb6d98a15.tar.gz
COMPMID-2342: Add layer normalization support in CLLSTMLayer
Change-Id: I25d974aa94e69c5f79a0bd99d5869a351d6d954d Signed-off-by: Michele Di Giorgio <michele.digiorgio@arm.com> Reviewed-on: https://review.mlplatform.org/c/1324 Reviewed-by: Manuel Bottini <manuel.bottini@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Michalis Spyrou <michalis.spyrou@arm.com>
Diffstat (limited to 'tests')
-rw-r--r--tests/validation/CL/LSTMLayer.cpp15
-rw-r--r--tests/validation/NEON/LSTMLayer.cpp15
-rw-r--r--tests/validation/fixtures/LSTMLayerFixture.h131
3 files changed, 135 insertions, 26 deletions
diff --git a/tests/validation/CL/LSTMLayer.cpp b/tests/validation/CL/LSTMLayer.cpp
index 71a9383d93..69ac61dcf4 100644
--- a/tests/validation/CL/LSTMLayer.cpp
+++ b/tests/validation/CL/LSTMLayer.cpp
@@ -153,10 +153,11 @@ template <typename T>
using CLLSTMLayerFixture = LSTMLayerValidationFixture<CLTensor, CLAccessor, CLLSTMLayer, LSTMParams<ICLTensor>, T>;
TEST_SUITE(FP32)
-FIXTURE_DATA_TEST_CASE(RunSmall, CLLSTMLayerFixture<float>, framework::DatasetMode::ALL, combine(combine(combine(datasets::SmallLSTMLayerDataset(), framework::dataset::make("DataType",
+FIXTURE_DATA_TEST_CASE(RunSmall, CLLSTMLayerFixture<float>, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SmallLSTMLayerDataset(), framework::dataset::make("DataType",
DataType::F32)),
- framework::dataset::make("ProjectionOpt", { true, false })),
- framework::dataset::make("PeepholeOpt", { true, false })))
+ framework::dataset::make("ProjectionOpt", { true, false })),
+ framework::dataset::make("PeepholeOpt", { true, false })),
+ framework::dataset::make("UseLayerNorm", { true, false })))
{
// Validate output
validate(CLAccessor(_target), _reference, tolerance_f32);
@@ -165,9 +166,11 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLLSTMLayerFixture<float>, framework::DatasetMo
TEST_SUITE_END() // FP32
TEST_SUITE(FP16)
-FIXTURE_DATA_TEST_CASE(RunSmall, CLLSTMLayerFixture<half>, framework::DatasetMode::ALL, combine(combine(combine(datasets::SmallLSTMLayerDataset(), framework::dataset::make("DataType", DataType::F16)),
- framework::dataset::make("ProjectionOpt", { true, false })),
- framework::dataset::make("PeepholeOpt", { true, false })))
+FIXTURE_DATA_TEST_CASE(RunSmall, CLLSTMLayerFixture<half>, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SmallLSTMLayerDataset(), framework::dataset::make("DataType",
+ DataType::F16)),
+ framework::dataset::make("ProjectionOpt", { true, false })),
+ framework::dataset::make("PeepholeOpt", { true, false })),
+ framework::dataset::make("UseLayerNorm", { true, false })))
{
// Validate output
validate(CLAccessor(_target), _reference, tolerance_f16);
diff --git a/tests/validation/NEON/LSTMLayer.cpp b/tests/validation/NEON/LSTMLayer.cpp
index b27dfae8fa..c503972ba9 100644
--- a/tests/validation/NEON/LSTMLayer.cpp
+++ b/tests/validation/NEON/LSTMLayer.cpp
@@ -153,10 +153,11 @@ template <typename T>
using NELSTMLayerFixture = LSTMLayerValidationFixture<Tensor, Accessor, NELSTMLayer, LSTMParams<ITensor>, T>;
TEST_SUITE(FP32)
-FIXTURE_DATA_TEST_CASE(RunSmall, NELSTMLayerFixture<float>, framework::DatasetMode::ALL, combine(combine(combine(datasets::SmallLSTMLayerDataset(), framework::dataset::make("DataType",
+FIXTURE_DATA_TEST_CASE(RunSmall, NELSTMLayerFixture<float>, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SmallLSTMLayerDataset(), framework::dataset::make("DataType",
DataType::F32)),
- framework::dataset::make("ProjectionOpt", { true, false })),
- framework::dataset::make("PeepholeOpt", { true, false })))
+ framework::dataset::make("ProjectionOpt", { true, false })),
+ framework::dataset::make("PeepholeOpt", { true, false })),
+ framework::dataset::make("UseLayerNorm", { false })))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_f32);
@@ -166,9 +167,11 @@ TEST_SUITE_END() // FP32
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
TEST_SUITE(FP16)
-FIXTURE_DATA_TEST_CASE(RunSmall, NELSTMLayerFixture<half>, framework::DatasetMode::ALL, combine(combine(combine(datasets::SmallLSTMLayerDataset(), framework::dataset::make("DataType", DataType::F16)),
- framework::dataset::make("ProjectionOpt", { true, false })),
- framework::dataset::make("PeepholeOpt", { true, false })))
+FIXTURE_DATA_TEST_CASE(RunSmall, NELSTMLayerFixture<half>, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SmallLSTMLayerDataset(), framework::dataset::make("DataType",
+ DataType::F16)),
+ framework::dataset::make("ProjectionOpt", { true, false })),
+ framework::dataset::make("PeepholeOpt", { true, false })),
+ framework::dataset::make("UseLayerNorm", { false })))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_f16);
diff --git a/tests/validation/fixtures/LSTMLayerFixture.h b/tests/validation/fixtures/LSTMLayerFixture.h
index 2cf83b8b3d..9260686d56 100644
--- a/tests/validation/fixtures/LSTMLayerFixture.h
+++ b/tests/validation/fixtures/LSTMLayerFixture.h
@@ -32,6 +32,7 @@
#include "tests/validation/reference/ConcatenateLayer.h"
#include "tests/validation/reference/FullyConnectedLayer.h"
#include "tests/validation/reference/GEMM.h"
+#include "tests/validation/reference/MeanStdDevNormalizationLayer.h"
#include "tests/validation/reference/PixelWiseMultiplication.h"
#include "tests/validation/reference/Transpose.h"
@@ -47,12 +48,13 @@ class LSTMLayerValidationFixture : public framework::Fixture
public:
template <typename...>
void setup(TensorShape input_shape, TensorShape input_weights_shape, TensorShape recurrent_weights_shape, TensorShape cell_bias_shape, TensorShape output_cell_shape, TensorShape output_shape,
- TensorShape scratch_shape, ActivationLayerInfo info, float cell_threshold, float projection_threshold, DataType data_type, bool projection_opt, bool peephole_opt)
+ TensorShape scratch_shape, ActivationLayerInfo info, float cell_threshold, float projection_threshold, DataType data_type, bool projection_opt, bool peephole_opt,
+ bool use_layer_norm)
{
_target = compute_target(input_shape, input_weights_shape, recurrent_weights_shape, cell_bias_shape, output_cell_shape, output_shape, scratch_shape, info, cell_threshold, projection_threshold,
- data_type, projection_opt, peephole_opt);
+ data_type, projection_opt, peephole_opt, use_layer_norm);
_reference = compute_reference(input_shape, input_weights_shape, recurrent_weights_shape, cell_bias_shape, output_cell_shape, output_shape, scratch_shape, info, cell_threshold, projection_threshold,
- data_type, projection_opt, peephole_opt);
+ data_type, projection_opt, peephole_opt, use_layer_norm);
}
protected:
@@ -70,7 +72,7 @@ protected:
}
TensorType compute_target(const TensorShape &input_shape, const TensorShape &input_weights_shape, const TensorShape &recurrent_weights_shape, const TensorShape &cell_bias_shape,
const TensorShape &output_cell_shape, const TensorShape &output_shape, const TensorShape &scratch_shape, ActivationLayerInfo info, float cell_threshold,
- float projection_threshold, DataType data_type, bool projection_opt, bool peephole_opt)
+ float projection_threshold, DataType data_type, bool projection_opt, bool peephole_opt, bool use_layer_norm)
{
const unsigned int num_cells = input_weights_shape.y();
const unsigned int num_outputs = recurrent_weights_shape.x();
@@ -100,6 +102,10 @@ protected:
TensorType cell_to_output_w;
TensorType projection_w;
TensorType projection_bias;
+ TensorType input_layer_norm_w;
+ TensorType forget_layer_norm_w;
+ TensorType cell_layer_norm_w;
+ TensorType output_layer_norm_w;
bool cifg_opt = scratch_shape.x() == cell_bias_shape.x() * 4 ? false : true;
@@ -131,6 +137,22 @@ protected:
lstm_params.set_projection_params(&projection_w, &projection_bias);
}
+ if(use_layer_norm)
+ {
+ forget_layer_norm_w = create_tensor<TensorType>(TensorShape(num_cells), data_type);
+ cell_layer_norm_w = create_tensor<TensorType>(TensorShape(num_cells), data_type);
+ output_layer_norm_w = create_tensor<TensorType>(TensorShape(num_cells), data_type);
+ if(!cifg_opt)
+ {
+ input_layer_norm_w = create_tensor<TensorType>(TensorShape(num_cells), data_type);
+ lstm_params.set_layer_normalization_params(&input_layer_norm_w, &forget_layer_norm_w, &cell_layer_norm_w, &output_layer_norm_w);
+ }
+ else
+ {
+ lstm_params.set_layer_normalization_params(nullptr, &forget_layer_norm_w, &cell_layer_norm_w, &output_layer_norm_w);
+ }
+ }
+
// Create and configure function
FunctionType lstm;
lstm.configure(&input, &input_to_forget_w, &input_to_cell_w, &input_to_output_w, &recurrent_to_forget_w,
@@ -257,6 +279,35 @@ protected:
fill(AccessorType(projection_bias), 21);
}
+ if(use_layer_norm)
+ {
+ if(!cifg_opt)
+ {
+ ARM_COMPUTE_EXPECT(input_layer_norm_w.info()->is_resizable(), framework::LogLevel::ERRORS);
+
+ input_layer_norm_w.allocator()->allocate();
+
+ ARM_COMPUTE_EXPECT(!input_layer_norm_w.info()->is_resizable(), framework::LogLevel::ERRORS);
+
+ fill(AccessorType(input_layer_norm_w), 22);
+ }
+ ARM_COMPUTE_EXPECT(forget_layer_norm_w.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(cell_layer_norm_w.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(output_layer_norm_w.info()->is_resizable(), framework::LogLevel::ERRORS);
+
+ forget_layer_norm_w.allocator()->allocate();
+ cell_layer_norm_w.allocator()->allocate();
+ output_layer_norm_w.allocator()->allocate();
+
+ ARM_COMPUTE_EXPECT(!forget_layer_norm_w.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(!cell_layer_norm_w.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(!output_layer_norm_w.info()->is_resizable(), framework::LogLevel::ERRORS);
+
+ fill(AccessorType(forget_layer_norm_w), 23);
+ fill(AccessorType(cell_layer_norm_w), 24);
+ fill(AccessorType(output_layer_norm_w), 25);
+ }
+
// Compute function
lstm.run();
@@ -266,7 +317,7 @@ protected:
SimpleTensor<T> compute_reference(const TensorShape &input_shape, const TensorShape &input_weights_shape, const TensorShape &recurrent_weights_shape, const TensorShape &cell_bias_shape,
const TensorShape &output_cell_shape, const TensorShape &output_shape, const TensorShape &scratch_shape, ActivationLayerInfo info, float cell_threshold,
- float projection_threshold, DataType data_type, bool projection_opt, bool peephole_opt)
+ float projection_threshold, DataType data_type, bool projection_opt, bool peephole_opt, bool use_layer_norm)
{
const unsigned int num_cells = input_weights_shape.y();
const unsigned int num_outputs = recurrent_weights_shape.x();
@@ -306,6 +357,8 @@ protected:
SimpleTensor<T> cell_state_out{ output_cell_shape, data_type };
SimpleTensor<T> output{ output_shape, data_type };
+ bool cifg_opt = scratch_shape.x() == cell_bias_shape.x() * 4 ? false : true;
+
// Fill reference
fill(input, 0);
fill(input_to_forget_w, 1);
@@ -314,9 +367,18 @@ protected:
fill(recurrent_to_forget_w, 4);
fill(recurrent_to_cell_w, 5);
fill(recurrent_to_output_w, 6);
- fill(forget_gate_bias, 7);
- fill(cell_bias, 8);
- fill(output_gate_bias, 9);
+ if(use_layer_norm)
+ {
+ fill_custom_val(forget_gate_bias, 0.f, 7);
+ fill_custom_val(cell_bias, 0.f, 8);
+ fill_custom_val(output_gate_bias, 0.f, 9);
+ }
+ else
+ {
+ fill(forget_gate_bias, 7);
+ fill(cell_bias, 8);
+ fill(output_gate_bias, 9);
+ }
fill(output_state_in, 10);
fill(cell_state_in, 11);
fill(scratch, 12);
@@ -324,14 +386,19 @@ protected:
fill(recurrent_to_input_w, 14);
fill(cell_to_input_w, 15);
fill(recurrent_to_input_w, 16);
- fill(input_gate_bias, 17);
+ if(!cifg_opt && use_layer_norm)
+ {
+ fill_custom_val(input_gate_bias, 0.f, 17);
+ }
+ else
+ {
+ fill(input_gate_bias, 17);
+ }
fill(cell_to_forget_w, 18);
fill(cell_to_output_w, 19);
fill(projection_w, 20);
fill(projection_bias, 21);
- bool cifg_opt = scratch_shape.x() == cell_bias_shape.x() * 4 ? false : true;
-
// Compute forget_gate
SimpleTensor<T> fully_connected_forget = reference::fully_connected_layer(input, input_to_forget_w, forget_gate_bias, output_cell_shape);
SimpleTensor<T> transposed_weights = reference::transpose(recurrent_to_forget_w);
@@ -344,6 +411,15 @@ protected:
forget_gate = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, forget_gate, pixelwise_mul_forget_gate, data_type, ConvertPolicy::SATURATE);
}
+ if(use_layer_norm)
+ {
+ SimpleTensor<T> forget_layer_norm_w{ cell_bias_shape, data_type };
+ fill(forget_layer_norm_w, 23);
+ forget_gate = reference::mean_std_normalization_layer(forget_gate);
+ forget_gate = reference::pixel_wise_multiplication(forget_gate, forget_layer_norm_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
+ fill(forget_gate_bias, 7);
+ forget_gate = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, forget_gate, forget_gate_bias, data_type, ConvertPolicy::SATURATE);
+ }
forget_gate = reference::activation_layer(forget_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
// Compute input_gate
@@ -365,6 +441,15 @@ protected:
SimpleTensor<T> pixelwise_mul_input_gate = reference::pixel_wise_multiplication(cell_state_in, cell_to_input_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
input_gate = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, input_gate, pixelwise_mul_input_gate, data_type, ConvertPolicy::SATURATE);
}
+ if(use_layer_norm)
+ {
+ SimpleTensor<T> input_layer_norm_w{ cell_bias_shape, data_type };
+ fill(input_layer_norm_w, 22);
+ input_gate = reference::mean_std_normalization_layer(input_gate);
+ input_gate = reference::pixel_wise_multiplication(input_gate, input_layer_norm_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
+ fill(input_gate_bias, 17);
+ input_gate = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, input_gate, input_gate_bias, data_type, ConvertPolicy::SATURATE);
+ }
input_gate = reference::activation_layer(input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
}
@@ -374,9 +459,18 @@ protected:
gemm = reference::gemm(output_state_in, transposed_weights, cell_state_out, 1.f, 0.f);
SimpleTensor<T> pixelwise_mul = reference::pixel_wise_multiplication(cell_state_in, forget_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
cell_state_out = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, fully_connected_cell_state, gemm, data_type, ConvertPolicy::SATURATE);
- cell_state_out = reference::activation_layer(cell_state_out, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
- cell_state_out = reference::pixel_wise_multiplication(cell_state_out, input_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
- cell_state_out = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, cell_state_out, pixelwise_mul, data_type, ConvertPolicy::SATURATE);
+ if(use_layer_norm)
+ {
+ SimpleTensor<T> cell_layer_norm_w{ cell_bias_shape, data_type };
+ fill(cell_layer_norm_w, 24);
+ cell_state_out = reference::mean_std_normalization_layer(cell_state_out);
+ cell_state_out = reference::pixel_wise_multiplication(cell_state_out, cell_layer_norm_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
+ fill(cell_bias, 8);
+ cell_state_out = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, cell_state_out, cell_bias, data_type, ConvertPolicy::SATURATE);
+ }
+ cell_state_out = reference::activation_layer(cell_state_out, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
+ cell_state_out = reference::pixel_wise_multiplication(cell_state_out, input_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
+ cell_state_out = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, cell_state_out, pixelwise_mul, data_type, ConvertPolicy::SATURATE);
if(cell_threshold != 0.f)
{
cell_state_out = reference::activation_layer(cell_state_out, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -cell_threshold, cell_threshold));
@@ -392,6 +486,15 @@ protected:
pixelwise_mul = reference::pixel_wise_multiplication(cell_state_out, cell_to_output_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
output = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, output, pixelwise_mul, data_type, ConvertPolicy::SATURATE);
}
+ if(use_layer_norm)
+ {
+ SimpleTensor<T> output_layer_norm_w{ cell_bias_shape, data_type };
+ fill(output_layer_norm_w, 25);
+ output = reference::mean_std_normalization_layer(output);
+ output = reference::pixel_wise_multiplication(output, output_layer_norm_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
+ fill(output_gate_bias, 9);
+ output = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, output, output_gate_bias, data_type, ConvertPolicy::SATURATE);
+ }
output = reference::activation_layer(output, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
// Compute output state