diff options
Diffstat (limited to 'tests/validation/fixtures/LSTMLayerFixture.h')
-rw-r--r-- | tests/validation/fixtures/LSTMLayerFixture.h | 21 |
1 files changed, 10 insertions, 11 deletions
diff --git a/tests/validation/fixtures/LSTMLayerFixture.h b/tests/validation/fixtures/LSTMLayerFixture.h index 20df855242..bc892bfecf 100644 --- a/tests/validation/fixtures/LSTMLayerFixture.h +++ b/tests/validation/fixtures/LSTMLayerFixture.h @@ -28,8 +28,7 @@ #include "tests/framework/Asserts.h" #include "tests/framework/Fixture.h" #include "tests/validation/reference/ActivationLayer.h" -#include "tests/validation/reference/ArithmeticAddition.h" -#include "tests/validation/reference/ArithmeticSubtraction.h" +#include "tests/validation/reference/ArithmeticOperations.h" #include "tests/validation/reference/FullyConnectedLayer.h" #include "tests/validation/reference/GEMM.h" #include "tests/validation/reference/PixelWiseMultiplication.h" @@ -333,12 +332,12 @@ protected: SimpleTensor<T> fully_connected_forget = reference::fully_connected_layer(input, input_to_forget_w, forget_gate_bias, output_cell_shape); SimpleTensor<T> transposed_weights = reference::transpose(recurrent_to_forget_w); SimpleTensor<T> gemm = reference::gemm(output_state_in, transposed_weights, cell_state_in, 1.f, 0.f); - SimpleTensor<T> forget_gate = reference::arithmetic_addition(fully_connected_forget, gemm, data_type, ConvertPolicy::SATURATE); + SimpleTensor<T> forget_gate = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, fully_connected_forget, gemm, data_type, ConvertPolicy::SATURATE); if(peephole_opt) { SimpleTensor<T> pixelwise_mul_forget_gate = reference::pixel_wise_multiplication(cell_state_in, cell_to_forget_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN); - forget_gate = reference::arithmetic_addition(forget_gate, pixelwise_mul_forget_gate, data_type, ConvertPolicy::SATURATE); + forget_gate = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, forget_gate, pixelwise_mul_forget_gate, data_type, ConvertPolicy::SATURATE); } forget_gate = reference::activation_layer(forget_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)); @@ -349,18 +348,18 @@ protected: { SimpleTensor<T> ones{ cell_bias_shape, data_type }; fill_custom_val(ones, 1.f, 0); - input_gate = reference::arithmetic_subtraction<T, T, T>(ones, forget_gate, data_type, ConvertPolicy::SATURATE); + input_gate = reference::arithmetic_operation<T>(reference::ArithmeticOperation::SUB, ones, forget_gate, data_type, ConvertPolicy::SATURATE); } else { SimpleTensor<T> fully_connected_input = reference::fully_connected_layer(input, input_to_input_w, input_gate_bias, output_cell_shape); transposed_weights = reference::transpose(recurrent_to_input_w); gemm = reference::gemm(output_state_in, transposed_weights, cell_state_in, 1.f, 0.f); - input_gate = reference::arithmetic_addition(fully_connected_input, gemm, data_type, ConvertPolicy::SATURATE); + input_gate = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, fully_connected_input, gemm, data_type, ConvertPolicy::SATURATE); if(peephole_opt) { SimpleTensor<T> pixelwise_mul_input_gate = reference::pixel_wise_multiplication(cell_state_in, cell_to_input_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN); - input_gate = reference::arithmetic_addition(input_gate, pixelwise_mul_input_gate, data_type, ConvertPolicy::SATURATE); + input_gate = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, input_gate, pixelwise_mul_input_gate, data_type, ConvertPolicy::SATURATE); } input_gate = reference::activation_layer(input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)); } @@ -370,10 +369,10 @@ protected: transposed_weights = reference::transpose(recurrent_to_cell_w); gemm = reference::gemm(output_state_in, transposed_weights, cell_state_out, 1.f, 0.f); SimpleTensor<T> pixelwise_mul = reference::pixel_wise_multiplication(cell_state_in, forget_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN); - cell_state_out = reference::arithmetic_addition(fully_connected_cell_state, gemm, data_type, ConvertPolicy::SATURATE); + cell_state_out = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, fully_connected_cell_state, gemm, data_type, ConvertPolicy::SATURATE); cell_state_out = reference::activation_layer(cell_state_out, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)); cell_state_out = reference::pixel_wise_multiplication(cell_state_out, input_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN); - cell_state_out = reference::arithmetic_addition(cell_state_out, pixelwise_mul, data_type, ConvertPolicy::SATURATE); + cell_state_out = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, cell_state_out, pixelwise_mul, data_type, ConvertPolicy::SATURATE); if(cell_threshold != 0.f) { cell_state_out = reference::activation_layer(cell_state_out, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -cell_threshold, cell_threshold)); @@ -383,11 +382,11 @@ protected: SimpleTensor<T> fully_connected_output = reference::fully_connected_layer(input, input_to_output_w, output_gate_bias, output_cell_shape); transposed_weights = reference::transpose(recurrent_to_output_w); gemm = reference::gemm(output_state_in, transposed_weights, cell_state_out, 1.f, 0.f); - output = reference::arithmetic_addition(fully_connected_output, gemm, data_type, ConvertPolicy::SATURATE); + output = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, fully_connected_output, gemm, data_type, ConvertPolicy::SATURATE); if(peephole_opt) { pixelwise_mul = reference::pixel_wise_multiplication(cell_state_out, cell_to_output_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN); - output = reference::arithmetic_addition(output, pixelwise_mul, data_type, ConvertPolicy::SATURATE); + output = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, output, pixelwise_mul, data_type, ConvertPolicy::SATURATE); } output = reference::activation_layer(output, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC)); |