aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_test_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r--verif/generator/tosa_test_gen.py15
1 files changed, 11 insertions, 4 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index d799eb0..c29763b 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -1072,10 +1072,15 @@ class TosaTestGen:
return None
attr = ts.TosaSerializerAttribute()
- if a.dtype in (DType.FP16, DType.BF16, DType.FP32):
- attr.ClampAttribute(0, 0, min_val, max_val)
+ if a.dtype in (DType.BF16, DType.FP16, DType.FP32):
+ if a.dtype == DType.FP16:
+ # Non-tensor fp16 ops take fp16 values as fp32 in reference_model
+ min_val = min_val.astype(np.float32)
+ max_val = max_val.astype(np.float32)
+
+ attr.ClampAttribute(self.ser.builder, 0, 0, min_val, max_val)
else:
- attr.ClampAttribute(min_val, max_val, 0, 0)
+ attr.ClampAttribute(self.ser.builder, min_val, max_val, 0, 0)
self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
@@ -1221,7 +1226,9 @@ class TosaTestGen:
result_tens = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
attr = ts.TosaSerializerAttribute()
- attr.PadAttribute(padding.flatten(), pad_const_int, pad_const_float)
+ attr.PadAttribute(
+ self.ser.builder, padding.flatten(), pad_const_int, pad_const_float
+ )
# Invalidate Input/Output list for error if checks.
input_list = [a.name]