aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_test_gen.py
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2023-10-26 13:53:14 +0100
committerEric Kunze <eric.kunze@arm.com>2023-11-02 23:22:09 +0000
commita4d907e8686791dd84ed987d0d79325c4d908b73 (patch)
tree9748ef39183b7548a9ff50d457920eace3a6fdec /verif/generator/tosa_test_gen.py
parentd1a08ce27ef8d0f6cf77e1b864610aade06edc5c (diff)
downloadreference_model-a4d907e8686791dd84ed987d0d79325c4d908b73.tar.gz
Main compliance testing support for MUL
Update verify ULP mode to allow fractions (e.g. 0.5). Update pseudo generator to accept ranges. Fix up pseudo random distribution based on ranges. Change-Id: I9168c5f7d37722678c0f1f9e906953c8cec367b1 Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r--verif/generator/tosa_test_gen.py62
1 files changed, 45 insertions, 17 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 54b624e..1995cbc 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -51,15 +51,31 @@ class TosaTestGen:
self.quantGen = TosaQuantGen()
# Force makeShape to do a specific starting shape
self.targetted_shape = None
- # Work out floating point range
- self.random_fp_low = min(args.tensor_fp_value_range)
- self.random_fp_high = max(args.tensor_fp_value_range)
# JSON schema validation
self.descSchemaValidator = TestDescSchemaValidator()
# Data generator library is sometimes needed for compliance set up
# even if we are generating the data later (lazy_data_generation)
self.dgl = GenerateLibrary(args.generate_lib_path)
+ # Work out floating point range
+ def convertFPRange(rangeFP, maxFP):
+ # Converts program arguments of max/-max to FP max
+ vals = []
+ for v in rangeFP:
+ if v == "max":
+ v = maxFP
+ elif v == "-max":
+ v = -maxFP
+ vals.append(v)
+ return tuple(sorted(vals))
+
+ self.random_float_range = {}
+ for dtype in (DType.FP32, DType.FP16, DType.BF16):
+ self.random_float_range[dtype] = convertFPRange(
+ args.tensor_fp_value_range,
+ TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
+ )
+
def createSerializer(self, opName, testPath):
self.testPath = os.path.join(opName, testPath)
@@ -130,9 +146,8 @@ class TosaTestGen:
# Returns dtype value range boundaries (low, high)
# The high boundary is excluded in the range
# unless high_inclusive is True
-
if dtype in (DType.FP32, DType.FP16, DType.BF16):
- return (self.random_fp_low, self.random_fp_high)
+ return self.random_float_range[dtype]
elif dtype == DType.BOOL:
rng = (0, 2)
elif dtype == DType.UINT8:
@@ -318,8 +333,6 @@ class TosaTestGen:
compliance_tens["ulp_info"] = {"ulp": op["compliance"]["ulp"]}
elif op["op"] == Op.REDUCE_PRODUCT:
mode = gtu.ComplianceMode.REDUCE_PRODUCT
- elif op["op"] in (Op.ADD, Op.MUL, Op.SUB, Op.CEIL, Op.FLOOR, Op.CAST):
- mode = gtu.ComplianceMode.ROUND
else:
mode = gtu.ComplianceMode.EXACT
compliance_tens["mode"] = gtu.ComplianceMode(mode).name
@@ -466,23 +479,29 @@ class TosaTestGen:
self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
- def build_mul(self, op, a, b, shift, validator_fcns=None, error_name=None):
- result_tens = OutputShaper.binaryBroadcastOp(
+ def build_mul(
+ self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
+ ):
+ assert len(inputs) == 2
+ a, b = inputs
+ shift = args_dict["shift"]
+
+ result_tensor = OutputShaper.binaryBroadcastOp(
self.ser, self.rng, a, b, error_name
)
- # Special for multiply:
- # Force the result to INT32 for INT types
+ # Special for multiply: Force the result to INT32 for INT types
if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
- result_tens.setDtype(DType.INT32)
+ result_tensor.setDtype(DType.INT32)
+
if error_name == ErrorIf.WrongOutputType:
all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
outputDType = self.rng.choice(all_dtypes)
- result_tens.setDtype(outputDType)
+ result_tensor.setDtype(outputDType)
# Invalidate Input/Output list for error if checks.
input_list = [a.name, b.name]
- output_list = [result_tens.name]
+ output_list = [result_tensor.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
@@ -497,8 +516,8 @@ class TosaTestGen:
input1=a,
input2=b,
input_dtype=a.dtype,
- output_dtype=result_tens.dtype,
- result_tensors=[result_tens],
+ output_dtype=result_tensor.dtype,
+ result_tensors=[result_tensor],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -509,7 +528,12 @@ class TosaTestGen:
attr.MulAttribute(shift)
self.ser.addOperator(op["op"], input_list, output_list, attr)
- return result_tens
+
+ compliance = self.tensorComplianceMetaData(
+ op, a.dtype, args_dict, result_tensor, error_name
+ )
+
+ return TosaTestGen.BuildInfo(result_tensor, compliance)
def build_table(self, op, a, table, validator_fcns=None, error_name=None):
result_tens = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
@@ -3456,6 +3480,10 @@ class TosaTestGen:
TosaErrorValidator.evDimensionMismatch,
TosaErrorValidator.evBroadcastShapesMismatch,
),
+ "data_gen": {
+ "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
+ },
+ "compliance": {"ulp": 0.5},
},
"pow": {
"op": Op.POW,