From 736fd1a7e4083153ccc4cf360b44dd07b6788494 Mon Sep 17 00:00:00 2001 From: James Ward Date: Mon, 23 Jan 2023 17:13:37 +0000 Subject: Create MI tests for Type Conversion: CAST * Add exclusion regex's to conformance generation Signed-off-by: James Ward Change-Id: I15bef7451efd5662065060242d35bd7fa3381487 --- reference_model/src/ops/op_factory.cc | 7 ++ reference_model/src/ops/type_conversion.cc | 79 ++++++++++++++-- reference_model/src/ops/type_conversion.h | 104 +++++++++++++++++++++ verif/conformance/test_select.py | 61 ++++++++---- verif/conformance/tosa_main_profile_ops_info.json | 76 +++++++++++++++ .../tosa_verif_conformance_generator.py | 4 - verif/generator/tosa_arg_gen.py | 33 +++++-- verif/generator/tosa_error_if.py | 23 +++-- 8 files changed, 345 insertions(+), 42 deletions(-) diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc index 0d56161..76cf666 100644 --- a/reference_model/src/ops/op_factory.cc +++ b/reference_model/src/ops/op_factory.cc @@ -469,26 +469,33 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT16); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT32); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, BOOL); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT32); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, BOOL); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT16); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT16); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, INT16); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT16); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, BF16); break; case Op_RESCALE: DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT8); diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc index f266675..dbedbad 100644 --- a/reference_model/src/ops/type_conversion.cc +++ b/reference_model/src/ops/type_conversion.cc @@ -15,6 +15,7 @@ #include "type_conversion.h" #include "quant_util.h" +#include "arith_util.h" #include "template_types.h" #include #include "half.hpp" @@ -307,30 +308,88 @@ CastHelper::CastHelper() template CastHelper::CastHelper() { + // Integer data converted to fp16 (stored as fp32) fcn = [](InEigenType in) -> float { - half_float::half out = half_float::half_cast(in); // Cast to half_float - return half_float::half_cast(out); // Cast to float (underlying FP16 EigenType) + half_float::half h = half_float::half(in); + float out = half_float::half_cast(h); + return out; + }; +} + +CastHelper::CastHelper() +{ + // fp32 data converted to fp16 (stored as fp32) + fcn = [](float in) -> float { + float out = fpTrunc(in); // truncate required for conversion from higher precision + return out; + }; +} + +template +CastHelper::CastHelper() +{ + // Integer data converted to bf16 (stored as fp32) + fcn = [](InEigenType in) -> float { + float out = (float)in; // default cast to float is round_to_nearest_float() + return out; + }; +} + +CastHelper::CastHelper() +{ + // fp32 data converted to bf16 (stored as fp32) + fcn = [](float in) -> float { + return fpTrunc(in); // truncate required for conversions from higher precision }; } template CastHelper::CastHelper() { - // Assuming InEigenType = float. + // fp16 data (stored as fp32) converted to integer fcn = [](float in) -> OutEigenType { - // Perform initial rounding in half-precision then cast back to float - half_float::half h = half_float::half_cast(in); + // Cast from float representation back to half_float before rounding + half_float::half h = half_float::half(in); h = std::round(h); - OutEigenType out = half_float::half_cast(h); + OutEigenType out = half_float::half_cast(h); out = std::max(out, OutMin); out = std::min(out, OutMax); return out; }; } +CastHelper::CastHelper() +{ + // No-op since fp16 values treated internally as their fp32 representation + fcn = [](float in) -> OutEigenType { + return in; + }; +} + +template +CastHelper::CastHelper() +{ + // bf16 data (stored as fp32) converted to integer + fcn = [](float in) -> OutEigenType { + OutEigenType out = std::round(in); + out = std::max(out, OutMin); + out = std::min(out, OutMax); + return out; + }; +} + +CastHelper::CastHelper() +{ + // No-op since bf16 values treated as truncated fp32 internally + fcn = [](InEigenType in) -> OutEigenType { + return in; + }; +} + template CastHelper::CastHelper() { + // Integer data converted to fp32 fcn = [](InEigenType in) -> float { float out = (OutEigenType)in; // default cast to float is round_to_nearest_float() return out; @@ -340,6 +399,7 @@ CastHelper::CastHelper() template CastHelper::CastHelper() { + // fp32 data converted to integer fcn = [](float in) -> OutEigenType { OutEigenType out = std::round(in); out = std::max(out, OutMin); @@ -356,26 +416,33 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, FP16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT16); diff --git a/reference_model/src/ops/type_conversion.h b/reference_model/src/ops/type_conversion.h index b0de30c..e2fc6e2 100644 --- a/reference_model/src/ops/type_conversion.h +++ b/reference_model/src/ops/type_conversion.h @@ -136,6 +136,76 @@ private: FcnType fcn; }; +template <> +class CastHelper +{ +public: + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using FcnType = std::function; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template +class CastHelper +{ +public: + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using FcnType = std::function; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template +class CastHelper +{ +public: + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using FcnType = std::function; + static constexpr int32_t OutMin = GetQMin::value; + static constexpr int32_t OutMax = GetQMax::value; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template <> +class CastHelper +{ +public: + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using FcnType = std::function; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + template class CastHelper { @@ -153,6 +223,40 @@ private: FcnType fcn; }; +template <> +class CastHelper +{ +public: + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using FcnType = std::function; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template <> +class CastHelper +{ +public: + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using FcnType = std::function; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + template class CastHelper { diff --git a/verif/conformance/test_select.py b/verif/conformance/test_select.py index 05f6db8..817d0b6 100644 --- a/verif/conformance/test_select.py +++ b/verif/conformance/test_select.py @@ -5,6 +5,7 @@ import argparse import itertools import json import logging +import re from pathlib import Path from typing import Any from typing import Dict @@ -129,20 +130,28 @@ class Operator: test_dir: Path, config: Dict[str, Dict[str, List[Any]]], negative=False, - exclude_types=None, ): """Initialise the selection parameters for an operator. - test_dir: the directory where the tests for all operators can be found + test_dir: the directory where the tests for all operators can + be found config: a dictionary with: - "params" - a dictionary with mappings of parameter names to the values - to select (a sub-set of expected values for instance) + "params" - a dictionary with mappings of parameter + names to the values to select (a sub-set of + expected values for instance) "permutes" - a list of parameter names to be permuted - "preselected" - a list of dictionaries containing parameter names and - pre-chosen values - "sparsity" - a dictionary of parameter names with a sparsity value - "errorifs" - list of ERRORIF case names to be selected (negative test) - negative: bool indicating if negative testing is being selected (ERRORIF tests) + "preselected" - a list of dictionaries containing + parameter names and pre-chosen values + "sparsity" - a dictionary of parameter names with a + sparsity value + "exclude_patterns" - a list of regex's whereby each + match will not be considered for selection. + Exclusion happens BEFORE test selection (i.e. + before permutes are applied). + "errorifs" - list of ERRORIF case names to be selected + (negative test) + negative: bool indicating if negative testing is being selected + (ERRORIF tests) EXAMPLE CONFIG: "params": { @@ -165,6 +174,9 @@ class Operator: "pad": "pad00" } ], + "exclude_patterns": [ + ".*_(i8|i16|i32|b)_out(i8|i16|i32|b)" + ], "errorifs": [ "InputZeroPointNotZero" ] @@ -187,23 +199,34 @@ class Operator: ) config["permutes"] = [] config["preselected"] = {} + config["exclude_patterns"] = [] self.params = config["params"] if "params" in config else {} self.permutes = config["permutes"] if "permutes" in config else [] self.sparsity = config["sparsity"] if "sparsity" in config else {} self.preselected = config["preselected"] if "preselected" in config else {} + self.exclude_patterns = ( + config["exclude_patterns"] if "exclude_patterns" in config else [] + ) self.non_permutes = [x for x in self.wks_param_names if x not in self.permutes] logger.info(f"{self.name}: permutes={self.permutes}") logger.info(f"{self.name}: non_permutes={self.non_permutes}") + logger.info(f"{self.name}: exclude_patterns={self.exclude_patterns}") + + self.test_paths = [] + excluded_paths = [] + for path in self.get_test_paths(test_dir, self.negative): + pattern_match = False + for pattern in self.exclude_patterns: + if re.fullmatch(pattern, path.name): + excluded_paths.append(path) + pattern_match = True + break + if not pattern_match: + self.test_paths.append(path) + + logger.debug(f"{self.name}: regex excluded paths={excluded_paths}") - if exclude_types is None: - exclude_types = [] - self.test_paths = [ - p - for p in self.get_test_paths(test_dir, self.negative) - # exclusion of types if requested - if self.path_params(p)["type"] not in exclude_types - ] if not self.test_paths: logger.error(f"no tests found for {self.name} in {test_dir}") logger.debug(f"{self.name}: paths={self.test_paths}") @@ -861,9 +884,7 @@ def main(): for op_name in Operator.registry: if not args.operators or op_name in args.operators: op_params = config[op_name] if op_name in config else {} - op = Operator.registry[op_name]( - args.test_dir, op_params, negative, exclude_types=["float"] - ) + op = Operator.registry[op_name](args.test_dir, op_params, negative) for test_path in op.select_tests(): print(test_path.resolve() if args.full_path else test_path.name) diff --git a/verif/conformance/tosa_main_profile_ops_info.json b/verif/conformance/tosa_main_profile_ops_info.json index 4cf2b57..f31fa71 100644 --- a/verif/conformance/tosa_main_profile_ops_info.json +++ b/verif/conformance/tosa_main_profile_ops_info.json @@ -216,6 +216,82 @@ "tosa-mi" ] }, + "cast": { + "group": "type_conversion", + "generator_negative_dim_range": "1,10", + "generator_args": [ + [ + "--target-dtype", + "fp32", + "--target-dtype", + "fp16", + "--target-dtype", + "bf16", + "--target-dtype", + "int8", + "--target-dtype", + "int16", + "--target-dtype", + "int32", + "--fp-values-range", + "-2.0,2.0", + "--tensor-dim-range", + "16,64", + "--target-rank", + "1", + "--target-rank", + "2", + "--target-rank", + "3" + ], + [ + "--target-dtype", + "fp32", + "--target-dtype", + "fp16", + "--target-dtype", + "bf16", + "--target-dtype", + "int8", + "--target-dtype", + "int16", + "--target-dtype", + "int32", + "--fp-values-range", + "-2.0,2.0", + "--tensor-dim-range", + "1,16", + "--target-rank", + "4", + "--target-rank", + "5" + ], + [ + "--target-dtype", + "fp16", + "--target-shape", + "1,1,1,65533,1", + "--target-shape", + "2,65538,1,1" + ] + ], + "params": { + "shape": [], + "type": [], + "output_type": [] + }, + "permutes": [ + "shape", + "type", + "output_type" + ], + "exclude_patterns": [ + ".*_(i8|i16|i32|b)_out(i8|i16|i32|b)" + ], + "profile": [ + "tosa-mi" + ] + }, "ceil": { "group": "ew_unary", "generator_args": [ diff --git a/verif/conformance/tosa_verif_conformance_generator.py b/verif/conformance/tosa_verif_conformance_generator.py index 817b242..4971fb0 100644 --- a/verif/conformance/tosa_verif_conformance_generator.py +++ b/verif/conformance/tosa_verif_conformance_generator.py @@ -34,13 +34,11 @@ PROFILE_OPS_INFO = { "tosa-bi": { "operator_test_params": "tosa_base_profile_ops_info.json", "framework_tests": "tosa_base_profile_framework_ops_info.json", - "exclude_types": [], }, "tosa-mi": { # Note: This is just the extra tests not in the base profile! "operator_test_params": "tosa_main_profile_ops_info.json", "framework_tests": "tosa_main_profile_framework_ops_info.json", - "exclude_types": [], }, } PROFILES_ALL = "all" @@ -164,7 +162,6 @@ def build_op_tests(args, profile, operator, test_params): def _check_to_include_test(profile, test_name, exclude_negative_tests=False): """Check test name for exclusions, return False to indicate excluded.""" excludes = ["ERRORIF"] if exclude_negative_tests else [] - excludes.extend(PROFILE_OPS_INFO[profile]["exclude_types"]) for exclusion in excludes: if f"_{exclusion}_" in test_name: @@ -338,7 +335,6 @@ def get_op_tests_selection( op_build_dir, op_params, negative, - exclude_types=PROFILE_OPS_INFO[profile]["exclude_types"], ) except KeyError: logger.error(f"{operator} operator is not supported by test_select") diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index fed91f6..05a7d2b 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -1445,19 +1445,40 @@ class TosaArgGen: if error_name == ErrorIf.WrongOutputType: dtypeList = TosaErrorIfArgGen.eiCastErrorIf(testGen, inDtype) elif inDtype == DType.INT8: - dtypeList = [DType.BOOL, DType.INT16, DType.INT32, DType.FP32] + dtypeList = [ + DType.BOOL, + DType.INT16, + DType.INT32, + DType.FP16, + DType.BF16, + DType.FP32, + ] elif inDtype == DType.INT16: - dtypeList = [DType.BOOL, DType.INT8, DType.INT32, DType.FP32] + dtypeList = [ + DType.BOOL, + DType.INT8, + DType.INT32, + DType.FP16, + DType.BF16, + DType.FP32, + ] elif inDtype == DType.INT32: - dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32] + dtypeList = [ + DType.BOOL, + DType.INT8, + DType.INT16, + DType.FP16, + DType.BF16, + DType.FP32, + ] elif inDtype == DType.BOOL: dtypeList = [DType.INT8, DType.INT16, DType.INT32] elif inDtype == DType.FP16: - dtypeList = [DType.INT8, DType.INT16, DType.INT32] + dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP32] elif inDtype == DType.BF16: - dtypeList = [DType.INT8, DType.INT16, DType.INT32] + dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP32] elif inDtype == DType.FP32: - dtypeList = [DType.INT8, DType.INT16, DType.INT32] + dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP16, DType.BF16] elif error_name == ErrorIf.WrongInputType: # Pick some potentially correct output type for incorrect input type dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32] diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py index 40c5d13..93f975d 100644 --- a/verif/generator/tosa_error_if.py +++ b/verif/generator/tosa_error_if.py @@ -314,12 +314,14 @@ class TosaErrorIfArgGen: @staticmethod def eiCastErrorIf(testGen, input_dtype): - if input_dtype in [DType.BOOL, DType.FP16, DType.BF16, DType.FP32]: - outputDType = [DType.BOOL, DType.INT48, DType.FP16, DType.BF16, DType.FP32] + if input_dtype in [DType.BOOL, DType.FP32]: + outputDType = [DType.BOOL, DType.INT48, DType.FP32] + elif input_dtype in [DType.FP16, DType.BF16]: + outputDType = [DType.BOOL, DType.INT48] elif input_dtype in [DType.INT8, DType.INT16, DType.INT32]: outputDType = [DType.INT48] else: - assert True, f"input_dtype ({input_dtype}) not supported" + assert False, f"input_dtype ({input_dtype}) not supported" return outputDType @@ -538,15 +540,24 @@ class TosaErrorValidator: ) or ( input_dtype == DType.FP16 - and output_dtype not in [DType.INT8, DType.INT16, DType.INT32] + and output_dtype + not in [DType.INT8, DType.INT16, DType.INT32, DType.FP32] ) or ( input_dtype == DType.BF16 - and output_dtype not in [DType.INT8, DType.INT16, DType.INT32] + and output_dtype + not in [DType.INT8, DType.INT16, DType.INT32, DType.FP32] ) or ( input_dtype == DType.FP32 - and output_dtype not in [DType.INT8, DType.INT16, DType.INT32] + and output_dtype + not in [ + DType.INT8, + DType.INT16, + DType.INT32, + DType.FP16, + DType.BF16, + ] ) ): error_result = True -- cgit v1.2.1