diff options
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r-- | verif/generator/tosa_test_gen.py | 35 |
1 files changed, 18 insertions, 17 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index c5ac0f9..38ab3f4 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -3,7 +3,6 @@ import json import logging import os -import struct from copy import deepcopy from datetime import datetime from pathlib import Path @@ -1390,20 +1389,14 @@ class TosaTestGen: return None attr = ts.TosaSerializerAttribute() - if a.dtype in (DType.BF16, DType.FP16, DType.FP32): - if a.dtype == DType.FP16: - # Non-tensor fp16 ops take fp16 values as fp32 in reference_model - min_val = min_val.astype(np.float32) - max_val = max_val.astype(np.float32) - min_val_as_bytes = struct.pack("<f", min_val) - max_val_as_bytes = struct.pack("<f", max_val) - elif a.dtype in (DType.INT8, DType.INT16): - min_val_as_bytes = struct.pack("<i", min_val) - max_val_as_bytes = struct.pack("<i", max_val) - else: - # to avoid internal error for incorrect input types - min_val_as_bytes = struct.pack("<i", 0) - max_val_as_bytes = struct.pack("<i", 0) + min_val_as_bytes = ts.TosaSerializer.convertDataToUint8Vec(a.dtype, [min_val]) + max_val_as_bytes = ts.TosaSerializer.convertDataToUint8Vec(a.dtype, [max_val]) + + # align to 8 bytes + while (len(min_val_as_bytes) % 8) != 0: + min_val_as_bytes.append(0) + while (len(max_val_as_bytes) % 8) != 0: + max_val_as_bytes.append(0) attr.ClampAttribute(self.ser.builder, min_val_as_bytes, max_val_as_bytes) @@ -1550,9 +1543,17 @@ class TosaTestGen: # get pad_const_val_as_bytes from either pad_const_float or pad_const_int if gtu.dtypeIsFloat(a.dtype): - pad_const_val_as_bytes = struct.pack("<f", pad_const_float) + pad_const_val_as_bytes = ts.TosaSerializer.convertDataToUint8Vec( + a.dtype, [pad_const_float] + ) else: - pad_const_val_as_bytes = struct.pack("<i", pad_const_int) + pad_const_val_as_bytes = ts.TosaSerializer.convertDataToUint8Vec( + a.dtype, [pad_const_int] + ) + + # align to 8 bytes + while (len(pad_const_val_as_bytes) % 8) != 0: + pad_const_val_as_bytes.append(0) attr = ts.TosaSerializerAttribute() attr.PadAttribute(self.ser.builder, pad_const_val_as_bytes) |