diff options
Diffstat (limited to 'tests/validation/fixtures/GEMMLowpAssemblyFixture.h')
-rw-r--r-- | tests/validation/fixtures/GEMMLowpAssemblyFixture.h | 42 |
1 files changed, 31 insertions, 11 deletions
diff --git a/tests/validation/fixtures/GEMMLowpAssemblyFixture.h b/tests/validation/fixtures/GEMMLowpAssemblyFixture.h index a2587440fb..38e08f7992 100644 --- a/tests/validation/fixtures/GEMMLowpAssemblyFixture.h +++ b/tests/validation/fixtures/GEMMLowpAssemblyFixture.h @@ -42,7 +42,7 @@ namespace test { namespace validation { -template <typename TensorType, typename AccessorType, typename FunctionType> +template <typename TensorType, typename AccessorType, typename FunctionType, typename T2> class GEMMLowpAssemblyFixture : public framework::Fixture { public: @@ -66,9 +66,11 @@ protected: TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_c) { + DataType dt_in = std::is_same<T2, int8_t>::value ? DataType::S8 : DataType::U8; + // Create tensors - TensorType a = create_tensor<TensorType>(shape_a, DataType::S8, 1); - TensorType b = create_tensor<TensorType>(shape_b, DataType::S8, 1); + TensorType a = create_tensor<TensorType>(shape_a, dt_in, 1); + TensorType b = create_tensor<TensorType>(shape_b, dt_in, 1); TensorType c = create_tensor<TensorType>(shape_c, DataType::S32, 1); // Create and configure function @@ -89,8 +91,16 @@ protected: ARM_COMPUTE_EXPECT(!c.info()->is_resizable(), framework::LogLevel::ERRORS); // Fill tensors - fill(AccessorType(a), 0, -128, 127); - fill(AccessorType(b), 1, -128, 127); + if(dt_in == DataType::S8) + { + fill(AccessorType(a), 0, -128, 127); + fill(AccessorType(b), 1, -128, 127); + } + else + { + fill(AccessorType(a), 0, 0, 128); + fill(AccessorType(b), 1, 0, 128); + } fill(AccessorType(c), 2, 0, 0); // Compute GEMM function @@ -100,15 +110,25 @@ protected: SimpleTensor<int32_t> compute_reference(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_c) { + DataType dt = std::is_same<T2, int8_t>::value ? DataType::S8 : DataType::U8; + // Create reference - SimpleTensor<int8_t> a{ shape_a, DataType::S8, 1 }; - SimpleTensor<int8_t> b{ shape_b, DataType::S8, 1 }; + SimpleTensor<T2> a{ shape_a, dt, 1 }; + SimpleTensor<T2> b{ shape_b, dt, 1 }; // Fill reference - fill(a, 0, -128, 127); - fill(b, 1, -128, 127); - - return reference::gemmlowp(a, b); + if(dt == DataType::S8) + { + fill(a, 0, -128, 127); + fill(b, 1, -128, 127); + } + else + { + fill(a, 0, 0, 128); + fill(b, 1, 0, 128); + } + + return reference::gemmlowp<int32_t, T2>(a, b); } TensorType _target{}; |