diff options
Diffstat (limited to 'tests/validation/fixtures/ScatterLayerFixture.h')
-rw-r--r-- | tests/validation/fixtures/ScatterLayerFixture.h | 21 |
1 files changed, 19 insertions, 2 deletions
diff --git a/tests/validation/fixtures/ScatterLayerFixture.h b/tests/validation/fixtures/ScatterLayerFixture.h index 4fb2d7f127..35e6b647f3 100644 --- a/tests/validation/fixtures/ScatterLayerFixture.h +++ b/tests/validation/fixtures/ScatterLayerFixture.h @@ -63,13 +63,30 @@ public: protected: template <typename U> - void fill(U &&tensor, int i, float lo = -10.f, float hi = 10.f) + void fill(U &&tensor, int i) { switch(tensor.data_type()) { case DataType::F32: + case DataType::F16: { - std::uniform_real_distribution<float> distribution(lo, hi); + std::uniform_real_distribution<float> distribution(-10.f, 10.f); + library->fill(tensor, distribution, i); + break; + } + case DataType::S32: + case DataType::S16: + case DataType::S8: + { + std::uniform_int_distribution<int32_t> distribution(-100, 100); + library->fill(tensor, distribution, i); + break; + } + case DataType::U32: + case DataType::U16: + case DataType::U8: + { + std::uniform_int_distribution<uint32_t> distribution(0, 200); library->fill(tensor, distribution, i); break; } |