diff options
Diffstat (limited to 'tests/validation/Helpers.h')
-rw-r--r-- | tests/validation/Helpers.h | 51 |
1 files changed, 51 insertions, 0 deletions
diff --git a/tests/validation/Helpers.h b/tests/validation/Helpers.h index cbaea4b894..c0c5865e56 100644 --- a/tests/validation/Helpers.h +++ b/tests/validation/Helpers.h @@ -25,6 +25,7 @@ #define __ARM_COMPUTE_TEST_VALIDATION_HELPERS_H__ #include "Types.h" +#include "ValidationUserConfiguration.h" #include <type_traits> #include <utility> @@ -117,6 +118,56 @@ std::pair<T, T> get_batchnormalization_layer_test_bounds(int fixed_point_positio return bounds; } + +/** Fill mask with the corresponding given pattern. + * + * @param[in,out] mask Mask to be filled according to pattern + * @param[in] cols Columns (width) of mask + * @param[in] rows Rows (height) of mask + * @param[in] pattern Pattern to fill the mask according to + */ +inline void fill_mask_from_pattern(uint8_t *mask, int cols, int rows, MatrixPattern pattern) +{ + unsigned int v = 0; + std::mt19937 gen(user_config.seed.get()); + std::bernoulli_distribution dist(0.5); + + for(int r = 0; r < rows; ++r) + { + for(int c = 0; c < cols; ++c, ++v) + { + uint8_t val = 0; + + switch(pattern) + { + case MatrixPattern::BOX: + val = 255; + break; + case MatrixPattern::CROSS: + val = ((r == (rows / 2)) || (c == (cols / 2))) ? 255 : 0; + break; + case MatrixPattern::DISK: + val = (((r - rows / 2.0f + 0.5f) * (r - rows / 2.0f + 0.5f)) / ((rows / 2.0f) * (rows / 2.0f)) + ((c - cols / 2.0f + 0.5f) * (c - cols / 2.0f + 0.5f)) / ((cols / 2.0f) * + (cols / 2.0f))) <= 1.0f ? 255 : 0; + break; + case MatrixPattern::OTHER: + val = (dist(gen) ? 0 : 255); + break; + default: + return; + } + + mask[v] = val; + } + } + + if(pattern == MatrixPattern::OTHER) + { + std::uniform_int_distribution<uint8_t> distribution_u8(0, ((cols * rows) - 1)); + mask[distribution_u8(gen)] = 255; + } +} + } // namespace validation } // namespace test } // namespace arm_compute |