aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/GEMMLowpAssemblyFixture.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/fixtures/GEMMLowpAssemblyFixture.h')
-rw-r--r--tests/validation/fixtures/GEMMLowpAssemblyFixture.h42
1 files changed, 31 insertions, 11 deletions
diff --git a/tests/validation/fixtures/GEMMLowpAssemblyFixture.h b/tests/validation/fixtures/GEMMLowpAssemblyFixture.h
index a2587440fb..38e08f7992 100644
--- a/tests/validation/fixtures/GEMMLowpAssemblyFixture.h
+++ b/tests/validation/fixtures/GEMMLowpAssemblyFixture.h
@@ -42,7 +42,7 @@ namespace test
{
namespace validation
{
-template <typename TensorType, typename AccessorType, typename FunctionType>
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T2>
class GEMMLowpAssemblyFixture : public framework::Fixture
{
public:
@@ -66,9 +66,11 @@ protected:
TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_c)
{
+ DataType dt_in = std::is_same<T2, int8_t>::value ? DataType::S8 : DataType::U8;
+
// Create tensors
- TensorType a = create_tensor<TensorType>(shape_a, DataType::S8, 1);
- TensorType b = create_tensor<TensorType>(shape_b, DataType::S8, 1);
+ TensorType a = create_tensor<TensorType>(shape_a, dt_in, 1);
+ TensorType b = create_tensor<TensorType>(shape_b, dt_in, 1);
TensorType c = create_tensor<TensorType>(shape_c, DataType::S32, 1);
// Create and configure function
@@ -89,8 +91,16 @@ protected:
ARM_COMPUTE_EXPECT(!c.info()->is_resizable(), framework::LogLevel::ERRORS);
// Fill tensors
- fill(AccessorType(a), 0, -128, 127);
- fill(AccessorType(b), 1, -128, 127);
+ if(dt_in == DataType::S8)
+ {
+ fill(AccessorType(a), 0, -128, 127);
+ fill(AccessorType(b), 1, -128, 127);
+ }
+ else
+ {
+ fill(AccessorType(a), 0, 0, 128);
+ fill(AccessorType(b), 1, 0, 128);
+ }
fill(AccessorType(c), 2, 0, 0);
// Compute GEMM function
@@ -100,15 +110,25 @@ protected:
SimpleTensor<int32_t> compute_reference(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_c)
{
+ DataType dt = std::is_same<T2, int8_t>::value ? DataType::S8 : DataType::U8;
+
// Create reference
- SimpleTensor<int8_t> a{ shape_a, DataType::S8, 1 };
- SimpleTensor<int8_t> b{ shape_b, DataType::S8, 1 };
+ SimpleTensor<T2> a{ shape_a, dt, 1 };
+ SimpleTensor<T2> b{ shape_b, dt, 1 };
// Fill reference
- fill(a, 0, -128, 127);
- fill(b, 1, -128, 127);
-
- return reference::gemmlowp(a, b);
+ if(dt == DataType::S8)
+ {
+ fill(a, 0, -128, 127);
+ fill(b, 1, -128, 127);
+ }
+ else
+ {
+ fill(a, 0, 0, 128);
+ fill(b, 1, 0, 128);
+ }
+
+ return reference::gemmlowp<int32_t, T2>(a, b);
}
TensorType _target{};