aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_utils.py')
-rw-r--r--verif/generator/tosa_utils.py81
1 files changed, 81 insertions, 0 deletions
diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py
new file mode 100644
index 0000000..ca115a2
--- /dev/null
+++ b/verif/generator/tosa_utils.py
@@ -0,0 +1,81 @@
+# Copyright (c) 2021-2022, ARM Limited.
+# SPDX-License-Identifier: Apache-2.0
+from tosa.DType import DType
+
+
+def valueToName(item, value):
+ """Get the name of an attribute with the given value.
+
+ This convenience function is needed to print meaningful names for
+ the values of the tosa.Op.Op and tosa.DType.DType classes.
+ This would not be necessary if they were subclasses of Enum, or
+ IntEnum, which, sadly, they are not.
+
+ Args:
+ item: The class, or object, to find the value in
+ value: The value to find
+
+ Example, to get the name of a DType value:
+
+ name = valueToName(DType, DType.INT8) # returns 'INT8'
+ name = valueToName(DType, 4) # returns 'INT8'
+
+ Returns:
+ The name of the first attribute found with a matching value,
+
+ Raises:
+ ValueError if the value is not found
+ """
+ for attr in dir(item):
+ if getattr(item, attr) == value:
+ return attr
+ raise ValueError(f"value ({value}) not found")
+
+
+def allDTypes(*, excludes=None):
+ """Get a set of all DType values, optionally excluding some values.
+
+ This convenience function is needed to provide a sequence of DType values.
+ This would be much easier if DType was a subclass of Enum, or IntEnum,
+ as we could then iterate over the values directly, instead of using
+ dir() to find the attributes and then check if they are what we want.
+
+ Args:
+ excludes: iterable of DTYPE values (e.g. [DType.INT8, DType.BOOL])
+
+ Returns:
+ A set of DType values
+ """
+ excludes = () if not excludes else excludes
+ return {
+ getattr(DType, t)
+ for t in dir(DType)
+ if not callable(getattr(DType, t))
+ and not t.startswith("__")
+ and getattr(DType, t) not in excludes
+ }
+
+
+def usableDTypes(*, excludes=None):
+ """Get a set of usable DType values, optionally excluding some values.
+
+ Excludes (DType.UNKNOWN, DType.UINT8) in addition to the excludes
+ specified by the caller, as the serializer lib does not support them.
+ If you wish to include 'UNKNOWN' or 'UINT8' use allDTypes instead.
+
+ Args:
+ excludes: iterable of DType values (e.g. [DType.INT8, DType.BOOL])
+
+ Returns:
+ A set of DType values
+ """
+ omit = {DType.UNKNOWN, DType.UINT8}
+ omit.update(excludes if excludes else ())
+ return allDTypes(excludes=omit)
+
+
+def product(shape):
+ value = 1
+ for n in shape:
+ value *= n
+ return value