aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2022-01-13 15:04:21 +0000
committerJeremy Johnson <jeremy.johnson@arm.com>2022-01-13 15:04:25 +0000
commit5c1364c8a547127bc90e4d4a78dd876070eb1026 (patch)
tree70187ef07f549687918f8a93725b74f1f7fc2715
parentbe1a9408eb53871d96a022f59664f016926a8cf4 (diff)
downloadreference_model-5c1364c8a547127bc90e4d4a78dd876070eb1026.tar.gz
Add python pre-commit script checkers
Fix up issues in existing python scripts. Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: Id4adab404560c3129c66f31c21ff0ce148283c73
-rw-r--r--.pre-commit-config.yaml20
-rw-r--r--scripts/json2fbbin/json2fbbin.py8
-rw-r--r--setup.cfg6
-rw-r--r--verif/generator/tosa_error_if.py17
-rw-r--r--verif/generator/tosa_test_gen.py3055
-rw-r--r--verif/generator/tosa_verif_build_tests.py45
-rw-r--r--verif/tests/test_json2numpy.py1
-rw-r--r--verif/tests/test_tosa_result_checker.py3
-rw-r--r--verif/tests/test_tosa_run_tests_mocksut.py1
9 files changed, 2097 insertions, 1059 deletions
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 0000000..4a73a48
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,20 @@
+# Copyright (c) 2021-2022 Arm Limited.
+# SPDX-License-Identifier: Apache-2.0
+
+# See https://pre-commit.com for more information
+# See https://pre-commit.com/hooks.html for more hooks
+repos:
+- repo: https://github.com/asottile/reorder_python_imports
+ rev: v2.2.0
+ hooks:
+ - id: reorder-python-imports
+
+- repo: https://github.com/psf/black
+ rev: 20.8b1
+ hooks:
+ - id: black
+
+- repo: https://gitlab.com/pycqa/flake8
+ rev: 3.7.9
+ hooks:
+ - id: flake8
diff --git a/scripts/json2fbbin/json2fbbin.py b/scripts/json2fbbin/json2fbbin.py
index 957acb1..8f9f274 100644
--- a/scripts/json2fbbin/json2fbbin.py
+++ b/scripts/json2fbbin/json2fbbin.py
@@ -4,7 +4,8 @@
from pathlib import Path
from typing import Optional
-from runner.run_command import run_sh_command, RunShCommandError
+from runner.run_command import run_sh_command
+from runner.run_command import RunShCommandError
def fbbin_to_json(flatc: Path, fbs: Path, t_path: Path, o_path: Optional[Path] = None):
@@ -63,7 +64,10 @@ def main(argv=None):
parser.add_argument(
"--flatc",
type=Path,
- default="reference_model/build/thirdparty/serialization_lib/third_party/flatbuffers/flatc",
+ default=(
+ "reference_model/build/thirdparty/serialization_lib/"
+ "third_party/flatbuffers/flatc"
+ ),
help="the path to the flatc compiler program",
)
parser.add_argument(
diff --git a/setup.cfg b/setup.cfg
index f9e5331..4e3dc10 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -50,3 +50,9 @@ console_scripts =
[tool:pytest]
testpaths=verif/tests
+
+[flake8]
+ignore = D213, E203, E266, E501, W503
+max-line-length = 88
+select = B,E,F,W,T4
+exclude = .eggs
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index 7c162be..7070205 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -1,16 +1,6 @@
-# Copyright (c) 2021, ARM Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
+# Copyright (c) 2021-2022, ARM Limited.
+# SPDX-License-Identifier: Apache-2.0
+
class ErrorIf(object):
MaxDimExceeded = "MaxDimExceeded"
@@ -68,4 +58,3 @@ class ErrorIf(object):
InputListBodyGraphInputMismatch = "InputListBodyGraphInputMismatch"
InputListBodyGraphOutputMismatch = "InputListBodyGraphOutputMismatch"
CondGraphOutputNotMatchingBool = "CondGraphOutputNotMatchingBool"
-
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 0d5a881..239a64e 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -1,48 +1,21 @@
-#!/usr/bin/env python3
-
# Copyright (c) 2020-2022, ARM Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import numpy as np
-import argparse
-import sys
-import re
-import os
-import subprocess
-import shlex
-import json
-import glob
-import math
-import queue
-import threading
-import traceback
-import math
+# SPDX-License-Identifier: Apache-2.0
import itertools
+import math
+import os
from copy import deepcopy
-from enum import IntEnum, Enum, unique
-
+import numpy as np
import serializer.tosa_serializer as ts
-from serializer.tosa_serializer import *
-import tosa
from generator.tosa_error_if import ErrorIf
-
-# Convenience variables to the flatc-generated types that should be enums, but aren't
+from serializer.tosa_serializer import DTypeNames
from tosa.DType import DType
from tosa.Op import Op
from tosa.ResizeMode import ResizeMode
+# DTypeNames, DType, Op and ResizeMode are convenience variables to the
+# flatc-generated types that should be enums, but aren't
+
def valueToName(item, value):
"""Get the name of an attribute with the given value.
@@ -70,7 +43,8 @@ def valueToName(item, value):
for attr in dir(item):
if getattr(item, attr) == value:
return attr
- raise ValueError(f'value ({value}) not found')
+ raise ValueError(f"value ({value}) not found")
+
def allDTypes(*, excludes=None):
"""Get a set of all DType values, optionally excluding some values.
@@ -87,9 +61,14 @@ def allDTypes(*, excludes=None):
A set of DType values
"""
excludes = () if not excludes else excludes
- return {getattr(DType, t) for t in dir(DType)
- if not callable(getattr(DType, t)) and not t.startswith('__')
- and getattr(DType, t) not in excludes}
+ return {
+ getattr(DType, t)
+ for t in dir(DType)
+ if not callable(getattr(DType, t))
+ and not t.startswith("__")
+ and getattr(DType, t) not in excludes
+ }
+
def usableDTypes(*, excludes=None):
"""Get a set of usable DType values, optionally excluding some values.
@@ -108,6 +87,7 @@ def usableDTypes(*, excludes=None):
omit.update(excludes if excludes else ())
return allDTypes(excludes=omit)
+
def product(shape):
value = 1
for n in shape:
@@ -116,7 +96,10 @@ def product(shape):
class TosaQuantGen:
- """QuantizedInfo random generator helper functions. Specify with 'qgen': in the operator defintion"""
+ """QuantizedInfo random generator helper functions.
+
+ Specify with 'qgen': in the operator defintion.
+ """
def __init__(self):
pass
@@ -128,7 +111,11 @@ class TosaQuantGen:
return testGen.randInt(-128, 128)
elif dtype == DType.UINT8:
return testGen.randInt(0, 256)
- elif error_name in [ErrorIf.InputZeroPointNotZero, ErrorIf.WeightZeroPointNotZero, ErrorIf.OutputZeroPointNotZero]:
+ elif error_name in [
+ ErrorIf.InputZeroPointNotZero,
+ ErrorIf.WeightZeroPointNotZero,
+ ErrorIf.OutputZeroPointNotZero,
+ ]:
zero_point = testGen.randInt(-128, 128)
if zero_point == 0:
zero_point = 1
@@ -140,15 +127,18 @@ class TosaQuantGen:
qinfo = ts.TosaSerializerQuantInfo()
if error_name == ErrorIf.InputZeroPointNotZero:
qinfo.UnaryQuantInfo(
- TosaQuantGen.getQinfo(testGen, dtype, error_name), TosaQuantGen.getQinfo(testGen, dtype)
+ TosaQuantGen.getQinfo(testGen, dtype, error_name),
+ TosaQuantGen.getQinfo(testGen, dtype),
)
elif error_name == ErrorIf.OutputZeroPointNotZero:
qinfo.UnaryQuantInfo(
- TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype, error_name)
+ TosaQuantGen.getQinfo(testGen, dtype),
+ TosaQuantGen.getQinfo(testGen, dtype, error_name),
)
else:
qinfo.UnaryQuantInfo(
- TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype)
+ TosaQuantGen.getQinfo(testGen, dtype),
+ TosaQuantGen.getQinfo(testGen, dtype),
)
return qinfo
@@ -180,11 +170,13 @@ class TosaQuantGen:
qinfo = ts.TosaSerializerQuantInfo()
if error_name == ErrorIf.InputZeroPointNotZero:
qinfo.MatMulQuantInfo(
- TosaQuantGen.getQinfo(testGen, dtype, error_name), TosaQuantGen.getQinfo(testGen, dtype, error_name)
- )
+ TosaQuantGen.getQinfo(testGen, dtype, error_name),
+ TosaQuantGen.getQinfo(testGen, dtype, error_name),
+ )
else:
qinfo.MatMulQuantInfo(
- TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype)
+ TosaQuantGen.getQinfo(testGen, dtype),
+ TosaQuantGen.getQinfo(testGen, dtype),
)
return qinfo
@@ -221,7 +213,8 @@ class TosaQuantGen:
shift = shift + 1
shift = (-shift) + scaleBits
- #print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(scaleFp, scaleBits, m, multiplier, shift))
+ # print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(
+ # scaleFp, scaleBits, m, multiplier, shift))
# Adjust multiplier such that shift is in allowed value range.
if shift == 0:
@@ -242,7 +235,10 @@ class TosaQuantGen:
class TosaTensorGen:
"""Tensor generators create a shape list for the placeholder and const tensor
- data operands for the operator. The actual random data is generated separately for each test."""
+ data operands for the operator.
+
+ The actual random data is generated separately for each test.
+ """
def __init__(self):
pass
@@ -331,7 +327,7 @@ class TosaTensorGen:
# Choose one of the inputs to broadcast
# Note: Simplifies OutputShaper code if we don't change first shape for errors
- bcast_idx = testGen.randInt(0 if error_name == None else 1, pl + const)
+ bcast_idx = testGen.randInt(0 if error_name is None else 1, pl + const)
for i in range(pl + const):
shape_bcast = shape.copy()
@@ -343,7 +339,9 @@ class TosaTensorGen:
elif error_name == ErrorIf.RankMismatch:
# Add one rank to the shape (or more for rank of 1)
extra_ranks = testGen.rng.choice([1, 2, 3]) if rank == 1 else 1
- shape_bcast = np.concatenate((shape_bcast, testGen.makeShape(extra_ranks)))
+ shape_bcast = np.concatenate(
+ (shape_bcast, testGen.makeShape(extra_ranks))
+ )
if rank != 1:
# Either keep the extra rank, or remove it
new_len = testGen.rng.choice([-2, len(shape_bcast)])
@@ -371,7 +369,9 @@ class TosaTensorGen:
# Constrict the overall size of the shape when creating ERROR_IF tests
if error_name:
- ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape, max_dim=24, max_items=10000)
+ ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
+ ifm_shape, max_dim=24, max_items=10000
+ )
# Get the filter height/width from the operator parameters
filter_hw = op["filter"]
@@ -403,7 +403,9 @@ class TosaTensorGen:
# Constrict the overall size of the shape when creating ERROR_IF tests
if error_name:
- ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape, max_dim=24, max_items=10000)
+ ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
+ ifm_shape, max_dim=24, max_items=10000
+ )
# Get the filter depth/height/width from the operator parameters
filter_dhw = op["filter"]
@@ -437,7 +439,9 @@ class TosaTensorGen:
# Constrict the overall size of the shape when creating ERROR_IF tests
if error_name:
- ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape, max_dim=24, max_items=10000)
+ ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
+ ifm_shape, max_dim=24, max_items=10000
+ )
# Get the filter height/width from the operator parameters
filter_hw = op["filter"]
@@ -470,7 +474,9 @@ class TosaTensorGen:
# Constrict the overall size of the shape when creating ERROR_IF tests
if error_name:
- ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape, max_dim=24, max_items=10000)
+ ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
+ ifm_shape, max_dim=24, max_items=10000
+ )
# Get the filter height/width from the operator parameters
# Filter is KH, HW, C, M
@@ -571,7 +577,11 @@ class TosaTensorGen:
@staticmethod
def tgConcatConstInput(testGen, shapeList, axis, error_name=None):
- if error_name in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank, ErrorIf.ConcatInputRankMismatch]:
+ if error_name in [
+ ErrorIf.AxisSmallerZero,
+ ErrorIf.AxisLargerRank,
+ ErrorIf.ConcatInputRankMismatch,
+ ]:
return shapeList
# Split concat shape along axis to allow for multiple const inputs
@@ -613,10 +623,13 @@ class TosaTensorGen:
class TosaArgGen:
- """Argument generators create exhaustive or random lists of attributes for operators that take
- attributes or other parameters. The return value is a list of (descriptive_name, [arglist])
- tuples where the descriptive_name is appended to the test name and the arglist is expanded
- as arguments to the operator build function."""
+ """Argument generators create exhaustive or random lists of attributes for
+ operators that take attributes or other parameters.
+
+ The return value is a list of (descriptive_name, [arglist]) tuples where
+ the descriptive_name is appended to the test name and the arglist is expanded
+ as arguments to the operator build function.
+ """
def __init__(self):
pass
@@ -651,7 +664,7 @@ class TosaArgGen:
ifm_shape = shapeList[0]
filter_shape = shapeList[1]
- # determine the kernel shape from the operator name (e.g. "conv2d_3x3" => [3,3])
+ # determine the kernel shape from operator name (e.g. "conv2d_3x3" => [3,3])
k = [int(x) for x in opName.split("_")[-1].split("x")]
# Check the rank
@@ -687,11 +700,15 @@ class TosaArgGen:
# add some oversize argument values
if max(ifm_shape) < 64:
bigPadding = 9
- paddings.update({x for x in itertools.product(*([[0, bigPadding]] * (k_rank * 2)))})
+ paddings.update(
+ {x for x in itertools.product(*([[0, bigPadding]] * (k_rank * 2)))}
+ )
bigStride = 8
strides.update({x for x in itertools.product(*([[1, bigStride]] * k_rank))})
bigDilation = 7
- dilations.update({x for x in itertools.product(*([[1, bigDilation]] * k_rank))})
+ dilations.update(
+ {x for x in itertools.product(*([[1, bigDilation]] * k_rank))}
+ )
# There are too many parameter combinations, so generate them sparsely,
# very sparse for negative tests
@@ -700,7 +717,8 @@ class TosaArgGen:
# If there are only a small number of tests, just select them all
if sparsity < 13:
sparsity = 1
- # To get a variety of parameter combinations sparsity should not be a multiple of 2, 3 or 5
+ # To get a variety of parameter combinations sparsity should not be a
+ # multiple of 2, 3 or 5
while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
sparsity += 1
@@ -708,15 +726,19 @@ class TosaArgGen:
for s in sorted(list(strides)):
for p in sorted(list(paddings)):
for d in sorted(list(dilations)):
- if (n % sparsity == 0
+ if (
+ n % sparsity == 0
# padding must not exceed the kernel size ?
- # and p[0] < k[0] and p[1] < k[0] and p[2] < k[1] and p[3] < k[1]
+ # and p[0] < k[0] and p[1] < k[0]
+ # and p[2] < k[1] and p[3] < k[1]
# and (k_rank < 3 or (p[4] < k[2] and p[5] < k[2]))
# the padded shape must exceed the kernel size
- and (ifm_shape[1] + p[0] + p[1]) > k[0] and (ifm_shape[2] + p[2] + p[3]) > k[1]
+ and (ifm_shape[1] + p[0] + p[1]) > k[0]
+ and (ifm_shape[2] + p[2] + p[3]) > k[1]
and (k_rank < 3 or ((ifm_shape[3] + p[4] + p[5]) > k[2]))
# the padded shape must exceed the dilation
- and (ifm_shape[1] + p[0] + p[1]) > d[0] and (ifm_shape[2] + p[2] + p[3]) > d[1]
+ and (ifm_shape[1] + p[0] + p[1]) > d[0]
+ and (ifm_shape[2] + p[2] + p[3]) > d[1]
and (k_rank < 3 or ((ifm_shape[3] + p[4] + p[5]) > d[2]))
):
arg_list.append(
@@ -768,7 +790,9 @@ class TosaArgGen:
# add some oversize argument values
if max(ifm_shape) < 64:
bigPadding = 9
- paddings.update({x for x in itertools.product(*([[0, bigPadding]] * 2))})
+ paddings.update(
+ {x for x in itertools.product(*([[0, bigPadding]] * 2))}
+ )
bigStride = 8
strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
bigDilation = 7
@@ -781,7 +805,8 @@ class TosaArgGen:
# If there are only a small number of tests, just select them all
if sparsity < 13:
sparsity = 1
- # To get a variety of parameter combinations sparsity should not be a multiple of 2, 3 or 5
+ # To get a variety of parameter combinations sparsity should not be a
+ # multiple of 2, 3 or 5
while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
sparsity += 1
@@ -887,8 +912,15 @@ class TosaArgGen:
for s in sorted(list(strides)):
for p in sorted(list(paddings)):
for k in sorted(list(kernels)):
- if error_name in [ErrorIf.StrideSmallerOne, ErrorIf.KernelSmallerOne, ErrorIf.PadSmallerZero, ErrorIf.PadLargerEqualKernel]:
- sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf(testGen, error_name, s, p, k)
+ if error_name in [
+ ErrorIf.StrideSmallerOne,
+ ErrorIf.KernelSmallerOne,
+ ErrorIf.PadSmallerZero,
+ ErrorIf.PadLargerEqualKernel,
+ ]:
+ sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf(
+ testGen, error_name, s, p, k
+ )
if None not in [sNew, pNew, kNew] and n % sparsity == 0:
arg_list.append(
(
@@ -900,11 +932,16 @@ class TosaArgGen:
[sNew, pNew, kNew],
)
)
- elif (n % sparsity == 0
+ elif (
+ n % sparsity == 0
# padding must not exceed the kernel size
- and p[0] < k[0] and p[1] < k[0] and p[2] < k[1] and p[3] < k[1]
+ and p[0] < k[0]
+ and p[1] < k[0]
+ and p[2] < k[1]
+ and p[3] < k[1]
# the padded shape must exceed the kernel size
- and (shape[1] + p[0] + p[1]) > k[0] and (shape[2] + p[2] + p[3]) > k[1]
+ and (shape[1] + p[0] + p[1]) > k[0]
+ and (shape[2] + p[2] + p[3]) > k[1]
):
arg_list.append(
(
@@ -954,31 +991,53 @@ class TosaArgGen:
# Enumerate the output types here
for dtype in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
- if dtype in [DType.UINT8, DType.INT8] and error_name == ErrorIf.OutputZeroPointNotZero:
+ if (
+ dtype in [DType.UINT8, DType.INT8]
+ and error_name == ErrorIf.OutputZeroPointNotZero
+ ):
continue
- if inDtype == DType.UINT8 and dtype != DType.INT8 and error_name != ErrorIf.WrongOutputType:
+ if (
+ inDtype == DType.UINT8
+ and dtype != DType.INT8
+ and error_name != ErrorIf.WrongOutputType
+ ):
# The only output dtype for UINT8 is INT8, skip all other combinations
continue
- if inDtype != DType.INT8 and dtype == DType.UINT8 and error_name != ErrorIf.WrongOutputType:
+ if (
+ inDtype != DType.INT8
+ and dtype == DType.UINT8
+ and error_name != ErrorIf.WrongOutputType
+ ):
# The only input dtype for UINT8 is INT8, skip all other combinations
continue
- if error_name == ErrorIf.WrongOutputType and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, dtype):
+ if (
+ error_name == ErrorIf.WrongOutputType
+ and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, dtype)
+ ):
continue
for scale32 in [False, True]:
- if error_name == ErrorIf.ScaleTrue and scale32 == False:
+ if error_name == ErrorIf.ScaleTrue and not scale32:
continue
- elif error_name == ErrorIf.ScaleNotTrue and scale32 == True:
+ elif error_name == ErrorIf.ScaleNotTrue and scale32:
continue
for double_round in [False, True]:
- if error_name == ErrorIf.ScaleNotTrue and double_round == False:
+ if error_name == ErrorIf.ScaleNotTrue and not double_round:
continue
for per_channel in [False, True]:
- if inDtype == DType.INT48 and scale32 and error_name != ErrorIf.ScaleTrue:
+ if (
+ inDtype == DType.INT48
+ and scale32
+ and error_name != ErrorIf.ScaleTrue
+ ):
# Illegal condition. Must be scale32=False
continue
- if double_round and not scale32 and error_name != ErrorIf.ScaleNotTrue:
+ if (
+ double_round
+ and not scale32
+ and error_name != ErrorIf.ScaleNotTrue
+ ):
# Illegal condition. ERROR_IF(!scale32 && double_round)
continue
@@ -1093,12 +1152,13 @@ class TosaArgGen:
ifm_shape = shapeList[0]
-
if error_name == ErrorIf.IndexOutsideBounds:
- incorrect_large_index = range(len(ifm_shape)+1, 2*len(ifm_shape)+1)
+ incorrect_large_index = range(len(ifm_shape) + 1, 2 * len(ifm_shape) + 1)
incorrect_small_index = range(-len(ifm_shape), 0)
permutations = [p for p in itertools.permutations(incorrect_large_index)]
- permutations.extend([p for p in itertools.permutations(incorrect_small_index)])
+ permutations.extend(
+ [p for p in itertools.permutations(incorrect_small_index)]
+ )
elif error_name == ErrorIf.IndexUsedTwice:
# Create list with a duplicated index
perm_range = list(range(len(ifm_shape)))
@@ -1106,7 +1166,6 @@ class TosaArgGen:
perm_range[(index_choice + 1) % len(perm_range)] = perm_range[index_choice]
permutations = [p for p in itertools.permutations(perm_range)]
-
else:
# Get all permutations
permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
@@ -1151,7 +1210,9 @@ class TosaArgGen:
if valid:
# If ERROR_IF test required then incorrect start, size will be returned
- start, size = TosaErrorIfArgGen.eiSliceErrorIf(testGen, error_name, ifm_shape, start, size)
+ start, size = TosaErrorIfArgGen.eiSliceErrorIf(
+ testGen, error_name, ifm_shape, start, size
+ )
arg_list.append(("perm{}".format(p), [start, size]))
return arg_list
@@ -1170,7 +1231,8 @@ class TosaArgGen:
multiples = []
for i in range(rank):
if ifm_shape[i] > 1000:
- # Multiple of 1 if ifm_shape dimension is large to reduce tensor size
+ # Multiple of 1 if ifm_shape dimension is large to reduce
+ # tensor size
multiples.append(1)
elif max(ifm_shape) > 1000:
multiples.append(2)
@@ -1212,9 +1274,9 @@ class TosaArgGen:
# A output_dim of 1 will cause offset to exceed allowed range
# so minimum value 2 produced below
output_dims = [testGen.randInt(1) + 1, testGen.randInt(1) + 1]
- while ((float(ifm_shape[1]) / float(output_dims[0])) >= 16):
+ while (float(ifm_shape[1]) / float(output_dims[0])) >= 16:
output_dims[0] += 1
- while ((float(ifm_shape[2]) / float(output_dims[1])) >= 16):
+ while (float(ifm_shape[2]) / float(output_dims[1])) >= 16:
output_dims[1] += 1
in_center_h = (ifm_shape[1] - 1) / 2.0
@@ -1229,7 +1291,10 @@ class TosaArgGen:
if outputDType == DType.FLOAT:
float_op = True
- arg_str = "mode{}_shift{}_odim{}x{}_out{}_st{:.2f}x{:.2f}_off{:.2f}x{:.2f}"
+ arg_str = (
+ "mode{}_shift{}_odim{}x{}_out{}"
+ "_st{:.2f}x{:.2f}_off{:.2f}x{:.2f}"
+ )
shift = 0
stride = [0, 0]
offset = [0, 0]
@@ -1239,11 +1304,11 @@ class TosaArgGen:
else:
float_op = False
arg_str = "mode{}_shift{}_odim{}x{}_out{}_st{}x{}_off{}x{}"
- shift = testGen.randInt(1,12)
+ shift = testGen.randInt(1, 12)
# Now search for a shift value (1 to 11) that will produce
# a valid and predictable resize operation
count = 0
- while (count < 12):
+ while count < 12:
unit = float(1 << shift)
stride_y = int(round(fp_stride_y * unit))
stride_x = int(round(fp_stride_x * unit))
@@ -1265,20 +1330,26 @@ class TosaArgGen:
shift = (shift % 11) + 1
continue
- def RESIZE_REQUIRE_CALC(length_in, length_out, stride, offset, shift):
+ def RESIZE_REQUIRE_CALC(
+ length_in, length_out, stride, offset, shift
+ ):
# Perform the pseudo loop to look for out of bounds
- for pos in range(0,length_out):
+ for pos in range(0, length_out):
a = pos * stride + offset
ia = a >> shift
ia0 = max(ia, 0)
- ia1 = min(ia+1, length_in-1)
+ ia1 = min(ia + 1, length_in - 1)
if ia0 > ia1:
# Found a problem value
break
return ia0, ia1
- iy0, iy1 = RESIZE_REQUIRE_CALC(ifm_shape[1], output_dims[0], stride_y, offset_y, shift)
- ix0, ix1 = RESIZE_REQUIRE_CALC(ifm_shape[2], output_dims[1], stride_x, offset_x, shift)
+ iy0, iy1 = RESIZE_REQUIRE_CALC(
+ ifm_shape[1], output_dims[0], stride_y, offset_y, shift
+ )
+ ix0, ix1 = RESIZE_REQUIRE_CALC(
+ ifm_shape[2], output_dims[1], stride_x, offset_x, shift
+ )
if ix0 > ix1 or iy0 > iy1:
# Change the shift value and check again
count += 1
@@ -1298,7 +1369,14 @@ class TosaArgGen:
# Common for all data types
if error_name is not None:
- shift, stride, stride_fp, offset, offset_fp, outputDTypeNew = TosaErrorIfArgGen.eiResizeErrorIf(
+ (
+ shift,
+ stride,
+ stride_fp,
+ offset,
+ offset_fp,
+ outputDTypeNew,
+ ) = TosaErrorIfArgGen.eiResizeErrorIf(
testGen,
error_name,
mode,
@@ -1309,7 +1387,7 @@ class TosaArgGen:
stride,
stride_fp,
offset,
- offset_fp
+ offset_fp,
)
else:
outputDTypeNew = outputDType
@@ -1325,7 +1403,7 @@ class TosaArgGen:
stride_fp[0] if float_op else stride[0],
stride_fp[1] if float_op else stride[1],
offset_fp[0] if float_op else offset[0],
- offset_fp[1] if float_op else offset[1]
+ offset_fp[1] if float_op else offset[1],
),
[
mode,
@@ -1384,14 +1462,26 @@ class TosaArgGen:
return arg_list
-class TosaErrorIfArgGen:
+class TosaErrorIfArgGen:
@staticmethod
- def eiResizeErrorIf(testGen, error_name, mode, dtype, shapeList, outputDType, shift, stride, stride_fp, offset, offset_fp):
+ def eiResizeErrorIf(
+ testGen,
+ error_name,
+ mode,
+ dtype,
+ shapeList,
+ outputDType,
+ shift,
+ stride,
+ stride_fp,
+ offset,
+ offset_fp,
+ ):
if outputDType == DType.FLOAT:
if error_name == ErrorIf.StrideSmallerEqualZero:
- stride_fp = testGen.rng.random(size=[2]) - 2
+ stride_fp = testGen.rng.random(size=[2]) - 2
elif error_name == ErrorIf.ShiftNotZero:
shift = testGen.rng.integers(1, 5)
elif error_name == ErrorIf.StrideLargerDimension:
@@ -1407,11 +1497,23 @@ class TosaErrorIfArgGen:
elif error_name == ErrorIf.ShiftSmallerOne:
shift = testGen.rng.integers(-3, 1)
if shift <= 0:
- stride = [(16 >> -shift) - 1, (16 >> -shift) - 1] # avoids other ERROR_IF checks
- offset = [(16 >> -shift) - 1, (16 >> -shift) - 1] # avoids other ERROR_IF checks
+ stride = [
+ (16 >> -shift) - 1,
+ (16 >> -shift) - 1,
+ ] # avoids other ERROR_IF checks
+ offset = [
+ (16 >> -shift) - 1,
+ (16 >> -shift) - 1,
+ ] # avoids other ERROR_IF checks
else:
- stride = [(16 << shift) - 1, (16 << shift) - 1] # avoids other ERROR_IF checks
- offset = [(16 << shift) - 1, (16 << shift) - 1] # avoids other ERROR_IF checks
+ stride = [
+ (16 << shift) - 1,
+ (16 << shift) - 1,
+ ] # avoids other ERROR_IF checks
+ offset = [
+ (16 << shift) - 1,
+ (16 << shift) - 1,
+ ] # avoids other ERROR_IF checks
elif error_name == ErrorIf.ShiftLargerEleven:
shift = np.int16(testGen.rng.integers(12, 15))
elif error_name == ErrorIf.StrideLargerDimension:
@@ -1428,49 +1530,91 @@ class TosaErrorIfArgGen:
elif error_name == ErrorIf.OffsetSmallerEqualMin:
offset = [(-16 << shift) - 1, (-16 << shift) - 1]
-
if error_name == ErrorIf.WrongOutputType:
if mode == ResizeMode.NEAREST and dtype == DType.INT8:
- incorrect_types = (DType.INT4, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT)
+ incorrect_types = (
+ DType.INT4,
+ DType.INT16,
+ DType.INT32,
+ DType.INT48,
+ DType.FLOAT,
+ )
elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
- incorrect_types = (DType.INT4, DType.INT8, DType.INT32, DType.INT48, DType.FLOAT)
+ incorrect_types = (
+ DType.INT4,
+ DType.INT8,
+ DType.INT32,
+ DType.INT48,
+ DType.FLOAT,
+ )
elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
- incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT48, DType.FLOAT)
+ incorrect_types = (
+ DType.INT4,
+ DType.INT8,
+ DType.INT16,
+ DType.INT48,
+ DType.FLOAT,
+ )
elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
- incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.FLOAT)
+ incorrect_types = (
+ DType.INT4,
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.FLOAT,
+ )
elif dtype == DType.FLOAT:
- incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48)
+ incorrect_types = (
+ DType.INT4,
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.INT48,
+ )
outputDType = testGen.rng.choice(a=incorrect_types)
return shift, stride, stride_fp, offset, offset_fp, outputDType
-
@staticmethod
def eiPoolingErrorIf(testGen, error_name, stride, pad, kernel):
- if (error_name == ErrorIf.StrideSmallerOne
+ if (
+ error_name == ErrorIf.StrideSmallerOne
# padding must not exceed the kernel size
- and pad[0] < kernel[0] and pad[1] < kernel[0] and pad[2] < kernel[1] and pad[3] < kernel[1]):
- wrongStride = (testGen.rng.choice([0, -1, -2, -3]), testGen.rng.choice([0, -1, -2, -3]))
+ and pad[0] < kernel[0]
+ and pad[1] < kernel[0]
+ and pad[2] < kernel[1]
+ and pad[3] < kernel[1]
+ ):
+ wrongStride = (
+ testGen.rng.choice([0, -1, -2, -3]),
+ testGen.rng.choice([0, -1, -2, -3]),
+ )
return wrongStride, pad, kernel
elif error_name == ErrorIf.PadSmallerZero:
- wrongPad = (testGen.rng.choice([-1, -2, -3]),
- testGen.rng.choice([-1, -2, -3]),
- testGen.rng.choice([-1, -2, -3]),
- testGen.rng.choice([-1, -2, -3]))
+ wrongPad = (
+ testGen.rng.choice([-1, -2, -3]),
+ testGen.rng.choice([-1, -2, -3]),
+ testGen.rng.choice([-1, -2, -3]),
+ testGen.rng.choice([-1, -2, -3]),
+ )
return stride, wrongPad, kernel
elif error_name == ErrorIf.KernelSmallerOne:
- wrongKernel = (testGen.rng.choice([0, -1, -2, -3]), testGen.rng.choice([0, -1, -2, -3]))
+ wrongKernel = (
+ testGen.rng.choice([0, -1, -2, -3]),
+ testGen.rng.choice([0, -1, -2, -3]),
+ )
return stride, pad, wrongKernel
elif error_name == ErrorIf.PadLargerEqualKernel:
- wrongPad = (testGen.rng.choice([kernel[0], kernel[0]+1, kernel[0]+2]),
- testGen.rng.choice([kernel[0], kernel[0]+1, kernel[0]+2]),
- testGen.rng.choice([kernel[1], kernel[1]+1, kernel[1]+2]),
- testGen.rng.choice([kernel[1], kernel[1]+1, kernel[1]+2]))
+ wrongPad = (
+ testGen.rng.choice([kernel[0], kernel[0] + 1, kernel[0] + 2]),
+ testGen.rng.choice([kernel[0], kernel[0] + 1, kernel[0] + 2]),
+ testGen.rng.choice([kernel[1], kernel[1] + 1, kernel[1] + 2]),
+ testGen.rng.choice([kernel[1], kernel[1] + 1, kernel[1] + 2]),
+ )
return stride, wrongPad, kernel
else:
return None, None, None
-
@staticmethod
def eiRescaleWrongOutputType(input_dtype, output_dtype):
if input_dtype == DType.INT8:
@@ -1487,27 +1631,28 @@ class TosaErrorIfArgGen:
return True
return False
-
@staticmethod
def eiInvalidateInputOutputList(testGen, error_name, input_list, output_list):
# Mess up input/output tensors for ERROR_IF checks
if error_name == "WrongInputList":
add_input = testGen.rng.choice([True, False])
if add_input:
- input_list.append('eiDummyInput')
+ input_list.append("eiDummyInput")
else:
input_list = input_list[:-1]
elif error_name == "WrongOutputList":
add_output = testGen.rng.choice([True, False])
if add_output:
- output_list.append('eiDummyOutput')
+ output_list.append("eiDummyOutput")
else:
output_list = []
return input_list, output_list
@staticmethod
def eiRestrictDimensions(shape, max_dim=32, max_items=100000):
- """Restrict the dimensions and overall size of a shape to max_dim and max_items."""
+ """Restrict the dimensions and overall size of a shape to
+ max_dim and max_items.
+ """
new_shape = [min(d, max_dim) for d in shape] if max(shape) > max_dim else shape
while product(new_shape) > max_items:
new_shape = [max(d - 1, 1) for d in new_shape]
@@ -1527,7 +1672,7 @@ class TosaErrorIfArgGen:
elif error_name == ErrorIf.StartSizeOutsideBounds:
newStart, newSize = [], []
for i in range(len(input_shape)):
- newStart.append(input_shape[i]-1)
+ newStart.append(input_shape[i] - 1)
newSize.append(testGen.rng.choice([2, 3, 4]))
return newStart, newSize
elif error_name == ErrorIf.InputSizeStartLengthMismatch:
@@ -1556,7 +1701,6 @@ class TosaErrorIfArgGen:
class TosaErrorValidator:
-
@staticmethod
def evValidateErrorIfs(serializer, validator_fcns, error_name, **kwargs):
"""Check ERROR_IF statements are caught and set the expected result.
@@ -1572,9 +1716,9 @@ class TosaErrorValidator:
overall_result = True
for val_fcn in validator_fcns:
val_result = val_fcn(True, **kwargs)
- validator_name = val_result['error_name']
- error_result = val_result['error_result']
- error_reason = val_result['error_reason']
+ validator_name = val_result["error_name"]
+ error_result = val_result["error_result"]
+ error_reason = val_result["error_reason"]
# expect an error IFF the error_name and validator_name match
expected_result = error_result == (error_name == validator_name)
@@ -1583,18 +1727,22 @@ class TosaErrorValidator:
if expected_result and error_result:
serializer.setExpectedReturnCode(2, True, desc=error_reason)
elif error_result: # and not expected_result
- print(f"Unexpected ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}"
- f" Expected: {error_name}, Got: {validator_name}")
- elif not expected_result: # and not error_result
- print(f"Missed ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}"
- f" Expected: {error_name}")
+ print(
+ f"Unexpected ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}"
+ f" Expected: {error_name}, Got: {validator_name}"
+ )
+ elif not expected_result: # and not error_result
+ print(
+ f"Missed ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}"
+ f" Expected: {error_name}"
+ )
if not expected_result:
for k, v in sorted(kwargs.items()):
- if k != 'op':
- if k.endswith('dtype'):
+ if k != "op":
+ if k.endswith("dtype"):
v = valueToName(DType, v)
- print(f' {k} = {v}')
+ print(f" {k} = {v}")
return overall_result
@@ -1603,24 +1751,26 @@ class TosaErrorValidator:
error_result = False
# Find the unsupported input data types
- op = kwargs['op']
- input_dtypes = op['types']
- allowed_input_dtypes = {t[0] if isinstance(t, list) else t for t in input_dtypes}
+ op = kwargs["op"]
+ input_dtypes = op["types"]
+ allowed_input_dtypes = {
+ t[0] if isinstance(t, list) else t for t in input_dtypes
+ }
wrong_input_dtypes = list(usableDTypes(excludes=allowed_input_dtypes))
- if op['op'] == Op.CLAMP:
+ if op["op"] == Op.CLAMP:
wrong_input_dtypes.remove(DType.INT48)
if check:
- input_dtype = kwargs['input_dtype']
+ input_dtype = kwargs["input_dtype"]
if input_dtype not in allowed_input_dtypes:
error_result = True
info_dict = {
"error_name": ErrorIf.WrongInputType,
"error_result": error_result,
- "error_reason": f"Input data type not supported for this operator",
- "param_reqs": {"rank": None, "dtype": wrong_input_dtypes, "shape": None}
+ "error_reason": "Input data type not supported for this operator",
+ "param_reqs": {"rank": None, "dtype": wrong_input_dtypes, "shape": None},
}
return info_dict
@@ -1629,24 +1779,45 @@ class TosaErrorValidator:
error_result = False
if check:
- input_dtype = kwargs['input_dtype']
- output_dtype = kwargs['output_dtype']
- op = kwargs['op']
+ input_dtype = kwargs["input_dtype"]
+ output_dtype = kwargs["output_dtype"]
+ op = kwargs["op"]
- if op['op'] == Op.RESIZE:
- mode = kwargs['mode']
+ if op["op"] == Op.RESIZE:
+ mode = kwargs["mode"]
if (
- (mode == ResizeMode.NEAREST and input_dtype == DType.INT8 and output_dtype != DType.INT8) or
- (mode == ResizeMode.NEAREST and input_dtype == DType.INT16 and output_dtype != DType.INT16) or
- (mode == ResizeMode.BILINEAR and input_dtype == DType.INT8 and output_dtype != DType.INT32) or
- (mode == ResizeMode.BILINEAR and input_dtype == DType.INT16 and output_dtype != DType.INT48) or
- (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
+ (
+ mode == ResizeMode.NEAREST
+ and input_dtype == DType.INT8
+ and output_dtype != DType.INT8
+ )
+ or (
+ mode == ResizeMode.NEAREST
+ and input_dtype == DType.INT16
+ and output_dtype != DType.INT16
+ )
+ or (
+ mode == ResizeMode.BILINEAR
+ and input_dtype == DType.INT8
+ and output_dtype != DType.INT32
+ )
+ or (
+ mode == ResizeMode.BILINEAR
+ and input_dtype == DType.INT16
+ and output_dtype != DType.INT48
+ )
+ or (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
):
error_result = True
- elif op['op'] == Op.RESCALE:
+ elif op["op"] == Op.RESCALE:
if input_dtype == DType.INT8:
- if output_dtype not in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
+ if output_dtype not in [
+ DType.UINT8,
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ ]:
error_result = True
if input_dtype in [DType.INT16, DType.INT32]:
if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
@@ -1658,49 +1829,78 @@ class TosaErrorValidator:
if output_dtype != DType.INT8:
error_result = True
- elif op['op'] in [Op.FULLY_CONNECTED, Op.MATMUL]:
+ elif op["op"] in [Op.FULLY_CONNECTED, Op.MATMUL]:
if (
- (input_dtype == DType.INT8 and output_dtype != DType.INT32) or
- (input_dtype == DType.INT16 and output_dtype != DType.INT48) or
- (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
+ (input_dtype == DType.INT8 and output_dtype != DType.INT32)
+ or (input_dtype == DType.INT16 and output_dtype != DType.INT48)
+ or (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
):
error_result = True
- elif op['op'] == Op.ARGMAX:
- if input_dtype in [DType.INT8, DType.INT16, DType.FLOAT] and output_dtype != DType.INT32:
+ elif op["op"] == Op.ARGMAX:
+ if (
+ input_dtype in [DType.INT8, DType.INT16, DType.FLOAT]
+ and output_dtype != DType.INT32
+ ):
error_result = True
- elif op['op'] == Op.MUL:
+ elif op["op"] == Op.MUL:
if input_dtype != DType.FLOAT and output_dtype != DType.INT32:
error_result = True
elif input_dtype == DType.FLOAT and output_dtype != DType.FLOAT:
error_result = True
- elif op['op'] == Op.TABLE:
+ elif op["op"] == Op.TABLE:
if input_dtype == DType.INT8 and output_dtype != DType.INT8:
error_result = True
elif input_dtype == DType.INT16 and output_dtype != DType.INT32:
error_result = True
- elif op['op'] in [Op.EQUAL, Op.GREATER_EQUAL, Op.GREATER]:
+ elif op["op"] in [Op.EQUAL, Op.GREATER_EQUAL, Op.GREATER]:
if output_dtype != DType.BOOL:
error_result = True
- elif op['op'] == Op.CAST:
+ elif op["op"] == Op.CAST:
if (
- (input_dtype == DType.BOOL and output_dtype not in [DType.INT8, DType.INT16, DType.INT32])
- or (input_dtype == DType.INT8 and output_dtype not in [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT])
- or (input_dtype == DType.INT16 and output_dtype not in [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT])
- or (input_dtype == DType.INT32 and output_dtype not in [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT])
- or (input_dtype == DType.FLOAT and output_dtype not in [DType.INT8, DType.INT16, DType.INT32])
+ (
+ input_dtype == DType.BOOL
+ and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
+ )
+ or (
+ input_dtype == DType.INT8
+ and output_dtype
+ not in [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT]
+ )
+ or (
+ input_dtype == DType.INT16
+ and output_dtype
+ not in [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT]
+ )
+ or (
+ input_dtype == DType.INT32
+ and output_dtype
+ not in [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
+ )
+ or (
+ input_dtype == DType.FLOAT
+ and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
+ )
):
error_result = True
- elif op['op'] in {Op.CONV2D, Op.CONV3D, Op.DEPTHWISE_CONV2D, Op.TRANSPOSE_CONV2D}:
+ elif op["op"] in {
+ Op.CONV2D,
+ Op.CONV3D,
+ Op.DEPTHWISE_CONV2D,
+ Op.TRANSPOSE_CONV2D,
+ }:
if (
- input_dtype == DType.INT8 and output_dtype != DType.INT32
- or input_dtype == DType.INT16 and output_dtype != DType.INT48
- or input_dtype == DType.FLOAT and output_dtype != DType.FLOAT
+ input_dtype == DType.INT8
+ and output_dtype != DType.INT32
+ or input_dtype == DType.INT16
+ and output_dtype != DType.INT48
+ or input_dtype == DType.FLOAT
+ and output_dtype != DType.FLOAT
):
error_result = True
# invalid input types are ignored, to avoid reporting multiple errors
@@ -1712,8 +1912,10 @@ class TosaErrorValidator:
info_dict = {
"error_name": ErrorIf.WrongOutputType,
"error_result": error_result,
- "error_reason": "Output data type not supported for this configuration of operator",
- "param_reqs": {"rank": None, "dtype": None, "shape": None}
+ "error_reason": (
+ "Output data type not supported for this configuration of operator"
+ ),
+ "param_reqs": {"rank": None, "dtype": None, "shape": None},
}
return info_dict
@@ -1722,19 +1924,19 @@ class TosaErrorValidator:
all_ranks = (1, 2, 3, 4, 5)
# Make a list of incorrect ranks
- assert 'op' in kwargs
- op = kwargs['op']
- rmin, rmax = op['rank']
+ assert "op" in kwargs
+ op = kwargs["op"]
+ rmin, rmax = op["rank"]
rank_range = range(rmin, rmax + 1)
incorrect_ranks = list(set(all_ranks) - set(rank_range))
# Remove small incorrect ranks to avoid index errors
incorrect_ranks = [rank for rank in incorrect_ranks if rank > rmin]
# Set minimum incorrect rank to 3 to avoid index error
- if op['op'] in [Op.RESIZE]:
+ if op["op"] in [Op.RESIZE]:
incorrect_ranks = [3, 5]
- elif op['op'] in [Op.TRANSPOSE]:
+ elif op["op"] in [Op.TRANSPOSE]:
incorrect_ranks = [7, 8]
- elif op['op'] in [Op.CONV3D]:
+ elif op["op"] in [Op.CONV3D]:
incorrect_ranks = [6, 7]
error_name = ErrorIf.WrongRank
@@ -1743,13 +1945,16 @@ class TosaErrorValidator:
error_reason = "Rank not supported for this operator"
if check:
- input_shape = kwargs['input_shape']
+ input_shape = kwargs["input_shape"]
- if op['op'] in [Op.RESIZE, Op.AVG_POOL2D, Op.MAX_POOL2D] and len(input_shape) != 4:
+ if (
+ op["op"] in [Op.RESIZE, Op.AVG_POOL2D, Op.MAX_POOL2D]
+ and len(input_shape) != 4
+ ):
error_result = True
- elif op['op'] == Op.FULLY_CONNECTED and len(input_shape) != 2:
+ elif op["op"] == Op.FULLY_CONNECTED and len(input_shape) != 2:
error_result = True
- elif op['op'] == Op.MATMUL and len(input_shape) != 3:
+ elif op["op"] == Op.MATMUL and len(input_shape) != 3:
error_result = True
else:
if len(input_shape) not in rank_range:
@@ -1759,7 +1964,7 @@ class TosaErrorValidator:
"error_name": error_name,
"error_result": error_result,
"error_reason": error_reason,
- "param_reqs": param_reqs
+ "param_reqs": param_reqs,
}
return info_dict
@@ -1771,10 +1976,10 @@ class TosaErrorValidator:
error_reason = "Op input list does not match expected input"
if check:
- op = kwargs['op']
- input_list = kwargs['input_list']
- num_operands = kwargs['num_operands']
- if op['op'] in [Op.SCATTER, Op.GATHER]:
+ op = kwargs["op"]
+ input_list = kwargs["input_list"]
+ num_operands = kwargs["num_operands"]
+ if op["op"] in [Op.SCATTER, Op.GATHER]:
# SCATTER/GATHER add an indices input tensor in their build functions
num_operands += 1
if len(input_list) != num_operands:
@@ -1784,7 +1989,7 @@ class TosaErrorValidator:
"error_name": error_name,
"error_result": error_result,
"error_reason": error_reason,
- "param_reqs": param_reqs
+ "param_reqs": param_reqs,
}
return info_dict
@@ -1796,7 +2001,7 @@ class TosaErrorValidator:
error_reason = "Op output list does not match expected output"
if check:
- output_list = kwargs['output_list']
+ output_list = kwargs["output_list"]
# Note this will be incorrect if an operator returns more than one output
if len(output_list) != 1:
error_result = True
@@ -1805,7 +2010,7 @@ class TosaErrorValidator:
"error_name": error_name,
"error_result": error_result,
"error_reason": error_reason,
- "param_reqs": param_reqs
+ "param_reqs": param_reqs,
}
return info_dict
@@ -1813,45 +2018,51 @@ class TosaErrorValidator:
def evMaxDimExceeded(check=False, **kwargs):
error_name = ErrorIf.MaxDimExceeded
param_reqs = {
- "rank": [4,4],
+ "rank": [4, 4],
"dtype": [DType.INT8],
- "shape": [[1, 16584, 5, 1], [1, 2, 16499, 4]]
- }
+ "shape": [[1, 16584, 5, 1], [1, 2, 16499, 4]],
+ }
error_result = False
- error_reason = "At least one maximum dimension is greater than or equal to 16384"
+ error_reason = (
+ "At least one maximum dimension is greater than or equal to 16384"
+ )
if check:
- input_shape = kwargs['input_shape']
- output_shape = kwargs['output_shape'] # Note this is just (OH, OW)
- if ((input_shape[1] >= 16384) or
- (input_shape[2] >= 16384) or
- (output_shape[0] >= 16384) or
- (output_shape[1] >= 16384)):
+ input_shape = kwargs["input_shape"]
+ output_shape = kwargs["output_shape"] # Note this is just (OH, OW)
+ if (
+ (input_shape[1] >= 16384)
+ or (input_shape[2] >= 16384)
+ or (output_shape[0] >= 16384)
+ or (output_shape[1] >= 16384)
+ ):
error_result = True
info_dict = {
"error_name": error_name,
"error_result": error_result,
"error_reason": error_reason,
- "param_reqs": param_reqs
+ "param_reqs": param_reqs,
}
return info_dict
@staticmethod
def evBatchMismatch(check=False, **kwargs):
error_name = ErrorIf.BatchMismatch
- param_reqs = {"rank": [4,4], "dtype": None, "shape": None}
+ param_reqs = {"rank": [4, 4], "dtype": None, "shape": None}
error_result = False
error_reason = "Input batch size not equal to output batch size"
- assert 'op' in kwargs
- op = kwargs['op']
- rmin, rmax = op['rank']
+ assert "op" in kwargs
+ op = kwargs["op"]
+ rmin, rmax = op["rank"]
rank_range = range(rmin, rmax + 1)
if check:
- input_shape = kwargs['input_shape']
- output_shape = kwargs['result_tensor'].shape # Note this is just (N, OH, OW, C)
+ input_shape = kwargs["input_shape"]
+ output_shape = kwargs[
+ "result_tensor"
+ ].shape # Note this is just (N, OH, OW, C)
if (len(input_shape) in rank_range) and (input_shape[0] != output_shape[0]):
error_result = True
@@ -1860,25 +2071,27 @@ class TosaErrorValidator:
"error_name": error_name,
"error_result": error_result,
"error_reason": error_reason,
- "param_reqs": param_reqs
+ "param_reqs": param_reqs,
}
return info_dict
@staticmethod
def evChannelMismatch(check=False, **kwargs):
error_name = ErrorIf.ChannelMismatch
- param_reqs = {"rank": [4,4], "dtype": None, "shape": None}
+ param_reqs = {"rank": [4, 4], "dtype": None, "shape": None}
error_result = False
error_reason = "Input channel size not equal to output channel size"
- assert 'op' in kwargs
- op = kwargs['op']
- rmin, rmax = op['rank']
+ assert "op" in kwargs
+ op = kwargs["op"]
+ rmin, rmax = op["rank"]
rank_range = range(rmin, rmax + 1)
if check:
- input_shape = kwargs['input_shape']
- output_shape = kwargs['result_tensor'].shape # Note this is just (N, OH, OW, C)
+ input_shape = kwargs["input_shape"]
+ output_shape = kwargs[
+ "result_tensor"
+ ].shape # Note this is just (N, OH, OW, C)
if (len(input_shape) in rank_range) and (input_shape[3] != output_shape[3]):
error_result = True
@@ -1886,7 +2099,7 @@ class TosaErrorValidator:
"error_name": error_name,
"error_result": error_result,
"error_reason": error_reason,
- "param_reqs": param_reqs
+ "param_reqs": param_reqs,
}
return info_dict
@@ -1898,16 +2111,18 @@ class TosaErrorValidator:
error_reason = "Stride value smaller than or equal zero"
if check:
- input_dtype = kwargs['input_dtype']
- output_dtype = kwargs['output_dtype']
+ input_dtype = kwargs["input_dtype"]
+ output_dtype = kwargs["output_dtype"]
if input_dtype != DType.FLOAT and output_dtype == DType.FLOAT:
- stride = kwargs['stride'] # Work around wrong input/output type tests
+ stride = kwargs["stride"] # Work around wrong input/output type tests
elif output_dtype == DType.FLOAT:
- stride = kwargs['stride_fp']
+ stride = kwargs["stride_fp"]
elif input_dtype == DType.FLOAT and output_dtype != DType.FLOAT:
- stride = kwargs['stride_fp'] # Work around wrong input/output type tests
+ stride = kwargs[
+ "stride_fp"
+ ] # Work around wrong input/output type tests
else:
- stride = kwargs['stride']
+ stride = kwargs["stride"]
if min(stride) <= 0:
error_result = True
@@ -1916,7 +2131,7 @@ class TosaErrorValidator:
"error_name": error_name,
"error_result": error_result,
"error_reason": error_reason,
- "param_reqs": param_reqs
+ "param_reqs": param_reqs,
}
return info_dict
@@ -1928,24 +2143,27 @@ class TosaErrorValidator:
error_reason = "Stride value larger than or equal to maximum value"
if check:
- shift = kwargs['shift']
- input_dtype = kwargs['input_dtype']
- stride = kwargs['stride']
+ shift = kwargs["shift"]
+ input_dtype = kwargs["input_dtype"]
+ stride = kwargs["stride"]
if input_dtype in [DType.INT8, DType.INT16]:
- if shift >= 0 and (stride[0] >= (16 << shift) or stride[1] >= (16 << shift)):
+ if shift >= 0 and (
+ stride[0] >= (16 << shift) or stride[1] >= (16 << shift)
+ ):
error_result = True
- elif shift < 0 and (stride[0] >= (16 >> -shift) or stride[1] >= (16 >> -shift)):
+ elif shift < 0 and (
+ stride[0] >= (16 >> -shift) or stride[1] >= (16 >> -shift)
+ ):
error_result = True
info_dict = {
"error_name": error_name,
"error_result": error_result,
"error_reason": error_reason,
- "param_reqs": param_reqs
+ "param_reqs": param_reqs,
}
return info_dict
-
@staticmethod
def evStrideLargerDimension(check=False, **kwargs):
error_name = ErrorIf.StrideLargerDimension
@@ -1954,22 +2172,25 @@ class TosaErrorValidator:
error_reason = "Stride value larger than or equal to H/W dimension"
if check:
- shape = kwargs['input_shape']
- input_dtype = kwargs['input_dtype']
- stride = kwargs['stride_fp']
+ shape = kwargs["input_shape"]
+ input_dtype = kwargs["input_dtype"]
+ stride = kwargs["stride_fp"]
- if input_dtype == DType.FLOAT and (stride[0] > shape[1]) or (stride[1] > shape[2]):
+ if (
+ input_dtype == DType.FLOAT
+ and (stride[0] > shape[1])
+ or (stride[1] > shape[2])
+ ):
error_result = True
info_dict = {
"error_name": error_name,
"error_result": error_result,
"error_reason": error_reason,
- "param_reqs": param_reqs
+ "param_reqs": param_reqs,
}
return info_dict
-
@staticmethod
def evOffsetSmallerEqualMin(check=False, **kwargs):
error_name = ErrorIf.OffsetSmallerEqualMin
@@ -1978,23 +2199,27 @@ class TosaErrorValidator:
error_reason = "Offset value smaller than or equal to minimum value"
if check:
- shift = kwargs['shift']
- output_dtype = kwargs['output_dtype']
+ shift = kwargs["shift"]
+ output_dtype = kwargs["output_dtype"]
if output_dtype == DType.FLOAT:
- offset = kwargs['offset_fp']
+ offset = kwargs["offset_fp"]
else:
- offset = kwargs['offset']
+ offset = kwargs["offset"]
- if shift >= 0 and (offset[0] <= (-16 << shift) or offset[1] <= (-16 << shift)):
+ if shift >= 0 and (
+ offset[0] <= (-16 << shift) or offset[1] <= (-16 << shift)
+ ):
error_result = True
- elif shift < 0 and (offset[0] <= (-16 >> -shift) or offset[1] <= (-16 >> -shift)):
+ elif shift < 0 and (
+ offset[0] <= (-16 >> -shift) or offset[1] <= (-16 >> -shift)
+ ):
error_result = True
info_dict = {
"error_name": error_name,
"error_result": error_result,
"error_reason": error_reason,
- "param_reqs": param_reqs
+ "param_reqs": param_reqs,
}
return info_dict
@@ -2006,27 +2231,31 @@ class TosaErrorValidator:
error_reason = "Offset value larger than or equal to maximum value"
if check:
- shift = kwargs['shift']
- output_dtype = kwargs['output_dtype']
+ shift = kwargs["shift"]
+ output_dtype = kwargs["output_dtype"]
if output_dtype == DType.FLOAT:
- offset = kwargs['offset_fp']
+ offset = kwargs["offset_fp"]
else:
- offset = kwargs['offset']
+ offset = kwargs["offset"]
if shift >= 0:
if offset[0] >= (16 << shift) or offset[1] >= (16 << shift):
error_result = True
- if shift >= 0 and (offset[0] >= (16 << shift) or offset[1] >= (16 << shift)):
+ if shift >= 0 and (
+ offset[0] >= (16 << shift) or offset[1] >= (16 << shift)
+ ):
error_result = True
- elif shift < 0 and (offset[0] >= (16 >> -shift) or offset[1] >= (16 >> -shift)):
+ elif shift < 0 and (
+ offset[0] >= (16 >> -shift) or offset[1] >= (16 >> -shift)
+ ):
error_result = True
info_dict = {
"error_name": error_name,
"error_result": error_result,
"error_reason": error_reason,
- "param_reqs": param_reqs
+ "param_reqs": param_reqs,
}
return info_dict
@@ -2038,21 +2267,24 @@ class TosaErrorValidator:
error_reason = "Shift value must be zero for float input"
if check:
- shift = kwargs['shift']
- input_dtype = kwargs['input_dtype']
- output_dtype = kwargs['output_dtype']
- if input_dtype == DType.FLOAT and output_dtype == DType.FLOAT and shift != 0:
+ shift = kwargs["shift"]
+ input_dtype = kwargs["input_dtype"]
+ output_dtype = kwargs["output_dtype"]
+ if (
+ input_dtype == DType.FLOAT
+ and output_dtype == DType.FLOAT
+ and shift != 0
+ ):
error_result = True
info_dict = {
"error_name": error_name,
"error_result": error_result,
"error_reason": error_reason,
- "param_reqs": param_reqs
+ "param_reqs": param_reqs,
}
return info_dict
-
@staticmethod
def evShiftSmallerOne(check=False, **kwargs):
error_name = ErrorIf.ShiftSmallerOne
@@ -2061,9 +2293,9 @@ class TosaErrorValidator:
error_reason = "Shift value smaller than one"
if check:
- shift = kwargs['shift']
- input_dtype = kwargs['input_dtype']
- output_dtype = kwargs['output_dtype']
+ shift = kwargs["shift"]
+ input_dtype = kwargs["input_dtype"]
+ output_dtype = kwargs["output_dtype"]
if shift < 1 and input_dtype != DType.FLOAT and output_dtype != DType.FLOAT:
error_result = True
@@ -2071,7 +2303,7 @@ class TosaErrorValidator:
"error_name": error_name,
"error_result": error_result,
"error_reason": error_reason,
- "param_reqs": param_reqs
+ "param_reqs": param_reqs,
}
return info_dict
@@ -2083,7 +2315,7 @@ class TosaErrorValidator:
error_reason = "Shift value larger than eleven"
if check:
- shift = kwargs['shift']
+ shift = kwargs["shift"]
if shift > 11:
error_result = True
@@ -2091,11 +2323,10 @@ class TosaErrorValidator:
"error_name": error_name,
"error_result": error_result,
"error_reason": error_reason,
- "param_reqs": param_reqs
+ "param_reqs": param_reqs,
}
return info_dict
-
@staticmethod
def evRankMismatch(check=False, **kwargs):
error_name = ErrorIf.RankMismatch
@@ -2104,23 +2335,25 @@ class TosaErrorValidator:
error_reason = "Input Rank does not match output rank"
if check:
- input1_shape = kwargs['input1'].shape
- input2_shape = kwargs['input2'].shape
+ input1_shape = kwargs["input1"].shape
+ input2_shape = kwargs["input2"].shape
# In case of SELECT op
- input3_shape = kwargs['input3'].shape if 'input3' in kwargs else input2_shape
- output_shape = kwargs['result_tensor'].shape
+ input3_shape = (
+ kwargs["input3"].shape if "input3" in kwargs else input2_shape
+ )
+ output_shape = kwargs["result_tensor"].shape
if (
- (len(input1_shape) != len(output_shape)) or
- (len(input2_shape) != len(output_shape)) or
- (len(input3_shape) != len(output_shape))
- ):
+ (len(input1_shape) != len(output_shape))
+ or (len(input2_shape) != len(output_shape))
+ or (len(input3_shape) != len(output_shape))
+ ):
error_result = True
info_dict = {
"error_name": error_name,
"error_result": error_result,
"error_reason": error_reason,
- "param_reqs": param_reqs
+ "param_reqs": param_reqs,
}
return info_dict
@@ -2132,30 +2365,34 @@ class TosaErrorValidator:
error_reason = "Input Dimensions do not match output"
if check:
- input1_shape = kwargs['input1'].shape
- input2_shape = kwargs['input2'].shape
+ input1_shape = kwargs["input1"].shape
+ input2_shape = kwargs["input2"].shape
# In case of SELECT op
- input3_shape = kwargs['input3'].shape if 'input3' in kwargs else input2_shape
- output_shape = kwargs['result_tensor'].shape
- for i in range(min(len(input1_shape), len(input2_shape), len(input3_shape))):
+ input3_shape = (
+ kwargs["input3"].shape if "input3" in kwargs else input2_shape
+ )
+ output_shape = kwargs["result_tensor"].shape
+ for i in range(
+ min(len(input1_shape), len(input2_shape), len(input3_shape))
+ ):
if (
- (input1_shape[i] != 1 and input1_shape[i] != output_shape[i]) or
- (input2_shape[i] != 1 and input2_shape[i] != output_shape[i]) or
- (input3_shape[i] != 1 and input3_shape[i] != output_shape[i])
- ):
+ (input1_shape[i] != 1 and input1_shape[i] != output_shape[i])
+ or (input2_shape[i] != 1 and input2_shape[i] != output_shape[i])
+ or (input3_shape[i] != 1 and input3_shape[i] != output_shape[i])
+ ):
error_result = True
info_dict = {
"error_name": error_name,
"error_result": error_result,
"error_reason": error_reason,
- "param_reqs": param_reqs
+ "param_reqs": param_reqs,
}
return info_dict
@staticmethod
def evInputZeroPointNotZero(check=False, **kwargs):
- op = kwargs['op']
+ op = kwargs["op"]
error_result = False
# Quantizable types
@@ -2163,26 +2400,27 @@ class TosaErrorValidator:
# This does not apply to quantizable types
inputDtypes = [
- dtype for dtype in op['types']
- if (isinstance(dtype, list) and dtype[0] not in qTypes) or
- (not isinstance(dtype, list) and dtype not in qTypes)
+ dtype
+ for dtype in op["types"]
+ if (isinstance(dtype, list) and dtype[0] not in qTypes)
+ or (not isinstance(dtype, list) and dtype not in qTypes)
]
if check:
- input_dtype = kwargs['input_dtype']
- if isinstance(kwargs['qinfo'], tuple):
- qinfo = kwargs['qinfo']
+ input_dtype = kwargs["input_dtype"]
+ if isinstance(kwargs["qinfo"], tuple):
+ qinfo = kwargs["qinfo"]
input_zero_point = qinfo[0]
else:
# For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
- qinfo = kwargs['qinfo'].ints
+ qinfo = kwargs["qinfo"].ints
input_zero_point = qinfo[0][1]
- if op['op'] == Op.MATMUL:
- qinfo = kwargs['qinfo'].ints
+ if op["op"] == Op.MATMUL:
+ qinfo = kwargs["qinfo"].ints
for dtype, zp in (
- (kwargs['input_dtype'], qinfo[0][1]),
- (kwargs['input2_dtype'], qinfo[1][1]),
+ (kwargs["input_dtype"], qinfo[0][1]),
+ (kwargs["input2_dtype"], qinfo[1][1]),
):
if dtype not in qTypes and zp != 0:
error_result = True
@@ -2194,32 +2432,28 @@ class TosaErrorValidator:
"error_name": ErrorIf.InputZeroPointNotZero,
"error_result": error_result,
"error_reason": "Input DType not INT8 and zero point not 0",
- "param_reqs": {"rank": None, "dtype": inputDtypes, "shape": None}
+ "param_reqs": {"rank": None, "dtype": inputDtypes, "shape": None},
}
return info_dict
-
@staticmethod
def evWeightZeroPointNotZero(check=False, **kwargs):
- op = kwargs['op']
+ op = kwargs["op"]
# exclude inputs with INT8 weights
- inputDtypes = [t for t in op['types']
- if not isinstance(t, list) or t[1] != DType.INT8]
+ inputDtypes = [
+ t for t in op["types"] if not isinstance(t, list) or t[1] != DType.INT8
+ ]
error_name = ErrorIf.WeightZeroPointNotZero
- param_reqs = {
- "rank": None,
- "dtype": inputDtypes,
- "shape": None
- }
+ param_reqs = {"rank": None, "dtype": inputDtypes, "shape": None}
error_result = False
error_reason = "Weight DType not INT8 and zero point not 0"
if check:
- weight_dtype = kwargs['weight_dtype']
+ weight_dtype = kwargs["weight_dtype"]
# For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = weight_zp
- qinfo = kwargs['qinfo'].ints
+ qinfo = kwargs["qinfo"].ints
weight_zero_point = qinfo[1][1]
if weight_dtype != DType.INT8 and weight_zero_point != 0:
error_result = True
@@ -2228,50 +2462,47 @@ class TosaErrorValidator:
"error_name": error_name,
"error_result": error_result,
"error_reason": error_reason,
- "param_reqs": param_reqs
+ "param_reqs": param_reqs,
}
return info_dict
-
@staticmethod
def evOutputZeroPointNotZero(check=False, **kwargs):
- op = kwargs['op']
- inputDtypes = op['types'].copy()
+ op = kwargs["op"]
+ inputDtypes = op["types"].copy()
if DType.INT8 in inputDtypes:
inputDtypes.remove(DType.INT8)
if DType.UINT8 in inputDtypes:
inputDtypes.remove(DType.UINT8)
error_name = ErrorIf.OutputZeroPointNotZero
- param_reqs = {
- "rank": None,
- "dtype": inputDtypes,
- "shape": None
- }
+ param_reqs = {"rank": None, "dtype": inputDtypes, "shape": None}
error_result = False
error_reason = "Output DType not INT8 and zero point not 0"
if check:
- input_dtype = kwargs['input_dtype']
- output_dtype = kwargs['output_dtype']
- if isinstance(kwargs['qinfo'], tuple):
- qinfo = kwargs['qinfo']
+ input_dtype = kwargs["input_dtype"]
+ output_dtype = kwargs["output_dtype"]
+ if isinstance(kwargs["qinfo"], tuple):
+ qinfo = kwargs["qinfo"]
output_zero_point = qinfo[1]
else:
# For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
- qinfo = kwargs['qinfo'].ints
+ qinfo = kwargs["qinfo"].ints
output_zero_point = qinfo[1][1]
- if op['op'] == Op.AVG_POOL2D:
+ if op["op"] == Op.AVG_POOL2D:
if input_dtype != DType.INT8 and output_zero_point != 0:
error_result = True
- elif output_dtype not in [DType.INT8, DType.UINT8] and output_zero_point != 0:
+ elif (
+ output_dtype not in [DType.INT8, DType.UINT8] and output_zero_point != 0
+ ):
error_result = True
info_dict = {
"error_name": error_name,
"error_result": error_result,
"error_reason": error_reason,
- "param_reqs": param_reqs
+ "param_reqs": param_reqs,
}
return info_dict
@@ -2283,7 +2514,7 @@ class TosaErrorValidator:
error_reason = "Axis smaller than zero"
if check:
- axis = kwargs['axis']
+ axis = kwargs["axis"]
if axis < 0:
error_result = True
@@ -2291,11 +2522,10 @@ class TosaErrorValidator:
"error_name": error_name,
"error_result": error_result,
"error_reason": error_reason,
- "param_reqs": param_reqs
+ "param_reqs": param_reqs,
}
return info_dict
-
@staticmethod
def evAxisLargerRank(check=False, **kwargs):
error_name = ErrorIf.AxisLargerRank
@@ -2304,8 +2534,8 @@ class TosaErrorValidator:
error_reason = "Axis larger than rank"
if check:
- axis = kwargs['axis']
- shape = kwargs['input_shape']
+ axis = kwargs["axis"]
+ shape = kwargs["input_shape"]
if axis > len(shape):
error_result = True
@@ -2313,11 +2543,10 @@ class TosaErrorValidator:
"error_name": error_name,
"error_result": error_result,
"error_reason": error_reason,
- "param_reqs": param_reqs
+ "param_reqs": param_reqs,
}
return info_dict
-
@staticmethod
def evShapeOfAxisNotOne(check=False, **kwargs):
error_name = ErrorIf.ShapeOfAxisNotOne
@@ -2326,8 +2555,8 @@ class TosaErrorValidator:
error_reason = "shape[axis] is not equal to 1"
if check:
- axis = kwargs['axis']
- shape = kwargs['output_shape']
+ axis = kwargs["axis"]
+ shape = kwargs["output_shape"]
if (0 <= axis < len(shape)) and shape[axis] != 1:
error_result = True
@@ -2335,11 +2564,10 @@ class TosaErrorValidator:
"error_name": error_name,
"error_result": error_result,
"error_reason": error_reason,
- "param_reqs": param_reqs
+ "param_reqs": param_reqs,
}
return info_dict
-
@staticmethod
def evPadSmallerZero(check=False, **kwargs):
error_name = ErrorIf.PadSmallerZero
@@ -2348,9 +2576,9 @@ class TosaErrorValidator:
error_reason = "At least one pad is smaller than zero"
if check:
- op = kwargs['op']
- pad = kwargs['pad']
- if op['op'] == Op.PAD:
+ op = kwargs["op"]
+ pad = kwargs["pad"]
+ if op["op"] == Op.PAD:
for padding in pad:
if min(padding) < 0:
error_result = True
@@ -2362,11 +2590,10 @@ class TosaErrorValidator:
"error_name": error_name,
"error_result": error_result,
"error_reason": error_reason,
- "param_reqs": param_reqs
+ "param_reqs": param_reqs,
}
return info_dict
-
@staticmethod
def evPadLargerEqualKernel(check=False, **kwargs):
error_name = ErrorIf.PadLargerEqualKernel
@@ -2375,17 +2602,22 @@ class TosaErrorValidator:
error_reason = "At least one pad is larger than kernel dimension"
if check:
- pad = kwargs['pad']
- kernel = kwargs['kernel']
+ pad = kwargs["pad"]
+ kernel = kwargs["kernel"]
if min(pad) > 0 and min(kernel) > 1:
- if pad[0] >= kernel[0] or pad[1] >= kernel[0] or pad[2] >= kernel[1] or pad[3] >= kernel[1]:
+ if (
+ pad[0] >= kernel[0]
+ or pad[1] >= kernel[0]
+ or pad[2] >= kernel[1]
+ or pad[3] >= kernel[1]
+ ):
error_result = True
info_dict = {
"error_name": error_name,
"error_result": error_result,
"error_reason": error_reason,
- "param_reqs": param_reqs
+ "param_reqs": param_reqs,
}
return info_dict
@@ -2394,32 +2626,47 @@ class TosaErrorValidator:
error_name = ErrorIf.PoolingOutputShapeMismatch
param_reqs = {"rank": None, "dtype": None, "shape": None}
error_result = False
- error_reason = "Mismatch between output shape provided and expected output shape"
+ error_reason = (
+ "Mismatch between output shape provided and expected output shape"
+ )
if check:
- pad = kwargs['pad']
+ pad = kwargs["pad"]
pad_top, pad_bottom, pad_left, pad_right = pad[0], pad[1], pad[2], pad[3]
- kernel = kwargs['kernel']
+ kernel = kwargs["kernel"]
kernel_y, kernel_x = kernel[0], kernel[1]
- input_shape = kwargs['input_shape']
+ input_shape = kwargs["input_shape"]
IH, IW = input_shape[1], input_shape[2]
- output_shape = kwargs['output_shape']
+ output_shape = kwargs["output_shape"]
OH, OW = output_shape[1], output_shape[2]
- stride = kwargs['stride']
+ stride = kwargs["stride"]
stride_y, stride_x = stride[0], stride[1]
# calculate correct height, width dimensions
if stride_x != 0 and stride_y != 0:
- y_correct = (IH + pad_top + pad_bottom + stride_y - kernel_y) // stride_y
- x_correct = (IW + pad_left + pad_right + stride_x - kernel_x) // stride_x
+ y_correct = (
+ IH + pad_top + pad_bottom + stride_y - kernel_y
+ ) // stride_y
+ x_correct = (
+ IW + pad_left + pad_right + stride_x - kernel_x
+ ) // stride_x
# ensure parameters are valid
- params_valid = (min(kernel) >= 1 and min(stride) >= 1 and min(pad) >= 0
- and not (pad[0] >= kernel[0] or pad[1] >= kernel[0] or pad[2] >= kernel[1] or pad[3] >= kernel[1]))
+ params_valid = (
+ min(kernel) >= 1
+ and min(stride) >= 1
+ and min(pad) >= 0
+ and not (
+ pad[0] >= kernel[0]
+ or pad[1] >= kernel[0]
+ or pad[2] >= kernel[1]
+ or pad[3] >= kernel[1]
+ )
+ )
if params_valid and (OH != y_correct or OW != x_correct):
error_result = True
@@ -2428,21 +2675,23 @@ class TosaErrorValidator:
"error_name": error_name,
"error_result": error_result,
"error_reason": error_reason,
- "param_reqs": param_reqs
+ "param_reqs": param_reqs,
}
return info_dict
@staticmethod
def evArgmaxOutputShapeMismatch(check=False, **kwargs):
error_name = ErrorIf.ArgmaxOutputShapeMismatch
- param_reqs = {"rank": [2,4], "dtype": None, "shape": None}
+ param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
error_result = False
- error_reason = "Mismatch between output shape provided and expected output shape"
+ error_reason = (
+ "Mismatch between output shape provided and expected output shape"
+ )
if check:
- output_shape = kwargs['output_shape']
- input_shape = kwargs['input_shape']
- axis = kwargs['axis']
+ output_shape = kwargs["output_shape"]
+ input_shape = kwargs["input_shape"]
+ axis = kwargs["axis"]
dimension_match = True
axis_shift = 0
@@ -2463,7 +2712,7 @@ class TosaErrorValidator:
"error_name": error_name,
"error_result": error_result,
"error_reason": error_reason,
- "param_reqs": param_reqs
+ "param_reqs": param_reqs,
}
return info_dict
@@ -2472,12 +2721,14 @@ class TosaErrorValidator:
error_name = ErrorIf.ArgmaxOutputRankMismatch
param_reqs = {"rank": None, "dtype": None, "shape": None}
error_result = False
- error_reason = "Mismatch between output shape provided and expected output shape"
+ error_reason = (
+ "Mismatch between output shape provided and expected output shape"
+ )
if check:
- output_shape = kwargs['output_shape']
- input_shape = kwargs['input_shape']
- axis = kwargs['axis']
+ output_shape = kwargs["output_shape"]
+ input_shape = kwargs["input_shape"]
+ axis = kwargs["axis"]
valid_params = axis >= 0 and axis < len(input_shape)
if valid_params and (len(input_shape) - 1) != len(output_shape):
@@ -2487,11 +2738,10 @@ class TosaErrorValidator:
"error_name": error_name,
"error_result": error_result,
"error_reason": error_reason,
- "param_reqs": param_reqs
+ "param_reqs": param_reqs,
}
return info_dict
-
@staticmethod
def evKernelSmallerOne(check=False, **kwargs):
error_name = ErrorIf.KernelSmallerOne
@@ -2500,7 +2750,7 @@ class TosaErrorValidator:
error_reason = "At least one kernel dimension is smaller than zero"
if check:
- kernel = kwargs['kernel']
+ kernel = kwargs["kernel"]
if min(kernel) < 1:
error_result = True
@@ -2508,7 +2758,7 @@ class TosaErrorValidator:
"error_name": error_name,
"error_result": error_result,
"error_reason": error_reason,
- "param_reqs": param_reqs
+ "param_reqs": param_reqs,
}
return info_dict
@@ -2520,7 +2770,7 @@ class TosaErrorValidator:
error_reason = "At least one stride dimension is smaller than zero"
if check:
- stride = kwargs['stride']
+ stride = kwargs["stride"]
if min(stride) < 1:
error_result = True
@@ -2528,18 +2778,18 @@ class TosaErrorValidator:
"error_name": error_name,
"error_result": error_result,
"error_reason": error_reason,
- "param_reqs": param_reqs
+ "param_reqs": param_reqs,
}
return info_dict
@staticmethod
def evDilationSmallerOne(check=False, **kwargs):
- error_result = check and min(kwargs['dilation']) < 1
+ error_result = check and min(kwargs["dilation"]) < 1
return {
"error_name": ErrorIf.DilationSmallerOne,
"error_reason": "At least one dilation is smaller than one",
"param_reqs": {"rank": None, "dtype": None, "shape": None},
- "error_result": error_result
+ "error_result": error_result,
}
@staticmethod
@@ -2550,8 +2800,8 @@ class TosaErrorValidator:
error_reason = "Scale set to true but input type is INT48"
if check:
- input_dtype = kwargs['input_dtype']
- scale32 = kwargs['scale32']
+ input_dtype = kwargs["input_dtype"]
+ scale32 = kwargs["scale32"]
if scale32 and input_dtype == DType.INT48:
error_result = True
@@ -2559,7 +2809,7 @@ class TosaErrorValidator:
"error_name": error_name,
"error_result": error_result,
"error_reason": error_reason,
- "param_reqs": param_reqs
+ "param_reqs": param_reqs,
}
return info_dict
@@ -2571,8 +2821,8 @@ class TosaErrorValidator:
error_reason = "Scale set to false but double round set to true"
if check:
- scale32 = kwargs['scale32']
- double_round = kwargs['double_round']
+ scale32 = kwargs["scale32"]
+ double_round = kwargs["double_round"]
if not scale32 and double_round:
error_result = True
@@ -2580,7 +2830,7 @@ class TosaErrorValidator:
"error_name": error_name,
"error_result": error_result,
"error_reason": error_reason,
- "param_reqs": param_reqs
+ "param_reqs": param_reqs,
}
return info_dict
@@ -2592,8 +2842,8 @@ class TosaErrorValidator:
error_reason = "Input tensor size does not match output tensor size"
if check:
- input_shape = kwargs['input_shape']
- output_shape = kwargs['output_shape']
+ input_shape = kwargs["input_shape"]
+ output_shape = kwargs["output_shape"]
input_size = np.prod(input_shape)
output_size = np.prod(output_shape)
if input_size != output_size:
@@ -2603,7 +2853,7 @@ class TosaErrorValidator:
"error_name": error_name,
"error_result": error_result,
"error_reason": error_reason,
- "param_reqs": param_reqs
+ "param_reqs": param_reqs,
}
return info_dict
@@ -2615,8 +2865,8 @@ class TosaErrorValidator:
error_reason = "Starting point smaller than zero"
if check:
- input_shape = kwargs['input_shape']
- start = kwargs['start']
+ input_shape = kwargs["input_shape"]
+ start = kwargs["start"]
rank = len(input_shape)
if len(start) == rank:
for index in range(rank):
@@ -2627,11 +2877,10 @@ class TosaErrorValidator:
"error_name": error_name,
"error_result": error_result,
"error_reason": error_reason,
- "param_reqs": param_reqs
+ "param_reqs": param_reqs,
}
return info_dict
-
@staticmethod
def evSizeSmallerEqualZero(check=False, **kwargs):
error_name = ErrorIf.SizeSmallerEqualZero
@@ -2640,8 +2889,8 @@ class TosaErrorValidator:
error_reason = "Size smaller than or equal to zero"
if check:
- input_shape = kwargs['input_shape']
- size = kwargs['size']
+ input_shape = kwargs["input_shape"]
+ size = kwargs["size"]
rank = len(input_shape)
if len(size) == rank:
for index in range(rank):
@@ -2652,11 +2901,10 @@ class TosaErrorValidator:
"error_name": error_name,
"error_result": error_result,
"error_reason": error_reason,
- "param_reqs": param_reqs
+ "param_reqs": param_reqs,
}
return info_dict
-
@staticmethod
def evStartSizeOutsideBounds(check=False, **kwargs):
error_name = ErrorIf.StartSizeOutsideBounds
@@ -2665,9 +2913,9 @@ class TosaErrorValidator:
error_reason = "starting point plus size larger than input dimension"
if check:
- input_shape = kwargs['input_shape']
- start = kwargs['start']
- size = kwargs['size']
+ input_shape = kwargs["input_shape"]
+ start = kwargs["start"]
+ size = kwargs["size"]
rank = len(input_shape)
if len(start) == rank and len(size) == rank:
for index in range(rank):
@@ -2678,11 +2926,10 @@ class TosaErrorValidator:
"error_name": error_name,
"error_result": error_result,
"error_reason": error_reason,
- "param_reqs": param_reqs
+ "param_reqs": param_reqs,
}
return info_dict
-
@staticmethod
def evSizeOutputShapeMismatch(check=False, **kwargs):
error_name = ErrorIf.SizeOutputShapeMismatch
@@ -2691,9 +2938,9 @@ class TosaErrorValidator:
error_reason = "Size does not match output dimension"
if check:
- input_shape = kwargs['input_shape']
- output_shape = kwargs['output_shape']
- size = kwargs['size']
+ input_shape = kwargs["input_shape"]
+ output_shape = kwargs["output_shape"]
+ size = kwargs["size"]
rank = len(input_shape)
if len(size) == rank:
for index in range(rank):
@@ -2704,7 +2951,7 @@ class TosaErrorValidator:
"error_name": error_name,
"error_result": error_result,
"error_reason": error_reason,
- "param_reqs": param_reqs
+ "param_reqs": param_reqs,
}
return info_dict
@@ -2716,9 +2963,9 @@ class TosaErrorValidator:
error_reason = "rank of input not equal to length of start or size"
if check:
- input_shape = kwargs['input_shape']
- start = kwargs['start']
- size = kwargs['size']
+ input_shape = kwargs["input_shape"]
+ start = kwargs["start"]
+ size = kwargs["size"]
rank = len(input_shape)
if rank != len(start) or rank != len(size):
error_result = True
@@ -2727,7 +2974,7 @@ class TosaErrorValidator:
"error_name": error_name,
"error_result": error_result,
"error_reason": error_reason,
- "param_reqs": param_reqs
+ "param_reqs": param_reqs,
}
return info_dict
@@ -2739,8 +2986,8 @@ class TosaErrorValidator:
error_reason = "Index outside of allowed bounds"
if check:
- input_shape = kwargs['input_shape']
- perms = kwargs['perms']
+ input_shape = kwargs["input_shape"]
+ perms = kwargs["perms"]
rank = len(input_shape)
for index in perms:
@@ -2751,21 +2998,19 @@ class TosaErrorValidator:
"error_name": error_name,
"error_result": error_result,
"error_reason": error_reason,
- "param_reqs": param_reqs
+ "param_reqs": param_reqs,
}
return info_dict
@staticmethod
def evIndexUsedTwice(check=False, **kwargs):
error_name = ErrorIf.IndexUsedTwice
- param_reqs = {"rank": [2,4], "dtype": None, "shape": None}
+ param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
error_result = False
error_reason = "Index used multiple times"
if check:
- input_shape = kwargs['input_shape']
- perms = kwargs['perms']
- rank = len(input_shape)
+ perms = kwargs["perms"]
unique_indices = []
for index in perms:
@@ -2778,42 +3023,41 @@ class TosaErrorValidator:
"error_name": error_name,
"error_result": error_result,
"error_reason": error_reason,
- "param_reqs": param_reqs
+ "param_reqs": param_reqs,
}
return info_dict
@staticmethod
def evMaxSmallerMin(check=False, **kwargs):
error_name = ErrorIf.MaxSmallerMin
- param_reqs = {"rank": [2,4], "dtype": None, "shape": None}
+ param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
error_result = False
error_reason = "Max value smaller than min value"
if check:
- max_val = kwargs['max_val']
- min_val = kwargs['min_val']
+ max_val = kwargs["max_val"]
+ min_val = kwargs["min_val"]
if max_val < min_val:
error_result = True
-
info_dict = {
"error_name": error_name,
"error_result": error_result,
"error_reason": error_reason,
- "param_reqs": param_reqs
+ "param_reqs": param_reqs,
}
return info_dict
@staticmethod
def evConcatInputRankMismatch(check=False, **kwargs):
error_name = ErrorIf.ConcatInputRankMismatch
- param_reqs = {"rank": [2,4], "dtype": None, "shape": None}
+ param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
error_result = False
error_reason = "Input ranks are not identical"
if check:
- inputs = kwargs['inputs']
- input_shape = kwargs['input_shape']
+ inputs = kwargs["inputs"]
+ input_shape = kwargs["input_shape"]
for input in inputs:
if len(input.shape) != len(input_shape):
error_result = True
@@ -2822,21 +3066,21 @@ class TosaErrorValidator:
"error_name": error_name,
"error_result": error_result,
"error_reason": error_reason,
- "param_reqs": param_reqs
+ "param_reqs": param_reqs,
}
return info_dict
@staticmethod
def evConcatInputDimMismatch(check=False, **kwargs):
error_name = ErrorIf.ConcatInputDimMismatch
- param_reqs = {"rank": [2,4], "dtype": None, "shape": None}
+ param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
error_result = False
error_reason = "Input dimensions differ on too many axes"
if check:
- inputs = kwargs['inputs']
- input_shape = kwargs['input_shape']
- axis = kwargs['axis']
+ inputs = kwargs["inputs"]
+ input_shape = kwargs["input_shape"]
+ axis = kwargs["axis"]
# Ensure rank is valid before checking dims.
valid_rank = True
@@ -2854,22 +3098,22 @@ class TosaErrorValidator:
"error_name": error_name,
"error_result": error_result,
"error_reason": error_reason,
- "param_reqs": param_reqs
+ "param_reqs": param_reqs,
}
return info_dict
@staticmethod
def evConcatShapeSumMismatch(check=False, **kwargs):
error_name = ErrorIf.ConcatShapeSumMismatch
- param_reqs = {"rank": [2,4], "dtype": None, "shape": None}
+ param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
error_result = False
error_reason = "Sum of dimensions on axis not equal to output dimension"
if check:
- inputs = kwargs['inputs']
- input_shape = kwargs['input_shape']
- output_shape = kwargs['output_shape']
- axis = kwargs['axis']
+ inputs = kwargs["inputs"]
+ input_shape = kwargs["input_shape"]
+ output_shape = kwargs["output_shape"]
+ axis = kwargs["axis"]
# Ensure rank is valid before checking dims.
valid_params = True
@@ -2887,12 +3131,11 @@ class TosaErrorValidator:
if axis_dim_sum != output_shape[axis]:
error_result = True
-
info_dict = {
"error_name": error_name,
"error_result": error_result,
"error_reason": error_reason,
- "param_reqs": param_reqs
+ "param_reqs": param_reqs,
}
return info_dict
@@ -2904,24 +3147,25 @@ class TosaErrorValidator:
error_reason = "Input list shape does not match then-graph shape"
if check:
- a = kwargs['a']
- b = kwargs['b']
- basicBlocks = kwargs['basicBlocks']
+ a = kwargs["a"]
+ b = kwargs["b"]
+ basicBlocks = kwargs["basicBlocks"]
then_block = basicBlocks[1]
then_inputs = then_block.inputs
then_tens = then_block.tensors
- if (a.shape != then_tens[then_inputs[0]].shape) or (b.shape != then_tens[then_inputs[1]].shape):
+ if (a.shape != then_tens[then_inputs[0]].shape) or (
+ b.shape != then_tens[then_inputs[1]].shape
+ ):
error_result = True
info_dict = {
"error_name": error_name,
"error_result": error_result,
"error_reason": error_reason,
- "param_reqs": param_reqs
+ "param_reqs": param_reqs,
}
return info_dict
-
@staticmethod
def evInputListElseGraphMismatch(check=False, **kwargs):
error_name = ErrorIf.CondIfInputListElseGraphMismatch
@@ -2930,24 +3174,25 @@ class TosaErrorValidator:
error_reason = "Input list shape does not match else-graph shape"
if check:
- a = kwargs['a']
- b = kwargs['b']
- basicBlocks = kwargs['basicBlocks']
+ a = kwargs["a"]
+ b = kwargs["b"]
+ basicBlocks = kwargs["basicBlocks"]
else_block = basicBlocks[2]
else_inputs = else_block.inputs
else_tens = else_block.tensors
- if (a.shape != else_tens[else_inputs[0]].shape) or (b.shape != else_tens[else_inputs[1]].shape):
+ if (a.shape != else_tens[else_inputs[0]].shape) or (
+ b.shape != else_tens[else_inputs[1]].shape
+ ):
error_result = True
info_dict = {
"error_name": error_name,
"error_result": error_result,
"error_reason": error_reason,
- "param_reqs": param_reqs
+ "param_reqs": param_reqs,
}
return info_dict
-
@staticmethod
def evOutputListThenGraphMismatch(check=False, **kwargs):
error_name = ErrorIf.CondIfOutputListThenGraphMismatch
@@ -2956,7 +3201,7 @@ class TosaErrorValidator:
error_reason = "Output list shape does not match then-graph shape"
if check:
- basicBlocks = kwargs['basicBlocks']
+ basicBlocks = kwargs["basicBlocks"]
cond_block = basicBlocks[0]
cond_outputs = cond_block.outputs
cond_tens = cond_block.tensors
@@ -2970,11 +3215,10 @@ class TosaErrorValidator:
"error_name": error_name,
"error_result": error_result,
"error_reason": error_reason,
- "param_reqs": param_reqs
+ "param_reqs": param_reqs,
}
return info_dict
-
@staticmethod
def evOutputListElseGraphMismatch(check=False, **kwargs):
error_name = ErrorIf.CondIfOutputListElseGraphMismatch
@@ -2983,7 +3227,7 @@ class TosaErrorValidator:
error_reason = "Output list shape does not match else-graph shape"
if check:
- basicBlocks = kwargs['basicBlocks']
+ basicBlocks = kwargs["basicBlocks"]
cond_block = basicBlocks[0]
cond_outputs = cond_block.outputs
cond_tens = cond_block.tensors
@@ -2997,11 +3241,10 @@ class TosaErrorValidator:
"error_name": error_name,
"error_result": error_result,
"error_reason": error_reason,
- "param_reqs": param_reqs
+ "param_reqs": param_reqs,
}
return info_dict
-
@staticmethod
def evInputListOutputListMismatch(check=False, **kwargs):
error_name = ErrorIf.InputListOutputListMismatch
@@ -3010,7 +3253,7 @@ class TosaErrorValidator:
error_reason = "Input list does not match output list"
if check:
- basicBlocks = kwargs['basicBlocks']
+ basicBlocks = kwargs["basicBlocks"]
while_block = basicBlocks[0]
while_inputs = while_block.inputs
while_outputs = while_block.outputs
@@ -3022,11 +3265,10 @@ class TosaErrorValidator:
"error_name": error_name,
"error_result": error_result,
"error_reason": error_reason,
- "param_reqs": param_reqs
+ "param_reqs": param_reqs,
}
return info_dict
-
@staticmethod
def evInputListCondGraphMismatch(check=False, **kwargs):
error_name = ErrorIf.InputListCondGraphMismatch
@@ -3035,26 +3277,26 @@ class TosaErrorValidator:
error_reason = "Input list does not match cond graph"
if check:
- basicBlocks = kwargs['basicBlocks']
+ basicBlocks = kwargs["basicBlocks"]
while_block = basicBlocks[0]
while_inputs = while_block.inputs
while_tens = while_block.tensors
cond_block = basicBlocks[1]
cond_inputs = cond_block.inputs
cond_tens = cond_block.tensors
- if ((while_tens[while_inputs[0]].shape != cond_tens[cond_inputs[0]].shape) or
- (while_tens[while_inputs[1]].shape != cond_tens[cond_inputs[2]].shape)):
+ if (
+ while_tens[while_inputs[0]].shape != cond_tens[cond_inputs[0]].shape
+ ) or (while_tens[while_inputs[1]].shape != cond_tens[cond_inputs[2]].shape):
error_result = True
info_dict = {
"error_name": error_name,
"error_result": error_result,
"error_reason": error_reason,
- "param_reqs": param_reqs
+ "param_reqs": param_reqs,
}
return info_dict
-
@staticmethod
def evInputListBodyGraphInputMismatch(check=False, **kwargs):
error_name = ErrorIf.InputListBodyGraphInputMismatch
@@ -3063,26 +3305,28 @@ class TosaErrorValidator:
error_reason = "Input list does not match body graph input"
if check:
- basicBlocks = kwargs['basicBlocks']
+ basicBlocks = kwargs["basicBlocks"]
while_block = basicBlocks[0]
while_inputs = while_block.inputs
while_tens = while_block.tensors
body_block = basicBlocks[2]
body_outputs = body_block.inputs
body_tens = body_block.tensors
- if ((while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape) or
- (while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape)):
+ if (
+ while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape
+ ) or (
+ while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape
+ ):
error_result = True
info_dict = {
"error_name": error_name,
"error_result": error_result,
"error_reason": error_reason,
- "param_reqs": param_reqs
+ "param_reqs": param_reqs,
}
return info_dict
-
@staticmethod
def evInputListBodyGraphOutputMismatch(check=False, **kwargs):
error_name = ErrorIf.InputListBodyGraphOutputMismatch
@@ -3091,25 +3335,27 @@ class TosaErrorValidator:
error_reason = "Input list does not match body graph output"
if check:
- basicBlocks = kwargs['basicBlocks']
+ basicBlocks = kwargs["basicBlocks"]
while_block = basicBlocks[0]
while_inputs = while_block.inputs
while_tens = while_block.tensors
body_block = basicBlocks[2]
body_outputs = body_block.outputs
body_tens = body_block.tensors
- if ((while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape) or
- (while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape)):
+ if (
+ while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape
+ ) or (
+ while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape
+ ):
error_result = True
info_dict = {
"error_name": error_name,
"error_result": error_result,
"error_reason": error_reason,
- "param_reqs": param_reqs
+ "param_reqs": param_reqs,
}
return info_dict
-
@staticmethod
def evCondGraphOutputNotMatchingBool(check=False, **kwargs):
error_name = ErrorIf.CondGraphOutputNotMatchingBool
@@ -3118,7 +3364,7 @@ class TosaErrorValidator:
error_reason = "Cond graph output is not a match list of booleans"
if check:
- basicBlocks = kwargs['basicBlocks']
+ basicBlocks = kwargs["basicBlocks"]
cond_block = basicBlocks[1]
cond_outputs = cond_block.outputs
cond_tens = cond_block.tensors
@@ -3129,35 +3375,31 @@ class TosaErrorValidator:
"error_name": error_name,
"error_result": error_result,
"error_reason": error_reason,
- "param_reqs": param_reqs
+ "param_reqs": param_reqs,
}
return info_dict
class TosaInvalidValidator:
-
@staticmethod
def ivWrongDataTypeOrModeResize(**kwargs):
input_dtype = kwargs["input_dtype"]
args = kwargs["args"]
mode = args[0]
- stride = args[1]
- stride_fp = args[4]
output_dtype = args[8]
if mode == ResizeMode.BILINEAR:
# Invalid output data type / Invalid input datatype
return (
- not (input_dtype == DType.INT8 and output_dtype == DType.INT32) or
- not (input_dtype == DType.INT16 and output_dtype == DType.INT48) or
- not (input_dtype == DType.FLOAT and output_dtype == DType.FLOAT) or
- (input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT])
+ not (input_dtype == DType.INT8 and output_dtype == DType.INT32)
+ or not (input_dtype == DType.INT16 and output_dtype == DType.INT48)
+ or not (input_dtype == DType.FLOAT and output_dtype == DType.FLOAT)
+ or (input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT])
)
elif mode == ResizeMode.NEAREST:
# Invalid output data type / Invalid input datatype
- return (
- (input_dtype != output_dtype) or
- (input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT])
+ return (input_dtype != output_dtype) or (
+ input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT]
)
else:
# Invalid resize mode
@@ -3184,20 +3426,24 @@ class TosaInvalidValidator:
@staticmethod
def ivHeightWidthInvalid(**kwargs):
- opName = kwargs['opName']
+ opName = kwargs["opName"]
- inputShapes = kwargs['shapeList']
+ inputShapes = kwargs["shapeList"]
input_shape = inputShapes[0]
- args = kwargs['args']
+ args = kwargs["args"]
strides = args[0]
padding = args[1]
if opName.endswith("pool2d"):
# avg_pool2d, max_pool2d
kernel_shape = args[2]
- h = (input_shape[1] + padding[0] + padding[1] + strides[0] - kernel_shape[0]) // strides[0]
- w = (input_shape[2] + padding[2] + padding[3] + strides[1] - kernel_shape[1]) // strides[1]
+ h = (
+ input_shape[1] + padding[0] + padding[1] + strides[0] - kernel_shape[0]
+ ) // strides[0]
+ w = (
+ input_shape[2] + padding[2] + padding[3] + strides[1] - kernel_shape[1]
+ ) // strides[1]
# return True if any dimension is < 1
return h < 1 or w < 1
@@ -3226,17 +3472,31 @@ class TosaInvalidValidator:
the output size
"""
dilated_kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
- return (in_size - 1) * stride + dilated_kernel_size - 2 * in_pad + out_pad
+ return (
+ (in_size - 1) * stride + dilated_kernel_size - 2 * in_pad + out_pad
+ )
for pad_h, pad_w in (
- (kernel_shape[0] - 1, kernel_shape[1] - 1), # FULL padding
- (kernel_shape[0] // 2, kernel_shape[1] // 2), # SAME padding
- (0, 0) # VALID padding
+ (kernel_shape[0] - 1, kernel_shape[1] - 1), # FULL padding
+ (kernel_shape[0] // 2, kernel_shape[1] // 2), # SAME padding
+ (0, 0), # VALID padding
):
- h = get_out_size(input_shape[1], strides[0], kernel_shape[0], dilations[0],
- padding[0], pad_h)
- w = get_out_size(input_shape[2], strides[1], kernel_shape[1], dilations[1],
- padding[1], pad_w)
+ h = get_out_size(
+ input_shape[1],
+ strides[0],
+ kernel_shape[0],
+ dilations[0],
+ padding[0],
+ pad_h,
+ )
+ w = get_out_size(
+ input_shape[2],
+ strides[1],
+ kernel_shape[1],
+ dilations[1],
+ padding[1],
+ pad_w,
+ )
if output_shape[1] == h and output_shape[2] == w:
return False
@@ -3247,7 +3507,11 @@ class TosaInvalidValidator:
# conv2d, conv3d, depthwise_conv2d
dilations = args[2]
filter_shape = inputShapes[1]
- kernel_shape = filter_shape[0:2] if opName.startswith("depthwise_conv2d") else filter_shape[1:-1]
+ kernel_shape = (
+ filter_shape[0:2]
+ if opName.startswith("depthwise_conv2d")
+ else filter_shape[1:-1]
+ )
for i in range(len(kernel_shape)):
dim = (
@@ -3266,7 +3530,7 @@ class TosaInvalidValidator:
@staticmethod
def ivNonPositiveOutputShape(**kwargs):
- args = kwargs['args']
+ args = kwargs["args"]
output_shape = args[3]
if output_shape[1] <= 0 or output_shape[2] <= 0:
# Negative output shape
@@ -3310,13 +3574,12 @@ class TosaTestGen:
fd.write(self.ser.writeJson("{}.tosa".format(testName)))
def resetRNG(self, seed=None):
- if seed == None:
+ if seed is None:
seed = self.random_seed + 1
self.rng = np.random.default_rng(seed)
def getRandTensor(self, shape, dtype):
if dtype == DType.BOOL:
- np_dt = np.bool
return np.bool_(self.rng.choice(a=[False, True], size=shape))
# TOSA specific INT4 weight range from -7 to 7
elif dtype == DType.INT4:
@@ -3469,8 +3732,8 @@ class TosaTestGen:
if isinstance(op, int):
self.ser.addOperator(op, a.name, result_tens.name, None, qinfo)
return result_tens
- elif op['op'] == Op.IDENTITY:
- self.ser.addOperator(op['op'], a.name, result_tens.name, None, qinfo)
+ elif op["op"] == Op.IDENTITY:
+ self.ser.addOperator(op["op"], a.name, result_tens.name, None, qinfo)
return result_tens
# Ensure new output type has correct qinfo
@@ -3478,7 +3741,8 @@ class TosaTestGen:
if result_tens.dtype not in [DType.INT8, DType.UINT8]:
qinfo = ts.TosaSerializerQuantInfo()
qinfo.UnaryQuantInfo(
- TosaQuantGen.getQinfo(self, a.dtype), TosaQuantGen.getQinfo(self, result_tens.dtype)
+ TosaQuantGen.getQinfo(self, a.dtype),
+ TosaQuantGen.getQinfo(self, result_tens.dtype),
)
# Invalidate Input/Output list for error if checks.
@@ -3486,7 +3750,9 @@ class TosaTestGen:
output_list = [result_tens.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
- input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
+ input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
+ self, error_name, input_list, output_list
+ )
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
@@ -3495,72 +3761,81 @@ class TosaTestGen:
op=op,
input_dtype=a.dtype,
output_dtype=result_tens.dtype,
- qinfo = qinfo,
- result_tensor = result_tens,
+ qinfo=qinfo,
+ result_tensor=result_tens,
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
):
return None
- self.ser.addOperator(op['op'], input_list, output_list, None, qinfo)
+ self.ser.addOperator(op["op"], input_list, output_list, None, qinfo)
return result_tens
def build_binary_broadcast(self, op, a, b, validator_fcns, error_name=None):
- result_tens = OutputShaper.binaryBroadcastOp(self.ser, self.rng, a, b, error_name)
-
+ result_tens = OutputShaper.binaryBroadcastOp(
+ self.ser, self.rng, a, b, error_name
+ )
# Invalidate Input/Output list for error if checks.
input_list = [a.name, b.name]
output_list = [result_tens.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
- input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
+ input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
+ self, error_name, input_list, output_list
+ )
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
- input1 = a,
- input2 = b,
- input_dtype = a.dtype,
- output_dtype = result_tens.dtype,
- result_tensor = result_tens,
+ input1=a,
+ input2=b,
+ input_dtype=a.dtype,
+ output_dtype=result_tens.dtype,
+ result_tensor=result_tens,
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
):
return None
- self.ser.addOperator(op['op'], input_list, output_list)
+ self.ser.addOperator(op["op"], input_list, output_list)
return result_tens
def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
- self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name])
+ self.ser.addOperator(op["op"], [a.name, b.name], [result_tens.name])
return result_tens
- def build_arithmetic_right_shift(self, op, a, b, round, validator_fcns=None, error_name=None):
- result_tens = OutputShaper.binaryBroadcastOp(self.ser, self.rng, a, b, error_name)
+ def build_arithmetic_right_shift(
+ self, op, a, b, round, validator_fcns=None, error_name=None
+ ):
+ result_tens = OutputShaper.binaryBroadcastOp(
+ self.ser, self.rng, a, b, error_name
+ )
# Invalidate Input/Output list for error if checks.
input_list = [a.name, b.name]
output_list = [result_tens.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
- input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
+ input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
+ self, error_name, input_list, output_list
+ )
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
- input1 = a,
- input2 = b,
- input_dtype = a.dtype,
- output_dtype = result_tens.dtype,
- result_tensor = result_tens,
+ input1=a,
+ input2=b,
+ input_dtype=a.dtype,
+ output_dtype=result_tens.dtype,
+ result_tensor=result_tens,
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -3570,11 +3845,13 @@ class TosaTestGen:
attr = ts.TosaSerializerAttribute()
attr.ArithmeticRightShiftAttribute(round)
- self.ser.addOperator(op['op'], input_list, output_list, attr)
+ self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
def build_mul(self, op, a, b, shift, validator_fcns=None, error_name=None):
- result_tens = OutputShaper.binaryBroadcastOp(self.ser, self.rng, a, b, error_name)
+ result_tens = OutputShaper.binaryBroadcastOp(
+ self.ser, self.rng, a, b, error_name
+ )
# Special for multiply:
# Force the result to INT32 for INT types
@@ -3590,18 +3867,20 @@ class TosaTestGen:
output_list = [result_tens.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
- input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
+ input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
+ self, error_name, input_list, output_list
+ )
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
- input1 = a,
- input2 = b,
- input_dtype = a.dtype,
- output_dtype = result_tens.dtype,
- result_tensor = result_tens,
+ input1=a,
+ input2=b,
+ input_dtype=a.dtype,
+ output_dtype=result_tens.dtype,
+ result_tensor=result_tens,
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -3611,7 +3890,7 @@ class TosaTestGen:
attr = ts.TosaSerializerAttribute()
attr.MulAttribute(shift)
- self.ser.addOperator(op['op'], input_list, output_list, attr)
+ self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
def build_table(self, op, a, table, validator_fcns=None, error_name=None):
@@ -3625,24 +3904,26 @@ class TosaTestGen:
output_list = [result_tens.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
- input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
+ input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
+ self, error_name, input_list, output_list
+ )
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
- input_shape = a.shape,
- input_dtype = a.dtype,
- output_dtype = result_tens.dtype,
- result_tensor = result_tens,
+ input_shape=a.shape,
+ input_dtype=a.dtype,
+ output_dtype=result_tens.dtype,
+ result_tensor=result_tens,
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
):
return None
- self.ser.addOperator(op['op'], input_list, output_list, attr)
+ self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
@@ -3654,58 +3935,72 @@ class TosaTestGen:
output_list = [result_tens.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
- input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
+ input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
+ self, error_name, input_list, output_list
+ )
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
- input1 = cond,
- input2 = a,
- input3 = b,
- input_shape = a.shape,
- input_dtype = a.dtype,
- output_dtype = result_tens.dtype,
- result_tensor = result_tens,
+ input1=cond,
+ input2=a,
+ input3=b,
+ input_shape=a.shape,
+ input_dtype=a.dtype,
+ output_dtype=result_tens.dtype,
+ result_tensor=result_tens,
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
):
return None
- self.ser.addOperator(op['op'], input_list, output_list,)
+ self.ser.addOperator(
+ op["op"],
+ input_list,
+ output_list,
+ )
return result_tens
def build_comparison(self, op, a, b, validator_fcns=None, error_name=None):
- result_tens = OutputShaper.binaryComparisonOp(self.ser, self.rng, a, b, error_name)
+ result_tens = OutputShaper.binaryComparisonOp(
+ self.ser, self.rng, a, b, error_name
+ )
# Invalidate Input/Output list for error if checks.
input_list = [a.name, b.name]
output_list = [result_tens.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
- input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
+ input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
+ self, error_name, input_list, output_list
+ )
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
- input1 = a,
- input2 = b,
- input_shape = a.shape,
- input_dtype = a.dtype,
- output_shape = result_tens.shape,
- output_dtype = result_tens.dtype,
- result_tensor = result_tens,
+ input1=a,
+ input2=b,
+ input_shape=a.shape,
+ input_dtype=a.dtype,
+ output_shape=result_tens.shape,
+ output_dtype=result_tens.dtype,
+ result_tensor=result_tens,
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
):
return None
- self.ser.addOperator(op['op'], input_list, output_list,)
+ self.ser.addOperator(
+ op["op"],
+ input_list,
+ output_list,
+ )
return result_tens
def build_argmax(self, op, a, axis, validator_fcns, error_name):
@@ -3716,7 +4011,9 @@ class TosaTestGen:
output_list = [result_tens.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
- input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
+ input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
+ self, error_name, input_list, output_list
+ )
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
@@ -3724,11 +4021,11 @@ class TosaTestGen:
error_name,
op=op,
axis=axis,
- input_shape = a.shape,
- input_dtype = a.dtype,
- output_shape = result_tens.shape,
- output_dtype = result_tens.dtype,
- result_tensor = result_tens,
+ input_shape=a.shape,
+ input_dtype=a.dtype,
+ output_shape=result_tens.shape,
+ output_dtype=result_tens.dtype,
+ result_tensor=result_tens,
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -3738,18 +4035,31 @@ class TosaTestGen:
attr = ts.TosaSerializerAttribute()
attr.AxisAttribute(axis)
- self.ser.addOperator(op['op'], input_list, output_list, attr)
+ self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
- def build_pool2d(self, op, input, stride, pad, kernel, validator_fcns=None, error_name=None, qinfo=None):
- result_tens = OutputShaper.pool2dOp(self.ser, self.rng, input, kernel, stride, pad, error_name)
+ def build_pool2d(
+ self,
+ op,
+ input,
+ stride,
+ pad,
+ kernel,
+ validator_fcns=None,
+ error_name=None,
+ qinfo=None,
+ ):
+ result_tens = OutputShaper.pool2dOp(
+ self.ser, self.rng, input, kernel, stride, pad, error_name
+ )
# Ensure new output type has correct qinfo
if error_name == ErrorIf.WrongInputType:
if input.dtype not in [DType.INT8, DType.UINT8]:
qinfo = ts.TosaSerializerQuantInfo()
qinfo.UnaryQuantInfo(
- TosaQuantGen.getQinfo(self, input.dtype), TosaQuantGen.getQinfo(self, result_tens.dtype)
+ TosaQuantGen.getQinfo(self, input.dtype),
+ TosaQuantGen.getQinfo(self, result_tens.dtype),
)
# Invalidate Input/Output list for error if checks.
@@ -3757,7 +4067,9 @@ class TosaTestGen:
output_list = [result_tens.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
- input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
+ input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
+ self, error_name, input_list, output_list
+ )
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
@@ -3771,8 +4083,8 @@ class TosaTestGen:
kernel=kernel,
stride=stride,
pad=pad,
- qinfo = qinfo,
- result_tensor = result_tens,
+ qinfo=qinfo,
+ result_tensor=result_tens,
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -3782,27 +4094,45 @@ class TosaTestGen:
attr = ts.TosaSerializerAttribute()
attr.PoolAttribute(kernel, stride, pad)
- self.ser.addOperator(op['op'], input_list, output_list, attr, qinfo)
+ self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo)
return result_tens
- def build_conv2d(self, op, ifm, filter, bias, strides, padding, dilations, validator_fcns=None, error_name=None, qinfo=None):
+ def build_conv2d(
+ self,
+ op,
+ ifm,
+ filter,
+ bias,
+ strides,
+ padding,
+ dilations,
+ validator_fcns=None,
+ error_name=None,
+ qinfo=None,
+ ):
assert len(padding) == 4
result_tens = OutputShaper.conv2dOp(
self.ser, self.rng, ifm, filter, strides, padding, dilations, error_name
)
# Ensure new output type has correct qinfo
- if error_name == ErrorIf.WrongInputType and ifm.dtype not in (DType.INT8, DType.UINT8):
+ if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
+ DType.INT8,
+ DType.UINT8,
+ ):
qinfo = ts.TosaSerializerQuantInfo()
qinfo.ConvQuantInfo(
- TosaQuantGen.getQinfo(self, ifm.dtype), TosaQuantGen.getQinfo(self, result_tens.dtype)
+ TosaQuantGen.getQinfo(self, ifm.dtype),
+ TosaQuantGen.getQinfo(self, result_tens.dtype),
)
# Invalidate Input/Output list for error_if checks.
input_list = [ifm.name, filter.name, bias.name]
output_list = [result_tens.name]
num_operands = sum(op["operands"])
- input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
+ input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
+ self, error_name, input_list, output_list
+ )
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
@@ -3826,29 +4156,45 @@ class TosaTestGen:
attr = ts.TosaSerializerAttribute()
attr.ConvAttribute(padding, strides, dilations)
- self.ser.addOperator(
- op['op'], input_list, output_list, attr, qinfo
- )
+ self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo)
return result_tens
- def build_conv3d(self, op, ifm, filter, bias, strides, padding, dilations, validator_fcns=None, error_name=None, qinfo=None):
+ def build_conv3d(
+ self,
+ op,
+ ifm,
+ filter,
+ bias,
+ strides,
+ padding,
+ dilations,
+ validator_fcns=None,
+ error_name=None,
+ qinfo=None,
+ ):
assert len(padding) == 6
result_tens = OutputShaper.conv3dOp(
self.ser, self.rng, ifm, filter, strides, padding, dilations, error_name
)
# Ensure new output type has correct qinfo
- if error_name == ErrorIf.WrongInputType and ifm.dtype not in (DType.INT8, DType.UINT8):
+ if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
+ DType.INT8,
+ DType.UINT8,
+ ):
qinfo = ts.TosaSerializerQuantInfo()
qinfo.ConvQuantInfo(
- TosaQuantGen.getQinfo(self, ifm.dtype), TosaQuantGen.getQinfo(self, result_tens.dtype)
+ TosaQuantGen.getQinfo(self, ifm.dtype),
+ TosaQuantGen.getQinfo(self, result_tens.dtype),
)
# Invalidate Input/Output list for error_if checks.
input_list = [ifm.name, filter.name, bias.name]
output_list = [result_tens.name]
num_operands = sum(op["operands"])
- input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
+ input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
+ self, error_name, input_list, output_list
+ )
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
@@ -3872,29 +4218,46 @@ class TosaTestGen:
attr = ts.TosaSerializerAttribute()
attr.ConvAttribute(padding, strides, dilations)
- self.ser.addOperator(
- op['op'], input_list, output_list, attr, qinfo
- )
+ self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo)
return result_tens
def build_transpose_conv2d(
- self, op, ifm, filter, bias, stride, outpad, dilation, output_shape, validator_fcns=None, error_name=None, qinfo=None
+ self,
+ op,
+ ifm,
+ filter,
+ bias,
+ stride,
+ outpad,
+ dilation,
+ output_shape,
+ validator_fcns=None,
+ error_name=None,
+ qinfo=None,
):
assert len(outpad) == 2
- result_tens = OutputShaper.transposeConv2DOp(self.ser, self.rng, ifm, output_shape, error_name)
+ result_tens = OutputShaper.transposeConv2DOp(
+ self.ser, self.rng, ifm, output_shape, error_name
+ )
# Ensure new output type has correct qinfo
- if error_name == ErrorIf.WrongInputType and ifm.dtype not in (DType.INT8, DType.UINT8):
+ if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
+ DType.INT8,
+ DType.UINT8,
+ ):
qinfo = ts.TosaSerializerQuantInfo()
qinfo.ConvQuantInfo(
- TosaQuantGen.getQinfo(self, ifm.dtype), TosaQuantGen.getQinfo(self, result_tens.dtype)
+ TosaQuantGen.getQinfo(self, ifm.dtype),
+ TosaQuantGen.getQinfo(self, result_tens.dtype),
)
# Invalidate Input/Output list for error_if checks.
input_list = [ifm.name, filter.name, bias.name]
output_list = [result_tens.name]
num_operands = sum(op["operands"])
- input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
+ input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
+ self, error_name, input_list, output_list
+ )
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
@@ -3918,30 +4281,44 @@ class TosaTestGen:
attr = ts.TosaSerializerAttribute()
attr.TransposeConvAttribute(outpad, stride, dilation, output_shape)
- self.ser.addOperator(
- op['op'], input_list, output_list, attr, qinfo
- )
+ self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo)
return result_tens
def build_depthwise_conv2d(
- self, op, ifm, filter, bias, strides, padding, dilations, validator_fcns=None, error_name=None, qinfo=None
+ self,
+ op,
+ ifm,
+ filter,
+ bias,
+ strides,
+ padding,
+ dilations,
+ validator_fcns=None,
+ error_name=None,
+ qinfo=None,
):
result_tens = OutputShaper.depthwiseConv2dOp(
self.ser, self.rng, ifm, filter, strides, padding, dilations, error_name
)
# Ensure new output type has correct qinfo
- if error_name == ErrorIf.WrongInputType and ifm.dtype not in (DType.INT8, DType.UINT8):
+ if error_name == ErrorIf.WrongInputType and ifm.dtype not in (
+ DType.INT8,
+ DType.UINT8,
+ ):
qinfo = ts.TosaSerializerQuantInfo()
qinfo.ConvQuantInfo(
- TosaQuantGen.getQinfo(self, ifm.dtype), TosaQuantGen.getQinfo(self, result_tens.dtype)
+ TosaQuantGen.getQinfo(self, ifm.dtype),
+ TosaQuantGen.getQinfo(self, result_tens.dtype),
)
# Invalidate Input/Output list for error_if checks.
input_list = [ifm.name, filter.name, bias.name]
output_list = [result_tens.name]
num_operands = sum(op["operands"])
- input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
+ input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
+ self, error_name, input_list, output_list
+ )
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
@@ -3965,20 +4342,24 @@ class TosaTestGen:
attr = ts.TosaSerializerAttribute()
attr.ConvAttribute(padding, strides, dilations)
- self.ser.addOperator(
- op['op'], input_list, output_list, attr, qinfo
- )
+ self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo)
return result_tens
- def build_fully_connected(self, op, ifm, filter, bias, validator_fcns=None, error_name=None, qinfo=None):
- result_tens = OutputShaper.fullyConnectedOp(self.ser, self.rng, ifm, filter, error_name)
+ def build_fully_connected(
+ self, op, ifm, filter, bias, validator_fcns=None, error_name=None, qinfo=None
+ ):
+ result_tens = OutputShaper.fullyConnectedOp(
+ self.ser, self.rng, ifm, filter, error_name
+ )
# Invalidate Input/Output list for error if checks.
input_list = [ifm.name, filter.name, bias.name]
output_list = [result_tens.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
- input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
+ input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
+ self, error_name, input_list, output_list
+ )
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
@@ -3990,17 +4371,15 @@ class TosaTestGen:
weight_dtype=filter.dtype,
output_shape=result_tens.shape,
output_dtype=result_tens.dtype,
- qinfo = qinfo,
- result_tensor = result_tens,
+ qinfo=qinfo,
+ result_tensor=result_tens,
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
):
return None
- self.ser.addOperator(
- op['op'], input_list, output_list, None, qinfo
- )
+ self.ser.addOperator(op["op"], input_list, output_list, None, qinfo)
return result_tens
def build_matmul(self, op, a, b, validator_fcns=None, error_name=None, qinfo=None):
@@ -4011,7 +4390,9 @@ class TosaTestGen:
output_list = [result_tens.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
- input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
+ input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
+ self, error_name, input_list, output_list
+ )
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
@@ -4024,15 +4405,15 @@ class TosaTestGen:
input2_dtype=b.dtype,
output_shape=result_tens.shape,
output_dtype=result_tens.dtype,
- qinfo = qinfo,
- result_tensor = result_tens,
+ qinfo=qinfo,
+ result_tensor=result_tens,
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
):
return None
- self.ser.addOperator(op['op'], input_list, output_list, None, qinfo)
+ self.ser.addOperator(op["op"], input_list, output_list, None, qinfo)
return result_tens
def build_reduce(self, op, a, axis, validator_fcns, error_name=None):
@@ -4043,19 +4424,21 @@ class TosaTestGen:
output_list = [result_tens.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
- input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
+ input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
+ self, error_name, input_list, output_list
+ )
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
- axis = axis,
- input_shape = a.shape,
- output_shape = result_tens.shape,
- input_dtype = a.dtype,
- output_dtype = result_tens.dtype,
- result_tensor = result_tens,
+ axis=axis,
+ input_shape=a.shape,
+ output_shape=result_tens.shape,
+ input_dtype=a.dtype,
+ output_dtype=result_tens.dtype,
+ result_tensor=result_tens,
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -4065,7 +4448,7 @@ class TosaTestGen:
attr = ts.TosaSerializerAttribute()
attr.AxisAttribute(axis)
- self.ser.addOperator(op['op'], input_list, output_list, attr)
+ self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
def build_clamp(self, op, a, validator_fcns=None, error_name=None):
@@ -4088,7 +4471,9 @@ class TosaTestGen:
output_list = [result_tens.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
- input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
+ input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
+ self, error_name, input_list, output_list
+ )
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
@@ -4097,11 +4482,11 @@ class TosaTestGen:
op=op,
max_val=max_val,
min_val=min_val,
- input_shape = a.shape,
- output_shape = result_tens.shape,
- input_dtype = a.dtype,
- output_dtype = result_tens.dtype,
- result_tensor = result_tens,
+ input_shape=a.shape,
+ output_shape=result_tens.shape,
+ input_dtype=a.dtype,
+ output_dtype=result_tens.dtype,
+ result_tensor=result_tens,
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -4114,7 +4499,7 @@ class TosaTestGen:
else:
attr.ClampAttribute(min_val, max_val, 0, 0)
- self.ser.addOperator(op['op'], input_list, output_list, attr)
+ self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
@@ -4123,14 +4508,14 @@ class TosaTestGen:
attr.LeakyReluAttribute(self.getRandNumberDType(DType.FLOAT))
- self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
+ self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr)
return result_tens
# Needs an additional type/input
def build_prelu(self, op, a, validator_fcns=None, error_name=None):
result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
- self.ser.addOperator(op['op'], [a.name], [result_tens.name])
+ self.ser.addOperator(op["op"], [a.name], [result_tens.name])
return result_tens
def build_sigmoid(self, op, a, validator_fcns=None, error_name=None):
@@ -4141,25 +4526,27 @@ class TosaTestGen:
output_list = [result_tens.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
- input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
+ input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
+ self, error_name, input_list, output_list
+ )
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
- input_shape = a.shape,
- output_shape = result_tens.shape,
- input_dtype = a.dtype,
- output_dtype = result_tens.dtype,
- result_tensor = result_tens,
+ input_shape=a.shape,
+ output_shape=result_tens.shape,
+ input_dtype=a.dtype,
+ output_dtype=result_tens.dtype,
+ result_tensor=result_tens,
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
):
return None
- self.ser.addOperator(op['op'], input_list, output_list)
+ self.ser.addOperator(op["op"], input_list, output_list)
return result_tens
def build_tanh(self, op, a, validator_fcns=None, error_name=None):
@@ -4170,25 +4557,27 @@ class TosaTestGen:
output_list = [result_tens.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
- input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
+ input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
+ self, error_name, input_list, output_list
+ )
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
- input_shape = a.shape,
- output_shape = result_tens.shape,
- input_dtype = a.dtype,
- output_dtype = result_tens.dtype,
- result_tensor = result_tens,
+ input_shape=a.shape,
+ output_shape=result_tens.shape,
+ input_dtype=a.dtype,
+ output_dtype=result_tens.dtype,
+ result_tensor=result_tens,
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
):
return None
- self.ser.addOperator(op['op'], input_list, output_list)
+ self.ser.addOperator(op["op"], input_list, output_list)
return result_tens
def build_concat(self, op, *a, validator_fcns=None, error_name=None):
@@ -4199,7 +4588,9 @@ class TosaTestGen:
axis = a[-1]
a = a[:-1]
- result_tens = OutputShaper.concatOp(self.ser, self.rng, axis, *a, error_name=error_name)
+ result_tens = OutputShaper.concatOp(
+ self.ser, self.rng, axis, *a, error_name=error_name
+ )
input_tensor_names = []
for tensor in a:
@@ -4210,7 +4601,9 @@ class TosaTestGen:
output_list = [result_tens.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
- input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
+ input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
+ self, error_name, input_list, output_list
+ )
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
@@ -4218,12 +4611,12 @@ class TosaTestGen:
error_name,
op=op,
axis=axis,
- input_shape = a[0].shape,
- output_shape = result_tens.shape,
- input_dtype = a[0].dtype,
- output_dtype = result_tens.dtype,
+ input_shape=a[0].shape,
+ output_shape=result_tens.shape,
+ input_dtype=a[0].dtype,
+ output_dtype=result_tens.dtype,
inputs=a,
- result_tensor = result_tens,
+ result_tensor=result_tens,
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -4233,11 +4626,20 @@ class TosaTestGen:
attr = ts.TosaSerializerAttribute()
attr.AxisAttribute(axis)
-
- self.ser.addOperator(op['op'], input_list, output_list, attr)
+ self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
- def build_pad(self, op, a, padding, pad_const_int, pad_const_float, validator_fcns=None, error_name=None, qinfo=None):
+ def build_pad(
+ self,
+ op,
+ a,
+ padding,
+ pad_const_int,
+ pad_const_float,
+ validator_fcns=None,
+ error_name=None,
+ qinfo=None,
+ ):
result_tens = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
attr = ts.TosaSerializerAttribute()
@@ -4248,51 +4650,55 @@ class TosaTestGen:
output_list = [result_tens.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
- input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
+ input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
+ self, error_name, input_list, output_list
+ )
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
- input_shape = a.shape,
- output_shape = result_tens.shape,
- input_dtype = a.dtype,
- output_dtype = result_tens.dtype,
+ input_shape=a.shape,
+ output_shape=result_tens.shape,
+ input_dtype=a.dtype,
+ output_dtype=result_tens.dtype,
pad=padding,
qinfo=qinfo,
- result_tensor = result_tens,
+ result_tensor=result_tens,
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
):
return None
- self.ser.addOperator(
- op['op'], input_list, output_list, attr, qinfo
- )
+ self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo)
return result_tens
def build_reshape(self, op, a, newShape, validator_fcns=None, error_name=None):
- result_tens = OutputShaper.reshapeOp(self.ser, self.rng, a, newShape, error_name)
+ result_tens = OutputShaper.reshapeOp(
+ self.ser, self.rng, a, newShape, error_name
+ )
# Invalidate Input/Output list for error if checks.
input_list = [a.name]
output_list = [result_tens.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
- input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
+ input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
+ self, error_name, input_list, output_list
+ )
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
- input_shape = a.shape,
- output_shape = result_tens.shape,
- input_dtype = a.dtype,
- output_dtype = result_tens.dtype,
- result_tensor = result_tens,
+ input_shape=a.shape,
+ output_shape=result_tens.shape,
+ input_dtype=a.dtype,
+ output_dtype=result_tens.dtype,
+ result_tensor=result_tens,
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -4302,7 +4708,7 @@ class TosaTestGen:
attr = ts.TosaSerializerAttribute()
attr.ReshapeAttribute(newShape)
- self.ser.addOperator(op['op'], input_list, output_list, attr)
+ self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
def build_reverse(self, op, a, axis, validator_fcns=None, error_name=None):
@@ -4313,7 +4719,9 @@ class TosaTestGen:
output_list = [result_tens.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
- input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
+ input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
+ self, error_name, input_list, output_list
+ )
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
@@ -4321,11 +4729,11 @@ class TosaTestGen:
error_name,
op=op,
axis=axis,
- input_shape = a.shape,
- output_shape = result_tens.shape,
- input_dtype = a.dtype,
- output_dtype = result_tens.dtype,
- result_tensor = result_tens,
+ input_shape=a.shape,
+ output_shape=result_tens.shape,
+ input_dtype=a.dtype,
+ output_dtype=result_tens.dtype,
+ result_tensor=result_tens,
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -4335,7 +4743,7 @@ class TosaTestGen:
attr = ts.TosaSerializerAttribute()
attr.AxisAttribute(axis)
- self.ser.addOperator(op['op'], input_list, output_list, attr)
+ self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
def build_transpose(self, op, a, perms, validator_fcns=None, error_name=None):
@@ -4349,51 +4757,56 @@ class TosaTestGen:
output_list = [result_tens.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
- input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
+ input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
+ self, error_name, input_list, output_list
+ )
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
- input_shape = a.shape,
- output_shape = result_tens.shape,
+ input_shape=a.shape,
+ output_shape=result_tens.shape,
perms=perms,
- input_dtype = a.dtype,
- output_dtype = result_tens.dtype,
- result_tensor = result_tens,
+ input_dtype=a.dtype,
+ output_dtype=result_tens.dtype,
+ result_tensor=result_tens,
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
):
return None
-
- self.ser.addOperator(op['op'], input_list, output_list, attr)
+ self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
def build_slice(self, op, a, start, size, validator_fcns=None, error_name=None):
- result_tens = OutputShaper.sliceOp(self.ser, self.rng, a, start, size, error_name)
+ result_tens = OutputShaper.sliceOp(
+ self.ser, self.rng, a, start, size, error_name
+ )
# Invalidate Input/Output list for error if checks.
input_list = [a.name]
output_list = [result_tens.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
- input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
+ input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
+ self, error_name, input_list, output_list
+ )
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
- input_shape = a.shape,
- output_shape = result_tens.shape,
- input_dtype = a.dtype,
- output_dtype = result_tens.dtype,
+ input_shape=a.shape,
+ output_shape=result_tens.shape,
+ input_dtype=a.dtype,
+ output_dtype=result_tens.dtype,
start=start,
size=size,
- result_tensor = result_tens,
+ result_tensor=result_tens,
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -4403,7 +4816,7 @@ class TosaTestGen:
attr = ts.TosaSerializerAttribute()
attr.SliceAttribute(start, size)
- self.ser.addOperator(op['op'], input_list, output_list, attr)
+ self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
def build_tile(self, op, a, multiples, validator_fcns=None, error_name=None):
@@ -4414,18 +4827,20 @@ class TosaTestGen:
output_list = [result_tens.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
- input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
+ input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
+ self, error_name, input_list, output_list
+ )
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
- input_shape = a.shape,
- output_shape = result_tens.shape,
- input_dtype = a.dtype,
- output_dtype = result_tens.dtype,
- result_tensor = result_tens,
+ input_shape=a.shape,
+ output_shape=result_tens.shape,
+ input_dtype=a.dtype,
+ output_dtype=result_tens.dtype,
+ result_tensor=result_tens,
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -4435,7 +4850,7 @@ class TosaTestGen:
attr = ts.TosaSerializerAttribute()
attr.TileAttribute(multiples)
- self.ser.addOperator(op['op'], input_list, output_list, attr)
+ self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
def build_gather(self, op, values, validator_fcns=None, error_name=None):
@@ -4452,32 +4867,36 @@ class TosaTestGen:
) # (N, W)
indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
- result_tens = OutputShaper.gatherOp(self.ser, self.rng, values, indicies, error_name)
+ result_tens = OutputShaper.gatherOp(
+ self.ser, self.rng, values, indicies, error_name
+ )
# Invalidate Input/Output list for error if checks.
input_list = [values.name, indicies.name]
output_list = [result_tens.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
- input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
+ input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
+ self, error_name, input_list, output_list
+ )
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
- input_shape = values.shape,
- output_shape = result_tens.shape,
- input_dtype = values.dtype,
- output_dtype = result_tens.dtype,
- result_tensor = result_tens,
+ input_shape=values.shape,
+ output_shape=result_tens.shape,
+ input_dtype=values.dtype,
+ output_dtype=result_tens.dtype,
+ result_tensor=result_tens,
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
):
return None
- self.ser.addOperator(op['op'], input_list, output_list)
+ self.ser.addOperator(op["op"], input_list, output_list)
return result_tens
@@ -4493,36 +4912,39 @@ class TosaTestGen:
) # (N, W)
indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
- result_tens = OutputShaper.scatterOp(self.ser, self.rng, values_in, indicies, input, error_name)
+ result_tens = OutputShaper.scatterOp(
+ self.ser, self.rng, values_in, indicies, input, error_name
+ )
# Invalidate Input/Output list for error if checks.
input_list = [values_in.name, indicies.name, input.name]
output_list = [result_tens.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
- input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
+ input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
+ self, error_name, input_list, output_list
+ )
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
- input_shape = values_in.shape,
- output_shape = result_tens.shape,
- input_dtype = values_in.dtype,
- output_dtype = result_tens.dtype,
- result_tensor = result_tens,
+ input_shape=values_in.shape,
+ output_shape=result_tens.shape,
+ input_dtype=values_in.dtype,
+ output_dtype=result_tens.dtype,
+ result_tensor=result_tens,
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
):
return None
- self.ser.addOperator(op['op'], input_list, output_list)
+ self.ser.addOperator(op["op"], input_list, output_list)
return result_tens
-
def build_resize(
self,
op,
@@ -4537,7 +4959,7 @@ class TosaTestGen:
input_dtype,
output_dtype,
validator_fcns,
- error_name = None,
+ error_name=None,
):
result_tens = OutputShaper.resizeOp(
self.ser,
@@ -4552,7 +4974,7 @@ class TosaTestGen:
output_dims,
input_dtype,
output_dtype,
- error_name
+ error_name,
)
# Invalidate Input/Output list for error if checks.
@@ -4560,7 +4982,9 @@ class TosaTestGen:
output_list = [result_tens.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
- input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
+ input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
+ self, error_name, input_list, output_list
+ )
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
@@ -4590,7 +5014,7 @@ class TosaTestGen:
output_dims, stride, offset, shift, stride_fp, offset_fp, mode
)
- self.ser.addOperator(op['op'], input_list, output_list, attr)
+ self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
@@ -4607,36 +5031,52 @@ class TosaTestGen:
# Type Conversion
def build_cast(self, op, val, out_dtype, validator_fcns=None, error_name=None):
- result_tens = OutputShaper.typeConversionOp(self.ser, self.rng, val, out_dtype, error_name)
+ result_tens = OutputShaper.typeConversionOp(
+ self.ser, self.rng, val, out_dtype, error_name
+ )
# Invalidate Input/Output list for error if checks.
input_list = [val.name]
output_list = [result_tens.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
- input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
+ input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
+ self, error_name, input_list, output_list
+ )
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
validator_fcns,
error_name,
op=op,
- input_shape = val.shape,
- output_shape = result_tens.shape,
- input_dtype = val.dtype,
- output_dtype = result_tens.dtype,
- result_tensor = result_tens,
+ input_shape=val.shape,
+ output_shape=result_tens.shape,
+ input_dtype=val.dtype,
+ output_dtype=result_tens.dtype,
+ result_tensor=result_tens,
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
):
return None
- self.ser.addOperator(op['op'], input_list, output_list)
+ self.ser.addOperator(op["op"], input_list, output_list)
return result_tens
- def build_rescale(self, op, val, out_dtype, scale32, double_round, per_channel, validator_fcns, error_name):
- result_tens = OutputShaper.typeConversionOp(self.ser, self.rng, val, out_dtype, error_name)
+ def build_rescale(
+ self,
+ op,
+ val,
+ out_dtype,
+ scale32,
+ double_round,
+ per_channel,
+ validator_fcns,
+ error_name,
+ ):
+ result_tens = OutputShaper.typeConversionOp(
+ self.ser, self.rng, val, out_dtype, error_name
+ )
if per_channel:
nc = val.shape[-1]
@@ -4705,7 +5145,9 @@ class TosaTestGen:
output_list = [result_tens.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
- input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
+ input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
+ self, error_name, input_list, output_list
+ )
qinfo = (input_zp, output_zp)
if not TosaErrorValidator.evValidateErrorIfs(
@@ -4717,8 +5159,8 @@ class TosaTestGen:
output_dtype=out_dtype,
input_shape=val.shape,
qinfo=qinfo,
- scale32 = scale32,
- double_round = double_round,
+ scale32=scale32,
+ double_round=double_round,
input_list=input_list,
output_list=output_list,
result_tensor=result_tens,
@@ -4737,10 +5179,12 @@ class TosaTestGen:
per_channel,
)
- self.ser.addOperator(op['op'], input_list, output_list, attr)
+ self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
- def build_cond_if_const(self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None):
+ def build_cond_if_const(
+ self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None
+ ):
# For cond_if with constants, we're supplied with then/else tensors that we ignore
# (except for the generated shap) and the condition. Build Then/Else blocks
# and fill them with const nodes for the body.
@@ -4752,10 +5196,17 @@ class TosaTestGen:
out_shape = then_tens.shape
# Create an incorrect output shape for error_if tests
- if error_name in [ErrorIf.CondIfOutputListThenGraphMismatch, ErrorIf.CondIfOutputListElseGraphMismatch]:
+ if error_name in [
+ ErrorIf.CondIfOutputListThenGraphMismatch,
+ ErrorIf.CondIfOutputListElseGraphMismatch,
+ ]:
incorrect_shape = deepcopy(then_tens.shape)
for i in range(len(incorrect_shape)):
- incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3]) if incorrect_shape[i] > 3 else self.rng.choice([1, 2, 4])
+ incorrect_shape[i] += (
+ self.rng.choice([-3, -2, 2, 3])
+ if incorrect_shape[i] > 3
+ else self.rng.choice([1, 2, 4])
+ )
incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
@@ -4771,7 +5222,7 @@ class TosaTestGen:
attr.CondIfAttribute(then_block, else_block)
# Finally, build the op and the two blocks
- self.ser.addOperator(op['op'], [cond_tens.name], [result_tens.name], attr)
+ self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr)
self.ser.startBasicBlock(then_block)
# Build the actual then/else tensors inside their blocks
@@ -4793,13 +5244,15 @@ class TosaTestGen:
validator_fcns,
error_name,
op=op,
- basicBlocks=self.ser.basicBlocks
+ basicBlocks=self.ser.basicBlocks,
):
return None
return result_tens
- def build_cond_if_binary(self, op, a, b, cond, validator_fcns=None, error_name=None):
+ def build_cond_if_binary(
+ self, op, a, b, cond, validator_fcns=None, error_name=None
+ ):
# For cond_if with a binary op in the then/else blocks, take a and b and
# alternately add or subtract them based on the condition
@@ -4814,18 +5267,21 @@ class TosaTestGen:
attr = ts.TosaSerializerAttribute()
attr.CondIfAttribute(then_block, else_block)
- if error_name in [ErrorIf.CondIfInputListThenGraphMismatch, ErrorIf.CondIfInputListElseGraphMismatch,
- ErrorIf.CondIfOutputListElseGraphMismatch, ErrorIf.CondIfOutputListThenGraphMismatch]:
+ if error_name in [
+ ErrorIf.CondIfInputListThenGraphMismatch,
+ ErrorIf.CondIfInputListElseGraphMismatch,
+ ErrorIf.CondIfOutputListElseGraphMismatch,
+ ErrorIf.CondIfOutputListThenGraphMismatch,
+ ]:
incorrect_shape = a.shape.copy()
for i in range(len(incorrect_shape)):
incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
incorrect_block_input = deepcopy(a)
incorrect_block_input.shape = incorrect_shape
-
# Finally, build the op and the two blocks
self.ser.addOperator(
- op['op'], [cond_tens.name, a.name, b.name], [result_tens.name], attr
+ op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr
)
if a.dtype in (DType.FLOAT, DType.INT32):
@@ -4837,13 +5293,23 @@ class TosaTestGen:
for block, op in ((then_block, then_op), (else_block, else_op)):
self.ser.startBasicBlock(block)
- if ((error_name == ErrorIf.CondIfInputListThenGraphMismatch and block == then_block) or
- (error_name == ErrorIf.CondIfInputListElseGraphMismatch and block == else_block)):
+ if (
+ error_name == ErrorIf.CondIfInputListThenGraphMismatch
+ and block == then_block
+ ) or (
+ error_name == ErrorIf.CondIfInputListElseGraphMismatch
+ and block == else_block
+ ):
self.ser.addInputTensor(incorrect_block_input)
self.ser.addInputTensor(b)
tens = self.ser.addOutput(a.shape, a.dtype)
- elif ((error_name == ErrorIf.CondIfOutputListThenGraphMismatch and block == then_block) or
- (error_name == ErrorIf.CondIfOutputListElseGraphMismatch and block == else_block)):
+ elif (
+ error_name == ErrorIf.CondIfOutputListThenGraphMismatch
+ and block == then_block
+ ) or (
+ error_name == ErrorIf.CondIfOutputListElseGraphMismatch
+ and block == else_block
+ ):
self.ser.addInputTensor(a)
self.ser.addInputTensor(b)
tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
@@ -4860,7 +5326,7 @@ class TosaTestGen:
op=op,
a=a,
b=b,
- basicBlocks=self.ser.basicBlocks
+ basicBlocks=self.ser.basicBlocks,
):
return None
@@ -4893,14 +5359,18 @@ class TosaTestGen:
# While_loop operator
self.ser.addOperator(
- op['op'],
+ op["op"],
[iter.name, a.name, acc.name],
[iter_out.name, a_out.name, acc_out.name],
attr,
)
self.ser.addOutputTensor(acc_out)
- if error_name in [ErrorIf.InputListCondGraphMismatch, ErrorIf.InputListBodyGraphInputMismatch, ErrorIf.InputListBodyGraphOutputMismatch]:
+ if error_name in [
+ ErrorIf.InputListCondGraphMismatch,
+ ErrorIf.InputListBodyGraphInputMismatch,
+ ErrorIf.InputListBodyGraphOutputMismatch,
+ ]:
incorrect_iter = deepcopy(iter)
for i in range(len(incorrect_iter.shape)):
incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
@@ -4924,7 +5394,9 @@ class TosaTestGen:
zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
- cond_tens = self.ser.addOutput([], self.rng.choice([DType.INT8, DType.INT32, DType.FLOAT]))
+ cond_tens = self.ser.addOutput(
+ [], self.rng.choice([DType.INT8, DType.INT32, DType.FLOAT])
+ )
else:
cond_tens = self.ser.addOutput([], DType.BOOL)
@@ -4945,8 +5417,12 @@ class TosaTestGen:
one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
- iter_body_out = self.ser.addIntermediate(incorrect_iter.shape, incorrect_iter.dtype)
- acc_body_out = self.ser.addIntermediate(incorrect_acc.shape, incorrect_acc.dtype)
+ iter_body_out = self.ser.addIntermediate(
+ incorrect_iter.shape, incorrect_iter.dtype
+ )
+ acc_body_out = self.ser.addIntermediate(
+ incorrect_acc.shape, incorrect_acc.dtype
+ )
else:
iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
@@ -4962,13 +5438,15 @@ class TosaTestGen:
validator_fcns,
error_name,
op=op,
- basicBlocks=self.ser.basicBlocks
+ basicBlocks=self.ser.basicBlocks,
):
return None
return acc_out
- def create_filter_lists(self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None):
+ def create_filter_lists(
+ self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
+ ):
# Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
default_test_rank_range = range(1, 5)
if not shapeFilter:
@@ -4986,7 +5464,11 @@ class TosaTestGen:
# Ensure default behaviour is bounded by default range or by operator,
# whichever is the smaller range of ranks.
opRankRange = range(rmin, rmax + 1)
- cleanRankFilter = opRankRange if len(opRankRange) <= len(default_test_rank_range) else default_test_rank_range
+ cleanRankFilter = (
+ opRankRange
+ if len(opRankRange) <= len(default_test_rank_range)
+ else default_test_rank_range
+ )
else:
cleanRankFilter = range(rmin, rmax + 1)
@@ -4996,57 +5478,65 @@ class TosaTestGen:
cleanDtypeFilter = []
# Create list of operator dtypes filtered by requested dtypes
for dtype in dtypes:
- if dtype in dtypeFilter or (isinstance(dtype, list) and dtype[0] in dtypeFilter):
+ if dtype in dtypeFilter or (
+ isinstance(dtype, list) and dtype[0] in dtypeFilter
+ ):
cleanDtypeFilter.append(dtype)
else:
cleanDtypeFilter = dtypes
- if testType == 'positive':
+ if testType == "positive":
filterDict = {
- 'shapeFilter': shapeFilter,
- 'rankFilter': cleanRankFilter,
- 'dtypeFilter': cleanDtypeFilter
+ "shapeFilter": shapeFilter,
+ "rankFilter": cleanRankFilter,
+ "dtypeFilter": cleanDtypeFilter,
}
return filterDict
- elif testType == 'negative':
+ elif testType == "negative":
if validator is not None:
validator_info = validator(check=False, op=op)
else:
return None
- error_arguments = validator_info['param_reqs']
+ error_arguments = validator_info["param_reqs"]
- #Set parameters as required
- if error_arguments['rank'] != None:
- rankFilter = error_arguments['rank']
+ # Set parameters as required
+ if error_arguments["rank"] is not None:
+ rankFilter = error_arguments["rank"]
else:
rankFilter = cleanRankFilter
- if error_arguments['dtype'] != None:
- dtypeFilter = error_arguments['dtype']
+ if error_arguments["dtype"] is not None:
+ dtypeFilter = error_arguments["dtype"]
else:
dtypeFilter = cleanDtypeFilter
- if error_arguments['shape'] != None:
- shapeFilter = error_arguments['shape']
+ if error_arguments["shape"] is not None:
+ shapeFilter = error_arguments["shape"]
else:
- shapeFilter = shapeFilter[:2] # Reduce number of shapes to keep test numbers small
+ shapeFilter = shapeFilter[
+ :2
+ ] # Reduce number of shapes to keep test numbers small
filterDict = {
- 'shapeFilter': shapeFilter,
- 'rankFilter': rankFilter,
- 'dtypeFilter': dtypeFilter
+ "shapeFilter": shapeFilter,
+ "rankFilter": rankFilter,
+ "dtypeFilter": dtypeFilter,
}
return filterDict
-
def genOpTestList(
- self, opName, shapeFilter=[None], rankFilter=None, dtypeFilter=None, testType='positive'
+ self,
+ opName,
+ shapeFilter=[None],
+ rankFilter=None,
+ dtypeFilter=None,
+ testType="positive",
):
try:
op = self.TOSA_OP_LIST[opName]
- except KeyError as e:
+ except KeyError:
raise Exception("Cannot find op with name {}".format(opName))
# Initialize a new random number generator
@@ -5057,24 +5547,26 @@ class TosaTestGen:
# Test list consists of a tuple of:
# (opName, testNameStr, dtype, shapeList, argumentsList)
testList = []
- if testType == 'negative' and "error_if_validators" in op:
+ if testType == "negative" and "error_if_validators" in op:
error_if_validators = op["error_if_validators"]
else:
error_if_validators = [None]
for validator in error_if_validators:
if validator is not None:
- error_name = validator(check=False, op=op)['error_name']
+ error_name = validator(check=False, op=op)["error_name"]
else:
error_name = None
- filterDict = self.create_filter_lists(op, shapeFilter, rankFilter, dtypeFilter, testType, validator)
- if filterDict == None:
+ filterDict = self.create_filter_lists(
+ op, shapeFilter, rankFilter, dtypeFilter, testType, validator
+ )
+ if filterDict is None:
return []
- cleanRankFilter = filterDict['rankFilter']
- cleanDtypeFilter = filterDict['dtypeFilter']
- cleanShapeFilter = filterDict['shapeFilter']
- #print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
+ cleanRankFilter = filterDict["rankFilter"]
+ cleanDtypeFilter = filterDict["dtypeFilter"]
+ cleanShapeFilter = filterDict["shapeFilter"]
+ # print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
for r in cleanRankFilter:
for t in cleanDtypeFilter:
@@ -5096,24 +5588,30 @@ class TosaTestGen:
argList = [("", [])]
for argStr, args in argList:
- if testType == 'positive':
+ if testType == "positive":
if argStr:
testStr = "{}_{}_{}_{}".format(
opName, shapeStr, typeStr, argStr
)
else:
- testStr = "{}_{}_{}".format(opName, shapeStr, typeStr)
- elif testType == 'negative':
+ testStr = "{}_{}_{}".format(
+ opName, shapeStr, typeStr
+ )
+ elif testType == "negative":
if argStr:
testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
opName, error_name, shapeStr, typeStr, argStr
)
else:
- testStr = "{}_ERRORIF_{}_{}_{}".format(opName, error_name, shapeStr, typeStr)
+ testStr = "{}_ERRORIF_{}_{}_{}".format(
+ opName, error_name, shapeStr, typeStr
+ )
- testList.append((opName, testStr, t, error_name, shapeList, args))
+ testList.append(
+ (opName, testStr, t, error_name, shapeList, args)
+ )
- if testType == 'positive':
+ if testType == "positive":
# Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
if "invalid_test_validators" in op:
invalid_test_validators = op["invalid_test_validators"]
@@ -5121,7 +5619,12 @@ class TosaTestGen:
for test in testList:
for validator_fcn in invalid_test_validators:
remove_test = False
- if validator_fcn(opName=test[0], input_dtype=test[2], shapeList=test[4], args=test[5]):
+ if validator_fcn(
+ opName=test[0],
+ input_dtype=test[2],
+ shapeList=test[4],
+ args=test[5],
+ ):
remove_test = True
if not remove_test:
clean_testList.append(test)
@@ -5129,11 +5632,12 @@ class TosaTestGen:
return testList
-
- def serializeTest(self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs):
+ def serializeTest(
+ self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs
+ ):
try:
op = self.TOSA_OP_LIST[opName]
- except KeyError as e:
+ except KeyError:
raise Exception("Cannot find op with name {}".format(opName))
# Create a serializer
@@ -5190,9 +5694,24 @@ class TosaTestGen:
resultName = build_fcn(self, op, *tens, *testArgs)
else:
if qinfo is not None:
- resultName = build_fcn(self, op, *tens, *testArgs, validator_fcns=error_if_validators, error_name=error_name, qinfo=qinfo)
+ resultName = build_fcn(
+ self,
+ op,
+ *tens,
+ *testArgs,
+ validator_fcns=error_if_validators,
+ error_name=error_name,
+ qinfo=qinfo,
+ )
else:
- resultName = build_fcn(self, op, *tens, *testArgs, validator_fcns=error_if_validators, error_name=error_name)
+ resultName = build_fcn(
+ self,
+ op,
+ *tens,
+ *testArgs,
+ validator_fcns=error_if_validators,
+ error_name=error_name,
+ )
except TypeError as e:
print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
raise e
@@ -5204,19 +5723,22 @@ class TosaTestGen:
# The test is not valid
print(f"Invalid ERROR_IF test created: {opName} {testStr}")
-
def generate_tensors(self, op, dtypeList, shapeList, testArgs, error_name=None):
pCount, cCount = op["operands"]
tens = []
- if (op["op"] == Op.ADD or op["op"] == Op.SUB) and dtypeList[0] == DType.INT32 and error_name == None:
+ if (
+ (op["op"] == Op.ADD or op["op"] == Op.SUB)
+ and dtypeList[0] == DType.INT32
+ and error_name is None
+ ):
# Make sure the operation does not cause value saturation - where
# the number wraps due to limited number of bits to store the answer
assert (
pCount == 2 and cCount == 0
), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
placeholders = []
- add = (op["op"] == Op.ADD)
+ add = op["op"] == Op.ADD
a_arr = self.getRandTensor(shapeList[0], dtypeList[0])
b_arr = self.getRandTensor(shapeList[1], dtypeList[1])
if add:
@@ -5225,7 +5747,7 @@ class TosaTestGen:
res_arr = np.subtract(a_arr, b_arr, dtype=np.int64)
# Work out the saturation limits
- max_i32 = (1 << 31)-1
+ max_i32 = (1 << 31) - 1
min_i32 = -(1 << 31)
max_arr = np.full(shapeList[1], max_i32)
min_arr = np.full(shapeList[1], min_i32)
@@ -5246,7 +5768,9 @@ class TosaTestGen:
# Reduce axes in unsaturated tensor to match original tensor
for axis, dim in enumerate(b_arr.shape):
if dim != b_unsat_arr.shape[axis]:
- assert ( dim == 1 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
+ assert (
+ dim == 1
+ ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True)
if (sat_min_arr != 0).any():
@@ -5255,7 +5779,9 @@ class TosaTestGen:
# Reduce axes in unsaturated tensor to match original tensor
for axis, dim in enumerate(b_arr.shape):
if dim != b_unsat_arr.shape[axis]:
- assert ( dim == 1 ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
+ assert (
+ dim == 1
+ ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True)
placeholders.append(
@@ -5266,15 +5792,19 @@ class TosaTestGen:
)
tens.extend(placeholders)
- elif (op["op"] == Op.COND_IF or op["op"] == Op.WHILE_LOOP) and dtypeList[0] == DType.INT32:
+ elif (op["op"] == Op.COND_IF or op["op"] == Op.WHILE_LOOP) and dtypeList[
+ 0
+ ] == DType.INT32:
# Limit input tensors with cond_if_binary or while_loop to stop
# saturation of add/sub ops
pRemain = pCount
placeholders = []
- for idx, shape in enumerate(shapeList[:]):
+ for idx, shape in enumerate(shapeList[:]):
arr = self.getRandTensor(shapeList[idx], DType.INT16)
if pRemain > 0:
- placeholders.append(self.ser.addPlaceholder(shape, dtypeList[idx], arr))
+ placeholders.append(
+ self.ser.addPlaceholder(shape, dtypeList[idx], arr)
+ )
pRemain -= 1
else:
placeholders.append(self.ser.addConst(shape, dtypeList[idx], arr))
@@ -5311,7 +5841,7 @@ class TosaTestGen:
self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
)
tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
- elif op["op"] == Op.INTDIV and error_name == None:
+ elif op["op"] == Op.INTDIV and error_name is None:
assert (
pCount == 2 and cCount == 0
), "Op.INTDIV must have 2 placeholders, 0 consts"
@@ -5341,7 +5871,7 @@ class TosaTestGen:
)
tens.extend(placeholders)
- elif op["op"] == Op.MUL and error_name == None:
+ elif op["op"] == Op.MUL and error_name is None:
assert (
pCount == 2 and cCount == 0
), "Op.MUL must have 2 placeholders, 0 consts"
@@ -5414,7 +5944,9 @@ class TosaTestGen:
# Ensure axis is an int
testArgs[0] = int(testArgs[0])
- shapeList = TosaTensorGen.tgConcatConstInput(self, shapeList, testArgs[0], error_name)
+ shapeList = TosaTensorGen.tgConcatConstInput(
+ self, shapeList, testArgs[0], error_name
+ )
tens.extend(
self.buildPlaceholderTensors(shapeList[0:count], dtypeList[0:count])
@@ -5466,7 +5998,7 @@ class TosaTestGen:
keyList = []
for k in self.TOSA_OP_LIST:
try:
- if self.TOSA_OP_LIST[k]["template"] == True:
+ if self.TOSA_OP_LIST[k]["template"]:
keyList.append(k)
continue
except KeyError:
@@ -5498,22 +6030,22 @@ class TosaTestGen:
)
try:
- types = self.TOSA_OP_LIST[op]["types"]
- except KeyError as e:
+ _ = self.TOSA_OP_LIST[op]["types"]
+ except KeyError:
raise Exception(
"Op {} is missing a valid type list in TOSA_OP_LIST".format(op)
)
try:
- opcode = self.TOSA_OP_LIST[op]["op"]
- except KeyError as e:
+ _ = self.TOSA_OP_LIST[op]["op"]
+ except KeyError:
raise Exception(
"Op {} is missing the Op field in TOSA_OP_LIST".format(op)
)
# Put in default rank range, if missing
try:
- rank = self.TOSA_OP_LIST[op]["rank"]
+ _ = self.TOSA_OP_LIST[op]["rank"]
except KeyError:
self.TOSA_OP_LIST[op]["rank"] = self.DEFAULT_RANK_RANGE
@@ -5553,9 +6085,17 @@ class TosaTestGen:
"rank": (1, 4),
"build_fcn": (build_argmax, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
"types": TYPE_NARROW_INT_FP,
- "error_if_validators": (TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evArgmaxOutputRankMismatch,
- TosaErrorValidator.evArgmaxOutputShapeMismatch, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType,
- TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+ "error_if_validators": (
+ TosaErrorValidator.evAxisSmallerZero,
+ TosaErrorValidator.evAxisLargerRank,
+ TosaErrorValidator.evArgmaxOutputRankMismatch,
+ TosaErrorValidator.evArgmaxOutputShapeMismatch,
+ TosaErrorValidator.evWrongRank,
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ ),
},
"avg_pool2d": {
"op": Op.AVG_POOL2D,
@@ -5565,10 +6105,20 @@ class TosaTestGen:
"qgen": TosaQuantGen.qgUnary,
"types": TYPE_NARROW_INT_FP,
"invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
- "error_if_validators": (TosaErrorValidator.evKernelSmallerOne, TosaErrorValidator.evStrideSmallerOne, TosaErrorValidator.evPadSmallerZero,
- TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
- TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evOutputZeroPointNotZero,
- TosaErrorValidator.evPadLargerEqualKernel, TosaErrorValidator.evPoolingOutputShapeMismatch)
+ "error_if_validators": (
+ TosaErrorValidator.evKernelSmallerOne,
+ TosaErrorValidator.evStrideSmallerOne,
+ TosaErrorValidator.evPadSmallerZero,
+ TosaErrorValidator.evWrongRank,
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ TosaErrorValidator.evInputZeroPointNotZero,
+ TosaErrorValidator.evOutputZeroPointNotZero,
+ TosaErrorValidator.evPadLargerEqualKernel,
+ TosaErrorValidator.evPoolingOutputShapeMismatch,
+ ),
},
# Templated operator. Filled in by createDynamicOpLists
"conv2d_TEMPLATE": {
@@ -5651,8 +6201,15 @@ class TosaTestGen:
"build_fcn": (build_fully_connected, TosaTensorGen.tgFullyConnected, None),
"qgen": TosaQuantGen.qgConv,
"types": TYPE_CONV,
- "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evWeightZeroPointNotZero, TosaErrorValidator.evWrongRank,
- TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+ "error_if_validators": (
+ TosaErrorValidator.evInputZeroPointNotZero,
+ TosaErrorValidator.evWeightZeroPointNotZero,
+ TosaErrorValidator.evWrongRank,
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ ),
},
"matmul": {
"op": Op.MATMUL,
@@ -5661,8 +6218,14 @@ class TosaTestGen:
"build_fcn": (build_matmul, TosaTensorGen.tgMatmul, None),
"qgen": TosaQuantGen.qgMatmul,
"types": TYPE_NARROW_INT_FP,
- "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType,
- TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+ "error_if_validators": (
+ TosaErrorValidator.evInputZeroPointNotZero,
+ TosaErrorValidator.evWrongRank,
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ ),
},
"max_pool2d": {
"op": Op.MAX_POOL2D,
@@ -5671,9 +6234,18 @@ class TosaTestGen:
"build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
"types": TYPE_NARROW_INT_FP,
"invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
- "error_if_validators": (TosaErrorValidator.evKernelSmallerOne, TosaErrorValidator.evStrideSmallerOne, TosaErrorValidator.evPadSmallerZero,
- TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
- TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evPadLargerEqualKernel, TosaErrorValidator.evPoolingOutputShapeMismatch)
+ "error_if_validators": (
+ TosaErrorValidator.evKernelSmallerOne,
+ TosaErrorValidator.evStrideSmallerOne,
+ TosaErrorValidator.evPadSmallerZero,
+ TosaErrorValidator.evWrongRank,
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ TosaErrorValidator.evPadLargerEqualKernel,
+ TosaErrorValidator.evPoolingOutputShapeMismatch,
+ ),
},
# Templated operator. Filled in by createDynamicOpLists
"transpose_conv2d_TEMPLATE": {
@@ -5711,24 +6283,37 @@ class TosaTestGen:
"operands": (1, 0),
"build_fcn": (build_clamp, TosaTensorGen.tgBasic, None),
"types": TYPE_NARROW_INT_FP,
- "error_if_validators": (TosaErrorValidator.evMaxSmallerMin, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
- TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+ "error_if_validators": (
+ TosaErrorValidator.evMaxSmallerMin,
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ ),
},
"sigmoid": {
"op": Op.SIGMOID,
"operands": (1, 0),
"build_fcn": (build_sigmoid, TosaTensorGen.tgBasic, None),
"types": TYPE_FP,
- "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
- TosaErrorValidator.evWrongOutputList)
+ "error_if_validators": (
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ ),
},
"tanh": {
"op": Op.TANH,
"operands": (1, 0),
"build_fcn": (build_tanh, TosaTensorGen.tgBasic, None),
"types": TYPE_FP,
- "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
- TosaErrorValidator.evWrongOutputList)
+ "error_if_validators": (
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ ),
},
# Elementwise Binary Operators
"add": {
@@ -5736,8 +6321,14 @@ class TosaTestGen:
"operands": (2, 0),
"build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
"types": TYPE_FI32,
- "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
- TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
+ "error_if_validators": (
+ TosaErrorValidator.evRankMismatch,
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ TosaErrorValidator.evDimensionMismatch,
+ ),
},
"arithmetic_right_shift": {
"op": Op.ARITHMETIC_RIGHT_SHIFT,
@@ -5748,120 +6339,210 @@ class TosaTestGen:
TosaArgGen.agArithmeticRightShift,
),
"types": TYPE_INT,
- "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
- TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
+ "error_if_validators": (
+ TosaErrorValidator.evRankMismatch,
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ TosaErrorValidator.evDimensionMismatch,
+ ),
},
"bitwise_and": {
"op": Op.BITWISE_AND,
"operands": (2, 0),
"build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
"types": TYPE_INT,
- "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
- TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
+ "error_if_validators": (
+ TosaErrorValidator.evRankMismatch,
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ TosaErrorValidator.evDimensionMismatch,
+ ),
},
"bitwise_or": {
"op": Op.BITWISE_OR,
"operands": (2, 0),
"build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
"types": TYPE_INT,
- "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
- TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
+ "error_if_validators": (
+ TosaErrorValidator.evRankMismatch,
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ TosaErrorValidator.evDimensionMismatch,
+ ),
},
"bitwise_xor": {
"op": Op.BITWISE_XOR,
"operands": (2, 0),
"build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
"types": TYPE_INT,
- "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
- TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
+ "error_if_validators": (
+ TosaErrorValidator.evRankMismatch,
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ TosaErrorValidator.evDimensionMismatch,
+ ),
},
"intdiv": {
"op": Op.INTDIV,
"operands": (2, 0),
"build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
"types": [DType.INT32],
- "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
- TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
+ "error_if_validators": (
+ TosaErrorValidator.evRankMismatch,
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ TosaErrorValidator.evDimensionMismatch,
+ ),
},
"logical_and": {
"op": Op.LOGICAL_AND,
"operands": (2, 0),
"build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
"types": TYPE_BOOL,
- "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
- TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
+ "error_if_validators": (
+ TosaErrorValidator.evRankMismatch,
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ TosaErrorValidator.evDimensionMismatch,
+ ),
},
"logical_left_shift": {
"op": Op.LOGICAL_LEFT_SHIFT,
"operands": (2, 0),
"build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
"types": TYPE_INT,
- "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
- TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
+ "error_if_validators": (
+ TosaErrorValidator.evRankMismatch,
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ TosaErrorValidator.evDimensionMismatch,
+ ),
},
"logical_right_shift": {
"op": Op.LOGICAL_RIGHT_SHIFT,
"operands": (2, 0),
"build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
"types": TYPE_INT,
- "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
- TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
+ "error_if_validators": (
+ TosaErrorValidator.evRankMismatch,
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ TosaErrorValidator.evDimensionMismatch,
+ ),
},
"logical_or": {
"op": Op.LOGICAL_OR,
"operands": (2, 0),
"build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
"types": TYPE_BOOL,
- "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
- TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
+ "error_if_validators": (
+ TosaErrorValidator.evRankMismatch,
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ TosaErrorValidator.evDimensionMismatch,
+ ),
},
"logical_xor": {
"op": Op.LOGICAL_XOR,
"operands": (2, 0),
"build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
"types": TYPE_BOOL,
- "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
- TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
+ "error_if_validators": (
+ TosaErrorValidator.evRankMismatch,
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ TosaErrorValidator.evDimensionMismatch,
+ ),
},
"maximum": {
"op": Op.MAXIMUM,
"operands": (2, 0),
"build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
"types": TYPE_FI32,
- "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
- TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
+ "error_if_validators": (
+ TosaErrorValidator.evRankMismatch,
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ TosaErrorValidator.evDimensionMismatch,
+ ),
},
"minimum": {
"op": Op.MINIMUM,
"operands": (2, 0),
"build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
"types": TYPE_FI32,
- "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
- TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
+ "error_if_validators": (
+ TosaErrorValidator.evRankMismatch,
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ TosaErrorValidator.evDimensionMismatch,
+ ),
},
"mul": {
"op": Op.MUL,
"operands": (2, 0),
"build_fcn": (build_mul, TosaTensorGen.tgBroadcastFuzz, TosaArgGen.agMul),
"types": TYPE_INT_FP,
- "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
- TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evRankMismatch, TosaErrorValidator.evDimensionMismatch)
+ "error_if_validators": (
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ TosaErrorValidator.evRankMismatch,
+ TosaErrorValidator.evDimensionMismatch,
+ ),
},
"pow": {
"op": Op.POW,
"operands": (2, 0),
"build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
"types": TYPE_FP,
- "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
- TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
+ "error_if_validators": (
+ TosaErrorValidator.evRankMismatch,
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ TosaErrorValidator.evDimensionMismatch,
+ ),
},
"sub": {
"op": Op.SUB,
"operands": (2, 0),
"build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
"types": TYPE_FI32,
- "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
- TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
+ "error_if_validators": (
+ TosaErrorValidator.evRankMismatch,
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ TosaErrorValidator.evDimensionMismatch,
+ ),
},
"table": {
"op": Op.TABLE,
@@ -5871,8 +6552,12 @@ class TosaTestGen:
"operands": (1, 0),
"build_fcn": (build_table, TosaTensorGen.tgBasic, TosaArgGen.agTable),
"types": [DType.INT8, DType.INT16],
- "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
- TosaErrorValidator.evWrongOutputList)
+ "error_if_validators": (
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ ),
},
# Elementwise Unary operators
"abs": {
@@ -5880,64 +6565,96 @@ class TosaTestGen:
"operands": (1, 0),
"build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
"types": TYPE_FI32,
- "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
- TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+ "error_if_validators": (
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ ),
},
"bitwise_not": {
"op": Op.BITWISE_NOT,
"operands": (1, 0),
"build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
"types": TYPE_INT,
- "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
- TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+ "error_if_validators": (
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ ),
},
"ceil": {
"op": Op.CEIL,
"operands": (1, 0),
"build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
"types": TYPE_FP,
- "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
- TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+ "error_if_validators": (
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ ),
},
"clz": {
"op": Op.CLZ,
"operands": (1, 0),
"build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
"types": [DType.INT32],
- "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
- TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+ "error_if_validators": (
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ ),
},
"exp": {
"op": Op.EXP,
"operands": (1, 0),
"build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
"types": TYPE_FP,
- "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
- TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+ "error_if_validators": (
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ ),
},
"floor": {
"op": Op.FLOOR,
"operands": (1, 0),
"build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
"types": TYPE_FP,
- "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
- TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+ "error_if_validators": (
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ ),
},
"log": {
"op": Op.LOG,
"operands": (1, 0),
"build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
"types": TYPE_FP,
- "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
- TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+ "error_if_validators": (
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ ),
},
"logical_not": {
"op": Op.LOGICAL_NOT,
"operands": (1, 0),
"build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
"types": TYPE_BOOL,
- "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
- TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+ "error_if_validators": (
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ ),
},
"negate": {
"op": Op.NEGATE,
@@ -5945,25 +6662,38 @@ class TosaTestGen:
"build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
"qgen": TosaQuantGen.qgUnary,
"types": TYPE_INT_FP,
- "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evOutputZeroPointNotZero,
- TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
- TosaErrorValidator.evWrongOutputList)
+ "error_if_validators": (
+ TosaErrorValidator.evInputZeroPointNotZero,
+ TosaErrorValidator.evOutputZeroPointNotZero,
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ ),
},
"reciprocal": {
"op": Op.RECIPROCAL,
"operands": (1, 0),
"build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
"types": TYPE_FP,
- "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
- TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+ "error_if_validators": (
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ ),
},
"rsqrt": {
"op": Op.RSQRT,
"operands": (1, 0),
"build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
"types": TYPE_FP,
- "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
- TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+ "error_if_validators": (
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ ),
},
# Elementwise Ternary operators
"select": {
@@ -5971,8 +6701,14 @@ class TosaTestGen:
"operands": (3, 0),
"build_fcn": (build_select, TosaTensorGen.tgBroadcastFuzz, None),
"types": TYPE_FIB,
- "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
- TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
+ "error_if_validators": (
+ TosaErrorValidator.evRankMismatch,
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ TosaErrorValidator.evDimensionMismatch,
+ ),
},
# Comparison operators
"equal": {
@@ -5980,24 +6716,42 @@ class TosaTestGen:
"operands": (2, 0),
"build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
"types": TYPE_FI32,
- "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
- TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
+ "error_if_validators": (
+ TosaErrorValidator.evRankMismatch,
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ TosaErrorValidator.evDimensionMismatch,
+ ),
},
"greater_equal": {
"op": Op.GREATER_EQUAL,
"operands": (2, 0),
"build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
"types": TYPE_FI32,
- "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
- TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
+ "error_if_validators": (
+ TosaErrorValidator.evRankMismatch,
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ TosaErrorValidator.evDimensionMismatch,
+ ),
},
"greater": {
"op": Op.GREATER,
"operands": (2, 0),
"build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
"types": TYPE_FI32,
- "error_if_validators": (TosaErrorValidator.evRankMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
- TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch)
+ "error_if_validators": (
+ TosaErrorValidator.evRankMismatch,
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ TosaErrorValidator.evDimensionMismatch,
+ ),
},
# Reduction operators
"reduce_all": {
@@ -6006,9 +6760,16 @@ class TosaTestGen:
"rank": (1, 4),
"build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
"types": TYPE_BOOL,
- "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
- TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
- TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+ "error_if_validators": (
+ TosaErrorValidator.evAxisLargerRank,
+ TosaErrorValidator.evAxisSmallerZero,
+ TosaErrorValidator.evShapeOfAxisNotOne,
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongRank,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ ),
},
"reduce_any": {
"op": Op.REDUCE_ANY,
@@ -6016,9 +6777,16 @@ class TosaTestGen:
"rank": (1, 4),
"build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
"types": TYPE_BOOL,
- "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
- TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
- TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+ "error_if_validators": (
+ TosaErrorValidator.evAxisLargerRank,
+ TosaErrorValidator.evAxisSmallerZero,
+ TosaErrorValidator.evShapeOfAxisNotOne,
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongRank,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ ),
},
"reduce_max": {
"op": Op.REDUCE_MAX,
@@ -6026,9 +6794,16 @@ class TosaTestGen:
"rank": (1, 4),
"build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
"types": TYPE_INT_FP,
- "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
- TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
- TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+ "error_if_validators": (
+ TosaErrorValidator.evAxisLargerRank,
+ TosaErrorValidator.evAxisSmallerZero,
+ TosaErrorValidator.evShapeOfAxisNotOne,
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongRank,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ ),
},
"reduce_min": {
"op": Op.REDUCE_MAX,
@@ -6036,9 +6811,16 @@ class TosaTestGen:
"rank": (1, 4),
"build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
"types": TYPE_INT_FP,
- "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
- TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
- TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+ "error_if_validators": (
+ TosaErrorValidator.evAxisLargerRank,
+ TosaErrorValidator.evAxisSmallerZero,
+ TosaErrorValidator.evShapeOfAxisNotOne,
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongRank,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ ),
},
"reduce_product": {
"op": Op.REDUCE_PRODUCT,
@@ -6046,9 +6828,16 @@ class TosaTestGen:
"rank": (1, 4),
"build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
"types": TYPE_FP,
- "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
- TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
- TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+ "error_if_validators": (
+ TosaErrorValidator.evAxisLargerRank,
+ TosaErrorValidator.evAxisSmallerZero,
+ TosaErrorValidator.evShapeOfAxisNotOne,
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongRank,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ ),
},
"reduce_sum": {
"op": Op.REDUCE_SUM,
@@ -6056,9 +6845,16 @@ class TosaTestGen:
"rank": (1, 4),
"build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
"types": TYPE_FI32,
- "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evShapeOfAxisNotOne,
- TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
- TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+ "error_if_validators": (
+ TosaErrorValidator.evAxisLargerRank,
+ TosaErrorValidator.evAxisSmallerZero,
+ TosaErrorValidator.evShapeOfAxisNotOne,
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongRank,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ ),
},
# Data layout operators
"concat": {
@@ -6066,9 +6862,16 @@ class TosaTestGen:
"operands": (2, 0),
"build_fcn": (build_concat, TosaTensorGen.tgConcat, TosaArgGen.agAxis),
"types": TYPE_FIB,
- "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evConcatInputRankMismatch,
- TosaErrorValidator.evConcatShapeSumMismatch, TosaErrorValidator.evConcatInputDimMismatch, TosaErrorValidator.evWrongInputType,
- TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongOutputList)
+ "error_if_validators": (
+ TosaErrorValidator.evAxisLargerRank,
+ TosaErrorValidator.evAxisSmallerZero,
+ TosaErrorValidator.evConcatInputRankMismatch,
+ TosaErrorValidator.evConcatShapeSumMismatch,
+ TosaErrorValidator.evConcatInputDimMismatch,
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongOutputList,
+ ),
},
"pad": {
"op": Op.PAD,
@@ -6077,24 +6880,40 @@ class TosaTestGen:
"build_fcn": (build_pad, TosaTensorGen.tgBasic, TosaArgGen.agPad),
"qgen": TosaQuantGen.qgPad,
"types": TYPE_FIB,
- "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evPadSmallerZero,
- TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+ "error_if_validators": (
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evPadSmallerZero,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ ),
},
"reshape": {
"op": Op.RESHAPE,
"operands": (1, 0),
"build_fcn": (build_reshape, TosaTensorGen.tgBasic, TosaArgGen.agReshape),
"types": TYPE_FIB,
- "error_if_validators": (TosaErrorValidator.evTensorSizeInputOutputMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
- TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+ "error_if_validators": (
+ TosaErrorValidator.evTensorSizeInputOutputMismatch,
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ ),
},
"reverse": {
"op": Op.REVERSE,
"operands": (1, 0),
"build_fcn": (build_reverse, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
"types": TYPE_FIB,
- "error_if_validators": (TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evWrongInputType,
- TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+ "error_if_validators": (
+ TosaErrorValidator.evAxisSmallerZero,
+ TosaErrorValidator.evAxisLargerRank,
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ ),
},
"slice": {
"op": Op.SLICE,
@@ -6102,17 +6921,30 @@ class TosaTestGen:
"rank": (1, 4),
"build_fcn": (build_slice, TosaTensorGen.tgBasic, TosaArgGen.agSlice),
"types": TYPE_FIB,
- "error_if_validators": (TosaErrorValidator.evStartSmallerZero, TosaErrorValidator.evSizeSmallerEqualZero, TosaErrorValidator.evStartSizeOutsideBounds,
- TosaErrorValidator.evSizeOutputShapeMismatch, TosaErrorValidator.evInputSizeStartLengthMismatch, TosaErrorValidator.evWrongRank,
- TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+ "error_if_validators": (
+ TosaErrorValidator.evStartSmallerZero,
+ TosaErrorValidator.evSizeSmallerEqualZero,
+ TosaErrorValidator.evStartSizeOutsideBounds,
+ TosaErrorValidator.evSizeOutputShapeMismatch,
+ TosaErrorValidator.evInputSizeStartLengthMismatch,
+ TosaErrorValidator.evWrongRank,
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ ),
},
"tile": {
"op": Op.TILE,
"operands": (1, 0),
"build_fcn": (build_tile, TosaTensorGen.tgBasic, TosaArgGen.agTile),
"types": TYPE_FIB,
- "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
- TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+ "error_if_validators": (
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ ),
},
"transpose": {
"op": Op.TRANSPOSE,
@@ -6124,8 +6956,14 @@ class TosaTestGen:
TosaArgGen.agTranspose,
),
"types": TYPE_FIB,
- "error_if_validators": (TosaErrorValidator.evIndexOutsideBounds, TosaErrorValidator.evIndexUsedTwice,
- TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+ "error_if_validators": (
+ TosaErrorValidator.evIndexOutsideBounds,
+ TosaErrorValidator.evIndexUsedTwice,
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ ),
},
# Data nodes
"const": {
@@ -6148,19 +6986,29 @@ class TosaTestGen:
"rank": (3, 3),
"build_fcn": (build_gather, TosaTensorGen.tgBasic, None),
"types": TYPE_INT_FP,
- "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
- TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evWrongRank)
+ "error_if_validators": (
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ TosaErrorValidator.evWrongRank,
+ ),
},
"scatter": {
"op": Op.SCATTER,
# Only specify 'values_in' tensor here.
- #'indices' and 'input' are generated in op building stage
+ # 'indices' and 'input' are generated in op building stage
"operands": (2, 0),
"rank": (3, 3),
"build_fcn": (build_scatter, TosaTensorGen.tgScatter, None),
"types": TYPE_INT_FP,
- "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
- TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evWrongRank)
+ "error_if_validators": (
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ TosaErrorValidator.evWrongRank,
+ ),
},
# Image operations
"resize": {
@@ -6169,12 +7017,28 @@ class TosaTestGen:
"rank": (4, 4),
"build_fcn": (build_resize, TosaTensorGen.tgNHWC, TosaArgGen.agResize),
"types": [DType.INT8, DType.INT16, DType.FLOAT],
- "invalid_test_validators": (TosaInvalidValidator.ivWrongDataTypeOrModeResize, TosaInvalidValidator.ivBadStride),
- "error_if_validators": (TosaErrorValidator.evMaxDimExceeded, TosaErrorValidator.evStrideSmallerEqualZero, TosaErrorValidator.evStrideLargerDimension,
- TosaErrorValidator.evStrideLargerEqualMax, TosaErrorValidator.evOffsetSmallerEqualMin, TosaErrorValidator.evOffsetLargerEqualMax,
- TosaErrorValidator.evShiftNotZero, TosaErrorValidator.evShiftSmallerOne, TosaErrorValidator.evShiftLargerEleven, TosaErrorValidator.evWrongInputType,
- TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList,
- TosaErrorValidator.evBatchMismatch, TosaErrorValidator.evChannelMismatch)
+ "invalid_test_validators": (
+ TosaInvalidValidator.ivWrongDataTypeOrModeResize,
+ TosaInvalidValidator.ivBadStride,
+ ),
+ "error_if_validators": (
+ TosaErrorValidator.evMaxDimExceeded,
+ TosaErrorValidator.evStrideSmallerEqualZero,
+ TosaErrorValidator.evStrideLargerDimension,
+ TosaErrorValidator.evStrideLargerEqualMax,
+ TosaErrorValidator.evOffsetSmallerEqualMin,
+ TosaErrorValidator.evOffsetLargerEqualMax,
+ TosaErrorValidator.evShiftNotZero,
+ TosaErrorValidator.evShiftSmallerOne,
+ TosaErrorValidator.evShiftLargerEleven,
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongRank,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ TosaErrorValidator.evBatchMismatch,
+ TosaErrorValidator.evChannelMismatch,
+ ),
},
# Type conversion
"cast": {
@@ -6182,18 +7046,30 @@ class TosaTestGen:
"operands": (1, 0),
"build_fcn": (build_cast, TosaTensorGen.tgBasic, TosaArgGen.agCast),
"types": [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL],
- "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
- TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+ "error_if_validators": (
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ ),
},
"rescale": {
"op": Op.RESCALE,
"operands": (1, 0),
- "rank": (1,4),
+ "rank": (1, 4),
"build_fcn": (build_rescale, TosaTensorGen.tgBasic, TosaArgGen.agRescale),
"types": [DType.UINT8, DType.INT8, DType.INT16, DType.INT32, DType.INT48],
- "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evOutputZeroPointNotZero, TosaErrorValidator.evScaleTrue,
- TosaErrorValidator.evScaleNotTrue, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
- TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
+ "error_if_validators": (
+ TosaErrorValidator.evInputZeroPointNotZero,
+ TosaErrorValidator.evOutputZeroPointNotZero,
+ TosaErrorValidator.evScaleTrue,
+ TosaErrorValidator.evScaleNotTrue,
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongRank,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ ),
},
# Custom
# Not implemented.
@@ -6210,7 +7086,10 @@ class TosaTestGen:
TosaArgGen.agCondIf,
),
"types": [DType.BOOL],
- "error_if_validators": (TosaErrorValidator.evOutputListThenGraphMismatch, TosaErrorValidator.evOutputListElseGraphMismatch)
+ "error_if_validators": (
+ TosaErrorValidator.evOutputListThenGraphMismatch,
+ TosaErrorValidator.evOutputListElseGraphMismatch,
+ ),
},
"cond_if_binary": {
"op": Op.COND_IF,
@@ -6221,8 +7100,12 @@ class TosaTestGen:
TosaArgGen.agCondIf,
),
"types": TYPE_INT_FP,
- "error_if_validators": (TosaErrorValidator.evInputListThenGraphMismatch, TosaErrorValidator.evInputListElseGraphMismatch,
- TosaErrorValidator.evOutputListThenGraphMismatch, TosaErrorValidator.evOutputListElseGraphMismatch)
+ "error_if_validators": (
+ TosaErrorValidator.evInputListThenGraphMismatch,
+ TosaErrorValidator.evInputListElseGraphMismatch,
+ TosaErrorValidator.evOutputListThenGraphMismatch,
+ TosaErrorValidator.evOutputListElseGraphMismatch,
+ ),
},
# while_loop
"while_loop": {
@@ -6234,9 +7117,13 @@ class TosaTestGen:
TosaArgGen.agWhileLoop,
),
"types": [DType.INT32],
- "error_if_validators": (TosaErrorValidator.evInputListOutputListMismatch, TosaErrorValidator.evInputListCondGraphMismatch,
- TosaErrorValidator.evInputListBodyGraphInputMismatch, TosaErrorValidator.evInputListBodyGraphOutputMismatch,
- TosaErrorValidator.evCondGraphOutputNotMatchingBool)
+ "error_if_validators": (
+ TosaErrorValidator.evInputListOutputListMismatch,
+ TosaErrorValidator.evInputListCondGraphMismatch,
+ TosaErrorValidator.evInputListBodyGraphInputMismatch,
+ TosaErrorValidator.evInputListBodyGraphOutputMismatch,
+ TosaErrorValidator.evCondGraphOutputNotMatchingBool,
+ ),
},
}
@@ -6257,13 +7144,19 @@ class OutputShaper:
shape = []
for i in range(len(a.shape)):
- if a.shape[i] == 1 and error_name == None:
+ if a.shape[i] == 1 and error_name is None:
shape.append(b.shape[i])
else:
shape.append(a.shape[i])
if error_name == ErrorIf.WrongOutputType:
- all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
+ all_dtypes = [
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.INT48,
+ DType.FLOAT,
+ ]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
else:
@@ -6286,7 +7179,13 @@ class OutputShaper:
@staticmethod
def unaryOp(ser, rng, a, error_name=None):
if error_name == ErrorIf.WrongOutputType:
- all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
+ all_dtypes = [
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.INT48,
+ DType.FLOAT,
+ ]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
else:
@@ -6302,13 +7201,19 @@ class OutputShaper:
shape = []
for i in range(len(cond.shape)):
- if cond.shape[i] == 1 and error_name == None:
+ if cond.shape[i] == 1 and error_name is None:
shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
else:
shape.append(cond.shape[i])
if error_name == ErrorIf.WrongOutputType:
- all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
+ all_dtypes = [
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.INT48,
+ DType.FLOAT,
+ ]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
else:
@@ -6317,7 +7222,7 @@ class OutputShaper:
return ser.addOutput(shape, outputDType)
@staticmethod
- def binaryComparisonOp(ser, rng, a, b , error_name=None):
+ def binaryComparisonOp(ser, rng, a, b, error_name=None):
if error_name != ErrorIf.RankMismatch:
assert len(a.shape) == len(b.shape)
assert a.dtype == b.dtype
@@ -6331,7 +7236,13 @@ class OutputShaper:
shape.append(a.shape[i])
if error_name == ErrorIf.WrongOutputType:
- wrong_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
+ wrong_dtypes = [
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.INT48,
+ DType.FLOAT,
+ ]
outputDType = rng.choice(wrong_dtypes)
else:
outputDType = DType.BOOL
@@ -6341,13 +7252,23 @@ class OutputShaper:
@staticmethod
def reduceOp(ser, rng, a, axis, error_name=None):
shape = a.shape.copy()
- if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank, ErrorIf.ShapeOfAxisNotOne]:
+ if error_name not in [
+ ErrorIf.AxisSmallerZero,
+ ErrorIf.AxisLargerRank,
+ ErrorIf.ShapeOfAxisNotOne,
+ ]:
shape[axis] = 1
if error_name == ErrorIf.ShapeOfAxisNotOne and shape[axis] == 1:
shape[axis] = rng.integers(2, 10)
if error_name == ErrorIf.WrongOutputType:
- all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
+ all_dtypes = [
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.INT48,
+ DType.FLOAT,
+ ]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
else:
@@ -6373,7 +7294,13 @@ class OutputShaper:
shape[i] = shape[i] + rng.integers(1, 10)
if error_name == ErrorIf.WrongOutputType:
- all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
+ all_dtypes = [
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.INT48,
+ DType.FLOAT,
+ ]
wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
outputDType = rng.choice(wrong_dtypes)
else:
@@ -6490,7 +7417,9 @@ class OutputShaper:
return ser.addOutput(ofm_shape, out_dtype)
@staticmethod
- def depthwiseConv2dOp(ser, rng, ifm, filter, strides, padding, dilations, error_name=None):
+ def depthwiseConv2dOp(
+ ser, rng, ifm, filter, strides, padding, dilations, error_name=None
+ ):
# IFM: NHWC
# Filter: HWCM
# OFM: NHW C*M
@@ -6553,7 +7482,13 @@ class OutputShaper:
ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
if error_name == ErrorIf.WrongOutputType:
- all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
+ all_dtypes = [
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.INT48,
+ DType.FLOAT,
+ ]
wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
outputDType = rng.choice(wrong_dtypes)
else:
@@ -6571,11 +7506,29 @@ class OutputShaper:
if error_name == ErrorIf.WrongOutputType:
if input.dtype == DType.INT8:
- incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT48, DType.FLOAT)
+ incorrect_types = (
+ DType.INT4,
+ DType.INT8,
+ DType.INT16,
+ DType.INT48,
+ DType.FLOAT,
+ )
elif input.dtype == DType.INT16:
- incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.FLOAT)
+ incorrect_types = (
+ DType.INT4,
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.FLOAT,
+ )
elif input.dtype == DType.FLOAT:
- incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48)
+ incorrect_types = (
+ DType.INT4,
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.INT48,
+ )
out_dtype = rng.choice(a=incorrect_types)
elif input.dtype == DType.INT8:
out_dtype = DType.INT32
@@ -6601,11 +7554,29 @@ class OutputShaper:
if error_name == ErrorIf.WrongOutputType:
if a.dtype == DType.INT8:
- incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT48, DType.FLOAT)
+ incorrect_types = (
+ DType.INT4,
+ DType.INT8,
+ DType.INT16,
+ DType.INT48,
+ DType.FLOAT,
+ )
elif a.dtype == DType.INT16:
- incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.FLOAT)
+ incorrect_types = (
+ DType.INT4,
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.FLOAT,
+ )
elif a.dtype == DType.FLOAT:
- incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48)
+ incorrect_types = (
+ DType.INT4,
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.INT48,
+ )
out_dtype = rng.choice(a=incorrect_types)
elif a.dtype == DType.INT8:
out_dtype = DType.INT32
@@ -6641,7 +7612,13 @@ class OutputShaper:
output_shape[axis] += rng.integers(5, 10)
if error_name == ErrorIf.WrongOutputType:
- all_dtypes = {DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT}
+ all_dtypes = {
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.INT48,
+ DType.FLOAT,
+ }
wrong_dtypes = list(all_dtypes - set([input1.dtype]))
outputDType = rng.choice(wrong_dtypes)
else:
@@ -6662,7 +7639,13 @@ class OutputShaper:
output_shape = [i if i >= 1 else 1 for i in output_shape]
if error_name == ErrorIf.WrongOutputType:
- all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
+ all_dtypes = [
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.INT48,
+ DType.FLOAT,
+ ]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
else:
@@ -6694,7 +7677,13 @@ class OutputShaper:
output_shape[i] = output_shape[i] + rng.integers(1, 10)
if error_name == ErrorIf.WrongOutputType:
- all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
+ all_dtypes = [
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.INT48,
+ DType.FLOAT,
+ ]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
else:
@@ -6706,7 +7695,13 @@ class OutputShaper:
def sliceOp(ser, rng, a, start, size, error_name=None):
if error_name == ErrorIf.WrongOutputType:
- all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
+ all_dtypes = [
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.INT48,
+ DType.FLOAT,
+ ]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
else:
@@ -6718,7 +7713,9 @@ class OutputShaper:
if output_shape[index] <= 2:
output_shape[index] = output_shape[index] + rng.choice([1, 2])
else:
- output_shape[index] = output_shape[index] + rng.choice([-2, -1, 1, 2])
+ output_shape[index] = output_shape[index] + rng.choice(
+ [-2, -1, 1, 2]
+ )
else:
output_shape = size.copy()
@@ -6734,7 +7731,13 @@ class OutputShaper:
output_shape[i] = a.shape[i] * multiples[i]
if error_name == ErrorIf.WrongOutputType:
- all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
+ all_dtypes = [
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.INT48,
+ DType.FLOAT,
+ ]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
else:
@@ -6756,7 +7759,13 @@ class OutputShaper:
output_shape[i] = a.shape[perms[i]]
if error_name == ErrorIf.WrongOutputType:
- all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
+ all_dtypes = [
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.INT48,
+ DType.FLOAT,
+ ]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
else:
@@ -6774,7 +7783,13 @@ class OutputShaper:
output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
if error_name == ErrorIf.WrongOutputType:
- all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
+ all_dtypes = [
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.INT48,
+ DType.FLOAT,
+ ]
wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
outputDType = rng.choice(wrong_dtypes)
else:
@@ -6795,7 +7810,13 @@ class OutputShaper:
output_shape = values_in.shape
if error_name == ErrorIf.WrongOutputType:
- all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
+ all_dtypes = [
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.INT48,
+ DType.FLOAT,
+ ]
wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
outputDType = rng.choice(wrong_dtypes)
else:
@@ -6810,7 +7831,13 @@ class OutputShaper:
assert input.dtype == DType.INT16 or input.dtype == DType.INT8
output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
if error_name == ErrorIf.WrongOutputType:
- wrong_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
+ wrong_dtypes = [
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.INT48,
+ DType.FLOAT,
+ ]
wrong_dtypes.remove(output_dtype)
output_dtype = rng.choice(wrong_dtypes)
return ser.addOutput(input.shape, output_dtype)
@@ -6829,17 +7856,37 @@ class OutputShaper:
output_dims,
input_dtype,
output_dtype,
- error_name = None
+ error_name=None,
):
if error_name == ErrorIf.WrongRank:
- output_dims = [input.shape[0], output_dims[0], output_dims[0], input.shape[0]]
+ output_dims = [
+ input.shape[0],
+ output_dims[0],
+ output_dims[0],
+ input.shape[0],
+ ]
else:
if error_name == ErrorIf.BatchMismatch:
- output_dims = [input.shape[0] + rng.integers(1, 10), output_dims[0], output_dims[1], input.shape[3]]
+ output_dims = [
+ input.shape[0] + rng.integers(1, 10),
+ output_dims[0],
+ output_dims[1],
+ input.shape[3],
+ ]
elif error_name == ErrorIf.ChannelMismatch:
- output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3] + rng.integers(1, 10)]
+ output_dims = [
+ input.shape[0],
+ output_dims[0],
+ output_dims[1],
+ input.shape[3] + rng.integers(1, 10),
+ ]
else:
- output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3]]
+ output_dims = [
+ input.shape[0],
+ output_dims[0],
+ output_dims[1],
+ input.shape[3],
+ ]
return serializer.addOutput(output_dims, output_dtype)
diff --git a/verif/generator/tosa_verif_build_tests.py b/verif/generator/tosa_verif_build_tests.py
index 09ee238..50f4033 100644
--- a/verif/generator/tosa_verif_build_tests.py
+++ b/verif/generator/tosa_verif_build_tests.py
@@ -1,38 +1,12 @@
-# Copyright (c) 2020-2021, ARM Limited.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
+# Copyright (c) 2020-2022, ARM Limited.
+# SPDX-License-Identifier: Apache-2.0
import argparse
-import sys
import re
-import os
-import subprocess
-import shlex
-import json
-import glob
-import math
-import queue
-import threading
-import traceback
-
-
-from enum import IntEnum, Enum, unique
-from datetime import datetime
from generator.tosa_test_gen import TosaTestGen
from serializer.tosa_serializer import dtype_str_to_val
+
# Used for parsing a comma-separated list of integers in a string
# to an actual list of integers
def str_to_list(in_s):
@@ -189,7 +163,7 @@ def parseArgs():
parser.add_argument(
"--test-type",
dest="test_type",
- choices=['positive', 'negative', 'both'],
+ choices=["positive", "negative", "both"],
default="positive",
type=str,
help="type of tests produced, postive, negative, or both",
@@ -205,8 +179,8 @@ def main():
ttg = TosaTestGen(args)
- if args.test_type == 'both':
- testType = ['positive', 'negative']
+ if args.test_type == "both":
+ testType = ["positive", "negative"]
else:
testType = [args.test_type]
results = []
@@ -220,7 +194,7 @@ def main():
shapeFilter=args.target_shapes,
rankFilter=args.target_ranks,
dtypeFilter=args.target_dtypes,
- testType=test_type
+ testType=test_type,
)
)
@@ -236,11 +210,12 @@ def main():
if args.verbose:
print(testStr)
- results.append(ttg.serializeTest(opName, testStr, dtype, error, shapeList, testArgs))
+ results.append(
+ ttg.serializeTest(opName, testStr, dtype, error, shapeList, testArgs)
+ )
print(f"Done creating {len(results)} tests")
-
if __name__ == "__main__":
exit(main())
diff --git a/verif/tests/test_json2numpy.py b/verif/tests/test_json2numpy.py
index aec555c..63bc2d9 100644
--- a/verif/tests/test_json2numpy.py
+++ b/verif/tests/test_json2numpy.py
@@ -6,7 +6,6 @@ import os
import numpy as np
import pytest
-
from json2numpy.json2numpy import main
diff --git a/verif/tests/test_tosa_result_checker.py b/verif/tests/test_tosa_result_checker.py
index bc8a2fc..efee23b 100644
--- a/verif/tests/test_tosa_result_checker.py
+++ b/verif/tests/test_tosa_result_checker.py
@@ -3,11 +3,10 @@
# SPDX-License-Identifier: Apache-2.0
from pathlib import Path
+import checker.tosa_result_checker as trc
import numpy as np
import pytest
-import checker.tosa_result_checker as trc
-
def _create_data_file(name, npy_data):
"""Create numpy data file."""
diff --git a/verif/tests/test_tosa_run_tests_mocksut.py b/verif/tests/test_tosa_run_tests_mocksut.py
index 98044e0..234f156 100644
--- a/verif/tests/test_tosa_run_tests_mocksut.py
+++ b/verif/tests/test_tosa_run_tests_mocksut.py
@@ -7,7 +7,6 @@ from pathlib import Path
from xml.dom import minidom
import pytest
-
from runner.tosa_verif_run_tests import main