aboutsummaryrefslogtreecommitdiff
path: root/verif
diff options
context:
space:
mode:
Diffstat (limited to 'verif')
-rw-r--r--verif/checker/tosa_result_checker.py22
-rw-r--r--verif/generator/tosa_arg_gen.py10
-rw-r--r--verif/generator/tosa_error_if.py35
-rw-r--r--verif/generator/tosa_test_gen.py80
-rw-r--r--verif/generator/tosa_utils.py45
-rw-r--r--verif/generator/tosa_verif_build_tests.py4
-rw-r--r--verif/tests/test_tosa_refmodel.py16
7 files changed, 191 insertions, 21 deletions
diff --git a/verif/checker/tosa_result_checker.py b/verif/checker/tosa_result_checker.py
index 8ae3218..b7a76b6 100644
--- a/verif/checker/tosa_result_checker.py
+++ b/verif/checker/tosa_result_checker.py
@@ -9,6 +9,7 @@ from enum import unique
from pathlib import Path
import numpy as np
+from generator.tosa_utils import float32_is_valid_bfloat16
##################################
color_printing = True
@@ -63,7 +64,12 @@ TestResultErrorStr = [
def test_check(
- reference, result, test_name="test", quantize_tolerance=0, float_tolerance=1e-3
+ reference,
+ result,
+ test_name="test",
+ quantize_tolerance=0,
+ float_tolerance=1e-3,
+ misc_checks=[],
):
"""Check if the result is the same as the expected reference."""
if not os.path.isfile(reference):
@@ -111,6 +117,20 @@ def test_check(
)
return (TestResult.MISMATCH, 0.0, msg)
+ # Perform miscellaneous checks
+ if "bf16" in misc_checks:
+ # Ensure floats are valid bfloat16 values
+ test_res_is_bf16 = all([float32_is_valid_bfloat16(f) for f in test_result.flat])
+ ref_res_is_bf16 = all(
+ [float32_is_valid_bfloat16(f) for f in reference_result.flat]
+ )
+ if not (test_res_is_bf16 and ref_res_is_bf16):
+ msg = (
+ "All output values must be valid bfloat16. "
+ "reference_result: {ref_res_is_bf16}; test_result: {test_res_is_bf16}"
+ )
+ return (TestResult.INCORRECT_FORMAT, 0.0, msg)
+
# for quantized test, allow +-(quantize_tolerance) error
if reference_result.dtype == np.int32 or reference_result.dtype == np.int64:
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index 0203513..932ad55 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -776,7 +776,7 @@ class TosaTensorValuesGen:
), "Op.MUL must have 2 placeholders, 0 consts"
tens = []
- if dtypeList[0] in (DType.FP16, DType.FP32):
+ if dtypeList[0] in (DType.FP16, DType.BF16, DType.FP32):
tens.extend(testGen.buildPlaceholderTensors(shapeList[:], dtypeList[:]))
else:
placeholders = []
@@ -1130,6 +1130,8 @@ class TosaArgGen:
accum_dtypes = [DType.INT48]
elif dtype == DType.FP16:
accum_dtypes = [DType.FP16, DType.FP32]
+ elif dtype == DType.BF16:
+ accum_dtypes = [DType.FP32]
elif dtype == DType.FP32:
accum_dtypes = [DType.FP32]
elif error_name is None:
@@ -1304,7 +1306,7 @@ class TosaArgGen:
accum_dtypes = [DType.INT32]
elif dtype == DType.FP16:
accum_dtypes = [DType.FP16, DType.FP32]
- elif dtype == DType.FP32:
+ elif dtype == DType.BF16 or dtype == DType.FP32:
accum_dtypes = [DType.FP32]
elif error_name is None:
assert False, f"Invalid I/O DType for pooling: {DTypeNames[dtype]}"
@@ -1417,6 +1419,8 @@ class TosaArgGen:
dtypeList = [DType.INT8, DType.INT16, DType.INT32]
elif inDtype == DType.FP16:
dtypeList = [DType.INT8, DType.INT16, DType.INT32]
+ elif inDtype == DType.BF16:
+ dtypeList = [DType.INT8, DType.INT16, DType.INT32]
elif inDtype == DType.FP32:
dtypeList = [DType.INT8, DType.INT16, DType.INT32]
elif error_name == ErrorIf.WrongInputType:
@@ -1826,6 +1830,8 @@ class TosaArgGen:
outputDTypeList = [DType.INT48]
elif dtype == DType.FP16:
outputDTypeList = [DType.FP16]
+ elif dtype == DType.BF16:
+ outputDTypeList = [DType.BF16]
elif dtype == DType.FP32:
outputDTypeList = [DType.FP32]
elif error_name == ErrorIf.WrongInputType:
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index abe1a97..a850699 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -158,6 +158,15 @@ class TosaErrorIfArgGen:
DType.INT48,
DType.FP32,
)
+ elif dtype == DType.BF16:
+ incorrect_types = (
+ DType.INT4,
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.INT48,
+ DType.FP32,
+ )
elif dtype == DType.FP32:
incorrect_types = (
DType.INT4,
@@ -299,8 +308,8 @@ class TosaErrorIfArgGen:
@staticmethod
def eiCastErrorIf(testGen, input_dtype):
- if input_dtype in [DType.BOOL, DType.FP16, DType.FP32]:
- outputDType = [DType.BOOL, DType.INT48, DType.FP16, DType.FP32]
+ if input_dtype in [DType.BOOL, DType.FP16, DType.BF16, DType.FP32]:
+ outputDType = [DType.BOOL, DType.INT48, DType.FP16, DType.BF16, DType.FP32]
elif input_dtype in [DType.INT8, DType.INT16, DType.INT32]:
outputDType = [DType.INT48]
else:
@@ -425,6 +434,7 @@ class TosaErrorValidator:
and output_dtype != DType.INT48
)
or (input_dtype == DType.FP16 and output_dtype != DType.FP16)
+ or (input_dtype == DType.BF16 and output_dtype != DType.BF16)
or (input_dtype == DType.FP32 and output_dtype != DType.FP32)
):
error_result = True
@@ -442,25 +452,29 @@ class TosaErrorValidator:
input_dtype == DType.FP16
and output_dtype not in (DType.FP16, DType.FP32)
)
+ or (input_dtype == DType.BF16 and output_dtype != DType.FP32)
or (input_dtype == DType.FP32 and output_dtype != DType.FP32)
):
error_result = True
elif op["op"] == Op.ARGMAX:
if (
- input_dtype in [DType.INT8, DType.INT16, DType.FP16, DType.FP32]
+ input_dtype
+ in [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
and output_dtype != DType.INT32
):
error_result = True
elif op["op"] == Op.MUL:
if (
- input_dtype not in (DType.FP16, DType.FP32)
+ input_dtype not in (DType.FP16, DType.BF16, DType.FP32)
and output_dtype != DType.INT32
):
error_result = True
elif input_dtype == DType.FP16 and output_dtype != DType.FP16:
error_result = True
+ elif input_dtype == DType.BF16 and output_dtype != DType.BF16:
+ error_result = True
elif input_dtype == DType.FP32 and output_dtype != DType.FP32:
error_result = True
@@ -489,6 +503,7 @@ class TosaErrorValidator:
DType.INT32,
DType.FP32,
DType.FP16,
+ DType.BF16,
]
)
or (
@@ -500,6 +515,7 @@ class TosaErrorValidator:
DType.INT32,
DType.FP32,
DType.FP16,
+ DType.BF16,
]
)
or (
@@ -511,6 +527,7 @@ class TosaErrorValidator:
DType.INT16,
DType.FP32,
DType.FP16,
+ DType.BF16,
]
)
or (
@@ -518,6 +535,10 @@ class TosaErrorValidator:
and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
)
or (
+ input_dtype == DType.BF16
+ and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
+ )
+ or (
input_dtype == DType.FP32
and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
)
@@ -537,6 +558,8 @@ class TosaErrorValidator:
and output_dtype != DType.INT48
or input_dtype == DType.FP16
and output_dtype not in (DType.FP16, DType.FP32)
+ or input_dtype == DType.BF16
+ and output_dtype != DType.FP32
or input_dtype == DType.FP32
and output_dtype != DType.FP32
):
@@ -2316,12 +2339,14 @@ class TosaInvalidValidator:
not (input_dtype == DType.INT8 and output_dtype == DType.INT32)
and not (input_dtype == DType.INT16 and output_dtype == DType.INT48)
and not (input_dtype == DType.FP16 and output_dtype == DType.FP16)
+ and not (input_dtype == DType.BF16 and output_dtype == DType.BF16)
and not (input_dtype == DType.FP32 and output_dtype == DType.FP32)
)
elif mode == ResizeMode.NEAREST:
# Invalid output data type / Invalid input datatype
return (input_dtype != output_dtype) or (
- input_dtype not in [DType.INT8, DType.INT16, DType.FP16, DType.FP32]
+ input_dtype
+ not in [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
)
else:
# Invalid resize mode
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 78d86cd..95e06ed 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -16,6 +16,7 @@ from generator.tosa_error_if import TosaInvalidValidator
from generator.tosa_utils import DTYPE_ATTRIBUTES
from generator.tosa_utils import MAX_RESIZE_DIMENSION
from generator.tosa_utils import usableDTypes
+from generator.tosa_utils import vect_f32_to_bf16
from tosa.DType import DType
from tosa.Op import Op
@@ -84,6 +85,10 @@ class TosaTestGen:
)
elif dtype == DType.FP16:
return np.float16(self.rng.random(size=shape))
+ elif dtype == DType.BF16:
+ f32_tensor = np.float32(self.rng.random(size=shape))
+ # Floor the last 16 bits of each f32 value
+ return np.float32(vect_f32_to_bf16(f32_tensor))
elif dtype == DType.FP32:
return np.float32(self.rng.random(size=shape))
else:
@@ -134,6 +139,9 @@ class TosaTestGen:
elif dtype == DType.FP16:
rand_f32 = self.rng.random()
return np.float16(rand_f32)
+ elif dtype == DType.BF16:
+ rand_f32 = self.rng.random()
+ return vect_f32_to_bf16(rand_f32)
elif dtype == DType.BOOL:
return self.rng.choice([False, True])
# TOSA specific INT4 weight range from -7 to 7
@@ -324,7 +332,7 @@ class TosaTestGen:
# Special for multiply:
# Force the result to INT32 for INT types
- if a.dtype not in (DType.FP16, DType.FP32):
+ if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
result_tens.setDtype(DType.INT32)
if error_name == ErrorIf.WrongOutputType:
all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
@@ -1043,7 +1051,7 @@ class TosaTestGen:
return None
attr = ts.TosaSerializerAttribute()
- if a.dtype in (DType.FP16, DType.FP32):
+ if a.dtype in (DType.FP16, DType.BF16, DType.FP32):
attr.ClampAttribute(0, 0, min_val, max_val)
else:
attr.ClampAttribute(min_val, max_val, 0, 0)
@@ -1859,7 +1867,7 @@ class TosaTestGen:
op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr
)
- if a.dtype in (DType.FP32, DType.FP16, DType.INT32):
+ if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
then_op, else_op = Op.ADD, Op.SUB
elif a.dtype in (DType.INT8, DType.INT16):
then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
@@ -2398,7 +2406,7 @@ class TosaTestGen:
# if not specified, defaults to (1, 4)
# 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
# 'types': array of datatypes to be tested
- TYPE_FP = [DType.FP32, DType.FP16]
+ TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]
TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
TYPE_INT_FP = [
@@ -2406,13 +2414,20 @@ class TosaTestGen:
DType.INT16,
DType.INT32,
DType.FP16,
+ DType.BF16,
DType.FP32,
] # Excludes INT4
TYPE_BOOL = [DType.BOOL]
- TYPE_FI32 = [DType.FP32, DType.FP16, DType.INT32] # floating-types and INT32
+ TYPE_FI32 = [
+ DType.FP32,
+ DType.FP16,
+ DType.BF16,
+ DType.INT32,
+ ] # floating-types and INT32
TYPE_FIB = [
DType.FP16,
+ DType.BF16,
DType.FP32,
DType.INT8,
DType.INT16,
@@ -2421,7 +2436,7 @@ class TosaTestGen:
]
TYPE_FI16 = [DType.FP32, DType.INT16]
- TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.FP32]
+ TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
# List of [Input Type 1, Input Type 2, Accumulator Type]
TYPE_CONV = [
@@ -2430,6 +2445,7 @@ class TosaTestGen:
[DType.INT16, DType.INT8, DType.INT48],
[DType.FP16, DType.FP16, DType.FP16],
[DType.FP16, DType.FP16, DType.FP32],
+ [DType.BF16, DType.BF16, DType.FP32],
[DType.FP32, DType.FP32, DType.FP32],
]
@@ -3448,7 +3464,7 @@ class TosaTestGen:
TosaTensorValuesGen.tvgReduceSum,
TosaArgGen.agAxis,
),
- "types": (DType.FP16, DType.FP32, DType.INT32),
+ "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
"error_if_validators": (
TosaErrorValidator.evAxisLargerRank,
TosaErrorValidator.evAxisSmallerZero,
@@ -3635,7 +3651,14 @@ class TosaTestGen:
TosaTensorValuesGen.tvgDefault,
None,
),
- "types": (DType.INT8, DType.INT16, DType.INT32, DType.FP16, DType.FP32),
+ "types": (
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.FP16,
+ DType.BF16,
+ DType.FP32,
+ ),
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
@@ -3676,7 +3699,7 @@ class TosaTestGen:
TosaTensorValuesGen.tvgDefault,
TosaArgGen.agResize,
),
- "types": (DType.INT8, DType.INT16, DType.FP16, DType.FP32),
+ "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
"invalid_test_validators": (
TosaInvalidValidator.ivWrongDataTypeOrModeResize,
),
@@ -3712,6 +3735,7 @@ class TosaTestGen:
),
"types": (
DType.FP16,
+ DType.BF16,
DType.FP32,
DType.INT8,
DType.INT16,
@@ -3842,6 +3866,8 @@ class OutputShaper:
DType.INT16,
DType.INT32,
DType.INT48,
+ DType.FP16,
+ DType.BF16,
DType.FP32,
]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
@@ -3872,6 +3898,8 @@ class OutputShaper:
DType.INT32,
DType.INT48,
DType.FP32,
+ DType.FP16,
+ DType.BF16,
]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
@@ -3900,6 +3928,8 @@ class OutputShaper:
DType.INT32,
DType.INT48,
DType.FP32,
+ DType.FP16,
+ DType.BF16,
]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
@@ -3929,6 +3959,8 @@ class OutputShaper:
DType.INT32,
DType.INT48,
DType.FP32,
+ DType.FP16,
+ DType.BF16,
]
outputDType = rng.choice(wrong_dtypes)
else:
@@ -3955,6 +3987,8 @@ class OutputShaper:
DType.INT32,
DType.INT48,
DType.FP32,
+ DType.FP16,
+ DType.BF16,
]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
@@ -3987,6 +4021,8 @@ class OutputShaper:
DType.INT32,
DType.INT48,
DType.FP32,
+ DType.FP16,
+ DType.BF16,
]
wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
outputDType = rng.choice(wrong_dtypes)
@@ -4189,6 +4225,7 @@ class OutputShaper:
DType.INT48,
DType.FP32,
DType.FP16,
+ DType.BF16,
]
wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
outputDType = rng.choice(wrong_dtypes)
@@ -4226,6 +4263,8 @@ class OutputShaper:
DType.INT16,
DType.INT48,
DType.FP32,
+ DType.FP16,
+ DType.BF16,
)
elif a.dtype == DType.INT16:
incorrect_types = (
@@ -4234,8 +4273,12 @@ class OutputShaper:
DType.INT16,
DType.INT32,
DType.FP32,
+ DType.FP16,
+ DType.BF16,
)
- elif a.dtype == DType.FP32 or a.dtype == DType.FP16:
+ elif (
+ a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
+ ):
incorrect_types = (
DType.INT4,
DType.INT8,
@@ -4278,6 +4321,8 @@ class OutputShaper:
DType.INT32,
DType.INT48,
DType.FP32,
+ DType.FP16,
+ DType.BF16,
}
wrong_dtypes = list(all_dtypes - set([input1.dtype]))
outputDType = rng.choice(wrong_dtypes)
@@ -4306,6 +4351,7 @@ class OutputShaper:
DType.INT48,
DType.FP32,
DType.FP16,
+ DType.BF16,
]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
@@ -4329,6 +4375,8 @@ class OutputShaper:
DType.INT32,
DType.INT48,
DType.FP32,
+ DType.FP16,
+ DType.BF16,
]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
@@ -4347,6 +4395,8 @@ class OutputShaper:
DType.INT32,
DType.INT48,
DType.FP32,
+ DType.FP16,
+ DType.BF16,
]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
@@ -4383,6 +4433,8 @@ class OutputShaper:
DType.INT32,
DType.INT48,
DType.FP32,
+ DType.FP16,
+ DType.BF16,
]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
@@ -4411,6 +4463,8 @@ class OutputShaper:
DType.INT32,
DType.INT48,
DType.FP32,
+ DType.FP16,
+ DType.BF16,
]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
@@ -4435,6 +4489,8 @@ class OutputShaper:
DType.INT32,
DType.INT48,
DType.FP32,
+ DType.FP16,
+ DType.BF16,
]
wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
outputDType = rng.choice(wrong_dtypes)
@@ -4462,6 +4518,8 @@ class OutputShaper:
DType.INT32,
DType.INT48,
DType.FP32,
+ DType.FP16,
+ DType.BF16,
]
wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
outputDType = rng.choice(wrong_dtypes)
@@ -4483,6 +4541,8 @@ class OutputShaper:
DType.INT32,
DType.INT48,
DType.FP32,
+ DType.FP16,
+ DType.BF16,
]
wrong_dtypes.remove(output_dtype)
output_dtype = rng.choice(wrong_dtypes)
diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py
index 104d9bb..d79ab3c 100644
--- a/verif/generator/tosa_utils.py
+++ b/verif/generator/tosa_utils.py
@@ -1,5 +1,9 @@
# Copyright (c) 2021-2022, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
+import struct
+import sys
+
+import numpy as np
from tosa.DType import DType
# Maximum dimension size for output and inputs for RESIZE
@@ -15,6 +19,7 @@ DTYPE_ATTRIBUTES = {
DType.INT32: {"str": "i32", "width": 32},
DType.INT48: {"str": "i48", "width": 48},
DType.FP16: {"str": "f16", "width": 16},
+ DType.BF16: {"str": "bf16", "width": 16},
DType.FP32: {"str": "f32", "width": 32},
}
@@ -125,7 +130,11 @@ def get_wrong_output_type(op_name, rng, input_dtype):
DType.FP32,
DType.FP16,
)
- elif input_dtype == DType.FP32 or input_dtype == DType.FP16:
+ elif (
+ input_dtype == DType.FP32
+ or input_dtype == DType.FP16
+ or input_dtype == DType.BF16
+ ):
incorrect_types = (
DType.INT4,
DType.INT8,
@@ -134,3 +143,37 @@ def get_wrong_output_type(op_name, rng, input_dtype):
DType.INT48,
)
return rng.choice(a=incorrect_types)
+
+
+def float32_is_valid_bfloat16(f):
+ """Return True if float value is valid bfloat16."""
+ f32_bits = get_float32_bitstring(f)
+ return f32_bits[16:] == "0" * 16
+
+
+def get_float32_bitstring(f):
+ """Return a big-endian string of bits representing a 32 bit float."""
+ f32_bits_as_int = struct.unpack(">L", struct.pack(">f", f))[0]
+ return f"{f32_bits_as_int:032b}"
+
+
+def float32_to_bfloat16(f):
+ """Turns fp32 value into bfloat16 by flooring.
+
+ Floors the least significant 16 bits of the input
+ fp32 value and returns this valid bfloat16 representation as fp32.
+ For simplicity during bit-wrangling, ignores underlying system
+ endianness and interprets as big-endian.
+ Returns a bf16-valid float following system's native byte order.
+ """
+ f32_bits = get_float32_bitstring(f)
+ f32_floored_bits = f32_bits[:16] + "0" * 16
+
+ # Assume sys.byteorder matches system's underlying float byteorder
+ fp_bytes = int(f32_floored_bits, 2).to_bytes(4, byteorder=sys.byteorder)
+ return struct.unpack("@f", fp_bytes)[0] # native byteorder
+
+
+vect_f32_to_bf16 = np.vectorize(
+ float32_to_bfloat16, otypes=(np.float32,)
+) # NumPy vectorize: applies function to vector faster than looping
diff --git a/verif/generator/tosa_verif_build_tests.py b/verif/generator/tosa_verif_build_tests.py
index 2fafacb..ab78b1a 100644
--- a/verif/generator/tosa_verif_build_tests.py
+++ b/verif/generator/tosa_verif_build_tests.py
@@ -5,6 +5,7 @@ import re
from generator.tosa_test_gen import TosaTestGen
from serializer.tosa_serializer import dtype_str_to_val
+from serializer.tosa_serializer import DTypeNames
# Used for parsing a comma-separated list of integers in a string
@@ -150,13 +151,14 @@ def parseArgs(argv):
help="Create tests with a particular input tensor rank",
)
+ # Used for parsing a comma-separated list of integers in a string
parser.add_argument(
"--target-dtype",
dest="target_dtypes",
action="append",
default=None,
type=lambda x: dtype_str_to_val(x),
- help="Create test with a particular DType (may be repeated)",
+ help=f"Create test with a particular DType: [{', '.join([d.lower() for d in DTypeNames[1:]])}] (may be repeated)",
)
parser.add_argument(
diff --git a/verif/tests/test_tosa_refmodel.py b/verif/tests/test_tosa_refmodel.py
index b608fd8..50ff1ab 100644
--- a/verif/tests/test_tosa_refmodel.py
+++ b/verif/tests/test_tosa_refmodel.py
@@ -47,6 +47,7 @@ REF_MODEL_TYPE_TO_OUT = {
"int32": "i32",
"fp32": "f32",
"fp16": "f16",
+ "bf16": "bf16",
}
@@ -127,11 +128,13 @@ TEST_PARAMS = [
("abs", "int32", 1),
("abs", "fp32", 1),
("abs", "fp16", 1),
+ ("abs", "bf16", 1),
("negate", "int8", 1),
("negate", "int16", 1),
("negate", "int32", 1),
("negate", "fp32", 1),
("negate", "fp16", 1),
+ ("negate", "bf16", 1),
# One test per axis (shape dimensions)
("concat", "bool", SHAPE_DIMS),
("concat", "int8", SHAPE_DIMS),
@@ -139,6 +142,7 @@ TEST_PARAMS = [
("concat", "int32", SHAPE_DIMS),
("concat", "fp32", SHAPE_DIMS),
("concat", "fp16", SHAPE_DIMS),
+ ("concat", "bf16", SHAPE_DIMS),
]
@@ -165,6 +169,9 @@ def test_refmodel_simple_op(tosaTest):
# Generate TOSA test(s) (mostly should be single test)
test_dirs = tosaTest.create_test()
+ # Indicate miscellaneous checks to run in tosa_check
+ misc_checks = []
+
for test_dir in test_dirs:
# Run ref model
desc_file = test_dir / TEST_DESC_FILENAME
@@ -227,8 +234,15 @@ def test_refmodel_simple_op(tosaTest):
np.save(str(result_file), result)
assert result_file.is_file()
+ # Ensure valid bf16
+ if tosaTest.ref_model_type == "bf16":
+ misc_checks.append("bf16")
+
# Check Numpy result versus refmodel
check_result, tolerance, msg = tosa_check(
- str(result_file), str(ofm_file), test_name=test_dir.name
+ str(result_file),
+ str(ofm_file),
+ test_name=test_dir.name,
+ misc_checks=misc_checks,
)
assert check_result == TosaResult.PASS