diff options
Diffstat (limited to 'tests/validation/CPP')
-rw-r--r-- | tests/validation/CPP/GEMMLowp.cpp | 8 | ||||
-rw-r--r-- | tests/validation/CPP/GEMMLowp.h | 2 |
2 files changed, 5 insertions, 5 deletions
diff --git a/tests/validation/CPP/GEMMLowp.cpp b/tests/validation/CPP/GEMMLowp.cpp index 06926e631e..e1d76503cd 100644 --- a/tests/validation/CPP/GEMMLowp.cpp +++ b/tests/validation/CPP/GEMMLowp.cpp @@ -34,7 +34,7 @@ namespace validation { namespace reference { -SimpleTensor<uint32_t> gemmlowp(const SimpleTensor<uint8_t> &a, const SimpleTensor<uint8_t> &b, SimpleTensor<uint32_t> &c) +SimpleTensor<int32_t> gemmlowp(const SimpleTensor<int8_t> &a, const SimpleTensor<int8_t> &b, SimpleTensor<int32_t> &c) { ARM_COMPUTE_UNUSED(a); ARM_COMPUTE_UNUSED(b); @@ -99,15 +99,15 @@ SimpleTensor<T> gemmlowp(const SimpleTensor<T> &a, const SimpleTensor<T> &b, Sim for(int j = 0; j < cols; ++j) { const int32_t result = ((c_offset + acc[j]) * c_mult_int) >> out_shift; - c[j + i * cols] = static_cast<uint8_t>(std::min(255, std::max(0, result))); + c[j + i * cols] = static_cast<int8_t>(std::min(127, std::max(-128, result))); } } return c; } -template SimpleTensor<uint8_t> gemmlowp(const SimpleTensor<uint8_t> &a, const SimpleTensor<uint8_t> &b, SimpleTensor<uint8_t> &c, - int32_t a_offset, int32_t b_offset, int32_t c_offset, int32_t c_mult_int, int32_t out_shift); +template SimpleTensor<int8_t> gemmlowp(const SimpleTensor<int8_t> &a, const SimpleTensor<int8_t> &b, SimpleTensor<int8_t> &c, + int32_t a_offset, int32_t b_offset, int32_t c_offset, int32_t c_mult_int, int32_t out_shift); } // namespace reference } // namespace validation } // namespace test diff --git a/tests/validation/CPP/GEMMLowp.h b/tests/validation/CPP/GEMMLowp.h index 0428e9e34f..2f903f2fe2 100644 --- a/tests/validation/CPP/GEMMLowp.h +++ b/tests/validation/CPP/GEMMLowp.h @@ -35,7 +35,7 @@ namespace validation { namespace reference { -SimpleTensor<uint32_t> gemmlowp(const SimpleTensor<uint8_t> &a, const SimpleTensor<uint8_t> &b, SimpleTensor<uint32_t> &c); +SimpleTensor<int32_t> gemmlowp(const SimpleTensor<int8_t> &a, const SimpleTensor<int8_t> &b, SimpleTensor<int32_t> &c); template <typename T> SimpleTensor<T> gemmlowp(const SimpleTensor<T> &a, const SimpleTensor<T> &b, SimpleTensor<T> &c, |