aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_test_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r--verif/generator/tosa_test_gen.py51
1 files changed, 34 insertions, 17 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index d82f919..ae689b4 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -357,6 +357,12 @@ class TosaTestGen:
elif "compliance" in op and "ulp" in op["compliance"]:
mode = gtu.ComplianceMode.ULP
compliance_tens["ulp_info"] = {"ulp": op["compliance"]["ulp"]}
+ elif "compliance" in op and "relative" in op["compliance"]:
+ mode = gtu.ComplianceMode.RELATIVE
+ compliance_tens["relative_info"] = {
+ "max": argsDict["max_abs_value"],
+ "scale": op["compliance"]["relative"],
+ }
elif op["op"] == Op.REDUCE_PRODUCT:
mode = gtu.ComplianceMode.REDUCE_PRODUCT
compliance_tens["reduce_product_info"] = {"n": argsDict["n"]}
@@ -1933,17 +1939,21 @@ class TosaTestGen:
def build_resize(
self,
op,
- input,
- mode,
- scale,
- offset,
- border,
- input_dtype,
- output_dtype,
+ inputs,
+ args_dict,
validator_fcns,
error_name=None,
+ qinfo=None,
):
- result_tens = OutputShaper.resizeOp(
+ assert len(inputs) == 1
+ input = inputs[0]
+ mode = args_dict["mode"]
+ scale = args_dict["scale"]
+ offset = args_dict["offset"]
+ border = args_dict["border"]
+ output_dtype = args_dict["output_dtype"]
+
+ result_tensor = OutputShaper.resizeOp(
self.ser,
self.rng,
input,
@@ -1951,14 +1961,14 @@ class TosaTestGen:
scale,
offset,
border,
- input_dtype,
+ input.dtype,
output_dtype,
error_name,
)
# Invalidate Input/Output list for error if checks.
input_list = [input.name]
- output_list = [result_tens.name]
+ output_list = [result_tensor.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
@@ -1972,25 +1982,28 @@ class TosaTestGen:
op=op,
mode=mode,
scale=scale,
- input_dtype=input_dtype,
+ input_dtype=input.dtype,
output_dtype=output_dtype,
input_shape=input.shape,
- output_shape=result_tens.shape,
+ output_shape=result_tensor.shape,
offset=offset,
border=border,
input_list=input_list,
output_list=output_list,
- result_tensors=[result_tens],
+ result_tensors=[result_tensor],
num_operands=num_operands,
):
return None
attr = ts.TosaSerializerAttribute()
-
attr.ResizeAttribute(scale, offset, border, mode)
-
self.ser.addOperator(op["op"], input_list, output_list, attr)
- return result_tens
+
+ compliance = self.tensorComplianceMetaData(
+ op, input.dtype, args_dict, result_tensor, error_name
+ )
+
+ return TosaTestGen.BuildInfo(result_tensor, compliance)
def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
@@ -4610,7 +4623,7 @@ class TosaTestGen:
"build_fcn": (
build_resize,
TosaTensorGen.tgNHWC,
- TosaTensorValuesGen.tvgDefault,
+ TosaTensorValuesGen.tvgResize,
TosaArgGen.agResize,
),
"types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
@@ -4636,6 +4649,10 @@ class TosaTestGen:
TosaErrorValidator.evResizeOutputShapeMismatch,
TosaErrorValidator.evResizeOutputShapeNonInteger,
),
+ "data_gen": {
+ "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
+ },
+ "compliance": {"relative": 0.006},
},
# Type conversion
"cast": {