aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/GEMMLowpAssemblyFixture.h
diff options
context:
space:
mode:
authorMichalis Spyrou <michalis.spyrou@arm.com>2017-11-21 17:52:12 +0000
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:41:04 +0000
commitf3dfa279d536906dac3e618244b2c1d33e5ff28a (patch)
tree6fdf1bf52ad5ce8fc33e18d5a011633c592b7958 /tests/validation/fixtures/GEMMLowpAssemblyFixture.h
parentf202e50a8b89f143f74c393e33e0154817bd3c1d (diff)
downloadComputeLibrary-f3dfa279d536906dac3e618244b2c1d33e5ff28a.tar.gz
COMPMID-632 Assembly: Integrate gemmlowp assembly version
Integrate generic gemmlowp assembly version for u8. Change-Id: I17ed4494c25a132b2bac581febe1544e49b4f352 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/110114 Tested-by: BSG Visual Compute Jenkins server to access repositories on http://mpd-gerrit.cambridge.arm.com <bsgcomp@arm.com> Reviewed-by: Pablo Tello <pablo.tello@arm.com>
Diffstat (limited to 'tests/validation/fixtures/GEMMLowpAssemblyFixture.h')
-rw-r--r--tests/validation/fixtures/GEMMLowpAssemblyFixture.h42
1 files changed, 31 insertions, 11 deletions
diff --git a/tests/validation/fixtures/GEMMLowpAssemblyFixture.h b/tests/validation/fixtures/GEMMLowpAssemblyFixture.h
index a2587440fb..38e08f7992 100644
--- a/tests/validation/fixtures/GEMMLowpAssemblyFixture.h
+++ b/tests/validation/fixtures/GEMMLowpAssemblyFixture.h
@@ -42,7 +42,7 @@ namespace test
{
namespace validation
{
-template <typename TensorType, typename AccessorType, typename FunctionType>
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T2>
class GEMMLowpAssemblyFixture : public framework::Fixture
{
public:
@@ -66,9 +66,11 @@ protected:
TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_c)
{
+ DataType dt_in = std::is_same<T2, int8_t>::value ? DataType::S8 : DataType::U8;
+
// Create tensors
- TensorType a = create_tensor<TensorType>(shape_a, DataType::S8, 1);
- TensorType b = create_tensor<TensorType>(shape_b, DataType::S8, 1);
+ TensorType a = create_tensor<TensorType>(shape_a, dt_in, 1);
+ TensorType b = create_tensor<TensorType>(shape_b, dt_in, 1);
TensorType c = create_tensor<TensorType>(shape_c, DataType::S32, 1);
// Create and configure function
@@ -89,8 +91,16 @@ protected:
ARM_COMPUTE_EXPECT(!c.info()->is_resizable(), framework::LogLevel::ERRORS);
// Fill tensors
- fill(AccessorType(a), 0, -128, 127);
- fill(AccessorType(b), 1, -128, 127);
+ if(dt_in == DataType::S8)
+ {
+ fill(AccessorType(a), 0, -128, 127);
+ fill(AccessorType(b), 1, -128, 127);
+ }
+ else
+ {
+ fill(AccessorType(a), 0, 0, 128);
+ fill(AccessorType(b), 1, 0, 128);
+ }
fill(AccessorType(c), 2, 0, 0);
// Compute GEMM function
@@ -100,15 +110,25 @@ protected:
SimpleTensor<int32_t> compute_reference(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_c)
{
+ DataType dt = std::is_same<T2, int8_t>::value ? DataType::S8 : DataType::U8;
+
// Create reference
- SimpleTensor<int8_t> a{ shape_a, DataType::S8, 1 };
- SimpleTensor<int8_t> b{ shape_b, DataType::S8, 1 };
+ SimpleTensor<T2> a{ shape_a, dt, 1 };
+ SimpleTensor<T2> b{ shape_b, dt, 1 };
// Fill reference
- fill(a, 0, -128, 127);
- fill(b, 1, -128, 127);
-
- return reference::gemmlowp(a, b);
+ if(dt == DataType::S8)
+ {
+ fill(a, 0, -128, 127);
+ fill(b, 1, -128, 127);
+ }
+ else
+ {
+ fill(a, 0, 0, 128);
+ fill(b, 1, 0, 128);
+ }
+
+ return reference::gemmlowp<int32_t, T2>(a, b);
}
TensorType _target{};