aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJames Ward <james.ward@arm.com>2023-01-23 17:13:37 +0000
committerJames Ward <james.ward@arm.com>2023-01-31 11:45:21 +0000
commit736fd1a7e4083153ccc4cf360b44dd07b6788494 (patch)
tree42f34388dddb504650be0e17dbda8c9073223313
parent2138a19ae830ec7d9ce5b15f15cbd7a22864bb8f (diff)
downloadreference_model-736fd1a7e4083153ccc4cf360b44dd07b6788494.tar.gz
Create MI tests for Type Conversion: CAST
* Add exclusion regex's to conformance generation Signed-off-by: James Ward <james.ward@arm.com> Change-Id: I15bef7451efd5662065060242d35bd7fa3381487
-rw-r--r--reference_model/src/ops/op_factory.cc7
-rw-r--r--reference_model/src/ops/type_conversion.cc79
-rw-r--r--reference_model/src/ops/type_conversion.h104
-rw-r--r--verif/conformance/test_select.py61
-rw-r--r--verif/conformance/tosa_main_profile_ops_info.json76
-rw-r--r--verif/conformance/tosa_verif_conformance_generator.py4
-rw-r--r--verif/generator/tosa_arg_gen.py33
-rw-r--r--verif/generator/tosa_error_if.py23
8 files changed, 345 insertions, 42 deletions
diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc
index 0d56161..76cf666 100644
--- a/reference_model/src/ops/op_factory.cc
+++ b/reference_model/src/ops/op_factory.cc
@@ -469,26 +469,33 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT16);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT32);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, BOOL);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT32);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, BOOL);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT16);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT16);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, INT16);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT16);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, BF16);
break;
case Op_RESCALE:
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT8);
diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc
index f266675..dbedbad 100644
--- a/reference_model/src/ops/type_conversion.cc
+++ b/reference_model/src/ops/type_conversion.cc
@@ -15,6 +15,7 @@
#include "type_conversion.h"
#include "quant_util.h"
+#include "arith_util.h"
#include "template_types.h"
#include <cmath>
#include "half.hpp"
@@ -307,30 +308,88 @@ CastHelper<DType_BOOL, OutDtype>::CastHelper()
template <DType InDtype>
CastHelper<InDtype, DType_FP16>::CastHelper()
{
+ // Integer data converted to fp16 (stored as fp32)
fcn = [](InEigenType in) -> float {
- half_float::half out = half_float::half_cast<half_float::half, InEigenType>(in); // Cast to half_float
- return half_float::half_cast<float, half_float::half>(out); // Cast to float (underlying FP16 EigenType)
+ half_float::half h = half_float::half(in);
+ float out = half_float::half_cast<float, half_float::half>(h);
+ return out;
+ };
+}
+
+CastHelper<DType_FP32, DType_FP16>::CastHelper()
+{
+ // fp32 data converted to fp16 (stored as fp32)
+ fcn = [](float in) -> float {
+ float out = fpTrunc<DType_FP16>(in); // truncate required for conversion from higher precision
+ return out;
+ };
+}
+
+template <DType InDtype>
+CastHelper<InDtype, DType_BF16>::CastHelper()
+{
+ // Integer data converted to bf16 (stored as fp32)
+ fcn = [](InEigenType in) -> float {
+ float out = (float)in; // default cast to float is round_to_nearest_float()
+ return out;
+ };
+}
+
+CastHelper<DType_FP32, DType_BF16>::CastHelper()
+{
+ // fp32 data converted to bf16 (stored as fp32)
+ fcn = [](float in) -> float {
+ return fpTrunc<DType_BF16>(in); // truncate required for conversions from higher precision
};
}
template <DType OutDtype>
CastHelper<DType_FP16, OutDtype>::CastHelper()
{
- // Assuming InEigenType = float.
+ // fp16 data (stored as fp32) converted to integer
fcn = [](float in) -> OutEigenType {
- // Perform initial rounding in half-precision then cast back to float
- half_float::half h = half_float::half_cast<half_float::half, float>(in);
+ // Cast from float representation back to half_float before rounding
+ half_float::half h = half_float::half(in);
h = std::round(h);
- OutEigenType out = half_float::half_cast<float, half_float::half>(h);
+ OutEigenType out = half_float::half_cast<OutEigenType, half_float::half>(h);
out = std::max<OutEigenType>(out, OutMin);
out = std::min<OutEigenType>(out, OutMax);
return out;
};
}
+CastHelper<DType_FP16, DType_FP32>::CastHelper()
+{
+ // No-op since fp16 values treated internally as their fp32 representation
+ fcn = [](float in) -> OutEigenType {
+ return in;
+ };
+}
+
+template <DType OutDtype>
+CastHelper<DType_BF16, OutDtype>::CastHelper()
+{
+ // bf16 data (stored as fp32) converted to integer
+ fcn = [](float in) -> OutEigenType {
+ OutEigenType out = std::round(in);
+ out = std::max<OutEigenType>(out, OutMin);
+ out = std::min<OutEigenType>(out, OutMax);
+ return out;
+ };
+}
+
+CastHelper<DType_BF16, DType_FP32>::CastHelper()
+{
+ // No-op since bf16 values treated as truncated fp32 internally
+ fcn = [](InEigenType in) -> OutEigenType {
+ return in;
+ };
+}
+
template <DType InDtype>
CastHelper<InDtype, DType_FP32>::CastHelper()
{
+ // Integer data converted to fp32
fcn = [](InEigenType in) -> float {
float out = (OutEigenType)in; // default cast to float is round_to_nearest_float()
return out;
@@ -340,6 +399,7 @@ CastHelper<InDtype, DType_FP32>::CastHelper()
template <DType OutDtype>
CastHelper<DType_FP32, OutDtype>::CastHelper()
{
+ // fp32 data converted to integer
fcn = [](float in) -> OutEigenType {
OutEigenType out = std::round(in);
out = std::max<OutEigenType>(out, OutMin);
@@ -356,26 +416,33 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, BOOL);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, BOOL);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, BOOL);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, INT16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, FP16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT16);
diff --git a/reference_model/src/ops/type_conversion.h b/reference_model/src/ops/type_conversion.h
index b0de30c..e2fc6e2 100644
--- a/reference_model/src/ops/type_conversion.h
+++ b/reference_model/src/ops/type_conversion.h
@@ -136,6 +136,76 @@ private:
FcnType fcn;
};
+template <>
+class CastHelper<DType_FP32, DType_FP16>
+{
+public:
+ using InEigenType = typename GetEigenType<DType_FP32>::type;
+ using OutEigenType = typename GetEigenType<DType_FP16>::type;
+ using FcnType = std::function<OutEigenType(InEigenType)>;
+ CastHelper();
+ const FcnType& get_fcn() const
+ {
+ return fcn;
+ }
+
+private:
+ FcnType fcn;
+};
+
+template <DType InDtype>
+class CastHelper<InDtype, DType_BF16>
+{
+public:
+ using InEigenType = typename GetEigenType<InDtype>::type;
+ using OutEigenType = typename GetEigenType<DType_BF16>::type;
+ using FcnType = std::function<OutEigenType(InEigenType)>;
+ CastHelper();
+ const FcnType& get_fcn() const
+ {
+ return fcn;
+ }
+
+private:
+ FcnType fcn;
+};
+
+template <DType OutDtype>
+class CastHelper<DType_BF16, OutDtype>
+{
+public:
+ using InEigenType = typename GetEigenType<DType_BF16>::type;
+ using OutEigenType = typename GetEigenType<OutDtype>::type;
+ using FcnType = std::function<OutEigenType(InEigenType)>;
+ static constexpr int32_t OutMin = GetQMin<OutDtype>::value;
+ static constexpr int32_t OutMax = GetQMax<OutDtype>::value;
+ CastHelper();
+ const FcnType& get_fcn() const
+ {
+ return fcn;
+ }
+
+private:
+ FcnType fcn;
+};
+
+template <>
+class CastHelper<DType_FP32, DType_BF16>
+{
+public:
+ using InEigenType = typename GetEigenType<DType_FP32>::type;
+ using OutEigenType = typename GetEigenType<DType_BF16>::type;
+ using FcnType = std::function<OutEigenType(InEigenType)>;
+ CastHelper();
+ const FcnType& get_fcn() const
+ {
+ return fcn;
+ }
+
+private:
+ FcnType fcn;
+};
+
template <DType InDtype>
class CastHelper<InDtype, DType_FP32>
{
@@ -153,6 +223,40 @@ private:
FcnType fcn;
};
+template <>
+class CastHelper<DType_FP16, DType_FP32>
+{
+public:
+ using InEigenType = typename GetEigenType<DType_FP16>::type;
+ using OutEigenType = typename GetEigenType<DType_FP32>::type;
+ using FcnType = std::function<OutEigenType(InEigenType)>;
+ CastHelper();
+ const FcnType& get_fcn() const
+ {
+ return fcn;
+ }
+
+private:
+ FcnType fcn;
+};
+
+template <>
+class CastHelper<DType_BF16, DType_FP32>
+{
+public:
+ using InEigenType = typename GetEigenType<DType_BF16>::type;
+ using OutEigenType = typename GetEigenType<DType_FP32>::type;
+ using FcnType = std::function<OutEigenType(InEigenType)>;
+ CastHelper();
+ const FcnType& get_fcn() const
+ {
+ return fcn;
+ }
+
+private:
+ FcnType fcn;
+};
+
template <DType OutDtype>
class CastHelper<DType_FP32, OutDtype>
{
diff --git a/verif/conformance/test_select.py b/verif/conformance/test_select.py
index 05f6db8..817d0b6 100644
--- a/verif/conformance/test_select.py
+++ b/verif/conformance/test_select.py
@@ -5,6 +5,7 @@ import argparse
import itertools
import json
import logging
+import re
from pathlib import Path
from typing import Any
from typing import Dict
@@ -129,20 +130,28 @@ class Operator:
test_dir: Path,
config: Dict[str, Dict[str, List[Any]]],
negative=False,
- exclude_types=None,
):
"""Initialise the selection parameters for an operator.
- test_dir: the directory where the tests for all operators can be found
+ test_dir: the directory where the tests for all operators can
+ be found
config: a dictionary with:
- "params" - a dictionary with mappings of parameter names to the values
- to select (a sub-set of expected values for instance)
+ "params" - a dictionary with mappings of parameter
+ names to the values to select (a sub-set of
+ expected values for instance)
"permutes" - a list of parameter names to be permuted
- "preselected" - a list of dictionaries containing parameter names and
- pre-chosen values
- "sparsity" - a dictionary of parameter names with a sparsity value
- "errorifs" - list of ERRORIF case names to be selected (negative test)
- negative: bool indicating if negative testing is being selected (ERRORIF tests)
+ "preselected" - a list of dictionaries containing
+ parameter names and pre-chosen values
+ "sparsity" - a dictionary of parameter names with a
+ sparsity value
+ "exclude_patterns" - a list of regex's whereby each
+ match will not be considered for selection.
+ Exclusion happens BEFORE test selection (i.e.
+ before permutes are applied).
+ "errorifs" - list of ERRORIF case names to be selected
+ (negative test)
+ negative: bool indicating if negative testing is being selected
+ (ERRORIF tests)
EXAMPLE CONFIG:
"params": {
@@ -165,6 +174,9 @@ class Operator:
"pad": "pad00"
}
],
+ "exclude_patterns": [
+ ".*_(i8|i16|i32|b)_out(i8|i16|i32|b)"
+ ],
"errorifs": [
"InputZeroPointNotZero"
]
@@ -187,23 +199,34 @@ class Operator:
)
config["permutes"] = []
config["preselected"] = {}
+ config["exclude_patterns"] = []
self.params = config["params"] if "params" in config else {}
self.permutes = config["permutes"] if "permutes" in config else []
self.sparsity = config["sparsity"] if "sparsity" in config else {}
self.preselected = config["preselected"] if "preselected" in config else {}
+ self.exclude_patterns = (
+ config["exclude_patterns"] if "exclude_patterns" in config else []
+ )
self.non_permutes = [x for x in self.wks_param_names if x not in self.permutes]
logger.info(f"{self.name}: permutes={self.permutes}")
logger.info(f"{self.name}: non_permutes={self.non_permutes}")
+ logger.info(f"{self.name}: exclude_patterns={self.exclude_patterns}")
+
+ self.test_paths = []
+ excluded_paths = []
+ for path in self.get_test_paths(test_dir, self.negative):
+ pattern_match = False
+ for pattern in self.exclude_patterns:
+ if re.fullmatch(pattern, path.name):
+ excluded_paths.append(path)
+ pattern_match = True
+ break
+ if not pattern_match:
+ self.test_paths.append(path)
+
+ logger.debug(f"{self.name}: regex excluded paths={excluded_paths}")
- if exclude_types is None:
- exclude_types = []
- self.test_paths = [
- p
- for p in self.get_test_paths(test_dir, self.negative)
- # exclusion of types if requested
- if self.path_params(p)["type"] not in exclude_types
- ]
if not self.test_paths:
logger.error(f"no tests found for {self.name} in {test_dir}")
logger.debug(f"{self.name}: paths={self.test_paths}")
@@ -861,9 +884,7 @@ def main():
for op_name in Operator.registry:
if not args.operators or op_name in args.operators:
op_params = config[op_name] if op_name in config else {}
- op = Operator.registry[op_name](
- args.test_dir, op_params, negative, exclude_types=["float"]
- )
+ op = Operator.registry[op_name](args.test_dir, op_params, negative)
for test_path in op.select_tests():
print(test_path.resolve() if args.full_path else test_path.name)
diff --git a/verif/conformance/tosa_main_profile_ops_info.json b/verif/conformance/tosa_main_profile_ops_info.json
index 4cf2b57..f31fa71 100644
--- a/verif/conformance/tosa_main_profile_ops_info.json
+++ b/verif/conformance/tosa_main_profile_ops_info.json
@@ -216,6 +216,82 @@
"tosa-mi"
]
},
+ "cast": {
+ "group": "type_conversion",
+ "generator_negative_dim_range": "1,10",
+ "generator_args": [
+ [
+ "--target-dtype",
+ "fp32",
+ "--target-dtype",
+ "fp16",
+ "--target-dtype",
+ "bf16",
+ "--target-dtype",
+ "int8",
+ "--target-dtype",
+ "int16",
+ "--target-dtype",
+ "int32",
+ "--fp-values-range",
+ "-2.0,2.0",
+ "--tensor-dim-range",
+ "16,64",
+ "--target-rank",
+ "1",
+ "--target-rank",
+ "2",
+ "--target-rank",
+ "3"
+ ],
+ [
+ "--target-dtype",
+ "fp32",
+ "--target-dtype",
+ "fp16",
+ "--target-dtype",
+ "bf16",
+ "--target-dtype",
+ "int8",
+ "--target-dtype",
+ "int16",
+ "--target-dtype",
+ "int32",
+ "--fp-values-range",
+ "-2.0,2.0",
+ "--tensor-dim-range",
+ "1,16",
+ "--target-rank",
+ "4",
+ "--target-rank",
+ "5"
+ ],
+ [
+ "--target-dtype",
+ "fp16",
+ "--target-shape",
+ "1,1,1,65533,1",
+ "--target-shape",
+ "2,65538,1,1"
+ ]
+ ],
+ "params": {
+ "shape": [],
+ "type": [],
+ "output_type": []
+ },
+ "permutes": [
+ "shape",
+ "type",
+ "output_type"
+ ],
+ "exclude_patterns": [
+ ".*_(i8|i16|i32|b)_out(i8|i16|i32|b)"
+ ],
+ "profile": [
+ "tosa-mi"
+ ]
+ },
"ceil": {
"group": "ew_unary",
"generator_args": [
diff --git a/verif/conformance/tosa_verif_conformance_generator.py b/verif/conformance/tosa_verif_conformance_generator.py
index 817b242..4971fb0 100644
--- a/verif/conformance/tosa_verif_conformance_generator.py
+++ b/verif/conformance/tosa_verif_conformance_generator.py
@@ -34,13 +34,11 @@ PROFILE_OPS_INFO = {
"tosa-bi": {
"operator_test_params": "tosa_base_profile_ops_info.json",
"framework_tests": "tosa_base_profile_framework_ops_info.json",
- "exclude_types": [],
},
"tosa-mi": {
# Note: This is just the extra tests not in the base profile!
"operator_test_params": "tosa_main_profile_ops_info.json",
"framework_tests": "tosa_main_profile_framework_ops_info.json",
- "exclude_types": [],
},
}
PROFILES_ALL = "all"
@@ -164,7 +162,6 @@ def build_op_tests(args, profile, operator, test_params):
def _check_to_include_test(profile, test_name, exclude_negative_tests=False):
"""Check test name for exclusions, return False to indicate excluded."""
excludes = ["ERRORIF"] if exclude_negative_tests else []
- excludes.extend(PROFILE_OPS_INFO[profile]["exclude_types"])
for exclusion in excludes:
if f"_{exclusion}_" in test_name:
@@ -338,7 +335,6 @@ def get_op_tests_selection(
op_build_dir,
op_params,
negative,
- exclude_types=PROFILE_OPS_INFO[profile]["exclude_types"],
)
except KeyError:
logger.error(f"{operator} operator is not supported by test_select")
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index fed91f6..05a7d2b 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -1445,19 +1445,40 @@ class TosaArgGen:
if error_name == ErrorIf.WrongOutputType:
dtypeList = TosaErrorIfArgGen.eiCastErrorIf(testGen, inDtype)
elif inDtype == DType.INT8:
- dtypeList = [DType.BOOL, DType.INT16, DType.INT32, DType.FP32]
+ dtypeList = [
+ DType.BOOL,
+ DType.INT16,
+ DType.INT32,
+ DType.FP16,
+ DType.BF16,
+ DType.FP32,
+ ]
elif inDtype == DType.INT16:
- dtypeList = [DType.BOOL, DType.INT8, DType.INT32, DType.FP32]
+ dtypeList = [
+ DType.BOOL,
+ DType.INT8,
+ DType.INT32,
+ DType.FP16,
+ DType.BF16,
+ DType.FP32,
+ ]
elif inDtype == DType.INT32:
- dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32]
+ dtypeList = [
+ DType.BOOL,
+ DType.INT8,
+ DType.INT16,
+ DType.FP16,
+ DType.BF16,
+ DType.FP32,
+ ]
elif inDtype == DType.BOOL:
dtypeList = [DType.INT8, DType.INT16, DType.INT32]
elif inDtype == DType.FP16:
- dtypeList = [DType.INT8, DType.INT16, DType.INT32]
+ dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
elif inDtype == DType.BF16:
- dtypeList = [DType.INT8, DType.INT16, DType.INT32]
+ dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
elif inDtype == DType.FP32:
- dtypeList = [DType.INT8, DType.INT16, DType.INT32]
+ dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP16, DType.BF16]
elif error_name == ErrorIf.WrongInputType:
# Pick some potentially correct output type for incorrect input type
dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32]
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index 40c5d13..93f975d 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -314,12 +314,14 @@ class TosaErrorIfArgGen:
@staticmethod
def eiCastErrorIf(testGen, input_dtype):
- if input_dtype in [DType.BOOL, DType.FP16, DType.BF16, DType.FP32]:
- outputDType = [DType.BOOL, DType.INT48, DType.FP16, DType.BF16, DType.FP32]
+ if input_dtype in [DType.BOOL, DType.FP32]:
+ outputDType = [DType.BOOL, DType.INT48, DType.FP32]
+ elif input_dtype in [DType.FP16, DType.BF16]:
+ outputDType = [DType.BOOL, DType.INT48]
elif input_dtype in [DType.INT8, DType.INT16, DType.INT32]:
outputDType = [DType.INT48]
else:
- assert True, f"input_dtype ({input_dtype}) not supported"
+ assert False, f"input_dtype ({input_dtype}) not supported"
return outputDType
@@ -538,15 +540,24 @@ class TosaErrorValidator:
)
or (
input_dtype == DType.FP16
- and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
+ and output_dtype
+ not in [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
)
or (
input_dtype == DType.BF16
- and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
+ and output_dtype
+ not in [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
)
or (
input_dtype == DType.FP32
- and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
+ and output_dtype
+ not in [
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.FP16,
+ DType.BF16,
+ ]
)
):
error_result = True