diff options
Diffstat (limited to 'verif')
-rw-r--r-- | verif/generator/datagenerator.py | 24 | ||||
-rw-r--r-- | verif/generator/tosa_arg_gen.py | 4 | ||||
-rw-r--r-- | verif/generator/tosa_utils.py | 2 |
3 files changed, 21 insertions, 9 deletions
diff --git a/verif/generator/datagenerator.py b/verif/generator/datagenerator.py index 0d59084..9de421b 100644 --- a/verif/generator/datagenerator.py +++ b/verif/generator/datagenerator.py @@ -68,19 +68,33 @@ class GenerateLibrary: def _create_buffer(self, dtype: str, shape: tuple): """Helper to create a buffer of the required type.""" - size = 1 - for dim in shape: - size *= dim + size = np.prod(shape) if dtype == "FP32": # Create buffer and initialize to zero buffer = (ct.c_float * size)(0) size_bytes = size * 4 + elif dtype == "FP16": + size_bytes = size * 2 + # Create buffer of bytes and initialize to zero + buffer = (ct.c_ubyte * size_bytes)(0) else: raise GenerateError(f"Unsupported data type {dtype}") return buffer, size_bytes + def _convert_buffer(self, buffer, dtype: str, shape: tuple): + """Helper to convert a buffer to a numpy array.""" + arr = np.ctypeslib.as_array(buffer) + + if dtype == "FP16": + # Convert from bytes back to FP16 + arr = np.frombuffer(arr, np.float16) + + arr = np.reshape(arr, shape) + + return arr + def _data_gen_array(self, json_config: str, tensor_name: str): """Generate the named tensor data and return a numpy array.""" try: @@ -106,9 +120,7 @@ class GenerateLibrary: if not result: raise GenerateError("Data generate failed") - arr = np.ctypeslib.as_array(buffer) - arr = np.reshape(arr, shape) - + arr = self._convert_buffer(buffer, dtype, shape) return arr def _data_gen_write( diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index 8e88390..193da73 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -1415,9 +1415,9 @@ class TosaTensorValuesGen: if ( error_name is None and argsDict["dg_type"] != gtu.ComplianceMode.DOT_PRODUCT - and dtype in (DType.FP16, DType.BF16) + and dtype in (DType.BF16,) ): - # TODO - Remove once FP16 and BF16 enabled for DOT_PRODUCT compliance + # TODO - Remove once BF16 enabled for DOT_PRODUCT compliance # Limit ranges for (non error & non compliance) FP tests by using # values that can be multiplied on any axis to not hit infinity/NaN IC = shapeList[0][1] diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py index 318f296..3d733f4 100644 --- a/verif/generator/tosa_utils.py +++ b/verif/generator/tosa_utils.py @@ -55,7 +55,7 @@ def dtypeIsSupportedByCompliance(dtype): """Types supported by the new data generation and compliance flow.""" if isinstance(dtype, list) or isinstance(dtype, tuple): dtype = dtype[0] - return dtype in (DType.FP32,) + return dtype in (DType.FP32, DType.FP16) def getOpNameFromOpListName(opName): |