diff options
author | James Ward <james.ward@arm.com> | 2022-08-12 20:48:56 +0100 |
---|---|---|
committer | James Ward <james.ward@arm.com> | 2022-10-11 11:56:02 +0100 |
commit | 8b39043c70332e1e4c95ee6a9616aec40dd3baf1 (patch) | |
tree | fea519246b698eb944b9d58537fc90bc30481d11 /verif/generator/tosa_arg_gen.py | |
parent | ba5fad356a926d5e1c6e0fe6b546a310230cc5a8 (diff) | |
download | reference_model-8b39043c70332e1e4c95ee6a9616aec40dd3baf1.tar.gz |
Reference model changes for fp16 support
Change-Id: I72f21fcfa153046274969d327313e3349981dbe6
Signed-off-by: James Ward <james.ward@arm.com>
Diffstat (limited to 'verif/generator/tosa_arg_gen.py')
-rw-r--r-- | verif/generator/tosa_arg_gen.py | 230 |
1 files changed, 163 insertions, 67 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index a65e220..69968d3 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -2,10 +2,13 @@ # SPDX-License-Identifier: Apache-2.0 import itertools import math +import warnings import numpy as np from generator.tosa_error_if import ErrorIf from generator.tosa_error_if import TosaErrorIfArgGen +from generator.tosa_utils import get_accum_dtype_from_tgTypes +from generator.tosa_utils import get_wrong_output_type from generator.tosa_utils import MAX_RESIZE_DIMENSION from serializer.tosa_serializer import DTypeNames from tosa.DType import DType @@ -773,7 +776,7 @@ class TosaTensorValuesGen: ), "Op.MUL must have 2 placeholders, 0 consts" tens = [] - if dtypeList[0] == DType.FLOAT: + if dtypeList[0] in (DType.FP16, DType.FLOAT): tens.extend(testGen.buildPlaceholderTensors(shapeList[:], dtypeList[:])) else: placeholders = [] @@ -982,7 +985,7 @@ class TosaArgGen: return axes @staticmethod - def agConv(testGen, opName, shapeList, dtype, error_name=None): + def agConv(testGen, opName, shapeList, dtypes, error_name=None): arg_list = [] ifm_shape = shapeList[0] @@ -990,6 +993,8 @@ class TosaArgGen: # determine the kernel shape from operator name (e.g. "conv2d_3x3" => [3,3]) k = [int(x) for x in opName.split("_")[-1].split("x")] + accum_dtype = get_accum_dtype_from_tgTypes(dtypes) + # Check the rank rank = 5 if opName.startswith("conv3d") else 4 if error_name != ErrorIf.WrongRank: @@ -1089,12 +1094,13 @@ class TosaArgGen: ): arg_list.append( ( - "st{}_pad{}_dilat{}".format( + "acc{}_st{}_pad{}_dilat{}".format( + testGen.typeStr(accum_dtype), "".join([str(x) for x in s]), "".join([str(x) for x in p]), "".join([str(x) for x in d]), ), - [s, p, d], + [accum_dtype, s, p, d], ) ) n += 1 @@ -1102,12 +1108,55 @@ class TosaArgGen: return arg_list @staticmethod - def agTransposeConv2D(testGen, opName, shapeList, dtype, error_name=None): + def agFullyConnected(testGen, opName, shapeList, dtypes, error_name=None): + + if isinstance(dtypes, list) or isinstance(dtypes, tuple): + input_dtype = dtypes[0] + else: + input_dtype = dtypes + + if error_name == ErrorIf.WrongOutputType: + accum_dtype = get_wrong_output_type(opName, testGen.rng, input_dtype) + elif error_name == ErrorIf.WrongInputType: + # Pick some potentially correct output dtype if input type is incorrect + accum_dtype = DType.INT32 + else: + accum_dtype = get_accum_dtype_from_tgTypes(dtypes) + + return [(f"acc{testGen.typeStr(accum_dtype)}", [accum_dtype])] + + @staticmethod + def agMatMul(testGen, opName, shapeList, dtype, error_name=None): + # Get valid accumulate type(s) + if dtype == DType.INT8: + accum_dtypes = [DType.INT32] + elif dtype == DType.INT16: + accum_dtypes = [DType.INT48] + elif dtype == DType.FP16: + accum_dtypes = [DType.FP16, DType.FLOAT] + elif dtype == DType.FLOAT: + accum_dtypes = [DType.FLOAT] + elif error_name is None: + assert False, f"Invalid I/O DType for MatMul: {DTypeNames[dtype]}" + + if error_name == ErrorIf.WrongOutputType: + # Get incorrect output dtype for ErrorIf case + accum_dtypes = [get_wrong_output_type(opName, testGen.rng, dtype)] + elif error_name == ErrorIf.WrongInputType: + # Pick some potentially correct output dtype if input type is incorrect + accum_dtypes = [DType.INT32] + + return [(f"acc{testGen.typeStr(a)}", [a]) for a in accum_dtypes] + + @staticmethod + def agTransposeConv2D(testGen, opName, shapeList, dtypes, error_name=None): arg_list = [] ifm_shape = shapeList[0] filter_shape = shapeList[1] + accum_dtype = get_accum_dtype_from_tgTypes(dtypes) + # Must be rank 4 if error_name != ErrorIf.WrongRank: assert len(ifm_shape) == 4 @@ -1169,12 +1218,13 @@ class TosaArgGen: os = [ifm_shape[0], oh, ow, filter_shape[0]] arg_list.append( ( - "st{}_pad{}_os{}".format( + "acc{}_st{}_pad{}_os{}".format( + testGen.typeStr(accum_dtype), "".join([str(x) for x in s]), "".join([str(x) for x in p]), "x".join([str(x) for x in os]), ), - [s, p, os], + [accum_dtype, s, p, os], ) ) n += 1 @@ -1199,18 +1249,38 @@ class TosaArgGen: if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]: pad_const_int = testGen.getRandNumberDType(dtype) pad_const_fp = 0 - elif dtype == DType.FLOAT: + elif dtype in (DType.FP16, DType.FLOAT): pad_const_int = 0 pad_const_fp = testGen.getRandNumberDType(dtype) else: return [] for paddings in shape_pad_values: - name = "pad" - for r in range(rank): - before, after = paddings[r] - name = f"{name}{before}{after}" - arg_list.append((name, [np.array(paddings), pad_const_int, pad_const_fp])) + paddings = list(paddings) + args_valid = True + + if error_name == ErrorIf.PadSmallerZero: + # Prevent negative output shapes while ensuring still testing for negative padding + for i in range(rank): + dim_after_padding = ( + paddings[i][0] + paddings[i][1] + shapeList[0][i] + ) + if dim_after_padding < 1: + paddings[i] = (0, 0) + if all([p > -1 for p in paddings[i]]): + args_valid = False + + if args_valid: + name = "pad" + for r in range(rank): + before, after = paddings[r] + name = f"{name}{before}{after}" + arg_list.append( + (name, [np.array(paddings), pad_const_int, pad_const_fp]) + ) + + if error_name == ErrorIf.PadSmallerZero and len(arg_list) == 0: + warnings.warn(f"No ErrorIf test created for input shape: {shapeList[0]}") return arg_list @@ -1232,6 +1302,21 @@ class TosaArgGen: k_vals = [x for x in range(2, testGen.args.max_pooling_kernel + 1)] kernels = {x for x in itertools.product(*([k_vals] * 2))} + if opName == "max_pool2d": + accum_dtypes = [None] # max_pool has no accumulate dtype + elif dtype == DType.INT8 or dtype == DType.INT16: + accum_dtypes = [DType.INT32] + elif dtype == DType.FP16: + accum_dtypes = [DType.FP16, DType.FLOAT] + elif dtype == DType.FLOAT: + accum_dtypes = [DType.FLOAT] + elif error_name is None: + assert False, f"Invalid I/O DType for pooling: {DTypeNames[dtype]}" + else: + # Set to something for the ErrorIf case which has + # incorrect input data-type + accum_dtypes = [DType.INT32] + if testGen.args.oversize: # add some oversize argument values bigStride = 7 @@ -1252,63 +1337,70 @@ class TosaArgGen: sparsity_factor = 2 if error_name else 500 sparsity = len(paddings) * len(strides) * len(kernels) // sparsity_factor + 1 + arg_str = ( + "acc{}_st{}_kern{}_pad{}" + if accum_dtypes[0] is not None + else "st{}_kern{}_pad{}" + ) + + def get_arg_list_element(accum, stride, pad, kern): + # Return tuple containing the formatted argument string and + # the corresponding argument values + arg_str_elems = [ + "".join([str(x) for x in stride]), + "".join([str(x) for x in kern]), + "".join([str(x) for x in pad]), + ] + # Note: different order to string + arg_val_elems = [stride, pad, kern] + + if accum is not None: + arg_str_elems.insert(0, testGen.typeStr(accum)) + arg_val_elems.insert(0, accum) + return (arg_str.format(*arg_str_elems), arg_val_elems) + n = 0 - for s in sorted(list(strides)): - for p in sorted(list(paddings)): - for k in sorted(list(kernels)): - if error_name in [ - ErrorIf.StrideSmallerOne, - ErrorIf.KernelSmallerOne, - ErrorIf.PadSmallerZero, - ErrorIf.PadLargerEqualKernel, - ]: - sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf( - testGen, error_name, s, p, k - ) - if None not in [sNew, pNew, kNew] and n % sparsity == 0: - arg_list.append( - ( - "st{}_kern{}_pad{}".format( - "".join([str(x) for x in sNew]), - "".join([str(x) for x in kNew]), - "".join([str(x) for x in pNew]), - ), - [sNew, pNew, kNew], - ) + for a in accum_dtypes: + for s in sorted(list(strides)): + for p in sorted(list(paddings)): + for k in sorted(list(kernels)): + if error_name in [ + ErrorIf.StrideSmallerOne, + ErrorIf.KernelSmallerOne, + ErrorIf.PadSmallerZero, + ErrorIf.PadLargerEqualKernel, + ]: + sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf( + testGen, error_name, s, p, k ) - elif ( - n % sparsity == 0 - # padding must not exceed the kernel size - and p[0] < k[0] - and p[1] < k[0] - and p[2] < k[1] - and p[3] < k[1] - # the padded shape must exceed the kernel size - and (shape[1] + p[0] + p[1]) > k[0] - and (shape[2] + p[2] + p[3]) > k[1] - ): - remainder_h = (shape[1] + p[0] + p[1] - k[0]) % s[0] - remainder_w = (shape[2] + p[2] + p[3] - k[1]) % s[1] - if ( - # the parameters must produce integer exact output - error_name != ErrorIf.PoolingOutputShapeNonInteger - and remainder_h == 0 - and remainder_w == 0 - ) or ( - error_name == ErrorIf.PoolingOutputShapeNonInteger - and (remainder_h != 0 or remainder_w != 0) + if None not in [sNew, pNew, kNew] and n % sparsity == 0: + arg_vals = [a, sNew, pNew, kNew] + arg_list.append(get_arg_list_element(*arg_vals)) + elif ( + n % sparsity == 0 + # padding must not exceed the kernel size + and p[0] < k[0] + and p[1] < k[0] + and p[2] < k[1] + and p[3] < k[1] + # the padded shape must exceed the kernel size + and (shape[1] + p[0] + p[1]) > k[0] + and (shape[2] + p[2] + p[3]) > k[1] ): - arg_list.append( - ( - "st{}_kern{}_pad{}".format( - "".join([str(x) for x in s]), - "".join([str(x) for x in k]), - "".join([str(x) for x in p]), - ), - [s, p, k], - ) - ) - n += 1 + remainder_h = (shape[1] + p[0] + p[1] - k[0]) % s[0] + remainder_w = (shape[2] + p[2] + p[3] - k[1]) % s[1] + if ( + # the parameters must produce integer exact output + error_name != ErrorIf.PoolingOutputShapeNonInteger + and remainder_h == 0 + and remainder_w == 0 + ) or ( + error_name == ErrorIf.PoolingOutputShapeNonInteger + and (remainder_h != 0 or remainder_w != 0) + ): + arg_vals = [a, s, p, k] + arg_list.append(get_arg_list_element(*arg_vals)) + n += 1 return arg_list @@ -1327,6 +1419,8 @@ class TosaArgGen: dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT] elif inDtype == DType.BOOL: dtypeList = [DType.INT8, DType.INT16, DType.INT32] + elif inDtype == DType.FP16: + dtypeList = [DType.INT8, DType.INT16, DType.INT32] elif inDtype == DType.FLOAT: dtypeList = [DType.INT8, DType.INT16, DType.INT32] elif error_name == ErrorIf.WrongInputType: @@ -1734,6 +1828,8 @@ class TosaArgGen: outputDTypeList = [DType.INT32] elif mode == ResizeMode.BILINEAR and dtype == DType.INT16: outputDTypeList = [DType.INT48] + elif dtype == DType.FP16: + outputDTypeList = [DType.FP16] elif dtype == DType.FLOAT: outputDTypeList = [DType.FLOAT] elif error_name == ErrorIf.WrongInputType: |