aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/LSTMLayerFixture.h
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2018-07-09 14:35:32 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:54:10 +0000
commit42a31723ebe79895c9bb2297a9c2ef22c01a6f26 (patch)
tree640e7727372f0543f966cc1fc8e0f075aab18cf9 /tests/validation/fixtures/LSTMLayerFixture.h
parent1d2f267934cb617a2dede585c2e83523777136ab (diff)
downloadComputeLibrary-42a31723ebe79895c9bb2297a9c2ef22c01a6f26.tar.gz
COMPMID-1124 : Fixes in CLLSTM layer
Change-Id: Ifc8e12c296d3ef2bf8e0f0bf1b87b7fd47a1fad7 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/139248 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Ruomei Yan <ruomei.yan@arm.com> Reviewed-by: Michalis Spyrou <michalis.spyrou@arm.com>
Diffstat (limited to 'tests/validation/fixtures/LSTMLayerFixture.h')
-rw-r--r--tests/validation/fixtures/LSTMLayerFixture.h25
1 files changed, 11 insertions, 14 deletions
diff --git a/tests/validation/fixtures/LSTMLayerFixture.h b/tests/validation/fixtures/LSTMLayerFixture.h
index bff2f375cd..3c1a560b24 100644
--- a/tests/validation/fixtures/LSTMLayerFixture.h
+++ b/tests/validation/fixtures/LSTMLayerFixture.h
@@ -315,9 +315,8 @@ protected:
if(peephole_opt)
{
- transposed_weights = reference::transpose(cell_to_forget_w);
- gemm = reference::gemm(cell_state, transposed_weights, gemm_out, 1.f, 0.f);
- forget_gate = reference::arithmetic_addition(forget_gate, gemm, data_type, ConvertPolicy::SATURATE);
+ SimpleTensor<T> pixelwise_mul_forget_gate = reference::pixel_wise_multiplication(cell_state, cell_to_forget_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
+ forget_gate = reference::arithmetic_addition(forget_gate, pixelwise_mul_forget_gate, data_type, ConvertPolicy::SATURATE);
}
forget_gate = reference::activation_layer(forget_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
@@ -332,14 +331,13 @@ protected:
}
else
{
- SimpleTensor<T> fully_connected_input = reference::fully_connected_layer(input, input_to_input_w, input_gate_bias, output_cell_shape);
- transposed_weights = reference::transpose(recurrent_to_input_w);
- gemm = reference::gemm(output_state, transposed_weights, cell_state, 1.f, 0.f);
- input_gate = reference::arithmetic_addition(fully_connected_input, gemm, data_type, ConvertPolicy::SATURATE);
- transposed_weights = reference::transpose(cell_to_input_w);
- gemm = reference::gemm(cell_state, transposed_weights, gemm_out, 1.f, 0.f);
- input_gate = reference::arithmetic_addition(input_gate, gemm, data_type, ConvertPolicy::SATURATE);
- input_gate = reference::activation_layer(input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
+ SimpleTensor<T> fully_connected_input = reference::fully_connected_layer(input, input_to_input_w, input_gate_bias, output_cell_shape);
+ transposed_weights = reference::transpose(recurrent_to_input_w);
+ gemm = reference::gemm(output_state, transposed_weights, cell_state, 1.f, 0.f);
+ input_gate = reference::arithmetic_addition(fully_connected_input, gemm, data_type, ConvertPolicy::SATURATE);
+ SimpleTensor<T> pixelwise_mul_input_gate = reference::pixel_wise_multiplication(cell_state, cell_to_input_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
+ input_gate = reference::arithmetic_addition(input_gate, pixelwise_mul_input_gate, data_type, ConvertPolicy::SATURATE);
+ input_gate = reference::activation_layer(input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
}
// Compute cell_state
@@ -363,9 +361,8 @@ protected:
output = reference::arithmetic_addition(fully_connected_output, gemm, data_type, ConvertPolicy::SATURATE);
if(peephole_opt)
{
- transposed_weights = reference::transpose(cell_to_output_w);
- gemm = reference::gemm(cell_state, transposed_weights, gemm_out, 1.f, 0.f);
- output = reference::arithmetic_addition(output, gemm, data_type, ConvertPolicy::SATURATE);
+ pixelwise_mul = reference::pixel_wise_multiplication(cell_state, cell_to_output_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
+ output = reference::arithmetic_addition(output, pixelwise_mul, data_type, ConvertPolicy::SATURATE);
}
output = reference::activation_layer(output, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));