diff options
author | Tai Ly <tai.ly@arm.com> | 2024-03-08 22:19:41 +0000 |
---|---|---|
committer | Tai Ly <tai.ly@arm.com> | 2024-03-17 19:56:21 -0700 |
commit | 60dc48c4ddf30f2a76d4cfcf1b40ca57b6f3bf95 (patch) | |
tree | e3d229a2d596e1a0788dfd75d77b996263055496 /verif | |
parent | e67115ef82bcba0718dcbd75cc8411985001b7cc (diff) | |
download | reference_model-60dc48c4ddf30f2a76d4cfcf1b40ca57b6f3bf95.tar.gz |
[ref model] Change Clamp and Pad attribute fields
This implements changes due to ClampAttribute and PadAttribute
field changes.
Signed-off-by: Tai Ly <tai.ly@arm.com>
Change-Id: Ide01e2a27fe3c1ea7794e7a4b6780b7eae436caf
Diffstat (limited to 'verif')
-rw-r--r-- | verif/generator/tosa_arg_gen.py | 16 | ||||
-rw-r--r-- | verif/generator/tosa_test_gen.py | 24 | ||||
-rw-r--r-- | verif/generator/tosa_utils.py | 5 |
3 files changed, 23 insertions, 22 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index 20572e8..a2ef5bf 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -1813,13 +1813,7 @@ class TosaArgGen: and "data_gen" in testGen.TOSA_OP_LIST[opName] and gtu.dtypeIsSupportedByCompliance(dtype) ): - if dtype in [ - DType.FP16, - DType.FP32, - DType.BF16, - DType.FP8E4M3, - DType.FP8E5M2, - ]: + if gtu.dtypeIsFloat(dtype): dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["fp"] else: dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["int"] @@ -2462,13 +2456,7 @@ class TosaArgGen: if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]: pad_const_int = testGen.getRandNumberDType(dtype) pad_const_fp = 0 - elif dtype in ( - DType.FP16, - DType.BF16, - DType.FP32, - DType.FP8E4M3, - DType.FP8E5M2, - ): + elif gtu.dtypeIsFloat(dtype): pad_const_int = 0 pad_const_fp = testGen.getRandNumberDType(dtype) else: diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index e7704f1..3173906 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -3,6 +3,7 @@ import json import logging import os +import struct from copy import deepcopy from datetime import datetime from pathlib import Path @@ -1428,13 +1429,17 @@ class TosaTestGen: # Non-tensor fp16 ops take fp16 values as fp32 in reference_model min_val = min_val.astype(np.float32) max_val = max_val.astype(np.float32) - - attr.ClampAttribute(self.ser.builder, 0, 0, min_val, max_val) + min_val_as_bytes = struct.pack("<f", min_val) + max_val_as_bytes = struct.pack("<f", max_val) elif a.dtype in (DType.INT8, DType.INT16): - attr.ClampAttribute(self.ser.builder, min_val, max_val, 0, 0) + min_val_as_bytes = struct.pack("<i", min_val) + max_val_as_bytes = struct.pack("<i", max_val) else: # to avoid internal error for incorrect input types - attr.ClampAttribute(self.ser.builder, 0, 0, 0, 0) + min_val_as_bytes = struct.pack("<i", 0) + max_val_as_bytes = struct.pack("<i", 0) + + attr.ClampAttribute(self.ser.builder, min_val_as_bytes, max_val_as_bytes) self.ser.addOperator(op["op"], input_list, output_list, attr) @@ -1578,9 +1583,14 @@ class TosaTestGen: result_tensor = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name) - # write empty padding into PadAttribute to ensure inputs[1] is used + # get pad_const_val_as_bytes from either pad_const_float or pad_const_int + if gtu.dtypeIsFloat(a.dtype): + pad_const_val_as_bytes = struct.pack("<f", pad_const_float) + else: + pad_const_val_as_bytes = struct.pack("<i", pad_const_int) + attr = ts.TosaSerializerAttribute() - attr.PadAttribute(self.ser.builder, [], pad_const_int, pad_const_float) + attr.PadAttribute(self.ser.builder, pad_const_val_as_bytes) # Invalidate Input/Output list for error if checks. input_list = [a.name, pad_input.name] @@ -2271,8 +2281,6 @@ class TosaTestGen: attr.RescaleAttribute( input_zp, output_zp, - [], - [], scale32, double_round, per_channel, diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py index 6558bf8..cfe7cc6 100644 --- a/verif/generator/tosa_utils.py +++ b/verif/generator/tosa_utils.py @@ -64,6 +64,11 @@ def dtypeWidth(dtype): raise Exception(f"Unknown dtype, cannot determine width: {dtype}") +def dtypeIsFloat(dtype): + """Is floating point data type""" + return dtype in (DType.BF16, DType.FP16, DType.FP32, DType.FP8E4M3, DType.FP8E5M2) + + def dtypeIsSupportedByCompliance(dtype): """Types supported by the new data generation and compliance flow.""" if isinstance(dtype, list) or isinstance(dtype, tuple): |