diff options
author | Jeremy Johnson <jeremy.johnson@arm.com> | 2022-04-07 11:29:20 +0100 |
---|---|---|
committer | Jeremy Johnson <jeremy.johnson@arm.com> | 2022-04-11 10:25:51 +0100 |
commit | 9a66abbd1da6547fd2cba1512d2f07fd1525de4d (patch) | |
tree | e06d3d728708f0ed4a57c369d4d1a48abdb5f607 /verif/generator/tosa_utils.py | |
parent | 7bebea8c086dc406d774e5a4419914748912089e (diff) | |
download | reference_model-9a66abbd1da6547fd2cba1512d2f07fd1525de4d.tar.gz |
Refactor verif/generator/tosa_test_gen.py into different files
Move all error & validation into tosa_error_if.py
Move all argument and tensor generation into tosa_arg_gen.py
Move utility functions into tosa_utils.py
Create new TosaTensorValuesGen class for specialising tensor
value generation.
Change-Id: Ib9ac65e2308b14471a567c6f11d775c76585bc5b
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Diffstat (limited to 'verif/generator/tosa_utils.py')
-rw-r--r-- | verif/generator/tosa_utils.py | 81 |
1 files changed, 81 insertions, 0 deletions
diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py new file mode 100644 index 0000000..ca115a2 --- /dev/null +++ b/verif/generator/tosa_utils.py @@ -0,0 +1,81 @@ +# Copyright (c) 2021-2022, ARM Limited. +# SPDX-License-Identifier: Apache-2.0 +from tosa.DType import DType + + +def valueToName(item, value): + """Get the name of an attribute with the given value. + + This convenience function is needed to print meaningful names for + the values of the tosa.Op.Op and tosa.DType.DType classes. + This would not be necessary if they were subclasses of Enum, or + IntEnum, which, sadly, they are not. + + Args: + item: The class, or object, to find the value in + value: The value to find + + Example, to get the name of a DType value: + + name = valueToName(DType, DType.INT8) # returns 'INT8' + name = valueToName(DType, 4) # returns 'INT8' + + Returns: + The name of the first attribute found with a matching value, + + Raises: + ValueError if the value is not found + """ + for attr in dir(item): + if getattr(item, attr) == value: + return attr + raise ValueError(f"value ({value}) not found") + + +def allDTypes(*, excludes=None): + """Get a set of all DType values, optionally excluding some values. + + This convenience function is needed to provide a sequence of DType values. + This would be much easier if DType was a subclass of Enum, or IntEnum, + as we could then iterate over the values directly, instead of using + dir() to find the attributes and then check if they are what we want. + + Args: + excludes: iterable of DTYPE values (e.g. [DType.INT8, DType.BOOL]) + + Returns: + A set of DType values + """ + excludes = () if not excludes else excludes + return { + getattr(DType, t) + for t in dir(DType) + if not callable(getattr(DType, t)) + and not t.startswith("__") + and getattr(DType, t) not in excludes + } + + +def usableDTypes(*, excludes=None): + """Get a set of usable DType values, optionally excluding some values. + + Excludes (DType.UNKNOWN, DType.UINT8) in addition to the excludes + specified by the caller, as the serializer lib does not support them. + If you wish to include 'UNKNOWN' or 'UINT8' use allDTypes instead. + + Args: + excludes: iterable of DType values (e.g. [DType.INT8, DType.BOOL]) + + Returns: + A set of DType values + """ + omit = {DType.UNKNOWN, DType.UINT8} + omit.update(excludes if excludes else ()) + return allDTypes(excludes=omit) + + +def product(shape): + value = 1 + for n in shape: + value *= n + return value |