aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures
diff options
context:
space:
mode:
authorGian Marco Iodice <gianmarco.iodice@arm.com>2019-09-27 09:23:15 +0100
committerGian Marco Iodice <gianmarco.iodice@arm.com>2019-09-30 08:28:43 +0000
commit0c17aa25a4f7bc812707150b91930f0cf8e75294 (patch)
tree29088e00bd7ba443dc122ad3436b0a4ef369a102 /tests/validation/fixtures
parent40958adf8bad8fd9fefe591ee55a381f7bbb6fea (diff)
downloadComputeLibrary-0c17aa25a4f7bc812707150b91930f0cf8e75294.tar.gz
COMPMID-2571: Add mixed-precision support in CLGEMMReshaped for FP16
Change-Id: I5ba90d4de4594ed784c7230aa6b10503be67c001 Signed-off-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Reviewed-on: https://review.mlplatform.org/c/1991 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'tests/validation/fixtures')
-rw-r--r--tests/validation/fixtures/GEMMFixture.h24
1 files changed, 20 insertions, 4 deletions
diff --git a/tests/validation/fixtures/GEMMFixture.h b/tests/validation/fixtures/GEMMFixture.h
index 854cc4a22b..bf919c9b09 100644
--- a/tests/validation/fixtures/GEMMFixture.h
+++ b/tests/validation/fixtures/GEMMFixture.h
@@ -667,7 +667,7 @@ protected:
SimpleTensor<T> _reference{};
};
-template <typename TensorType, typename AccessorType, typename T, typename ReshapeLHSFunctionType, typename ReshapeRHSFunctionType, typename GEMMFunctionType>
+template <typename TensorType, typename AccessorType, typename T, typename ReshapeLHSFunctionType, typename ReshapeRHSFunctionType, typename GEMMFunctionType, bool fp_mixed_precision = false>
class GEMMMatrixMultiplyReshapedValidationFixture : public framework::Fixture
{
public:
@@ -734,6 +734,7 @@ protected:
kernel_info.reinterpret_input_as_3d = false;
kernel_info.broadcast_bias = broadcast_bias;
kernel_info.activation_info = act_info;
+ kernel_info.fp_mixed_precision = fp_mixed_precision;
// The output tensor will be auto-initialized within the function
@@ -807,14 +808,21 @@ protected:
}
}
- return reference::activation_layer(reference::gemm<T>(lhs, rhs, bias, alpha, beta), act_info);
+ if(fp_mixed_precision)
+ {
+ return reference::activation_layer(reference::gemm_mixed_precision<T>(lhs, rhs, bias, alpha, beta), act_info);
+ }
+ else
+ {
+ return reference::activation_layer(reference::gemm<T>(lhs, rhs, bias, alpha, beta), act_info);
+ }
}
TensorType _target{};
SimpleTensor<T> _reference{};
};
-template <typename TensorType, typename AccessorType, typename T, typename ReshapeLHSFunctionType, typename ReshapeRHSFunctionType, typename GEMMFunctionType>
+template <typename TensorType, typename AccessorType, typename T, typename ReshapeLHSFunctionType, typename ReshapeRHSFunctionType, typename GEMMFunctionType, bool fp_mixed_precision = false>
class GEMMMatrixMultiplyReshaped3DValidationFixture : public framework::Fixture
{
public:
@@ -879,6 +887,7 @@ protected:
kernel_info.reinterpret_input_as_3d = false;
kernel_info.broadcast_bias = true;
kernel_info.activation_info = act_info;
+ kernel_info.fp_mixed_precision = fp_mixed_precision;
// The output tensor will be auto-initialized within the function
@@ -951,7 +960,14 @@ protected:
memcpy(bias.data() + i * n, bias.data(), n * sizeof(T));
}
- return reference::activation_layer(reference::gemm<T>(lhs, rhs, bias, alpha, beta), act_info);
+ if(fp_mixed_precision)
+ {
+ return reference::activation_layer(reference::gemm_mixed_precision<T>(lhs, rhs, bias, alpha, beta), act_info);
+ }
+ else
+ {
+ return reference::activation_layer(reference::gemm<T>(lhs, rhs, bias, alpha, beta), act_info);
+ }
}
TensorType _target{};