aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_random_gen.py
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2023-09-27 14:59:43 +0100
committerEric Kunze <eric.kunze@arm.com>2024-03-19 20:28:57 +0000
commit0a6d1deef02f2bd76b3068d615565f20c46075a5 (patch)
treea90e8a17bb167e83419733d20c5e23f2c9c50af2 /verif/generator/tosa_random_gen.py
parent60dc48c4ddf30f2a76d4cfcf1b40ca57b6f3bf95 (diff)
downloadreference_model-0a6d1deef02f2bd76b3068d615565f20c46075a5.tar.gz
Updated build_tests to support different random generators
All generator functions now take RNG argument to allow different random number generators, rather than relying on global RNG Default behaviour is the same as before using global RNG Added stable random generation mode * shape rng based on operator, rank and datatype * arguments rng based on operator, shape and datatype * build operands and data rng based on op, shape, datatype and args Add optional stable RNG test generation to conformance_generator Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: I5ee4ff85575a81177fd74ed1617e946bfa3a0769
Diffstat (limited to 'verif/generator/tosa_random_gen.py')
-rw-r--r--verif/generator/tosa_random_gen.py174
1 files changed, 174 insertions, 0 deletions
diff --git a/verif/generator/tosa_random_gen.py b/verif/generator/tosa_random_gen.py
new file mode 100644
index 0000000..ae8ae5c
--- /dev/null
+++ b/verif/generator/tosa_random_gen.py
@@ -0,0 +1,174 @@
+# Copyright (c) 2024, ARM Limited.
+# SPDX-License-Identifier: Apache-2.0
+import hashlib
+import logging
+
+import generator.tosa_utils as gtu
+import numpy as np
+from tosa.DType import DType
+
+logging.basicConfig()
+logger = logging.getLogger("tosa_verif_build_tests")
+
+
+class TosaRandomGenerator(np.random.Generator):
+ """Equivalent to numpy.default_rng, with support for TOSA data types"""
+
+ def __init__(self, seed, restrict_range_by_type={}):
+ """Create random generator with TOSA type support.
+
+ seed: integer seed
+ restrict_range_by_type: see TosaHashRandomGenerator.__init__()
+ """
+ self._restrict_range_by_type = restrict_range_by_type
+ self._seed = int(seed)
+ self._bitgen = np.random.PCG64(self._seed)
+ super().__init__(self._bitgen)
+
+ @property
+ def seed(self):
+ return self._seed
+
+ @property
+ def hexSeed(self):
+ return hex(self._seed)
+
+ def dTypeRange(self, dtype, high_inclusive=False):
+ """Returns range tuple for given dtype.
+
+ dtype: DType
+ high_inclusive: True for inclusive high values
+ Returns: dtype value range boundaries tuple (low, high)
+ The high boundary is excluded in the range unless high_inclusive is True
+ """
+ if dtype in self._restrict_range_by_type:
+ rng = self._restrict_range_by_type[dtype]
+ elif dtype == DType.BOOL:
+ rng = (0, 2)
+ elif dtype == DType.UINT8:
+ rng = (0, 256)
+ elif dtype == DType.UINT16:
+ rng = (0, 65536)
+ elif dtype == DType.INT4:
+ # TOSA specific INT4 weight range from -7 to 7
+ rng = (-7, 8)
+ elif dtype == DType.INT8:
+ rng = (-128, 128)
+ elif dtype == DType.INT16:
+ rng = (-32768, 32768)
+ elif dtype == DType.INT32:
+ rng = (-(1 << 31), (1 << 31))
+ elif dtype == DType.INT48:
+ rng = (-(1 << 47), (1 << 47))
+ else:
+ # Float types and SHAPE should be in _restrict_range_by_type dict
+ raise Exception("Unknown supported dtype: {}".format(dtype))
+
+ if dtype in (DType.FP16, DType.BF16, DType.FP32, DType.FP8E4M3, DType.FP8E5M2):
+ # Floating point - range is always inclusive
+ return rng
+ else:
+ # Integer
+ if not high_inclusive:
+ # Exclusive high: low <= range < high
+ return rng
+ else:
+ # Inclusive range: low <= range <= high
+ return (rng[0], rng[1] - 1)
+
+ def randInt(self, low=0, high=256):
+ return np.int32(self.integers(low=low, high=high, size=1))[0]
+
+ def randNumberDType(self, dtype):
+ low, high = self.dTypeRange(dtype)
+
+ if dtype == DType.FP32:
+ return np.float32(self.uniform(low=low, high=high))
+ elif dtype == DType.FP16:
+ return np.float16(self.uniform(low=low, high=high))
+ elif dtype == DType.BF16:
+ rand_f32 = np.float32(self.uniform(low=low, high=high))
+ return gtu.vect_f32_to_bf16(rand_f32)
+ elif dtype == DType.FP8E4M3:
+ rand_f32 = np.float32(self.uniform(low=low, high=high))
+ return gtu.vect_f32_to_fp8e4m3(rand_f32)
+ elif dtype == DType.FP8E5M2:
+ rand_f32 = np.float32(self.uniform(low=low, high=high))
+ return gtu.vect_f32_to_fp8e5m2(rand_f32)
+ elif dtype == DType.BOOL:
+ return self.choice([False, True])
+ elif dtype == DType.INT48 or dtype == DType.SHAPE:
+ # Special size
+ return np.int64(self.integers(low, high, size=1))[0]
+
+ return np.int32(self.integers(low, high, size=1))[0]
+
+ def randTensor(self, shape, dtype, data_range=None):
+ if data_range is None:
+ low, high = self.dTypeRange(dtype)
+ else:
+ low, high = data_range
+
+ if dtype == DType.BOOL:
+ return np.bool_(self.choice(a=[False, True], size=shape))
+ elif dtype == DType.INT4:
+ return np.int8(self.integers(low=low, high=high, size=shape))
+ elif dtype == DType.INT8:
+ return np.int8(self.integers(low=low, high=high, size=shape))
+ elif dtype == DType.UINT8:
+ return np.uint8(self.integers(low=low, high=high, size=shape))
+ elif dtype == DType.INT16:
+ return np.int16(self.integers(low=low, high=high, size=shape))
+ elif dtype == DType.UINT16:
+ return np.uint16(self.integers(low=low, high=high, size=shape))
+ elif dtype in (DType.INT48, DType.SHAPE):
+ return np.int64(self.integers(low=low, high=high, size=shape))
+ elif dtype in (
+ DType.FP16,
+ DType.BF16,
+ DType.FP32,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
+ ):
+ f_tensor = self.uniform(low=low, high=high, size=shape)
+
+ if dtype == DType.FP16:
+ return np.float16(f_tensor)
+ else:
+ f32_tensor = np.float32(f_tensor)
+ if dtype == DType.BF16:
+ # Floor the last 16 bits of each f32 value
+ return np.float32(gtu.vect_f32_to_bf16(f32_tensor))
+ elif dtype == DType.FP8E4M3:
+ return np.float32(gtu.vect_f32_to_fp8e4m3(f32_tensor))
+ elif dtype == DType.FP8E5M2:
+ return np.float32(gtu.vect_f32_to_fp8e5m2(f32_tensor))
+ else:
+ return f32_tensor
+ else:
+ # All other integer types
+ return np.int32(self.integers(low=low, high=high, size=shape))
+
+
+class TosaHashRandomGenerator(TosaRandomGenerator):
+ """Hash seeded TOSA random number generator."""
+
+ def __init__(self, seed, seed_list, restrict_range_by_type={}):
+ """Create TOSA random generator seeding it with a hashable list.
+
+ seed: integer starting seed
+ seed_list: list of hashable items to add to starting seed
+ restrict_range_by_type: dictionary of DTypes with (low, high) range tuples
+ This must contain entries for SHAPE and all Floating Point data types.
+ NOTE: For integers, the high value must be the exclusive value
+ """
+ # Convert seed_list to strings
+ seed_strings_list = [str(s) for s in seed_list]
+ # Create a single string and create hash
+ self._seed_string = "__".join(seed_strings_list)
+ self._hash = hashlib.md5(bytes(self._seed_string, "utf-8"))
+ # Add the hash value to the given seed
+ seed += int(self._hash.hexdigest(), 16)
+
+ logger.debug(f"Seed={seed} Seed string={self._seed_string}")
+ super().__init__(seed, restrict_range_by_type)