diff options
Diffstat (limited to 'verif')
-rw-r--r-- | verif/conformance/test_select.py | 12 | ||||
-rw-r--r-- | verif/conformance/tosa_main_profile_ops_info.json | 124 | ||||
-rw-r--r-- | verif/generator/tosa_test_gen.py | 48 |
3 files changed, 184 insertions, 0 deletions
diff --git a/verif/conformance/test_select.py b/verif/conformance/test_select.py index 55eef58..58d3f9f 100644 --- a/verif/conformance/test_select.py +++ b/verif/conformance/test_select.py @@ -848,6 +848,18 @@ class RsqrtOperator(Operator): name = "rsqrt" +class CosOperator(Operator): + """Test selector for the COS operator.""" + + name = "cos" + + +class SinOperator(Operator): + """Test selector for the SIN operator.""" + + name = "sin" + + class ScatterOperator(Operator): """Test selector for the SCATTER operator.""" diff --git a/verif/conformance/tosa_main_profile_ops_info.json b/verif/conformance/tosa_main_profile_ops_info.json index 18e078a..9d68dbf 100644 --- a/verif/conformance/tosa_main_profile_ops_info.json +++ b/verif/conformance/tosa_main_profile_ops_info.json @@ -2269,6 +2269,130 @@ } } }, + "cos": { + "group": "ew_unary", + "profile": [ + "tosa-mi" + ], + "support_for": [ "lazy_data_gen" ], + "generation": { + "standard": { + "generator_args": [ + [ + "--target-dtype", + "fp32", + "--target-dtype", + "fp16", + "--target-dtype", + "bf16", + "--fp-values-range", + "-max,max", + "--tensor-dim-range", + "15,64", + "--target-rank", + "1", + "--target-rank", + "2", + "--target-rank", + "3" + ], + [ + "--target-dtype", + "fp16", + "--fp-values-range", + "-max,max", + "--tensor-dim-range", + "1,15", + "--target-rank", + "4", + "--target-rank", + "5" + ], + [ + "--target-dtype", + "fp32", + "--fp-values-range", + "-max,max", + "--target-shape", + "2,1,65537,1", + "--target-shape", + "3,1,65539,2,1" + ] + ] + } + }, + "selection": { + "default": { + "params": {}, + "permutes": [ + "shape", + "type" + ] + } + } + }, + "sin": { + "group": "ew_unary", + "profile": [ + "tosa-mi" + ], + "support_for": [ "lazy_data_gen" ], + "generation": { + "standard": { + "generator_args": [ + [ + "--target-dtype", + "fp32", + "--target-dtype", + "fp16", + "--target-dtype", + "bf16", + "--fp-values-range", + "-max,max", + "--tensor-dim-range", + "15,64", + "--target-rank", + "1", + "--target-rank", + "2", + "--target-rank", + "3" + ], + [ + "--target-dtype", + "fp32", + "--fp-values-range", + "-max,max", + "--tensor-dim-range", + "1,15", + "--target-rank", + "4", + "--target-rank", + "5" + ], + [ + "--target-dtype", + "fp16", + "--fp-values-range", + "-max,max", + "--target-shape", + "3,1,65534,2", + "--target-shape", + "65533,1,3,2,1" + ] + ] + } + }, + "selection": { + "default": { + "params": {}, + "permutes": [ + "shape", + "type" + ] + } + } + }, "rsqrt": { "group": "ew_unary", "profile": [ diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index b472087..978e735 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -392,6 +392,12 @@ class TosaTestGen: compliance_tens["abs_error_info"] = { "lower_bound": op["compliance"]["abs_error_lower_bound"] } + elif op["op"] in (Op.SIN, Op.COS): + mode = gtu.ComplianceMode.ABS_ERROR + if "compliance" in op and "abs_error_normal_divisor" in op["compliance"]: + compliance_tens["abs_error_info"] = { + "normal_divisor": op["compliance"]["abs_error_normal_divisor"] + } else: mode = gtu.ComplianceMode.EXACT compliance_tens["mode"] = gtu.ComplianceMode(mode).name @@ -4036,6 +4042,27 @@ class TosaTestGen: TosaErrorValidator.evWrongOutputList, ), }, + "cos": { + "op": Op.COS, + "operands": (1, 0), + "build_fcn": ( + build_unary, + TosaTensorGen.tgBasic, + TosaTensorValuesGen.tvgLazyGenDefault, + TosaArgGen.agNone, + ), + "types": TYPE_FP, + "error_if_validators": ( + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + ), + "data_gen": { + "fp": (gtu.DataGenType.PSEUDO_RANDOM,), + }, + "compliance": {"abs_error_normal_divisor": 2}, + }, "exp": { "op": Op.EXP, "operands": (1, 0), @@ -4180,6 +4207,27 @@ class TosaTestGen: }, "compliance": {"ulp": 2}, }, + "sin": { + "op": Op.SIN, + "operands": (1, 0), + "build_fcn": ( + build_unary, + TosaTensorGen.tgBasic, + TosaTensorValuesGen.tvgLazyGenDefault, + TosaArgGen.agNone, + ), + "types": TYPE_FP, + "error_if_validators": ( + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + ), + "data_gen": { + "fp": (gtu.DataGenType.PSEUDO_RANDOM,), + }, + "compliance": {"abs_error_normal_divisor": 2}, + }, # Elementwise Ternary operators "select": { "op": Op.SELECT, |