diff options
Diffstat (limited to 'verif/generator/tosa_utils.py')
-rw-r--r-- | verif/generator/tosa_utils.py | 81 |
1 files changed, 81 insertions, 0 deletions
diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py new file mode 100644 index 0000000..ca115a2 --- /dev/null +++ b/verif/generator/tosa_utils.py @@ -0,0 +1,81 @@ +# Copyright (c) 2021-2022, ARM Limited. +# SPDX-License-Identifier: Apache-2.0 +from tosa.DType import DType + + +def valueToName(item, value): + """Get the name of an attribute with the given value. + + This convenience function is needed to print meaningful names for + the values of the tosa.Op.Op and tosa.DType.DType classes. + This would not be necessary if they were subclasses of Enum, or + IntEnum, which, sadly, they are not. + + Args: + item: The class, or object, to find the value in + value: The value to find + + Example, to get the name of a DType value: + + name = valueToName(DType, DType.INT8) # returns 'INT8' + name = valueToName(DType, 4) # returns 'INT8' + + Returns: + The name of the first attribute found with a matching value, + + Raises: + ValueError if the value is not found + """ + for attr in dir(item): + if getattr(item, attr) == value: + return attr + raise ValueError(f"value ({value}) not found") + + +def allDTypes(*, excludes=None): + """Get a set of all DType values, optionally excluding some values. + + This convenience function is needed to provide a sequence of DType values. + This would be much easier if DType was a subclass of Enum, or IntEnum, + as we could then iterate over the values directly, instead of using + dir() to find the attributes and then check if they are what we want. + + Args: + excludes: iterable of DTYPE values (e.g. [DType.INT8, DType.BOOL]) + + Returns: + A set of DType values + """ + excludes = () if not excludes else excludes + return { + getattr(DType, t) + for t in dir(DType) + if not callable(getattr(DType, t)) + and not t.startswith("__") + and getattr(DType, t) not in excludes + } + + +def usableDTypes(*, excludes=None): + """Get a set of usable DType values, optionally excluding some values. + + Excludes (DType.UNKNOWN, DType.UINT8) in addition to the excludes + specified by the caller, as the serializer lib does not support them. + If you wish to include 'UNKNOWN' or 'UINT8' use allDTypes instead. + + Args: + excludes: iterable of DType values (e.g. [DType.INT8, DType.BOOL]) + + Returns: + A set of DType values + """ + omit = {DType.UNKNOWN, DType.UINT8} + omit.update(excludes if excludes else ()) + return allDTypes(excludes=omit) + + +def product(shape): + value = 1 + for n in shape: + value *= n + return value |