From 5c1364c8a547127bc90e4d4a78dd876070eb1026 Mon Sep 17 00:00:00 2001 From: Jeremy Johnson Date: Thu, 13 Jan 2022 15:04:21 +0000 Subject: Add python pre-commit script checkers Fix up issues in existing python scripts. Signed-off-by: Jeremy Johnson Change-Id: Id4adab404560c3129c66f31c21ff0ce148283c73 --- .pre-commit-config.yaml | 20 + scripts/json2fbbin/json2fbbin.py | 8 +- setup.cfg | 6 + verif/generator/tosa_error_if.py | 17 +- verif/generator/tosa_test_gen.py | 3063 +++++++++++++++++++--------- verif/generator/tosa_verif_build_tests.py | 45 +- verif/tests/test_json2numpy.py | 1 - verif/tests/test_tosa_result_checker.py | 3 +- verif/tests/test_tosa_run_tests_mocksut.py | 1 - 9 files changed, 2101 insertions(+), 1063 deletions(-) create mode 100644 .pre-commit-config.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..4a73a48 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,20 @@ +# Copyright (c) 2021-2022 Arm Limited. +# SPDX-License-Identifier: Apache-2.0 + +# See https://pre-commit.com for more information +# See https://pre-commit.com/hooks.html for more hooks +repos: +- repo: https://github.com/asottile/reorder_python_imports + rev: v2.2.0 + hooks: + - id: reorder-python-imports + +- repo: https://github.com/psf/black + rev: 20.8b1 + hooks: + - id: black + +- repo: https://gitlab.com/pycqa/flake8 + rev: 3.7.9 + hooks: + - id: flake8 diff --git a/scripts/json2fbbin/json2fbbin.py b/scripts/json2fbbin/json2fbbin.py index 957acb1..8f9f274 100644 --- a/scripts/json2fbbin/json2fbbin.py +++ b/scripts/json2fbbin/json2fbbin.py @@ -4,7 +4,8 @@ from pathlib import Path from typing import Optional -from runner.run_command import run_sh_command, RunShCommandError +from runner.run_command import run_sh_command +from runner.run_command import RunShCommandError def fbbin_to_json(flatc: Path, fbs: Path, t_path: Path, o_path: Optional[Path] = None): @@ -63,7 +64,10 @@ def main(argv=None): parser.add_argument( "--flatc", type=Path, - default="reference_model/build/thirdparty/serialization_lib/third_party/flatbuffers/flatc", + default=( + "reference_model/build/thirdparty/serialization_lib/" + "third_party/flatbuffers/flatc" + ), help="the path to the flatc compiler program", ) parser.add_argument( diff --git a/setup.cfg b/setup.cfg index f9e5331..4e3dc10 100644 --- a/setup.cfg +++ b/setup.cfg @@ -50,3 +50,9 @@ console_scripts = [tool:pytest] testpaths=verif/tests + +[flake8] +ignore = D213, E203, E266, E501, W503 +max-line-length = 88 +select = B,E,F,W,T4 +exclude = .eggs diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py index 7c162be..7070205 100644 --- a/verif/generator/tosa_error_if.py +++ b/verif/generator/tosa_error_if.py @@ -1,16 +1,6 @@ -# Copyright (c) 2021, ARM Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Copyright (c) 2021-2022, ARM Limited. +# SPDX-License-Identifier: Apache-2.0 + class ErrorIf(object): MaxDimExceeded = "MaxDimExceeded" @@ -68,4 +58,3 @@ class ErrorIf(object): InputListBodyGraphInputMismatch = "InputListBodyGraphInputMismatch" InputListBodyGraphOutputMismatch = "InputListBodyGraphOutputMismatch" CondGraphOutputNotMatchingBool = "CondGraphOutputNotMatchingBool" - diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index 0d5a881..239a64e 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -1,48 +1,21 @@ -#!/usr/bin/env python3 - # Copyright (c) 2020-2022, ARM Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -import argparse -import sys -import re -import os -import subprocess -import shlex -import json -import glob -import math -import queue -import threading -import traceback -import math +# SPDX-License-Identifier: Apache-2.0 import itertools +import math +import os from copy import deepcopy -from enum import IntEnum, Enum, unique - +import numpy as np import serializer.tosa_serializer as ts -from serializer.tosa_serializer import * -import tosa from generator.tosa_error_if import ErrorIf - -# Convenience variables to the flatc-generated types that should be enums, but aren't +from serializer.tosa_serializer import DTypeNames from tosa.DType import DType from tosa.Op import Op from tosa.ResizeMode import ResizeMode +# DTypeNames, DType, Op and ResizeMode are convenience variables to the +# flatc-generated types that should be enums, but aren't + def valueToName(item, value): """Get the name of an attribute with the given value. @@ -70,7 +43,8 @@ def valueToName(item, value): for attr in dir(item): if getattr(item, attr) == value: return attr - raise ValueError(f'value ({value}) not found') + raise ValueError(f"value ({value}) not found") + def allDTypes(*, excludes=None): """Get a set of all DType values, optionally excluding some values. @@ -87,9 +61,14 @@ def allDTypes(*, excludes=None): A set of DType values """ excludes = () if not excludes else excludes - return {getattr(DType, t) for t in dir(DType) - if not callable(getattr(DType, t)) and not t.startswith('__') - and getattr(DType, t) not in excludes} + return { + getattr(DType, t) + for t in dir(DType) + if not callable(getattr(DType, t)) + and not t.startswith("__") + and getattr(DType, t) not in excludes + } + def usableDTypes(*, excludes=None): """Get a set of usable DType values, optionally excluding some values. @@ -108,6 +87,7 @@ def usableDTypes(*, excludes=None): omit.update(excludes if excludes else ()) return allDTypes(excludes=omit) + def product(shape): value = 1 for n in shape: @@ -116,7 +96,10 @@ def product(shape): class TosaQuantGen: - """QuantizedInfo random generator helper functions. Specify with 'qgen': in the operator defintion""" + """QuantizedInfo random generator helper functions. + + Specify with 'qgen': in the operator defintion. + """ def __init__(self): pass @@ -128,7 +111,11 @@ class TosaQuantGen: return testGen.randInt(-128, 128) elif dtype == DType.UINT8: return testGen.randInt(0, 256) - elif error_name in [ErrorIf.InputZeroPointNotZero, ErrorIf.WeightZeroPointNotZero, ErrorIf.OutputZeroPointNotZero]: + elif error_name in [ + ErrorIf.InputZeroPointNotZero, + ErrorIf.WeightZeroPointNotZero, + ErrorIf.OutputZeroPointNotZero, + ]: zero_point = testGen.randInt(-128, 128) if zero_point == 0: zero_point = 1 @@ -140,15 +127,18 @@ class TosaQuantGen: qinfo = ts.TosaSerializerQuantInfo() if error_name == ErrorIf.InputZeroPointNotZero: qinfo.UnaryQuantInfo( - TosaQuantGen.getQinfo(testGen, dtype, error_name), TosaQuantGen.getQinfo(testGen, dtype) + TosaQuantGen.getQinfo(testGen, dtype, error_name), + TosaQuantGen.getQinfo(testGen, dtype), ) elif error_name == ErrorIf.OutputZeroPointNotZero: qinfo.UnaryQuantInfo( - TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype, error_name) + TosaQuantGen.getQinfo(testGen, dtype), + TosaQuantGen.getQinfo(testGen, dtype, error_name), ) else: qinfo.UnaryQuantInfo( - TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype) + TosaQuantGen.getQinfo(testGen, dtype), + TosaQuantGen.getQinfo(testGen, dtype), ) return qinfo @@ -180,11 +170,13 @@ class TosaQuantGen: qinfo = ts.TosaSerializerQuantInfo() if error_name == ErrorIf.InputZeroPointNotZero: qinfo.MatMulQuantInfo( - TosaQuantGen.getQinfo(testGen, dtype, error_name), TosaQuantGen.getQinfo(testGen, dtype, error_name) - ) + TosaQuantGen.getQinfo(testGen, dtype, error_name), + TosaQuantGen.getQinfo(testGen, dtype, error_name), + ) else: qinfo.MatMulQuantInfo( - TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype) + TosaQuantGen.getQinfo(testGen, dtype), + TosaQuantGen.getQinfo(testGen, dtype), ) return qinfo @@ -221,7 +213,8 @@ class TosaQuantGen: shift = shift + 1 shift = (-shift) + scaleBits - #print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(scaleFp, scaleBits, m, multiplier, shift)) + # print('scalefp {} scaleBits {} m {} mult {} shift {}'.format( + # scaleFp, scaleBits, m, multiplier, shift)) # Adjust multiplier such that shift is in allowed value range. if shift == 0: @@ -242,7 +235,10 @@ class TosaQuantGen: class TosaTensorGen: """Tensor generators create a shape list for the placeholder and const tensor - data operands for the operator. The actual random data is generated separately for each test.""" + data operands for the operator. + + The actual random data is generated separately for each test. + """ def __init__(self): pass @@ -331,7 +327,7 @@ class TosaTensorGen: # Choose one of the inputs to broadcast # Note: Simplifies OutputShaper code if we don't change first shape for errors - bcast_idx = testGen.randInt(0 if error_name == None else 1, pl + const) + bcast_idx = testGen.randInt(0 if error_name is None else 1, pl + const) for i in range(pl + const): shape_bcast = shape.copy() @@ -343,7 +339,9 @@ class TosaTensorGen: elif error_name == ErrorIf.RankMismatch: # Add one rank to the shape (or more for rank of 1) extra_ranks = testGen.rng.choice([1, 2, 3]) if rank == 1 else 1 - shape_bcast = np.concatenate((shape_bcast, testGen.makeShape(extra_ranks))) + shape_bcast = np.concatenate( + (shape_bcast, testGen.makeShape(extra_ranks)) + ) if rank != 1: # Either keep the extra rank, or remove it new_len = testGen.rng.choice([-2, len(shape_bcast)]) @@ -371,7 +369,9 @@ class TosaTensorGen: # Constrict the overall size of the shape when creating ERROR_IF tests if error_name: - ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape, max_dim=24, max_items=10000) + ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions( + ifm_shape, max_dim=24, max_items=10000 + ) # Get the filter height/width from the operator parameters filter_hw = op["filter"] @@ -403,7 +403,9 @@ class TosaTensorGen: # Constrict the overall size of the shape when creating ERROR_IF tests if error_name: - ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape, max_dim=24, max_items=10000) + ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions( + ifm_shape, max_dim=24, max_items=10000 + ) # Get the filter depth/height/width from the operator parameters filter_dhw = op["filter"] @@ -437,7 +439,9 @@ class TosaTensorGen: # Constrict the overall size of the shape when creating ERROR_IF tests if error_name: - ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape, max_dim=24, max_items=10000) + ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions( + ifm_shape, max_dim=24, max_items=10000 + ) # Get the filter height/width from the operator parameters filter_hw = op["filter"] @@ -470,7 +474,9 @@ class TosaTensorGen: # Constrict the overall size of the shape when creating ERROR_IF tests if error_name: - ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape, max_dim=24, max_items=10000) + ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions( + ifm_shape, max_dim=24, max_items=10000 + ) # Get the filter height/width from the operator parameters # Filter is KH, HW, C, M @@ -571,7 +577,11 @@ class TosaTensorGen: @staticmethod def tgConcatConstInput(testGen, shapeList, axis, error_name=None): - if error_name in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank, ErrorIf.ConcatInputRankMismatch]: + if error_name in [ + ErrorIf.AxisSmallerZero, + ErrorIf.AxisLargerRank, + ErrorIf.ConcatInputRankMismatch, + ]: return shapeList # Split concat shape along axis to allow for multiple const inputs @@ -613,10 +623,13 @@ class TosaTensorGen: class TosaArgGen: - """Argument generators create exhaustive or random lists of attributes for operators that take - attributes or other parameters. The return value is a list of (descriptive_name, [arglist]) - tuples where the descriptive_name is appended to the test name and the arglist is expanded - as arguments to the operator build function.""" + """Argument generators create exhaustive or random lists of attributes for + operators that take attributes or other parameters. + + The return value is a list of (descriptive_name, [arglist]) tuples where + the descriptive_name is appended to the test name and the arglist is expanded + as arguments to the operator build function. + """ def __init__(self): pass @@ -651,7 +664,7 @@ class TosaArgGen: ifm_shape = shapeList[0] filter_shape = shapeList[1] - # determine the kernel shape from the operator name (e.g. "conv2d_3x3" => [3,3]) + # determine the kernel shape from operator name (e.g. "conv2d_3x3" => [3,3]) k = [int(x) for x in opName.split("_")[-1].split("x")] # Check the rank @@ -687,11 +700,15 @@ class TosaArgGen: # add some oversize argument values if max(ifm_shape) < 64: bigPadding = 9 - paddings.update({x for x in itertools.product(*([[0, bigPadding]] * (k_rank * 2)))}) + paddings.update( + {x for x in itertools.product(*([[0, bigPadding]] * (k_rank * 2)))} + ) bigStride = 8 strides.update({x for x in itertools.product(*([[1, bigStride]] * k_rank))}) bigDilation = 7 - dilations.update({x for x in itertools.product(*([[1, bigDilation]] * k_rank))}) + dilations.update( + {x for x in itertools.product(*([[1, bigDilation]] * k_rank))} + ) # There are too many parameter combinations, so generate them sparsely, # very sparse for negative tests @@ -700,7 +717,8 @@ class TosaArgGen: # If there are only a small number of tests, just select them all if sparsity < 13: sparsity = 1 - # To get a variety of parameter combinations sparsity should not be a multiple of 2, 3 or 5 + # To get a variety of parameter combinations sparsity should not be a + # multiple of 2, 3 or 5 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0: sparsity += 1 @@ -708,15 +726,19 @@ class TosaArgGen: for s in sorted(list(strides)): for p in sorted(list(paddings)): for d in sorted(list(dilations)): - if (n % sparsity == 0 + if ( + n % sparsity == 0 # padding must not exceed the kernel size ? - # and p[0] < k[0] and p[1] < k[0] and p[2] < k[1] and p[3] < k[1] + # and p[0] < k[0] and p[1] < k[0] + # and p[2] < k[1] and p[3] < k[1] # and (k_rank < 3 or (p[4] < k[2] and p[5] < k[2])) # the padded shape must exceed the kernel size - and (ifm_shape[1] + p[0] + p[1]) > k[0] and (ifm_shape[2] + p[2] + p[3]) > k[1] + and (ifm_shape[1] + p[0] + p[1]) > k[0] + and (ifm_shape[2] + p[2] + p[3]) > k[1] and (k_rank < 3 or ((ifm_shape[3] + p[4] + p[5]) > k[2])) # the padded shape must exceed the dilation - and (ifm_shape[1] + p[0] + p[1]) > d[0] and (ifm_shape[2] + p[2] + p[3]) > d[1] + and (ifm_shape[1] + p[0] + p[1]) > d[0] + and (ifm_shape[2] + p[2] + p[3]) > d[1] and (k_rank < 3 or ((ifm_shape[3] + p[4] + p[5]) > d[2])) ): arg_list.append( @@ -768,7 +790,9 @@ class TosaArgGen: # add some oversize argument values if max(ifm_shape) < 64: bigPadding = 9 - paddings.update({x for x in itertools.product(*([[0, bigPadding]] * 2))}) + paddings.update( + {x for x in itertools.product(*([[0, bigPadding]] * 2))} + ) bigStride = 8 strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))}) bigDilation = 7 @@ -781,7 +805,8 @@ class TosaArgGen: # If there are only a small number of tests, just select them all if sparsity < 13: sparsity = 1 - # To get a variety of parameter combinations sparsity should not be a multiple of 2, 3 or 5 + # To get a variety of parameter combinations sparsity should not be a + # multiple of 2, 3 or 5 while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0: sparsity += 1 @@ -887,8 +912,15 @@ class TosaArgGen: for s in sorted(list(strides)): for p in sorted(list(paddings)): for k in sorted(list(kernels)): - if error_name in [ErrorIf.StrideSmallerOne, ErrorIf.KernelSmallerOne, ErrorIf.PadSmallerZero, ErrorIf.PadLargerEqualKernel]: - sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf(testGen, error_name, s, p, k) + if error_name in [ + ErrorIf.StrideSmallerOne, + ErrorIf.KernelSmallerOne, + ErrorIf.PadSmallerZero, + ErrorIf.PadLargerEqualKernel, + ]: + sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf( + testGen, error_name, s, p, k + ) if None not in [sNew, pNew, kNew] and n % sparsity == 0: arg_list.append( ( @@ -900,11 +932,16 @@ class TosaArgGen: [sNew, pNew, kNew], ) ) - elif (n % sparsity == 0 + elif ( + n % sparsity == 0 # padding must not exceed the kernel size - and p[0] < k[0] and p[1] < k[0] and p[2] < k[1] and p[3] < k[1] + and p[0] < k[0] + and p[1] < k[0] + and p[2] < k[1] + and p[3] < k[1] # the padded shape must exceed the kernel size - and (shape[1] + p[0] + p[1]) > k[0] and (shape[2] + p[2] + p[3]) > k[1] + and (shape[1] + p[0] + p[1]) > k[0] + and (shape[2] + p[2] + p[3]) > k[1] ): arg_list.append( ( @@ -954,31 +991,53 @@ class TosaArgGen: # Enumerate the output types here for dtype in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]: - if dtype in [DType.UINT8, DType.INT8] and error_name == ErrorIf.OutputZeroPointNotZero: + if ( + dtype in [DType.UINT8, DType.INT8] + and error_name == ErrorIf.OutputZeroPointNotZero + ): continue - if inDtype == DType.UINT8 and dtype != DType.INT8 and error_name != ErrorIf.WrongOutputType: + if ( + inDtype == DType.UINT8 + and dtype != DType.INT8 + and error_name != ErrorIf.WrongOutputType + ): # The only output dtype for UINT8 is INT8, skip all other combinations continue - if inDtype != DType.INT8 and dtype == DType.UINT8 and error_name != ErrorIf.WrongOutputType: + if ( + inDtype != DType.INT8 + and dtype == DType.UINT8 + and error_name != ErrorIf.WrongOutputType + ): # The only input dtype for UINT8 is INT8, skip all other combinations continue - if error_name == ErrorIf.WrongOutputType and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, dtype): + if ( + error_name == ErrorIf.WrongOutputType + and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, dtype) + ): continue for scale32 in [False, True]: - if error_name == ErrorIf.ScaleTrue and scale32 == False: + if error_name == ErrorIf.ScaleTrue and not scale32: continue - elif error_name == ErrorIf.ScaleNotTrue and scale32 == True: + elif error_name == ErrorIf.ScaleNotTrue and scale32: continue for double_round in [False, True]: - if error_name == ErrorIf.ScaleNotTrue and double_round == False: + if error_name == ErrorIf.ScaleNotTrue and not double_round: continue for per_channel in [False, True]: - if inDtype == DType.INT48 and scale32 and error_name != ErrorIf.ScaleTrue: + if ( + inDtype == DType.INT48 + and scale32 + and error_name != ErrorIf.ScaleTrue + ): # Illegal condition. Must be scale32=False continue - if double_round and not scale32 and error_name != ErrorIf.ScaleNotTrue: + if ( + double_round + and not scale32 + and error_name != ErrorIf.ScaleNotTrue + ): # Illegal condition. ERROR_IF(!scale32 && double_round) continue @@ -1093,12 +1152,13 @@ class TosaArgGen: ifm_shape = shapeList[0] - if error_name == ErrorIf.IndexOutsideBounds: - incorrect_large_index = range(len(ifm_shape)+1, 2*len(ifm_shape)+1) + incorrect_large_index = range(len(ifm_shape) + 1, 2 * len(ifm_shape) + 1) incorrect_small_index = range(-len(ifm_shape), 0) permutations = [p for p in itertools.permutations(incorrect_large_index)] - permutations.extend([p for p in itertools.permutations(incorrect_small_index)]) + permutations.extend( + [p for p in itertools.permutations(incorrect_small_index)] + ) elif error_name == ErrorIf.IndexUsedTwice: # Create list with a duplicated index perm_range = list(range(len(ifm_shape))) @@ -1106,7 +1166,6 @@ class TosaArgGen: perm_range[(index_choice + 1) % len(perm_range)] = perm_range[index_choice] permutations = [p for p in itertools.permutations(perm_range)] - else: # Get all permutations permutations = [p for p in itertools.permutations(range(len(ifm_shape)))] @@ -1151,7 +1210,9 @@ class TosaArgGen: if valid: # If ERROR_IF test required then incorrect start, size will be returned - start, size = TosaErrorIfArgGen.eiSliceErrorIf(testGen, error_name, ifm_shape, start, size) + start, size = TosaErrorIfArgGen.eiSliceErrorIf( + testGen, error_name, ifm_shape, start, size + ) arg_list.append(("perm{}".format(p), [start, size])) return arg_list @@ -1170,7 +1231,8 @@ class TosaArgGen: multiples = [] for i in range(rank): if ifm_shape[i] > 1000: - # Multiple of 1 if ifm_shape dimension is large to reduce tensor size + # Multiple of 1 if ifm_shape dimension is large to reduce + # tensor size multiples.append(1) elif max(ifm_shape) > 1000: multiples.append(2) @@ -1212,9 +1274,9 @@ class TosaArgGen: # A output_dim of 1 will cause offset to exceed allowed range # so minimum value 2 produced below output_dims = [testGen.randInt(1) + 1, testGen.randInt(1) + 1] - while ((float(ifm_shape[1]) / float(output_dims[0])) >= 16): + while (float(ifm_shape[1]) / float(output_dims[0])) >= 16: output_dims[0] += 1 - while ((float(ifm_shape[2]) / float(output_dims[1])) >= 16): + while (float(ifm_shape[2]) / float(output_dims[1])) >= 16: output_dims[1] += 1 in_center_h = (ifm_shape[1] - 1) / 2.0 @@ -1229,7 +1291,10 @@ class TosaArgGen: if outputDType == DType.FLOAT: float_op = True - arg_str = "mode{}_shift{}_odim{}x{}_out{}_st{:.2f}x{:.2f}_off{:.2f}x{:.2f}" + arg_str = ( + "mode{}_shift{}_odim{}x{}_out{}" + "_st{:.2f}x{:.2f}_off{:.2f}x{:.2f}" + ) shift = 0 stride = [0, 0] offset = [0, 0] @@ -1239,11 +1304,11 @@ class TosaArgGen: else: float_op = False arg_str = "mode{}_shift{}_odim{}x{}_out{}_st{}x{}_off{}x{}" - shift = testGen.randInt(1,12) + shift = testGen.randInt(1, 12) # Now search for a shift value (1 to 11) that will produce # a valid and predictable resize operation count = 0 - while (count < 12): + while count < 12: unit = float(1 << shift) stride_y = int(round(fp_stride_y * unit)) stride_x = int(round(fp_stride_x * unit)) @@ -1265,20 +1330,26 @@ class TosaArgGen: shift = (shift % 11) + 1 continue - def RESIZE_REQUIRE_CALC(length_in, length_out, stride, offset, shift): + def RESIZE_REQUIRE_CALC( + length_in, length_out, stride, offset, shift + ): # Perform the pseudo loop to look for out of bounds - for pos in range(0,length_out): + for pos in range(0, length_out): a = pos * stride + offset ia = a >> shift ia0 = max(ia, 0) - ia1 = min(ia+1, length_in-1) + ia1 = min(ia + 1, length_in - 1) if ia0 > ia1: # Found a problem value break return ia0, ia1 - iy0, iy1 = RESIZE_REQUIRE_CALC(ifm_shape[1], output_dims[0], stride_y, offset_y, shift) - ix0, ix1 = RESIZE_REQUIRE_CALC(ifm_shape[2], output_dims[1], stride_x, offset_x, shift) + iy0, iy1 = RESIZE_REQUIRE_CALC( + ifm_shape[1], output_dims[0], stride_y, offset_y, shift + ) + ix0, ix1 = RESIZE_REQUIRE_CALC( + ifm_shape[2], output_dims[1], stride_x, offset_x, shift + ) if ix0 > ix1 or iy0 > iy1: # Change the shift value and check again count += 1 @@ -1298,7 +1369,14 @@ class TosaArgGen: # Common for all data types if error_name is not None: - shift, stride, stride_fp, offset, offset_fp, outputDTypeNew = TosaErrorIfArgGen.eiResizeErrorIf( + ( + shift, + stride, + stride_fp, + offset, + offset_fp, + outputDTypeNew, + ) = TosaErrorIfArgGen.eiResizeErrorIf( testGen, error_name, mode, @@ -1309,7 +1387,7 @@ class TosaArgGen: stride, stride_fp, offset, - offset_fp + offset_fp, ) else: outputDTypeNew = outputDType @@ -1325,7 +1403,7 @@ class TosaArgGen: stride_fp[0] if float_op else stride[0], stride_fp[1] if float_op else stride[1], offset_fp[0] if float_op else offset[0], - offset_fp[1] if float_op else offset[1] + offset_fp[1] if float_op else offset[1], ), [ mode, @@ -1384,14 +1462,26 @@ class TosaArgGen: return arg_list -class TosaErrorIfArgGen: +class TosaErrorIfArgGen: @staticmethod - def eiResizeErrorIf(testGen, error_name, mode, dtype, shapeList, outputDType, shift, stride, stride_fp, offset, offset_fp): + def eiResizeErrorIf( + testGen, + error_name, + mode, + dtype, + shapeList, + outputDType, + shift, + stride, + stride_fp, + offset, + offset_fp, + ): if outputDType == DType.FLOAT: if error_name == ErrorIf.StrideSmallerEqualZero: - stride_fp = testGen.rng.random(size=[2]) - 2 + stride_fp = testGen.rng.random(size=[2]) - 2 elif error_name == ErrorIf.ShiftNotZero: shift = testGen.rng.integers(1, 5) elif error_name == ErrorIf.StrideLargerDimension: @@ -1407,11 +1497,23 @@ class TosaErrorIfArgGen: elif error_name == ErrorIf.ShiftSmallerOne: shift = testGen.rng.integers(-3, 1) if shift <= 0: - stride = [(16 >> -shift) - 1, (16 >> -shift) - 1] # avoids other ERROR_IF checks - offset = [(16 >> -shift) - 1, (16 >> -shift) - 1] # avoids other ERROR_IF checks + stride = [ + (16 >> -shift) - 1, + (16 >> -shift) - 1, + ] # avoids other ERROR_IF checks + offset = [ + (16 >> -shift) - 1, + (16 >> -shift) - 1, + ] # avoids other ERROR_IF checks else: - stride = [(16 << shift) - 1, (16 << shift) - 1] # avoids other ERROR_IF checks - offset = [(16 << shift) - 1, (16 << shift) - 1] # avoids other ERROR_IF checks + stride = [ + (16 << shift) - 1, + (16 << shift) - 1, + ] # avoids other ERROR_IF checks + offset = [ + (16 << shift) - 1, + (16 << shift) - 1, + ] # avoids other ERROR_IF checks elif error_name == ErrorIf.ShiftLargerEleven: shift = np.int16(testGen.rng.integers(12, 15)) elif error_name == ErrorIf.StrideLargerDimension: @@ -1428,49 +1530,91 @@ class TosaErrorIfArgGen: elif error_name == ErrorIf.OffsetSmallerEqualMin: offset = [(-16 << shift) - 1, (-16 << shift) - 1] - if error_name == ErrorIf.WrongOutputType: if mode == ResizeMode.NEAREST and dtype == DType.INT8: - incorrect_types = (DType.INT4, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT) + incorrect_types = ( + DType.INT4, + DType.INT16, + DType.INT32, + DType.INT48, + DType.FLOAT, + ) elif mode == ResizeMode.NEAREST and dtype == DType.INT16: - incorrect_types = (DType.INT4, DType.INT8, DType.INT32, DType.INT48, DType.FLOAT) + incorrect_types = ( + DType.INT4, + DType.INT8, + DType.INT32, + DType.INT48, + DType.FLOAT, + ) elif mode == ResizeMode.BILINEAR and dtype == DType.INT8: - incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT48, DType.FLOAT) + incorrect_types = ( + DType.INT4, + DType.INT8, + DType.INT16, + DType.INT48, + DType.FLOAT, + ) elif mode == ResizeMode.BILINEAR and dtype == DType.INT16: - incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.FLOAT) + incorrect_types = ( + DType.INT4, + DType.INT8, + DType.INT16, + DType.INT32, + DType.FLOAT, + ) elif dtype == DType.FLOAT: - incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48) + incorrect_types = ( + DType.INT4, + DType.INT8, + DType.INT16, + DType.INT32, + DType.INT48, + ) outputDType = testGen.rng.choice(a=incorrect_types) return shift, stride, stride_fp, offset, offset_fp, outputDType - @staticmethod def eiPoolingErrorIf(testGen, error_name, stride, pad, kernel): - if (error_name == ErrorIf.StrideSmallerOne + if ( + error_name == ErrorIf.StrideSmallerOne # padding must not exceed the kernel size - and pad[0] < kernel[0] and pad[1] < kernel[0] and pad[2] < kernel[1] and pad[3] < kernel[1]): - wrongStride = (testGen.rng.choice([0, -1, -2, -3]), testGen.rng.choice([0, -1, -2, -3])) + and pad[0] < kernel[0] + and pad[1] < kernel[0] + and pad[2] < kernel[1] + and pad[3] < kernel[1] + ): + wrongStride = ( + testGen.rng.choice([0, -1, -2, -3]), + testGen.rng.choice([0, -1, -2, -3]), + ) return wrongStride, pad, kernel elif error_name == ErrorIf.PadSmallerZero: - wrongPad = (testGen.rng.choice([-1, -2, -3]), - testGen.rng.choice([-1, -2, -3]), - testGen.rng.choice([-1, -2, -3]), - testGen.rng.choice([-1, -2, -3])) + wrongPad = ( + testGen.rng.choice([-1, -2, -3]), + testGen.rng.choice([-1, -2, -3]), + testGen.rng.choice([-1, -2, -3]), + testGen.rng.choice([-1, -2, -3]), + ) return stride, wrongPad, kernel elif error_name == ErrorIf.KernelSmallerOne: - wrongKernel = (testGen.rng.choice([0, -1, -2, -3]), testGen.rng.choice([0, -1, -2, -3])) + wrongKernel = ( + testGen.rng.choice([0, -1, -2, -3]), + testGen.rng.choice([0, -1, -2, -3]), + ) return stride, pad, wrongKernel elif error_name == ErrorIf.PadLargerEqualKernel: - wrongPad = (testGen.rng.choice([kernel[0], kernel[0]+1, kernel[0]+2]), - testGen.rng.choice([kernel[0], kernel[0]+1, kernel[0]+2]), - testGen.rng.choice([kernel[1], kernel[1]+1, kernel[1]+2]), - testGen.rng.choice([kernel[1], kernel[1]+1, kernel[1]+2])) + wrongPad = ( + testGen.rng.choice([kernel[0], kernel[0] + 1, kernel[0] + 2]), + testGen.rng.choice([kernel[0], kernel[0] + 1, kernel[0] + 2]), + testGen.rng.choice([kernel[1], kernel[1] + 1, kernel[1] + 2]), + testGen.rng.choice([kernel[1], kernel[1] + 1, kernel[1] + 2]), + ) return stride, wrongPad, kernel else: return None, None, None - @staticmethod def eiRescaleWrongOutputType(input_dtype, output_dtype): if input_dtype == DType.INT8: @@ -1487,27 +1631,28 @@ class TosaErrorIfArgGen: return True return False - @staticmethod def eiInvalidateInputOutputList(testGen, error_name, input_list, output_list): # Mess up input/output tensors for ERROR_IF checks if error_name == "WrongInputList": add_input = testGen.rng.choice([True, False]) if add_input: - input_list.append('eiDummyInput') + input_list.append("eiDummyInput") else: input_list = input_list[:-1] elif error_name == "WrongOutputList": add_output = testGen.rng.choice([True, False]) if add_output: - output_list.append('eiDummyOutput') + output_list.append("eiDummyOutput") else: output_list = [] return input_list, output_list @staticmethod def eiRestrictDimensions(shape, max_dim=32, max_items=100000): - """Restrict the dimensions and overall size of a shape to max_dim and max_items.""" + """Restrict the dimensions and overall size of a shape to + max_dim and max_items. + """ new_shape = [min(d, max_dim) for d in shape] if max(shape) > max_dim else shape while product(new_shape) > max_items: new_shape = [max(d - 1, 1) for d in new_shape] @@ -1527,7 +1672,7 @@ class TosaErrorIfArgGen: elif error_name == ErrorIf.StartSizeOutsideBounds: newStart, newSize = [], [] for i in range(len(input_shape)): - newStart.append(input_shape[i]-1) + newStart.append(input_shape[i] - 1) newSize.append(testGen.rng.choice([2, 3, 4])) return newStart, newSize elif error_name == ErrorIf.InputSizeStartLengthMismatch: @@ -1556,7 +1701,6 @@ class TosaErrorIfArgGen: class TosaErrorValidator: - @staticmethod def evValidateErrorIfs(serializer, validator_fcns, error_name, **kwargs): """Check ERROR_IF statements are caught and set the expected result. @@ -1572,9 +1716,9 @@ class TosaErrorValidator: overall_result = True for val_fcn in validator_fcns: val_result = val_fcn(True, **kwargs) - validator_name = val_result['error_name'] - error_result = val_result['error_result'] - error_reason = val_result['error_reason'] + validator_name = val_result["error_name"] + error_result = val_result["error_result"] + error_reason = val_result["error_reason"] # expect an error IFF the error_name and validator_name match expected_result = error_result == (error_name == validator_name) @@ -1583,18 +1727,22 @@ class TosaErrorValidator: if expected_result and error_result: serializer.setExpectedReturnCode(2, True, desc=error_reason) elif error_result: # and not expected_result - print(f"Unexpected ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}" - f" Expected: {error_name}, Got: {validator_name}") - elif not expected_result: # and not error_result - print(f"Missed ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}" - f" Expected: {error_name}") + print( + f"Unexpected ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}" + f" Expected: {error_name}, Got: {validator_name}" + ) + elif not expected_result: # and not error_result + print( + f"Missed ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}" + f" Expected: {error_name}" + ) if not expected_result: for k, v in sorted(kwargs.items()): - if k != 'op': - if k.endswith('dtype'): + if k != "op": + if k.endswith("dtype"): v = valueToName(DType, v) - print(f' {k} = {v}') + print(f" {k} = {v}") return overall_result @@ -1603,24 +1751,26 @@ class TosaErrorValidator: error_result = False # Find the unsupported input data types - op = kwargs['op'] - input_dtypes = op['types'] - allowed_input_dtypes = {t[0] if isinstance(t, list) else t for t in input_dtypes} + op = kwargs["op"] + input_dtypes = op["types"] + allowed_input_dtypes = { + t[0] if isinstance(t, list) else t for t in input_dtypes + } wrong_input_dtypes = list(usableDTypes(excludes=allowed_input_dtypes)) - if op['op'] == Op.CLAMP: + if op["op"] == Op.CLAMP: wrong_input_dtypes.remove(DType.INT48) if check: - input_dtype = kwargs['input_dtype'] + input_dtype = kwargs["input_dtype"] if input_dtype not in allowed_input_dtypes: error_result = True info_dict = { "error_name": ErrorIf.WrongInputType, "error_result": error_result, - "error_reason": f"Input data type not supported for this operator", - "param_reqs": {"rank": None, "dtype": wrong_input_dtypes, "shape": None} + "error_reason": "Input data type not supported for this operator", + "param_reqs": {"rank": None, "dtype": wrong_input_dtypes, "shape": None}, } return info_dict @@ -1629,24 +1779,45 @@ class TosaErrorValidator: error_result = False if check: - input_dtype = kwargs['input_dtype'] - output_dtype = kwargs['output_dtype'] - op = kwargs['op'] + input_dtype = kwargs["input_dtype"] + output_dtype = kwargs["output_dtype"] + op = kwargs["op"] - if op['op'] == Op.RESIZE: - mode = kwargs['mode'] + if op["op"] == Op.RESIZE: + mode = kwargs["mode"] if ( - (mode == ResizeMode.NEAREST and input_dtype == DType.INT8 and output_dtype != DType.INT8) or - (mode == ResizeMode.NEAREST and input_dtype == DType.INT16 and output_dtype != DType.INT16) or - (mode == ResizeMode.BILINEAR and input_dtype == DType.INT8 and output_dtype != DType.INT32) or - (mode == ResizeMode.BILINEAR and input_dtype == DType.INT16 and output_dtype != DType.INT48) or - (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT) + ( + mode == ResizeMode.NEAREST + and input_dtype == DType.INT8 + and output_dtype != DType.INT8 + ) + or ( + mode == ResizeMode.NEAREST + and input_dtype == DType.INT16 + and output_dtype != DType.INT16 + ) + or ( + mode == ResizeMode.BILINEAR + and input_dtype == DType.INT8 + and output_dtype != DType.INT32 + ) + or ( + mode == ResizeMode.BILINEAR + and input_dtype == DType.INT16 + and output_dtype != DType.INT48 + ) + or (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT) ): error_result = True - elif op['op'] == Op.RESCALE: + elif op["op"] == Op.RESCALE: if input_dtype == DType.INT8: - if output_dtype not in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]: + if output_dtype not in [ + DType.UINT8, + DType.INT8, + DType.INT16, + DType.INT32, + ]: error_result = True if input_dtype in [DType.INT16, DType.INT32]: if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]: @@ -1658,49 +1829,78 @@ class TosaErrorValidator: if output_dtype != DType.INT8: error_result = True - elif op['op'] in [Op.FULLY_CONNECTED, Op.MATMUL]: + elif op["op"] in [Op.FULLY_CONNECTED, Op.MATMUL]: if ( - (input_dtype == DType.INT8 and output_dtype != DType.INT32) or - (input_dtype == DType.INT16 and output_dtype != DType.INT48) or - (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT) + (input_dtype == DType.INT8 and output_dtype != DType.INT32) + or (input_dtype == DType.INT16 and output_dtype != DType.INT48) + or (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT) ): error_result = True - elif op['op'] == Op.ARGMAX: - if input_dtype in [DType.INT8, DType.INT16, DType.FLOAT] and output_dtype != DType.INT32: + elif op["op"] == Op.ARGMAX: + if ( + input_dtype in [DType.INT8, DType.INT16, DType.FLOAT] + and output_dtype != DType.INT32 + ): error_result = True - elif op['op'] == Op.MUL: + elif op["op"] == Op.MUL: if input_dtype != DType.FLOAT and output_dtype != DType.INT32: error_result = True elif input_dtype == DType.FLOAT and output_dtype != DType.FLOAT: error_result = True - elif op['op'] == Op.TABLE: + elif op["op"] == Op.TABLE: if input_dtype == DType.INT8 and output_dtype != DType.INT8: error_result = True elif input_dtype == DType.INT16 and output_dtype != DType.INT32: error_result = True - elif op['op'] in [Op.EQUAL, Op.GREATER_EQUAL, Op.GREATER]: + elif op["op"] in [Op.EQUAL, Op.GREATER_EQUAL, Op.GREATER]: if output_dtype != DType.BOOL: error_result = True - elif op['op'] == Op.CAST: + elif op["op"] == Op.CAST: if ( - (input_dtype == DType.BOOL and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]) - or (input_dtype == DType.INT8 and output_dtype not in [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT]) - or (input_dtype == DType.INT16 and output_dtype not in [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT]) - or (input_dtype == DType.INT32 and output_dtype not in [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]) - or (input_dtype == DType.FLOAT and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]) + ( + input_dtype == DType.BOOL + and output_dtype not in [DType.INT8, DType.INT16, DType.INT32] + ) + or ( + input_dtype == DType.INT8 + and output_dtype + not in [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT] + ) + or ( + input_dtype == DType.INT16 + and output_dtype + not in [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT] + ) + or ( + input_dtype == DType.INT32 + and output_dtype + not in [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT] + ) + or ( + input_dtype == DType.FLOAT + and output_dtype not in [DType.INT8, DType.INT16, DType.INT32] + ) ): error_result = True - elif op['op'] in {Op.CONV2D, Op.CONV3D, Op.DEPTHWISE_CONV2D, Op.TRANSPOSE_CONV2D}: + elif op["op"] in { + Op.CONV2D, + Op.CONV3D, + Op.DEPTHWISE_CONV2D, + Op.TRANSPOSE_CONV2D, + }: if ( - input_dtype == DType.INT8 and output_dtype != DType.INT32 - or input_dtype == DType.INT16 and output_dtype != DType.INT48 - or input_dtype == DType.FLOAT and output_dtype != DType.FLOAT + input_dtype == DType.INT8 + and output_dtype != DType.INT32 + or input_dtype == DType.INT16 + and output_dtype != DType.INT48 + or input_dtype == DType.FLOAT + and output_dtype != DType.FLOAT ): error_result = True # invalid input types are ignored, to avoid reporting multiple errors @@ -1712,8 +1912,10 @@ class TosaErrorValidator: info_dict = { "error_name": ErrorIf.WrongOutputType, "error_result": error_result, - "error_reason": "Output data type not supported for this configuration of operator", - "param_reqs": {"rank": None, "dtype": None, "shape": None} + "error_reason": ( + "Output data type not supported for this configuration of operator" + ), + "param_reqs": {"rank": None, "dtype": None, "shape": None}, } return info_dict @@ -1722,19 +1924,19 @@ class TosaErrorValidator: all_ranks = (1, 2, 3, 4, 5) # Make a list of incorrect ranks - assert 'op' in kwargs - op = kwargs['op'] - rmin, rmax = op['rank'] + assert "op" in kwargs + op = kwargs["op"] + rmin, rmax = op["rank"] rank_range = range(rmin, rmax + 1) incorrect_ranks = list(set(all_ranks) - set(rank_range)) # Remove small incorrect ranks to avoid index errors incorrect_ranks = [rank for rank in incorrect_ranks if rank > rmin] # Set minimum incorrect rank to 3 to avoid index error - if op['op'] in [Op.RESIZE]: + if op["op"] in [Op.RESIZE]: incorrect_ranks = [3, 5] - elif op['op'] in [Op.TRANSPOSE]: + elif op["op"] in [Op.TRANSPOSE]: incorrect_ranks = [7, 8] - elif op['op'] in [Op.CONV3D]: + elif op["op"] in [Op.CONV3D]: incorrect_ranks = [6, 7] error_name = ErrorIf.WrongRank @@ -1743,13 +1945,16 @@ class TosaErrorValidator: error_reason = "Rank not supported for this operator" if check: - input_shape = kwargs['input_shape'] + input_shape = kwargs["input_shape"] - if op['op'] in [Op.RESIZE, Op.AVG_POOL2D, Op.MAX_POOL2D] and len(input_shape) != 4: + if ( + op["op"] in [Op.RESIZE, Op.AVG_POOL2D, Op.MAX_POOL2D] + and len(input_shape) != 4 + ): error_result = True - elif op['op'] == Op.FULLY_CONNECTED and len(input_shape) != 2: + elif op["op"] == Op.FULLY_CONNECTED and len(input_shape) != 2: error_result = True - elif op['op'] == Op.MATMUL and len(input_shape) != 3: + elif op["op"] == Op.MATMUL and len(input_shape) != 3: error_result = True else: if len(input_shape) not in rank_range: @@ -1759,7 +1964,7 @@ class TosaErrorValidator: "error_name": error_name, "error_result": error_result, "error_reason": error_reason, - "param_reqs": param_reqs + "param_reqs": param_reqs, } return info_dict @@ -1771,10 +1976,10 @@ class TosaErrorValidator: error_reason = "Op input list does not match expected input" if check: - op = kwargs['op'] - input_list = kwargs['input_list'] - num_operands = kwargs['num_operands'] - if op['op'] in [Op.SCATTER, Op.GATHER]: + op = kwargs["op"] + input_list = kwargs["input_list"] + num_operands = kwargs["num_operands"] + if op["op"] in [Op.SCATTER, Op.GATHER]: # SCATTER/GATHER add an indices input tensor in their build functions num_operands += 1 if len(input_list) != num_operands: @@ -1784,7 +1989,7 @@ class TosaErrorValidator: "error_name": error_name, "error_result": error_result, "error_reason": error_reason, - "param_reqs": param_reqs + "param_reqs": param_reqs, } return info_dict @@ -1796,7 +2001,7 @@ class TosaErrorValidator: error_reason = "Op output list does not match expected output" if check: - output_list = kwargs['output_list'] + output_list = kwargs["output_list"] # Note this will be incorrect if an operator returns more than one output if len(output_list) != 1: error_result = True @@ -1805,7 +2010,7 @@ class TosaErrorValidator: "error_name": error_name, "error_result": error_result, "error_reason": error_reason, - "param_reqs": param_reqs + "param_reqs": param_reqs, } return info_dict @@ -1813,45 +2018,51 @@ class TosaErrorValidator: def evMaxDimExceeded(check=False, **kwargs): error_name = ErrorIf.MaxDimExceeded param_reqs = { - "rank": [4,4], + "rank": [4, 4], "dtype": [DType.INT8], - "shape": [[1, 16584, 5, 1], [1, 2, 16499, 4]] - } + "shape": [[1, 16584, 5, 1], [1, 2, 16499, 4]], + } error_result = False - error_reason = "At least one maximum dimension is greater than or equal to 16384" + error_reason = ( + "At least one maximum dimension is greater than or equal to 16384" + ) if check: - input_shape = kwargs['input_shape'] - output_shape = kwargs['output_shape'] # Note this is just (OH, OW) - if ((input_shape[1] >= 16384) or - (input_shape[2] >= 16384) or - (output_shape[0] >= 16384) or - (output_shape[1] >= 16384)): + input_shape = kwargs["input_shape"] + output_shape = kwargs["output_shape"] # Note this is just (OH, OW) + if ( + (input_shape[1] >= 16384) + or (input_shape[2] >= 16384) + or (output_shape[0] >= 16384) + or (output_shape[1] >= 16384) + ): error_result = True info_dict = { "error_name": error_name, "error_result": error_result, "error_reason": error_reason, - "param_reqs": param_reqs + "param_reqs": param_reqs, } return info_dict @staticmethod def evBatchMismatch(check=False, **kwargs): error_name = ErrorIf.BatchMismatch - param_reqs = {"rank": [4,4], "dtype": None, "shape": None} + param_reqs = {"rank": [4, 4], "dtype": None, "shape": None} error_result = False error_reason = "Input batch size not equal to output batch size" - assert 'op' in kwargs - op = kwargs['op'] - rmin, rmax = op['rank'] + assert "op" in kwargs + op = kwargs["op"] + rmin, rmax = op["rank"] rank_range = range(rmin, rmax + 1) if check: - input_shape = kwargs['input_shape'] - output_shape = kwargs['result_tensor'].shape # Note this is just (N, OH, OW, C) + input_shape = kwargs["input_shape"] + output_shape = kwargs[ + "result_tensor" + ].shape # Note this is just (N, OH, OW, C) if (len(input_shape) in rank_range) and (input_shape[0] != output_shape[0]): error_result = True @@ -1860,25 +2071,27 @@ class TosaErrorValidator: "error_name": error_name, "error_result": error_result, "error_reason": error_reason, - "param_reqs": param_reqs + "param_reqs": param_reqs, } return info_dict @staticmethod def evChannelMismatch(check=False, **kwargs): error_name = ErrorIf.ChannelMismatch - param_reqs = {"rank": [4,4], "dtype": None, "shape": None} + param_reqs = {"rank": [4, 4], "dtype": None, "shape": None} error_result = False error_reason = "Input channel size not equal to output channel size" - assert 'op' in kwargs - op = kwargs['op'] - rmin, rmax = op['rank'] + assert "op" in kwargs + op = kwargs["op"] + rmin, rmax = op["rank"] rank_range = range(rmin, rmax + 1) if check: - input_shape = kwargs['input_shape'] - output_shape = kwargs['result_tensor'].shape # Note this is just (N, OH, OW, C) + input_shape = kwargs["input_shape"] + output_shape = kwargs[ + "result_tensor" + ].shape # Note this is just (N, OH, OW, C) if (len(input_shape) in rank_range) and (input_shape[3] != output_shape[3]): error_result = True @@ -1886,7 +2099,7 @@ class TosaErrorValidator: "error_name": error_name, "error_result": error_result, "error_reason": error_reason, - "param_reqs": param_reqs + "param_reqs": param_reqs, } return info_dict @@ -1898,16 +2111,18 @@ class TosaErrorValidator: error_reason = "Stride value smaller than or equal zero" if check: - input_dtype = kwargs['input_dtype'] - output_dtype = kwargs['output_dtype'] + input_dtype = kwargs["input_dtype"] + output_dtype = kwargs["output_dtype"] if input_dtype != DType.FLOAT and output_dtype == DType.FLOAT: - stride = kwargs['stride'] # Work around wrong input/output type tests + stride = kwargs["stride"] # Work around wrong input/output type tests elif output_dtype == DType.FLOAT: - stride = kwargs['stride_fp'] + stride = kwargs["stride_fp"] elif input_dtype == DType.FLOAT and output_dtype != DType.FLOAT: - stride = kwargs['stride_fp'] # Work around wrong input/output type tests + stride = kwargs[ + "stride_fp" + ] # Work around wrong input/output type tests else: - stride = kwargs['stride'] + stride = kwargs["stride"] if min(stride) <= 0: error_result = True @@ -1916,7 +2131,7 @@ class TosaErrorValidator: "error_name": error_name, "error_result": error_result, "error_reason": error_reason, - "param_reqs": param_reqs + "param_reqs": param_reqs, } return info_dict @@ -1928,24 +2143,27 @@ class TosaErrorValidator: error_reason = "Stride value larger than or equal to maximum value" if check: - shift = kwargs['shift'] - input_dtype = kwargs['input_dtype'] - stride = kwargs['stride'] + shift = kwargs["shift"] + input_dtype = kwargs["input_dtype"] + stride = kwargs["stride"] if input_dtype in [DType.INT8, DType.INT16]: - if shift >= 0 and (stride[0] >= (16 << shift) or stride[1] >= (16 << shift)): + if shift >= 0 and ( + stride[0] >= (16 << shift) or stride[1] >= (16 << shift) + ): error_result = True - elif shift < 0 and (stride[0] >= (16 >> -shift) or stride[1] >= (16 >> -shift)): + elif shift < 0 and ( + stride[0] >= (16 >> -shift) or stride[1] >= (16 >> -shift) + ): error_result = True info_dict = { "error_name": error_name, "error_result": error_result, "error_reason": error_reason, - "param_reqs": param_reqs + "param_reqs": param_reqs, } return info_dict - @staticmethod def evStrideLargerDimension(check=False, **kwargs): error_name = ErrorIf.StrideLargerDimension @@ -1954,22 +2172,25 @@ class TosaErrorValidator: error_reason = "Stride value larger than or equal to H/W dimension" if check: - shape = kwargs['input_shape'] - input_dtype = kwargs['input_dtype'] - stride = kwargs['stride_fp'] + shape = kwargs["input_shape"] + input_dtype = kwargs["input_dtype"] + stride = kwargs["stride_fp"] - if input_dtype == DType.FLOAT and (stride[0] > shape[1]) or (stride[1] > shape[2]): + if ( + input_dtype == DType.FLOAT + and (stride[0] > shape[1]) + or (stride[1] > shape[2]) + ): error_result = True info_dict = { "error_name": error_name, "error_result": error_result, "error_reason": error_reason, - "param_reqs": param_reqs + "param_reqs": param_reqs, } return info_dict - @staticmethod def evOffsetSmallerEqualMin(check=False, **kwargs): error_name = ErrorIf.OffsetSmallerEqualMin @@ -1978,23 +2199,27 @@ class TosaErrorValidator: error_reason = "Offset value smaller than or equal to minimum value" if check: - shift = kwargs['shift'] - output_dtype = kwargs['output_dtype'] + shift = kwargs["shift"] + output_dtype = kwargs["output_dtype"] if output_dtype == DType.FLOAT: - offset = kwargs['offset_fp'] + offset = kwargs["offset_fp"] else: - offset = kwargs['offset'] + offset = kwargs["offset"] - if shift >= 0 and (offset[0] <= (-16 << shift) or offset[1] <= (-16 << shift)): + if shift >= 0 and ( + offset[0] <= (-16 << shift) or offset[1] <= (-16 << shift) + ): error_result = True - elif shift < 0 and (offset[0] <= (-16 >> -shift) or offset[1] <= (-16 >> -shift)): + elif shift < 0 and ( + offset[0] <= (-16 >> -shift) or offset[1] <= (-16 >> -shift) + ): error_result = True info_dict = { "error_name": error_name, "error_result": error_result, "error_reason": error_reason, - "param_reqs": param_reqs + "param_reqs": param_reqs, } return info_dict @@ -2006,27 +2231,31 @@ class TosaErrorValidator: error_reason = "Offset value larger than or equal to maximum value" if check: - shift = kwargs['shift'] - output_dtype = kwargs['output_dtype'] + shift = kwargs["shift"] + output_dtype = kwargs["output_dtype"] if output_dtype == DType.FLOAT: - offset = kwargs['offset_fp'] + offset = kwargs["offset_fp"] else: - offset = kwargs['offset'] + offset = kwargs["offset"] if shift >= 0: if offset[0] >= (16 << shift) or offset[1] >= (16 << shift): error_result = True - if shift >= 0 and (offset[0] >= (16 << shift) or offset[1] >= (16 << shift)): + if shift >= 0 and ( + offset[0] >= (16 << shift) or offset[1] >= (16 << shift) + ): error_result = True - elif shift < 0 and (offset[0] >= (16 >> -shift) or offset[1] >= (16 >> -shift)): + elif shift < 0 and ( + offset[0] >= (16 >> -shift) or offset[1] >= (16 >> -shift) + ): error_result = True info_dict = { "error_name": error_name, "error_result": error_result, "error_reason": error_reason, - "param_reqs": param_reqs + "param_reqs": param_reqs, } return info_dict @@ -2038,21 +2267,24 @@ class TosaErrorValidator: error_reason = "Shift value must be zero for float input" if check: - shift = kwargs['shift'] - input_dtype = kwargs['input_dtype'] - output_dtype = kwargs['output_dtype'] - if input_dtype == DType.FLOAT and output_dtype == DType.FLOAT and shift != 0: + shift = kwargs["shift"] + input_dtype = kwargs["input_dtype"] + output_dtype = kwargs["output_dtype"] + if ( + input_dtype == DType.FLOAT + and output_dtype == DType.FLOAT + and shift != 0 + ): error_result = True info_dict = { "error_name": error_name, "error_result": error_result, "error_reason": error_reason, - "param_reqs": param_reqs + "param_reqs": param_reqs, } return info_dict - @staticmethod def evShiftSmallerOne(check=False, **kwargs): error_name = ErrorIf.ShiftSmallerOne @@ -2061,9 +2293,9 @@ class TosaErrorValidator: error_reason = "Shift value smaller than one" if check: - shift = kwargs['shift'] - input_dtype = kwargs['input_dtype'] - output_dtype = kwargs['output_dtype'] + shift = kwargs["shift"] + input_dtype = kwargs["input_dtype"] + output_dtype = kwargs["output_dtype"] if shift < 1 and input_dtype != DType.FLOAT and output_dtype != DType.FLOAT: error_result = True @@ -2071,7 +2303,7 @@ class TosaErrorValidator: "error_name": error_name, "error_result": error_result, "error_reason": error_reason, - "param_reqs": param_reqs + "param_reqs": param_reqs, } return info_dict @@ -2083,7 +2315,7 @@ class TosaErrorValidator: error_reason = "Shift value larger than eleven" if check: - shift = kwargs['shift'] + shift = kwargs["shift"] if shift > 11: error_result = True @@ -2091,11 +2323,10 @@ class TosaErrorValidator: "error_name": error_name, "error_result": error_result, "error_reason": error_reason, - "param_reqs": param_reqs + "param_reqs": param_reqs, } return info_dict - @staticmethod def evRankMismatch(check=False, **kwargs): error_name = ErrorIf.RankMismatch @@ -2104,23 +2335,25 @@ class TosaErrorValidator: error_reason = "Input Rank does not match output rank" if check: - input1_shape = kwargs['input1'].shape - input2_shape = kwargs['input2'].shape + input1_shape = kwargs["input1"].shape + input2_shape = kwargs["input2"].shape # In case of SELECT op - input3_shape = kwargs['input3'].shape if 'input3' in kwargs else input2_shape - output_shape = kwargs['result_tensor'].shape + input3_shape = ( + kwargs["input3"].shape if "input3" in kwargs else input2_shape + ) + output_shape = kwargs["result_tensor"].shape if ( - (len(input1_shape) != len(output_shape)) or - (len(input2_shape) != len(output_shape)) or - (len(input3_shape) != len(output_shape)) - ): + (len(input1_shape) != len(output_shape)) + or (len(input2_shape) != len(output_shape)) + or (len(input3_shape) != len(output_shape)) + ): error_result = True info_dict = { "error_name": error_name, "error_result": error_result, "error_reason": error_reason, - "param_reqs": param_reqs + "param_reqs": param_reqs, } return info_dict @@ -2132,30 +2365,34 @@ class TosaErrorValidator: error_reason = "Input Dimensions do not match output" if check: - input1_shape = kwargs['input1'].shape - input2_shape = kwargs['input2'].shape + input1_shape = kwargs["input1"].shape + input2_shape = kwargs["input2"].shape # In case of SELECT op - input3_shape = kwargs['input3'].shape if 'input3' in kwargs else input2_shape - output_shape = kwargs['result_tensor'].shape - for i in range(min(len(input1_shape), len(input2_shape), len(input3_shape))): + input3_shape = ( + kwargs["input3"].shape if "input3" in kwargs else input2_shape + ) + output_shape = kwargs["result_tensor"].shape + for i in range( + min(len(input1_shape), len(input2_shape), len(input3_shape)) + ): if ( - (input1_shape[i] != 1 and input1_shape[i] != output_shape[i]) or - (input2_shape[i] != 1 and input2_shape[i] != output_shape[i]) or - (input3_shape[i] != 1 and input3_shape[i] != output_shape[i]) - ): + (input1_shape[i] != 1 and input1_shape[i] != output_shape[i]) + or (input2_shape[i] != 1 and input2_shape[i] != output_shape[i]) + or (input3_shape[i] != 1 and input3_shape[i] != output_shape[i]) + ): error_result = True info_dict = { "error_name": error_name, "error_result": error_result, "error_reason": error_reason, - "param_reqs": param_reqs + "param_reqs": param_reqs, } return info_dict @staticmethod def evInputZeroPointNotZero(check=False, **kwargs): - op = kwargs['op'] + op = kwargs["op"] error_result = False # Quantizable types @@ -2163,26 +2400,27 @@ class TosaErrorValidator: # This does not apply to quantizable types inputDtypes = [ - dtype for dtype in op['types'] - if (isinstance(dtype, list) and dtype[0] not in qTypes) or - (not isinstance(dtype, list) and dtype not in qTypes) + dtype + for dtype in op["types"] + if (isinstance(dtype, list) and dtype[0] not in qTypes) + or (not isinstance(dtype, list) and dtype not in qTypes) ] if check: - input_dtype = kwargs['input_dtype'] - if isinstance(kwargs['qinfo'], tuple): - qinfo = kwargs['qinfo'] + input_dtype = kwargs["input_dtype"] + if isinstance(kwargs["qinfo"], tuple): + qinfo = kwargs["qinfo"] input_zero_point = qinfo[0] else: # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp - qinfo = kwargs['qinfo'].ints + qinfo = kwargs["qinfo"].ints input_zero_point = qinfo[0][1] - if op['op'] == Op.MATMUL: - qinfo = kwargs['qinfo'].ints + if op["op"] == Op.MATMUL: + qinfo = kwargs["qinfo"].ints for dtype, zp in ( - (kwargs['input_dtype'], qinfo[0][1]), - (kwargs['input2_dtype'], qinfo[1][1]), + (kwargs["input_dtype"], qinfo[0][1]), + (kwargs["input2_dtype"], qinfo[1][1]), ): if dtype not in qTypes and zp != 0: error_result = True @@ -2194,32 +2432,28 @@ class TosaErrorValidator: "error_name": ErrorIf.InputZeroPointNotZero, "error_result": error_result, "error_reason": "Input DType not INT8 and zero point not 0", - "param_reqs": {"rank": None, "dtype": inputDtypes, "shape": None} + "param_reqs": {"rank": None, "dtype": inputDtypes, "shape": None}, } return info_dict - @staticmethod def evWeightZeroPointNotZero(check=False, **kwargs): - op = kwargs['op'] + op = kwargs["op"] # exclude inputs with INT8 weights - inputDtypes = [t for t in op['types'] - if not isinstance(t, list) or t[1] != DType.INT8] + inputDtypes = [ + t for t in op["types"] if not isinstance(t, list) or t[1] != DType.INT8 + ] error_name = ErrorIf.WeightZeroPointNotZero - param_reqs = { - "rank": None, - "dtype": inputDtypes, - "shape": None - } + param_reqs = {"rank": None, "dtype": inputDtypes, "shape": None} error_result = False error_reason = "Weight DType not INT8 and zero point not 0" if check: - weight_dtype = kwargs['weight_dtype'] + weight_dtype = kwargs["weight_dtype"] # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = weight_zp - qinfo = kwargs['qinfo'].ints + qinfo = kwargs["qinfo"].ints weight_zero_point = qinfo[1][1] if weight_dtype != DType.INT8 and weight_zero_point != 0: error_result = True @@ -2228,50 +2462,47 @@ class TosaErrorValidator: "error_name": error_name, "error_result": error_result, "error_reason": error_reason, - "param_reqs": param_reqs + "param_reqs": param_reqs, } return info_dict - @staticmethod def evOutputZeroPointNotZero(check=False, **kwargs): - op = kwargs['op'] - inputDtypes = op['types'].copy() + op = kwargs["op"] + inputDtypes = op["types"].copy() if DType.INT8 in inputDtypes: inputDtypes.remove(DType.INT8) if DType.UINT8 in inputDtypes: inputDtypes.remove(DType.UINT8) error_name = ErrorIf.OutputZeroPointNotZero - param_reqs = { - "rank": None, - "dtype": inputDtypes, - "shape": None - } + param_reqs = {"rank": None, "dtype": inputDtypes, "shape": None} error_result = False error_reason = "Output DType not INT8 and zero point not 0" if check: - input_dtype = kwargs['input_dtype'] - output_dtype = kwargs['output_dtype'] - if isinstance(kwargs['qinfo'], tuple): - qinfo = kwargs['qinfo'] + input_dtype = kwargs["input_dtype"] + output_dtype = kwargs["output_dtype"] + if isinstance(kwargs["qinfo"], tuple): + qinfo = kwargs["qinfo"] output_zero_point = qinfo[1] else: # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp - qinfo = kwargs['qinfo'].ints + qinfo = kwargs["qinfo"].ints output_zero_point = qinfo[1][1] - if op['op'] == Op.AVG_POOL2D: + if op["op"] == Op.AVG_POOL2D: if input_dtype != DType.INT8 and output_zero_point != 0: error_result = True - elif output_dtype not in [DType.INT8, DType.UINT8] and output_zero_point != 0: + elif ( + output_dtype not in [DType.INT8, DType.UINT8] and output_zero_point != 0 + ): error_result = True info_dict = { "error_name": error_name, "error_result": error_result, "error_reason": error_reason, - "param_reqs": param_reqs + "param_reqs": param_reqs, } return info_dict @@ -2283,7 +2514,7 @@ class TosaErrorValidator: error_reason = "Axis smaller than zero" if check: - axis = kwargs['axis'] + axis = kwargs["axis"] if axis < 0: error_result = True @@ -2291,11 +2522,10 @@ class TosaErrorValidator: "error_name": error_name, "error_result": error_result, "error_reason": error_reason, - "param_reqs": param_reqs + "param_reqs": param_reqs, } return info_dict - @staticmethod def evAxisLargerRank(check=False, **kwargs): error_name = ErrorIf.AxisLargerRank @@ -2304,8 +2534,8 @@ class TosaErrorValidator: error_reason = "Axis larger than rank" if check: - axis = kwargs['axis'] - shape = kwargs['input_shape'] + axis = kwargs["axis"] + shape = kwargs["input_shape"] if axis > len(shape): error_result = True @@ -2313,11 +2543,10 @@ class TosaErrorValidator: "error_name": error_name, "error_result": error_result, "error_reason": error_reason, - "param_reqs": param_reqs + "param_reqs": param_reqs, } return info_dict - @staticmethod def evShapeOfAxisNotOne(check=False, **kwargs): error_name = ErrorIf.ShapeOfAxisNotOne @@ -2326,8 +2555,8 @@ class TosaErrorValidator: error_reason = "shape[axis] is not equal to 1" if check: - axis = kwargs['axis'] - shape = kwargs['output_shape'] + axis = kwargs["axis"] + shape = kwargs["output_shape"] if (0 <= axis < len(shape)) and shape[axis] != 1: error_result = True @@ -2335,11 +2564,10 @@ class TosaErrorValidator: "error_name": error_name, "error_result": error_result, "error_reason": error_reason, - "param_reqs": param_reqs + "param_reqs": param_reqs, } return info_dict - @staticmethod def evPadSmallerZero(check=False, **kwargs): error_name = ErrorIf.PadSmallerZero @@ -2348,9 +2576,9 @@ class TosaErrorValidator: error_reason = "At least one pad is smaller than zero" if check: - op = kwargs['op'] - pad = kwargs['pad'] - if op['op'] == Op.PAD: + op = kwargs["op"] + pad = kwargs["pad"] + if op["op"] == Op.PAD: for padding in pad: if min(padding) < 0: error_result = True @@ -2362,11 +2590,10 @@ class TosaErrorValidator: "error_name": error_name, "error_result": error_result, "error_reason": error_reason, - "param_reqs": param_reqs + "param_reqs": param_reqs, } return info_dict - @staticmethod def evPadLargerEqualKernel(check=False, **kwargs): error_name = ErrorIf.PadLargerEqualKernel @@ -2375,17 +2602,22 @@ class TosaErrorValidator: error_reason = "At least one pad is larger than kernel dimension" if check: - pad = kwargs['pad'] - kernel = kwargs['kernel'] + pad = kwargs["pad"] + kernel = kwargs["kernel"] if min(pad) > 0 and min(kernel) > 1: - if pad[0] >= kernel[0] or pad[1] >= kernel[0] or pad[2] >= kernel[1] or pad[3] >= kernel[1]: + if ( + pad[0] >= kernel[0] + or pad[1] >= kernel[0] + or pad[2] >= kernel[1] + or pad[3] >= kernel[1] + ): error_result = True info_dict = { "error_name": error_name, "error_result": error_result, "error_reason": error_reason, - "param_reqs": param_reqs + "param_reqs": param_reqs, } return info_dict @@ -2394,32 +2626,47 @@ class TosaErrorValidator: error_name = ErrorIf.PoolingOutputShapeMismatch param_reqs = {"rank": None, "dtype": None, "shape": None} error_result = False - error_reason = "Mismatch between output shape provided and expected output shape" + error_reason = ( + "Mismatch between output shape provided and expected output shape" + ) if check: - pad = kwargs['pad'] + pad = kwargs["pad"] pad_top, pad_bottom, pad_left, pad_right = pad[0], pad[1], pad[2], pad[3] - kernel = kwargs['kernel'] + kernel = kwargs["kernel"] kernel_y, kernel_x = kernel[0], kernel[1] - input_shape = kwargs['input_shape'] + input_shape = kwargs["input_shape"] IH, IW = input_shape[1], input_shape[2] - output_shape = kwargs['output_shape'] + output_shape = kwargs["output_shape"] OH, OW = output_shape[1], output_shape[2] - stride = kwargs['stride'] + stride = kwargs["stride"] stride_y, stride_x = stride[0], stride[1] # calculate correct height, width dimensions if stride_x != 0 and stride_y != 0: - y_correct = (IH + pad_top + pad_bottom + stride_y - kernel_y) // stride_y - x_correct = (IW + pad_left + pad_right + stride_x - kernel_x) // stride_x + y_correct = ( + IH + pad_top + pad_bottom + stride_y - kernel_y + ) // stride_y + x_correct = ( + IW + pad_left + pad_right + stride_x - kernel_x + ) // stride_x # ensure parameters are valid - params_valid = (min(kernel) >= 1 and min(stride) >= 1 and min(pad) >= 0 - and not (pad[0] >= kernel[0] or pad[1] >= kernel[0] or pad[2] >= kernel[1] or pad[3] >= kernel[1])) + params_valid = ( + min(kernel) >= 1 + and min(stride) >= 1 + and min(pad) >= 0 + and not ( + pad[0] >= kernel[0] + or pad[1] >= kernel[0] + or pad[2] >= kernel[1] + or pad[3] >= kernel[1] + ) + ) if params_valid and (OH != y_correct or OW != x_correct): error_result = True @@ -2428,21 +2675,23 @@ class TosaErrorValidator: "error_name": error_name, "error_result": error_result, "error_reason": error_reason, - "param_reqs": param_reqs + "param_reqs": param_reqs, } return info_dict @staticmethod def evArgmaxOutputShapeMismatch(check=False, **kwargs): error_name = ErrorIf.ArgmaxOutputShapeMismatch - param_reqs = {"rank": [2,4], "dtype": None, "shape": None} + param_reqs = {"rank": [2, 4], "dtype": None, "shape": None} error_result = False - error_reason = "Mismatch between output shape provided and expected output shape" + error_reason = ( + "Mismatch between output shape provided and expected output shape" + ) if check: - output_shape = kwargs['output_shape'] - input_shape = kwargs['input_shape'] - axis = kwargs['axis'] + output_shape = kwargs["output_shape"] + input_shape = kwargs["input_shape"] + axis = kwargs["axis"] dimension_match = True axis_shift = 0 @@ -2463,7 +2712,7 @@ class TosaErrorValidator: "error_name": error_name, "error_result": error_result, "error_reason": error_reason, - "param_reqs": param_reqs + "param_reqs": param_reqs, } return info_dict @@ -2472,12 +2721,14 @@ class TosaErrorValidator: error_name = ErrorIf.ArgmaxOutputRankMismatch param_reqs = {"rank": None, "dtype": None, "shape": None} error_result = False - error_reason = "Mismatch between output shape provided and expected output shape" + error_reason = ( + "Mismatch between output shape provided and expected output shape" + ) if check: - output_shape = kwargs['output_shape'] - input_shape = kwargs['input_shape'] - axis = kwargs['axis'] + output_shape = kwargs["output_shape"] + input_shape = kwargs["input_shape"] + axis = kwargs["axis"] valid_params = axis >= 0 and axis < len(input_shape) if valid_params and (len(input_shape) - 1) != len(output_shape): @@ -2487,11 +2738,10 @@ class TosaErrorValidator: "error_name": error_name, "error_result": error_result, "error_reason": error_reason, - "param_reqs": param_reqs + "param_reqs": param_reqs, } return info_dict - @staticmethod def evKernelSmallerOne(check=False, **kwargs): error_name = ErrorIf.KernelSmallerOne @@ -2500,7 +2750,7 @@ class TosaErrorValidator: error_reason = "At least one kernel dimension is smaller than zero" if check: - kernel = kwargs['kernel'] + kernel = kwargs["kernel"] if min(kernel) < 1: error_result = True @@ -2508,7 +2758,7 @@ class TosaErrorValidator: "error_name": error_name, "error_result": error_result, "error_reason": error_reason, - "param_reqs": param_reqs + "param_reqs": param_reqs, } return info_dict @@ -2520,7 +2770,7 @@ class TosaErrorValidator: error_reason = "At least one stride dimension is smaller than zero" if check: - stride = kwargs['stride'] + stride = kwargs["stride"] if min(stride) < 1: error_result = True @@ -2528,18 +2778,18 @@ class TosaErrorValidator: "error_name": error_name, "error_result": error_result, "error_reason": error_reason, - "param_reqs": param_reqs + "param_reqs": param_reqs, } return info_dict @staticmethod def evDilationSmallerOne(check=False, **kwargs): - error_result = check and min(kwargs['dilation']) < 1 + error_result = check and min(kwargs["dilation"]) < 1 return { "error_name": ErrorIf.DilationSmallerOne, "error_reason": "At least one dilation is smaller than one", "param_reqs": {"rank": None, "dtype": None, "shape": None}, - "error_result": error_result + "error_result": error_result, } @staticmethod @@ -2550,8 +2800,8 @@ class TosaErrorValidator: error_reason = "Scale set to true but input type is INT48" if check: - input_dtype = kwargs['input_dtype'] - scale32 = kwargs['scale32'] + input_dtype = kwargs["input_dtype"] + scale32 = kwargs["scale32"] if scale32 and input_dtype == DType.INT48: error_result = True @@ -2559,7 +2809,7 @@ class TosaErrorValidator: "error_name": error_name, "error_result": error_result, "error_reason": error_reason, - "param_reqs": param_reqs + "param_reqs": param_reqs, } return info_dict @@ -2571,8 +2821,8 @@ class TosaErrorValidator: error_reason = "Scale set to false but double round set to true" if check: - scale32 = kwargs['scale32'] - double_round = kwargs['double_round'] + scale32 = kwargs["scale32"] + double_round = kwargs["double_round"] if not scale32 and double_round: error_result = True @@ -2580,7 +2830,7 @@ class TosaErrorValidator: "error_name": error_name, "error_result": error_result, "error_reason": error_reason, - "param_reqs": param_reqs + "param_reqs": param_reqs, } return info_dict @@ -2592,8 +2842,8 @@ class TosaErrorValidator: error_reason = "Input tensor size does not match output tensor size" if check: - input_shape = kwargs['input_shape'] - output_shape = kwargs['output_shape'] + input_shape = kwargs["input_shape"] + output_shape = kwargs["output_shape"] input_size = np.prod(input_shape) output_size = np.prod(output_shape) if input_size != output_size: @@ -2603,7 +2853,7 @@ class TosaErrorValidator: "error_name": error_name, "error_result": error_result, "error_reason": error_reason, - "param_reqs": param_reqs + "param_reqs": param_reqs, } return info_dict @@ -2615,8 +2865,8 @@ class TosaErrorValidator: error_reason = "Starting point smaller than zero" if check: - input_shape = kwargs['input_shape'] - start = kwargs['start'] + input_shape = kwargs["input_shape"] + start = kwargs["start"] rank = len(input_shape) if len(start) == rank: for index in range(rank): @@ -2627,11 +2877,10 @@ class TosaErrorValidator: "error_name": error_name, "error_result": error_result, "error_reason": error_reason, - "param_reqs": param_reqs + "param_reqs": param_reqs, } return info_dict - @staticmethod def evSizeSmallerEqualZero(check=False, **kwargs): error_name = ErrorIf.SizeSmallerEqualZero @@ -2640,8 +2889,8 @@ class TosaErrorValidator: error_reason = "Size smaller than or equal to zero" if check: - input_shape = kwargs['input_shape'] - size = kwargs['size'] + input_shape = kwargs["input_shape"] + size = kwargs["size"] rank = len(input_shape) if len(size) == rank: for index in range(rank): @@ -2652,11 +2901,10 @@ class TosaErrorValidator: "error_name": error_name, "error_result": error_result, "error_reason": error_reason, - "param_reqs": param_reqs + "param_reqs": param_reqs, } return info_dict - @staticmethod def evStartSizeOutsideBounds(check=False, **kwargs): error_name = ErrorIf.StartSizeOutsideBounds @@ -2665,9 +2913,9 @@ class TosaErrorValidator: error_reason = "starting point plus size larger than input dimension" if check: - input_shape = kwargs['input_shape'] - start = kwargs['start'] - size = kwargs['size'] + input_shape = kwargs["input_shape"] + start = kwargs["start"] + size = kwargs["size"] rank = len(input_shape) if len(start) == rank and len(size) == rank: for index in range(rank): @@ -2678,11 +2926,10 @@ class TosaErrorValidator: "error_name": error_name, "error_result": error_result, "error_reason": error_reason, - "param_reqs": param_reqs + "param_reqs": param_reqs, } return info_dict - @staticmethod def evSizeOutputShapeMismatch(check=False, **kwargs): error_name = ErrorIf.SizeOutputShapeMismatch @@ -2691,9 +2938,9 @@ class TosaErrorValidator: error_reason = "Size does not match output dimension" if check: - input_shape = kwargs['input_shape'] - output_shape = kwargs['output_shape'] - size = kwargs['size'] + input_shape = kwargs["input_shape"] + output_shape = kwargs["output_shape"] + size = kwargs["size"] rank = len(input_shape) if len(size) == rank: for index in range(rank): @@ -2704,7 +2951,7 @@ class TosaErrorValidator: "error_name": error_name, "error_result": error_result, "error_reason": error_reason, - "param_reqs": param_reqs + "param_reqs": param_reqs, } return info_dict @@ -2716,9 +2963,9 @@ class TosaErrorValidator: error_reason = "rank of input not equal to length of start or size" if check: - input_shape = kwargs['input_shape'] - start = kwargs['start'] - size = kwargs['size'] + input_shape = kwargs["input_shape"] + start = kwargs["start"] + size = kwargs["size"] rank = len(input_shape) if rank != len(start) or rank != len(size): error_result = True @@ -2727,7 +2974,7 @@ class TosaErrorValidator: "error_name": error_name, "error_result": error_result, "error_reason": error_reason, - "param_reqs": param_reqs + "param_reqs": param_reqs, } return info_dict @@ -2739,8 +2986,8 @@ class TosaErrorValidator: error_reason = "Index outside of allowed bounds" if check: - input_shape = kwargs['input_shape'] - perms = kwargs['perms'] + input_shape = kwargs["input_shape"] + perms = kwargs["perms"] rank = len(input_shape) for index in perms: @@ -2751,21 +2998,19 @@ class TosaErrorValidator: "error_name": error_name, "error_result": error_result, "error_reason": error_reason, - "param_reqs": param_reqs + "param_reqs": param_reqs, } return info_dict @staticmethod def evIndexUsedTwice(check=False, **kwargs): error_name = ErrorIf.IndexUsedTwice - param_reqs = {"rank": [2,4], "dtype": None, "shape": None} + param_reqs = {"rank": [2, 4], "dtype": None, "shape": None} error_result = False error_reason = "Index used multiple times" if check: - input_shape = kwargs['input_shape'] - perms = kwargs['perms'] - rank = len(input_shape) + perms = kwargs["perms"] unique_indices = [] for index in perms: @@ -2778,42 +3023,41 @@ class TosaErrorValidator: "error_name": error_name, "error_result": error_result, "error_reason": error_reason, - "param_reqs": param_reqs + "param_reqs": param_reqs, } return info_dict @staticmethod def evMaxSmallerMin(check=False, **kwargs): error_name = ErrorIf.MaxSmallerMin - param_reqs = {"rank": [2,4], "dtype": None, "shape": None} + param_reqs = {"rank": [2, 4], "dtype": None, "shape": None} error_result = False error_reason = "Max value smaller than min value" if check: - max_val = kwargs['max_val'] - min_val = kwargs['min_val'] + max_val = kwargs["max_val"] + min_val = kwargs["min_val"] if max_val < min_val: error_result = True - info_dict = { "error_name": error_name, "error_result": error_result, "error_reason": error_reason, - "param_reqs": param_reqs + "param_reqs": param_reqs, } return info_dict @staticmethod def evConcatInputRankMismatch(check=False, **kwargs): error_name = ErrorIf.ConcatInputRankMismatch - param_reqs = {"rank": [2,4], "dtype": None, "shape": None} + param_reqs = {"rank": [2, 4], "dtype": None, "shape": None} error_result = False error_reason = "Input ranks are not identical" if check: - inputs = kwargs['inputs'] - input_shape = kwargs['input_shape'] + inputs = kwargs["inputs"] + input_shape = kwargs["input_shape"] for input in inputs: if len(input.shape) != len(input_shape): error_result = True @@ -2822,21 +3066,21 @@ class TosaErrorValidator: "error_name": error_name, "error_result": error_result, "error_reason": error_reason, - "param_reqs": param_reqs + "param_reqs": param_reqs, } return info_dict @staticmethod def evConcatInputDimMismatch(check=False, **kwargs): error_name = ErrorIf.ConcatInputDimMismatch - param_reqs = {"rank": [2,4], "dtype": None, "shape": None} + param_reqs = {"rank": [2, 4], "dtype": None, "shape": None} error_result = False error_reason = "Input dimensions differ on too many axes" if check: - inputs = kwargs['inputs'] - input_shape = kwargs['input_shape'] - axis = kwargs['axis'] + inputs = kwargs["inputs"] + input_shape = kwargs["input_shape"] + axis = kwargs["axis"] # Ensure rank is valid before checking dims. valid_rank = True @@ -2854,22 +3098,22 @@ class TosaErrorValidator: "error_name": error_name, "error_result": error_result, "error_reason": error_reason, - "param_reqs": param_reqs + "param_reqs": param_reqs, } return info_dict @staticmethod def evConcatShapeSumMismatch(check=False, **kwargs): error_name = ErrorIf.ConcatShapeSumMismatch - param_reqs = {"rank": [2,4], "dtype": None, "shape": None} + param_reqs = {"rank": [2, 4], "dtype": None, "shape": None} error_result = False error_reason = "Sum of dimensions on axis not equal to output dimension" if check: - inputs = kwargs['inputs'] - input_shape = kwargs['input_shape'] - output_shape = kwargs['output_shape'] - axis = kwargs['axis'] + inputs = kwargs["inputs"] + input_shape = kwargs["input_shape"] + output_shape = kwargs["output_shape"] + axis = kwargs["axis"] # Ensure rank is valid before checking dims. valid_params = True @@ -2887,12 +3131,11 @@ class TosaErrorValidator: if axis_dim_sum != output_shape[axis]: error_result = True - info_dict = { "error_name": error_name, "error_result": error_result, "error_reason": error_reason, - "param_reqs": param_reqs + "param_reqs": param_reqs, } return info_dict @@ -2904,24 +3147,25 @@ class TosaErrorValidator: error_reason = "Input list shape does not match then-graph shape" if check: - a = kwargs['a'] - b = kwargs['b'] - basicBlocks = kwargs['basicBlocks'] + a = kwargs["a"] + b = kwargs["b"] + basicBlocks = kwargs["basicBlocks"] then_block = basicBlocks[1] then_inputs = then_block.inputs then_tens = then_block.tensors - if (a.shape != then_tens[then_inputs[0]].shape) or (b.shape != then_tens[then_inputs[1]].shape): + if (a.shape != then_tens[then_inputs[0]].shape) or ( + b.shape != then_tens[then_inputs[1]].shape + ): error_result = True info_dict = { "error_name": error_name, "error_result": error_result, "error_reason": error_reason, - "param_reqs": param_reqs + "param_reqs": param_reqs, } return info_dict - @staticmethod def evInputListElseGraphMismatch(check=False, **kwargs): error_name = ErrorIf.CondIfInputListElseGraphMismatch @@ -2930,24 +3174,25 @@ class TosaErrorValidator: error_reason = "Input list shape does not match else-graph shape" if check: - a = kwargs['a'] - b = kwargs['b'] - basicBlocks = kwargs['basicBlocks'] + a = kwargs["a"] + b = kwargs["b"] + basicBlocks = kwargs["basicBlocks"] else_block = basicBlocks[2] else_inputs = else_block.inputs else_tens = else_block.tensors - if (a.shape != else_tens[else_inputs[0]].shape) or (b.shape != else_tens[else_inputs[1]].shape): + if (a.shape != else_tens[else_inputs[0]].shape) or ( + b.shape != else_tens[else_inputs[1]].shape + ): error_result = True info_dict = { "error_name": error_name, "error_result": error_result, "error_reason": error_reason, - "param_reqs": param_reqs + "param_reqs": param_reqs, } return info_dict - @staticmethod def evOutputListThenGraphMismatch(check=False, **kwargs): error_name = ErrorIf.CondIfOutputListThenGraphMismatch @@ -2956,7 +3201,7 @@ class TosaErrorValidator: error_reason = "Output list shape does not match then-graph shape" if check: - basicBlocks = kwargs['basicBlocks'] + basicBlocks = kwargs["basicBlocks"] cond_block = basicBlocks[0] cond_outputs = cond_block.outputs cond_tens = cond_block.tensors @@ -2970,11 +3215,10 @@ class TosaErrorValidator: "error_name": error_name, "error_result": error_result, "error_reason": error_reason, - "param_reqs": param_reqs + "param_reqs": param_reqs, } return info_dict - @staticmethod def evOutputListElseGraphMismatch(check=False, **kwargs): error_name = ErrorIf.CondIfOutputListElseGraphMismatch @@ -2983,7 +3227,7 @@ class TosaErrorValidator: error_reason = "Output list shape does not match else-graph shape" if check: - basicBlocks = kwargs['basicBlocks'] + basicBlocks = kwargs["basicBlocks"] cond_block = basicBlocks[0] cond_outputs = cond_block.outputs cond_tens = cond_block.tensors @@ -2997,11 +3241,10 @@ class TosaErrorValidator: "error_name": error_name, "error_result": error_result, "error_reason": error_reason, - "param_reqs": param_reqs + "param_reqs": param_reqs, } return info_dict - @staticmethod def evInputListOutputListMismatch(check=False, **kwargs): error_name = ErrorIf.InputListOutputListMismatch @@ -3010,7 +3253,7 @@ class TosaErrorValidator: error_reason = "Input list does not match output list" if check: - basicBlocks = kwargs['basicBlocks'] + basicBlocks = kwargs["basicBlocks"] while_block = basicBlocks[0] while_inputs = while_block.inputs while_outputs = while_block.outputs @@ -3022,11 +3265,10 @@ class TosaErrorValidator: "error_name": error_name, "error_result": error_result, "error_reason": error_reason, - "param_reqs": param_reqs + "param_reqs": param_reqs, } return info_dict - @staticmethod def evInputListCondGraphMismatch(check=False, **kwargs): error_name = ErrorIf.InputListCondGraphMismatch @@ -3035,26 +3277,26 @@ class TosaErrorValidator: error_reason = "Input list does not match cond graph" if check: - basicBlocks = kwargs['basicBlocks'] + basicBlocks = kwargs["basicBlocks"] while_block = basicBlocks[0] while_inputs = while_block.inputs while_tens = while_block.tensors cond_block = basicBlocks[1] cond_inputs = cond_block.inputs cond_tens = cond_block.tensors - if ((while_tens[while_inputs[0]].shape != cond_tens[cond_inputs[0]].shape) or - (while_tens[while_inputs[1]].shape != cond_tens[cond_inputs[2]].shape)): + if ( + while_tens[while_inputs[0]].shape != cond_tens[cond_inputs[0]].shape + ) or (while_tens[while_inputs[1]].shape != cond_tens[cond_inputs[2]].shape): error_result = True info_dict = { "error_name": error_name, "error_result": error_result, "error_reason": error_reason, - "param_reqs": param_reqs + "param_reqs": param_reqs, } return info_dict - @staticmethod def evInputListBodyGraphInputMismatch(check=False, **kwargs): error_name = ErrorIf.InputListBodyGraphInputMismatch @@ -3063,26 +3305,28 @@ class TosaErrorValidator: error_reason = "Input list does not match body graph input" if check: - basicBlocks = kwargs['basicBlocks'] + basicBlocks = kwargs["basicBlocks"] while_block = basicBlocks[0] while_inputs = while_block.inputs while_tens = while_block.tensors body_block = basicBlocks[2] body_outputs = body_block.inputs body_tens = body_block.tensors - if ((while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape) or - (while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape)): + if ( + while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape + ) or ( + while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape + ): error_result = True info_dict = { "error_name": error_name, "error_result": error_result, "error_reason": error_reason, - "param_reqs": param_reqs + "param_reqs": param_reqs, } return info_dict - @staticmethod def evInputListBodyGraphOutputMismatch(check=False, **kwargs): error_name = ErrorIf.InputListBodyGraphOutputMismatch @@ -3091,25 +3335,27 @@ class TosaErrorValidator: error_reason = "Input list does not match body graph output" if check: - basicBlocks = kwargs['basicBlocks'] + basicBlocks = kwargs["basicBlocks"] while_block = basicBlocks[0] while_inputs = while_block.inputs while_tens = while_block.tensors body_block = basicBlocks[2] body_outputs = body_block.outputs body_tens = body_block.tensors - if ((while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape) or - (while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape)): + if ( + while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape + ) or ( + while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape + ): error_result = True info_dict = { "error_name": error_name, "error_result": error_result, "error_reason": error_reason, - "param_reqs": param_reqs + "param_reqs": param_reqs, } return info_dict - @staticmethod def evCondGraphOutputNotMatchingBool(check=False, **kwargs): error_name = ErrorIf.CondGraphOutputNotMatchingBool @@ -3118,7 +3364,7 @@ class TosaErrorValidator: error_reason = "Cond graph output is not a match list of booleans" if check: - basicBlocks = kwargs['basicBlocks'] + basicBlocks = kwargs["basicBlocks"] cond_block = basicBlocks[1] cond_outputs = cond_block.outputs cond_tens = cond_block.tensors @@ -3129,35 +3375,31 @@ class TosaErrorValidator: "error_name": error_name, "error_result": error_result, "error_reason": error_reason, - "param_reqs": param_reqs + "param_reqs": param_reqs, } return info_dict class TosaInvalidValidator: - @staticmethod def ivWrongDataTypeOrModeResize(**kwargs): input_dtype = kwargs["input_dtype"] args = kwargs["args"] mode = args[0] - stride = args[1] - stride_fp = args[4] output_dtype = args[8] if mode == ResizeMode.BILINEAR: # Invalid output data type / Invalid input datatype return ( - not (input_dtype == DType.INT8 and output_dtype == DType.INT32) or - not (input_dtype == DType.INT16 and output_dtype == DType.INT48) or - not (input_dtype == DType.FLOAT and output_dtype == DType.FLOAT) or - (input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT]) + not (input_dtype == DType.INT8 and output_dtype == DType.INT32) + or not (input_dtype == DType.INT16 and output_dtype == DType.INT48) + or not (input_dtype == DType.FLOAT and output_dtype == DType.FLOAT) + or (input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT]) ) elif mode == ResizeMode.NEAREST: # Invalid output data type / Invalid input datatype - return ( - (input_dtype != output_dtype) or - (input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT]) + return (input_dtype != output_dtype) or ( + input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT] ) else: # Invalid resize mode @@ -3184,20 +3426,24 @@ class TosaInvalidValidator: @staticmethod def ivHeightWidthInvalid(**kwargs): - opName = kwargs['opName'] + opName = kwargs["opName"] - inputShapes = kwargs['shapeList'] + inputShapes = kwargs["shapeList"] input_shape = inputShapes[0] - args = kwargs['args'] + args = kwargs["args"] strides = args[0] padding = args[1] if opName.endswith("pool2d"): # avg_pool2d, max_pool2d kernel_shape = args[2] - h = (input_shape[1] + padding[0] + padding[1] + strides[0] - kernel_shape[0]) // strides[0] - w = (input_shape[2] + padding[2] + padding[3] + strides[1] - kernel_shape[1]) // strides[1] + h = ( + input_shape[1] + padding[0] + padding[1] + strides[0] - kernel_shape[0] + ) // strides[0] + w = ( + input_shape[2] + padding[2] + padding[3] + strides[1] - kernel_shape[1] + ) // strides[1] # return True if any dimension is < 1 return h < 1 or w < 1 @@ -3226,17 +3472,31 @@ class TosaInvalidValidator: the output size """ dilated_kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1) - return (in_size - 1) * stride + dilated_kernel_size - 2 * in_pad + out_pad + return ( + (in_size - 1) * stride + dilated_kernel_size - 2 * in_pad + out_pad + ) for pad_h, pad_w in ( - (kernel_shape[0] - 1, kernel_shape[1] - 1), # FULL padding - (kernel_shape[0] // 2, kernel_shape[1] // 2), # SAME padding - (0, 0) # VALID padding + (kernel_shape[0] - 1, kernel_shape[1] - 1), # FULL padding + (kernel_shape[0] // 2, kernel_shape[1] // 2), # SAME padding + (0, 0), # VALID padding ): - h = get_out_size(input_shape[1], strides[0], kernel_shape[0], dilations[0], - padding[0], pad_h) - w = get_out_size(input_shape[2], strides[1], kernel_shape[1], dilations[1], - padding[1], pad_w) + h = get_out_size( + input_shape[1], + strides[0], + kernel_shape[0], + dilations[0], + padding[0], + pad_h, + ) + w = get_out_size( + input_shape[2], + strides[1], + kernel_shape[1], + dilations[1], + padding[1], + pad_w, + ) if output_shape[1] == h and output_shape[2] == w: return False @@ -3247,7 +3507,11 @@ class TosaInvalidValidator: # conv2d, conv3d, depthwise_conv2d dilations = args[2] filter_shape = inputShapes[1] - kernel_shape = filter_shape[0:2] if opName.startswith("depthwise_conv2d") else filter_shape[1:-1] + kernel_shape = ( + filter_shape[0:2] + if opName.startswith("depthwise_conv2d") + else filter_shape[1:-1] + ) for i in range(len(kernel_shape)): dim = ( @@ -3266,7 +3530,7 @@ class TosaInvalidValidator: @staticmethod def ivNonPositiveOutputShape(**kwargs): - args = kwargs['args'] + args = kwargs["args"] output_shape = args[3] if output_shape[1] <= 0 or output_shape[2] <= 0: # Negative output shape @@ -3310,13 +3574,12 @@ class TosaTestGen: fd.write(self.ser.writeJson("{}.tosa".format(testName))) def resetRNG(self, seed=None): - if seed == None: + if seed is None: seed = self.random_seed + 1 self.rng = np.random.default_rng(seed) def getRandTensor(self, shape, dtype): if dtype == DType.BOOL: - np_dt = np.bool return np.bool_(self.rng.choice(a=[False, True], size=shape)) # TOSA specific INT4 weight range from -7 to 7 elif dtype == DType.INT4: @@ -3469,8 +3732,8 @@ class TosaTestGen: if isinstance(op, int): self.ser.addOperator(op, a.name, result_tens.name, None, qinfo) return result_tens - elif op['op'] == Op.IDENTITY: - self.ser.addOperator(op['op'], a.name, result_tens.name, None, qinfo) + elif op["op"] == Op.IDENTITY: + self.ser.addOperator(op["op"], a.name, result_tens.name, None, qinfo) return result_tens # Ensure new output type has correct qinfo @@ -3478,7 +3741,8 @@ class TosaTestGen: if result_tens.dtype not in [DType.INT8, DType.UINT8]: qinfo = ts.TosaSerializerQuantInfo() qinfo.UnaryQuantInfo( - TosaQuantGen.getQinfo(self, a.dtype), TosaQuantGen.getQinfo(self, result_tens.dtype) + TosaQuantGen.getQinfo(self, a.dtype), + TosaQuantGen.getQinfo(self, result_tens.dtype), ) # Invalidate Input/Output list for error if checks. @@ -3486,7 +3750,9 @@ class TosaTestGen: output_list = [result_tens.name] pCount, cCount = op["operands"] num_operands = pCount + cCount - input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) + input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( + self, error_name, input_list, output_list + ) if not TosaErrorValidator.evValidateErrorIfs( self.ser, @@ -3495,72 +3761,81 @@ class TosaTestGen: op=op, input_dtype=a.dtype, output_dtype=result_tens.dtype, - qinfo = qinfo, - result_tensor = result_tens, + qinfo=qinfo, + result_tensor=result_tens, input_list=input_list, output_list=output_list, num_operands=num_operands, ): return None - self.ser.addOperator(op['op'], input_list, output_list, None, qinfo) + self.ser.addOperator(op["op"], input_list, output_list, None, qinfo) return result_tens def build_binary_broadcast(self, op, a, b, validator_fcns, error_name=None): - result_tens = OutputShaper.binaryBroadcastOp(self.ser, self.rng, a, b, error_name) - + result_tens = OutputShaper.binaryBroadcastOp( + self.ser, self.rng, a, b, error_name + ) # Invalidate Input/Output list for error if checks. input_list = [a.name, b.name] output_list = [result_tens.name] pCount, cCount = op["operands"] num_operands = pCount + cCount - input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) + input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( + self, error_name, input_list, output_list + ) if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, op=op, - input1 = a, - input2 = b, - input_dtype = a.dtype, - output_dtype = result_tens.dtype, - result_tensor = result_tens, + input1=a, + input2=b, + input_dtype=a.dtype, + output_dtype=result_tens.dtype, + result_tensor=result_tens, input_list=input_list, output_list=output_list, num_operands=num_operands, ): return None - self.ser.addOperator(op['op'], input_list, output_list) + self.ser.addOperator(op["op"], input_list, output_list) return result_tens def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None): result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b) - self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name]) + self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name]) return result_tens - def build_arithmetic_right_shift(self, op, a, b, round, validator_fcns=None, error_name=None): - result_tens = OutputShaper.binaryBroadcastOp(self.ser, self.rng, a, b, error_name) + def build_arithmetic_right_shift( + self, op, a, b, round, validator_fcns=None, error_name=None + ): + result_tens = OutputShaper.binaryBroadcastOp( + self.ser, self.rng, a, b, error_name + ) # Invalidate Input/Output list for error if checks. input_list = [a.name, b.name] output_list = [result_tens.name] pCount, cCount = op["operands"] num_operands = pCount + cCount - input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) + input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( + self, error_name, input_list, output_list + ) if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, op=op, - input1 = a, - input2 = b, - input_dtype = a.dtype, - output_dtype = result_tens.dtype, - result_tensor = result_tens, + input1=a, + input2=b, + input_dtype=a.dtype, + output_dtype=result_tens.dtype, + result_tensor=result_tens, input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -3570,11 +3845,13 @@ class TosaTestGen: attr = ts.TosaSerializerAttribute() attr.ArithmeticRightShiftAttribute(round) - self.ser.addOperator(op['op'], input_list, output_list, attr) + self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens def build_mul(self, op, a, b, shift, validator_fcns=None, error_name=None): - result_tens = OutputShaper.binaryBroadcastOp(self.ser, self.rng, a, b, error_name) + result_tens = OutputShaper.binaryBroadcastOp( + self.ser, self.rng, a, b, error_name + ) # Special for multiply: # Force the result to INT32 for INT types @@ -3590,18 +3867,20 @@ class TosaTestGen: output_list = [result_tens.name] pCount, cCount = op["operands"] num_operands = pCount + cCount - input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) + input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( + self, error_name, input_list, output_list + ) if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, op=op, - input1 = a, - input2 = b, - input_dtype = a.dtype, - output_dtype = result_tens.dtype, - result_tensor = result_tens, + input1=a, + input2=b, + input_dtype=a.dtype, + output_dtype=result_tens.dtype, + result_tensor=result_tens, input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -3611,7 +3890,7 @@ class TosaTestGen: attr = ts.TosaSerializerAttribute() attr.MulAttribute(shift) - self.ser.addOperator(op['op'], input_list, output_list, attr) + self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens def build_table(self, op, a, table, validator_fcns=None, error_name=None): @@ -3625,24 +3904,26 @@ class TosaTestGen: output_list = [result_tens.name] pCount, cCount = op["operands"] num_operands = pCount + cCount - input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) + input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( + self, error_name, input_list, output_list + ) if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, op=op, - input_shape = a.shape, - input_dtype = a.dtype, - output_dtype = result_tens.dtype, - result_tensor = result_tens, + input_shape=a.shape, + input_dtype=a.dtype, + output_dtype=result_tens.dtype, + result_tensor=result_tens, input_list=input_list, output_list=output_list, num_operands=num_operands, ): return None - self.ser.addOperator(op['op'], input_list, output_list, attr) + self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens @@ -3654,58 +3935,72 @@ class TosaTestGen: output_list = [result_tens.name] pCount, cCount = op["operands"] num_operands = pCount + cCount - input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) + input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( + self, error_name, input_list, output_list + ) if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, op=op, - input1 = cond, - input2 = a, - input3 = b, - input_shape = a.shape, - input_dtype = a.dtype, - output_dtype = result_tens.dtype, - result_tensor = result_tens, + input1=cond, + input2=a, + input3=b, + input_shape=a.shape, + input_dtype=a.dtype, + output_dtype=result_tens.dtype, + result_tensor=result_tens, input_list=input_list, output_list=output_list, num_operands=num_operands, ): return None - self.ser.addOperator(op['op'], input_list, output_list,) + self.ser.addOperator( + op["op"], + input_list, + output_list, + ) return result_tens def build_comparison(self, op, a, b, validator_fcns=None, error_name=None): - result_tens = OutputShaper.binaryComparisonOp(self.ser, self.rng, a, b, error_name) + result_tens = OutputShaper.binaryComparisonOp( + self.ser, self.rng, a, b, error_name + ) # Invalidate Input/Output list for error if checks. input_list = [a.name, b.name] output_list = [result_tens.name] pCount, cCount = op["operands"] num_operands = pCount + cCount - input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) + input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( + self, error_name, input_list, output_list + ) if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, op=op, - input1 = a, - input2 = b, - input_shape = a.shape, - input_dtype = a.dtype, - output_shape = result_tens.shape, - output_dtype = result_tens.dtype, - result_tensor = result_tens, + input1=a, + input2=b, + input_shape=a.shape, + input_dtype=a.dtype, + output_shape=result_tens.shape, + output_dtype=result_tens.dtype, + result_tensor=result_tens, input_list=input_list, output_list=output_list, num_operands=num_operands, ): return None - self.ser.addOperator(op['op'], input_list, output_list,) + self.ser.addOperator( + op["op"], + input_list, + output_list, + ) return result_tens def build_argmax(self, op, a, axis, validator_fcns, error_name): @@ -3716,7 +4011,9 @@ class TosaTestGen: output_list = [result_tens.name] pCount, cCount = op["operands"] num_operands = pCount + cCount - input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) + input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( + self, error_name, input_list, output_list + ) if not TosaErrorValidator.evValidateErrorIfs( self.ser, @@ -3724,11 +4021,11 @@ class TosaTestGen: error_name, op=op, axis=axis, - input_shape = a.shape, - input_dtype = a.dtype, - output_shape = result_tens.shape, - output_dtype = result_tens.dtype, - result_tensor = result_tens, + input_shape=a.shape, + input_dtype=a.dtype, + output_shape=result_tens.shape, + output_dtype=result_tens.dtype, + result_tensor=result_tens, input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -3738,18 +4035,31 @@ class TosaTestGen: attr = ts.TosaSerializerAttribute() attr.AxisAttribute(axis) - self.ser.addOperator(op['op'], input_list, output_list, attr) + self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens - def build_pool2d(self, op, input, stride, pad, kernel, validator_fcns=None, error_name=None, qinfo=None): - result_tens = OutputShaper.pool2dOp(self.ser, self.rng, input, kernel, stride, pad, error_name) + def build_pool2d( + self, + op, + input, + stride, + pad, + kernel, + validator_fcns=None, + error_name=None, + qinfo=None, + ): + result_tens = OutputShaper.pool2dOp( + self.ser, self.rng, input, kernel, stride, pad, error_name + ) # Ensure new output type has correct qinfo if error_name == ErrorIf.WrongInputType: if input.dtype not in [DType.INT8, DType.UINT8]: qinfo = ts.TosaSerializerQuantInfo() qinfo.UnaryQuantInfo( - TosaQuantGen.getQinfo(self, input.dtype), TosaQuantGen.getQinfo(self, result_tens.dtype) + TosaQuantGen.getQinfo(self, input.dtype), + TosaQuantGen.getQinfo(self, result_tens.dtype), ) # Invalidate Input/Output list for error if checks. @@ -3757,7 +4067,9 @@ class TosaTestGen: output_list = [result_tens.name] pCount, cCount = op["operands"] num_operands = pCount + cCount - input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) + input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( + self, error_name, input_list, output_list + ) if not TosaErrorValidator.evValidateErrorIfs( self.ser, @@ -3771,8 +4083,8 @@ class TosaTestGen: kernel=kernel, stride=stride, pad=pad, - qinfo = qinfo, - result_tensor = result_tens, + qinfo=qinfo, + result_tensor=result_tens, input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -3782,27 +4094,45 @@ class TosaTestGen: attr = ts.TosaSerializerAttribute() attr.PoolAttribute(kernel, stride, pad) - self.ser.addOperator(op['op'], input_list, output_list, attr, qinfo) + self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo) return result_tens - def build_conv2d(self, op, ifm, filter, bias, strides, padding, dilations, validator_fcns=None, error_name=None, qinfo=None): + def build_conv2d( + self, + op, + ifm, + filter, + bias, + strides, + padding, + dilations, + validator_fcns=None, + error_name=None, + qinfo=None, + ): assert len(padding) == 4 result_tens = OutputShaper.conv2dOp( self.ser, self.rng, ifm, filter, strides, padding, dilations, error_name ) # Ensure new output type has correct qinfo - if error_name == ErrorIf.WrongInputType and ifm.dtype not in (DType.INT8, DType.UINT8): + if error_name == ErrorIf.WrongInputType and ifm.dtype not in ( + DType.INT8, + DType.UINT8, + ): qinfo = ts.TosaSerializerQuantInfo() qinfo.ConvQuantInfo( - TosaQuantGen.getQinfo(self, ifm.dtype), TosaQuantGen.getQinfo(self, result_tens.dtype) + TosaQuantGen.getQinfo(self, ifm.dtype), + TosaQuantGen.getQinfo(self, result_tens.dtype), ) # Invalidate Input/Output list for error_if checks. input_list = [ifm.name, filter.name, bias.name] output_list = [result_tens.name] num_operands = sum(op["operands"]) - input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) + input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( + self, error_name, input_list, output_list + ) if not TosaErrorValidator.evValidateErrorIfs( self.ser, @@ -3826,29 +4156,45 @@ class TosaTestGen: attr = ts.TosaSerializerAttribute() attr.ConvAttribute(padding, strides, dilations) - self.ser.addOperator( - op['op'], input_list, output_list, attr, qinfo - ) + self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo) return result_tens - def build_conv3d(self, op, ifm, filter, bias, strides, padding, dilations, validator_fcns=None, error_name=None, qinfo=None): + def build_conv3d( + self, + op, + ifm, + filter, + bias, + strides, + padding, + dilations, + validator_fcns=None, + error_name=None, + qinfo=None, + ): assert len(padding) == 6 result_tens = OutputShaper.conv3dOp( self.ser, self.rng, ifm, filter, strides, padding, dilations, error_name ) # Ensure new output type has correct qinfo - if error_name == ErrorIf.WrongInputType and ifm.dtype not in (DType.INT8, DType.UINT8): + if error_name == ErrorIf.WrongInputType and ifm.dtype not in ( + DType.INT8, + DType.UINT8, + ): qinfo = ts.TosaSerializerQuantInfo() qinfo.ConvQuantInfo( - TosaQuantGen.getQinfo(self, ifm.dtype), TosaQuantGen.getQinfo(self, result_tens.dtype) + TosaQuantGen.getQinfo(self, ifm.dtype), + TosaQuantGen.getQinfo(self, result_tens.dtype), ) # Invalidate Input/Output list for error_if checks. input_list = [ifm.name, filter.name, bias.name] output_list = [result_tens.name] num_operands = sum(op["operands"]) - input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) + input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( + self, error_name, input_list, output_list + ) if not TosaErrorValidator.evValidateErrorIfs( self.ser, @@ -3872,29 +4218,46 @@ class TosaTestGen: attr = ts.TosaSerializerAttribute() attr.ConvAttribute(padding, strides, dilations) - self.ser.addOperator( - op['op'], input_list, output_list, attr, qinfo - ) + self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo) return result_tens def build_transpose_conv2d( - self, op, ifm, filter, bias, stride, outpad, dilation, output_shape, validator_fcns=None, error_name=None, qinfo=None + self, + op, + ifm, + filter, + bias, + stride, + outpad, + dilation, + output_shape, + validator_fcns=None, + error_name=None, + qinfo=None, ): assert len(outpad) == 2 - result_tens = OutputShaper.transposeConv2DOp(self.ser, self.rng, ifm, output_shape, error_name) + result_tens = OutputShaper.transposeConv2DOp( + self.ser, self.rng, ifm, output_shape, error_name + ) # Ensure new output type has correct qinfo - if error_name == ErrorIf.WrongInputType and ifm.dtype not in (DType.INT8, DType.UINT8): + if error_name == ErrorIf.WrongInputType and ifm.dtype not in ( + DType.INT8, + DType.UINT8, + ): qinfo = ts.TosaSerializerQuantInfo() qinfo.ConvQuantInfo( - TosaQuantGen.getQinfo(self, ifm.dtype), TosaQuantGen.getQinfo(self, result_tens.dtype) + TosaQuantGen.getQinfo(self, ifm.dtype), + TosaQuantGen.getQinfo(self, result_tens.dtype), ) # Invalidate Input/Output list for error_if checks. input_list = [ifm.name, filter.name, bias.name] output_list = [result_tens.name] num_operands = sum(op["operands"]) - input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) + input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( + self, error_name, input_list, output_list + ) if not TosaErrorValidator.evValidateErrorIfs( self.ser, @@ -3918,30 +4281,44 @@ class TosaTestGen: attr = ts.TosaSerializerAttribute() attr.TransposeConvAttribute(outpad, stride, dilation, output_shape) - self.ser.addOperator( - op['op'], input_list, output_list, attr, qinfo - ) + self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo) return result_tens def build_depthwise_conv2d( - self, op, ifm, filter, bias, strides, padding, dilations, validator_fcns=None, error_name=None, qinfo=None - ): + self, + op, + ifm, + filter, + bias, + strides, + padding, + dilations, + validator_fcns=None, + error_name=None, + qinfo=None, + ): result_tens = OutputShaper.depthwiseConv2dOp( self.ser, self.rng, ifm, filter, strides, padding, dilations, error_name ) # Ensure new output type has correct qinfo - if error_name == ErrorIf.WrongInputType and ifm.dtype not in (DType.INT8, DType.UINT8): + if error_name == ErrorIf.WrongInputType and ifm.dtype not in ( + DType.INT8, + DType.UINT8, + ): qinfo = ts.TosaSerializerQuantInfo() qinfo.ConvQuantInfo( - TosaQuantGen.getQinfo(self, ifm.dtype), TosaQuantGen.getQinfo(self, result_tens.dtype) + TosaQuantGen.getQinfo(self, ifm.dtype), + TosaQuantGen.getQinfo(self, result_tens.dtype), ) # Invalidate Input/Output list for error_if checks. input_list = [ifm.name, filter.name, bias.name] output_list = [result_tens.name] num_operands = sum(op["operands"]) - input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) + input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( + self, error_name, input_list, output_list + ) if not TosaErrorValidator.evValidateErrorIfs( self.ser, @@ -3965,20 +4342,24 @@ class TosaTestGen: attr = ts.TosaSerializerAttribute() attr.ConvAttribute(padding, strides, dilations) - self.ser.addOperator( - op['op'], input_list, output_list, attr, qinfo - ) + self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo) return result_tens - def build_fully_connected(self, op, ifm, filter, bias, validator_fcns=None, error_name=None, qinfo=None): - result_tens = OutputShaper.fullyConnectedOp(self.ser, self.rng, ifm, filter, error_name) + def build_fully_connected( + self, op, ifm, filter, bias, validator_fcns=None, error_name=None, qinfo=None + ): + result_tens = OutputShaper.fullyConnectedOp( + self.ser, self.rng, ifm, filter, error_name + ) # Invalidate Input/Output list for error if checks. input_list = [ifm.name, filter.name, bias.name] output_list = [result_tens.name] pCount, cCount = op["operands"] num_operands = pCount + cCount - input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) + input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( + self, error_name, input_list, output_list + ) if not TosaErrorValidator.evValidateErrorIfs( self.ser, @@ -3990,17 +4371,15 @@ class TosaTestGen: weight_dtype=filter.dtype, output_shape=result_tens.shape, output_dtype=result_tens.dtype, - qinfo = qinfo, - result_tensor = result_tens, + qinfo=qinfo, + result_tensor=result_tens, input_list=input_list, output_list=output_list, num_operands=num_operands, ): return None - self.ser.addOperator( - op['op'], input_list, output_list, None, qinfo - ) + self.ser.addOperator(op["op"], input_list, output_list, None, qinfo) return result_tens def build_matmul(self, op, a, b, validator_fcns=None, error_name=None, qinfo=None): @@ -4011,7 +4390,9 @@ class TosaTestGen: output_list = [result_tens.name] pCount, cCount = op["operands"] num_operands = pCount + cCount - input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) + input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( + self, error_name, input_list, output_list + ) if not TosaErrorValidator.evValidateErrorIfs( self.ser, @@ -4024,15 +4405,15 @@ class TosaTestGen: input2_dtype=b.dtype, output_shape=result_tens.shape, output_dtype=result_tens.dtype, - qinfo = qinfo, - result_tensor = result_tens, + qinfo=qinfo, + result_tensor=result_tens, input_list=input_list, output_list=output_list, num_operands=num_operands, ): return None - self.ser.addOperator(op['op'], input_list, output_list, None, qinfo) + self.ser.addOperator(op["op"], input_list, output_list, None, qinfo) return result_tens def build_reduce(self, op, a, axis, validator_fcns, error_name=None): @@ -4043,19 +4424,21 @@ class TosaTestGen: output_list = [result_tens.name] pCount, cCount = op["operands"] num_operands = pCount + cCount - input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) + input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( + self, error_name, input_list, output_list + ) if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, op=op, - axis = axis, - input_shape = a.shape, - output_shape = result_tens.shape, - input_dtype = a.dtype, - output_dtype = result_tens.dtype, - result_tensor = result_tens, + axis=axis, + input_shape=a.shape, + output_shape=result_tens.shape, + input_dtype=a.dtype, + output_dtype=result_tens.dtype, + result_tensor=result_tens, input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -4065,7 +4448,7 @@ class TosaTestGen: attr = ts.TosaSerializerAttribute() attr.AxisAttribute(axis) - self.ser.addOperator(op['op'], input_list, output_list, attr) + self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens def build_clamp(self, op, a, validator_fcns=None, error_name=None): @@ -4088,7 +4471,9 @@ class TosaTestGen: output_list = [result_tens.name] pCount, cCount = op["operands"] num_operands = pCount + cCount - input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) + input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( + self, error_name, input_list, output_list + ) if not TosaErrorValidator.evValidateErrorIfs( self.ser, @@ -4097,11 +4482,11 @@ class TosaTestGen: op=op, max_val=max_val, min_val=min_val, - input_shape = a.shape, - output_shape = result_tens.shape, - input_dtype = a.dtype, - output_dtype = result_tens.dtype, - result_tensor = result_tens, + input_shape=a.shape, + output_shape=result_tens.shape, + input_dtype=a.dtype, + output_dtype=result_tens.dtype, + result_tensor=result_tens, input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -4114,7 +4499,7 @@ class TosaTestGen: else: attr.ClampAttribute(min_val, max_val, 0, 0) - self.ser.addOperator(op['op'], input_list, output_list, attr) + self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None): @@ -4123,14 +4508,14 @@ class TosaTestGen: attr.LeakyReluAttribute(self.getRandNumberDType(DType.FLOAT)) - self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr) + self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr) return result_tens # Needs an additional type/input def build_prelu(self, op, a, validator_fcns=None, error_name=None): result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name) - self.ser.addOperator(op['op'], [a.name], [result_tens.name]) + self.ser.addOperator(op["op"], [a.name], [result_tens.name]) return result_tens def build_sigmoid(self, op, a, validator_fcns=None, error_name=None): @@ -4141,25 +4526,27 @@ class TosaTestGen: output_list = [result_tens.name] pCount, cCount = op["operands"] num_operands = pCount + cCount - input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) + input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( + self, error_name, input_list, output_list + ) if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, op=op, - input_shape = a.shape, - output_shape = result_tens.shape, - input_dtype = a.dtype, - output_dtype = result_tens.dtype, - result_tensor = result_tens, + input_shape=a.shape, + output_shape=result_tens.shape, + input_dtype=a.dtype, + output_dtype=result_tens.dtype, + result_tensor=result_tens, input_list=input_list, output_list=output_list, num_operands=num_operands, ): return None - self.ser.addOperator(op['op'], input_list, output_list) + self.ser.addOperator(op["op"], input_list, output_list) return result_tens def build_tanh(self, op, a, validator_fcns=None, error_name=None): @@ -4170,25 +4557,27 @@ class TosaTestGen: output_list = [result_tens.name] pCount, cCount = op["operands"] num_operands = pCount + cCount - input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) + input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( + self, error_name, input_list, output_list + ) if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, op=op, - input_shape = a.shape, - output_shape = result_tens.shape, - input_dtype = a.dtype, - output_dtype = result_tens.dtype, - result_tensor = result_tens, + input_shape=a.shape, + output_shape=result_tens.shape, + input_dtype=a.dtype, + output_dtype=result_tens.dtype, + result_tensor=result_tens, input_list=input_list, output_list=output_list, num_operands=num_operands, ): return None - self.ser.addOperator(op['op'], input_list, output_list) + self.ser.addOperator(op["op"], input_list, output_list) return result_tens def build_concat(self, op, *a, validator_fcns=None, error_name=None): @@ -4199,7 +4588,9 @@ class TosaTestGen: axis = a[-1] a = a[:-1] - result_tens = OutputShaper.concatOp(self.ser, self.rng, axis, *a, error_name=error_name) + result_tens = OutputShaper.concatOp( + self.ser, self.rng, axis, *a, error_name=error_name + ) input_tensor_names = [] for tensor in a: @@ -4210,7 +4601,9 @@ class TosaTestGen: output_list = [result_tens.name] pCount, cCount = op["operands"] num_operands = pCount + cCount - input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) + input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( + self, error_name, input_list, output_list + ) if not TosaErrorValidator.evValidateErrorIfs( self.ser, @@ -4218,12 +4611,12 @@ class TosaTestGen: error_name, op=op, axis=axis, - input_shape = a[0].shape, - output_shape = result_tens.shape, - input_dtype = a[0].dtype, - output_dtype = result_tens.dtype, + input_shape=a[0].shape, + output_shape=result_tens.shape, + input_dtype=a[0].dtype, + output_dtype=result_tens.dtype, inputs=a, - result_tensor = result_tens, + result_tensor=result_tens, input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -4233,11 +4626,20 @@ class TosaTestGen: attr = ts.TosaSerializerAttribute() attr.AxisAttribute(axis) - - self.ser.addOperator(op['op'], input_list, output_list, attr) + self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens - def build_pad(self, op, a, padding, pad_const_int, pad_const_float, validator_fcns=None, error_name=None, qinfo=None): + def build_pad( + self, + op, + a, + padding, + pad_const_int, + pad_const_float, + validator_fcns=None, + error_name=None, + qinfo=None, + ): result_tens = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name) attr = ts.TosaSerializerAttribute() @@ -4248,51 +4650,55 @@ class TosaTestGen: output_list = [result_tens.name] pCount, cCount = op["operands"] num_operands = pCount + cCount - input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) + input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( + self, error_name, input_list, output_list + ) if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, op=op, - input_shape = a.shape, - output_shape = result_tens.shape, - input_dtype = a.dtype, - output_dtype = result_tens.dtype, + input_shape=a.shape, + output_shape=result_tens.shape, + input_dtype=a.dtype, + output_dtype=result_tens.dtype, pad=padding, qinfo=qinfo, - result_tensor = result_tens, + result_tensor=result_tens, input_list=input_list, output_list=output_list, num_operands=num_operands, ): return None - self.ser.addOperator( - op['op'], input_list, output_list, attr, qinfo - ) + self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo) return result_tens def build_reshape(self, op, a, newShape, validator_fcns=None, error_name=None): - result_tens = OutputShaper.reshapeOp(self.ser, self.rng, a, newShape, error_name) + result_tens = OutputShaper.reshapeOp( + self.ser, self.rng, a, newShape, error_name + ) # Invalidate Input/Output list for error if checks. input_list = [a.name] output_list = [result_tens.name] pCount, cCount = op["operands"] num_operands = pCount + cCount - input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) + input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( + self, error_name, input_list, output_list + ) if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, op=op, - input_shape = a.shape, - output_shape = result_tens.shape, - input_dtype = a.dtype, - output_dtype = result_tens.dtype, - result_tensor = result_tens, + input_shape=a.shape, + output_shape=result_tens.shape, + input_dtype=a.dtype, + output_dtype=result_tens.dtype, + result_tensor=result_tens, input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -4302,7 +4708,7 @@ class TosaTestGen: attr = ts.TosaSerializerAttribute() attr.ReshapeAttribute(newShape) - self.ser.addOperator(op['op'], input_list, output_list, attr) + self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens def build_reverse(self, op, a, axis, validator_fcns=None, error_name=None): @@ -4313,7 +4719,9 @@ class TosaTestGen: output_list = [result_tens.name] pCount, cCount = op["operands"] num_operands = pCount + cCount - input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) + input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( + self, error_name, input_list, output_list + ) if not TosaErrorValidator.evValidateErrorIfs( self.ser, @@ -4321,11 +4729,11 @@ class TosaTestGen: error_name, op=op, axis=axis, - input_shape = a.shape, - output_shape = result_tens.shape, - input_dtype = a.dtype, - output_dtype = result_tens.dtype, - result_tensor = result_tens, + input_shape=a.shape, + output_shape=result_tens.shape, + input_dtype=a.dtype, + output_dtype=result_tens.dtype, + result_tensor=result_tens, input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -4335,7 +4743,7 @@ class TosaTestGen: attr = ts.TosaSerializerAttribute() attr.AxisAttribute(axis) - self.ser.addOperator(op['op'], input_list, output_list, attr) + self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens def build_transpose(self, op, a, perms, validator_fcns=None, error_name=None): @@ -4349,51 +4757,56 @@ class TosaTestGen: output_list = [result_tens.name] pCount, cCount = op["operands"] num_operands = pCount + cCount - input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) + input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( + self, error_name, input_list, output_list + ) if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, op=op, - input_shape = a.shape, - output_shape = result_tens.shape, + input_shape=a.shape, + output_shape=result_tens.shape, perms=perms, - input_dtype = a.dtype, - output_dtype = result_tens.dtype, - result_tensor = result_tens, + input_dtype=a.dtype, + output_dtype=result_tens.dtype, + result_tensor=result_tens, input_list=input_list, output_list=output_list, num_operands=num_operands, ): return None - - self.ser.addOperator(op['op'], input_list, output_list, attr) + self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens def build_slice(self, op, a, start, size, validator_fcns=None, error_name=None): - result_tens = OutputShaper.sliceOp(self.ser, self.rng, a, start, size, error_name) + result_tens = OutputShaper.sliceOp( + self.ser, self.rng, a, start, size, error_name + ) # Invalidate Input/Output list for error if checks. input_list = [a.name] output_list = [result_tens.name] pCount, cCount = op["operands"] num_operands = pCount + cCount - input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) + input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( + self, error_name, input_list, output_list + ) if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, op=op, - input_shape = a.shape, - output_shape = result_tens.shape, - input_dtype = a.dtype, - output_dtype = result_tens.dtype, + input_shape=a.shape, + output_shape=result_tens.shape, + input_dtype=a.dtype, + output_dtype=result_tens.dtype, start=start, size=size, - result_tensor = result_tens, + result_tensor=result_tens, input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -4403,7 +4816,7 @@ class TosaTestGen: attr = ts.TosaSerializerAttribute() attr.SliceAttribute(start, size) - self.ser.addOperator(op['op'], input_list, output_list, attr) + self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens def build_tile(self, op, a, multiples, validator_fcns=None, error_name=None): @@ -4414,18 +4827,20 @@ class TosaTestGen: output_list = [result_tens.name] pCount, cCount = op["operands"] num_operands = pCount + cCount - input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) + input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( + self, error_name, input_list, output_list + ) if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, op=op, - input_shape = a.shape, - output_shape = result_tens.shape, - input_dtype = a.dtype, - output_dtype = result_tens.dtype, - result_tensor = result_tens, + input_shape=a.shape, + output_shape=result_tens.shape, + input_dtype=a.dtype, + output_dtype=result_tens.dtype, + result_tensor=result_tens, input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -4435,7 +4850,7 @@ class TosaTestGen: attr = ts.TosaSerializerAttribute() attr.TileAttribute(multiples) - self.ser.addOperator(op['op'], input_list, output_list, attr) + self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens def build_gather(self, op, values, validator_fcns=None, error_name=None): @@ -4452,32 +4867,36 @@ class TosaTestGen: ) # (N, W) indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr) - result_tens = OutputShaper.gatherOp(self.ser, self.rng, values, indicies, error_name) + result_tens = OutputShaper.gatherOp( + self.ser, self.rng, values, indicies, error_name + ) # Invalidate Input/Output list for error if checks. input_list = [values.name, indicies.name] output_list = [result_tens.name] pCount, cCount = op["operands"] num_operands = pCount + cCount - input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) + input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( + self, error_name, input_list, output_list + ) if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, op=op, - input_shape = values.shape, - output_shape = result_tens.shape, - input_dtype = values.dtype, - output_dtype = result_tens.dtype, - result_tensor = result_tens, + input_shape=values.shape, + output_shape=result_tens.shape, + input_dtype=values.dtype, + output_dtype=result_tens.dtype, + result_tensor=result_tens, input_list=input_list, output_list=output_list, num_operands=num_operands, ): return None - self.ser.addOperator(op['op'], input_list, output_list) + self.ser.addOperator(op["op"], input_list, output_list) return result_tens @@ -4493,36 +4912,39 @@ class TosaTestGen: ) # (N, W) indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr) - result_tens = OutputShaper.scatterOp(self.ser, self.rng, values_in, indicies, input, error_name) + result_tens = OutputShaper.scatterOp( + self.ser, self.rng, values_in, indicies, input, error_name + ) # Invalidate Input/Output list for error if checks. input_list = [values_in.name, indicies.name, input.name] output_list = [result_tens.name] pCount, cCount = op["operands"] num_operands = pCount + cCount - input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) + input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( + self, error_name, input_list, output_list + ) if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, op=op, - input_shape = values_in.shape, - output_shape = result_tens.shape, - input_dtype = values_in.dtype, - output_dtype = result_tens.dtype, - result_tensor = result_tens, + input_shape=values_in.shape, + output_shape=result_tens.shape, + input_dtype=values_in.dtype, + output_dtype=result_tens.dtype, + result_tensor=result_tens, input_list=input_list, output_list=output_list, num_operands=num_operands, ): return None - self.ser.addOperator(op['op'], input_list, output_list) + self.ser.addOperator(op["op"], input_list, output_list) return result_tens - def build_resize( self, op, @@ -4537,7 +4959,7 @@ class TosaTestGen: input_dtype, output_dtype, validator_fcns, - error_name = None, + error_name=None, ): result_tens = OutputShaper.resizeOp( self.ser, @@ -4552,7 +4974,7 @@ class TosaTestGen: output_dims, input_dtype, output_dtype, - error_name + error_name, ) # Invalidate Input/Output list for error if checks. @@ -4560,7 +4982,9 @@ class TosaTestGen: output_list = [result_tens.name] pCount, cCount = op["operands"] num_operands = pCount + cCount - input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) + input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( + self, error_name, input_list, output_list + ) if not TosaErrorValidator.evValidateErrorIfs( self.ser, @@ -4590,7 +5014,7 @@ class TosaTestGen: output_dims, stride, offset, shift, stride_fp, offset_fp, mode ) - self.ser.addOperator(op['op'], input_list, output_list, attr) + self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None): @@ -4607,36 +5031,52 @@ class TosaTestGen: # Type Conversion def build_cast(self, op, val, out_dtype, validator_fcns=None, error_name=None): - result_tens = OutputShaper.typeConversionOp(self.ser, self.rng, val, out_dtype, error_name) + result_tens = OutputShaper.typeConversionOp( + self.ser, self.rng, val, out_dtype, error_name + ) # Invalidate Input/Output list for error if checks. input_list = [val.name] output_list = [result_tens.name] pCount, cCount = op["operands"] num_operands = pCount + cCount - input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) + input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( + self, error_name, input_list, output_list + ) if not TosaErrorValidator.evValidateErrorIfs( self.ser, validator_fcns, error_name, op=op, - input_shape = val.shape, - output_shape = result_tens.shape, - input_dtype = val.dtype, - output_dtype = result_tens.dtype, - result_tensor = result_tens, + input_shape=val.shape, + output_shape=result_tens.shape, + input_dtype=val.dtype, + output_dtype=result_tens.dtype, + result_tensor=result_tens, input_list=input_list, output_list=output_list, num_operands=num_operands, ): return None - self.ser.addOperator(op['op'], input_list, output_list) + self.ser.addOperator(op["op"], input_list, output_list) return result_tens - def build_rescale(self, op, val, out_dtype, scale32, double_round, per_channel, validator_fcns, error_name): - result_tens = OutputShaper.typeConversionOp(self.ser, self.rng, val, out_dtype, error_name) + def build_rescale( + self, + op, + val, + out_dtype, + scale32, + double_round, + per_channel, + validator_fcns, + error_name, + ): + result_tens = OutputShaper.typeConversionOp( + self.ser, self.rng, val, out_dtype, error_name + ) if per_channel: nc = val.shape[-1] @@ -4705,7 +5145,9 @@ class TosaTestGen: output_list = [result_tens.name] pCount, cCount = op["operands"] num_operands = pCount + cCount - input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list) + input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( + self, error_name, input_list, output_list + ) qinfo = (input_zp, output_zp) if not TosaErrorValidator.evValidateErrorIfs( @@ -4717,8 +5159,8 @@ class TosaTestGen: output_dtype=out_dtype, input_shape=val.shape, qinfo=qinfo, - scale32 = scale32, - double_round = double_round, + scale32=scale32, + double_round=double_round, input_list=input_list, output_list=output_list, result_tensor=result_tens, @@ -4737,10 +5179,12 @@ class TosaTestGen: per_channel, ) - self.ser.addOperator(op['op'], input_list, output_list, attr) + self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens - def build_cond_if_const(self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None): + def build_cond_if_const( + self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None + ): # For cond_if with constants, we're supplied with then/else tensors that we ignore # (except for the generated shap) and the condition. Build Then/Else blocks # and fill them with const nodes for the body. @@ -4752,10 +5196,17 @@ class TosaTestGen: out_shape = then_tens.shape # Create an incorrect output shape for error_if tests - if error_name in [ErrorIf.CondIfOutputListThenGraphMismatch, ErrorIf.CondIfOutputListElseGraphMismatch]: + if error_name in [ + ErrorIf.CondIfOutputListThenGraphMismatch, + ErrorIf.CondIfOutputListElseGraphMismatch, + ]: incorrect_shape = deepcopy(then_tens.shape) for i in range(len(incorrect_shape)): - incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3]) if incorrect_shape[i] > 3 else self.rng.choice([1, 2, 4]) + incorrect_shape[i] += ( + self.rng.choice([-3, -2, 2, 3]) + if incorrect_shape[i] > 3 + else self.rng.choice([1, 2, 4]) + ) incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape)) then_arr = np.int32(self.rng.integers(0, 256, size=out_shape)) @@ -4771,7 +5222,7 @@ class TosaTestGen: attr.CondIfAttribute(then_block, else_block) # Finally, build the op and the two blocks - self.ser.addOperator(op['op'], [cond_tens.name], [result_tens.name], attr) + self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr) self.ser.startBasicBlock(then_block) # Build the actual then/else tensors inside their blocks @@ -4793,13 +5244,15 @@ class TosaTestGen: validator_fcns, error_name, op=op, - basicBlocks=self.ser.basicBlocks + basicBlocks=self.ser.basicBlocks, ): return None return result_tens - def build_cond_if_binary(self, op, a, b, cond, validator_fcns=None, error_name=None): + def build_cond_if_binary( + self, op, a, b, cond, validator_fcns=None, error_name=None + ): # For cond_if with a binary op in the then/else blocks, take a and b and # alternately add or subtract them based on the condition @@ -4814,18 +5267,21 @@ class TosaTestGen: attr = ts.TosaSerializerAttribute() attr.CondIfAttribute(then_block, else_block) - if error_name in [ErrorIf.CondIfInputListThenGraphMismatch, ErrorIf.CondIfInputListElseGraphMismatch, - ErrorIf.CondIfOutputListElseGraphMismatch, ErrorIf.CondIfOutputListThenGraphMismatch]: + if error_name in [ + ErrorIf.CondIfInputListThenGraphMismatch, + ErrorIf.CondIfInputListElseGraphMismatch, + ErrorIf.CondIfOutputListElseGraphMismatch, + ErrorIf.CondIfOutputListThenGraphMismatch, + ]: incorrect_shape = a.shape.copy() for i in range(len(incorrect_shape)): incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3]) incorrect_block_input = deepcopy(a) incorrect_block_input.shape = incorrect_shape - # Finally, build the op and the two blocks self.ser.addOperator( - op['op'], [cond_tens.name, a.name, b.name], [result_tens.name], attr + op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr ) if a.dtype in (DType.FLOAT, DType.INT32): @@ -4837,13 +5293,23 @@ class TosaTestGen: for block, op in ((then_block, then_op), (else_block, else_op)): self.ser.startBasicBlock(block) - if ((error_name == ErrorIf.CondIfInputListThenGraphMismatch and block == then_block) or - (error_name == ErrorIf.CondIfInputListElseGraphMismatch and block == else_block)): + if ( + error_name == ErrorIf.CondIfInputListThenGraphMismatch + and block == then_block + ) or ( + error_name == ErrorIf.CondIfInputListElseGraphMismatch + and block == else_block + ): self.ser.addInputTensor(incorrect_block_input) self.ser.addInputTensor(b) tens = self.ser.addOutput(a.shape, a.dtype) - elif ((error_name == ErrorIf.CondIfOutputListThenGraphMismatch and block == then_block) or - (error_name == ErrorIf.CondIfOutputListElseGraphMismatch and block == else_block)): + elif ( + error_name == ErrorIf.CondIfOutputListThenGraphMismatch + and block == then_block + ) or ( + error_name == ErrorIf.CondIfOutputListElseGraphMismatch + and block == else_block + ): self.ser.addInputTensor(a) self.ser.addInputTensor(b) tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype) @@ -4860,7 +5326,7 @@ class TosaTestGen: op=op, a=a, b=b, - basicBlocks=self.ser.basicBlocks + basicBlocks=self.ser.basicBlocks, ): return None @@ -4893,14 +5359,18 @@ class TosaTestGen: # While_loop operator self.ser.addOperator( - op['op'], + op["op"], [iter.name, a.name, acc.name], [iter_out.name, a_out.name, acc_out.name], attr, ) self.ser.addOutputTensor(acc_out) - if error_name in [ErrorIf.InputListCondGraphMismatch, ErrorIf.InputListBodyGraphInputMismatch, ErrorIf.InputListBodyGraphOutputMismatch]: + if error_name in [ + ErrorIf.InputListCondGraphMismatch, + ErrorIf.InputListBodyGraphInputMismatch, + ErrorIf.InputListBodyGraphOutputMismatch, + ]: incorrect_iter = deepcopy(iter) for i in range(len(incorrect_iter.shape)): incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3]) @@ -4924,7 +5394,9 @@ class TosaTestGen: zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)]) if error_name == ErrorIf.CondGraphOutputNotMatchingBool: - cond_tens = self.ser.addOutput([], self.rng.choice([DType.INT8, DType.INT32, DType.FLOAT])) + cond_tens = self.ser.addOutput( + [], self.rng.choice([DType.INT8, DType.INT32, DType.FLOAT]) + ) else: cond_tens = self.ser.addOutput([], DType.BOOL) @@ -4945,8 +5417,12 @@ class TosaTestGen: one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)]) if error_name == ErrorIf.InputListBodyGraphOutputMismatch: - iter_body_out = self.ser.addIntermediate(incorrect_iter.shape, incorrect_iter.dtype) - acc_body_out = self.ser.addIntermediate(incorrect_acc.shape, incorrect_acc.dtype) + iter_body_out = self.ser.addIntermediate( + incorrect_iter.shape, incorrect_iter.dtype + ) + acc_body_out = self.ser.addIntermediate( + incorrect_acc.shape, incorrect_acc.dtype + ) else: iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype) acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype) @@ -4962,13 +5438,15 @@ class TosaTestGen: validator_fcns, error_name, op=op, - basicBlocks=self.ser.basicBlocks + basicBlocks=self.ser.basicBlocks, ): return None return acc_out - def create_filter_lists(self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None): + def create_filter_lists( + self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None + ): # Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small. default_test_rank_range = range(1, 5) if not shapeFilter: @@ -4986,7 +5464,11 @@ class TosaTestGen: # Ensure default behaviour is bounded by default range or by operator, # whichever is the smaller range of ranks. opRankRange = range(rmin, rmax + 1) - cleanRankFilter = opRankRange if len(opRankRange) <= len(default_test_rank_range) else default_test_rank_range + cleanRankFilter = ( + opRankRange + if len(opRankRange) <= len(default_test_rank_range) + else default_test_rank_range + ) else: cleanRankFilter = range(rmin, rmax + 1) @@ -4996,57 +5478,65 @@ class TosaTestGen: cleanDtypeFilter = [] # Create list of operator dtypes filtered by requested dtypes for dtype in dtypes: - if dtype in dtypeFilter or (isinstance(dtype, list) and dtype[0] in dtypeFilter): + if dtype in dtypeFilter or ( + isinstance(dtype, list) and dtype[0] in dtypeFilter + ): cleanDtypeFilter.append(dtype) else: cleanDtypeFilter = dtypes - if testType == 'positive': + if testType == "positive": filterDict = { - 'shapeFilter': shapeFilter, - 'rankFilter': cleanRankFilter, - 'dtypeFilter': cleanDtypeFilter + "shapeFilter": shapeFilter, + "rankFilter": cleanRankFilter, + "dtypeFilter": cleanDtypeFilter, } return filterDict - elif testType == 'negative': + elif testType == "negative": if validator is not None: validator_info = validator(check=False, op=op) else: return None - error_arguments = validator_info['param_reqs'] + error_arguments = validator_info["param_reqs"] - #Set parameters as required - if error_arguments['rank'] != None: - rankFilter = error_arguments['rank'] + # Set parameters as required + if error_arguments["rank"] is not None: + rankFilter = error_arguments["rank"] else: rankFilter = cleanRankFilter - if error_arguments['dtype'] != None: - dtypeFilter = error_arguments['dtype'] + if error_arguments["dtype"] is not None: + dtypeFilter = error_arguments["dtype"] else: dtypeFilter = cleanDtypeFilter - if error_arguments['shape'] != None: - shapeFilter = error_arguments['shape'] + if error_arguments["shape"] is not None: + shapeFilter = error_arguments["shape"] else: - shapeFilter = shapeFilter[:2] # Reduce number of shapes to keep test numbers small + shapeFilter = shapeFilter[ + :2 + ] # Reduce number of shapes to keep test numbers small filterDict = { - 'shapeFilter': shapeFilter, - 'rankFilter': rankFilter, - 'dtypeFilter': dtypeFilter + "shapeFilter": shapeFilter, + "rankFilter": rankFilter, + "dtypeFilter": dtypeFilter, } return filterDict - def genOpTestList( - self, opName, shapeFilter=[None], rankFilter=None, dtypeFilter=None, testType='positive' + self, + opName, + shapeFilter=[None], + rankFilter=None, + dtypeFilter=None, + testType="positive", ): try: op = self.TOSA_OP_LIST[opName] - except KeyError as e: + except KeyError: raise Exception("Cannot find op with name {}".format(opName)) # Initialize a new random number generator @@ -5057,24 +5547,26 @@ class TosaTestGen: # Test list consists of a tuple of: # (opName, testNameStr, dtype, shapeList, argumentsList) testList = [] - if testType == 'negative' and "error_if_validators" in op: + if testType == "negative" and "error_if_validators" in op: error_if_validators = op["error_if_validators"] else: error_if_validators = [None] for validator in error_if_validators: if validator is not None: - error_name = validator(check=False, op=op)['error_name'] + error_name = validator(check=False, op=op)["error_name"] else: error_name = None - filterDict = self.create_filter_lists(op, shapeFilter, rankFilter, dtypeFilter, testType, validator) - if filterDict == None: + filterDict = self.create_filter_lists( + op, shapeFilter, rankFilter, dtypeFilter, testType, validator + ) + if filterDict is None: return [] - cleanRankFilter = filterDict['rankFilter'] - cleanDtypeFilter = filterDict['dtypeFilter'] - cleanShapeFilter = filterDict['shapeFilter'] - #print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}") + cleanRankFilter = filterDict["rankFilter"] + cleanDtypeFilter = filterDict["dtypeFilter"] + cleanShapeFilter = filterDict["shapeFilter"] + # print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}") for r in cleanRankFilter: for t in cleanDtypeFilter: @@ -5096,24 +5588,30 @@ class TosaTestGen: argList = [("", [])] for argStr, args in argList: - if testType == 'positive': + if testType == "positive": if argStr: testStr = "{}_{}_{}_{}".format( opName, shapeStr, typeStr, argStr ) else: - testStr = "{}_{}_{}".format(opName, shapeStr, typeStr) - elif testType == 'negative': + testStr = "{}_{}_{}".format( + opName, shapeStr, typeStr + ) + elif testType == "negative": if argStr: testStr = "{}_ERRORIF_{}_{}_{}_{}".format( opName, error_name, shapeStr, typeStr, argStr ) else: - testStr = "{}_ERRORIF_{}_{}_{}".format(opName, error_name, shapeStr, typeStr) + testStr = "{}_ERRORIF_{}_{}_{}".format( + opName, error_name, shapeStr, typeStr + ) - testList.append((opName, testStr, t, error_name, shapeList, args)) + testList.append( + (opName, testStr, t, error_name, shapeList, args) + ) - if testType == 'positive': + if testType == "positive": # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement if "invalid_test_validators" in op: invalid_test_validators = op["invalid_test_validators"] @@ -5121,7 +5619,12 @@ class TosaTestGen: for test in testList: for validator_fcn in invalid_test_validators: remove_test = False - if validator_fcn(opName=test[0], input_dtype=test[2], shapeList=test[4], args=test[5]): + if validator_fcn( + opName=test[0], + input_dtype=test[2], + shapeList=test[4], + args=test[5], + ): remove_test = True if not remove_test: clean_testList.append(test) @@ -5129,11 +5632,12 @@ class TosaTestGen: return testList - - def serializeTest(self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs): + def serializeTest( + self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs + ): try: op = self.TOSA_OP_LIST[opName] - except KeyError as e: + except KeyError: raise Exception("Cannot find op with name {}".format(opName)) # Create a serializer @@ -5190,9 +5694,24 @@ class TosaTestGen: resultName = build_fcn(self, op, *tens, *testArgs) else: if qinfo is not None: - resultName = build_fcn(self, op, *tens, *testArgs, validator_fcns=error_if_validators, error_name=error_name, qinfo=qinfo) + resultName = build_fcn( + self, + op, + *tens, + *testArgs, + validator_fcns=error_if_validators, + error_name=error_name, + qinfo=qinfo, + ) else: - resultName = build_fcn(self, op, *tens, *testArgs, validator_fcns=error_if_validators, error_name=error_name) + resultName = build_fcn( + self, + op, + *tens, + *testArgs, + validator_fcns=error_if_validators, + error_name=error_name, + ) except TypeError as e: print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n") raise e @@ -5204,19 +5723,22 @@ class TosaTestGen: # The test is not valid print(f"Invalid ERROR_IF test created: {opName} {testStr}") - def generate_tensors(self, op, dtypeList, shapeList, testArgs, error_name=None): pCount, cCount = op["operands"] tens = [] - if (op["op"] == Op.ADD or op["op"] == Op.SUB) and dtypeList[0] == DType.INT32 and error_name == None: + if ( + (op["op"] == Op.ADD or op["op"] == Op.SUB) + and dtypeList[0] == DType.INT32 + and error_name is None + ): # Make sure the operation does not cause value saturation - where # the number wraps due to limited number of bits to store the answer assert ( pCount == 2 and cCount == 0 ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts" placeholders = [] - add = (op["op"] == Op.ADD) + add = op["op"] == Op.ADD a_arr = self.getRandTensor(shapeList[0], dtypeList[0]) b_arr = self.getRandTensor(shapeList[1], dtypeList[1]) if add: @@ -5225,7 +5747,7 @@ class TosaTestGen: res_arr = np.subtract(a_arr, b_arr, dtype=np.int64) # Work out the saturation limits - max_i32 = (1 << 31)-1 + max_i32 = (1 << 31) - 1 min_i32 = -(1 << 31) max_arr = np.full(shapeList[1], max_i32) min_arr = np.full(shapeList[1], min_i32) @@ -5246,7 +5768,9 @@ class TosaTestGen: # Reduce axes in unsaturated tensor to match original tensor for axis, dim in enumerate(b_arr.shape): if dim != b_unsat_arr.shape[axis]: - assert ( dim == 1 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable" + assert ( + dim == 1 + ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable" b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True) if (sat_min_arr != 0).any(): @@ -5255,7 +5779,9 @@ class TosaTestGen: # Reduce axes in unsaturated tensor to match original tensor for axis, dim in enumerate(b_arr.shape): if dim != b_unsat_arr.shape[axis]: - assert ( dim == 1 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable" + assert ( + dim == 1 + ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable" b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True) placeholders.append( @@ -5266,15 +5792,19 @@ class TosaTestGen: ) tens.extend(placeholders) - elif (op["op"] == Op.COND_IF or op["op"] == Op.WHILE_LOOP) and dtypeList[0] == DType.INT32: + elif (op["op"] == Op.COND_IF or op["op"] == Op.WHILE_LOOP) and dtypeList[ + 0 + ] == DType.INT32: # Limit input tensors with cond_if_binary or while_loop to stop # saturation of add/sub ops pRemain = pCount placeholders = [] - for idx, shape in enumerate(shapeList[:]): + for idx, shape in enumerate(shapeList[:]): arr = self.getRandTensor(shapeList[idx], DType.INT16) if pRemain > 0: - placeholders.append(self.ser.addPlaceholder(shape, dtypeList[idx], arr)) + placeholders.append( + self.ser.addPlaceholder(shape, dtypeList[idx], arr) + ) pRemain -= 1 else: placeholders.append(self.ser.addConst(shape, dtypeList[idx], arr)) @@ -5311,7 +5841,7 @@ class TosaTestGen: self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount]) ) tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:])) - elif op["op"] == Op.INTDIV and error_name == None: + elif op["op"] == Op.INTDIV and error_name is None: assert ( pCount == 2 and cCount == 0 ), "Op.INTDIV must have 2 placeholders, 0 consts" @@ -5341,7 +5871,7 @@ class TosaTestGen: ) tens.extend(placeholders) - elif op["op"] == Op.MUL and error_name == None: + elif op["op"] == Op.MUL and error_name is None: assert ( pCount == 2 and cCount == 0 ), "Op.MUL must have 2 placeholders, 0 consts" @@ -5414,7 +5944,9 @@ class TosaTestGen: # Ensure axis is an int testArgs[0] = int(testArgs[0]) - shapeList = TosaTensorGen.tgConcatConstInput(self, shapeList, testArgs[0], error_name) + shapeList = TosaTensorGen.tgConcatConstInput( + self, shapeList, testArgs[0], error_name + ) tens.extend( self.buildPlaceholderTensors(shapeList[0:count], dtypeList[0:count]) @@ -5466,7 +5998,7 @@ class TosaTestGen: keyList = [] for k in self.TOSA_OP_LIST: try: - if self.TOSA_OP_LIST[k]["template"] == True: + if self.TOSA_OP_LIST[k]["template"]: keyList.append(k) continue except KeyError: @@ -5498,22 +6030,22 @@ class TosaTestGen: ) try: - types = self.TOSA_OP_LIST[op]["types"] - except KeyError as e: + _ = self.TOSA_OP_LIST[op]["types"] + except KeyError: raise Exception( "Op {} is missing a valid type list in TOSA_OP_LIST".format(op) ) try: - opcode = self.TOSA_OP_LIST[op]["op"] - except KeyError as e: + _ = self.TOSA_OP_LIST[op]["op"] + except KeyError: raise Exception( "Op {} is missing the Op field in TOSA_OP_LIST".format(op) ) # Put in default rank range, if missing try: - rank = self.TOSA_OP_LIST[op]["rank"] + _ = self.TOSA_OP_LIST[op]["rank"] except KeyError: self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE @@ -5553,9 +6085,17 @@ class TosaTestGen: "rank": (1, 4), "build_fcn": (build_argmax, TosaTensorGen.tgBasic, TosaArgGen.agAxis), "types": TYPE_NARROW_INT_FP, - "error_if_validators": (TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evArgmaxOutputRankMismatch, - TosaErrorValidator.evArgmaxOutputShapeMismatch, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType, - TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList) + "error_if_validators": ( + TosaErrorValidator.evAxisSmallerZero, + TosaErrorValidator.evAxisLargerRank, + TosaErrorValidator.evArgmaxOutputRankMismatch, + TosaErrorValidator.evArgmaxOutputShapeMismatch, + TosaErrorValidator.evWrongRank, + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + ), }, "avg_pool2d": { "op": Op.AVG_POOL2D, @@ -5565,10 +6105,20 @@ class TosaTestGen: "qgen": TosaQuantGen.qgUnary, "types": TYPE_NARROW_INT_FP, "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,), - "error_if_validators": (TosaErrorValidator.evKernelSmallerOne, TosaErrorValidator.evStrideSmallerOne, TosaErrorValidator.evPadSmallerZero, - TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, - TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evOutputZeroPointNotZero, - TosaErrorValidator.evPadLargerEqualKernel, TosaErrorValidator.evPoolingOutputShapeMismatch) + "error_if_validators": ( + TosaErrorValidator.evKernelSmallerOne, + TosaErrorValidator.evStrideSmallerOne, + TosaErrorValidator.evPadSmallerZero, + TosaErrorValidator.evWrongRank, + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + TosaErrorValidator.evInputZeroPointNotZero, + TosaErrorValidator.evOutputZeroPointNotZero, + TosaErrorValidator.evPadLargerEqualKernel, + TosaErrorValidator.evPoolingOutputShapeMismatch, + ), }, # Templated operator. Filled in by createDynamicOpLists "conv2d_TEMPLATE": { @@ -5651,8 +6201,15 @@ class TosaTestGen: "build_fcn": (build_fully_connected, TosaTensorGen.tgFullyConnected, None), "qgen": TosaQuantGen.qgConv, "types": TYPE_CONV, - "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evWeightZeroPointNotZero, TosaErrorValidator.evWrongRank, - TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList) + "error_if_validators": ( + TosaErrorValidator.evInputZeroPointNotZero, + TosaErrorValidator.evWeightZeroPointNotZero, + TosaErrorValidator.evWrongRank, + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + ), }, "matmul": { "op": Op.MATMUL, @@ -5661,8 +6218,14 @@ class TosaTestGen: "build_fcn": (build_matmul, TosaTensorGen.tgMatmul, None), "qgen": TosaQuantGen.qgMatmul, "types": TYPE_NARROW_INT_FP, - "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType, - TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList) + "error_if_validators": ( + TosaErrorValidator.evInputZeroPointNotZero, + TosaErrorValidator.evWrongRank, + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + ), }, "max_pool2d": { "op": Op.MAX_POOL2D, @@ -5671,9 +6234,18 @@ class TosaTestGen: "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling), "types": TYPE_NARROW_INT_FP, "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,), - "error_if_validators": (TosaErrorValidator.evKernelSmallerOne, TosaErrorValidator.evStrideSmallerOne, TosaErrorValidator.evPadSmallerZero, - TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, - TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evPadLargerEqualKernel, TosaErrorValidator.evPoolingOutputShapeMismatch) + "error_if_validators": ( + TosaErrorValidator.evKernelSmallerOne, + TosaErrorValidator.evStrideSmallerOne, + TosaErrorValidator.evPadSmallerZero, + TosaErrorValidator.evWrongRank, + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + TosaErrorValidator.evPadLargerEqualKernel, + TosaErrorValidator.evPoolingOutputShapeMismatch, + ), }, # Templated operator. Filled in by createDynamicOpLists "transpose_conv2d_TEMPLATE": { @@ -5711,24 +6283,37 @@ class TosaTestGen: "operands": (1, 0), "build_fcn": (build_clamp, TosaTensorGen.tgBasic, None), "types": TYPE_NARROW_INT_FP, - "error_if_validators": (TosaErrorValidator.evMaxSmallerMin, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, - TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList) + "error_if_validators": ( + TosaErrorValidator.evMaxSmallerMin, + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + ), }, "sigmoid": { "op": Op.SIGMOID, "operands": (1, 0), "build_fcn": (build_sigmoid, TosaTensorGen.tgBasic, None), "types": TYPE_FP, - "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, - TosaErrorValidator.evWrongOutputList) + "error_if_validators": ( + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + ), }, "tanh": { "op": Op.TANH, "operands": (1, 0), "build_fcn": (build_tanh, TosaTensorGen.tgBasic, None), "types": TYPE_FP, - "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, - TosaErrorValidator.evWrongOutputList) + "error_if_validators": ( + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + ), }, # Elementwise Binary Operators "add": { @@ -5736,8 +6321,14 @@ class TosaTestGen: "operands": (2, 0), "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None), "types": TYPE_FI32, - "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, - TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch) + "error_if_validators": ( + TosaErrorValidator.evRankMismatch, + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + TosaErrorValidator.evDimensionMismatch, + ), }, "arithmetic_right_shift": { "op": Op.ARITHMETIC_RIGHT_SHIFT, @@ -5748,120 +6339,210 @@ class TosaTestGen: TosaArgGen.agArithmeticRightShift, ), "types": TYPE_INT, - "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, - TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch) + "error_if_validators": ( + TosaErrorValidator.evRankMismatch, + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + TosaErrorValidator.evDimensionMismatch, + ), }, "bitwise_and": { "op": Op.BITWISE_AND, "operands": (2, 0), "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None), "types": TYPE_INT, - "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, - TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch) + "error_if_validators": ( + TosaErrorValidator.evRankMismatch, + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + TosaErrorValidator.evDimensionMismatch, + ), }, "bitwise_or": { "op": Op.BITWISE_OR, "operands": (2, 0), "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None), "types": TYPE_INT, - "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, - TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch) + "error_if_validators": ( + TosaErrorValidator.evRankMismatch, + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + TosaErrorValidator.evDimensionMismatch, + ), }, "bitwise_xor": { "op": Op.BITWISE_XOR, "operands": (2, 0), "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None), "types": TYPE_INT, - "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, - TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch) + "error_if_validators": ( + TosaErrorValidator.evRankMismatch, + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + TosaErrorValidator.evDimensionMismatch, + ), }, "intdiv": { "op": Op.INTDIV, "operands": (2, 0), "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None), "types": [DType.INT32], - "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, - TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch) + "error_if_validators": ( + TosaErrorValidator.evRankMismatch, + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + TosaErrorValidator.evDimensionMismatch, + ), }, "logical_and": { "op": Op.LOGICAL_AND, "operands": (2, 0), "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None), "types": TYPE_BOOL, - "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, - TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch) + "error_if_validators": ( + TosaErrorValidator.evRankMismatch, + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + TosaErrorValidator.evDimensionMismatch, + ), }, "logical_left_shift": { "op": Op.LOGICAL_LEFT_SHIFT, "operands": (2, 0), "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None), "types": TYPE_INT, - "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, - TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch) - }, - "logical_right_shift": { - "op": Op.LOGICAL_RIGHT_SHIFT, + "error_if_validators": ( + TosaErrorValidator.evRankMismatch, + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + TosaErrorValidator.evDimensionMismatch, + ), + }, + "logical_right_shift": { + "op": Op.LOGICAL_RIGHT_SHIFT, "operands": (2, 0), "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None), "types": TYPE_INT, - "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, - TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch) + "error_if_validators": ( + TosaErrorValidator.evRankMismatch, + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + TosaErrorValidator.evDimensionMismatch, + ), }, "logical_or": { "op": Op.LOGICAL_OR, "operands": (2, 0), "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None), "types": TYPE_BOOL, - "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, - TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch) + "error_if_validators": ( + TosaErrorValidator.evRankMismatch, + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + TosaErrorValidator.evDimensionMismatch, + ), }, "logical_xor": { "op": Op.LOGICAL_XOR, "operands": (2, 0), "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None), "types": TYPE_BOOL, - "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, - TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch) + "error_if_validators": ( + TosaErrorValidator.evRankMismatch, + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + TosaErrorValidator.evDimensionMismatch, + ), }, "maximum": { "op": Op.MAXIMUM, "operands": (2, 0), "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None), "types": TYPE_FI32, - "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, - TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch) + "error_if_validators": ( + TosaErrorValidator.evRankMismatch, + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + TosaErrorValidator.evDimensionMismatch, + ), }, "minimum": { "op": Op.MINIMUM, "operands": (2, 0), "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None), "types": TYPE_FI32, - "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, - TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch) + "error_if_validators": ( + TosaErrorValidator.evRankMismatch, + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + TosaErrorValidator.evDimensionMismatch, + ), }, "mul": { "op": Op.MUL, "operands": (2, 0), "build_fcn": (build_mul, TosaTensorGen.tgBroadcastFuzz, TosaArgGen.agMul), "types": TYPE_INT_FP, - "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, - TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evRankMismatch, TosaErrorValidator.evDimensionMismatch) + "error_if_validators": ( + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + TosaErrorValidator.evRankMismatch, + TosaErrorValidator.evDimensionMismatch, + ), }, "pow": { "op": Op.POW, "operands": (2, 0), "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None), "types": TYPE_FP, - "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, - TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch) + "error_if_validators": ( + TosaErrorValidator.evRankMismatch, + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + TosaErrorValidator.evDimensionMismatch, + ), }, "sub": { "op": Op.SUB, "operands": (2, 0), "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None), "types": TYPE_FI32, - "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, - TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch) + "error_if_validators": ( + TosaErrorValidator.evRankMismatch, + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + TosaErrorValidator.evDimensionMismatch, + ), }, "table": { "op": Op.TABLE, @@ -5871,8 +6552,12 @@ class TosaTestGen: "operands": (1, 0), "build_fcn": (build_table, TosaTensorGen.tgBasic, TosaArgGen.agTable), "types": [DType.INT8, DType.INT16], - "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, - TosaErrorValidator.evWrongOutputList) + "error_if_validators": ( + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + ), }, # Elementwise Unary operators "abs": { @@ -5880,64 +6565,96 @@ class TosaTestGen: "operands": (1, 0), "build_fcn": (build_unary, TosaTensorGen.tgBasic, None), "types": TYPE_FI32, - "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, - TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList) + "error_if_validators": ( + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + ), }, "bitwise_not": { "op": Op.BITWISE_NOT, "operands": (1, 0), "build_fcn": (build_unary, TosaTensorGen.tgBasic, None), "types": TYPE_INT, - "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, - TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList) + "error_if_validators": ( + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + ), }, "ceil": { "op": Op.CEIL, "operands": (1, 0), "build_fcn": (build_unary, TosaTensorGen.tgBasic, None), "types": TYPE_FP, - "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, - TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList) + "error_if_validators": ( + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + ), }, "clz": { "op": Op.CLZ, "operands": (1, 0), "build_fcn": (build_unary, TosaTensorGen.tgBasic, None), "types": [DType.INT32], - "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, - TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList) + "error_if_validators": ( + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + ), }, "exp": { "op": Op.EXP, "operands": (1, 0), "build_fcn": (build_unary, TosaTensorGen.tgBasic, None), "types": TYPE_FP, - "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, - TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList) + "error_if_validators": ( + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + ), }, "floor": { "op": Op.FLOOR, "operands": (1, 0), "build_fcn": (build_unary, TosaTensorGen.tgBasic, None), "types": TYPE_FP, - "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, - TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList) + "error_if_validators": ( + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + ), }, "log": { "op": Op.LOG, "operands": (1, 0), "build_fcn": (build_unary, TosaTensorGen.tgBasic, None), "types": TYPE_FP, - "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, - TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList) + "error_if_validators": ( + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + ), }, "logical_not": { "op": Op.LOGICAL_NOT, "operands": (1, 0), "build_fcn": (build_unary, TosaTensorGen.tgBasic, None), "types": TYPE_BOOL, - "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, - TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList) + "error_if_validators": ( + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + ), }, "negate": { "op": Op.NEGATE, @@ -5945,25 +6662,38 @@ class TosaTestGen: "build_fcn": (build_unary, TosaTensorGen.tgBasic, None), "qgen": TosaQuantGen.qgUnary, "types": TYPE_INT_FP, - "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evOutputZeroPointNotZero, - TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, - TosaErrorValidator.evWrongOutputList) + "error_if_validators": ( + TosaErrorValidator.evInputZeroPointNotZero, + TosaErrorValidator.evOutputZeroPointNotZero, + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + ), }, "reciprocal": { "op": Op.RECIPROCAL, "operands": (1, 0), "build_fcn": (build_unary, TosaTensorGen.tgBasic, None), "types": TYPE_FP, - "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, - TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList) + "error_if_validators": ( + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + ), }, "rsqrt": { "op": Op.RSQRT, "operands": (1, 0), "build_fcn": (build_unary, TosaTensorGen.tgBasic, None), "types": TYPE_FP, - "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, - TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList) + "error_if_validators": ( + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + ), }, # Elementwise Ternary operators "select": { @@ -5971,8 +6701,14 @@ class TosaTestGen: "operands": (3, 0), "build_fcn": (build_select, TosaTensorGen.tgBroadcastFuzz, None), "types": TYPE_FIB, - "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, - TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch) + "error_if_validators": ( + TosaErrorValidator.evRankMismatch, + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + TosaErrorValidator.evDimensionMismatch, + ), }, # Comparison operators "equal": { @@ -5980,24 +6716,42 @@ class TosaTestGen: "operands": (2, 0), "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None), "types": TYPE_FI32, - "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, - TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch) + "error_if_validators": ( + TosaErrorValidator.evRankMismatch, + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + TosaErrorValidator.evDimensionMismatch, + ), }, "greater_equal": { "op": Op.GREATER_EQUAL, "operands": (2, 0), "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None), "types": TYPE_FI32, - "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, - TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch) + "error_if_validators": ( + TosaErrorValidator.evRankMismatch, + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + TosaErrorValidator.evDimensionMismatch, + ), }, "greater": { "op": Op.GREATER, "operands": (2, 0), "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None), "types": TYPE_FI32, - "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, - TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch) + "error_if_validators": ( + TosaErrorValidator.evRankMismatch, + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + TosaErrorValidator.evDimensionMismatch, + ), }, # Reduction operators "reduce_all": { @@ -6006,9 +6760,16 @@ class TosaTestGen: "rank": (1, 4), "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis), "types": TYPE_BOOL, - "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne, - TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank, - TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList) + "error_if_validators": ( + TosaErrorValidator.evAxisLargerRank, + TosaErrorValidator.evAxisSmallerZero, + TosaErrorValidator.evShapeOfAxisNotOne, + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongRank, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + ), }, "reduce_any": { "op": Op.REDUCE_ANY, @@ -6016,9 +6777,16 @@ class TosaTestGen: "rank": (1, 4), "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis), "types": TYPE_BOOL, - "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne, - TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank, - TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList) + "error_if_validators": ( + TosaErrorValidator.evAxisLargerRank, + TosaErrorValidator.evAxisSmallerZero, + TosaErrorValidator.evShapeOfAxisNotOne, + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongRank, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + ), }, "reduce_max": { "op": Op.REDUCE_MAX, @@ -6026,9 +6794,16 @@ class TosaTestGen: "rank": (1, 4), "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis), "types": TYPE_INT_FP, - "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne, - TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank, - TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList) + "error_if_validators": ( + TosaErrorValidator.evAxisLargerRank, + TosaErrorValidator.evAxisSmallerZero, + TosaErrorValidator.evShapeOfAxisNotOne, + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongRank, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + ), }, "reduce_min": { "op": Op.REDUCE_MAX, @@ -6036,9 +6811,16 @@ class TosaTestGen: "rank": (1, 4), "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis), "types": TYPE_INT_FP, - "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne, - TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank, - TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList) + "error_if_validators": ( + TosaErrorValidator.evAxisLargerRank, + TosaErrorValidator.evAxisSmallerZero, + TosaErrorValidator.evShapeOfAxisNotOne, + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongRank, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + ), }, "reduce_product": { "op": Op.REDUCE_PRODUCT, @@ -6046,9 +6828,16 @@ class TosaTestGen: "rank": (1, 4), "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis), "types": TYPE_FP, - "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne, - TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank, - TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList) + "error_if_validators": ( + TosaErrorValidator.evAxisLargerRank, + TosaErrorValidator.evAxisSmallerZero, + TosaErrorValidator.evShapeOfAxisNotOne, + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongRank, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + ), }, "reduce_sum": { "op": Op.REDUCE_SUM, @@ -6056,9 +6845,16 @@ class TosaTestGen: "rank": (1, 4), "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis), "types": TYPE_FI32, - "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne, - TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank, - TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList) + "error_if_validators": ( + TosaErrorValidator.evAxisLargerRank, + TosaErrorValidator.evAxisSmallerZero, + TosaErrorValidator.evShapeOfAxisNotOne, + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongRank, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + ), }, # Data layout operators "concat": { @@ -6066,9 +6862,16 @@ class TosaTestGen: "operands": (2, 0), "build_fcn": (build_concat, TosaTensorGen.tgConcat, TosaArgGen.agAxis), "types": TYPE_FIB, - "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evConcatInputRankMismatch, - TosaErrorValidator.evConcatShapeSumMismatch, TosaErrorValidator.evConcatInputDimMismatch, TosaErrorValidator.evWrongInputType, - TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongOutputList) + "error_if_validators": ( + TosaErrorValidator.evAxisLargerRank, + TosaErrorValidator.evAxisSmallerZero, + TosaErrorValidator.evConcatInputRankMismatch, + TosaErrorValidator.evConcatShapeSumMismatch, + TosaErrorValidator.evConcatInputDimMismatch, + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongOutputList, + ), }, "pad": { "op": Op.PAD, @@ -6077,24 +6880,40 @@ class TosaTestGen: "build_fcn": (build_pad, TosaTensorGen.tgBasic, TosaArgGen.agPad), "qgen": TosaQuantGen.qgPad, "types": TYPE_FIB, - "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evPadSmallerZero, - TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList) + "error_if_validators": ( + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evPadSmallerZero, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + ), }, "reshape": { "op": Op.RESHAPE, "operands": (1, 0), "build_fcn": (build_reshape, TosaTensorGen.tgBasic, TosaArgGen.agReshape), "types": TYPE_FIB, - "error_if_validators": (TosaErrorValidator.evTensorSizeInputOutputMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, - TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList) + "error_if_validators": ( + TosaErrorValidator.evTensorSizeInputOutputMismatch, + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + ), }, "reverse": { "op": Op.REVERSE, "operands": (1, 0), "build_fcn": (build_reverse, TosaTensorGen.tgBasic, TosaArgGen.agAxis), "types": TYPE_FIB, - "error_if_validators": (TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evWrongInputType, - TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList) + "error_if_validators": ( + TosaErrorValidator.evAxisSmallerZero, + TosaErrorValidator.evAxisLargerRank, + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + ), }, "slice": { "op": Op.SLICE, @@ -6102,17 +6921,30 @@ class TosaTestGen: "rank": (1, 4), "build_fcn": (build_slice, TosaTensorGen.tgBasic, TosaArgGen.agSlice), "types": TYPE_FIB, - "error_if_validators": (TosaErrorValidator.evStartSmallerZero, TosaErrorValidator.evSizeSmallerEqualZero, TosaErrorValidator.evStartSizeOutsideBounds, - TosaErrorValidator.evSizeOutputShapeMismatch, TosaErrorValidator.evInputSizeStartLengthMismatch, TosaErrorValidator.evWrongRank, - TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList) + "error_if_validators": ( + TosaErrorValidator.evStartSmallerZero, + TosaErrorValidator.evSizeSmallerEqualZero, + TosaErrorValidator.evStartSizeOutsideBounds, + TosaErrorValidator.evSizeOutputShapeMismatch, + TosaErrorValidator.evInputSizeStartLengthMismatch, + TosaErrorValidator.evWrongRank, + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + ), }, "tile": { "op": Op.TILE, "operands": (1, 0), "build_fcn": (build_tile, TosaTensorGen.tgBasic, TosaArgGen.agTile), "types": TYPE_FIB, - "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, - TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList) + "error_if_validators": ( + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + ), }, "transpose": { "op": Op.TRANSPOSE, @@ -6124,8 +6956,14 @@ class TosaTestGen: TosaArgGen.agTranspose, ), "types": TYPE_FIB, - "error_if_validators": (TosaErrorValidator.evIndexOutsideBounds, TosaErrorValidator.evIndexUsedTwice, - TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList) + "error_if_validators": ( + TosaErrorValidator.evIndexOutsideBounds, + TosaErrorValidator.evIndexUsedTwice, + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + ), }, # Data nodes "const": { @@ -6148,19 +6986,29 @@ class TosaTestGen: "rank": (3, 3), "build_fcn": (build_gather, TosaTensorGen.tgBasic, None), "types": TYPE_INT_FP, - "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, - TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evWrongRank) + "error_if_validators": ( + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + TosaErrorValidator.evWrongRank, + ), }, "scatter": { "op": Op.SCATTER, # Only specify 'values_in' tensor here. - #'indices' and 'input' are generated in op building stage + # 'indices' and 'input' are generated in op building stage "operands": (2, 0), "rank": (3, 3), "build_fcn": (build_scatter, TosaTensorGen.tgScatter, None), "types": TYPE_INT_FP, - "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, - TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evWrongRank) + "error_if_validators": ( + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + TosaErrorValidator.evWrongRank, + ), }, # Image operations "resize": { @@ -6169,12 +7017,28 @@ class TosaTestGen: "rank": (4, 4), "build_fcn": (build_resize, TosaTensorGen.tgNHWC, TosaArgGen.agResize), "types": [DType.INT8, DType.INT16, DType.FLOAT], - "invalid_test_validators": (TosaInvalidValidator.ivWrongDataTypeOrModeResize, TosaInvalidValidator.ivBadStride), - "error_if_validators": (TosaErrorValidator.evMaxDimExceeded, TosaErrorValidator.evStrideSmallerEqualZero, TosaErrorValidator.evStrideLargerDimension, - TosaErrorValidator.evStrideLargerEqualMax, TosaErrorValidator.evOffsetSmallerEqualMin, TosaErrorValidator.evOffsetLargerEqualMax, - TosaErrorValidator.evShiftNotZero, TosaErrorValidator.evShiftSmallerOne, TosaErrorValidator.evShiftLargerEleven, TosaErrorValidator.evWrongInputType, - TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, - TosaErrorValidator.evBatchMismatch, TosaErrorValidator.evChannelMismatch) + "invalid_test_validators": ( + TosaInvalidValidator.ivWrongDataTypeOrModeResize, + TosaInvalidValidator.ivBadStride, + ), + "error_if_validators": ( + TosaErrorValidator.evMaxDimExceeded, + TosaErrorValidator.evStrideSmallerEqualZero, + TosaErrorValidator.evStrideLargerDimension, + TosaErrorValidator.evStrideLargerEqualMax, + TosaErrorValidator.evOffsetSmallerEqualMin, + TosaErrorValidator.evOffsetLargerEqualMax, + TosaErrorValidator.evShiftNotZero, + TosaErrorValidator.evShiftSmallerOne, + TosaErrorValidator.evShiftLargerEleven, + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongRank, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + TosaErrorValidator.evBatchMismatch, + TosaErrorValidator.evChannelMismatch, + ), }, # Type conversion "cast": { @@ -6182,18 +7046,30 @@ class TosaTestGen: "operands": (1, 0), "build_fcn": (build_cast, TosaTensorGen.tgBasic, TosaArgGen.agCast), "types": [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL], - "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, - TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList) + "error_if_validators": ( + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + ), }, "rescale": { "op": Op.RESCALE, "operands": (1, 0), - "rank": (1,4), + "rank": (1, 4), "build_fcn": (build_rescale, TosaTensorGen.tgBasic, TosaArgGen.agRescale), "types": [DType.UINT8, DType.INT8, DType.INT16, DType.INT32, DType.INT48], - "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evOutputZeroPointNotZero, TosaErrorValidator.evScaleTrue, - TosaErrorValidator.evScaleNotTrue, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank, - TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList) + "error_if_validators": ( + TosaErrorValidator.evInputZeroPointNotZero, + TosaErrorValidator.evOutputZeroPointNotZero, + TosaErrorValidator.evScaleTrue, + TosaErrorValidator.evScaleNotTrue, + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongRank, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + ), }, # Custom # Not implemented. @@ -6210,7 +7086,10 @@ class TosaTestGen: TosaArgGen.agCondIf, ), "types": [DType.BOOL], - "error_if_validators": (TosaErrorValidator.evOutputListThenGraphMismatch, TosaErrorValidator.evOutputListElseGraphMismatch) + "error_if_validators": ( + TosaErrorValidator.evOutputListThenGraphMismatch, + TosaErrorValidator.evOutputListElseGraphMismatch, + ), }, "cond_if_binary": { "op": Op.COND_IF, @@ -6221,8 +7100,12 @@ class TosaTestGen: TosaArgGen.agCondIf, ), "types": TYPE_INT_FP, - "error_if_validators": (TosaErrorValidator.evInputListThenGraphMismatch, TosaErrorValidator.evInputListElseGraphMismatch, - TosaErrorValidator.evOutputListThenGraphMismatch, TosaErrorValidator.evOutputListElseGraphMismatch) + "error_if_validators": ( + TosaErrorValidator.evInputListThenGraphMismatch, + TosaErrorValidator.evInputListElseGraphMismatch, + TosaErrorValidator.evOutputListThenGraphMismatch, + TosaErrorValidator.evOutputListElseGraphMismatch, + ), }, # while_loop "while_loop": { @@ -6234,9 +7117,13 @@ class TosaTestGen: TosaArgGen.agWhileLoop, ), "types": [DType.INT32], - "error_if_validators": (TosaErrorValidator.evInputListOutputListMismatch, TosaErrorValidator.evInputListCondGraphMismatch, - TosaErrorValidator.evInputListBodyGraphInputMismatch, TosaErrorValidator.evInputListBodyGraphOutputMismatch, - TosaErrorValidator.evCondGraphOutputNotMatchingBool) + "error_if_validators": ( + TosaErrorValidator.evInputListOutputListMismatch, + TosaErrorValidator.evInputListCondGraphMismatch, + TosaErrorValidator.evInputListBodyGraphInputMismatch, + TosaErrorValidator.evInputListBodyGraphOutputMismatch, + TosaErrorValidator.evCondGraphOutputNotMatchingBool, + ), }, } @@ -6257,13 +7144,19 @@ class OutputShaper: shape = [] for i in range(len(a.shape)): - if a.shape[i] == 1 and error_name == None: + if a.shape[i] == 1 and error_name is None: shape.append(b.shape[i]) else: shape.append(a.shape[i]) if error_name == ErrorIf.WrongOutputType: - all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT] + all_dtypes = [ + DType.INT8, + DType.INT16, + DType.INT32, + DType.INT48, + DType.FLOAT, + ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) outputDType = rng.choice(wrong_dtypes) else: @@ -6286,7 +7179,13 @@ class OutputShaper: @staticmethod def unaryOp(ser, rng, a, error_name=None): if error_name == ErrorIf.WrongOutputType: - all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT] + all_dtypes = [ + DType.INT8, + DType.INT16, + DType.INT32, + DType.INT48, + DType.FLOAT, + ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) outputDType = rng.choice(wrong_dtypes) else: @@ -6302,13 +7201,19 @@ class OutputShaper: shape = [] for i in range(len(cond.shape)): - if cond.shape[i] == 1 and error_name == None: + if cond.shape[i] == 1 and error_name is None: shape.append(max(cond.shape[i], a.shape[i], b.shape[i])) else: shape.append(cond.shape[i]) if error_name == ErrorIf.WrongOutputType: - all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT] + all_dtypes = [ + DType.INT8, + DType.INT16, + DType.INT32, + DType.INT48, + DType.FLOAT, + ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) outputDType = rng.choice(wrong_dtypes) else: @@ -6317,7 +7222,7 @@ class OutputShaper: return ser.addOutput(shape, outputDType) @staticmethod - def binaryComparisonOp(ser, rng, a, b , error_name=None): + def binaryComparisonOp(ser, rng, a, b, error_name=None): if error_name != ErrorIf.RankMismatch: assert len(a.shape) == len(b.shape) assert a.dtype == b.dtype @@ -6331,7 +7236,13 @@ class OutputShaper: shape.append(a.shape[i]) if error_name == ErrorIf.WrongOutputType: - wrong_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT] + wrong_dtypes = [ + DType.INT8, + DType.INT16, + DType.INT32, + DType.INT48, + DType.FLOAT, + ] outputDType = rng.choice(wrong_dtypes) else: outputDType = DType.BOOL @@ -6341,13 +7252,23 @@ class OutputShaper: @staticmethod def reduceOp(ser, rng, a, axis, error_name=None): shape = a.shape.copy() - if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank, ErrorIf.ShapeOfAxisNotOne]: + if error_name not in [ + ErrorIf.AxisSmallerZero, + ErrorIf.AxisLargerRank, + ErrorIf.ShapeOfAxisNotOne, + ]: shape[axis] = 1 if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1: shape[axis] = rng.integers(2, 10) if error_name == ErrorIf.WrongOutputType: - all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT] + all_dtypes = [ + DType.INT8, + DType.INT16, + DType.INT32, + DType.INT48, + DType.FLOAT, + ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) outputDType = rng.choice(wrong_dtypes) else: @@ -6373,7 +7294,13 @@ class OutputShaper: shape[i] = shape[i] + rng.integers(1, 10) if error_name == ErrorIf.WrongOutputType: - all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT] + all_dtypes = [ + DType.INT8, + DType.INT16, + DType.INT32, + DType.INT48, + DType.FLOAT, + ] wrong_dtypes = list(set(all_dtypes) - set([DType.INT32])) outputDType = rng.choice(wrong_dtypes) else: @@ -6490,7 +7417,9 @@ class OutputShaper: return ser.addOutput(ofm_shape, out_dtype) @staticmethod - def depthwiseConv2dOp(ser, rng, ifm, filter, strides, padding, dilations, error_name=None): + def depthwiseConv2dOp( + ser, rng, ifm, filter, strides, padding, dilations, error_name=None + ): # IFM: NHWC # Filter: HWCM # OFM: NHW C*M @@ -6553,7 +7482,13 @@ class OutputShaper: ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]] if error_name == ErrorIf.WrongOutputType: - all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT] + all_dtypes = [ + DType.INT8, + DType.INT16, + DType.INT32, + DType.INT48, + DType.FLOAT, + ] wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype])) outputDType = rng.choice(wrong_dtypes) else: @@ -6571,11 +7506,29 @@ class OutputShaper: if error_name == ErrorIf.WrongOutputType: if input.dtype == DType.INT8: - incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT48, DType.FLOAT) + incorrect_types = ( + DType.INT4, + DType.INT8, + DType.INT16, + DType.INT48, + DType.FLOAT, + ) elif input.dtype == DType.INT16: - incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.FLOAT) + incorrect_types = ( + DType.INT4, + DType.INT8, + DType.INT16, + DType.INT32, + DType.FLOAT, + ) elif input.dtype == DType.FLOAT: - incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48) + incorrect_types = ( + DType.INT4, + DType.INT8, + DType.INT16, + DType.INT32, + DType.INT48, + ) out_dtype = rng.choice(a=incorrect_types) elif input.dtype == DType.INT8: out_dtype = DType.INT32 @@ -6601,11 +7554,29 @@ class OutputShaper: if error_name == ErrorIf.WrongOutputType: if a.dtype == DType.INT8: - incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT48, DType.FLOAT) + incorrect_types = ( + DType.INT4, + DType.INT8, + DType.INT16, + DType.INT48, + DType.FLOAT, + ) elif a.dtype == DType.INT16: - incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.FLOAT) + incorrect_types = ( + DType.INT4, + DType.INT8, + DType.INT16, + DType.INT32, + DType.FLOAT, + ) elif a.dtype == DType.FLOAT: - incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48) + incorrect_types = ( + DType.INT4, + DType.INT8, + DType.INT16, + DType.INT32, + DType.INT48, + ) out_dtype = rng.choice(a=incorrect_types) elif a.dtype == DType.INT8: out_dtype = DType.INT32 @@ -6641,7 +7612,13 @@ class OutputShaper: output_shape[axis] += rng.integers(5, 10) if error_name == ErrorIf.WrongOutputType: - all_dtypes = {DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT} + all_dtypes = { + DType.INT8, + DType.INT16, + DType.INT32, + DType.INT48, + DType.FLOAT, + } wrong_dtypes = list(all_dtypes - set([input1.dtype])) outputDType = rng.choice(wrong_dtypes) else: @@ -6662,7 +7639,13 @@ class OutputShaper: output_shape = [i if i >= 1 else 1 for i in output_shape] if error_name == ErrorIf.WrongOutputType: - all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT] + all_dtypes = [ + DType.INT8, + DType.INT16, + DType.INT32, + DType.INT48, + DType.FLOAT, + ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) outputDType = rng.choice(wrong_dtypes) else: @@ -6694,7 +7677,13 @@ class OutputShaper: output_shape[i] = output_shape[i] + rng.integers(1, 10) if error_name == ErrorIf.WrongOutputType: - all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT] + all_dtypes = [ + DType.INT8, + DType.INT16, + DType.INT32, + DType.INT48, + DType.FLOAT, + ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) outputDType = rng.choice(wrong_dtypes) else: @@ -6706,7 +7695,13 @@ class OutputShaper: def sliceOp(ser, rng, a, start, size, error_name=None): if error_name == ErrorIf.WrongOutputType: - all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT] + all_dtypes = [ + DType.INT8, + DType.INT16, + DType.INT32, + DType.INT48, + DType.FLOAT, + ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) outputDType = rng.choice(wrong_dtypes) else: @@ -6718,7 +7713,9 @@ class OutputShaper: if output_shape[index] <= 2: output_shape[index] = output_shape[index] + rng.choice([1, 2]) else: - output_shape[index] = output_shape[index] + rng.choice([-2, -1, 1, 2]) + output_shape[index] = output_shape[index] + rng.choice( + [-2, -1, 1, 2] + ) else: output_shape = size.copy() @@ -6734,7 +7731,13 @@ class OutputShaper: output_shape[i] = a.shape[i] * multiples[i] if error_name == ErrorIf.WrongOutputType: - all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT] + all_dtypes = [ + DType.INT8, + DType.INT16, + DType.INT32, + DType.INT48, + DType.FLOAT, + ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) outputDType = rng.choice(wrong_dtypes) else: @@ -6756,7 +7759,13 @@ class OutputShaper: output_shape[i] = a.shape[perms[i]] if error_name == ErrorIf.WrongOutputType: - all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT] + all_dtypes = [ + DType.INT8, + DType.INT16, + DType.INT32, + DType.INT48, + DType.FLOAT, + ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) outputDType = rng.choice(wrong_dtypes) else: @@ -6774,7 +7783,13 @@ class OutputShaper: output_shape = [values.shape[0], indices.shape[1], values.shape[2]] if error_name == ErrorIf.WrongOutputType: - all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT] + all_dtypes = [ + DType.INT8, + DType.INT16, + DType.INT32, + DType.INT48, + DType.FLOAT, + ] wrong_dtypes = list(set(all_dtypes) - set([values.dtype])) outputDType = rng.choice(wrong_dtypes) else: @@ -6795,7 +7810,13 @@ class OutputShaper: output_shape = values_in.shape if error_name == ErrorIf.WrongOutputType: - all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT] + all_dtypes = [ + DType.INT8, + DType.INT16, + DType.INT32, + DType.INT48, + DType.FLOAT, + ] wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype])) outputDType = rng.choice(wrong_dtypes) else: @@ -6810,7 +7831,13 @@ class OutputShaper: assert input.dtype == DType.INT16 or input.dtype == DType.INT8 output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8 if error_name == ErrorIf.WrongOutputType: - wrong_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT] + wrong_dtypes = [ + DType.INT8, + DType.INT16, + DType.INT32, + DType.INT48, + DType.FLOAT, + ] wrong_dtypes.remove(output_dtype) output_dtype = rng.choice(wrong_dtypes) return ser.addOutput(input.shape, output_dtype) @@ -6829,17 +7856,37 @@ class OutputShaper: output_dims, input_dtype, output_dtype, - error_name = None + error_name=None, ): if error_name == ErrorIf.WrongRank: - output_dims = [input.shape[0], output_dims[0], output_dims[0], input.shape[0]] + output_dims = [ + input.shape[0], + output_dims[0], + output_dims[0], + input.shape[0], + ] else: if error_name == ErrorIf.BatchMismatch: - output_dims = [input.shape[0] + rng.integers(1, 10), output_dims[0], output_dims[1], input.shape[3]] + output_dims = [ + input.shape[0] + rng.integers(1, 10), + output_dims[0], + output_dims[1], + input.shape[3], + ] elif error_name == ErrorIf.ChannelMismatch: - output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3] + rng.integers(1, 10)] + output_dims = [ + input.shape[0], + output_dims[0], + output_dims[1], + input.shape[3] + rng.integers(1, 10), + ] else: - output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3]] + output_dims = [ + input.shape[0], + output_dims[0], + output_dims[1], + input.shape[3], + ] return serializer.addOutput(output_dims, output_dtype) diff --git a/verif/generator/tosa_verif_build_tests.py b/verif/generator/tosa_verif_build_tests.py index 09ee238..50f4033 100644 --- a/verif/generator/tosa_verif_build_tests.py +++ b/verif/generator/tosa_verif_build_tests.py @@ -1,38 +1,12 @@ -# Copyright (c) 2020-2021, ARM Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - +# Copyright (c) 2020-2022, ARM Limited. +# SPDX-License-Identifier: Apache-2.0 import argparse -import sys import re -import os -import subprocess -import shlex -import json -import glob -import math -import queue -import threading -import traceback - - -from enum import IntEnum, Enum, unique -from datetime import datetime from generator.tosa_test_gen import TosaTestGen from serializer.tosa_serializer import dtype_str_to_val + # Used for parsing a comma-separated list of integers in a string # to an actual list of integers def str_to_list(in_s): @@ -189,7 +163,7 @@ def parseArgs(): parser.add_argument( "--test-type", dest="test_type", - choices=['positive', 'negative', 'both'], + choices=["positive", "negative", "both"], default="positive", type=str, help="type of tests produced, postive, negative, or both", @@ -205,8 +179,8 @@ def main(): ttg = TosaTestGen(args) - if args.test_type == 'both': - testType = ['positive', 'negative'] + if args.test_type == "both": + testType = ["positive", "negative"] else: testType = [args.test_type] results = [] @@ -220,7 +194,7 @@ def main(): shapeFilter=args.target_shapes, rankFilter=args.target_ranks, dtypeFilter=args.target_dtypes, - testType=test_type + testType=test_type, ) ) @@ -236,11 +210,12 @@ def main(): if args.verbose: print(testStr) - results.append(ttg.serializeTest(opName, testStr, dtype, error, shapeList, testArgs)) + results.append( + ttg.serializeTest(opName, testStr, dtype, error, shapeList, testArgs) + ) print(f"Done creating {len(results)} tests") - if __name__ == "__main__": exit(main()) diff --git a/verif/tests/test_json2numpy.py b/verif/tests/test_json2numpy.py index aec555c..63bc2d9 100644 --- a/verif/tests/test_json2numpy.py +++ b/verif/tests/test_json2numpy.py @@ -6,7 +6,6 @@ import os import numpy as np import pytest - from json2numpy.json2numpy import main diff --git a/verif/tests/test_tosa_result_checker.py b/verif/tests/test_tosa_result_checker.py index bc8a2fc..efee23b 100644 --- a/verif/tests/test_tosa_result_checker.py +++ b/verif/tests/test_tosa_result_checker.py @@ -3,11 +3,10 @@ # SPDX-License-Identifier: Apache-2.0 from pathlib import Path +import checker.tosa_result_checker as trc import numpy as np import pytest -import checker.tosa_result_checker as trc - def _create_data_file(name, npy_data): """Create numpy data file.""" diff --git a/verif/tests/test_tosa_run_tests_mocksut.py b/verif/tests/test_tosa_run_tests_mocksut.py index 98044e0..234f156 100644 --- a/verif/tests/test_tosa_run_tests_mocksut.py +++ b/verif/tests/test_tosa_run_tests_mocksut.py @@ -7,7 +7,6 @@ from pathlib import Path from xml.dom import minidom import pytest - from runner.tosa_verif_run_tests import main -- cgit v1.2.1