aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/LSTMLayerFixture.h
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2018-07-18 19:51:24 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:54:54 +0000
commit8bc745dfd133f46e68edad511e2933a590602a24 (patch)
tree440126c9bb569e4eda74a580da9b041afe216eef /tests/validation/fixtures/LSTMLayerFixture.h
parent201cea1b40597c226bf2c8e59d90bebdf9817dd3 (diff)
downloadComputeLibrary-8bc745dfd133f46e68edad511e2933a590602a24.tar.gz
COMPMID-1124: Validate CLLSTM
-Enables cell-to-input weights when !cifg and peephole -Makes projection bias conditional Change-Id: Iee866db9f5d8479c2dfd95d74a2d42492bf07a8d Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/140543 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Les Bell <les.bell@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'tests/validation/fixtures/LSTMLayerFixture.h')
-rw-r--r--tests/validation/fixtures/LSTMLayerFixture.h125
1 files changed, 75 insertions, 50 deletions
diff --git a/tests/validation/fixtures/LSTMLayerFixture.h b/tests/validation/fixtures/LSTMLayerFixture.h
index 3c1a560b24..20df855242 100644
--- a/tests/validation/fixtures/LSTMLayerFixture.h
+++ b/tests/validation/fixtures/LSTMLayerFixture.h
@@ -72,9 +72,8 @@ protected:
const TensorShape &output_cell_shape, const TensorShape &output_shape, const TensorShape &scratch_shape, ActivationLayerInfo info, float cell_threshold,
float projection_threshold, DataType data_type, bool projection_opt, bool peephole_opt)
{
- // Create projection bias shape
- TensorShape projection_bias_shape{};
- projection_bias_shape.set(0, output_shape.x());
+ const unsigned int num_cells = input_weights_shape.y();
+ const unsigned int num_outputs = recurrent_weights_shape.x();
// Create tensors
TensorType input = create_tensor<TensorType>(input_shape, data_type);
@@ -87,9 +86,11 @@ protected:
TensorType forget_gate_bias = create_tensor<TensorType>(cell_bias_shape, data_type);
TensorType cell_bias = create_tensor<TensorType>(cell_bias_shape, data_type);
TensorType output_gate_bias = create_tensor<TensorType>(cell_bias_shape, data_type);
- TensorType output_state = create_tensor<TensorType>(output_shape, data_type);
- TensorType cell_state = create_tensor<TensorType>(output_cell_shape, data_type);
+ TensorType output_state_in = create_tensor<TensorType>(output_shape, data_type);
+ TensorType cell_state_in = create_tensor<TensorType>(output_cell_shape, data_type);
TensorType scratch = create_tensor<TensorType>(scratch_shape, data_type);
+ TensorType output_state_out = create_tensor<TensorType>(output_shape, data_type);
+ TensorType cell_state_out = create_tensor<TensorType>(output_cell_shape, data_type);
TensorType output = create_tensor<TensorType>(output_shape, data_type);
TensorType input_to_input_w;
TensorType recurrent_to_input_w;
@@ -108,8 +109,11 @@ protected:
{
input_to_input_w = create_tensor<TensorType>(input_weights_shape, data_type);
recurrent_to_input_w = create_tensor<TensorType>(recurrent_weights_shape, data_type);
- cell_to_input_w = create_tensor<TensorType>(cell_bias_shape, data_type);
- input_gate_bias = create_tensor<TensorType>(cell_bias_shape, data_type);
+ if(peephole_opt)
+ {
+ cell_to_input_w = create_tensor<TensorType>(cell_bias_shape, data_type);
+ }
+ input_gate_bias = create_tensor<TensorType>(cell_bias_shape, data_type);
lstm_params.set_cifg_params(&input_to_input_w, &recurrent_to_input_w, &cell_to_input_w, &input_gate_bias);
}
@@ -122,16 +126,18 @@ protected:
if(projection_opt)
{
- projection_w = create_tensor<TensorType>(recurrent_weights_shape, data_type);
- projection_bias = create_tensor<TensorType>(projection_bias_shape, data_type);
+ projection_w = create_tensor<TensorType>(TensorShape(num_cells, num_outputs), data_type);
+ projection_bias = create_tensor<TensorType>(TensorShape(num_outputs), data_type);
lstm_params.set_projection_params(&projection_w, &projection_bias);
}
// Create and configure function
FunctionType lstm;
lstm.configure(&input, &input_to_forget_w, &input_to_cell_w, &input_to_output_w, &recurrent_to_forget_w,
- &recurrent_to_cell_w, &recurrent_to_output_w, &forget_gate_bias, &cell_bias, &output_gate_bias, &output_state, &cell_state,
- &scratch, &output, lstm_params, info, cell_threshold, projection_threshold);
+ &recurrent_to_cell_w, &recurrent_to_output_w, &forget_gate_bias, &cell_bias, &output_gate_bias,
+ &output_state_in, &cell_state_in,
+ &scratch, &output_state_out, &cell_state_out, &output,
+ lstm_params, info, cell_threshold, projection_threshold);
ARM_COMPUTE_EXPECT(input.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(input_to_forget_w.info()->is_resizable(), framework::LogLevel::ERRORS);
@@ -143,9 +149,11 @@ protected:
ARM_COMPUTE_EXPECT(forget_gate_bias.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(cell_bias.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(output_gate_bias.info()->is_resizable(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(output_state.info()->is_resizable(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(cell_state.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(output_state_in.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(cell_state_in.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(scratch.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(output_state_out.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(cell_state_out.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(output.info()->is_resizable(), framework::LogLevel::ERRORS);
// Allocate tensors
@@ -159,9 +167,11 @@ protected:
forget_gate_bias.allocator()->allocate();
cell_bias.allocator()->allocate();
output_gate_bias.allocator()->allocate();
- output_state.allocator()->allocate();
- cell_state.allocator()->allocate();
+ output_state_in.allocator()->allocate();
+ cell_state_in.allocator()->allocate();
scratch.allocator()->allocate();
+ output_state_out.allocator()->allocate();
+ cell_state_out.allocator()->allocate();
output.allocator()->allocate();
ARM_COMPUTE_EXPECT(!input.info()->is_resizable(), framework::LogLevel::ERRORS);
@@ -174,9 +184,11 @@ protected:
ARM_COMPUTE_EXPECT(!forget_gate_bias.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(!cell_bias.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(!output_gate_bias.info()->is_resizable(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(!output_state.info()->is_resizable(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(!cell_state.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(!output_state_in.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(!cell_state_in.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(!scratch.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(!output_state_out.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(!cell_state_out.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(!output.info()->is_resizable(), framework::LogLevel::ERRORS);
// Fill tensors
@@ -190,8 +202,8 @@ protected:
fill(AccessorType(forget_gate_bias), 7);
fill(AccessorType(cell_bias), 8);
fill(AccessorType(output_gate_bias), 9);
- fill(AccessorType(output_state), 10);
- fill(AccessorType(cell_state), 11);
+ fill(AccessorType(output_state_in), 10);
+ fill(AccessorType(cell_state_in), 11);
fill(AccessorType(scratch), 12);
if(!cifg_opt)
@@ -210,7 +222,10 @@ protected:
ARM_COMPUTE_EXPECT(!input_gate_bias.info()->is_resizable(), framework::LogLevel::ERRORS);
fill(AccessorType(input_to_input_w), 13);
fill(AccessorType(recurrent_to_input_w), 14);
- fill(AccessorType(cell_to_input_w), 15);
+ if(peephole_opt)
+ {
+ fill(AccessorType(cell_to_input_w), 15);
+ }
fill(AccessorType(recurrent_to_input_w), 16);
fill(AccessorType(input_gate_bias), 17);
}
@@ -251,9 +266,14 @@ protected:
const TensorShape &output_cell_shape, const TensorShape &output_shape, const TensorShape &scratch_shape, ActivationLayerInfo info, float cell_threshold,
float projection_threshold, DataType data_type, bool projection_opt, bool peephole_opt)
{
+ const unsigned int num_cells = input_weights_shape.y();
+ const unsigned int num_outputs = recurrent_weights_shape.x();
+
+ // Create projection weights shape
+ TensorShape projection_weights_shape(num_cells, num_outputs);
+
// Create projection bias shape
- TensorShape projection_bias_shape{};
- projection_bias_shape.set(0, output_shape.x());
+ TensorShape projection_bias_shape(num_outputs);
TensorShape gemm_shape{ 1, output_shape.y() };
SimpleTensor<T> gemm_out{ gemm_shape, data_type };
@@ -275,11 +295,13 @@ protected:
SimpleTensor<T> forget_gate_bias{ cell_bias_shape, data_type };
SimpleTensor<T> cell_bias{ cell_bias_shape, data_type };
SimpleTensor<T> output_gate_bias{ cell_bias_shape, data_type };
- SimpleTensor<T> projection_w{ recurrent_weights_shape, data_type };
+ SimpleTensor<T> projection_w{ projection_weights_shape, data_type };
SimpleTensor<T> projection_bias{ projection_bias_shape, data_type };
- SimpleTensor<T> output_state{ output_shape, data_type };
- SimpleTensor<T> cell_state{ output_cell_shape, data_type };
+ SimpleTensor<T> output_state_in{ output_shape, data_type };
+ SimpleTensor<T> cell_state_in{ output_cell_shape, data_type };
SimpleTensor<T> scratch{ scratch_shape, data_type };
+ SimpleTensor<T> output_state_out{ output_shape, data_type };
+ SimpleTensor<T> cell_state_out{ output_cell_shape, data_type };
SimpleTensor<T> output{ output_shape, data_type };
// Fill reference
@@ -293,8 +315,8 @@ protected:
fill(forget_gate_bias, 7);
fill(cell_bias, 8);
fill(output_gate_bias, 9);
- fill(output_state, 10);
- fill(cell_state, 11);
+ fill(output_state_in, 10);
+ fill(cell_state_in, 11);
fill(scratch, 12);
fill(input_to_input_w, 13);
fill(recurrent_to_input_w, 14);
@@ -310,12 +332,12 @@ protected:
// Compute forget_gate
SimpleTensor<T> fully_connected_forget = reference::fully_connected_layer(input, input_to_forget_w, forget_gate_bias, output_cell_shape);
SimpleTensor<T> transposed_weights = reference::transpose(recurrent_to_forget_w);
- SimpleTensor<T> gemm = reference::gemm(output_state, transposed_weights, cell_state, 1.f, 0.f);
+ SimpleTensor<T> gemm = reference::gemm(output_state_in, transposed_weights, cell_state_in, 1.f, 0.f);
SimpleTensor<T> forget_gate = reference::arithmetic_addition(fully_connected_forget, gemm, data_type, ConvertPolicy::SATURATE);
if(peephole_opt)
{
- SimpleTensor<T> pixelwise_mul_forget_gate = reference::pixel_wise_multiplication(cell_state, cell_to_forget_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
+ SimpleTensor<T> pixelwise_mul_forget_gate = reference::pixel_wise_multiplication(cell_state_in, cell_to_forget_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
forget_gate = reference::arithmetic_addition(forget_gate, pixelwise_mul_forget_gate, data_type, ConvertPolicy::SATURATE);
}
@@ -331,54 +353,57 @@ protected:
}
else
{
- SimpleTensor<T> fully_connected_input = reference::fully_connected_layer(input, input_to_input_w, input_gate_bias, output_cell_shape);
- transposed_weights = reference::transpose(recurrent_to_input_w);
- gemm = reference::gemm(output_state, transposed_weights, cell_state, 1.f, 0.f);
- input_gate = reference::arithmetic_addition(fully_connected_input, gemm, data_type, ConvertPolicy::SATURATE);
- SimpleTensor<T> pixelwise_mul_input_gate = reference::pixel_wise_multiplication(cell_state, cell_to_input_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
- input_gate = reference::arithmetic_addition(input_gate, pixelwise_mul_input_gate, data_type, ConvertPolicy::SATURATE);
- input_gate = reference::activation_layer(input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
+ SimpleTensor<T> fully_connected_input = reference::fully_connected_layer(input, input_to_input_w, input_gate_bias, output_cell_shape);
+ transposed_weights = reference::transpose(recurrent_to_input_w);
+ gemm = reference::gemm(output_state_in, transposed_weights, cell_state_in, 1.f, 0.f);
+ input_gate = reference::arithmetic_addition(fully_connected_input, gemm, data_type, ConvertPolicy::SATURATE);
+ if(peephole_opt)
+ {
+ SimpleTensor<T> pixelwise_mul_input_gate = reference::pixel_wise_multiplication(cell_state_in, cell_to_input_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
+ input_gate = reference::arithmetic_addition(input_gate, pixelwise_mul_input_gate, data_type, ConvertPolicy::SATURATE);
+ }
+ input_gate = reference::activation_layer(input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
}
// Compute cell_state
SimpleTensor<T> fully_connected_cell_state = reference::fully_connected_layer(input, input_to_cell_w, cell_bias, output_cell_shape);
transposed_weights = reference::transpose(recurrent_to_cell_w);
- gemm = reference::gemm(output_state, transposed_weights, cell_state, 1.f, 0.f);
- SimpleTensor<T> pixelwise_mul = reference::pixel_wise_multiplication(cell_state, forget_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
- cell_state = reference::arithmetic_addition(fully_connected_cell_state, gemm, data_type, ConvertPolicy::SATURATE);
- cell_state = reference::activation_layer(cell_state, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
- cell_state = reference::pixel_wise_multiplication(cell_state, input_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
- cell_state = reference::arithmetic_addition(cell_state, pixelwise_mul, data_type, ConvertPolicy::SATURATE);
+ gemm = reference::gemm(output_state_in, transposed_weights, cell_state_out, 1.f, 0.f);
+ SimpleTensor<T> pixelwise_mul = reference::pixel_wise_multiplication(cell_state_in, forget_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
+ cell_state_out = reference::arithmetic_addition(fully_connected_cell_state, gemm, data_type, ConvertPolicy::SATURATE);
+ cell_state_out = reference::activation_layer(cell_state_out, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
+ cell_state_out = reference::pixel_wise_multiplication(cell_state_out, input_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
+ cell_state_out = reference::arithmetic_addition(cell_state_out, pixelwise_mul, data_type, ConvertPolicy::SATURATE);
if(cell_threshold != 0.f)
{
- cell_state = reference::activation_layer(cell_state, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -cell_threshold, cell_threshold));
+ cell_state_out = reference::activation_layer(cell_state_out, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -cell_threshold, cell_threshold));
}
// Compute output
SimpleTensor<T> fully_connected_output = reference::fully_connected_layer(input, input_to_output_w, output_gate_bias, output_cell_shape);
transposed_weights = reference::transpose(recurrent_to_output_w);
- gemm = reference::gemm(output_state, transposed_weights, cell_state, 1.f, 0.f);
+ gemm = reference::gemm(output_state_in, transposed_weights, cell_state_out, 1.f, 0.f);
output = reference::arithmetic_addition(fully_connected_output, gemm, data_type, ConvertPolicy::SATURATE);
if(peephole_opt)
{
- pixelwise_mul = reference::pixel_wise_multiplication(cell_state, cell_to_output_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
+ pixelwise_mul = reference::pixel_wise_multiplication(cell_state_out, cell_to_output_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
output = reference::arithmetic_addition(output, pixelwise_mul, data_type, ConvertPolicy::SATURATE);
}
output = reference::activation_layer(output, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
// Compute output state
- SimpleTensor<T> cell_state_activation = reference::activation_layer(cell_state, info);
- output_state = reference::pixel_wise_multiplication(output, cell_state_activation, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
+ SimpleTensor<T> cell_state_activation = reference::activation_layer(cell_state_out, info);
+ output_state_out = reference::pixel_wise_multiplication(output, cell_state_activation, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
if(projection_opt)
{
- SimpleTensor<T> fully_connected_projection = reference::fully_connected_layer(output_state, projection_w, projection_bias, output_cell_shape);
+ SimpleTensor<T> fully_connected_projection = reference::fully_connected_layer(output_state_out, projection_w, projection_bias, output_cell_shape);
if(projection_threshold != 0.f)
{
- output_state = reference::activation_layer(fully_connected_projection, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -projection_threshold, projection_threshold));
+ output_state_out = reference::activation_layer(fully_connected_projection, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -projection_threshold, projection_threshold));
}
}
- return output_state;
+ return output_state_out;
}
TensorType _target{};