aboutsummaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/validation/fixtures/LSTMLayerFixture.h25
1 files changed, 11 insertions, 14 deletions
diff --git a/tests/validation/fixtures/LSTMLayerFixture.h b/tests/validation/fixtures/LSTMLayerFixture.h
index bff2f375cd..3c1a560b24 100644
--- a/tests/validation/fixtures/LSTMLayerFixture.h
+++ b/tests/validation/fixtures/LSTMLayerFixture.h
@@ -315,9 +315,8 @@ protected:
if(peephole_opt)
{
- transposed_weights = reference::transpose(cell_to_forget_w);
- gemm = reference::gemm(cell_state, transposed_weights, gemm_out, 1.f, 0.f);
- forget_gate = reference::arithmetic_addition(forget_gate, gemm, data_type, ConvertPolicy::SATURATE);
+ SimpleTensor<T> pixelwise_mul_forget_gate = reference::pixel_wise_multiplication(cell_state, cell_to_forget_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
+ forget_gate = reference::arithmetic_addition(forget_gate, pixelwise_mul_forget_gate, data_type, ConvertPolicy::SATURATE);
}
forget_gate = reference::activation_layer(forget_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
@@ -332,14 +331,13 @@ protected:
}
else
{
- SimpleTensor<T> fully_connected_input = reference::fully_connected_layer(input, input_to_input_w, input_gate_bias, output_cell_shape);
- transposed_weights = reference::transpose(recurrent_to_input_w);
- gemm = reference::gemm(output_state, transposed_weights, cell_state, 1.f, 0.f);
- input_gate = reference::arithmetic_addition(fully_connected_input, gemm, data_type, ConvertPolicy::SATURATE);
- transposed_weights = reference::transpose(cell_to_input_w);
- gemm = reference::gemm(cell_state, transposed_weights, gemm_out, 1.f, 0.f);
- input_gate = reference::arithmetic_addition(input_gate, gemm, data_type, ConvertPolicy::SATURATE);
- input_gate = reference::activation_layer(input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
+ SimpleTensor<T> fully_connected_input = reference::fully_connected_layer(input, input_to_input_w, input_gate_bias, output_cell_shape);
+ transposed_weights = reference::transpose(recurrent_to_input_w);
+ gemm = reference::gemm(output_state, transposed_weights, cell_state, 1.f, 0.f);
+ input_gate = reference::arithmetic_addition(fully_connected_input, gemm, data_type, ConvertPolicy::SATURATE);
+ SimpleTensor<T> pixelwise_mul_input_gate = reference::pixel_wise_multiplication(cell_state, cell_to_input_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
+ input_gate = reference::arithmetic_addition(input_gate, pixelwise_mul_input_gate, data_type, ConvertPolicy::SATURATE);
+ input_gate = reference::activation_layer(input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
}
// Compute cell_state
@@ -363,9 +361,8 @@ protected:
output = reference::arithmetic_addition(fully_connected_output, gemm, data_type, ConvertPolicy::SATURATE);
if(peephole_opt)
{
- transposed_weights = reference::transpose(cell_to_output_w);
- gemm = reference::gemm(cell_state, transposed_weights, gemm_out, 1.f, 0.f);
- output = reference::arithmetic_addition(output, gemm, data_type, ConvertPolicy::SATURATE);
+ pixelwise_mul = reference::pixel_wise_multiplication(cell_state, cell_to_output_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
+ output = reference::arithmetic_addition(output, pixelwise_mul, data_type, ConvertPolicy::SATURATE);
}
output = reference::activation_layer(output, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));