aboutsummaryrefslogtreecommitdiff
path: root/verif
diff options
context:
space:
mode:
Diffstat (limited to 'verif')
-rw-r--r--verif/generator/datagenerator.py24
-rw-r--r--verif/generator/tosa_arg_gen.py4
-rw-r--r--verif/generator/tosa_utils.py2
3 files changed, 21 insertions, 9 deletions
diff --git a/verif/generator/datagenerator.py b/verif/generator/datagenerator.py
index 0d59084..9de421b 100644
--- a/verif/generator/datagenerator.py
+++ b/verif/generator/datagenerator.py
@@ -68,19 +68,33 @@ class GenerateLibrary:
def _create_buffer(self, dtype: str, shape: tuple):
"""Helper to create a buffer of the required type."""
- size = 1
- for dim in shape:
- size *= dim
+ size = np.prod(shape)
if dtype == "FP32":
# Create buffer and initialize to zero
buffer = (ct.c_float * size)(0)
size_bytes = size * 4
+ elif dtype == "FP16":
+ size_bytes = size * 2
+ # Create buffer of bytes and initialize to zero
+ buffer = (ct.c_ubyte * size_bytes)(0)
else:
raise GenerateError(f"Unsupported data type {dtype}")
return buffer, size_bytes
+ def _convert_buffer(self, buffer, dtype: str, shape: tuple):
+ """Helper to convert a buffer to a numpy array."""
+ arr = np.ctypeslib.as_array(buffer)
+
+ if dtype == "FP16":
+ # Convert from bytes back to FP16
+ arr = np.frombuffer(arr, np.float16)
+
+ arr = np.reshape(arr, shape)
+
+ return arr
+
def _data_gen_array(self, json_config: str, tensor_name: str):
"""Generate the named tensor data and return a numpy array."""
try:
@@ -106,9 +120,7 @@ class GenerateLibrary:
if not result:
raise GenerateError("Data generate failed")
- arr = np.ctypeslib.as_array(buffer)
- arr = np.reshape(arr, shape)
-
+ arr = self._convert_buffer(buffer, dtype, shape)
return arr
def _data_gen_write(
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index 8e88390..193da73 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -1415,9 +1415,9 @@ class TosaTensorValuesGen:
if (
error_name is None
and argsDict["dg_type"] != gtu.ComplianceMode.DOT_PRODUCT
- and dtype in (DType.FP16, DType.BF16)
+ and dtype in (DType.BF16,)
):
- # TODO - Remove once FP16 and BF16 enabled for DOT_PRODUCT compliance
+ # TODO - Remove once BF16 enabled for DOT_PRODUCT compliance
# Limit ranges for (non error & non compliance) FP tests by using
# values that can be multiplied on any axis to not hit infinity/NaN
IC = shapeList[0][1]
diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py
index 318f296..3d733f4 100644
--- a/verif/generator/tosa_utils.py
+++ b/verif/generator/tosa_utils.py
@@ -55,7 +55,7 @@ def dtypeIsSupportedByCompliance(dtype):
"""Types supported by the new data generation and compliance flow."""
if isinstance(dtype, list) or isinstance(dtype, tuple):
dtype = dtype[0]
- return dtype in (DType.FP32,)
+ return dtype in (DType.FP32, DType.FP16)
def getOpNameFromOpListName(opName):