aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_test_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r--verif/generator/tosa_test_gen.py24
1 files changed, 16 insertions, 8 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index e7704f1..3173906 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -3,6 +3,7 @@
import json
import logging
import os
+import struct
from copy import deepcopy
from datetime import datetime
from pathlib import Path
@@ -1428,13 +1429,17 @@ class TosaTestGen:
# Non-tensor fp16 ops take fp16 values as fp32 in reference_model
min_val = min_val.astype(np.float32)
max_val = max_val.astype(np.float32)
-
- attr.ClampAttribute(self.ser.builder, 0, 0, min_val, max_val)
+ min_val_as_bytes = struct.pack("<f", min_val)
+ max_val_as_bytes = struct.pack("<f", max_val)
elif a.dtype in (DType.INT8, DType.INT16):
- attr.ClampAttribute(self.ser.builder, min_val, max_val, 0, 0)
+ min_val_as_bytes = struct.pack("<i", min_val)
+ max_val_as_bytes = struct.pack("<i", max_val)
else:
# to avoid internal error for incorrect input types
- attr.ClampAttribute(self.ser.builder, 0, 0, 0, 0)
+ min_val_as_bytes = struct.pack("<i", 0)
+ max_val_as_bytes = struct.pack("<i", 0)
+
+ attr.ClampAttribute(self.ser.builder, min_val_as_bytes, max_val_as_bytes)
self.ser.addOperator(op["op"], input_list, output_list, attr)
@@ -1578,9 +1583,14 @@ class TosaTestGen:
result_tensor = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
- # write empty padding into PadAttribute to ensure inputs[1] is used
+ # get pad_const_val_as_bytes from either pad_const_float or pad_const_int
+ if gtu.dtypeIsFloat(a.dtype):
+ pad_const_val_as_bytes = struct.pack("<f", pad_const_float)
+ else:
+ pad_const_val_as_bytes = struct.pack("<i", pad_const_int)
+
attr = ts.TosaSerializerAttribute()
- attr.PadAttribute(self.ser.builder, [], pad_const_int, pad_const_float)
+ attr.PadAttribute(self.ser.builder, pad_const_val_as_bytes)
# Invalidate Input/Output list for error if checks.
input_list = [a.name, pad_input.name]
@@ -2271,8 +2281,6 @@ class TosaTestGen:
attr.RescaleAttribute(
input_zp,
output_zp,
- [],
- [],
scale32,
double_round,
per_channel,