aboutsummaryrefslogtreecommitdiff
path: root/verif
diff options
context:
space:
mode:
authorJames Ward <james.ward@arm.com>2022-08-12 20:48:56 +0100
committerJames Ward <james.ward@arm.com>2022-10-11 11:56:02 +0100
commit8b39043c70332e1e4c95ee6a9616aec40dd3baf1 (patch)
treefea519246b698eb944b9d58537fc90bc30481d11 /verif
parentba5fad356a926d5e1c6e0fe6b546a310230cc5a8 (diff)
downloadreference_model-8b39043c70332e1e4c95ee6a9616aec40dd3baf1.tar.gz
Reference model changes for fp16 support
Change-Id: I72f21fcfa153046274969d327313e3349981dbe6 Signed-off-by: James Ward <james.ward@arm.com>
Diffstat (limited to 'verif')
-rw-r--r--verif/checker/tosa_result_checker.py4
-rw-r--r--verif/generator/tosa_arg_gen.py230
-rw-r--r--verif/generator/tosa_error_if.py74
-rw-r--r--verif/generator/tosa_test_gen.py298
-rw-r--r--verif/generator/tosa_utils.py39
-rw-r--r--verif/tests/test_tosa_result_checker.py2
6 files changed, 444 insertions, 203 deletions
diff --git a/verif/checker/tosa_result_checker.py b/verif/checker/tosa_result_checker.py
index 66864c2..8ae3218 100644
--- a/verif/checker/tosa_result_checker.py
+++ b/verif/checker/tosa_result_checker.py
@@ -147,14 +147,14 @@ def test_check(
tolerance = 0.0
# Fall-through to below to add failure values
- elif reference_result.dtype == np.float32:
+ # TODO: update for fp16 tolerance
+ elif reference_result.dtype == np.float32 or reference_result.dtype == np.float16:
tolerance = float_tolerance
if np.allclose(reference_result, test_result, atol=tolerance, equal_nan=True):
print_color(LogColors.GREEN, "Results PASS {}".format(test_name))
return (TestResult.PASS, tolerance, "")
msg = "Float result does not match within tolerance of {}".format(tolerance)
# Fall-through to below to add failure values
-
else:
print_color(LogColors.RED, "Results UNSUPPORTED TYPE {}".format(test_name))
msg = "Unsupported results type: {}".format(reference_result.dtype)
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index a65e220..69968d3 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -2,10 +2,13 @@
# SPDX-License-Identifier: Apache-2.0
import itertools
import math
+import warnings
import numpy as np
from generator.tosa_error_if import ErrorIf
from generator.tosa_error_if import TosaErrorIfArgGen
+from generator.tosa_utils import get_accum_dtype_from_tgTypes
+from generator.tosa_utils import get_wrong_output_type
from generator.tosa_utils import MAX_RESIZE_DIMENSION
from serializer.tosa_serializer import DTypeNames
from tosa.DType import DType
@@ -773,7 +776,7 @@ class TosaTensorValuesGen:
), "Op.MUL must have 2 placeholders, 0 consts"
tens = []
- if dtypeList[0] == DType.FLOAT:
+ if dtypeList[0] in (DType.FP16, DType.FLOAT):
tens.extend(testGen.buildPlaceholderTensors(shapeList[:], dtypeList[:]))
else:
placeholders = []
@@ -982,7 +985,7 @@ class TosaArgGen:
return axes
@staticmethod
- def agConv(testGen, opName, shapeList, dtype, error_name=None):
+ def agConv(testGen, opName, shapeList, dtypes, error_name=None):
arg_list = []
ifm_shape = shapeList[0]
@@ -990,6 +993,8 @@ class TosaArgGen:
# determine the kernel shape from operator name (e.g. "conv2d_3x3" => [3,3])
k = [int(x) for x in opName.split("_")[-1].split("x")]
+ accum_dtype = get_accum_dtype_from_tgTypes(dtypes)
+
# Check the rank
rank = 5 if opName.startswith("conv3d") else 4
if error_name != ErrorIf.WrongRank:
@@ -1089,12 +1094,13 @@ class TosaArgGen:
):
arg_list.append(
(
- "st{}_pad{}_dilat{}".format(
+ "acc{}_st{}_pad{}_dilat{}".format(
+ testGen.typeStr(accum_dtype),
"".join([str(x) for x in s]),
"".join([str(x) for x in p]),
"".join([str(x) for x in d]),
),
- [s, p, d],
+ [accum_dtype, s, p, d],
)
)
n += 1
@@ -1102,12 +1108,55 @@ class TosaArgGen:
return arg_list
@staticmethod
- def agTransposeConv2D(testGen, opName, shapeList, dtype, error_name=None):
+ def agFullyConnected(testGen, opName, shapeList, dtypes, error_name=None):
+
+ if isinstance(dtypes, list) or isinstance(dtypes, tuple):
+ input_dtype = dtypes[0]
+ else:
+ input_dtype = dtypes
+
+ if error_name == ErrorIf.WrongOutputType:
+ accum_dtype = get_wrong_output_type(opName, testGen.rng, input_dtype)
+ elif error_name == ErrorIf.WrongInputType:
+ # Pick some potentially correct output dtype if input type is incorrect
+ accum_dtype = DType.INT32
+ else:
+ accum_dtype = get_accum_dtype_from_tgTypes(dtypes)
+
+ return [(f"acc{testGen.typeStr(accum_dtype)}", [accum_dtype])]
+
+ @staticmethod
+ def agMatMul(testGen, opName, shapeList, dtype, error_name=None):
+ # Get valid accumulate type(s)
+ if dtype == DType.INT8:
+ accum_dtypes = [DType.INT32]
+ elif dtype == DType.INT16:
+ accum_dtypes = [DType.INT48]
+ elif dtype == DType.FP16:
+ accum_dtypes = [DType.FP16, DType.FLOAT]
+ elif dtype == DType.FLOAT:
+ accum_dtypes = [DType.FLOAT]
+ elif error_name is None:
+ assert False, f"Invalid I/O DType for MatMul: {DTypeNames[dtype]}"
+
+ if error_name == ErrorIf.WrongOutputType:
+ # Get incorrect output dtype for ErrorIf case
+ accum_dtypes = [get_wrong_output_type(opName, testGen.rng, dtype)]
+ elif error_name == ErrorIf.WrongInputType:
+ # Pick some potentially correct output dtype if input type is incorrect
+ accum_dtypes = [DType.INT32]
+
+ return [(f"acc{testGen.typeStr(a)}", [a]) for a in accum_dtypes]
+
+ @staticmethod
+ def agTransposeConv2D(testGen, opName, shapeList, dtypes, error_name=None):
arg_list = []
ifm_shape = shapeList[0]
filter_shape = shapeList[1]
+ accum_dtype = get_accum_dtype_from_tgTypes(dtypes)
+
# Must be rank 4
if error_name != ErrorIf.WrongRank:
assert len(ifm_shape) == 4
@@ -1169,12 +1218,13 @@ class TosaArgGen:
os = [ifm_shape[0], oh, ow, filter_shape[0]]
arg_list.append(
(
- "st{}_pad{}_os{}".format(
+ "acc{}_st{}_pad{}_os{}".format(
+ testGen.typeStr(accum_dtype),
"".join([str(x) for x in s]),
"".join([str(x) for x in p]),
"x".join([str(x) for x in os]),
),
- [s, p, os],
+ [accum_dtype, s, p, os],
)
)
n += 1
@@ -1199,18 +1249,38 @@ class TosaArgGen:
if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]:
pad_const_int = testGen.getRandNumberDType(dtype)
pad_const_fp = 0
- elif dtype == DType.FLOAT:
+ elif dtype in (DType.FP16, DType.FLOAT):
pad_const_int = 0
pad_const_fp = testGen.getRandNumberDType(dtype)
else:
return []
for paddings in shape_pad_values:
- name = "pad"
- for r in range(rank):
- before, after = paddings[r]
- name = f"{name}{before}{after}"
- arg_list.append((name, [np.array(paddings), pad_const_int, pad_const_fp]))
+ paddings = list(paddings)
+ args_valid = True
+
+ if error_name == ErrorIf.PadSmallerZero:
+ # Prevent negative output shapes while ensuring still testing for negative padding
+ for i in range(rank):
+ dim_after_padding = (
+ paddings[i][0] + paddings[i][1] + shapeList[0][i]
+ )
+ if dim_after_padding < 1:
+ paddings[i] = (0, 0)
+ if all([p > -1 for p in paddings[i]]):
+ args_valid = False
+
+ if args_valid:
+ name = "pad"
+ for r in range(rank):
+ before, after = paddings[r]
+ name = f"{name}{before}{after}"
+ arg_list.append(
+ (name, [np.array(paddings), pad_const_int, pad_const_fp])
+ )
+
+ if error_name == ErrorIf.PadSmallerZero and len(arg_list) == 0:
+ warnings.warn(f"No ErrorIf test created for input shape: {shapeList[0]}")
return arg_list
@@ -1232,6 +1302,21 @@ class TosaArgGen:
k_vals = [x for x in range(2, testGen.args.max_pooling_kernel + 1)]
kernels = {x for x in itertools.product(*([k_vals] * 2))}
+ if opName == "max_pool2d":
+ accum_dtypes = [None] # max_pool has no accumulate dtype
+ elif dtype == DType.INT8 or dtype == DType.INT16:
+ accum_dtypes = [DType.INT32]
+ elif dtype == DType.FP16:
+ accum_dtypes = [DType.FP16, DType.FLOAT]
+ elif dtype == DType.FLOAT:
+ accum_dtypes = [DType.FLOAT]
+ elif error_name is None:
+ assert False, f"Invalid I/O DType for pooling: {DTypeNames[dtype]}"
+ else:
+ # Set to something for the ErrorIf case which has
+ # incorrect input data-type
+ accum_dtypes = [DType.INT32]
+
if testGen.args.oversize:
# add some oversize argument values
bigStride = 7
@@ -1252,63 +1337,70 @@ class TosaArgGen:
sparsity_factor = 2 if error_name else 500
sparsity = len(paddings) * len(strides) * len(kernels) // sparsity_factor + 1
+ arg_str = (
+ "acc{}_st{}_kern{}_pad{}"
+ if accum_dtypes[0] is not None
+ else "st{}_kern{}_pad{}"
+ )
+
+ def get_arg_list_element(accum, stride, pad, kern):
+ # Return tuple containing the formatted argument string and
+ # the corresponding argument values
+ arg_str_elems = [
+ "".join([str(x) for x in stride]),
+ "".join([str(x) for x in kern]),
+ "".join([str(x) for x in pad]),
+ ]
+ # Note: different order to string
+ arg_val_elems = [stride, pad, kern]
+
+ if accum is not None:
+ arg_str_elems.insert(0, testGen.typeStr(accum))
+ arg_val_elems.insert(0, accum)
+ return (arg_str.format(*arg_str_elems), arg_val_elems)
+
n = 0
- for s in sorted(list(strides)):
- for p in sorted(list(paddings)):
- for k in sorted(list(kernels)):
- if error_name in [
- ErrorIf.StrideSmallerOne,
- ErrorIf.KernelSmallerOne,
- ErrorIf.PadSmallerZero,
- ErrorIf.PadLargerEqualKernel,
- ]:
- sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf(
- testGen, error_name, s, p, k
- )
- if None not in [sNew, pNew, kNew] and n % sparsity == 0:
- arg_list.append(
- (
- "st{}_kern{}_pad{}".format(
- "".join([str(x) for x in sNew]),
- "".join([str(x) for x in kNew]),
- "".join([str(x) for x in pNew]),
- ),
- [sNew, pNew, kNew],
- )
+ for a in accum_dtypes:
+ for s in sorted(list(strides)):
+ for p in sorted(list(paddings)):
+ for k in sorted(list(kernels)):
+ if error_name in [
+ ErrorIf.StrideSmallerOne,
+ ErrorIf.KernelSmallerOne,
+ ErrorIf.PadSmallerZero,
+ ErrorIf.PadLargerEqualKernel,
+ ]:
+ sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf(
+ testGen, error_name, s, p, k
)
- elif (
- n % sparsity == 0
- # padding must not exceed the kernel size
- and p[0] < k[0]
- and p[1] < k[0]
- and p[2] < k[1]
- and p[3] < k[1]
- # the padded shape must exceed the kernel size
- and (shape[1] + p[0] + p[1]) > k[0]
- and (shape[2] + p[2] + p[3]) > k[1]
- ):
- remainder_h = (shape[1] + p[0] + p[1] - k[0]) % s[0]
- remainder_w = (shape[2] + p[2] + p[3] - k[1]) % s[1]
- if (
- # the parameters must produce integer exact output
- error_name != ErrorIf.PoolingOutputShapeNonInteger
- and remainder_h == 0
- and remainder_w == 0
- ) or (
- error_name == ErrorIf.PoolingOutputShapeNonInteger
- and (remainder_h != 0 or remainder_w != 0)
+ if None not in [sNew, pNew, kNew] and n % sparsity == 0:
+ arg_vals = [a, sNew, pNew, kNew]
+ arg_list.append(get_arg_list_element(*arg_vals))
+ elif (
+ n % sparsity == 0
+ # padding must not exceed the kernel size
+ and p[0] < k[0]
+ and p[1] < k[0]
+ and p[2] < k[1]
+ and p[3] < k[1]
+ # the padded shape must exceed the kernel size
+ and (shape[1] + p[0] + p[1]) > k[0]
+ and (shape[2] + p[2] + p[3]) > k[1]
):
- arg_list.append(
- (
- "st{}_kern{}_pad{}".format(
- "".join([str(x) for x in s]),
- "".join([str(x) for x in k]),
- "".join([str(x) for x in p]),
- ),
- [s, p, k],
- )
- )
- n += 1
+ remainder_h = (shape[1] + p[0] + p[1] - k[0]) % s[0]
+ remainder_w = (shape[2] + p[2] + p[3] - k[1]) % s[1]
+ if (
+ # the parameters must produce integer exact output
+ error_name != ErrorIf.PoolingOutputShapeNonInteger
+ and remainder_h == 0
+ and remainder_w == 0
+ ) or (
+ error_name == ErrorIf.PoolingOutputShapeNonInteger
+ and (remainder_h != 0 or remainder_w != 0)
+ ):
+ arg_vals = [a, s, p, k]
+ arg_list.append(get_arg_list_element(*arg_vals))
+ n += 1
return arg_list
@@ -1327,6 +1419,8 @@ class TosaArgGen:
dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
elif inDtype == DType.BOOL:
dtypeList = [DType.INT8, DType.INT16, DType.INT32]
+ elif inDtype == DType.FP16:
+ dtypeList = [DType.INT8, DType.INT16, DType.INT32]
elif inDtype == DType.FLOAT:
dtypeList = [DType.INT8, DType.INT16, DType.INT32]
elif error_name == ErrorIf.WrongInputType:
@@ -1734,6 +1828,8 @@ class TosaArgGen:
outputDTypeList = [DType.INT32]
elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
outputDTypeList = [DType.INT48]
+ elif dtype == DType.FP16:
+ outputDTypeList = [DType.FP16]
elif dtype == DType.FLOAT:
outputDTypeList = [DType.FLOAT]
elif error_name == ErrorIf.WrongInputType:
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index f9a00f9..a766803 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -120,6 +120,7 @@ class TosaErrorIfArgGen:
DType.INT32,
DType.INT48,
DType.FLOAT,
+ DType.FP16,
)
elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
incorrect_types = (
@@ -128,6 +129,7 @@ class TosaErrorIfArgGen:
DType.INT32,
DType.INT48,
DType.FLOAT,
+ DType.FP16,
)
elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
incorrect_types = (
@@ -136,6 +138,7 @@ class TosaErrorIfArgGen:
DType.INT16,
DType.INT48,
DType.FLOAT,
+ DType.FP16,
)
elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
incorrect_types = (
@@ -144,6 +147,16 @@ class TosaErrorIfArgGen:
DType.INT16,
DType.INT32,
DType.FLOAT,
+ DType.FP16,
+ )
+ elif dtype == DType.FP16:
+ incorrect_types = (
+ DType.INT4,
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.INT48,
+ DType.FLOAT,
)
elif dtype == DType.FLOAT:
incorrect_types = (
@@ -152,6 +165,7 @@ class TosaErrorIfArgGen:
DType.INT16,
DType.INT32,
DType.INT48,
+ DType.FP16,
)
outputDType = testGen.rng.choice(a=incorrect_types)
@@ -285,8 +299,8 @@ class TosaErrorIfArgGen:
@staticmethod
def eiCastErrorIf(testGen, input_dtype):
- if input_dtype in [DType.BOOL, DType.FLOAT]:
- outputDType = [DType.BOOL, DType.INT48, DType.FLOAT]
+ if input_dtype in [DType.BOOL, DType.FP16, DType.FLOAT]:
+ outputDType = [DType.BOOL, DType.INT48, DType.FP16, DType.FLOAT]
elif input_dtype in [DType.INT8, DType.INT16, DType.INT32]:
outputDType = [DType.INT48]
else:
@@ -400,6 +414,7 @@ class TosaErrorValidator:
and input_dtype == DType.INT16
and output_dtype != DType.INT48
)
+ or (input_dtype == DType.FP16 and output_dtype != DType.FP16)
or (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
):
error_result = True
@@ -413,19 +428,28 @@ class TosaErrorValidator:
if (
(input_dtype == DType.INT8 and output_dtype != DType.INT32)
or (input_dtype == DType.INT16 and output_dtype != DType.INT48)
+ or (
+ input_dtype == DType.FP16
+ and output_dtype not in (DType.FP16, DType.FLOAT)
+ )
or (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
):
error_result = True
elif op["op"] == Op.ARGMAX:
if (
- input_dtype in [DType.INT8, DType.INT16, DType.FLOAT]
+ input_dtype in [DType.INT8, DType.INT16, DType.FP16, DType.FLOAT]
and output_dtype != DType.INT32
):
error_result = True
elif op["op"] == Op.MUL:
- if input_dtype != DType.FLOAT and output_dtype != DType.INT32:
+ if (
+ input_dtype not in (DType.FP16, DType.FLOAT)
+ and output_dtype != DType.INT32
+ ):
+ error_result = True
+ elif input_dtype == DType.FP16 and output_dtype != DType.FP16:
error_result = True
elif input_dtype == DType.FLOAT and output_dtype != DType.FLOAT:
error_result = True
@@ -449,17 +473,39 @@ class TosaErrorValidator:
or (
input_dtype == DType.INT8
and output_dtype
- not in [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT]
+ not in [
+ DType.BOOL,
+ DType.INT16,
+ DType.INT32,
+ DType.FLOAT,
+ DType.FP16,
+ ]
)
or (
input_dtype == DType.INT16
and output_dtype
- not in [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT]
+ not in [
+ DType.BOOL,
+ DType.INT8,
+ DType.INT32,
+ DType.FLOAT,
+ DType.FP16,
+ ]
)
or (
input_dtype == DType.INT32
and output_dtype
- not in [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
+ not in [
+ DType.BOOL,
+ DType.INT8,
+ DType.INT16,
+ DType.FLOAT,
+ DType.FP16,
+ ]
+ )
+ or (
+ input_dtype == DType.FP16
+ and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
)
or (
input_dtype == DType.FLOAT
@@ -479,6 +525,8 @@ class TosaErrorValidator:
and output_dtype != DType.INT32
or input_dtype == DType.INT16
and output_dtype != DType.INT48
+ or input_dtype == DType.FP16
+ and output_dtype not in (DType.FP16, DType.FLOAT)
or input_dtype == DType.FLOAT
and output_dtype != DType.FLOAT
):
@@ -2257,12 +2305,13 @@ class TosaInvalidValidator:
return (
not (input_dtype == DType.INT8 and output_dtype == DType.INT32)
and not (input_dtype == DType.INT16 and output_dtype == DType.INT48)
+ and not (input_dtype == DType.FP16 and output_dtype == DType.FP16)
and not (input_dtype == DType.FLOAT and output_dtype == DType.FLOAT)
)
elif mode == ResizeMode.NEAREST:
# Invalid output data type / Invalid input datatype
return (input_dtype != output_dtype) or (
- input_dtype not in [DType.INT8, DType.INT16, DType.FLOAT]
+ input_dtype not in [DType.INT8, DType.INT16, DType.FP16, DType.FLOAT]
)
else:
# Invalid resize mode
@@ -2276,8 +2325,11 @@ class TosaInvalidValidator:
input_shape = inputShapes[0]
args = kwargs["args"]
- strides = args[0]
- padding = args[1]
+
+ # MaxPool2D has no accum_dtype arg
+ stride_idx, pad_idx = (0, 1) if opName == "max_pool2d" else (1, 2)
+ strides = args[stride_idx]
+ padding = args[pad_idx]
if opName.endswith("pool2d"):
# avg_pool2d, max_pool2d
@@ -2365,7 +2417,7 @@ class TosaInvalidValidator:
@staticmethod
def ivNonPositiveOutputShape(**kwargs):
args = kwargs["args"]
- output_shape = args[2]
+ output_shape = args[3]
if output_shape[1] <= 0 or output_shape[2] <= 0:
# Negative output shape
return True
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index b76b656..9ff6ec5 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -81,6 +81,8 @@ class TosaTestGen:
return np.int64(
self.rng.integers(low=-(1 << 47), high=(1 << 47), size=shape)
)
+ elif dtype == DType.FP16:
+ return np.float16(self.rng.random(size=shape))
elif dtype == DType.FLOAT:
return np.float32(self.rng.random(size=shape))
else:
@@ -128,6 +130,9 @@ class TosaTestGen:
def getRandNumberDType(self, dtype):
if dtype == DType.FLOAT:
return self.rng.random()
+ elif dtype == DType.FP16:
+ rand_f32 = self.rng.random()
+ return np.float16(rand_f32)
elif dtype == DType.BOOL:
return self.rng.choice([False, True])
# TOSA specific INT4 weight range from -7 to 7
@@ -178,13 +183,15 @@ class TosaTestGen:
return "i32"
elif t == DType.INT48:
return "i48"
+ elif t == DType.FP16:
+ return "f16"
elif t == DType.FLOAT:
return "float"
else:
raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
def typeWidth(self, t):
- """Get the datatype width for integer types"""
+ """Get the datatype width for data types"""
if t == DType.INT4:
return 4
elif t == DType.INT8:
@@ -199,6 +206,8 @@ class TosaTestGen:
return 32
elif t == DType.INT48:
return 48
+ elif t == DType.FP16:
+ return 16
elif t == DType.FLOAT:
return 32
elif t == DType.BOOL:
@@ -346,7 +355,7 @@ class TosaTestGen:
# Special for multiply:
# Force the result to INT32 for INT types
- if a.dtype != DType.FLOAT:
+ if a.dtype not in (DType.FP16, DType.FLOAT):
result_tens.setDtype(DType.INT32)
if error_name == ErrorIf.WrongOutputType:
all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
@@ -533,6 +542,7 @@ class TosaTestGen:
self,
op,
input,
+ accum_dtype,
stride,
pad,
kernel,
@@ -585,17 +595,43 @@ class TosaTestGen:
qinfo = [0, 0]
attr = ts.TosaSerializerAttribute()
- attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1])
+ attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype)
self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
+ def build_maxpool2d(
+ self,
+ op,
+ input,
+ stride,
+ pad,
+ kernel,
+ validator_fcns=None,
+ error_name=None,
+ qinfo=None,
+ ):
+ # Same as build_pool2d but manually sets accum_dtype value
+ # (maxpool has no accum_dtype)
+ return self.build_pool2d(
+ op,
+ input,
+ DType.UNKNOWN,
+ stride,
+ pad,
+ kernel,
+ validator_fcns,
+ error_name,
+ qinfo,
+ )
+
def build_conv2d(
self,
op,
ifm,
filter,
bias,
+ accum_dtype,
strides,
padding,
dilations,
@@ -605,7 +641,15 @@ class TosaTestGen:
):
assert len(padding) == 4
result_tens = OutputShaper.conv2dOp(
- self.ser, self.rng, ifm, filter, strides, padding, dilations, error_name
+ self.ser,
+ self.rng,
+ ifm,
+ filter,
+ accum_dtype,
+ strides,
+ padding,
+ dilations,
+ error_name,
)
# Ensure new output type has correct qinfo
@@ -648,7 +692,7 @@ class TosaTestGen:
return None
attr = ts.TosaSerializerAttribute()
- attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
+ attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], accum_dtype)
self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
@@ -659,6 +703,7 @@ class TosaTestGen:
ifm,
filter,
bias,
+ accum_dtype,
strides,
padding,
dilations,
@@ -668,7 +713,15 @@ class TosaTestGen:
):
assert len(padding) == 6
result_tens = OutputShaper.conv3dOp(
- self.ser, self.rng, ifm, filter, strides, padding, dilations, error_name
+ self.ser,
+ self.rng,
+ ifm,
+ filter,
+ accum_dtype,
+ strides,
+ padding,
+ dilations,
+ error_name,
)
# Ensure new output type has correct qinfo
@@ -711,7 +764,7 @@ class TosaTestGen:
return None
attr = ts.TosaSerializerAttribute()
- attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
+ attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], accum_dtype)
self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
@@ -722,6 +775,7 @@ class TosaTestGen:
ifm,
filter,
bias,
+ accum_dtype,
stride,
out_pad,
output_shape,
@@ -731,7 +785,7 @@ class TosaTestGen:
):
assert len(out_pad) == 4
result_tens = OutputShaper.transposeConv2DOp(
- self.ser, self.rng, ifm, output_shape, error_name
+ self.ser, self.rng, ifm, output_shape, accum_dtype, error_name
)
# Ensure new output type has correct qinfo
@@ -773,7 +827,9 @@ class TosaTestGen:
return None
attr = ts.TosaSerializerAttribute()
- attr.TransposeConvAttribute(out_pad, stride, output_shape, qinfo[0], qinfo[1])
+ attr.TransposeConvAttribute(
+ out_pad, stride, output_shape, qinfo[0], qinfo[1], accum_dtype
+ )
self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
@@ -784,6 +840,7 @@ class TosaTestGen:
ifm,
filter,
bias,
+ accum_dtype,
strides,
padding,
dilations,
@@ -792,7 +849,15 @@ class TosaTestGen:
qinfo=None,
):
result_tens = OutputShaper.depthwiseConv2dOp(
- self.ser, self.rng, ifm, filter, strides, padding, dilations, error_name
+ self.ser,
+ self.rng,
+ ifm,
+ filter,
+ accum_dtype,
+ strides,
+ padding,
+ dilations,
+ error_name,
)
# Ensure new output type has correct qinfo
@@ -835,16 +900,24 @@ class TosaTestGen:
return None
attr = ts.TosaSerializerAttribute()
- attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
+ attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], accum_dtype)
self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
def build_fully_connected(
- self, op, ifm, filter, bias, validator_fcns=None, error_name=None, qinfo=None
+ self,
+ op,
+ ifm,
+ filter,
+ bias,
+ accum_dtype,
+ validator_fcns=None,
+ error_name=None,
+ qinfo=None,
):
result_tens = OutputShaper.fullyConnectedOp(
- self.ser, self.rng, ifm, filter, error_name
+ self.ser, self.rng, ifm, filter, accum_dtype, error_name
)
# Invalidate Input/Output list for error if checks.
@@ -871,17 +944,22 @@ class TosaTestGen:
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
+ accum_dtype=accum_dtype,
):
return None
attr = ts.TosaSerializerAttribute()
- attr.FullyConnectedAttribute(qinfo[0], qinfo[1])
+ attr.FullyConnectedAttribute(qinfo[0], qinfo[1], accum_dtype)
self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
- def build_matmul(self, op, a, b, validator_fcns=None, error_name=None, qinfo=None):
- result_tens = OutputShaper.matmulOp(self.ser, self.rng, a, b, error_name)
+ def build_matmul(
+ self, op, a, b, accum_dtype, validator_fcns=None, error_name=None, qinfo=None
+ ):
+ result_tens = OutputShaper.matmulOp(
+ self.ser, self.rng, a, b, accum_dtype, error_name
+ )
# Invalidate Input/Output list for error if checks.
input_list = [a.name, b.name]
@@ -908,11 +986,12 @@ class TosaTestGen:
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
+ accum_dtype=accum_dtype,
):
return None
attr = ts.TosaSerializerAttribute()
- attr.MatMulAttribute(qinfo[0], qinfo[1])
+ attr.MatMulAttribute(qinfo[0], qinfo[1], accum_dtype)
self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
@@ -995,7 +1074,7 @@ class TosaTestGen:
return None
attr = ts.TosaSerializerAttribute()
- if a.dtype == DType.FLOAT:
+ if a.dtype in (DType.FP16, DType.FLOAT):
attr.ClampAttribute(0, 0, min_val, max_val)
else:
attr.ClampAttribute(min_val, max_val, 0, 0)
@@ -1811,7 +1890,7 @@ class TosaTestGen:
op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr
)
- if a.dtype in (DType.FLOAT, DType.INT32):
+ if a.dtype in (DType.FLOAT, DType.FP16, DType.INT32):
then_op, else_op = Op.ADD, Op.SUB
elif a.dtype in (DType.INT8, DType.INT16):
then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
@@ -2350,22 +2429,37 @@ class TosaTestGen:
# if not specified, defaults to (1, 4)
# 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
# 'types': array of datatypes to be tested
- TYPE_FP = [DType.FLOAT]
+ TYPE_FP = [DType.FLOAT, DType.FP16]
TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
- TYPE_INT_FP = [DType.INT8, DType.INT16, DType.INT32, DType.FLOAT] # Excludes INT4
+ TYPE_INT_FP = [
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.FP16,
+ DType.FLOAT,
+ ] # Excludes INT4
TYPE_BOOL = [DType.BOOL]
- TYPE_FI32 = [DType.FLOAT, DType.INT32]
- TYPE_FIB = [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL]
+ TYPE_FI32 = [DType.FLOAT, DType.FP16, DType.INT32] # floating-types and INT32
+ TYPE_FIB = [
+ DType.FP16,
+ DType.FLOAT,
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.BOOL,
+ ]
TYPE_FI16 = [DType.FLOAT, DType.INT16]
- TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FLOAT]
+ TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.FLOAT]
TYPE_CONV = [
[DType.INT8, DType.INT4, DType.INT32],
[DType.INT8, DType.INT8, DType.INT32],
[DType.INT16, DType.INT8, DType.INT48],
+ [DType.FP16, DType.FP16, DType.FP16],
+ [DType.FP16, DType.FP16, DType.FLOAT],
DType.FLOAT,
]
@@ -2524,7 +2618,7 @@ class TosaTestGen:
build_fully_connected,
TosaTensorGen.tgFullyConnected,
TosaTensorValuesGen.tvgDefault,
- None,
+ TosaArgGen.agFullyConnected,
),
"qgen": TosaQuantGen.qgConv,
"types": TYPE_CONV,
@@ -2546,7 +2640,7 @@ class TosaTestGen:
build_matmul,
TosaTensorGen.tgMatmul,
TosaTensorValuesGen.tvgDefault,
- None,
+ TosaArgGen.agMatMul,
),
"qgen": TosaQuantGen.qgMatmul,
"types": TYPE_NARROW_INT_FP,
@@ -2564,7 +2658,7 @@ class TosaTestGen:
"operands": (1, 0),
"rank": (4, 4),
"build_fcn": (
- build_pool2d,
+ build_maxpool2d,
TosaTensorGen.tgNHWC,
TosaTensorValuesGen.tvgDefault,
TosaArgGen.agPooling,
@@ -3384,7 +3478,7 @@ class TosaTestGen:
TosaTensorValuesGen.tvgReduceSum,
TosaArgGen.agAxis,
),
- "types": TYPE_FI32,
+ "types": (DType.FP16, DType.FLOAT, DType.INT32),
"error_if_validators": (
TosaErrorValidator.evAxisLargerRank,
TosaErrorValidator.evAxisSmallerZero,
@@ -3571,7 +3665,7 @@ class TosaTestGen:
TosaTensorValuesGen.tvgDefault,
None,
),
- "types": TYPE_INT_FP,
+ "types": (DType.INT8, DType.INT16, DType.INT32, DType.FP16, DType.FLOAT),
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
@@ -3612,7 +3706,7 @@ class TosaTestGen:
TosaTensorValuesGen.tvgDefault,
TosaArgGen.agResize,
),
- "types": [DType.INT8, DType.INT16, DType.FLOAT],
+ "types": (DType.INT8, DType.INT16, DType.FP16, DType.FLOAT),
"invalid_test_validators": (
TosaInvalidValidator.ivWrongDataTypeOrModeResize,
),
@@ -3646,7 +3740,14 @@ class TosaTestGen:
TosaTensorValuesGen.tvgDefault,
TosaArgGen.agCast,
),
- "types": [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL],
+ "types": (
+ DType.FP16,
+ DType.FLOAT,
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.BOOL,
+ ),
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
@@ -3925,7 +4026,9 @@ class OutputShaper:
return ser.addOutput(shape, outputDType)
@staticmethod
- def conv2dOp(ser, rng, ifm, filter, strides, padding, dilations, error_name=None):
+ def conv2dOp(
+ ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
+ ):
# IFM: NHWC
# Filter: OHWI
@@ -3958,26 +4061,26 @@ class OutputShaper:
ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
- if ifm.dtype == DType.INT8:
- out_dtype = DType.INT32
- elif ifm.dtype == DType.INT16:
- out_dtype = DType.INT48
- elif ifm.dtype == DType.FLOAT:
- out_dtype = DType.FLOAT
- elif error_name == ErrorIf.WrongInputType:
+ if error_name == ErrorIf.WrongInputType:
# Pick some potentially correct output dtype if input type is incorrect
out_dtype = DType.INT32
else:
- raise Exception(f"Unsupported input dtype: {ifm.dtype}")
+ out_dtype = accum_dtype
if error_name == ErrorIf.WrongOutputType:
- wrong_dtypes = list(usableDTypes(excludes=[out_dtype]))
+ if ifm.dtype == DType.FP16:
+ excludes = [DType.FP16, DType.FLOAT]
+ else:
+ excludes = [out_dtype]
+ wrong_dtypes = list(usableDTypes(excludes=excludes))
out_dtype = rng.choice(wrong_dtypes)
return ser.addOutput(ofm_shape, out_dtype)
@staticmethod
- def conv3dOp(ser, rng, ifm, filter, strides, padding, dilations, error_name=None):
+ def conv3dOp(
+ ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
+ ):
# IFM: NDHWC
# Filter: ODHWI
@@ -4020,27 +4123,25 @@ class OutputShaper:
ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
- if ifm.dtype == DType.INT8:
- out_dtype = DType.INT32
- elif ifm.dtype == DType.INT16:
- out_dtype = DType.INT48
- elif ifm.dtype == DType.FLOAT:
- out_dtype = DType.FLOAT
- elif error_name == ErrorIf.WrongInputType:
+ if error_name == ErrorIf.WrongInputType:
# Pick some potentially correct output dtype if input type is incorrect
out_dtype = DType.INT32
else:
- raise Exception(f"Unsupported input dtype: {ifm.dtype}")
+ out_dtype = accum_dtype
if error_name == ErrorIf.WrongOutputType:
- wrong_dtypes = list(usableDTypes(excludes=[out_dtype]))
+ if ifm.dtype == DType.FP16:
+ excludes = [DType.FP16, DType.FLOAT]
+ else:
+ excludes = [out_dtype]
+ wrong_dtypes = list(usableDTypes(excludes=excludes))
out_dtype = rng.choice(wrong_dtypes)
return ser.addOutput(ofm_shape, out_dtype)
@staticmethod
def depthwiseConv2dOp(
- ser, rng, ifm, filter, strides, padding, dilations, error_name=None
+ ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
):
# IFM: NHWC
# Filter: HWCM
@@ -4073,20 +4174,18 @@ class OutputShaper:
ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
- if ifm.dtype == DType.INT8:
- out_dtype = DType.INT32
- elif ifm.dtype == DType.INT16:
- out_dtype = DType.INT48
- elif ifm.dtype == DType.FLOAT:
- out_dtype = DType.FLOAT
- elif error_name == ErrorIf.WrongInputType:
+ if error_name == ErrorIf.WrongInputType:
# Pick some potentially correct output dtype if input type is incorrect
out_dtype = DType.INT32
else:
- raise Exception(f"Unsupported input dtype: {ifm.dtype}")
+ out_dtype = accum_dtype
if error_name == ErrorIf.WrongOutputType:
- wrong_dtypes = list(usableDTypes(excludes=[out_dtype]))
+ if ifm.dtype == DType.FP16:
+ excludes = [DType.FP16, DType.FLOAT]
+ else:
+ excludes = [out_dtype]
+ wrong_dtypes = list(usableDTypes(excludes=excludes))
out_dtype = rng.choice(wrong_dtypes)
return ser.addOutput(ofm_shape, out_dtype)
@@ -4119,6 +4218,7 @@ class OutputShaper:
DType.INT32,
DType.INT48,
DType.FLOAT,
+ DType.FP16,
]
wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
outputDType = rng.choice(wrong_dtypes)
@@ -4128,55 +4228,20 @@ class OutputShaper:
return ser.addOutput(ofm_shape, outputDType)
@staticmethod
- def fullyConnectedOp(ser, rng, input, filter, error_name=None):
+ def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None):
# input: N, IC
# filter: OC, IC
# output: N, OC
output_shape = [input.shape[0], filter.shape[0]]
- if error_name == ErrorIf.WrongOutputType:
- if input.dtype == DType.INT8:
- incorrect_types = (
- DType.INT4,
- DType.INT8,
- DType.INT16,
- DType.INT48,
- DType.FLOAT,
- )
- elif input.dtype == DType.INT16:
- incorrect_types = (
- DType.INT4,
- DType.INT8,
- DType.INT16,
- DType.INT32,
- DType.FLOAT,
- )
- elif input.dtype == DType.FLOAT:
- incorrect_types = (
- DType.INT4,
- DType.INT8,
- DType.INT16,
- DType.INT32,
- DType.INT48,
- )
- out_dtype = rng.choice(a=incorrect_types)
- elif input.dtype == DType.INT8:
- out_dtype = DType.INT32
- elif input.dtype == DType.INT16:
- out_dtype = DType.INT48
- elif input.dtype == DType.FLOAT:
- out_dtype = DType.FLOAT
- elif error_name == ErrorIf.WrongInputType:
- # Pick some potentially correct output dtype if input type is incorrect
- out_dtype = DType.INT32
- else:
- raise Exception("Unsupported input dtype: {}".format(input.dtype))
+ # Validated in arg_gen (also invalidated for ErrorIf)
+ out_dtype = accum_dtype
return ser.addOutput(output_shape, out_dtype)
@staticmethod
- def matmulOp(ser, rng, a, b, error_name=None):
+ def matmulOp(ser, rng, a, b, accum_dtype, error_name=None):
# a: N, H, C
# b: N, C, W
# out: N, H, W
@@ -4200,7 +4265,7 @@ class OutputShaper:
DType.INT32,
DType.FLOAT,
)
- elif a.dtype == DType.FLOAT:
+ elif a.dtype == DType.FLOAT or a.dtype == DType.FP16:
incorrect_types = (
DType.INT4,
DType.INT8,
@@ -4209,17 +4274,11 @@ class OutputShaper:
DType.INT48,
)
out_dtype = rng.choice(a=incorrect_types)
- elif a.dtype == DType.INT8:
- out_dtype = DType.INT32
- elif a.dtype == DType.INT16:
- out_dtype = DType.INT48
- elif a.dtype == DType.FLOAT:
- out_dtype = DType.FLOAT
elif error_name == ErrorIf.WrongInputType:
# Pick some potentially correct output dtype if input type is incorrect
out_dtype = DType.INT32
else:
- raise Exception("Unsupported input dtype for matmul: {}".format(a.dtype))
+ out_dtype = accum_dtype # Validated in arg_gen
return ser.addOutput(output_shape, out_dtype)
@@ -4269,10 +4328,6 @@ class OutputShaper:
bad_dim = rng.choice(range(len(output_shape)))
output_shape[bad_dim] -= rng.choice([1, 2])
- # Fix negative output shape if error_if test causes it
- if error_name == ErrorIf.PadSmallerZero and min(output_shape) < 1:
- output_shape = [i if i >= 1 else 1 for i in output_shape]
-
if error_name == ErrorIf.WrongOutputType:
all_dtypes = [
DType.INT8,
@@ -4280,6 +4335,7 @@ class OutputShaper:
DType.INT32,
DType.INT48,
DType.FLOAT,
+ DType.FP16,
]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
@@ -4546,7 +4602,7 @@ class OutputShaper:
return ser.addOutput(val.shape, out_dtype)
@staticmethod
- def transposeConv2DOp(ser, rng, ifm, output_shape, error_name=None):
+ def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None):
if error_name == ErrorIf.ConvOutputShapeMismatch:
choices = [1, 2, 3]
change = rng.choice(choices)
@@ -4555,20 +4611,18 @@ class OutputShaper:
if change in [2, 3]:
output_shape[2] = output_shape[2] + rng.choice(choices)
- if ifm.dtype == DType.INT8:
- out_dtype = DType.INT32
- elif ifm.dtype == DType.INT16:
- out_dtype = DType.INT48
- elif ifm.dtype == DType.FLOAT:
- out_dtype = DType.FLOAT
- elif error_name == ErrorIf.WrongInputType:
+ if error_name == ErrorIf.WrongInputType:
# Pick some potentially correct output dtype if input type is incorrect
out_dtype = DType.INT32
else:
- raise Exception(f"Unsupported input dtype: {ifm.dtype}")
+ out_dtype = accum_dtype
if error_name == ErrorIf.WrongOutputType:
- wrong_dtypes = list(usableDTypes(excludes=[out_dtype]))
+ if ifm.dtype == DType.FP16:
+ excludes = [DType.FP16, DType.FLOAT]
+ else:
+ excludes = [out_dtype]
+ wrong_dtypes = list(usableDTypes(excludes=excludes))
out_dtype = rng.choice(wrong_dtypes)
return ser.addOutput(output_shape, out_dtype)
diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py
index 6a689d0..7fa31e7 100644
--- a/verif/generator/tosa_utils.py
+++ b/verif/generator/tosa_utils.py
@@ -84,3 +84,42 @@ def product(shape):
for n in shape:
value *= n
return value
+
+
+def get_accum_dtype_from_tgTypes(dtypes):
+ # Get accumulate data-type from the test generator's defined types
+ if isinstance(dtypes, list) or isinstance(dtypes, tuple):
+ return dtypes[-1]
+ else:
+ return dtypes
+
+
+def get_wrong_output_type(op_name, rng, input_dtype):
+ if op_name == "fully_connected" or op_name == "matmul":
+ if input_dtype == DType.INT8:
+ incorrect_types = (
+ DType.INT4,
+ DType.INT8,
+ DType.INT16,
+ DType.INT48,
+ DType.FLOAT,
+ DType.FP16,
+ )
+ elif input_dtype == DType.INT16:
+ incorrect_types = (
+ DType.INT4,
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.FLOAT,
+ DType.FP16,
+ )
+ elif input_dtype == DType.FLOAT or input_dtype == DType.FP16:
+ incorrect_types = (
+ DType.INT4,
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.INT48,
+ )
+ return rng.choice(a=incorrect_types)
diff --git a/verif/tests/test_tosa_result_checker.py b/verif/tests/test_tosa_result_checker.py
index efee23b..d78d158 100644
--- a/verif/tests/test_tosa_result_checker.py
+++ b/verif/tests/test_tosa_result_checker.py
@@ -40,7 +40,7 @@ def _delete_data_file(file: Path):
(np.uint16, trc.TestResult.MISMATCH),
(np.uint32, trc.TestResult.MISMATCH),
(np.uint64, trc.TestResult.MISMATCH),
- (np.float16, trc.TestResult.MISMATCH),
+ (np.float16, trc.TestResult.PASS),
(np.float32, trc.TestResult.PASS),
(np.float64, trc.TestResult.MISMATCH),
(bool, trc.TestResult.PASS),