diff options
Diffstat (limited to 'tests/validation/fixtures')
-rw-r--r-- | tests/validation/fixtures/LSTMLayerFixture.h | 25 |
1 files changed, 11 insertions, 14 deletions
diff --git a/tests/validation/fixtures/LSTMLayerFixture.h b/tests/validation/fixtures/LSTMLayerFixture.h index bff2f375cd..3c1a560b24 100644 --- a/tests/validation/fixtures/LSTMLayerFixture.h +++ b/tests/validation/fixtures/LSTMLayerFixture.h @@ -315,9 +315,8 @@ protected: if(peephole_opt) { - transposed_weights = reference::transpose(cell_to_forget_w); - gemm = reference::gemm(cell_state, transposed_weights, gemm_out, 1.f, 0.f); - forget_gate = reference::arithmetic_addition(forget_gate, gemm, data_type, ConvertPolicy::SATURATE); + SimpleTensor<T> pixelwise_mul_forget_gate = reference::pixel_wise_multiplication(cell_state, cell_to_forget_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN); + forget_gate = reference::arithmetic_addition(forget_gate, pixelwise_mul_forget_gate, data_type, ConvertPolicy::SATURATE); } forget_gate = reference::activation_layer(forget_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)); @@ -332,14 +331,13 @@ protected: } else { - SimpleTensor<T> fully_connected_input = reference::fully_connected_layer(input, input_to_input_w, input_gate_bias, output_cell_shape); - transposed_weights = reference::transpose(recurrent_to_input_w); - gemm = reference::gemm(output_state, transposed_weights, cell_state, 1.f, 0.f); - input_gate = reference::arithmetic_addition(fully_connected_input, gemm, data_type, ConvertPolicy::SATURATE); - transposed_weights = reference::transpose(cell_to_input_w); - gemm = reference::gemm(cell_state, transposed_weights, gemm_out, 1.f, 0.f); - input_gate = reference::arithmetic_addition(input_gate, gemm, data_type, ConvertPolicy::SATURATE); - input_gate = reference::activation_layer(input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)); + SimpleTensor<T> fully_connected_input = reference::fully_connected_layer(input, input_to_input_w, input_gate_bias, output_cell_shape); + transposed_weights = reference::transpose(recurrent_to_input_w); + gemm = reference::gemm(output_state, transposed_weights, cell_state, 1.f, 0.f); + input_gate = reference::arithmetic_addition(fully_connected_input, gemm, data_type, ConvertPolicy::SATURATE); + SimpleTensor<T> pixelwise_mul_input_gate = reference::pixel_wise_multiplication(cell_state, cell_to_input_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN); + input_gate = reference::arithmetic_addition(input_gate, pixelwise_mul_input_gate, data_type, ConvertPolicy::SATURATE); + input_gate = reference::activation_layer(input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)); } // Compute cell_state @@ -363,9 +361,8 @@ protected: output = reference::arithmetic_addition(fully_connected_output, gemm, data_type, ConvertPolicy::SATURATE); if(peephole_opt) { - transposed_weights = reference::transpose(cell_to_output_w); - gemm = reference::gemm(cell_state, transposed_weights, gemm_out, 1.f, 0.f); - output = reference::arithmetic_addition(output, gemm, data_type, ConvertPolicy::SATURATE); + pixelwise_mul = reference::pixel_wise_multiplication(cell_state, cell_to_output_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN); + output = reference::arithmetic_addition(output, pixelwise_mul, data_type, ConvertPolicy::SATURATE); } output = reference::activation_layer(output, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)); |