diff options
Diffstat (limited to 'tests/validation')
-rw-r--r-- | tests/validation/fixtures/ScaleFixture.h | 15 |
1 files changed, 11 insertions, 4 deletions
diff --git a/tests/validation/fixtures/ScaleFixture.h b/tests/validation/fixtures/ScaleFixture.h index 6fa810aa96..476985e066 100644 --- a/tests/validation/fixtures/ScaleFixture.h +++ b/tests/validation/fixtures/ScaleFixture.h @@ -46,15 +46,22 @@ public: template <typename...> void setup(TensorShape shape, DataType data_type, InterpolationPolicy policy, BorderMode border_mode) { + constexpr float max_width = 8192.0f; + constexpr float max_height = 6384.0f; + _shape = shape; _policy = policy; _border_mode = border_mode; _data_type = data_type; - std::mt19937 generator(library->seed()); - std::uniform_real_distribution<float> distribution_float(0.25, 4); - const float scale_x = distribution_float(generator); - const float scale_y = distribution_float(generator); + std::mt19937 generator(library->seed()); + std::uniform_real_distribution<float> distribution_float(0.25, 3); + float scale_x = distribution_float(generator); + float scale_y = distribution_float(generator); + + scale_x = ((shape.x() * scale_x) > max_width) ? (max_width / shape.x()) : scale_x; + scale_y = ((shape.y() * scale_y) > max_height) ? (max_height / shape.y()) : scale_y; + std::uniform_int_distribution<uint8_t> distribution_u8(0, 255); T constant_border_value = static_cast<T>(distribution_u8(generator)); |