diff options
author | Georgios Pinitas <georgios.pinitas@arm.com> | 2018-07-09 14:35:32 +0100 |
---|---|---|
committer | Anthony Barbier <anthony.barbier@arm.com> | 2018-11-02 16:54:10 +0000 |
commit | 42a31723ebe79895c9bb2297a9c2ef22c01a6f26 (patch) | |
tree | 640e7727372f0543f966cc1fc8e0f075aab18cf9 /tests | |
parent | 1d2f267934cb617a2dede585c2e83523777136ab (diff) | |
download | ComputeLibrary-42a31723ebe79895c9bb2297a9c2ef22c01a6f26.tar.gz |
COMPMID-1124 : Fixes in CLLSTM layer
Change-Id: Ifc8e12c296d3ef2bf8e0f0bf1b87b7fd47a1fad7
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/139248
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Ruomei Yan <ruomei.yan@arm.com>
Reviewed-by: Michalis Spyrou <michalis.spyrou@arm.com>
Diffstat (limited to 'tests')
-rw-r--r-- | tests/validation/fixtures/LSTMLayerFixture.h | 25 |
1 files changed, 11 insertions, 14 deletions
diff --git a/tests/validation/fixtures/LSTMLayerFixture.h b/tests/validation/fixtures/LSTMLayerFixture.h index bff2f375cd..3c1a560b24 100644 --- a/tests/validation/fixtures/LSTMLayerFixture.h +++ b/tests/validation/fixtures/LSTMLayerFixture.h @@ -315,9 +315,8 @@ protected: if(peephole_opt) { - transposed_weights = reference::transpose(cell_to_forget_w); - gemm = reference::gemm(cell_state, transposed_weights, gemm_out, 1.f, 0.f); - forget_gate = reference::arithmetic_addition(forget_gate, gemm, data_type, ConvertPolicy::SATURATE); + SimpleTensor<T> pixelwise_mul_forget_gate = reference::pixel_wise_multiplication(cell_state, cell_to_forget_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN); + forget_gate = reference::arithmetic_addition(forget_gate, pixelwise_mul_forget_gate, data_type, ConvertPolicy::SATURATE); } forget_gate = reference::activation_layer(forget_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)); @@ -332,14 +331,13 @@ protected: } else { - SimpleTensor<T> fully_connected_input = reference::fully_connected_layer(input, input_to_input_w, input_gate_bias, output_cell_shape); - transposed_weights = reference::transpose(recurrent_to_input_w); - gemm = reference::gemm(output_state, transposed_weights, cell_state, 1.f, 0.f); - input_gate = reference::arithmetic_addition(fully_connected_input, gemm, data_type, ConvertPolicy::SATURATE); - transposed_weights = reference::transpose(cell_to_input_w); - gemm = reference::gemm(cell_state, transposed_weights, gemm_out, 1.f, 0.f); - input_gate = reference::arithmetic_addition(input_gate, gemm, data_type, ConvertPolicy::SATURATE); - input_gate = reference::activation_layer(input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)); + SimpleTensor<T> fully_connected_input = reference::fully_connected_layer(input, input_to_input_w, input_gate_bias, output_cell_shape); + transposed_weights = reference::transpose(recurrent_to_input_w); + gemm = reference::gemm(output_state, transposed_weights, cell_state, 1.f, 0.f); + input_gate = reference::arithmetic_addition(fully_connected_input, gemm, data_type, ConvertPolicy::SATURATE); + SimpleTensor<T> pixelwise_mul_input_gate = reference::pixel_wise_multiplication(cell_state, cell_to_input_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN); + input_gate = reference::arithmetic_addition(input_gate, pixelwise_mul_input_gate, data_type, ConvertPolicy::SATURATE); + input_gate = reference::activation_layer(input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)); } // Compute cell_state @@ -363,9 +361,8 @@ protected: output = reference::arithmetic_addition(fully_connected_output, gemm, data_type, ConvertPolicy::SATURATE); if(peephole_opt) { - transposed_weights = reference::transpose(cell_to_output_w); - gemm = reference::gemm(cell_state, transposed_weights, gemm_out, 1.f, 0.f); - output = reference::arithmetic_addition(output, gemm, data_type, ConvertPolicy::SATURATE); + pixelwise_mul = reference::pixel_wise_multiplication(cell_state, cell_to_output_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN); + output = reference::arithmetic_addition(output, pixelwise_mul, data_type, ConvertPolicy::SATURATE); } output = reference::activation_layer(output, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)); |