aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/LSTMLayerFixture.h
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2018-09-10 15:07:45 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:54:54 +0000
commitcbf39c63a6eb89a2c80b2338afc374081803d79d (patch)
treeafe1c55d5e3bbf0e111ec0dce9a564304844a55f /tests/validation/fixtures/LSTMLayerFixture.h
parentd7647d4ebd0f0b5253b7f31ffcd48a851ba62947 (diff)
downloadComputeLibrary-cbf39c63a6eb89a2c80b2338afc374081803d79d.tar.gz
COMPMID-1566: Add broadcast to CLArithmeticSubtraction
Change-Id: I05d21f9a92013ecfd1128d12cf1561cfd6e5c5e9 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/147983 Tested-by: bsgcomp <bsgcomp@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'tests/validation/fixtures/LSTMLayerFixture.h')
-rw-r--r--tests/validation/fixtures/LSTMLayerFixture.h21
1 files changed, 10 insertions, 11 deletions
diff --git a/tests/validation/fixtures/LSTMLayerFixture.h b/tests/validation/fixtures/LSTMLayerFixture.h
index 20df855242..bc892bfecf 100644
--- a/tests/validation/fixtures/LSTMLayerFixture.h
+++ b/tests/validation/fixtures/LSTMLayerFixture.h
@@ -28,8 +28,7 @@
#include "tests/framework/Asserts.h"
#include "tests/framework/Fixture.h"
#include "tests/validation/reference/ActivationLayer.h"
-#include "tests/validation/reference/ArithmeticAddition.h"
-#include "tests/validation/reference/ArithmeticSubtraction.h"
+#include "tests/validation/reference/ArithmeticOperations.h"
#include "tests/validation/reference/FullyConnectedLayer.h"
#include "tests/validation/reference/GEMM.h"
#include "tests/validation/reference/PixelWiseMultiplication.h"
@@ -333,12 +332,12 @@ protected:
SimpleTensor<T> fully_connected_forget = reference::fully_connected_layer(input, input_to_forget_w, forget_gate_bias, output_cell_shape);
SimpleTensor<T> transposed_weights = reference::transpose(recurrent_to_forget_w);
SimpleTensor<T> gemm = reference::gemm(output_state_in, transposed_weights, cell_state_in, 1.f, 0.f);
- SimpleTensor<T> forget_gate = reference::arithmetic_addition(fully_connected_forget, gemm, data_type, ConvertPolicy::SATURATE);
+ SimpleTensor<T> forget_gate = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, fully_connected_forget, gemm, data_type, ConvertPolicy::SATURATE);
if(peephole_opt)
{
SimpleTensor<T> pixelwise_mul_forget_gate = reference::pixel_wise_multiplication(cell_state_in, cell_to_forget_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
- forget_gate = reference::arithmetic_addition(forget_gate, pixelwise_mul_forget_gate, data_type, ConvertPolicy::SATURATE);
+ forget_gate = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, forget_gate, pixelwise_mul_forget_gate, data_type, ConvertPolicy::SATURATE);
}
forget_gate = reference::activation_layer(forget_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
@@ -349,18 +348,18 @@ protected:
{
SimpleTensor<T> ones{ cell_bias_shape, data_type };
fill_custom_val(ones, 1.f, 0);
- input_gate = reference::arithmetic_subtraction<T, T, T>(ones, forget_gate, data_type, ConvertPolicy::SATURATE);
+ input_gate = reference::arithmetic_operation<T>(reference::ArithmeticOperation::SUB, ones, forget_gate, data_type, ConvertPolicy::SATURATE);
}
else
{
SimpleTensor<T> fully_connected_input = reference::fully_connected_layer(input, input_to_input_w, input_gate_bias, output_cell_shape);
transposed_weights = reference::transpose(recurrent_to_input_w);
gemm = reference::gemm(output_state_in, transposed_weights, cell_state_in, 1.f, 0.f);
- input_gate = reference::arithmetic_addition(fully_connected_input, gemm, data_type, ConvertPolicy::SATURATE);
+ input_gate = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, fully_connected_input, gemm, data_type, ConvertPolicy::SATURATE);
if(peephole_opt)
{
SimpleTensor<T> pixelwise_mul_input_gate = reference::pixel_wise_multiplication(cell_state_in, cell_to_input_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
- input_gate = reference::arithmetic_addition(input_gate, pixelwise_mul_input_gate, data_type, ConvertPolicy::SATURATE);
+ input_gate = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, input_gate, pixelwise_mul_input_gate, data_type, ConvertPolicy::SATURATE);
}
input_gate = reference::activation_layer(input_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
}
@@ -370,10 +369,10 @@ protected:
transposed_weights = reference::transpose(recurrent_to_cell_w);
gemm = reference::gemm(output_state_in, transposed_weights, cell_state_out, 1.f, 0.f);
SimpleTensor<T> pixelwise_mul = reference::pixel_wise_multiplication(cell_state_in, forget_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
- cell_state_out = reference::arithmetic_addition(fully_connected_cell_state, gemm, data_type, ConvertPolicy::SATURATE);
+ cell_state_out = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, fully_connected_cell_state, gemm, data_type, ConvertPolicy::SATURATE);
cell_state_out = reference::activation_layer(cell_state_out, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
cell_state_out = reference::pixel_wise_multiplication(cell_state_out, input_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
- cell_state_out = reference::arithmetic_addition(cell_state_out, pixelwise_mul, data_type, ConvertPolicy::SATURATE);
+ cell_state_out = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, cell_state_out, pixelwise_mul, data_type, ConvertPolicy::SATURATE);
if(cell_threshold != 0.f)
{
cell_state_out = reference::activation_layer(cell_state_out, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -cell_threshold, cell_threshold));
@@ -383,11 +382,11 @@ protected:
SimpleTensor<T> fully_connected_output = reference::fully_connected_layer(input, input_to_output_w, output_gate_bias, output_cell_shape);
transposed_weights = reference::transpose(recurrent_to_output_w);
gemm = reference::gemm(output_state_in, transposed_weights, cell_state_out, 1.f, 0.f);
- output = reference::arithmetic_addition(fully_connected_output, gemm, data_type, ConvertPolicy::SATURATE);
+ output = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, fully_connected_output, gemm, data_type, ConvertPolicy::SATURATE);
if(peephole_opt)
{
pixelwise_mul = reference::pixel_wise_multiplication(cell_state_out, cell_to_output_w, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_EVEN);
- output = reference::arithmetic_addition(output, pixelwise_mul, data_type, ConvertPolicy::SATURATE);
+ output = reference::arithmetic_operation(reference::ArithmeticOperation::ADD, output, pixelwise_mul, data_type, ConvertPolicy::SATURATE);
}
output = reference::activation_layer(output, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));