aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/GEMMFixture.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/fixtures/GEMMFixture.h')
-rw-r--r--tests/validation/fixtures/GEMMFixture.h24
1 files changed, 20 insertions, 4 deletions
diff --git a/tests/validation/fixtures/GEMMFixture.h b/tests/validation/fixtures/GEMMFixture.h
index 854cc4a22b..bf919c9b09 100644
--- a/tests/validation/fixtures/GEMMFixture.h
+++ b/tests/validation/fixtures/GEMMFixture.h
@@ -667,7 +667,7 @@ protected:
SimpleTensor<T> _reference{};
};
-template <typename TensorType, typename AccessorType, typename T, typename ReshapeLHSFunctionType, typename ReshapeRHSFunctionType, typename GEMMFunctionType>
+template <typename TensorType, typename AccessorType, typename T, typename ReshapeLHSFunctionType, typename ReshapeRHSFunctionType, typename GEMMFunctionType, bool fp_mixed_precision = false>
class GEMMMatrixMultiplyReshapedValidationFixture : public framework::Fixture
{
public:
@@ -734,6 +734,7 @@ protected:
kernel_info.reinterpret_input_as_3d = false;
kernel_info.broadcast_bias = broadcast_bias;
kernel_info.activation_info = act_info;
+ kernel_info.fp_mixed_precision = fp_mixed_precision;
// The output tensor will be auto-initialized within the function
@@ -807,14 +808,21 @@ protected:
}
}
- return reference::activation_layer(reference::gemm<T>(lhs, rhs, bias, alpha, beta), act_info);
+ if(fp_mixed_precision)
+ {
+ return reference::activation_layer(reference::gemm_mixed_precision<T>(lhs, rhs, bias, alpha, beta), act_info);
+ }
+ else
+ {
+ return reference::activation_layer(reference::gemm<T>(lhs, rhs, bias, alpha, beta), act_info);
+ }
}
TensorType _target{};
SimpleTensor<T> _reference{};
};
-template <typename TensorType, typename AccessorType, typename T, typename ReshapeLHSFunctionType, typename ReshapeRHSFunctionType, typename GEMMFunctionType>
+template <typename TensorType, typename AccessorType, typename T, typename ReshapeLHSFunctionType, typename ReshapeRHSFunctionType, typename GEMMFunctionType, bool fp_mixed_precision = false>
class GEMMMatrixMultiplyReshaped3DValidationFixture : public framework::Fixture
{
public:
@@ -879,6 +887,7 @@ protected:
kernel_info.reinterpret_input_as_3d = false;
kernel_info.broadcast_bias = true;
kernel_info.activation_info = act_info;
+ kernel_info.fp_mixed_precision = fp_mixed_precision;
// The output tensor will be auto-initialized within the function
@@ -951,7 +960,14 @@ protected:
memcpy(bias.data() + i * n, bias.data(), n * sizeof(T));
}
- return reference::activation_layer(reference::gemm<T>(lhs, rhs, bias, alpha, beta), act_info);
+ if(fp_mixed_precision)
+ {
+ return reference::activation_layer(reference::gemm_mixed_precision<T>(lhs, rhs, bias, alpha, beta), act_info);
+ }
+ else
+ {
+ return reference::activation_layer(reference::gemm<T>(lhs, rhs, bias, alpha, beta), act_info);
+ }
}
TensorType _target{};