aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_arg_gen.py
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2023-11-15 11:00:49 +0000
committerEric Kunze <eric.kunze@arm.com>2023-11-16 21:24:34 +0000
commit30a41db0c89f8209ab710c1d312fd6697107a41b (patch)
tree3ebff243aa29fa2f538064852c3fa2950c4f96a3 /verif/generator/tosa_arg_gen.py
parent9a758384d1066ade713311940f3d15c860f90866 (diff)
downloadreference_model-30a41db0c89f8209ab710c1d312fd6697107a41b.tar.gz
Fix FP16, BF16 data ranges for conformance tests
Enable use of data ranges for old data gen path as well as the new generate library path, so that FP16 and BF16 test data is produced within the correct ranges. Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: I749870a3112f8c3a75f4d16b8322c813fbf977cd
Diffstat (limited to 'verif/generator/tosa_arg_gen.py')
-rw-r--r--verif/generator/tosa_arg_gen.py29
1 files changed, 12 insertions, 17 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index 454013a..6675025 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -654,28 +654,24 @@ class TosaTensorValuesGen:
):
# Variable inputs versus constants
pCount, cCount = testGen.TOSA_OP_LIST[opName]["operands"]
+ tens_ser_list = []
if (
error_name is not None
or not gtu.dtypeIsSupportedByCompliance(dtypeList[0])
or "data_gen" not in testGen.TOSA_OP_LIST[opName]
):
- # Fall back to original path when dealing with unsupported types or ops
-
- # First turn off lazy data gen so we always produce data
- lazy_data_gen = testGen.args.lazy_data_gen
- testGen.args.lazy_data_gen = False
-
- tens_ser_list = TosaTensorValuesGen.tvgDefault(
- testGen,
- testGen.TOSA_OP_LIST[opName],
- dtypeList,
- shapeList,
- [],
- error_name,
- )
- # Restore lazy data gen setting
- testGen.args.lazy_data_gen = lazy_data_gen
+ # Fall back to internal data gen when dealing with unsupported types or ops
+ data_range = argsDict["data_range"] if "data_range" in argsDict else None
+ for idx, info in enumerate(zip(shapeList, dtypeList)):
+ shape, dtype = info
+ # Ignore lazy data gen option and create data array using any range limits
+ arr = testGen.getRandTensor(shape, dtype, data_range)
+ if idx < pCount:
+ tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
+ else:
+ tens_ser_list.append(testGen.ser.addConst(shape, dtype, arr))
+
return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
# Create data generator meta-data
@@ -685,7 +681,6 @@ class TosaTensorValuesGen:
"tensors": {},
}
dg_tens_meta = tens_data["tensors"]
- tens_ser_list = []
for idx, shape in enumerate(shapeList):
tens_meta = {}