aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/Helpers.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/Helpers.h')
-rw-r--r--tests/validation/Helpers.h50
1 files changed, 50 insertions, 0 deletions
diff --git a/tests/validation/Helpers.h b/tests/validation/Helpers.h
index 30959161bb..5fdd57f7ba 100644
--- a/tests/validation/Helpers.h
+++ b/tests/validation/Helpers.h
@@ -26,6 +26,7 @@
#include "arm_compute/core/Types.h"
#include "arm_compute/core/Utils.h"
+#include "tests/Globals.h"
#include "tests/validation/half.h"
#include <random>
@@ -129,6 +130,55 @@ std::pair<T, T> get_activation_layer_test_bounds(ActivationLayerInfo::Activation
return bounds;
}
+/** Fill mask with the corresponding given pattern.
+ *
+ * @param[in,out] mask Mask to be filled according to pattern
+ * @param[in] cols Columns (width) of mask
+ * @param[in] rows Rows (height) of mask
+ * @param[in] pattern Pattern to fill the mask according to
+ */
+inline void fill_mask_from_pattern(uint8_t *mask, int cols, int rows, MatrixPattern pattern)
+{
+ unsigned int v = 0;
+ std::mt19937 gen(library->seed());
+ std::bernoulli_distribution dist(0.5);
+
+ for(int r = 0; r < rows; ++r)
+ {
+ for(int c = 0; c < cols; ++c, ++v)
+ {
+ uint8_t val = 0;
+
+ switch(pattern)
+ {
+ case MatrixPattern::BOX:
+ val = 255;
+ break;
+ case MatrixPattern::CROSS:
+ val = ((r == (rows / 2)) || (c == (cols / 2))) ? 255 : 0;
+ break;
+ case MatrixPattern::DISK:
+ val = (((r - rows / 2.0f + 0.5f) * (r - rows / 2.0f + 0.5f)) / ((rows / 2.0f) * (rows / 2.0f)) + ((c - cols / 2.0f + 0.5f) * (c - cols / 2.0f + 0.5f)) / ((cols / 2.0f) *
+ (cols / 2.0f))) <= 1.0f ? 255 : 0;
+ break;
+ case MatrixPattern::OTHER:
+ val = (dist(gen) ? 0 : 255);
+ break;
+ default:
+ return;
+ }
+
+ mask[v] = val;
+ }
+ }
+
+ if(pattern == MatrixPattern::OTHER)
+ {
+ std::uniform_int_distribution<uint8_t> distribution_u8(0, ((cols * rows) - 1));
+ mask[distribution_u8(gen)] = 255;
+ }
+}
+
/** Calculate output tensor shape give a vector of input tensor to concatenate
*
* @param[in] input_shapes Shapes of the tensors to concatenate across depth.