aboutsummaryrefslogtreecommitdiff
path: root/verif
diff options
context:
space:
mode:
authorJames Ward <james.ward@arm.com>2023-01-23 17:13:37 +0000
committerJames Ward <james.ward@arm.com>2023-01-31 11:45:21 +0000
commit736fd1a7e4083153ccc4cf360b44dd07b6788494 (patch)
tree42f34388dddb504650be0e17dbda8c9073223313 /verif
parent2138a19ae830ec7d9ce5b15f15cbd7a22864bb8f (diff)
downloadreference_model-736fd1a7e4083153ccc4cf360b44dd07b6788494.tar.gz
Create MI tests for Type Conversion: CAST
* Add exclusion regex's to conformance generation Signed-off-by: James Ward <james.ward@arm.com> Change-Id: I15bef7451efd5662065060242d35bd7fa3381487
Diffstat (limited to 'verif')
-rw-r--r--verif/conformance/test_select.py61
-rw-r--r--verif/conformance/tosa_main_profile_ops_info.json76
-rw-r--r--verif/conformance/tosa_verif_conformance_generator.py4
-rw-r--r--verif/generator/tosa_arg_gen.py33
-rw-r--r--verif/generator/tosa_error_if.py23
5 files changed, 161 insertions, 36 deletions
diff --git a/verif/conformance/test_select.py b/verif/conformance/test_select.py
index 05f6db8..817d0b6 100644
--- a/verif/conformance/test_select.py
+++ b/verif/conformance/test_select.py
@@ -5,6 +5,7 @@ import argparse
import itertools
import json
import logging
+import re
from pathlib import Path
from typing import Any
from typing import Dict
@@ -129,20 +130,28 @@ class Operator:
test_dir: Path,
config: Dict[str, Dict[str, List[Any]]],
negative=False,
- exclude_types=None,
):
"""Initialise the selection parameters for an operator.
- test_dir: the directory where the tests for all operators can be found
+ test_dir: the directory where the tests for all operators can
+ be found
config: a dictionary with:
- "params" - a dictionary with mappings of parameter names to the values
- to select (a sub-set of expected values for instance)
+ "params" - a dictionary with mappings of parameter
+ names to the values to select (a sub-set of
+ expected values for instance)
"permutes" - a list of parameter names to be permuted
- "preselected" - a list of dictionaries containing parameter names and
- pre-chosen values
- "sparsity" - a dictionary of parameter names with a sparsity value
- "errorifs" - list of ERRORIF case names to be selected (negative test)
- negative: bool indicating if negative testing is being selected (ERRORIF tests)
+ "preselected" - a list of dictionaries containing
+ parameter names and pre-chosen values
+ "sparsity" - a dictionary of parameter names with a
+ sparsity value
+ "exclude_patterns" - a list of regex's whereby each
+ match will not be considered for selection.
+ Exclusion happens BEFORE test selection (i.e.
+ before permutes are applied).
+ "errorifs" - list of ERRORIF case names to be selected
+ (negative test)
+ negative: bool indicating if negative testing is being selected
+ (ERRORIF tests)
EXAMPLE CONFIG:
"params": {
@@ -165,6 +174,9 @@ class Operator:
"pad": "pad00"
}
],
+ "exclude_patterns": [
+ ".*_(i8|i16|i32|b)_out(i8|i16|i32|b)"
+ ],
"errorifs": [
"InputZeroPointNotZero"
]
@@ -187,23 +199,34 @@ class Operator:
)
config["permutes"] = []
config["preselected"] = {}
+ config["exclude_patterns"] = []
self.params = config["params"] if "params" in config else {}
self.permutes = config["permutes"] if "permutes" in config else []
self.sparsity = config["sparsity"] if "sparsity" in config else {}
self.preselected = config["preselected"] if "preselected" in config else {}
+ self.exclude_patterns = (
+ config["exclude_patterns"] if "exclude_patterns" in config else []
+ )
self.non_permutes = [x for x in self.wks_param_names if x not in self.permutes]
logger.info(f"{self.name}: permutes={self.permutes}")
logger.info(f"{self.name}: non_permutes={self.non_permutes}")
+ logger.info(f"{self.name}: exclude_patterns={self.exclude_patterns}")
+
+ self.test_paths = []
+ excluded_paths = []
+ for path in self.get_test_paths(test_dir, self.negative):
+ pattern_match = False
+ for pattern in self.exclude_patterns:
+ if re.fullmatch(pattern, path.name):
+ excluded_paths.append(path)
+ pattern_match = True
+ break
+ if not pattern_match:
+ self.test_paths.append(path)
+
+ logger.debug(f"{self.name}: regex excluded paths={excluded_paths}")
- if exclude_types is None:
- exclude_types = []
- self.test_paths = [
- p
- for p in self.get_test_paths(test_dir, self.negative)
- # exclusion of types if requested
- if self.path_params(p)["type"] not in exclude_types
- ]
if not self.test_paths:
logger.error(f"no tests found for {self.name} in {test_dir}")
logger.debug(f"{self.name}: paths={self.test_paths}")
@@ -861,9 +884,7 @@ def main():
for op_name in Operator.registry:
if not args.operators or op_name in args.operators:
op_params = config[op_name] if op_name in config else {}
- op = Operator.registry[op_name](
- args.test_dir, op_params, negative, exclude_types=["float"]
- )
+ op = Operator.registry[op_name](args.test_dir, op_params, negative)
for test_path in op.select_tests():
print(test_path.resolve() if args.full_path else test_path.name)
diff --git a/verif/conformance/tosa_main_profile_ops_info.json b/verif/conformance/tosa_main_profile_ops_info.json
index 4cf2b57..f31fa71 100644
--- a/verif/conformance/tosa_main_profile_ops_info.json
+++ b/verif/conformance/tosa_main_profile_ops_info.json
@@ -216,6 +216,82 @@
"tosa-mi"
]
},
+ "cast": {
+ "group": "type_conversion",
+ "generator_negative_dim_range": "1,10",
+ "generator_args": [
+ [
+ "--target-dtype",
+ "fp32",
+ "--target-dtype",
+ "fp16",
+ "--target-dtype",
+ "bf16",
+ "--target-dtype",
+ "int8",
+ "--target-dtype",
+ "int16",
+ "--target-dtype",
+ "int32",
+ "--fp-values-range",
+ "-2.0,2.0",
+ "--tensor-dim-range",
+ "16,64",
+ "--target-rank",
+ "1",
+ "--target-rank",
+ "2",
+ "--target-rank",
+ "3"
+ ],
+ [
+ "--target-dtype",
+ "fp32",
+ "--target-dtype",
+ "fp16",
+ "--target-dtype",
+ "bf16",
+ "--target-dtype",
+ "int8",
+ "--target-dtype",
+ "int16",
+ "--target-dtype",
+ "int32",
+ "--fp-values-range",
+ "-2.0,2.0",
+ "--tensor-dim-range",
+ "1,16",
+ "--target-rank",
+ "4",
+ "--target-rank",
+ "5"
+ ],
+ [
+ "--target-dtype",
+ "fp16",
+ "--target-shape",
+ "1,1,1,65533,1",
+ "--target-shape",
+ "2,65538,1,1"
+ ]
+ ],
+ "params": {
+ "shape": [],
+ "type": [],
+ "output_type": []
+ },
+ "permutes": [
+ "shape",
+ "type",
+ "output_type"
+ ],
+ "exclude_patterns": [
+ ".*_(i8|i16|i32|b)_out(i8|i16|i32|b)"
+ ],
+ "profile": [
+ "tosa-mi"
+ ]
+ },
"ceil": {
"group": "ew_unary",
"generator_args": [
diff --git a/verif/conformance/tosa_verif_conformance_generator.py b/verif/conformance/tosa_verif_conformance_generator.py
index 817b242..4971fb0 100644
--- a/verif/conformance/tosa_verif_conformance_generator.py
+++ b/verif/conformance/tosa_verif_conformance_generator.py
@@ -34,13 +34,11 @@ PROFILE_OPS_INFO = {
"tosa-bi": {
"operator_test_params": "tosa_base_profile_ops_info.json",
"framework_tests": "tosa_base_profile_framework_ops_info.json",
- "exclude_types": [],
},
"tosa-mi": {
# Note: This is just the extra tests not in the base profile!
"operator_test_params": "tosa_main_profile_ops_info.json",
"framework_tests": "tosa_main_profile_framework_ops_info.json",
- "exclude_types": [],
},
}
PROFILES_ALL = "all"
@@ -164,7 +162,6 @@ def build_op_tests(args, profile, operator, test_params):
def _check_to_include_test(profile, test_name, exclude_negative_tests=False):
"""Check test name for exclusions, return False to indicate excluded."""
excludes = ["ERRORIF"] if exclude_negative_tests else []
- excludes.extend(PROFILE_OPS_INFO[profile]["exclude_types"])
for exclusion in excludes:
if f"_{exclusion}_" in test_name:
@@ -338,7 +335,6 @@ def get_op_tests_selection(
op_build_dir,
op_params,
negative,
- exclude_types=PROFILE_OPS_INFO[profile]["exclude_types"],
)
except KeyError:
logger.error(f"{operator} operator is not supported by test_select")
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index fed91f6..05a7d2b 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -1445,19 +1445,40 @@ class TosaArgGen:
if error_name == ErrorIf.WrongOutputType:
dtypeList = TosaErrorIfArgGen.eiCastErrorIf(testGen, inDtype)
elif inDtype == DType.INT8:
- dtypeList = [DType.BOOL, DType.INT16, DType.INT32, DType.FP32]
+ dtypeList = [
+ DType.BOOL,
+ DType.INT16,
+ DType.INT32,
+ DType.FP16,
+ DType.BF16,
+ DType.FP32,
+ ]
elif inDtype == DType.INT16:
- dtypeList = [DType.BOOL, DType.INT8, DType.INT32, DType.FP32]
+ dtypeList = [
+ DType.BOOL,
+ DType.INT8,
+ DType.INT32,
+ DType.FP16,
+ DType.BF16,
+ DType.FP32,
+ ]
elif inDtype == DType.INT32:
- dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32]
+ dtypeList = [
+ DType.BOOL,
+ DType.INT8,
+ DType.INT16,
+ DType.FP16,
+ DType.BF16,
+ DType.FP32,
+ ]
elif inDtype == DType.BOOL:
dtypeList = [DType.INT8, DType.INT16, DType.INT32]
elif inDtype == DType.FP16:
- dtypeList = [DType.INT8, DType.INT16, DType.INT32]
+ dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
elif inDtype == DType.BF16:
- dtypeList = [DType.INT8, DType.INT16, DType.INT32]
+ dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
elif inDtype == DType.FP32:
- dtypeList = [DType.INT8, DType.INT16, DType.INT32]
+ dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP16, DType.BF16]
elif error_name == ErrorIf.WrongInputType:
# Pick some potentially correct output type for incorrect input type
dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32]
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index 40c5d13..93f975d 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -314,12 +314,14 @@ class TosaErrorIfArgGen:
@staticmethod
def eiCastErrorIf(testGen, input_dtype):
- if input_dtype in [DType.BOOL, DType.FP16, DType.BF16, DType.FP32]:
- outputDType = [DType.BOOL, DType.INT48, DType.FP16, DType.BF16, DType.FP32]
+ if input_dtype in [DType.BOOL, DType.FP32]:
+ outputDType = [DType.BOOL, DType.INT48, DType.FP32]
+ elif input_dtype in [DType.FP16, DType.BF16]:
+ outputDType = [DType.BOOL, DType.INT48]
elif input_dtype in [DType.INT8, DType.INT16, DType.INT32]:
outputDType = [DType.INT48]
else:
- assert True, f"input_dtype ({input_dtype}) not supported"
+ assert False, f"input_dtype ({input_dtype}) not supported"
return outputDType
@@ -538,15 +540,24 @@ class TosaErrorValidator:
)
or (
input_dtype == DType.FP16
- and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
+ and output_dtype
+ not in [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
)
or (
input_dtype == DType.BF16
- and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
+ and output_dtype
+ not in [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
)
or (
input_dtype == DType.FP32
- and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
+ and output_dtype
+ not in [
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.FP16,
+ DType.BF16,
+ ]
)
):
error_result = True