aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_utils.py')
-rw-r--r--verif/generator/tosa_utils.py59
1 files changed, 46 insertions, 13 deletions
diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py
index 3cd0370..75a0df5 100644
--- a/verif/generator/tosa_utils.py
+++ b/verif/generator/tosa_utils.py
@@ -1,7 +1,8 @@
-# Copyright (c) 2021-2022, ARM Limited.
+# Copyright (c) 2021-2023, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
import struct
import sys
+from enum import IntEnum
import numpy as np
from tosa.DType import DType
@@ -9,22 +10,54 @@ from tosa.DType import DType
# Maximum dimension size for output and inputs for RESIZE
MAX_RESIZE_DIMENSION = 16384
+# Data type information dictionary
+# - str: filename abbreviation
+# - width: number of bytes needed for type
+# - json: JSON type string
DTYPE_ATTRIBUTES = {
- DType.BOOL: {"str": "b", "width": 1},
- DType.INT4: {"str": "i4", "width": 4},
- DType.INT8: {"str": "i8", "width": 8},
- DType.UINT8: {"str": "u8", "width": 8},
- DType.INT16: {"str": "i16", "width": 16},
- DType.UINT16: {"str": "u16", "width": 16},
- DType.INT32: {"str": "i32", "width": 32},
- DType.INT48: {"str": "i48", "width": 48},
- DType.SHAPE: {"str": "i64", "width": 64},
- DType.FP16: {"str": "f16", "width": 16},
- DType.BF16: {"str": "bf16", "width": 16},
- DType.FP32: {"str": "f32", "width": 32},
+ DType.BOOL: {"str": "b", "width": 1, "json": "BOOL"},
+ DType.INT4: {"str": "i4", "width": 4, "json": "INT4"},
+ DType.INT8: {"str": "i8", "width": 8, "json": "INT8"},
+ DType.UINT8: {"str": "u8", "width": 8, "json": "UINT8"},
+ DType.INT16: {"str": "i16", "width": 16, "json": "INT16"},
+ DType.UINT16: {"str": "u16", "width": 16, "json": "UINT16"},
+ DType.INT32: {"str": "i32", "width": 32, "json": "INT32"},
+ DType.INT48: {"str": "i48", "width": 48, "json": "INT48"},
+ DType.SHAPE: {"str": "s", "width": 64, "json": "SHAPE"},
+ DType.FP16: {"str": "f16", "width": 16, "json": "FP16"},
+ DType.BF16: {"str": "bf16", "width": 16, "json": "BF16"},
+ DType.FP32: {"str": "f32", "width": 32, "json": "FP32"},
}
+class ComplianceMode(IntEnum):
+ """Compliance mode types."""
+
+ EXACT = 0
+ DOT_PRODUCT = 1
+ ULP = 2
+ FP_SPECIAL = 3
+ REDUCE_PRODUCT = 4
+
+
+class DataGenType(IntEnum):
+ """Data generator types."""
+
+ PSEUDO_RANDOM = 0
+ DOT_PRODUCT = 1
+ OP_BOUNDARY = 2
+ OP_FULLSET = 3
+ OP_SPECIAL = 4
+
+
+# Additional (optional) data for dot product data generator
+DG_DOT_PRODUCT_OPTIONAL_INFO = ("acc_type", "kernel", "axis")
+
+
+def dtypeIsFloat(dtype):
+ return dtype in (DType.FP16, DType.BF16, DType.FP32)
+
+
def valueToName(item, value):
"""Get the name of an attribute with the given value.