aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2022-04-07 11:29:20 +0100
committerJeremy Johnson <jeremy.johnson@arm.com>2022-04-11 10:25:51 +0100
commit9a66abbd1da6547fd2cba1512d2f07fd1525de4d (patch)
treee06d3d728708f0ed4a57c369d4d1a48abdb5f607
parent7bebea8c086dc406d774e5a4419914748912089e (diff)
downloadreference_model-9a66abbd1da6547fd2cba1512d2f07fd1525de4d.tar.gz
Refactor verif/generator/tosa_test_gen.py into different files
Move all error & validation into tosa_error_if.py Move all argument and tensor generation into tosa_arg_gen.py Move utility functions into tosa_utils.py Create new TosaTensorValuesGen class for specialising tensor value generation. Change-Id: Ib9ac65e2308b14471a567c6f11d775c76585bc5b Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
-rw-r--r--verif/generator/tosa_arg_gen.py1809
-rw-r--r--verif/generator/tosa_error_if.py2082
-rw-r--r--verif/generator/tosa_test_gen.py4314
-rw-r--r--verif/generator/tosa_utils.py81
4 files changed, 4351 insertions, 3935 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
new file mode 100644
index 0000000..e3492cd
--- /dev/null
+++ b/verif/generator/tosa_arg_gen.py
@@ -0,0 +1,1809 @@
+# Copyright (c) 2021-2022, ARM Limited.
+# SPDX-License-Identifier: Apache-2.0
+import itertools
+import math
+
+import numpy as np
+import serializer.tosa_serializer as ts
+from generator.tosa_error_if import ErrorIf
+from generator.tosa_error_if import TosaErrorIfArgGen
+from serializer.tosa_serializer import DTypeNames
+from tosa.DType import DType
+from tosa.Op import Op
+from tosa.ResizeMode import ResizeMode
+
+# DTypeNames, DType, Op and ResizeMode are convenience variables to the
+# flatc-generated types that should be enums, but aren't
+
+
+class TosaQuantGen:
+ """QuantizedInfo random generator helper functions.
+
+ Specify with 'qgen': in the operator defintion.
+ """
+
+ def __init__(self):
+ pass
+
+ @staticmethod
+ def getQinfo(testGen, dtype, error_name=None):
+
+ if dtype == DType.INT8:
+ return testGen.randInt(-128, 128)
+ elif dtype == DType.UINT8:
+ return testGen.randInt(0, 256)
+ elif error_name in [
+ ErrorIf.InputZeroPointNotZero,
+ ErrorIf.WeightZeroPointNotZero,
+ ErrorIf.OutputZeroPointNotZero,
+ ]:
+ zero_point = testGen.randInt(-128, 128)
+ if zero_point == 0:
+ zero_point = 1
+ return zero_point
+ return 0
+
+ @staticmethod
+ def qgUnary(testGen, op, dtype, error_name=None):
+ qinfo = ts.TosaSerializerQuantInfo()
+ if error_name == ErrorIf.InputZeroPointNotZero:
+ qinfo.UnaryQuantInfo(
+ TosaQuantGen.getQinfo(testGen, dtype, error_name),
+ TosaQuantGen.getQinfo(testGen, dtype),
+ )
+ elif error_name == ErrorIf.OutputZeroPointNotZero:
+ qinfo.UnaryQuantInfo(
+ TosaQuantGen.getQinfo(testGen, dtype),
+ TosaQuantGen.getQinfo(testGen, dtype, error_name),
+ )
+ else:
+ qinfo.UnaryQuantInfo(
+ TosaQuantGen.getQinfo(testGen, dtype),
+ TosaQuantGen.getQinfo(testGen, dtype),
+ )
+ return qinfo
+
+ @staticmethod
+ def qgConv(testGen, op, dtype_or_dtypeList, error_name=None):
+ qinfo = ts.TosaSerializerQuantInfo()
+ if isinstance(dtype_or_dtypeList, list):
+ # a list of [input, weights, accumulator] dtypes
+ dtypeList = dtype_or_dtypeList
+ else:
+ # an int, [input, weights, accumulator] dtypes are the same
+ dtypeList = [dtype_or_dtypeList] * 3
+
+ if error_name == ErrorIf.InputZeroPointNotZero:
+ input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0], error_name)
+ weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1])
+ elif error_name == ErrorIf.WeightZeroPointNotZero:
+ input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0])
+ weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1], error_name)
+ else:
+ input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0])
+ weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1])
+
+ qinfo.ConvQuantInfo(input_zp, weights_zp)
+ return qinfo
+
+ @staticmethod
+ def qgMatmul(testGen, op, dtype, error_name=None):
+ qinfo = ts.TosaSerializerQuantInfo()
+ if error_name == ErrorIf.InputZeroPointNotZero:
+ qinfo.MatMulQuantInfo(
+ TosaQuantGen.getQinfo(testGen, dtype, error_name),
+ TosaQuantGen.getQinfo(testGen, dtype, error_name),
+ )
+ else:
+ qinfo.MatMulQuantInfo(
+ TosaQuantGen.getQinfo(testGen, dtype),
+ TosaQuantGen.getQinfo(testGen, dtype),
+ )
+ return qinfo
+
+ @staticmethod
+ def qgPad(testGen, op, dtype, error_name=None):
+ qinfo = ts.TosaSerializerQuantInfo()
+ if error_name == ErrorIf.InputZeroPointNotZero:
+ qinfo.PadQuantInfo(TosaQuantGen.getQinfo(testGen, dtype, error_name))
+ else:
+ qinfo.PadQuantInfo(TosaQuantGen.getQinfo(testGen, dtype))
+ return qinfo
+
+ @staticmethod
+ def computeMultiplierAndShift(scaleFp, scale32):
+ # Derived from computeMultiplierAndShiftTosaScale32
+ # Provide a floating-point scaling factor and the scale32 parameter
+ # to compute the multiplier and shift
+
+ if scale32:
+ scaleBits = 31
+ else:
+ scaleBits = 15
+
+ m, shift = math.frexp(scaleFp)
+
+ if scaleFp < 0.0:
+ m = -m
+
+ multiplier = round(m * (1 << scaleBits))
+ assert multiplier <= (1 << scaleBits)
+
+ if multiplier == (1 << scaleBits):
+ multiplier = multiplier // 2
+ shift = shift + 1
+
+ shift = (-shift) + scaleBits
+ # print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(
+ # scaleFp, scaleBits, m, multiplier, shift))
+
+ # Adjust multiplier such that shift is in allowed value range.
+ if shift == 0:
+ multiplier = multiplier // 4
+ shift = shift + 2
+ elif shift == 1:
+ multiplier = multiplier // 2
+ shift = shift + 1
+ elif shift == 63:
+ multiplier = multiplier * 2
+ shift = shift - 1
+
+ assert multiplier <= (1 << scaleBits)
+ assert shift >= 2 and shift <= 62
+
+ return multiplier, shift
+
+
+class TosaTensorGen:
+ """Tensor generators create a shape list for the placeholder and const tensor
+ data operands for the operator.
+
+ The actual random data is generated separately for each test.
+ """
+
+ def __init__(self):
+ pass
+
+ @staticmethod
+ def tgBasic(testGen, opName, rank, error_name=None):
+ pl, const = opName["operands"]
+ shape = testGen.makeShape(rank)
+
+ # Constrict the overall size of the shape when creating ERROR_IF tests
+ if error_name:
+ shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
+
+ shape_list = []
+ for i in range(pl + const):
+ shape_list.append(shape.copy())
+
+ if error_name == ErrorIf.RankMismatch:
+ if rank == 1 and i != 1:
+ shape = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
+ elif i != 1:
+ shape = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
+
+ return shape_list
+
+ @staticmethod
+ def tgNHWC(testGen, opName, rank, error_name=None):
+ pl, const = opName["operands"]
+
+ if error_name != ErrorIf.WrongRank:
+ assert rank == 4
+
+ shape = testGen.makeShape(rank)
+
+ # Constrict the batch size?
+ if testGen.args.max_batch_size:
+ shape[0] = (shape[0] % testGen.args.max_batch_size) + 1
+
+ # Constrict the overall size of the shape when creating ERROR_IF tests
+ if error_name and error_name != ErrorIf.MaxDimExceeded:
+ shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
+
+ shape_list = []
+ for i in range(pl + const):
+ shape_list.append(shape.copy())
+
+ return shape_list
+
+ @staticmethod
+ def tgScatter(testGen, opName, rank, error_name=None):
+ pl, const = opName["operands"]
+
+ assert pl == 2
+ assert const == 0
+ if error_name != ErrorIf.WrongRank:
+ assert rank == 3
+
+ values_in_shape = testGen.makeShape(rank)
+
+ # ignore max batch size if target shape is set
+ if testGen.args.max_batch_size and not testGen.args.target_shapes:
+ values_in_shape[0] = (values_in_shape[0] % testGen.args.max_batch_size) + 1
+
+ W = testGen.randInt(
+ testGen.args.tensor_shape_range[0], testGen.args.tensor_shape_range[1]
+ )
+ # Constrict W if one dimension is too large to keep tensor size reasonable
+ if max(values_in_shape) > 5000:
+ W = testGen.randInt(0, 16)
+
+ input_shape = [values_in_shape[0], W, values_in_shape[2]]
+
+ shape_list = []
+ shape_list.append(values_in_shape.copy())
+ shape_list.append(input_shape.copy())
+
+ return shape_list
+
+ @staticmethod
+ def tgBroadcastFuzz(testGen, op, rank, error_name=None):
+ shape = testGen.makeShape(rank)
+
+ pl, const = op["operands"]
+
+ shape_list = []
+
+ # Choose one of the inputs to broadcast
+ # Note: Simplifies OutputShaper code if we don't change first shape for errors
+ bcast_idx = testGen.randInt(0 if error_name is None else 1, pl + const)
+ for i in range(pl + const):
+ shape_bcast = shape.copy()
+
+ # If the chosen input, pick a random index to broadcast
+ if i == bcast_idx:
+ fuzz_idx = testGen.randInt(0, rank)
+ if error_name == ErrorIf.DimensionMismatch:
+ shape_bcast[fuzz_idx] += 1
+ elif error_name == ErrorIf.RankMismatch:
+ # Add one rank to the shape (or more for rank of 1)
+ extra_ranks = testGen.rng.choice([1, 2, 3]) if rank == 1 else 1
+ shape_bcast = np.concatenate(
+ (shape_bcast, testGen.makeShape(extra_ranks))
+ )
+ if rank != 1:
+ # Either keep the extra rank, or remove it
+ new_len = testGen.rng.choice([-2, len(shape_bcast)])
+ shape_bcast = shape_bcast[:new_len]
+ else:
+ shape_bcast[fuzz_idx] = 1
+
+ shape_list.append(shape_bcast)
+
+ return shape_list
+
+ @staticmethod
+ def tgConv2D(testGen, op, rank, error_name=None):
+ pl, const = op["operands"]
+
+ if error_name != ErrorIf.WrongRank:
+ assert rank == 4
+
+ # IFM dimensions are NHWC
+ ifm_shape = testGen.makeShape(rank)
+
+ # Constrict the batch size?
+ if testGen.args.max_batch_size:
+ ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
+
+ # Constrict the overall size of the shape when creating ERROR_IF tests
+ if error_name:
+ ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
+ ifm_shape, max_dim=24, max_items=10000
+ )
+
+ # Get the filter height/width from the operator parameters
+ filter_hw = op["filter"]
+
+ # Generate a random OFM depth
+ ofm_depth = testGen.makeShape(1)[0]
+
+ # The filter dimensions are OHWI
+ filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
+
+ # The bias is OC
+ bias_shape = np.asarray([ofm_depth])
+
+ return [ifm_shape, filter_shape, bias_shape]
+
+ @staticmethod
+ def tgConv3D(testGen, op, rank, error_name=None):
+ pl, const = op["operands"]
+
+ if error_name != ErrorIf.WrongRank:
+ assert rank == 5
+
+ # IFM dimensions are NDHWC
+ ifm_shape = testGen.makeShape(rank)
+
+ # Constrict the batch size?
+ if testGen.args.max_batch_size:
+ ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
+
+ # Constrict the overall size of the shape when creating ERROR_IF tests
+ if error_name:
+ ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
+ ifm_shape, max_dim=24, max_items=10000
+ )
+
+ # Get the filter depth/height/width from the operator parameters
+ filter_dhw = op["filter"]
+
+ # Generate a random OFM channel
+ ofm_channel = testGen.makeShape(1)[0]
+
+ # The filter dimensions are ODHWI
+ filter_shape = np.asarray(
+ [ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]]
+ )
+
+ # The bias is OC
+ bias_shape = np.asarray([ofm_channel])
+
+ return [ifm_shape, filter_shape, bias_shape]
+
+ @staticmethod
+ def tgTransposeConv2D(testGen, op, rank, error_name=None):
+ pl, const = op["operands"]
+
+ if error_name != ErrorIf.WrongRank:
+ assert rank == 4
+
+ # IFM dimensions are NHWC
+ ifm_shape = testGen.makeShape(rank)
+
+ # Constrict the batch size?
+ if testGen.args.max_batch_size:
+ ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
+
+ # Constrict the overall size of the shape when creating ERROR_IF tests
+ if error_name:
+ ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
+ ifm_shape, max_dim=24, max_items=10000
+ )
+
+ # Get the filter height/width from the operator parameters
+ filter_hw = op["filter"]
+
+ # Generate a random OFM depth
+ ofm_depth = testGen.makeShape(1)[0]
+
+ # The filter dimensions are OHWI
+ filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
+
+ # The bias is OC
+ bias_shape = np.asarray([ofm_depth])
+
+ return [ifm_shape, filter_shape, bias_shape]
+
+ @staticmethod
+ def tgDepthwiseConv2D(testGen, op, rank, error_name=None):
+ pl, const = op["operands"]
+
+ if error_name != ErrorIf.WrongRank:
+ assert rank == 4
+ assert pl == 1 and const == 2
+
+ # IFM dimensions are NHWC
+ ifm_shape = testGen.makeShape(rank)
+
+ # Constrict the batch size?
+ if testGen.args.max_batch_size:
+ ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
+
+ # Constrict the overall size of the shape when creating ERROR_IF tests
+ if error_name:
+ ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
+ ifm_shape, max_dim=24, max_items=10000
+ )
+
+ # Get the filter height/width from the operator parameters
+ # Filter is KH, HW, C, M
+ filter_hw = op["filter"]
+
+ # Generate a random OFM depth, but don't let it get too big because
+ # the output depth is M * C
+ filter_m = (
+ testGen.makeShape(1)[0] % (testGen.args.tensor_shape_range[1] // 4)
+ ) + 1
+
+ # The filter dimensions are HWCM
+ filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
+
+ # The bias is M * C
+ bias_shape = np.asarray([ifm_shape[3] * filter_m])
+
+ return [ifm_shape, filter_shape, bias_shape]
+
+ @staticmethod
+ def tgFullyConnected(testGen, op, rank, error_name=None):
+ pl, const = op["operands"]
+
+ if error_name != ErrorIf.WrongRank:
+ assert rank == 2
+
+ input_shape = testGen.makeShape(rank)
+
+ # Constrict the overall size of the shape when creating ERROR_IF tests
+ if error_name:
+ input_shape = TosaErrorIfArgGen.eiRestrictDimensions(input_shape)
+
+ filter_oc = testGen.rng.integers(
+ low=testGen.args.tensor_shape_range[0],
+ high=testGen.args.tensor_shape_range[1],
+ size=1,
+ )[0]
+ filter_shape = np.asarray([filter_oc, input_shape[1]])
+
+ bias_shape = np.asarray([filter_oc])
+
+ return [input_shape, filter_shape, bias_shape]
+
+ @staticmethod
+ def tgMatmul(testGen, op, rank, error_name=None):
+ pl, const = op["operands"]
+
+ if error_name != ErrorIf.WrongRank:
+ assert rank == 3
+ assert pl == 2 and const == 0
+
+ a_shape = testGen.makeShape(rank)
+
+ # Constrict the overall size of the shape when creating ERROR_IF tests
+ if error_name:
+ a_shape = TosaErrorIfArgGen.eiRestrictDimensions(a_shape)
+
+ # Get a random number for b_oc even if target shape is defined
+ b_oc = np.int32(
+ testGen.rng.integers(
+ low=testGen.args.tensor_shape_range[0],
+ high=testGen.args.tensor_shape_range[1],
+ size=1,
+ )
+ )[0]
+ # If N or H is large let b_oc be 1 to reduce output tensor size
+ if max(a_shape) > 1000:
+ b_oc = 1
+
+ b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
+ return [a_shape, b_shape]
+
+ @staticmethod
+ def tgConcat(testGen, opName, rank, error_name=None):
+ pl, const = opName["operands"]
+ shape = testGen.makeShape(rank)
+
+ # Create extra tensors to concat.
+ # Take into account value of pl when getting maximum number of concats
+ num_tensors = testGen.randInt(0, 4)
+ shape_list = []
+ for i in range(pl + const + num_tensors):
+ if error_name == ErrorIf.ConcatInputRankMismatch and i != 0:
+ remove = testGen.rng.choice([True, False])
+ wrongShape = shape.copy()
+
+ if remove and len(shape) > 1:
+ wrongShape = wrongShape[1:]
+ else:
+ wrongShape = list(wrongShape)
+ wrongShape.append(testGen.rng.integers(1, 10))
+
+ shape_list.append(wrongShape)
+ else:
+ shape_list.append(shape.copy())
+
+ return shape_list
+
+ @staticmethod
+ def tgConcatConstInput(testGen, shapeList, axis, error_name=None):
+ if error_name in [
+ ErrorIf.AxisSmallerZero,
+ ErrorIf.AxisLargerRank,
+ ErrorIf.ConcatInputRankMismatch,
+ ]:
+ return shapeList
+
+ # Split concat shape along axis to allow for multiple const inputs
+ # without making too many large tensors
+ if len(shapeList) == 2 or shapeList[0][axis] < len(shapeList):
+ # If axis can't be split we still need to invalidate other dimensions
+ if error_name == ErrorIf.ConcatInputDimMismatch:
+ for shape in shapeList[1:]:
+ # Negative test shapeLists are created individually for each test,
+ # so no need to copy the shape before altering it.
+ shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
+ return shapeList
+
+ # Create copy of shape we are going to split (so we don't alter shapeList)
+ shape = shapeList[0].copy()
+ # Add original shape as first input
+ new_shapeList = [shape.copy()]
+ length_on_axis = shape[axis]
+ remaining_length = length_on_axis
+ for i in range(len(shapeList) - 2):
+ # Calculate split on axis and remaining value
+ split_shape_val = int(shape[axis] / 2)
+ remaining_length = remaining_length - split_shape_val
+
+ # Append new shape, and set remaining shape
+ shape[axis] = split_shape_val
+ new_shapeList.append(shape.copy())
+
+ # invalidate dimensions
+ if error_name == ErrorIf.ConcatInputDimMismatch:
+ shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
+ else:
+ shape[axis] = remaining_length
+
+ if i == len(shapeList) - 3:
+ new_shapeList.append(shape.copy())
+
+ return new_shapeList
+
+
+class TosaTensorValuesGen:
+ """Tensor Value generators create the random data for each test."""
+
+ def __init__(self):
+ pass
+
+ @staticmethod
+ def tvgDefault(testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None):
+ pCount, cCount = op["operands"]
+
+ tens = []
+ tens.extend(
+ testGen.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
+ )
+ tens.extend(testGen.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
+
+ return tens
+
+ @staticmethod
+ def tvgNegate(testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None):
+ if dtypeList[0] != DType.FLOAT and error_name is None:
+ pCount, cCount = op["operands"]
+ assert (
+ pCount == 1 and cCount == 0
+ ), "Op.NEGATE must have 1 placeholders, 0 consts"
+ # Must create tensors with values within negatable ranges
+ if dtypeList[0] == DType.INT8:
+ # Must be within int8, adjustable by input_zp and then negatable
+ # and be within int8
+ # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
+ max_val = min(127, 127 + qinfo.ints[0][1])
+ min_val = max(-127, -127 + qinfo.ints[0][1])
+ elif dtypeList[0] == DType.INT16:
+ max_val = 32767
+ min_val = -max_val
+ else:
+ assert (
+ dtypeList[0] == DType.INT32
+ ), "Op.NEGATE found with unsupported input type"
+ max_val = (1 << 31) - 1
+ min_val = -max_val
+ arr = np.int32(
+ testGen.rng.integers(low=min_val, high=(max_val + 1), size=shapeList[0])
+ )
+ placeholders = []
+ placeholders.append(
+ testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], arr)
+ )
+ return placeholders
+ else:
+ return TosaTensorValuesGen.tvgDefault(
+ testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name
+ )
+
+ @staticmethod
+ def tvgAddSub(testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None):
+ if dtypeList[0] == DType.INT32 and error_name is None:
+ # Make sure the operation does not cause value saturation - where
+ # the number wraps due to limited number of bits to store the answer
+ pCount, cCount = op["operands"]
+ assert (
+ pCount == 2 and cCount == 0
+ ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
+ placeholders = []
+ add = op["op"] == Op.ADD
+ a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
+ b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
+ if add:
+ res_arr = np.add(a_arr, b_arr, dtype=np.int64)
+ else:
+ res_arr = np.subtract(a_arr, b_arr, dtype=np.int64)
+
+ # Work out the saturation limits
+ max_i32 = (1 << 31) - 1
+ min_i32 = -(1 << 31)
+ max_arr = np.full(shapeList[1], max_i32)
+ min_arr = np.full(shapeList[1], min_i32)
+
+ # Find how much values exceed the maximum/minimums
+ sat_max_arr = np.maximum(res_arr - max_arr, 0)
+ sat_min_arr = np.minimum(res_arr - min_arr, 0)
+
+ if not add:
+ # Swap saturation values and negate values as we need to perform opposite operations
+ sat_max_arr, sat_min_arr = -sat_min_arr, -sat_max_arr
+
+ # Create new array of unsaturated values by clipping values as needed
+ b_unsat_arr = b_arr
+ if (sat_max_arr != 0).any():
+ # Clip values that cause saturation
+ b_unsat_arr = np.subtract(b_unsat_arr, sat_max_arr, dtype=np.int32)
+ # Reduce axes in unsaturated tensor to match original tensor
+ for axis, dim in enumerate(b_arr.shape):
+ if dim != b_unsat_arr.shape[axis]:
+ assert (
+ dim == 1
+ ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
+ b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True)
+
+ if (sat_min_arr != 0).any():
+ # Clip values that cause saturation
+ b_unsat_arr = np.subtract(b_unsat_arr, sat_min_arr, dtype=np.int32)
+ # Reduce axes in unsaturated tensor to match original tensor
+ for axis, dim in enumerate(b_arr.shape):
+ if dim != b_unsat_arr.shape[axis]:
+ assert (
+ dim == 1
+ ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
+ b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True)
+
+ placeholders.append(
+ testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
+ )
+ placeholders.append(
+ testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_unsat_arr)
+ )
+
+ return placeholders
+ else:
+ return TosaTensorValuesGen.tvgDefault(
+ testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name
+ )
+
+ @staticmethod
+ def tvgCondIfWhileLoop(
+ testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None
+ ):
+ if dtypeList[0] in (
+ DType.INT32,
+ DType.INT16,
+ DType.INT8,
+ ):
+ # Limit input tensors with cond_if_binary or while_loop to stop
+ # saturation of add/sub ops with int32 and keep all logical shift
+ # values between 0 to 31 for int16 or int8
+ pCount, cCount = op["operands"]
+ pRemain = pCount
+ placeholders = []
+ for idx, shape in enumerate(shapeList[:]):
+ if dtypeList[0] == DType.INT32:
+ arr = testGen.getRandTensor(shapeList[idx], DType.INT16)
+ else:
+ arr = np.int32(
+ testGen.rng.integers(low=0, high=32, size=shapeList[idx])
+ )
+ if pRemain > 0:
+ placeholders.append(
+ testGen.ser.addPlaceholder(shape, dtypeList[idx], arr)
+ )
+ pRemain -= 1
+ else:
+ placeholders.append(
+ testGen.ser.addConst(shape, dtypeList[idx], arr)
+ )
+
+ return placeholders
+ else:
+ return TosaTensorValuesGen.tvgDefault(
+ testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name
+ )
+
+ @staticmethod
+ def tvgArithmeticRightShift(
+ testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None
+ ):
+ pCount, cCount = op["operands"]
+ # Force value of operand[1] to be within [0, num_bits]
+ assert (
+ pCount == 2 and cCount == 0
+ ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
+
+ placeholders = []
+ for idx, shape in enumerate(shapeList[:]):
+ if idx == 1:
+ if dtypeList[idx] == DType.INT8:
+ arr = np.int32(testGen.rng.integers(low=0, high=8, size=shape))
+ elif dtypeList[idx] == DType.INT16:
+ arr = np.int32(testGen.rng.integers(low=0, high=16, size=shape))
+ elif dtypeList[idx] == DType.INT32:
+ arr = np.int32(testGen.rng.integers(low=0, high=32, size=shape))
+ elif error_name == ErrorIf.WrongInputType:
+ arr = np.int32(testGen.rng.integers(low=0, high=8, size=shape))
+ else:
+ raise Exception("OpArithmeticRightShift: invalid input dtype")
+ else:
+ arr = testGen.getRandTensor(shape, dtypeList[idx])
+ placeholders.append(testGen.ser.addPlaceholder(shape, dtypeList[idx], arr))
+
+ return placeholders
+
+ @staticmethod
+ def tvgSelect(testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None):
+ # Set datatype of condition tensor to boolean
+ dtypeList[0] = DType.BOOL
+
+ return TosaTensorValuesGen.tvgDefault(
+ testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name
+ )
+
+ @staticmethod
+ def tvgIntDiv(testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None):
+ if error_name is None:
+ pCount, cCount = op["operands"]
+ assert (
+ pCount == 2 and cCount == 0
+ ), "Op.INTDIV must have 2 placeholders, 0 consts"
+
+ placeholders = []
+
+ # Two invalid cases for Op.INTDIV:
+ # 1. divisor == 0
+ # 2. dividend == -(1<<31) and divisor == -1
+ while True:
+ dividend_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
+ divisor_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
+
+ if (divisor_arr == 0).any():
+ continue
+
+ if (dividend_arr == -(2**31)).any() and (divisor_arr == -1).any():
+ continue
+
+ break
+
+ placeholders.append(
+ testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
+ )
+ placeholders.append(
+ testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
+ )
+
+ return placeholders
+ else:
+ return TosaTensorValuesGen.tvgDefault(
+ testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name
+ )
+
+ @staticmethod
+ def tvgMul(testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None):
+ if error_name is None:
+ pCount, cCount = op["operands"]
+ assert (
+ pCount == 2 and cCount == 0
+ ), "Op.MUL must have 2 placeholders, 0 consts"
+
+ tens = []
+ if dtypeList[0] == DType.FLOAT:
+ tens.extend(testGen.buildPlaceholderTensors(shapeList[:], dtypeList[:]))
+ else:
+ placeholders = []
+
+ # Make sure multiply result in int32 range
+ shift = testArgs[0]
+ if dtypeList[0] == DType.INT8:
+ num_bits = 8
+ elif dtypeList[0] == DType.INT16:
+ num_bits = 16
+ elif dtypeList[0] == DType.INT32:
+ num_bits = 32
+ elif error_name == ErrorIf.WrongInputType:
+ num_bits = 8
+ else:
+ raise Exception("OpMul: invalid input dtype")
+
+ for idx, shape in enumerate(shapeList[:]):
+ low = -(2 ** (num_bits - 1))
+ high = (2 ** (num_bits - 1)) - 1
+
+ a_arr = np.int32(
+ testGen.rng.integers(low=low, high=high, size=shapeList[0])
+ )
+ b_arr = np.int32(
+ testGen.rng.integers(low=low, high=high, size=shapeList[1])
+ )
+
+ i = 0
+ while True:
+
+ a_arr_64 = a_arr.astype(np.int64)
+ b_arr_64 = b_arr.astype(np.int64)
+
+ if shift > 0:
+ rounding = 1 << (shift - 1)
+ result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
+ else:
+ result_arr = a_arr_64 * b_arr_64
+
+ if (result_arr > -(2**31)).all() and (
+ result_arr <= ((2**31) - 1)
+ ).all():
+ break
+
+ i = i + 1
+ a_arr = a_arr // 2
+ b_arr = b_arr // 2
+
+ placeholders.append(
+ testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
+ )
+ placeholders.append(
+ testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
+ )
+
+ tens.extend(placeholders)
+
+ return tens
+ else:
+ return TosaTensorValuesGen.tvgDefault(
+ testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name
+ )
+
+ @staticmethod
+ def tvgConcat(testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None):
+ count = len(shapeList) - testGen.args.num_const_inputs_concat
+ if count < 1:
+ count = 1
+ if testGen.args.num_const_inputs_concat == 0:
+ count = len(shapeList)
+
+ # Ensure axis is an int
+ testArgs[0] = int(testArgs[0])
+
+ shapeList = TosaTensorGen.tgConcatConstInput(
+ testGen, shapeList, testArgs[0], error_name
+ )
+
+ tens = []
+ tens.extend(
+ testGen.buildPlaceholderTensors(shapeList[0:count], dtypeList[0:count])
+ )
+ tens.extend(testGen.buildConstTensors(shapeList[count:], dtypeList[count:]))
+
+ return tens
+
+ @staticmethod
+ def tvgLogicalShift(
+ testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None
+ ):
+ pCount, cCount = op["operands"]
+ assert (
+ pCount == 2 and cCount == 0
+ ), "Op.LOGICAL_LEFT_SHIFT or Op.LOGICAL_RIGHT_SHIFT must have 2 placeholders, 0 consts"
+ values_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
+ shift_arr = np.int32(testGen.rng.integers(low=0, high=32, size=shapeList[1]))
+ placeholders = []
+ placeholders.append(
+ testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], values_arr)
+ )
+ placeholders.append(
+ testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], shift_arr)
+ )
+
+ return placeholders
+
+ @staticmethod
+ def tvgEqual(testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None):
+ if error_name is None:
+ pCount, cCount = op["operands"]
+ assert (
+ pCount == 2 and cCount == 0
+ ), "Op.EQUAL must have 2 placeholders, 0 consts"
+ a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
+ b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
+ # Using random numbers means that it will be very unlikely that
+ # there are any matching (equal) values, therefore force that
+ # there are twice the number of matching values as the tensor rank
+ for num in range(0, len(shapeList[0]) * 2):
+ a_index = []
+ b_index = []
+ # Choose an index in each axis for the whole shape
+ for axis in range(0, len(shapeList[0])):
+ # Index can be up to the largest dimension in both shapes
+ index = np.int32(
+ testGen.rng.integers(
+ 0, max(shapeList[0][axis], shapeList[1][axis])
+ )
+ )
+ # Reduce the index down to a shape's dim for broadcasting
+ a_index.append(min(shapeList[0][axis] - 1, index))
+ b_index.append(min(shapeList[1][axis] - 1, index))
+
+ a_arr[tuple(a_index)] = b_arr[tuple(b_index)]
+
+ placeholders = []
+ placeholders.append(
+ testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
+ )
+ placeholders.append(
+ testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
+ )
+ return placeholders
+ else:
+ return TosaTensorValuesGen.tvgDefault(
+ testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name
+ )
+
+ @staticmethod
+ def tvgReduceSum(
+ testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None
+ ):
+ if dtypeList[0] == DType.INT32:
+ pCount, cCount = op["operands"]
+ assert (
+ pCount == 1 and cCount == 0
+ ), "Op.REDUCE_SUM must have 1 placeholders, 0 consts"
+ # Limit values so that the sum cannot exceed the range of an int32 during
+ # summation of any axis
+ range_val = int((1 << 31) / max(shapeList[0]))
+ values_arr = np.int32(
+ testGen.rng.integers(low=-range_val, high=range_val, size=shapeList[0])
+ )
+ placeholders = []
+ placeholders.append(
+ testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], values_arr)
+ )
+ return placeholders
+ else:
+ return TosaTensorValuesGen.tvgDefault(
+ testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name
+ )
+
+
+class TosaArgGen:
+ """Argument generators create exhaustive or random lists of attributes for
+ operators that take attributes or other parameters.
+
+ The return value is a list of (descriptive_name, [arglist]) tuples where
+ the descriptive_name is appended to the test name and the arglist is expanded
+ as arguments to the operator build function.
+ """
+
+ def __init__(self):
+ pass
+
+ @staticmethod
+ def agNone(testGen, opName, shapeList, dtype, error_name=None):
+ """A trivial argument generator for operators that don't take any
+ non-tensor arguments"""
+ return [("", [])]
+
+ @staticmethod
+ def agAxis(testGen, opName, shapeList, dtype, error_name=None):
+ """Build the axis argument for operators that take a single axis"""
+ axes = []
+ shape = shapeList[0]
+
+ if error_name == ErrorIf.AxisSmallerZero:
+ small_axis = testGen.rng.integers(-5, 0)
+ axes.append(("axis{}".format(small_axis), [small_axis]))
+ elif error_name == ErrorIf.AxisLargerRank:
+ large_axis = testGen.rng.integers(len(shape) + 1, len(shape) + 10)
+ axes.append(("axis{}".format(large_axis), [large_axis]))
+ else:
+ for a in range(0, len(shape)):
+ axes.append(("axis{}".format(a), [a]))
+
+ return axes
+
+ @staticmethod
+ def agConv(testGen, opName, shapeList, dtype, error_name=None):
+ arg_list = []
+
+ ifm_shape = shapeList[0]
+ filter_shape = shapeList[1]
+ # determine the kernel shape from operator name (e.g. "conv2d_3x3" => [3,3])
+ k = [int(x) for x in opName.split("_")[-1].split("x")]
+
+ # Check the rank
+ rank = 5 if opName.startswith("conv3d") else 4
+ if error_name != ErrorIf.WrongRank:
+ assert len(ifm_shape) == rank
+ assert len(filter_shape) == rank
+
+ # kernel rank omits batch and channels
+ k_rank = rank - 2
+ assert len(k) == k_rank
+
+ # Generate comprehensive argument lists
+ # - except for named errors, which use specific invalid value(s)
+ if error_name == ErrorIf.PadSmallerZero:
+ p_vals = [testGen.rng.choice(range(-5, 0))]
+ else:
+ p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
+ paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))}
+ if error_name == ErrorIf.StrideSmallerOne:
+ # Can't use stride=0, as it is used to derive output shape, as a divisor
+ s_vals = [testGen.rng.choice(range(-5, 0))]
+ else:
+ s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
+ strides = {x for x in itertools.product(*([s_vals] * k_rank))}
+ if error_name == ErrorIf.DilationSmallerOne:
+ d_vals = [testGen.rng.choice(range(-5, 1))]
+ else:
+ d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
+ dilations = {x for x in itertools.product(*([d_vals] * k_rank))}
+
+ if not error_name and testGen.args.oversize:
+ # add some oversize argument values
+ if max(ifm_shape) < 64:
+ bigPadding = 9
+ paddings.update(
+ {x for x in itertools.product(*([[0, bigPadding]] * (k_rank * 2)))}
+ )
+ bigStride = 8
+ strides.update({x for x in itertools.product(*([[1, bigStride]] * k_rank))})
+ bigDilation = 7
+ dilations.update(
+ {x for x in itertools.product(*([[1, bigDilation]] * k_rank))}
+ )
+
+ # There are too many parameter combinations, so generate them sparsely,
+ # very sparse for negative tests
+ sparsity_factor = 2 if error_name else 100
+ sparsity = len(paddings) * len(strides) * len(dilations) // sparsity_factor + 1
+ # If there are only a small number of tests, just select them all
+ if sparsity < 13:
+ sparsity = 1
+ # To get a variety of parameter combinations sparsity should not be a
+ # multiple of 2, 3 or 5
+ while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
+ sparsity += 1
+
+ n = 0
+ for s in sorted(list(strides)):
+ for p in sorted(list(paddings)):
+ for d in sorted(list(dilations)):
+ if (
+ n % sparsity == 0
+ # padding must not exceed the kernel size ?
+ # and p[0] < k[0] and p[1] < k[0]
+ # and p[2] < k[1] and p[3] < k[1]
+ # and (k_rank < 3 or (p[4] < k[2] and p[5] < k[2]))
+ # the padded shape must exceed the kernel size
+ and (ifm_shape[1] + p[0] + p[1]) > k[0]
+ and (ifm_shape[2] + p[2] + p[3]) > k[1]
+ and (k_rank < 3 or ((ifm_shape[3] + p[4] + p[5]) > k[2]))
+ # the padded shape must exceed the dilation
+ and (ifm_shape[1] + p[0] + p[1]) > d[0]
+ and (ifm_shape[2] + p[2] + p[3]) > d[1]
+ and (k_rank < 3 or ((ifm_shape[3] + p[4] + p[5]) > d[2]))
+ ):
+ arg_list.append(
+ (
+ "st{}_pad{}_dilat{}".format(
+ "".join([str(x) for x in s]),
+ "".join([str(x) for x in p]),
+ "".join([str(x) for x in d]),
+ ),
+ [s, p, d],
+ )
+ )
+ n += 1
+
+ return arg_list
+
+ @staticmethod
+ def agTransposeConv2D(testGen, opName, shapeList, dtype, error_name=None):
+ arg_list = []
+
+ ifm_shape = shapeList[0]
+ filter_shape = shapeList[1]
+
+ # Must be rank 4
+ if error_name != ErrorIf.WrongRank:
+ assert len(ifm_shape) == 4
+ assert len(filter_shape) == 4
+
+ # Generate comprehensive argument lists
+ # - except for named errors, which use specific invalid value(s)
+ if error_name == ErrorIf.PadSmallerZero:
+ p_vals = [testGen.rng.choice(range(-5, 0))]
+ else:
+ p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
+ paddings = {x for x in itertools.product(*([p_vals] * 2))}
+ if error_name == ErrorIf.StrideSmallerOne:
+ # Can't use stride=0, as it is used to derive output shape, as a divisor
+ s_vals = [testGen.rng.choice(range(-5, 0))]
+ else:
+ s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
+ strides = {x for x in itertools.product(*([s_vals] * 2))}
+ if error_name == ErrorIf.DilationSmallerOne:
+ d_vals = [testGen.rng.choice(range(-5, 1))]
+ else:
+ d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
+ dilations = {x for x in itertools.product(*([d_vals] * 2))}
+
+ if not error_name:
+ # add some oversize argument values
+ if max(ifm_shape) < 64:
+ bigPadding = 9
+ paddings.update(
+ {x for x in itertools.product(*([[0, bigPadding]] * 2))}
+ )
+ bigStride = 8
+ strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
+ bigDilation = 7
+ dilations.update({x for x in itertools.product(*([[1, bigDilation]] * 2))})
+
+ # There are too many parameter combinations, so generate them sparsely,
+ # very sparse for negative tests
+ sparsity_factor = 2 if error_name else 100
+ sparsity = len(paddings) * len(strides) * len(dilations) // sparsity_factor + 1
+ # If there are only a small number of tests, just select them all
+ if sparsity < 13:
+ sparsity = 1
+ # To get a variety of parameter combinations sparsity should not be a
+ # multiple of 2, 3 or 5
+ while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
+ sparsity += 1
+
+ n = 0
+ for s in sorted(list(strides)):
+ for p in sorted(list(paddings)):
+ for d in sorted(list(dilations)):
+ if n % sparsity == 0:
+ # Determine the output shape
+ oh = (
+ ifm_shape[1]
+ - filter_shape[1]
+ - (filter_shape[1] - 1) * (d[0] - 1)
+ + 2 * p[0]
+ ) // s[0] + 1
+ ow = (
+ ifm_shape[2]
+ - filter_shape[2]
+ - (filter_shape[2] - 1) * (d[1] - 1)
+ + 2 * p[1]
+ ) // s[1] + 1
+ os = [ifm_shape[0], oh, ow, filter_shape[0]]
+ arg_list.append(
+ (
+ "st{}_pad{}_dilat{}_os{}".format(
+ "".join([str(x) for x in s]),
+ "".join([str(x) for x in p]),
+ "".join([str(x) for x in d]),
+ "x".join([str(x) for x in os]),
+ ),
+ [s, p, d, os],
+ )
+ )
+ n += 1
+
+ return arg_list
+
+ @staticmethod
+ def agPad(testGen, opName, shapeList, dtype, error_name=None):
+ arg_list = []
+ rank = len(shapeList[0])
+
+ # Exhaustively test combinations of padding on each side of each dimension
+ # - the range of padding values is defined by pad_min and pad_max
+ # - for padding >9, the name format needs to be more distinctive
+ pad_min, pad_max = 0, 1
+ pad_values = [x for x in range(pad_min, pad_max + 1)]
+ if error_name == ErrorIf.PadSmallerZero:
+ pad_values = [x for x in range(-2, 0)]
+ axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
+ shape_pad_values = itertools.product(*([axis_pad_values] * rank))
+
+ if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]:
+ pad_const_int = testGen.getRandNumberDType(dtype)
+ pad_const_fp = 0
+ elif dtype == DType.FLOAT:
+ pad_const_int = 0
+ pad_const_fp = testGen.getRandNumberDType(dtype)
+ else:
+ return []
+
+ for paddings in shape_pad_values:
+ name = "pad"
+ for r in range(rank):
+ before, after = paddings[r]
+ name = f"{name}{before}{after}"
+ arg_list.append((name, [np.array(paddings), pad_const_int, pad_const_fp]))
+
+ return arg_list
+
+ @staticmethod
+ def agPooling(testGen, opName, shapeList, dtype, error_name=None):
+ arg_list = []
+
+ shape = shapeList[0]
+ if error_name != ErrorIf.WrongRank:
+ assert len(shape) == 4
+
+ # Generate comprehensive argument lists
+ p_vals = [x for x in range(0, testGen.args.max_pooling_padding + 1)]
+ paddings = {x for x in itertools.product(*([p_vals] * 4))}
+ s_vals = [x for x in range(1, testGen.args.max_pooling_stride + 1)]
+ strides = {x for x in itertools.product(*([s_vals] * 2))}
+ k_vals = [x for x in range(2, testGen.args.max_pooling_kernel + 1)]
+ kernels = {x for x in itertools.product(*([k_vals] * 2))}
+
+ if testGen.args.oversize:
+ # add some oversize argument values
+ bigStride = 7
+ strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
+ bigKernel = 6
+ kernels.update({x for x in itertools.product(*([[2, bigKernel]] * 2))})
+ if max(shape) < 64:
+ # padding must be less than the kernel size
+ bigPadding = bigKernel - 1
+ paddings.update(
+ {x for x in itertools.product(*([[0, bigPadding]] * 4))}
+ )
+
+ # There are too many parameter combinations, so generate them sparsely,
+ # very sparse for negative tests
+ sparsity_factor = 2 if error_name else 500
+ sparsity = len(paddings) * len(strides) * len(kernels) // sparsity_factor + 1
+
+ n = 0
+ for s in sorted(list(strides)):
+ for p in sorted(list(paddings)):
+ for k in sorted(list(kernels)):
+ if error_name in [
+ ErrorIf.StrideSmallerOne,
+ ErrorIf.KernelSmallerOne,
+ ErrorIf.PadSmallerZero,
+ ErrorIf.PadLargerEqualKernel,
+ ]:
+ sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf(
+ testGen, error_name, s, p, k
+ )
+ if None not in [sNew, pNew, kNew] and n % sparsity == 0:
+ arg_list.append(
+ (
+ "st{}_kern{}_pad{}".format(
+ "".join([str(x) for x in sNew]),
+ "".join([str(x) for x in kNew]),
+ "".join([str(x) for x in pNew]),
+ ),
+ [sNew, pNew, kNew],
+ )
+ )
+ elif (
+ n % sparsity == 0
+ # padding must not exceed the kernel size
+ and p[0] < k[0]
+ and p[1] < k[0]
+ and p[2] < k[1]
+ and p[3] < k[1]
+ # the padded shape must exceed the kernel size
+ and (shape[1] + p[0] + p[1]) > k[0]
+ and (shape[2] + p[2] + p[3]) > k[1]
+ ):
+ arg_list.append(
+ (
+ "st{}_kern{}_pad{}".format(
+ "".join([str(x) for x in s]),
+ "".join([str(x) for x in k]),
+ "".join([str(x) for x in p]),
+ ),
+ [s, p, k],
+ )
+ )
+ n += 1
+
+ return arg_list
+
+ @staticmethod
+ def agCast(testGen, opName, shapeList, inDtype, error_name=None):
+ arg_list = []
+
+ # Enumerate the output types here
+ if error_name == ErrorIf.WrongOutputType:
+ dtypeList = TosaErrorIfArgGen.eiCastErrorIf(testGen, inDtype)
+ elif inDtype == DType.INT8:
+ dtypeList = [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT]
+ elif inDtype == DType.INT16:
+ dtypeList = [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT]
+ elif inDtype == DType.INT32:
+ dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
+ elif inDtype == DType.BOOL:
+ dtypeList = [DType.INT8, DType.INT16, DType.INT32]
+ elif inDtype == DType.FLOAT:
+ dtypeList = [DType.INT8, DType.INT16, DType.INT32]
+ elif error_name == ErrorIf.WrongInputType:
+ # Pick some potentially correct output type for incorrect input type
+ dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
+ else:
+ raise Exception("Unexpected input dtype: {}".format(inDtype))
+
+ for dtype in dtypeList:
+ arg_list.append(("out{}".format(DTypeNames[dtype]), [dtype]))
+
+ return arg_list
+
+ @staticmethod
+ def agRescale(testGen, opName, shapeList, inDtype, error_name=None):
+ arg_list = []
+
+ # Enumerate the output types here
+ for dtype in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
+ if (
+ dtype in [DType.UINT8, DType.INT8]
+ and error_name == ErrorIf.OutputZeroPointNotZero
+ ):
+ continue
+ if (
+ inDtype == DType.UINT8
+ and dtype != DType.INT8
+ and error_name != ErrorIf.WrongOutputType
+ ):
+ # The only output dtype for UINT8 is INT8, skip all other combinations
+ continue
+ if (
+ inDtype != DType.INT8
+ and dtype == DType.UINT8
+ and error_name != ErrorIf.WrongOutputType
+ ):
+ # The only input dtype for UINT8 is INT8, skip all other combinations
+ continue
+ if (
+ error_name == ErrorIf.WrongOutputType
+ and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, dtype)
+ ):
+ continue
+
+ for scale32 in [False, True]:
+ if error_name == ErrorIf.ScaleTrue and not scale32:
+ continue
+ elif error_name == ErrorIf.ScaleNotTrue and scale32:
+ continue
+ for double_round in [False, True]:
+ if error_name == ErrorIf.ScaleNotTrue and not double_round:
+ continue
+ for per_channel in [False, True]:
+
+ if (
+ inDtype == DType.INT48
+ and scale32
+ and error_name != ErrorIf.ScaleTrue
+ ):
+ # Illegal condition. Must be scale32=False
+ continue
+ if (
+ double_round
+ and not scale32
+ and error_name != ErrorIf.ScaleNotTrue
+ ):
+ # Illegal condition. ERROR_IF(!scale32 && double_round)
+ continue
+
+ arg_list.append(
+ (
+ "out{}_sc{}_dr{}_pc{}".format(
+ DTypeNames[dtype],
+ int(scale32),
+ int(double_round),
+ int(per_channel),
+ ),
+ [dtype, scale32, double_round, per_channel],
+ )
+ )
+
+ return arg_list
+
+ @staticmethod
+ def agMul(testGen, opName, shapeList, dtype, error_name=None):
+ arg_list = []
+
+ if dtype is DType.INT32:
+ for p in range(testGen.args.num_rand_permutations):
+
+ shift = testGen.randInt(0, 32)
+
+ arg_list.append(("perm{}_shift{}".format(p, shift), [shift]))
+ else:
+ arg_list.append(("perm0_shift0", [0]))
+
+ return arg_list
+
+ @staticmethod
+ def agArithmeticRightShift(testGen, opName, shapeList, dtype, error_name=None):
+ arg_list = []
+
+ arg_list.append(("roundTrue", [True]))
+ arg_list.append(("roundFalse", [False]))
+
+ return arg_list
+
+ # Helper function for reshape. Gets some factors of a larger number.
+ @staticmethod
+ def getFactors(val, start=1):
+ factors = []
+
+ for i in range(start, int(np.sqrt(val)) + 1):
+ if (val % i) == 0:
+ factors.append(i)
+
+ return factors
+
+ @staticmethod
+ def agReshape(testGen, opName, shapeList, dtype, error_name=None):
+ arg_list = []
+
+ origShape = shapeList[0]
+
+ totalElements = 1
+ for s in origShape:
+ totalElements *= s
+
+ # This code is NOT fast. Fortunately, the numbers are fairly small.
+ factors = TosaArgGen.getFactors(totalElements)
+
+ for p in range(testGen.args.num_rand_permutations):
+ newRank = testGen.randInt(1, 7)
+ if len(factors) < newRank:
+ continue
+
+ found = True
+ # escape_counter breaks while loop if it continues on for too long
+ escape_counter = 0
+ while found:
+ newShape = []
+ # Generate newShape ensuring it isn't a duplicate
+ remainingElements = totalElements
+ shuffledFactors = testGen.rng.permutation(factors)
+ for i in range(1, newRank):
+ # pick rank-1 factors
+ newShape.append(shuffledFactors[0])
+ remainingElements = remainingElements // shuffledFactors[0]
+ shuffledFactors = testGen.rng.permutation(
+ TosaArgGen.getFactors(remainingElements)
+ )
+ newShape.append(remainingElements)
+
+ # Toss in a -1 sometimes
+ minusOne = testGen.randInt(0, newRank * 4)
+ if minusOne < newRank:
+ newShape[minusOne] = -1
+
+ # Check for duplicates
+ found = False
+ for name, other_shape in arg_list:
+ if other_shape[0] == newShape:
+ found = True
+ break
+
+ escape_counter += 1
+ if escape_counter >= 100:
+ break
+
+ if not found:
+ arg_list.append(("perm{}_rank{}".format(p, newRank), [newShape]))
+
+ return arg_list
+
+ @staticmethod
+ def agTranspose(testGen, opName, shapeList, dtype, error_name=None):
+ arg_list = []
+
+ ifm_shape = shapeList[0]
+
+ if error_name == ErrorIf.IndexOutsideBounds:
+ incorrect_large_index = range(len(ifm_shape) + 1, 2 * len(ifm_shape) + 1)
+ incorrect_small_index = range(-len(ifm_shape), 0)
+ permutations = [p for p in itertools.permutations(incorrect_large_index)]
+ permutations.extend(
+ [p for p in itertools.permutations(incorrect_small_index)]
+ )
+ elif error_name == ErrorIf.IndexUsedTwice:
+ # Create list with a duplicated index
+ perm_range = list(range(len(ifm_shape)))
+ index_choice = testGen.rng.choice(range(len(perm_range)))
+ perm_range[(index_choice + 1) % len(perm_range)] = perm_range[index_choice]
+ permutations = [p for p in itertools.permutations(perm_range)]
+
+ else:
+ # Get all permutations
+ permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
+
+ # Limit to possible permutations from shape dimension or argument setting
+ limit = min(len(permutations), testGen.args.num_rand_permutations)
+
+ # Get random permutation generator that uses all permutations
+ random_permutations = testGen.rng.permutation(permutations)
+
+ # Create list of required amount of permutations
+ arg_list = [
+ ("perm{}".format(p), [random_permutations[p].tolist()])
+ for p in range(limit)
+ ]
+ return arg_list
+
+ @staticmethod
+ def agSlice(testGen, opName, shapeList, dtype, error_name=None):
+ arg_list = []
+
+ ifm_shape = shapeList[0]
+ rank = len(ifm_shape)
+
+ for p in range(testGen.args.num_rand_permutations):
+ start = []
+ size = []
+
+ valid = True
+
+ for i in range(rank):
+ if ifm_shape[i] > 1:
+ start.append(testGen.randInt(0, ifm_shape[i]))
+ size.append(testGen.randInt(0, ifm_shape[i] - start[i]))
+
+ # Invalid slice size?
+ if size[i] == 0:
+ valid = False
+ else:
+ start.append(0)
+ size.append(1)
+
+ if valid:
+ # If ERROR_IF test required then incorrect start, size will be returned
+ start, size = TosaErrorIfArgGen.eiSliceErrorIf(
+ testGen, error_name, ifm_shape, start, size
+ )
+ arg_list.append(("perm{}".format(p), [start, size]))
+ return arg_list
+
+ @staticmethod
+ def agTile(testGen, opName, shapeList, dtype, error_name=None):
+ arg_list = []
+
+ ifm_shape = shapeList[0]
+ rank = len(ifm_shape)
+
+ for p in range(testGen.args.num_rand_permutations):
+
+ # Pick a few random, but small multiple values
+ # because otherwise this has a tendency to generate
+ # enormous tensors
+ multiples = []
+ for i in range(rank):
+ if ifm_shape[i] > 1000:
+ # Multiple of 1 if ifm_shape dimension is large to reduce
+ # tensor size
+ multiples.append(1)
+ elif max(ifm_shape) > 1000:
+ multiples.append(2)
+ else:
+ multiples.append(testGen.randInt(1, 4))
+ arg_list.append(("perm{}".format(p), [multiples]))
+
+ return arg_list
+
+ @staticmethod
+ def agResize(testGen, opName, shapeList, dtype, error_name=None):
+ arg_list = []
+
+ ifm_shape = shapeList[0]
+ for mode in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
+
+ # Exclude illegal {mode, type} configurations. Pick legal output types
+ if mode == ResizeMode.NEAREST and dtype == DType.INT8:
+ outputDTypeList = [DType.INT8]
+ elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
+ outputDTypeList = [DType.INT16]
+ elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
+ outputDTypeList = [DType.INT32]
+ elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
+ outputDTypeList = [DType.INT48]
+ elif dtype == DType.FLOAT:
+ outputDTypeList = [DType.FLOAT]
+ elif error_name == ErrorIf.WrongInputType:
+ # If an incorrect input type is used then we set a 'correct'
+ # output type to avoid other errors
+ outputDTypeList = [DType.INT8, DType.INT16, DType.INT32]
+ else:
+ continue
+
+ for outputDType in outputDTypeList:
+ for perm in range(testGen.args.num_rand_permutations):
+ # Randomly generate legal output dimensions and shift
+ # and then compute the stride and offset based on them
+ # A output_dim of 1 will cause offset to exceed allowed range
+ # so minimum value 2 produced below
+ output_dims = [testGen.randInt(1) + 1, testGen.randInt(1) + 1]
+ while (float(ifm_shape[1]) / float(output_dims[0])) >= 16:
+ output_dims[0] += 1
+ while (float(ifm_shape[2]) / float(output_dims[1])) >= 16:
+ output_dims[1] += 1
+
+ in_center_h = (ifm_shape[1] - 1) / 2.0
+ in_center_w = (ifm_shape[2] - 1) / 2.0
+ out_center_h = (output_dims[0] - 1) / 2.0
+ out_center_w = (output_dims[1] - 1) / 2.0
+
+ fp_stride_y = float(ifm_shape[1]) / float(output_dims[0])
+ fp_stride_x = float(ifm_shape[2]) / float(output_dims[1])
+ fp_offset_y = in_center_h - fp_stride_y * out_center_h
+ fp_offset_x = in_center_w - fp_stride_x * out_center_w
+
+ if outputDType == DType.FLOAT:
+ float_op = True
+ arg_str = (
+ "mode{}_shift{}_odim{}x{}_out{}"
+ "_st{:.2f}x{:.2f}_off{:.2f}x{:.2f}"
+ )
+ shift = 0
+ stride = [0, 0]
+ offset = [0, 0]
+ stride_fp = [fp_stride_y, fp_stride_x]
+ offset_fp = [fp_offset_y, fp_offset_x]
+
+ else:
+ float_op = False
+ arg_str = "mode{}_shift{}_odim{}x{}_out{}_st{}x{}_off{}x{}"
+ shift = testGen.randInt(1, 12)
+ # Now search for a shift value (1 to 11) that will produce
+ # a valid and predictable resize operation
+ count = 0
+ while count < 12:
+ unit = float(1 << shift)
+ stride_y = int(round(fp_stride_y * unit))
+ stride_x = int(round(fp_stride_x * unit))
+ offset_y = int(round(fp_offset_y * unit))
+ offset_x = int(round(fp_offset_x * unit))
+
+ if (
+ stride_y <= 0
+ or stride_x <= 0
+ or stride_y >= (16 << shift)
+ or stride_x >= (16 << shift)
+ or offset_y >= (16 << shift)
+ or offset_x >= (16 << shift)
+ or offset_y <= (-16 << shift)
+ or offset_x <= (-16 << shift)
+ ):
+ # Change the shift value and check again
+ count += 1
+ shift = (shift % 11) + 1
+ continue
+
+ def RESIZE_REQUIRE_CALC(
+ length_in, length_out, stride, offset, shift
+ ):
+ # Perform the pseudo loop to look for out of bounds
+ for pos in range(0, length_out):
+ a = pos * stride + offset
+ ia = a >> shift
+ ia0 = max(ia, 0)
+ ia1 = min(ia + 1, length_in - 1)
+ if ia0 > ia1:
+ # Found a problem value
+ break
+ return ia0, ia1
+
+ iy0, iy1 = RESIZE_REQUIRE_CALC(
+ ifm_shape[1], output_dims[0], stride_y, offset_y, shift
+ )
+ ix0, ix1 = RESIZE_REQUIRE_CALC(
+ ifm_shape[2], output_dims[1], stride_x, offset_x, shift
+ )
+ if ix0 > ix1 or iy0 > iy1:
+ # Change the shift value and check again
+ count += 1
+ shift = (shift % 11) + 1
+ continue
+ break
+
+ if count >= 12:
+ # Couldn't find a good set of values for this test, skip it
+ continue
+
+ stride = [stride_y, stride_x]
+ offset = [offset_y, offset_x]
+
+ stride_fp = [0.0, 0.0]
+ offset_fp = [0.0, 0.0]
+
+ # Common for all data types
+ if error_name is not None:
+ (
+ shift,
+ stride,
+ stride_fp,
+ offset,
+ offset_fp,
+ outputDTypeNew,
+ ) = TosaErrorIfArgGen.eiResizeErrorIf(
+ testGen,
+ error_name,
+ mode,
+ dtype,
+ shapeList,
+ outputDType,
+ shift,
+ stride,
+ stride_fp,
+ offset,
+ offset_fp,
+ )
+ else:
+ outputDTypeNew = outputDType
+
+ arg_list.append(
+ (
+ arg_str.format(
+ "N" if mode == ResizeMode.NEAREST else "B",
+ shift,
+ output_dims[0],
+ output_dims[1],
+ testGen.typeStr(outputDTypeNew),
+ stride_fp[0] if float_op else stride[0],
+ stride_fp[1] if float_op else stride[1],
+ offset_fp[0] if float_op else offset[0],
+ offset_fp[1] if float_op else offset[1],
+ ),
+ [
+ mode,
+ stride,
+ offset,
+ shift,
+ stride_fp,
+ offset_fp,
+ output_dims,
+ dtype,
+ outputDTypeNew,
+ ],
+ )
+ )
+
+ return arg_list
+
+ @staticmethod
+ def agTable(testGen, opName, shapeList, dtype, error_name=None):
+ arg_list = []
+
+ if dtype == DType.INT8:
+ table = np.int32(
+ testGen.rng.integers(low=-128, high=128, size=[256])
+ ).tolist()
+ else: # INT16
+ table = np.int32(
+ testGen.rng.integers(low=-32768, high=32768, size=[513])
+ ).tolist()
+
+ arg_list.append(
+ (
+ "",
+ [table],
+ )
+ )
+ return arg_list
+
+ def agCondIf(testGen, opName, shapeList, dtype, error_name=None):
+ # CondIf generates the condition values here.
+ # Convert to tensors in the build function, along with the
+ # then and else blocks
+ arg_list = []
+
+ for c in [False, True]:
+ arg_list.append(("cond{}".format(int(c)), [c]))
+
+ return arg_list
+
+ def agWhileLoop(testGen, opName, shapeList, dtype, error_name=None):
+ # While loop: 0 iterations, 1, more than 1
+ arg_list = []
+
+ for iter in [0, 1, 4]:
+ arg_list.append(("iter{}".format(iter), [iter]))
+
+ return arg_list
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index 7070205..caf63e3 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -1,5 +1,12 @@
# Copyright (c) 2021-2022, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
+import numpy as np
+from generator.tosa_utils import product
+from generator.tosa_utils import usableDTypes
+from generator.tosa_utils import valueToName
+from tosa.DType import DType
+from tosa.Op import Op
+from tosa.ResizeMode import ResizeMode
class ErrorIf(object):
@@ -58,3 +65,2078 @@ class ErrorIf(object):
InputListBodyGraphInputMismatch = "InputListBodyGraphInputMismatch"
InputListBodyGraphOutputMismatch = "InputListBodyGraphOutputMismatch"
CondGraphOutputNotMatchingBool = "CondGraphOutputNotMatchingBool"
+
+
+class TosaErrorIfArgGen:
+ @staticmethod
+ def eiResizeErrorIf(
+ testGen,
+ error_name,
+ mode,
+ dtype,
+ shapeList,
+ outputDType,
+ shift,
+ stride,
+ stride_fp,
+ offset,
+ offset_fp,
+ ):
+
+ if outputDType == DType.FLOAT:
+ if error_name == ErrorIf.StrideSmallerEqualZero:
+ stride_fp = testGen.rng.random(size=[2]) - 2
+ elif error_name == ErrorIf.ShiftNotZero:
+ shift = testGen.rng.integers(1, 5)
+ elif error_name == ErrorIf.StrideLargerDimension:
+ shape = shapeList[0]
+ transform_height = testGen.rng.choice([False, True])
+ if transform_height:
+ stride_fp[0] = shape[1] + testGen.rng.integers(1, 10)
+ else:
+ stride_fp[1] = shape[2] + testGen.rng.integers(1, 10)
+ else:
+ if error_name == ErrorIf.StrideSmallerEqualZero:
+ stride = np.int16(testGen.rng.integers(-1, 1, size=[2]))
+ elif error_name == ErrorIf.ShiftSmallerOne:
+ shift = testGen.rng.integers(-3, 1)
+ if shift <= 0:
+ stride = [
+ (16 >> -shift) - 1,
+ (16 >> -shift) - 1,
+ ] # avoids other ERROR_IF checks
+ offset = [
+ (16 >> -shift) - 1,
+ (16 >> -shift) - 1,
+ ] # avoids other ERROR_IF checks
+ else:
+ stride = [
+ (16 << shift) - 1,
+ (16 << shift) - 1,
+ ] # avoids other ERROR_IF checks
+ offset = [
+ (16 << shift) - 1,
+ (16 << shift) - 1,
+ ] # avoids other ERROR_IF checks
+ elif error_name == ErrorIf.ShiftLargerEleven:
+ shift = np.int16(testGen.rng.integers(12, 15))
+ elif error_name == ErrorIf.StrideLargerDimension:
+ shape = shapeList[0]
+ transform_height = testGen.rng.choice([False, True])
+ if transform_height:
+ stride[0] = shape[1] + testGen.rng.integers(1, 10)
+ else:
+ stride[1] = shape[2] + testGen.rng.integers(1, 10)
+ elif error_name == ErrorIf.StrideLargerEqualMax:
+ stride = [(16 << shift) + 1, (16 << shift) + 1]
+ elif error_name == ErrorIf.OffsetLargerEqualMax:
+ offset = [(16 << shift) + 1, (16 << shift) + 1]
+ elif error_name == ErrorIf.OffsetSmallerEqualMin:
+ offset = [(-16 << shift) - 1, (-16 << shift) - 1]
+
+ if error_name == ErrorIf.WrongOutputType:
+ if mode == ResizeMode.NEAREST and dtype == DType.INT8:
+ incorrect_types = (
+ DType.INT4,
+ DType.INT16,
+ DType.INT32,
+ DType.INT48,
+ DType.FLOAT,
+ )
+ elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
+ incorrect_types = (
+ DType.INT4,
+ DType.INT8,
+ DType.INT32,
+ DType.INT48,
+ DType.FLOAT,
+ )
+ elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
+ incorrect_types = (
+ DType.INT4,
+ DType.INT8,
+ DType.INT16,
+ DType.INT48,
+ DType.FLOAT,
+ )
+ elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
+ incorrect_types = (
+ DType.INT4,
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.FLOAT,
+ )
+ elif dtype == DType.FLOAT:
+ incorrect_types = (
+ DType.INT4,
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.INT48,
+ )
+ outputDType = testGen.rng.choice(a=incorrect_types)
+
+ return shift, stride, stride_fp, offset, offset_fp, outputDType
+
+ @staticmethod
+ def eiPoolingErrorIf(testGen, error_name, stride, pad, kernel):
+ if (
+ error_name == ErrorIf.StrideSmallerOne
+ # padding must not exceed the kernel size
+ and pad[0] < kernel[0]
+ and pad[1] < kernel[0]
+ and pad[2] < kernel[1]
+ and pad[3] < kernel[1]
+ ):
+ wrongStride = (
+ testGen.rng.choice([0, -1, -2, -3]),
+ testGen.rng.choice([0, -1, -2, -3]),
+ )
+ return wrongStride, pad, kernel
+ elif error_name == ErrorIf.PadSmallerZero:
+ wrongPad = (
+ testGen.rng.choice([-1, -2, -3]),
+ testGen.rng.choice([-1, -2, -3]),
+ testGen.rng.choice([-1, -2, -3]),
+ testGen.rng.choice([-1, -2, -3]),
+ )
+ return stride, wrongPad, kernel
+ elif error_name == ErrorIf.KernelSmallerOne:
+ wrongKernel = (
+ testGen.rng.choice([0, -1, -2, -3]),
+ testGen.rng.choice([0, -1, -2, -3]),
+ )
+ return stride, pad, wrongKernel
+ elif error_name == ErrorIf.PadLargerEqualKernel:
+ wrongPad = (
+ testGen.rng.choice([kernel[0], kernel[0] + 1, kernel[0] + 2]),
+ testGen.rng.choice([kernel[0], kernel[0] + 1, kernel[0] + 2]),
+ testGen.rng.choice([kernel[1], kernel[1] + 1, kernel[1] + 2]),
+ testGen.rng.choice([kernel[1], kernel[1] + 1, kernel[1] + 2]),
+ )
+ return stride, wrongPad, kernel
+ else:
+ return None, None, None
+
+ @staticmethod
+ def eiRescaleWrongOutputType(input_dtype, output_dtype):
+ if input_dtype == DType.INT8:
+ if output_dtype not in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
+ return True
+ if input_dtype in [DType.INT16, DType.INT32]:
+ if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
+ return True
+ elif input_dtype == DType.INT48:
+ if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
+ return True
+ elif input_dtype == DType.UINT8:
+ if output_dtype != DType.INT8:
+ return True
+ return False
+
+ @staticmethod
+ def eiInvalidateInputOutputList(testGen, error_name, input_list, output_list):
+ # Mess up input/output tensors for ERROR_IF checks
+ if error_name == "WrongInputList":
+ add_input = testGen.rng.choice([True, False])
+ if add_input:
+ input_list.append("eiDummyInput")
+ else:
+ input_list = input_list[:-1]
+ elif error_name == "WrongOutputList":
+ add_output = testGen.rng.choice([True, False])
+ if add_output:
+ output_list.append("eiDummyOutput")
+ else:
+ output_list = []
+ return input_list, output_list
+
+ @staticmethod
+ def eiRestrictDimensions(shape, max_dim=32, max_items=100000):
+ """Restrict the dimensions and overall size of a shape to
+ max_dim and max_items.
+ """
+ new_shape = [min(d, max_dim) for d in shape] if max(shape) > max_dim else shape
+ while product(new_shape) > max_items:
+ new_shape = [max(d - 1, 1) for d in new_shape]
+ return new_shape
+
+ def eiSliceErrorIf(testGen, error_name, input_shape, start, size):
+ if error_name == ErrorIf.StartSmallerZero:
+ newStart = []
+ for i in range(len(input_shape)):
+ newStart.append(testGen.rng.choice([-3, -2, -1]))
+ return newStart, size
+ elif error_name == ErrorIf.SizeSmallerEqualZero:
+ newSize = []
+ for i in range(len(input_shape)):
+ newSize.append(testGen.rng.choice([-3, -2, -1, 0]))
+ return start, newSize
+ elif error_name == ErrorIf.StartSizeOutsideBounds:
+ newStart, newSize = [], []
+ for i in range(len(input_shape)):
+ newStart.append(input_shape[i] - 1)
+ newSize.append(testGen.rng.choice([2, 3, 4]))
+ return newStart, newSize
+ elif error_name == ErrorIf.InputSizeStartLengthMismatch:
+ remove = testGen.rng.choice([True, False])
+ if remove:
+ newStart = start[1:]
+ newSize = size[1:]
+ else:
+ newStart = start
+ newStart.append(1)
+ newSize = size
+ newSize.append(1)
+ return newStart, newSize
+ else:
+ return start, size
+
+ @staticmethod
+ def eiCastErrorIf(testGen, input_dtype):
+ if input_dtype in [DType.BOOL, DType.FLOAT]:
+ outputDType = [DType.BOOL, DType.INT48, DType.FLOAT]
+ elif input_dtype in [DType.INT8, DType.INT16, DType.INT32]:
+ outputDType = [DType.INT48]
+ else:
+ assert True, f"input_dtype ({input_dtype}) not supported"
+ return outputDType
+
+
+class TosaErrorValidator:
+ @staticmethod
+ def evValidateErrorIfs(serializer, validator_fcns, error_name, **kwargs):
+ """Check ERROR_IF statements are caught and set the expected result.
+
+ Args:
+ serializer: the serializer to set the expected result in
+ validator_fcns: a sequence of validator functions to verify the result
+ error_name: the name of the ERROR_IF condition to check for
+ kwargs: keyword arguments for the validator functions
+ Returns:
+ True if the result matches the expected result; otherwise False
+ """
+ overall_result = True
+ for val_fcn in validator_fcns:
+ val_result = val_fcn(True, **kwargs)
+ validator_name = val_result["error_name"]
+ error_result = val_result["error_result"]
+ error_reason = val_result["error_reason"]
+
+ # expect an error IFF the error_name and validator_name match
+ expected_result = error_result == (error_name == validator_name)
+ overall_result &= expected_result
+
+ if expected_result and error_result:
+ serializer.setExpectedReturnCode(2, True, desc=error_reason)
+ elif error_result: # and not expected_result
+ print(
+ f"Unexpected ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}"
+ f" Expected: {error_name}, Got: {validator_name}"
+ )
+ elif not expected_result: # and not error_result
+ print(
+ f"Missed ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}"
+ f" Expected: {error_name}"
+ )
+
+ if not expected_result:
+ for k, v in sorted(kwargs.items()):
+ if k != "op":
+ if k.endswith("dtype"):
+ v = valueToName(DType, v)
+ print(f" {k} = {v}")
+
+ return overall_result
+
+ @staticmethod
+ def evWrongInputType(check=False, **kwargs):
+ error_result = False
+
+ # Find the unsupported input data types
+ op = kwargs["op"]
+ input_dtypes = op["types"]
+ allowed_input_dtypes = {
+ t[0] if isinstance(t, list) else t for t in input_dtypes
+ }
+ wrong_input_dtypes = list(usableDTypes(excludes=allowed_input_dtypes))
+
+ if op["op"] == Op.CLAMP:
+ wrong_input_dtypes.remove(DType.INT48)
+
+ if check:
+ input_dtype = kwargs["input_dtype"]
+ if input_dtype not in allowed_input_dtypes:
+ error_result = True
+
+ info_dict = {
+ "error_name": ErrorIf.WrongInputType,
+ "error_result": error_result,
+ "error_reason": "Input data type not supported for this operator",
+ "param_reqs": {"rank": None, "dtype": wrong_input_dtypes, "shape": None},
+ }
+ return info_dict
+
+ @staticmethod
+ def evWrongOutputType(check=False, **kwargs):
+ error_result = False
+
+ if check:
+ input_dtype = kwargs["input_dtype"]
+ output_dtype = kwargs["output_dtype"]
+ op = kwargs["op"]
+
+ if op["op"] == Op.RESIZE:
+ mode = kwargs["mode"]
+ if (
+ (
+ mode == ResizeMode.NEAREST
+ and input_dtype == DType.INT8
+ and output_dtype != DType.INT8
+ )
+ or (
+ mode == ResizeMode.NEAREST
+ and input_dtype == DType.INT16
+ and output_dtype != DType.INT16
+ )
+ or (
+ mode == ResizeMode.BILINEAR
+ and input_dtype == DType.INT8
+ and output_dtype != DType.INT32
+ )
+ or (
+ mode == ResizeMode.BILINEAR
+ and input_dtype == DType.INT16
+ and output_dtype != DType.INT48
+ )
+ or (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
+ ):
+ error_result = True
+
+ elif op["op"] == Op.RESCALE:
+ if input_dtype == DType.INT8:
+ if output_dtype not in [
+ DType.UINT8,
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ ]:
+ error_result = True
+ if input_dtype in [DType.INT16, DType.INT32]:
+ if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
+ error_result = True
+ elif input_dtype == DType.INT48:
+ if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
+ error_result = True
+ elif input_dtype == DType.UINT8:
+ if output_dtype != DType.INT8:
+ error_result = True
+
+ elif op["op"] in [Op.FULLY_CONNECTED, Op.MATMUL]:
+ if (
+ (input_dtype == DType.INT8 and output_dtype != DType.INT32)
+ or (input_dtype == DType.INT16 and output_dtype != DType.INT48)
+ or (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
+ ):
+ error_result = True
+
+ elif op["op"] == Op.ARGMAX:
+ if (
+ input_dtype in [DType.INT8, DType.INT16, DType.FLOAT]
+ and output_dtype != DType.INT32
+ ):
+ error_result = True
+
+ elif op["op"] == Op.MUL:
+ if input_dtype != DType.FLOAT and output_dtype != DType.INT32:
+ error_result = True
+ elif input_dtype == DType.FLOAT and output_dtype != DType.FLOAT:
+ error_result = True
+
+ elif op["op"] == Op.TABLE:
+ if input_dtype == DType.INT8 and output_dtype != DType.INT8:
+ error_result = True
+ elif input_dtype == DType.INT16 and output_dtype != DType.INT32:
+ error_result = True
+
+ elif op["op"] in [Op.EQUAL, Op.GREATER_EQUAL, Op.GREATER]:
+ if output_dtype != DType.BOOL:
+ error_result = True
+
+ elif op["op"] == Op.CAST:
+ if (
+ (
+ input_dtype == DType.BOOL
+ and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
+ )
+ or (
+ input_dtype == DType.INT8
+ and output_dtype
+ not in [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT]
+ )
+ or (
+ input_dtype == DType.INT16
+ and output_dtype
+ not in [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT]
+ )
+ or (
+ input_dtype == DType.INT32
+ and output_dtype
+ not in [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
+ )
+ or (
+ input_dtype == DType.FLOAT
+ and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
+ )
+ ):
+ error_result = True
+
+ elif op["op"] in {
+ Op.CONV2D,
+ Op.CONV3D,
+ Op.DEPTHWISE_CONV2D,
+ Op.TRANSPOSE_CONV2D,
+ }:
+ if (
+ input_dtype == DType.INT8
+ and output_dtype != DType.INT32
+ or input_dtype == DType.INT16
+ and output_dtype != DType.INT48
+ or input_dtype == DType.FLOAT
+ and output_dtype != DType.FLOAT
+ ):
+ error_result = True
+ # invalid input types are ignored, to avoid reporting multiple errors
+
+ else:
+ if output_dtype != input_dtype:
+ error_result = True
+
+ info_dict = {
+ "error_name": ErrorIf.WrongOutputType,
+ "error_result": error_result,
+ "error_reason": (
+ "Output data type not supported for this configuration of operator"
+ ),
+ "param_reqs": {"rank": None, "dtype": None, "shape": None},
+ }
+ return info_dict
+
+ @staticmethod
+ def evWrongRank(check=False, **kwargs):
+ all_ranks = (1, 2, 3, 4, 5)
+
+ # Make a list of incorrect ranks
+ assert "op" in kwargs
+ op = kwargs["op"]
+ rmin, rmax = op["rank"]
+ rank_range = range(rmin, rmax + 1)
+ incorrect_ranks = list(set(all_ranks) - set(rank_range))
+ # Remove small incorrect ranks to avoid index errors
+ incorrect_ranks = [rank for rank in incorrect_ranks if rank > rmin]
+ # Set minimum incorrect rank to 3 to avoid index error
+ if op["op"] in [Op.RESIZE]:
+ incorrect_ranks = [3, 5]
+ elif op["op"] in [Op.TRANSPOSE]:
+ incorrect_ranks = [7, 8]
+ elif op["op"] in [Op.CONV3D]:
+ incorrect_ranks = [6, 7]
+
+ error_name = ErrorIf.WrongRank
+ param_reqs = {"rank": incorrect_ranks, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Rank not supported for this operator"
+
+ if check:
+ input_shape = kwargs["input_shape"]
+
+ if (
+ op["op"] in [Op.RESIZE, Op.AVG_POOL2D, Op.MAX_POOL2D]
+ and len(input_shape) != 4
+ ):
+ error_result = True
+ elif op["op"] == Op.FULLY_CONNECTED and len(input_shape) != 2:
+ error_result = True
+ elif op["op"] == Op.MATMUL and len(input_shape) != 3:
+ error_result = True
+ else:
+ if len(input_shape) not in rank_range:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
+ def evWrongInputList(check=False, **kwargs):
+ error_name = ErrorIf.WrongInputList
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Op input list does not match expected input"
+
+ if check:
+ op = kwargs["op"]
+ input_list = kwargs["input_list"]
+ num_operands = kwargs["num_operands"]
+ if op["op"] in [Op.SCATTER, Op.GATHER]:
+ # SCATTER/GATHER add an indices input tensor in their build functions
+ num_operands += 1
+ if len(input_list) != num_operands:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
+ def evWrongOutputList(check=False, **kwargs):
+ error_name = ErrorIf.WrongOutputList
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Op output list does not match expected output"
+
+ if check:
+ output_list = kwargs["output_list"]
+ # Note this will be incorrect if an operator returns more than one output
+ if len(output_list) != 1:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
+ def evMaxDimExceeded(check=False, **kwargs):
+ error_name = ErrorIf.MaxDimExceeded
+ param_reqs = {
+ "rank": [4, 4],
+ "dtype": [DType.INT8],
+ "shape": [[1, 16584, 5, 1], [1, 2, 16499, 4]],
+ }
+ error_result = False
+ error_reason = (
+ "At least one maximum dimension is greater than or equal to 16384"
+ )
+
+ if check:
+ input_shape = kwargs["input_shape"]
+ output_shape = kwargs["output_shape"] # Note this is just (OH, OW)
+ if (
+ (input_shape[1] >= 16384)
+ or (input_shape[2] >= 16384)
+ or (output_shape[0] >= 16384)
+ or (output_shape[1] >= 16384)
+ ):
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
+ def evBatchMismatch(check=False, **kwargs):
+ error_name = ErrorIf.BatchMismatch
+ param_reqs = {"rank": [4, 4], "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Input batch size not equal to output batch size"
+
+ assert "op" in kwargs
+ op = kwargs["op"]
+ rmin, rmax = op["rank"]
+ rank_range = range(rmin, rmax + 1)
+
+ if check:
+ input_shape = kwargs["input_shape"]
+ output_shape = kwargs[
+ "result_tensor"
+ ].shape # Note this is just (N, OH, OW, C)
+
+ if (len(input_shape) in rank_range) and (input_shape[0] != output_shape[0]):
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
+ def evChannelMismatch(check=False, **kwargs):
+ error_name = ErrorIf.ChannelMismatch
+ param_reqs = {"rank": [4, 4], "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Input channel size not equal to output channel size"
+
+ assert "op" in kwargs
+ op = kwargs["op"]
+ rmin, rmax = op["rank"]
+ rank_range = range(rmin, rmax + 1)
+
+ if check:
+ input_shape = kwargs["input_shape"]
+ output_shape = kwargs[
+ "result_tensor"
+ ].shape # Note this is just (N, OH, OW, C)
+ if (len(input_shape) in rank_range) and (input_shape[3] != output_shape[3]):
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
+ def evStrideSmallerEqualZero(check=False, **kwargs):
+ error_name = ErrorIf.StrideSmallerEqualZero
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Stride value smaller than or equal zero"
+
+ if check:
+ input_dtype = kwargs["input_dtype"]
+ output_dtype = kwargs["output_dtype"]
+ if input_dtype != DType.FLOAT and output_dtype == DType.FLOAT:
+ stride = kwargs["stride"] # Work around wrong input/output type tests
+ elif output_dtype == DType.FLOAT:
+ stride = kwargs["stride_fp"]
+ elif input_dtype == DType.FLOAT and output_dtype != DType.FLOAT:
+ stride = kwargs[
+ "stride_fp"
+ ] # Work around wrong input/output type tests
+ else:
+ stride = kwargs["stride"]
+
+ if min(stride) <= 0:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
+ def evStrideLargerEqualMax(check=False, **kwargs):
+ error_name = ErrorIf.StrideLargerEqualMax
+ param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
+ error_result = False
+ error_reason = "Stride value larger than or equal to maximum value"
+
+ if check:
+ shift = kwargs["shift"]
+ input_dtype = kwargs["input_dtype"]
+ stride = kwargs["stride"]
+ if input_dtype in [DType.INT8, DType.INT16]:
+ if shift >= 0 and (
+ stride[0] >= (16 << shift) or stride[1] >= (16 << shift)
+ ):
+ error_result = True
+ elif shift < 0 and (
+ stride[0] >= (16 >> -shift) or stride[1] >= (16 >> -shift)
+ ):
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
+ def evStrideLargerDimension(check=False, **kwargs):
+ error_name = ErrorIf.StrideLargerDimension
+ param_reqs = {"rank": None, "dtype": [DType.FLOAT], "shape": None}
+ error_result = False
+ error_reason = "Stride value larger than or equal to H/W dimension"
+
+ if check:
+ shape = kwargs["input_shape"]
+ input_dtype = kwargs["input_dtype"]
+ stride = kwargs["stride_fp"]
+
+ if (
+ input_dtype == DType.FLOAT
+ and (stride[0] > shape[1])
+ or (stride[1] > shape[2])
+ ):
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
+ def evOffsetSmallerEqualMin(check=False, **kwargs):
+ error_name = ErrorIf.OffsetSmallerEqualMin
+ param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
+ error_result = False
+ error_reason = "Offset value smaller than or equal to minimum value"
+
+ if check:
+ shift = kwargs["shift"]
+ output_dtype = kwargs["output_dtype"]
+ if output_dtype == DType.FLOAT:
+ offset = kwargs["offset_fp"]
+ else:
+ offset = kwargs["offset"]
+
+ if shift >= 0 and (
+ offset[0] <= (-16 << shift) or offset[1] <= (-16 << shift)
+ ):
+ error_result = True
+ elif shift < 0 and (
+ offset[0] <= (-16 >> -shift) or offset[1] <= (-16 >> -shift)
+ ):
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
+ def evOffsetLargerEqualMax(check=False, **kwargs):
+ error_name = ErrorIf.OffsetLargerEqualMax
+ param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
+ error_result = False
+ error_reason = "Offset value larger than or equal to maximum value"
+
+ if check:
+ shift = kwargs["shift"]
+ output_dtype = kwargs["output_dtype"]
+ if output_dtype == DType.FLOAT:
+ offset = kwargs["offset_fp"]
+ else:
+ offset = kwargs["offset"]
+
+ if shift >= 0:
+ if offset[0] >= (16 << shift) or offset[1] >= (16 << shift):
+ error_result = True
+
+ if shift >= 0 and (
+ offset[0] >= (16 << shift) or offset[1] >= (16 << shift)
+ ):
+ error_result = True
+ elif shift < 0 and (
+ offset[0] >= (16 >> -shift) or offset[1] >= (16 >> -shift)
+ ):
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
+ def evShiftNotZero(check=False, **kwargs):
+ error_name = ErrorIf.ShiftNotZero
+ param_reqs = {"rank": None, "dtype": [DType.FLOAT], "shape": None}
+ error_result = False
+ error_reason = "Shift value must be zero for float input"
+
+ if check:
+ shift = kwargs["shift"]
+ input_dtype = kwargs["input_dtype"]
+ output_dtype = kwargs["output_dtype"]
+ if (
+ input_dtype == DType.FLOAT
+ and output_dtype == DType.FLOAT
+ and shift != 0
+ ):
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
+ def evShiftSmallerOne(check=False, **kwargs):
+ error_name = ErrorIf.ShiftSmallerOne
+ param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
+ error_result = False
+ error_reason = "Shift value smaller than one"
+
+ if check:
+ shift = kwargs["shift"]
+ input_dtype = kwargs["input_dtype"]
+ output_dtype = kwargs["output_dtype"]
+ if shift < 1 and input_dtype != DType.FLOAT and output_dtype != DType.FLOAT:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
+ def evShiftLargerEleven(check=False, **kwargs):
+ error_name = ErrorIf.ShiftLargerEleven
+ param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
+ error_result = False
+ error_reason = "Shift value larger than eleven"
+
+ if check:
+ shift = kwargs["shift"]
+ if shift > 11:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
+ def evRankMismatch(check=False, **kwargs):
+ error_name = ErrorIf.RankMismatch
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Input Rank does not match output rank"
+
+ if check:
+ input1_shape = kwargs["input1"].shape
+ input2_shape = kwargs["input2"].shape
+ # In case of SELECT op
+ input3_shape = (
+ kwargs["input3"].shape if "input3" in kwargs else input2_shape
+ )
+ output_shape = kwargs["result_tensor"].shape
+ if (
+ (len(input1_shape) != len(output_shape))
+ or (len(input2_shape) != len(output_shape))
+ or (len(input3_shape) != len(output_shape))
+ ):
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
+ def evDimensionMismatch(check=False, **kwargs):
+ error_name = ErrorIf.DimensionMismatch
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Input Dimensions do not match output"
+
+ if check:
+ input1_shape = kwargs["input1"].shape
+ input2_shape = kwargs["input2"].shape
+ # In case of SELECT op
+ input3_shape = (
+ kwargs["input3"].shape if "input3" in kwargs else input2_shape
+ )
+ output_shape = kwargs["result_tensor"].shape
+ for i in range(
+ min(len(input1_shape), len(input2_shape), len(input3_shape))
+ ):
+ if (
+ (input1_shape[i] != 1 and input1_shape[i] != output_shape[i])
+ or (input2_shape[i] != 1 and input2_shape[i] != output_shape[i])
+ or (input3_shape[i] != 1 and input3_shape[i] != output_shape[i])
+ ):
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
+ def evInputZeroPointNotZero(check=False, **kwargs):
+ op = kwargs["op"]
+ error_result = False
+
+ # Quantizable types
+ qTypes = (DType.INT8, DType.UINT8)
+
+ # This does not apply to quantizable types
+ inputDtypes = [
+ dtype
+ for dtype in op["types"]
+ if (isinstance(dtype, list) and dtype[0] not in qTypes)
+ or (not isinstance(dtype, list) and dtype not in qTypes)
+ ]
+
+ if check:
+ input_dtype = kwargs["input_dtype"]
+ if isinstance(kwargs["qinfo"], tuple):
+ qinfo = kwargs["qinfo"]
+ input_zero_point = qinfo[0]
+ else:
+ # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
+ qinfo = kwargs["qinfo"].ints
+ input_zero_point = qinfo[0][1]
+
+ if op["op"] == Op.MATMUL:
+ qinfo = kwargs["qinfo"].ints
+ for dtype, zp in (
+ (kwargs["input_dtype"], qinfo[0][1]),
+ (kwargs["input2_dtype"], qinfo[1][1]),
+ ):
+ if dtype not in qTypes and zp != 0:
+ error_result = True
+ break
+ else:
+ error_result = input_dtype not in qTypes and input_zero_point != 0
+
+ info_dict = {
+ "error_name": ErrorIf.InputZeroPointNotZero,
+ "error_result": error_result,
+ "error_reason": "Input DType not INT8 and zero point not 0",
+ "param_reqs": {"rank": None, "dtype": inputDtypes, "shape": None},
+ }
+ return info_dict
+
+ @staticmethod
+ def evWeightZeroPointNotZero(check=False, **kwargs):
+ op = kwargs["op"]
+
+ # exclude inputs with INT8 weights
+ inputDtypes = [
+ t for t in op["types"] if not isinstance(t, list) or t[1] != DType.INT8
+ ]
+
+ error_name = ErrorIf.WeightZeroPointNotZero
+ param_reqs = {"rank": None, "dtype": inputDtypes, "shape": None}
+ error_result = False
+ error_reason = "Weight DType not INT8 and zero point not 0"
+
+ if check:
+ weight_dtype = kwargs["weight_dtype"]
+ # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = weight_zp
+ qinfo = kwargs["qinfo"].ints
+ weight_zero_point = qinfo[1][1]
+ if weight_dtype != DType.INT8 and weight_zero_point != 0:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
+ def evOutputZeroPointNotZero(check=False, **kwargs):
+ op = kwargs["op"]
+ inputDtypes = op["types"].copy()
+ if DType.INT8 in inputDtypes:
+ inputDtypes.remove(DType.INT8)
+ if DType.UINT8 in inputDtypes:
+ inputDtypes.remove(DType.UINT8)
+
+ error_name = ErrorIf.OutputZeroPointNotZero
+ param_reqs = {"rank": None, "dtype": inputDtypes, "shape": None}
+ error_result = False
+ error_reason = "Output DType not INT8 and zero point not 0"
+
+ if check:
+ input_dtype = kwargs["input_dtype"]
+ output_dtype = kwargs["output_dtype"]
+ if isinstance(kwargs["qinfo"], tuple):
+ qinfo = kwargs["qinfo"]
+ output_zero_point = qinfo[1]
+ else:
+ # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
+ qinfo = kwargs["qinfo"].ints
+ output_zero_point = qinfo[1][1]
+ if op["op"] == Op.AVG_POOL2D:
+ if input_dtype != DType.INT8 and output_zero_point != 0:
+ error_result = True
+ elif (
+ output_dtype not in [DType.INT8, DType.UINT8] and output_zero_point != 0
+ ):
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
+ def evAxisSmallerZero(check=False, **kwargs):
+ error_name = ErrorIf.AxisSmallerZero
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Axis smaller than zero"
+
+ if check:
+ axis = kwargs["axis"]
+ if axis < 0:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
+ def evAxisLargerRank(check=False, **kwargs):
+ error_name = ErrorIf.AxisLargerRank
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Axis larger than rank"
+
+ if check:
+ axis = kwargs["axis"]
+ shape = kwargs["input_shape"]
+ if axis > len(shape):
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
+ def evShapeOfAxisNotOne(check=False, **kwargs):
+ error_name = ErrorIf.ShapeOfAxisNotOne
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "shape[axis] is not equal to 1"
+
+ if check:
+ axis = kwargs["axis"]
+ shape = kwargs["output_shape"]
+ if (0 <= axis < len(shape)) and shape[axis] != 1:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
+ def evPadSmallerZero(check=False, **kwargs):
+ error_name = ErrorIf.PadSmallerZero
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "At least one pad is smaller than zero"
+
+ if check:
+ op = kwargs["op"]
+ pad = kwargs["pad"]
+ if op["op"] == Op.PAD:
+ for padding in pad:
+ if min(padding) < 0:
+ error_result = True
+ else:
+ if min(pad) < 0:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
+ def evPadLargerEqualKernel(check=False, **kwargs):
+ error_name = ErrorIf.PadLargerEqualKernel
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "At least one pad is larger than kernel dimension"
+
+ if check:
+ pad = kwargs["pad"]
+ kernel = kwargs["kernel"]
+ if min(pad) > 0 and min(kernel) > 1:
+ if (
+ pad[0] >= kernel[0]
+ or pad[1] >= kernel[0]
+ or pad[2] >= kernel[1]
+ or pad[3] >= kernel[1]
+ ):
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
+ def evPoolingOutputShapeMismatch(check=False, **kwargs):
+ error_name = ErrorIf.PoolingOutputShapeMismatch
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = (
+ "Mismatch between output shape provided and expected output shape"
+ )
+
+ if check:
+ pad = kwargs["pad"]
+ pad_top, pad_bottom, pad_left, pad_right = pad[0], pad[1], pad[2], pad[3]
+
+ kernel = kwargs["kernel"]
+ kernel_y, kernel_x = kernel[0], kernel[1]
+
+ input_shape = kwargs["input_shape"]
+ IH, IW = input_shape[1], input_shape[2]
+
+ output_shape = kwargs["output_shape"]
+ OH, OW = output_shape[1], output_shape[2]
+
+ stride = kwargs["stride"]
+ stride_y, stride_x = stride[0], stride[1]
+
+ # calculate correct height, width dimensions
+ if stride_x != 0 and stride_y != 0:
+ y_correct = (
+ IH + pad_top + pad_bottom + stride_y - kernel_y
+ ) // stride_y
+ x_correct = (
+ IW + pad_left + pad_right + stride_x - kernel_x
+ ) // stride_x
+
+ # ensure parameters are valid
+ params_valid = (
+ min(kernel) >= 1
+ and min(stride) >= 1
+ and min(pad) >= 0
+ and not (
+ pad[0] >= kernel[0]
+ or pad[1] >= kernel[0]
+ or pad[2] >= kernel[1]
+ or pad[3] >= kernel[1]
+ )
+ )
+
+ if params_valid and (OH != y_correct or OW != x_correct):
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
+ def evArgmaxOutputShapeMismatch(check=False, **kwargs):
+ error_name = ErrorIf.ArgmaxOutputShapeMismatch
+ param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
+ error_result = False
+ error_reason = (
+ "Mismatch between output shape provided and expected output shape"
+ )
+
+ if check:
+ output_shape = kwargs["output_shape"]
+ input_shape = kwargs["input_shape"]
+ axis = kwargs["axis"]
+
+ dimension_match = True
+ axis_shift = 0
+
+ # Check that rank is correct before trying to check dimensions
+ if (len(input_shape) - 1) == len(output_shape):
+ for i in range(len(input_shape)):
+ if i == axis:
+ axis_shift = 1
+ continue
+ if input_shape[i] != output_shape[i - axis_shift]:
+ dimension_match = False
+
+ if not dimension_match:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
+ def evArgmaxOutputRankMismatch(check=False, **kwargs):
+ error_name = ErrorIf.ArgmaxOutputRankMismatch
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = (
+ "Mismatch between output shape provided and expected output shape"
+ )
+
+ if check:
+ output_shape = kwargs["output_shape"]
+ input_shape = kwargs["input_shape"]
+ axis = kwargs["axis"]
+ valid_params = axis >= 0 and axis < len(input_shape)
+
+ if valid_params and (len(input_shape) - 1) != len(output_shape):
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
+ def evKernelSmallerOne(check=False, **kwargs):
+ error_name = ErrorIf.KernelSmallerOne
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "At least one kernel dimension is smaller than zero"
+
+ if check:
+ kernel = kwargs["kernel"]
+ if min(kernel) < 1:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
+ def evStrideSmallerOne(check=False, **kwargs):
+ error_name = ErrorIf.StrideSmallerOne
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "At least one stride dimension is smaller than zero"
+
+ if check:
+ stride = kwargs["stride"]
+ if min(stride) < 1:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
+ def evDilationSmallerOne(check=False, **kwargs):
+ error_result = check and min(kwargs["dilation"]) < 1
+ return {
+ "error_name": ErrorIf.DilationSmallerOne,
+ "error_reason": "At least one dilation is smaller than one",
+ "param_reqs": {"rank": None, "dtype": None, "shape": None},
+ "error_result": error_result,
+ }
+
+ @staticmethod
+ def evScaleTrue(check=False, **kwargs):
+ error_name = ErrorIf.ScaleTrue
+ param_reqs = {"rank": None, "dtype": [DType.INT48], "shape": None}
+ error_result = False
+ error_reason = "Scale set to true but input type is INT48"
+
+ if check:
+ input_dtype = kwargs["input_dtype"]
+ scale32 = kwargs["scale32"]
+ if scale32 and input_dtype == DType.INT48:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
+ def evScaleNotTrue(check=False, **kwargs):
+ error_name = ErrorIf.ScaleNotTrue
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Scale set to false but double round set to true"
+
+ if check:
+ scale32 = kwargs["scale32"]
+ double_round = kwargs["double_round"]
+ if not scale32 and double_round:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
+ def evTensorSizeInputOutputMismatch(check=False, **kwargs):
+ error_name = ErrorIf.TensorSizeInputOutputMismatch
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Input tensor size does not match output tensor size"
+
+ if check:
+ input_shape = kwargs["input_shape"]
+ output_shape = kwargs["output_shape"]
+ input_size = np.prod(input_shape)
+ output_size = np.prod(output_shape)
+ if input_size != output_size:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
+ def evStartSmallerZero(check=False, **kwargs):
+ error_name = ErrorIf.StartSmallerZero
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Starting point smaller than zero"
+
+ if check:
+ input_shape = kwargs["input_shape"]
+ start = kwargs["start"]
+ rank = len(input_shape)
+ if len(start) == rank:
+ for index in range(rank):
+ if start[index] < 0:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
+ def evSizeSmallerEqualZero(check=False, **kwargs):
+ error_name = ErrorIf.SizeSmallerEqualZero
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Size smaller than or equal to zero"
+
+ if check:
+ input_shape = kwargs["input_shape"]
+ size = kwargs["size"]
+ rank = len(input_shape)
+ if len(size) == rank:
+ for index in range(rank):
+ if size[index] <= 0:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
+ def evStartSizeOutsideBounds(check=False, **kwargs):
+ error_name = ErrorIf.StartSizeOutsideBounds
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "starting point plus size larger than input dimension"
+
+ if check:
+ input_shape = kwargs["input_shape"]
+ start = kwargs["start"]
+ size = kwargs["size"]
+ rank = len(input_shape)
+ if len(start) == rank and len(size) == rank:
+ for index in range(rank):
+ if start[index] + size[index] > input_shape[index]:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
+ def evSizeOutputShapeMismatch(check=False, **kwargs):
+ error_name = ErrorIf.SizeOutputShapeMismatch
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Size does not match output dimension"
+
+ if check:
+ input_shape = kwargs["input_shape"]
+ output_shape = kwargs["output_shape"]
+ size = kwargs["size"]
+ rank = len(input_shape)
+ if len(size) == rank:
+ for index in range(rank):
+ if size[index] != output_shape[index]:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
+ def evInputSizeStartLengthMismatch(check=False, **kwargs):
+ error_name = ErrorIf.InputSizeStartLengthMismatch
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "rank of input not equal to length of start or size"
+
+ if check:
+ input_shape = kwargs["input_shape"]
+ start = kwargs["start"]
+ size = kwargs["size"]
+ rank = len(input_shape)
+ if rank != len(start) or rank != len(size):
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
+ def evIndexOutsideBounds(check=False, **kwargs):
+ error_name = ErrorIf.IndexOutsideBounds
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Index outside of allowed bounds"
+
+ if check:
+ input_shape = kwargs["input_shape"]
+ perms = kwargs["perms"]
+ rank = len(input_shape)
+
+ for index in perms:
+ if index < 0 or index > rank:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
+ def evIndexUsedTwice(check=False, **kwargs):
+ error_name = ErrorIf.IndexUsedTwice
+ param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Index used multiple times"
+
+ if check:
+ perms = kwargs["perms"]
+
+ unique_indices = []
+ for index in perms:
+ if index in unique_indices:
+ error_result = True
+ else:
+ unique_indices.append(index)
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
+ def evMaxSmallerMin(check=False, **kwargs):
+ error_name = ErrorIf.MaxSmallerMin
+ param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Max value smaller than min value"
+
+ if check:
+ max_val = kwargs["max_val"]
+ min_val = kwargs["min_val"]
+ if max_val < min_val:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
+ def evConcatInputRankMismatch(check=False, **kwargs):
+ error_name = ErrorIf.ConcatInputRankMismatch
+ param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Input ranks are not identical"
+
+ if check:
+ inputs = kwargs["inputs"]
+ input_shape = kwargs["input_shape"]
+ for input in inputs:
+ if len(input.shape) != len(input_shape):
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
+ def evConcatInputDimMismatch(check=False, **kwargs):
+ error_name = ErrorIf.ConcatInputDimMismatch
+ param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Input dimensions differ on too many axes"
+
+ if check:
+ inputs = kwargs["inputs"]
+ input_shape = kwargs["input_shape"]
+ axis = kwargs["axis"]
+
+ # Ensure rank is valid before checking dims.
+ valid_rank = True
+ for input in inputs:
+ if len(input.shape) != len(input_shape):
+ valid_rank = False
+
+ if valid_rank:
+ for input in inputs:
+ for i, dim in enumerate(input.shape):
+ if dim != input_shape[i] and axis != i:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
+ def evConcatShapeSumMismatch(check=False, **kwargs):
+ error_name = ErrorIf.ConcatShapeSumMismatch
+ param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Sum of dimensions on axis not equal to output dimension"
+
+ if check:
+ inputs = kwargs["inputs"]
+ input_shape = kwargs["input_shape"]
+ output_shape = kwargs["output_shape"]
+ axis = kwargs["axis"]
+
+ # Ensure rank is valid before checking dims.
+ valid_params = True
+ for input in inputs:
+ if len(input.shape) != len(input_shape):
+ valid_params = False
+ if axis < 0 or axis > len(input_shape):
+ valid_params = False
+
+ if valid_params:
+ axis_dim_sum = 0
+ for input in inputs:
+ axis_dim_sum += input.shape[axis]
+
+ if axis_dim_sum != output_shape[axis]:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
+ def evInputListThenGraphMismatch(check=False, **kwargs):
+ error_name = ErrorIf.CondIfInputListThenGraphMismatch
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Input list shape does not match then-graph shape"
+
+ if check:
+ a = kwargs["a"]
+ b = kwargs["b"]
+ basicBlocks = kwargs["basicBlocks"]
+ then_block = basicBlocks[1]
+ then_inputs = then_block.inputs
+ then_tens = then_block.tensors
+ if (a.shape != then_tens[then_inputs[0]].shape) or (
+ b.shape != then_tens[then_inputs[1]].shape
+ ):
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
+ def evInputListElseGraphMismatch(check=False, **kwargs):
+ error_name = ErrorIf.CondIfInputListElseGraphMismatch
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Input list shape does not match else-graph shape"
+
+ if check:
+ a = kwargs["a"]
+ b = kwargs["b"]
+ basicBlocks = kwargs["basicBlocks"]
+ else_block = basicBlocks[2]
+ else_inputs = else_block.inputs
+ else_tens = else_block.tensors
+ if (a.shape != else_tens[else_inputs[0]].shape) or (
+ b.shape != else_tens[else_inputs[1]].shape
+ ):
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
+ def evOutputListThenGraphMismatch(check=False, **kwargs):
+ error_name = ErrorIf.CondIfOutputListThenGraphMismatch
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Output list shape does not match then-graph shape"
+
+ if check:
+ basicBlocks = kwargs["basicBlocks"]
+ cond_block = basicBlocks[0]
+ cond_outputs = cond_block.outputs
+ cond_tens = cond_block.tensors
+ then_block = basicBlocks[1]
+ then_outputs = then_block.outputs
+ then_tens = then_block.tensors
+ if then_tens[then_outputs[0]].shape != cond_tens[cond_outputs[0]].shape:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
+ def evOutputListElseGraphMismatch(check=False, **kwargs):
+ error_name = ErrorIf.CondIfOutputListElseGraphMismatch
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Output list shape does not match else-graph shape"
+
+ if check:
+ basicBlocks = kwargs["basicBlocks"]
+ cond_block = basicBlocks[0]
+ cond_outputs = cond_block.outputs
+ cond_tens = cond_block.tensors
+ else_block = basicBlocks[2]
+ else_outputs = else_block.outputs
+ else_tens = else_block.tensors
+ if else_tens[else_outputs[0]].shape != cond_tens[cond_outputs[0]].shape:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
+ def evInputListOutputListMismatch(check=False, **kwargs):
+ error_name = ErrorIf.InputListOutputListMismatch
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Input list does not match output list"
+
+ if check:
+ basicBlocks = kwargs["basicBlocks"]
+ while_block = basicBlocks[0]
+ while_inputs = while_block.inputs
+ while_outputs = while_block.outputs
+ while_tens = while_block.tensors
+ if while_tens[while_inputs[1]].shape != while_tens[while_outputs[0]].shape:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
+ def evInputListCondGraphMismatch(check=False, **kwargs):
+ error_name = ErrorIf.InputListCondGraphMismatch
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Input list does not match cond graph"
+
+ if check:
+ basicBlocks = kwargs["basicBlocks"]
+ while_block = basicBlocks[0]
+ while_inputs = while_block.inputs
+ while_tens = while_block.tensors
+ cond_block = basicBlocks[1]
+ cond_inputs = cond_block.inputs
+ cond_tens = cond_block.tensors
+ if (
+ while_tens[while_inputs[0]].shape != cond_tens[cond_inputs[0]].shape
+ ) or (while_tens[while_inputs[1]].shape != cond_tens[cond_inputs[2]].shape):
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
+ def evInputListBodyGraphInputMismatch(check=False, **kwargs):
+ error_name = ErrorIf.InputListBodyGraphInputMismatch
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Input list does not match body graph input"
+
+ if check:
+ basicBlocks = kwargs["basicBlocks"]
+ while_block = basicBlocks[0]
+ while_inputs = while_block.inputs
+ while_tens = while_block.tensors
+ body_block = basicBlocks[2]
+ body_outputs = body_block.inputs
+ body_tens = body_block.tensors
+ if (
+ while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape
+ ) or (
+ while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape
+ ):
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
+ def evInputListBodyGraphOutputMismatch(check=False, **kwargs):
+ error_name = ErrorIf.InputListBodyGraphOutputMismatch
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Input list does not match body graph output"
+
+ if check:
+ basicBlocks = kwargs["basicBlocks"]
+ while_block = basicBlocks[0]
+ while_inputs = while_block.inputs
+ while_tens = while_block.tensors
+ body_block = basicBlocks[2]
+ body_outputs = body_block.outputs
+ body_tens = body_block.tensors
+ if (
+ while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape
+ ) or (
+ while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape
+ ):
+ error_result = True
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
+ def evCondGraphOutputNotMatchingBool(check=False, **kwargs):
+ error_name = ErrorIf.CondGraphOutputNotMatchingBool
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Cond graph output is not a match list of booleans"
+
+ if check:
+ basicBlocks = kwargs["basicBlocks"]
+ cond_block = basicBlocks[1]
+ cond_outputs = cond_block.outputs
+ cond_tens = cond_block.tensors
+ if cond_tens[cond_outputs[0]].dtype != DType.BOOL:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+
+class TosaInvalidValidator:
+ @staticmethod
+ def ivWrongDataTypeOrModeResize(**kwargs):
+ input_dtype = kwargs["input_dtype"]
+ args = kwargs["args"]
+ mode = args[0]
+ output_dtype = args[8]
+
+ if mode == ResizeMode.BILINEAR:
+ # Invalid output data type / Invalid input datatype
+ return (
+ not (input_dtype == DType.INT8 and output_dtype == DType.INT32)
+ or not (input_dtype == DType.INT16 and output_dtype == DType.INT48)
+ or not (input_dtype == DType.FLOAT and output_dtype == DType.FLOAT)
+ or (input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT])
+ )
+ elif mode == ResizeMode.NEAREST:
+ # Invalid output data type / Invalid input datatype
+ return (input_dtype != output_dtype) or (
+ input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT]
+ )
+ else:
+ # Invalid resize mode
+ return True
+
+ @staticmethod
+ def ivBadStride(**kwargs):
+ input_dtype = kwargs["input_dtype"]
+ args = kwargs["args"]
+ stride_x = args[1][0]
+ stride_y = args[1][1]
+ stride_fp_x = args[4][0]
+ stride_fp_y = args[4][1]
+
+ if input_dtype == DType.FLOAT:
+ if stride_fp_x <= 0 or stride_fp_y <= 0:
+ # Negative or zero stride
+ return True
+ else:
+ if stride_x <= 0 or stride_y <= 0:
+ # Negative or zero stride
+ return True
+ return False
+
+ @staticmethod
+ def ivHeightWidthInvalid(**kwargs):
+ opName = kwargs["opName"]
+
+ inputShapes = kwargs["shapeList"]
+ input_shape = inputShapes[0]
+
+ args = kwargs["args"]
+ strides = args[0]
+ padding = args[1]
+
+ if opName.endswith("pool2d"):
+ # avg_pool2d, max_pool2d
+ kernel_shape = args[2]
+ h = (
+ input_shape[1] + padding[0] + padding[1] + strides[0] - kernel_shape[0]
+ ) // strides[0]
+ w = (
+ input_shape[2] + padding[2] + padding[3] + strides[1] - kernel_shape[1]
+ ) // strides[1]
+ # return True if any dimension is < 1
+ return h < 1 or w < 1
+
+ if opName.startswith("transpose_conv2d"):
+ # transpose_conv2d
+ dilations = args[2]
+ output_shape = args[3]
+ filter_shape = inputShapes[1]
+ kernel_shape = filter_shape[1:-1]
+
+ def get_out_size(in_size, stride, kernel_size, dilation, out_pad, in_pad):
+ """Calculate the transpose_conv2d output size for a dimension.
+
+ Based on the keras function deconv_output_length, in
+ https://github.com/keras-team/keras/blob/master/keras/utils/conv_utils.py
+
+ Args:
+ in_size: the input size - int
+ stride: the stride - int
+ kernel_size: the kernel size - int
+ dilation: the kernel dilation - int
+ out_pad: the output padding - int
+ in_pad: the input padding - int
+
+ Returns:
+ the output size
+ """
+ dilated_kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
+ return (
+ (in_size - 1) * stride + dilated_kernel_size - 2 * in_pad + out_pad
+ )
+
+ for pad_h, pad_w in (
+ (kernel_shape[0] - 1, kernel_shape[1] - 1), # FULL padding
+ (kernel_shape[0] // 2, kernel_shape[1] // 2), # SAME padding
+ (0, 0), # VALID padding
+ ):
+ h = get_out_size(
+ input_shape[1],
+ strides[0],
+ kernel_shape[0],
+ dilations[0],
+ padding[0],
+ pad_h,
+ )
+ w = get_out_size(
+ input_shape[2],
+ strides[1],
+ kernel_shape[1],
+ dilations[1],
+ padding[1],
+ pad_w,
+ )
+ if output_shape[1] == h and output_shape[2] == w:
+ return False
+
+ # output shape does not match the expected shape for any padding option
+ return True
+
+ if "conv2d" in opName or "conv3d" in opName:
+ # conv2d, conv3d, depthwise_conv2d
+ dilations = args[2]
+ filter_shape = inputShapes[1]
+ kernel_shape = (
+ filter_shape[0:2]
+ if opName.startswith("depthwise_conv2d")
+ else filter_shape[1:-1]
+ )
+
+ for i in range(len(kernel_shape)):
+ dim = (
+ input_shape[i + 1]
+ - kernel_shape[i]
+ - (kernel_shape[i] - 1) * (dilations[i] - 1)
+ + padding[i * 2 + 0]
+ + padding[i * 2 + 1]
+ ) // strides[i] + 1
+ # return True if any dimension is < 1
+ if dim < 1:
+ return True
+ return False
+
+ assert False, f"Unrecognized Op: {opName}"
+
+ @staticmethod
+ def ivNonPositiveOutputShape(**kwargs):
+ args = kwargs["args"]
+ output_shape = args[3]
+ if output_shape[1] <= 0 or output_shape[2] <= 0:
+ # Negative output shape
+ return True
+ return False
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 5ae5ed2..38365d0 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -1,3544 +1,21 @@
# Copyright (c) 2020-2022, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
-import itertools
-import math
import os
from copy import deepcopy
import numpy as np
import serializer.tosa_serializer as ts
+from generator.tosa_arg_gen import TosaArgGen
+from generator.tosa_arg_gen import TosaQuantGen
+from generator.tosa_arg_gen import TosaTensorGen
+from generator.tosa_arg_gen import TosaTensorValuesGen
from generator.tosa_error_if import ErrorIf
-from serializer.tosa_serializer import DTypeNames
+from generator.tosa_error_if import TosaErrorIfArgGen
+from generator.tosa_error_if import TosaErrorValidator
+from generator.tosa_error_if import TosaInvalidValidator
+from generator.tosa_utils import usableDTypes
from tosa.DType import DType
from tosa.Op import Op
-from tosa.ResizeMode import ResizeMode
-
-# DTypeNames, DType, Op and ResizeMode are convenience variables to the
-# flatc-generated types that should be enums, but aren't
-
-
-def valueToName(item, value):
- """Get the name of an attribute with the given value.
-
- This convenience function is needed to print meaningful names for
- the values of the tosa.Op.Op and tosa.DType.DType classes.
- This would not be necessary if they were subclasses of Enum, or
- IntEnum, which, sadly, they are not.
-
- Args:
- item: The class, or object, to find the value in
- value: The value to find
-
- Example, to get the name of a DType value:
-
- name = valueToName(DType, DType.INT8) # returns 'INT8'
- name = valueToName(DType, 4) # returns 'INT8'
-
- Returns:
- The name of the first attribute found with a matching value,
-
- Raises:
- ValueError if the value is not found
- """
- for attr in dir(item):
- if getattr(item, attr) == value:
- return attr
- raise ValueError(f"value ({value}) not found")
-
-
-def allDTypes(*, excludes=None):
- """Get a set of all DType values, optionally excluding some values.
-
- This convenience function is needed to provide a sequence of DType values.
- This would be much easier if DType was a subclass of Enum, or IntEnum,
- as we could then iterate over the values directly, instead of using
- dir() to find the attributes and then check if they are what we want.
-
- Args:
- excludes: iterable of DTYPE values (e.g. [DType.INT8, DType.BOOL])
-
- Returns:
- A set of DType values
- """
- excludes = () if not excludes else excludes
- return {
- getattr(DType, t)
- for t in dir(DType)
- if not callable(getattr(DType, t))
- and not t.startswith("__")
- and getattr(DType, t) not in excludes
- }
-
-
-def usableDTypes(*, excludes=None):
- """Get a set of usable DType values, optionally excluding some values.
-
- Excludes (DType.UNKNOWN, DType.UINT8) in addition to the excludes
- specified by the caller, as the serializer lib does not support them.
- If you wish to include 'UNKNOWN' or 'UINT8' use allDTypes instead.
-
- Args:
- excludes: iterable of DType values (e.g. [DType.INT8, DType.BOOL])
-
- Returns:
- A set of DType values
- """
- omit = {DType.UNKNOWN, DType.UINT8}
- omit.update(excludes if excludes else ())
- return allDTypes(excludes=omit)
-
-
-def product(shape):
- value = 1
- for n in shape:
- value *= n
- return value
-
-
-class TosaQuantGen:
- """QuantizedInfo random generator helper functions.
-
- Specify with 'qgen': in the operator defintion.
- """
-
- def __init__(self):
- pass
-
- @staticmethod
- def getQinfo(testGen, dtype, error_name=None):
-
- if dtype == DType.INT8:
- return testGen.randInt(-128, 128)
- elif dtype == DType.UINT8:
- return testGen.randInt(0, 256)
- elif error_name in [
- ErrorIf.InputZeroPointNotZero,
- ErrorIf.WeightZeroPointNotZero,
- ErrorIf.OutputZeroPointNotZero,
- ]:
- zero_point = testGen.randInt(-128, 128)
- if zero_point == 0:
- zero_point = 1
- return zero_point
- return 0
-
- @staticmethod
- def qgUnary(testGen, op, dtype, error_name=None):
- qinfo = ts.TosaSerializerQuantInfo()
- if error_name == ErrorIf.InputZeroPointNotZero:
- qinfo.UnaryQuantInfo(
- TosaQuantGen.getQinfo(testGen, dtype, error_name),
- TosaQuantGen.getQinfo(testGen, dtype),
- )
- elif error_name == ErrorIf.OutputZeroPointNotZero:
- qinfo.UnaryQuantInfo(
- TosaQuantGen.getQinfo(testGen, dtype),
- TosaQuantGen.getQinfo(testGen, dtype, error_name),
- )
- else:
- qinfo.UnaryQuantInfo(
- TosaQuantGen.getQinfo(testGen, dtype),
- TosaQuantGen.getQinfo(testGen, dtype),
- )
- return qinfo
-
- @staticmethod
- def qgConv(testGen, op, dtype_or_dtypeList, error_name=None):
- qinfo = ts.TosaSerializerQuantInfo()
- if isinstance(dtype_or_dtypeList, list):
- # a list of [input, weights, accumulator] dtypes
- dtypeList = dtype_or_dtypeList
- else:
- # an int, [input, weights, accumulator] dtypes are the same
- dtypeList = [dtype_or_dtypeList] * 3
-
- if error_name == ErrorIf.InputZeroPointNotZero:
- input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0], error_name)
- weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1])
- elif error_name == ErrorIf.WeightZeroPointNotZero:
- input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0])
- weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1], error_name)
- else:
- input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0])
- weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1])
-
- qinfo.ConvQuantInfo(input_zp, weights_zp)
- return qinfo
-
- @staticmethod
- def qgMatmul(testGen, op, dtype, error_name=None):
- qinfo = ts.TosaSerializerQuantInfo()
- if error_name == ErrorIf.InputZeroPointNotZero:
- qinfo.MatMulQuantInfo(
- TosaQuantGen.getQinfo(testGen, dtype, error_name),
- TosaQuantGen.getQinfo(testGen, dtype, error_name),
- )
- else:
- qinfo.MatMulQuantInfo(
- TosaQuantGen.getQinfo(testGen, dtype),
- TosaQuantGen.getQinfo(testGen, dtype),
- )
- return qinfo
-
- @staticmethod
- def qgPad(testGen, op, dtype, error_name=None):
- qinfo = ts.TosaSerializerQuantInfo()
- if error_name == ErrorIf.InputZeroPointNotZero:
- qinfo.PadQuantInfo(TosaQuantGen.getQinfo(testGen, dtype, error_name))
- else:
- qinfo.PadQuantInfo(TosaQuantGen.getQinfo(testGen, dtype))
- return qinfo
-
- @staticmethod
- def computeMultiplierAndShift(scaleFp, scale32):
- # Derived from computeMultiplierAndShiftTosaScale32
- # Provide a floating-point scaling factor and the scale32 parameter
- # to compute the multiplier and shift
-
- if scale32:
- scaleBits = 31
- else:
- scaleBits = 15
-
- m, shift = math.frexp(scaleFp)
-
- if scaleFp < 0.0:
- m = -m
-
- multiplier = round(m * (1 << scaleBits))
- assert multiplier <= (1 << scaleBits)
-
- if multiplier == (1 << scaleBits):
- multiplier = multiplier // 2
- shift = shift + 1
-
- shift = (-shift) + scaleBits
- # print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(
- # scaleFp, scaleBits, m, multiplier, shift))
-
- # Adjust multiplier such that shift is in allowed value range.
- if shift == 0:
- multiplier = multiplier // 4
- shift = shift + 2
- elif shift == 1:
- multiplier = multiplier // 2
- shift = shift + 1
- elif shift == 63:
- multiplier = multiplier * 2
- shift = shift - 1
-
- assert multiplier <= (1 << scaleBits)
- assert shift >= 2 and shift <= 62
-
- return multiplier, shift
-
-
-class TosaTensorGen:
- """Tensor generators create a shape list for the placeholder and const tensor
- data operands for the operator.
-
- The actual random data is generated separately for each test.
- """
-
- def __init__(self):
- pass
-
- @staticmethod
- def tgBasic(testGen, opName, rank, error_name=None):
- pl, const = opName["operands"]
- shape = testGen.makeShape(rank)
-
- # Constrict the overall size of the shape when creating ERROR_IF tests
- if error_name:
- shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
-
- shape_list = []
- for i in range(pl + const):
- shape_list.append(shape.copy())
-
- if error_name == ErrorIf.RankMismatch:
- if rank == 1 and i != 1:
- shape = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
- elif i != 1:
- shape = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
-
- return shape_list
-
- @staticmethod
- def tgNHWC(testGen, opName, rank, error_name=None):
- pl, const = opName["operands"]
-
- if error_name != ErrorIf.WrongRank:
- assert rank == 4
-
- shape = testGen.makeShape(rank)
-
- # Constrict the batch size?
- if testGen.args.max_batch_size:
- shape[0] = (shape[0] % testGen.args.max_batch_size) + 1
-
- # Constrict the overall size of the shape when creating ERROR_IF tests
- if error_name and error_name != ErrorIf.MaxDimExceeded:
- shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
-
- shape_list = []
- for i in range(pl + const):
- shape_list.append(shape.copy())
-
- return shape_list
-
- @staticmethod
- def tgScatter(testGen, opName, rank, error_name=None):
- pl, const = opName["operands"]
-
- assert pl == 2
- assert const == 0
- if error_name != ErrorIf.WrongRank:
- assert rank == 3
-
- values_in_shape = testGen.makeShape(rank)
-
- # ignore max batch size if target shape is set
- if testGen.args.max_batch_size and not testGen.args.target_shapes:
- values_in_shape[0] = (values_in_shape[0] % testGen.args.max_batch_size) + 1
-
- W = testGen.randInt(
- testGen.args.tensor_shape_range[0], testGen.args.tensor_shape_range[1]
- )
- # Constrict W if one dimension is too large to keep tensor size reasonable
- if max(values_in_shape) > 5000:
- W = testGen.randInt(0, 16)
-
- input_shape = [values_in_shape[0], W, values_in_shape[2]]
-
- shape_list = []
- shape_list.append(values_in_shape.copy())
- shape_list.append(input_shape.copy())
-
- return shape_list
-
- @staticmethod
- def tgBroadcastFuzz(testGen, op, rank, error_name=None):
- shape = testGen.makeShape(rank)
-
- pl, const = op["operands"]
-
- shape_list = []
-
- # Choose one of the inputs to broadcast
- # Note: Simplifies OutputShaper code if we don't change first shape for errors
- bcast_idx = testGen.randInt(0 if error_name is None else 1, pl + const)
- for i in range(pl + const):
- shape_bcast = shape.copy()
-
- # If the chosen input, pick a random index to broadcast
- if i == bcast_idx:
- fuzz_idx = testGen.randInt(0, rank)
- if error_name == ErrorIf.DimensionMismatch:
- shape_bcast[fuzz_idx] += 1
- elif error_name == ErrorIf.RankMismatch:
- # Add one rank to the shape (or more for rank of 1)
- extra_ranks = testGen.rng.choice([1, 2, 3]) if rank == 1 else 1
- shape_bcast = np.concatenate(
- (shape_bcast, testGen.makeShape(extra_ranks))
- )
- if rank != 1:
- # Either keep the extra rank, or remove it
- new_len = testGen.rng.choice([-2, len(shape_bcast)])
- shape_bcast = shape_bcast[:new_len]
- else:
- shape_bcast[fuzz_idx] = 1
-
- shape_list.append(shape_bcast)
-
- return shape_list
-
- @staticmethod
- def tgConv2D(testGen, op, rank, error_name=None):
- pl, const = op["operands"]
-
- if error_name != ErrorIf.WrongRank:
- assert rank == 4
-
- # IFM dimensions are NHWC
- ifm_shape = testGen.makeShape(rank)
-
- # Constrict the batch size?
- if testGen.args.max_batch_size:
- ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
-
- # Constrict the overall size of the shape when creating ERROR_IF tests
- if error_name:
- ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
- ifm_shape, max_dim=24, max_items=10000
- )
-
- # Get the filter height/width from the operator parameters
- filter_hw = op["filter"]
-
- # Generate a random OFM depth
- ofm_depth = testGen.makeShape(1)[0]
-
- # The filter dimensions are OHWI
- filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
-
- # The bias is OC
- bias_shape = np.asarray([ofm_depth])
-
- return [ifm_shape, filter_shape, bias_shape]
-
- @staticmethod
- def tgConv3D(testGen, op, rank, error_name=None):
- pl, const = op["operands"]
-
- if error_name != ErrorIf.WrongRank:
- assert rank == 5
-
- # IFM dimensions are NDHWC
- ifm_shape = testGen.makeShape(rank)
-
- # Constrict the batch size?
- if testGen.args.max_batch_size:
- ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
-
- # Constrict the overall size of the shape when creating ERROR_IF tests
- if error_name:
- ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
- ifm_shape, max_dim=24, max_items=10000
- )
-
- # Get the filter depth/height/width from the operator parameters
- filter_dhw = op["filter"]
-
- # Generate a random OFM channel
- ofm_channel = testGen.makeShape(1)[0]
-
- # The filter dimensions are ODHWI
- filter_shape = np.asarray(
- [ofm_channel, filter_dhw[0], filter_dhw[1], filter_dhw[2], ifm_shape[4]]
- )
-
- # The bias is OC
- bias_shape = np.asarray([ofm_channel])
-
- return [ifm_shape, filter_shape, bias_shape]
-
- @staticmethod
- def tgTransposeConv2D(testGen, op, rank, error_name=None):
- pl, const = op["operands"]
-
- if error_name != ErrorIf.WrongRank:
- assert rank == 4
-
- # IFM dimensions are NHWC
- ifm_shape = testGen.makeShape(rank)
-
- # Constrict the batch size?
- if testGen.args.max_batch_size:
- ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
-
- # Constrict the overall size of the shape when creating ERROR_IF tests
- if error_name:
- ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
- ifm_shape, max_dim=24, max_items=10000
- )
-
- # Get the filter height/width from the operator parameters
- filter_hw = op["filter"]
-
- # Generate a random OFM depth
- ofm_depth = testGen.makeShape(1)[0]
-
- # The filter dimensions are OHWI
- filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
-
- # The bias is OC
- bias_shape = np.asarray([ofm_depth])
-
- return [ifm_shape, filter_shape, bias_shape]
-
- @staticmethod
- def tgDepthwiseConv2D(testGen, op, rank, error_name=None):
- pl, const = op["operands"]
-
- if error_name != ErrorIf.WrongRank:
- assert rank == 4
- assert pl == 1 and const == 2
-
- # IFM dimensions are NHWC
- ifm_shape = testGen.makeShape(rank)
-
- # Constrict the batch size?
- if testGen.args.max_batch_size:
- ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
-
- # Constrict the overall size of the shape when creating ERROR_IF tests
- if error_name:
- ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(
- ifm_shape, max_dim=24, max_items=10000
- )
-
- # Get the filter height/width from the operator parameters
- # Filter is KH, HW, C, M
- filter_hw = op["filter"]
-
- # Generate a random OFM depth, but don't let it get too big because
- # the output depth is M * C
- filter_m = (
- testGen.makeShape(1)[0] % (testGen.args.tensor_shape_range[1] // 4)
- ) + 1
-
- # The filter dimensions are HWCM
- filter_shape = np.asarray([filter_hw[0], filter_hw[1], ifm_shape[3], filter_m])
-
- # The bias is M * C
- bias_shape = np.asarray([ifm_shape[3] * filter_m])
-
- return [ifm_shape, filter_shape, bias_shape]
-
- @staticmethod
- def tgFullyConnected(testGen, op, rank, error_name=None):
- pl, const = op["operands"]
-
- if error_name != ErrorIf.WrongRank:
- assert rank == 2
-
- input_shape = testGen.makeShape(rank)
-
- # Constrict the overall size of the shape when creating ERROR_IF tests
- if error_name:
- input_shape = TosaErrorIfArgGen.eiRestrictDimensions(input_shape)
-
- filter_oc = testGen.rng.integers(
- low=testGen.args.tensor_shape_range[0],
- high=testGen.args.tensor_shape_range[1],
- size=1,
- )[0]
- filter_shape = np.asarray([filter_oc, input_shape[1]])
-
- bias_shape = np.asarray([filter_oc])
-
- return [input_shape, filter_shape, bias_shape]
-
- @staticmethod
- def tgMatmul(testGen, op, rank, error_name=None):
- pl, const = op["operands"]
-
- if error_name != ErrorIf.WrongRank:
- assert rank == 3
- assert pl == 2 and const == 0
-
- a_shape = testGen.makeShape(rank)
-
- # Constrict the overall size of the shape when creating ERROR_IF tests
- if error_name:
- a_shape = TosaErrorIfArgGen.eiRestrictDimensions(a_shape)
-
- # Get a random number for b_oc even if target shape is defined
- b_oc = np.int32(
- testGen.rng.integers(
- low=testGen.args.tensor_shape_range[0],
- high=testGen.args.tensor_shape_range[1],
- size=1,
- )
- )[0]
- # If N or H is large let b_oc be 1 to reduce output tensor size
- if max(a_shape) > 1000:
- b_oc = 1
-
- b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
- return [a_shape, b_shape]
-
- @staticmethod
- def tgConcat(testGen, opName, rank, error_name=None):
- pl, const = opName["operands"]
- shape = testGen.makeShape(rank)
-
- # Create extra tensors to concat.
- # Take into account value of pl when getting maximum number of concats
- num_tensors = testGen.randInt(0, 4)
- shape_list = []
- for i in range(pl + const + num_tensors):
- if error_name == ErrorIf.ConcatInputRankMismatch and i != 0:
- remove = testGen.rng.choice([True, False])
- wrongShape = shape.copy()
-
- if remove and len(shape) > 1:
- wrongShape = wrongShape[1:]
- else:
- wrongShape = list(wrongShape)
- wrongShape.append(testGen.rng.integers(1, 10))
-
- shape_list.append(wrongShape)
- else:
- shape_list.append(shape.copy())
-
- return shape_list
-
- @staticmethod
- def tgConcatConstInput(testGen, shapeList, axis, error_name=None):
- if error_name in [
- ErrorIf.AxisSmallerZero,
- ErrorIf.AxisLargerRank,
- ErrorIf.ConcatInputRankMismatch,
- ]:
- return shapeList
-
- # Split concat shape along axis to allow for multiple const inputs
- # without making too many large tensors
- if len(shapeList) == 2 or shapeList[0][axis] < len(shapeList):
- # If axis can't be split we still need to invalidate other dimensions
- if error_name == ErrorIf.ConcatInputDimMismatch:
- for shape in shapeList[1:]:
- # Negative test shapeLists are created individually for each test,
- # so no need to copy the shape before altering it.
- shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
- return shapeList
-
- # Create copy of shape we are going to split (so we don't alter shapeList)
- shape = shapeList[0].copy()
- # Add original shape as first input
- new_shapeList = [shape.copy()]
- length_on_axis = shape[axis]
- remaining_length = length_on_axis
- for i in range(len(shapeList) - 2):
- # Calculate split on axis and remaining value
- split_shape_val = int(shape[axis] / 2)
- remaining_length = remaining_length - split_shape_val
-
- # Append new shape, and set remaining shape
- shape[axis] = split_shape_val
- new_shapeList.append(shape.copy())
-
- # invalidate dimensions
- if error_name == ErrorIf.ConcatInputDimMismatch:
- shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
- else:
- shape[axis] = remaining_length
-
- if i == len(shapeList) - 3:
- new_shapeList.append(shape.copy())
-
- return new_shapeList
-
-
-class TosaArgGen:
- """Argument generators create exhaustive or random lists of attributes for
- operators that take attributes or other parameters.
-
- The return value is a list of (descriptive_name, [arglist]) tuples where
- the descriptive_name is appended to the test name and the arglist is expanded
- as arguments to the operator build function.
- """
-
- def __init__(self):
- pass
-
- @staticmethod
- def agNone(testGen, opName, shapeList, dtype, error_name=None):
- """A trivial argument generator for operators that don't take any
- non-tensor arguments"""
- return [("", [])]
-
- @staticmethod
- def agAxis(testGen, opName, shapeList, dtype, error_name=None):
- """Build the axis argument for operators that take a single axis"""
- axes = []
- shape = shapeList[0]
-
- if error_name == ErrorIf.AxisSmallerZero:
- small_axis = testGen.rng.integers(-5, 0)
- axes.append(("axis{}".format(small_axis), [small_axis]))
- elif error_name == ErrorIf.AxisLargerRank:
- large_axis = testGen.rng.integers(len(shape) + 1, len(shape) + 10)
- axes.append(("axis{}".format(large_axis), [large_axis]))
- else:
- for a in range(0, len(shape)):
- axes.append(("axis{}".format(a), [a]))
-
- return axes
-
- @staticmethod
- def agConv(testGen, opName, shapeList, dtype, error_name=None):
- arg_list = []
-
- ifm_shape = shapeList[0]
- filter_shape = shapeList[1]
- # determine the kernel shape from operator name (e.g. "conv2d_3x3" => [3,3])
- k = [int(x) for x in opName.split("_")[-1].split("x")]
-
- # Check the rank
- rank = 5 if opName.startswith("conv3d") else 4
- if error_name != ErrorIf.WrongRank:
- assert len(ifm_shape) == rank
- assert len(filter_shape) == rank
-
- # kernel rank omits batch and channels
- k_rank = rank - 2
- assert len(k) == k_rank
-
- # Generate comprehensive argument lists
- # - except for named errors, which use specific invalid value(s)
- if error_name == ErrorIf.PadSmallerZero:
- p_vals = [testGen.rng.choice(range(-5, 0))]
- else:
- p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
- paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))}
- if error_name == ErrorIf.StrideSmallerOne:
- # Can't use stride=0, as it is used to derive output shape, as a divisor
- s_vals = [testGen.rng.choice(range(-5, 0))]
- else:
- s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
- strides = {x for x in itertools.product(*([s_vals] * k_rank))}
- if error_name == ErrorIf.DilationSmallerOne:
- d_vals = [testGen.rng.choice(range(-5, 1))]
- else:
- d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
- dilations = {x for x in itertools.product(*([d_vals] * k_rank))}
-
- if not error_name and testGen.args.oversize:
- # add some oversize argument values
- if max(ifm_shape) < 64:
- bigPadding = 9
- paddings.update(
- {x for x in itertools.product(*([[0, bigPadding]] * (k_rank * 2)))}
- )
- bigStride = 8
- strides.update({x for x in itertools.product(*([[1, bigStride]] * k_rank))})
- bigDilation = 7
- dilations.update(
- {x for x in itertools.product(*([[1, bigDilation]] * k_rank))}
- )
-
- # There are too many parameter combinations, so generate them sparsely,
- # very sparse for negative tests
- sparsity_factor = 2 if error_name else 100
- sparsity = len(paddings) * len(strides) * len(dilations) // sparsity_factor + 1
- # If there are only a small number of tests, just select them all
- if sparsity < 13:
- sparsity = 1
- # To get a variety of parameter combinations sparsity should not be a
- # multiple of 2, 3 or 5
- while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
- sparsity += 1
-
- n = 0
- for s in sorted(list(strides)):
- for p in sorted(list(paddings)):
- for d in sorted(list(dilations)):
- if (
- n % sparsity == 0
- # padding must not exceed the kernel size ?
- # and p[0] < k[0] and p[1] < k[0]
- # and p[2] < k[1] and p[3] < k[1]
- # and (k_rank < 3 or (p[4] < k[2] and p[5] < k[2]))
- # the padded shape must exceed the kernel size
- and (ifm_shape[1] + p[0] + p[1]) > k[0]
- and (ifm_shape[2] + p[2] + p[3]) > k[1]
- and (k_rank < 3 or ((ifm_shape[3] + p[4] + p[5]) > k[2]))
- # the padded shape must exceed the dilation
- and (ifm_shape[1] + p[0] + p[1]) > d[0]
- and (ifm_shape[2] + p[2] + p[3]) > d[1]
- and (k_rank < 3 or ((ifm_shape[3] + p[4] + p[5]) > d[2]))
- ):
- arg_list.append(
- (
- "st{}_pad{}_dilat{}".format(
- "".join([str(x) for x in s]),
- "".join([str(x) for x in p]),
- "".join([str(x) for x in d]),
- ),
- [s, p, d],
- )
- )
- n += 1
-
- return arg_list
-
- @staticmethod
- def agTransposeConv2D(testGen, opName, shapeList, dtype, error_name=None):
- arg_list = []
-
- ifm_shape = shapeList[0]
- filter_shape = shapeList[1]
-
- # Must be rank 4
- if error_name != ErrorIf.WrongRank:
- assert len(ifm_shape) == 4
- assert len(filter_shape) == 4
-
- # Generate comprehensive argument lists
- # - except for named errors, which use specific invalid value(s)
- if error_name == ErrorIf.PadSmallerZero:
- p_vals = [testGen.rng.choice(range(-5, 0))]
- else:
- p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
- paddings = {x for x in itertools.product(*([p_vals] * 2))}
- if error_name == ErrorIf.StrideSmallerOne:
- # Can't use stride=0, as it is used to derive output shape, as a divisor
- s_vals = [testGen.rng.choice(range(-5, 0))]
- else:
- s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
- strides = {x for x in itertools.product(*([s_vals] * 2))}
- if error_name == ErrorIf.DilationSmallerOne:
- d_vals = [testGen.rng.choice(range(-5, 1))]
- else:
- d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
- dilations = {x for x in itertools.product(*([d_vals] * 2))}
-
- if not error_name:
- # add some oversize argument values
- if max(ifm_shape) < 64:
- bigPadding = 9
- paddings.update(
- {x for x in itertools.product(*([[0, bigPadding]] * 2))}
- )
- bigStride = 8
- strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
- bigDilation = 7
- dilations.update({x for x in itertools.product(*([[1, bigDilation]] * 2))})
-
- # There are too many parameter combinations, so generate them sparsely,
- # very sparse for negative tests
- sparsity_factor = 2 if error_name else 100
- sparsity = len(paddings) * len(strides) * len(dilations) // sparsity_factor + 1
- # If there are only a small number of tests, just select them all
- if sparsity < 13:
- sparsity = 1
- # To get a variety of parameter combinations sparsity should not be a
- # multiple of 2, 3 or 5
- while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
- sparsity += 1
-
- n = 0
- for s in sorted(list(strides)):
- for p in sorted(list(paddings)):
- for d in sorted(list(dilations)):
- if n % sparsity == 0:
- # Determine the output shape
- oh = (
- ifm_shape[1]
- - filter_shape[1]
- - (filter_shape[1] - 1) * (d[0] - 1)
- + 2 * p[0]
- ) // s[0] + 1
- ow = (
- ifm_shape[2]
- - filter_shape[2]
- - (filter_shape[2] - 1) * (d[1] - 1)
- + 2 * p[1]
- ) // s[1] + 1
- os = [ifm_shape[0], oh, ow, filter_shape[0]]
- arg_list.append(
- (
- "st{}_pad{}_dilat{}_os{}".format(
- "".join([str(x) for x in s]),
- "".join([str(x) for x in p]),
- "".join([str(x) for x in d]),
- "x".join([str(x) for x in os]),
- ),
- [s, p, d, os],
- )
- )
- n += 1
-
- return arg_list
-
- @staticmethod
- def agPad(testGen, opName, shapeList, dtype, error_name=None):
- arg_list = []
- rank = len(shapeList[0])
-
- # Exhaustively test combinations of padding on each side of each dimension
- # - the range of padding values is defined by pad_min and pad_max
- # - for padding >9, the name format needs to be more distinctive
- pad_min, pad_max = 0, 1
- pad_values = [x for x in range(pad_min, pad_max + 1)]
- if error_name == ErrorIf.PadSmallerZero:
- pad_values = [x for x in range(-2, 0)]
- axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
- shape_pad_values = itertools.product(*([axis_pad_values] * rank))
-
- if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]:
- pad_const_int = testGen.getRandNumberDType(dtype)
- pad_const_fp = 0
- elif dtype == DType.FLOAT:
- pad_const_int = 0
- pad_const_fp = testGen.getRandNumberDType(dtype)
- else:
- return []
-
- for paddings in shape_pad_values:
- name = "pad"
- for r in range(rank):
- before, after = paddings[r]
- name = f"{name}{before}{after}"
- arg_list.append((name, [np.array(paddings), pad_const_int, pad_const_fp]))
-
- return arg_list
-
- @staticmethod
- def agPooling(testGen, opName, shapeList, dtype, error_name=None):
- arg_list = []
-
- shape = shapeList[0]
- if error_name != ErrorIf.WrongRank:
- assert len(shape) == 4
-
- # Generate comprehensive argument lists
- p_vals = [x for x in range(0, testGen.args.max_pooling_padding + 1)]
- paddings = {x for x in itertools.product(*([p_vals] * 4))}
- s_vals = [x for x in range(1, testGen.args.max_pooling_stride + 1)]
- strides = {x for x in itertools.product(*([s_vals] * 2))}
- k_vals = [x for x in range(2, testGen.args.max_pooling_kernel + 1)]
- kernels = {x for x in itertools.product(*([k_vals] * 2))}
-
- if testGen.args.oversize:
- # add some oversize argument values
- bigStride = 7
- strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
- bigKernel = 6
- kernels.update({x for x in itertools.product(*([[2, bigKernel]] * 2))})
- if max(shape) < 64:
- # padding must be less than the kernel size
- bigPadding = bigKernel - 1
- paddings.update(
- {x for x in itertools.product(*([[0, bigPadding]] * 4))}
- )
-
- # There are too many parameter combinations, so generate them sparsely,
- # very sparse for negative tests
- sparsity_factor = 2 if error_name else 500
- sparsity = len(paddings) * len(strides) * len(kernels) // sparsity_factor + 1
-
- n = 0
- for s in sorted(list(strides)):
- for p in sorted(list(paddings)):
- for k in sorted(list(kernels)):
- if error_name in [
- ErrorIf.StrideSmallerOne,
- ErrorIf.KernelSmallerOne,
- ErrorIf.PadSmallerZero,
- ErrorIf.PadLargerEqualKernel,
- ]:
- sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf(
- testGen, error_name, s, p, k
- )
- if None not in [sNew, pNew, kNew] and n % sparsity == 0:
- arg_list.append(
- (
- "st{}_kern{}_pad{}".format(
- "".join([str(x) for x in sNew]),
- "".join([str(x) for x in kNew]),
- "".join([str(x) for x in pNew]),
- ),
- [sNew, pNew, kNew],
- )
- )
- elif (
- n % sparsity == 0
- # padding must not exceed the kernel size
- and p[0] < k[0]
- and p[1] < k[0]
- and p[2] < k[1]
- and p[3] < k[1]
- # the padded shape must exceed the kernel size
- and (shape[1] + p[0] + p[1]) > k[0]
- and (shape[2] + p[2] + p[3]) > k[1]
- ):
- arg_list.append(
- (
- "st{}_kern{}_pad{}".format(
- "".join([str(x) for x in s]),
- "".join([str(x) for x in k]),
- "".join([str(x) for x in p]),
- ),
- [s, p, k],
- )
- )
- n += 1
-
- return arg_list
-
- @staticmethod
- def agCast(testGen, opName, shapeList, inDtype, error_name=None):
- arg_list = []
-
- # Enumerate the output types here
- if error_name == ErrorIf.WrongOutputType:
- dtypeList = TosaErrorIfArgGen.eiCastErrorIf(testGen, inDtype)
- elif inDtype == DType.INT8:
- dtypeList = [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT]
- elif inDtype == DType.INT16:
- dtypeList = [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT]
- elif inDtype == DType.INT32:
- dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
- elif inDtype == DType.BOOL:
- dtypeList = [DType.INT8, DType.INT16, DType.INT32]
- elif inDtype == DType.FLOAT:
- dtypeList = [DType.INT8, DType.INT16, DType.INT32]
- elif error_name == ErrorIf.WrongInputType:
- # Pick some potentially correct output type for incorrect input type
- dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
- else:
- raise Exception("Unexpected input dtype: {}".format(inDtype))
-
- for dtype in dtypeList:
- arg_list.append(("out{}".format(DTypeNames[dtype]), [dtype]))
-
- return arg_list
-
- @staticmethod
- def agRescale(testGen, opName, shapeList, inDtype, error_name=None):
- arg_list = []
-
- # Enumerate the output types here
- for dtype in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
- if (
- dtype in [DType.UINT8, DType.INT8]
- and error_name == ErrorIf.OutputZeroPointNotZero
- ):
- continue
- if (
- inDtype == DType.UINT8
- and dtype != DType.INT8
- and error_name != ErrorIf.WrongOutputType
- ):
- # The only output dtype for UINT8 is INT8, skip all other combinations
- continue
- if (
- inDtype != DType.INT8
- and dtype == DType.UINT8
- and error_name != ErrorIf.WrongOutputType
- ):
- # The only input dtype for UINT8 is INT8, skip all other combinations
- continue
- if (
- error_name == ErrorIf.WrongOutputType
- and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, dtype)
- ):
- continue
-
- for scale32 in [False, True]:
- if error_name == ErrorIf.ScaleTrue and not scale32:
- continue
- elif error_name == ErrorIf.ScaleNotTrue and scale32:
- continue
- for double_round in [False, True]:
- if error_name == ErrorIf.ScaleNotTrue and not double_round:
- continue
- for per_channel in [False, True]:
-
- if (
- inDtype == DType.INT48
- and scale32
- and error_name != ErrorIf.ScaleTrue
- ):
- # Illegal condition. Must be scale32=False
- continue
- if (
- double_round
- and not scale32
- and error_name != ErrorIf.ScaleNotTrue
- ):
- # Illegal condition. ERROR_IF(!scale32 && double_round)
- continue
-
- arg_list.append(
- (
- "out{}_sc{}_dr{}_pc{}".format(
- DTypeNames[dtype],
- int(scale32),
- int(double_round),
- int(per_channel),
- ),
- [dtype, scale32, double_round, per_channel],
- )
- )
-
- return arg_list
-
- @staticmethod
- def agMul(testGen, opName, shapeList, dtype, error_name=None):
- arg_list = []
-
- if dtype is DType.INT32:
- for p in range(testGen.args.num_rand_permutations):
-
- shift = testGen.randInt(0, 32)
-
- arg_list.append(("perm{}_shift{}".format(p, shift), [shift]))
- else:
- arg_list.append(("perm0_shift0", [0]))
-
- return arg_list
-
- @staticmethod
- def agArithmeticRightShift(testGen, opName, shapeList, dtype, error_name=None):
- arg_list = []
-
- arg_list.append(("roundTrue", [True]))
- arg_list.append(("roundFalse", [False]))
-
- return arg_list
-
- # Helper function for reshape. Gets some factors of a larger number.
- @staticmethod
- def getFactors(val, start=1):
- factors = []
-
- for i in range(start, int(np.sqrt(val)) + 1):
- if (val % i) == 0:
- factors.append(i)
-
- return factors
-
- @staticmethod
- def agReshape(testGen, opName, shapeList, dtype, error_name=None):
- arg_list = []
-
- origShape = shapeList[0]
-
- totalElements = 1
- for s in origShape:
- totalElements *= s
-
- # This code is NOT fast. Fortunately, the numbers are fairly small.
- factors = TosaArgGen.getFactors(totalElements)
-
- for p in range(testGen.args.num_rand_permutations):
- newRank = testGen.randInt(1, 7)
- if len(factors) < newRank:
- continue
-
- found = True
- # escape_counter breaks while loop if it continues on for too long
- escape_counter = 0
- while found:
- newShape = []
- # Generate newShape ensuring it isn't a duplicate
- remainingElements = totalElements
- shuffledFactors = testGen.rng.permutation(factors)
- for i in range(1, newRank):
- # pick rank-1 factors
- newShape.append(shuffledFactors[0])
- remainingElements = remainingElements // shuffledFactors[0]
- shuffledFactors = testGen.rng.permutation(
- TosaArgGen.getFactors(remainingElements)
- )
- newShape.append(remainingElements)
-
- # Toss in a -1 sometimes
- minusOne = testGen.randInt(0, newRank * 4)
- if minusOne < newRank:
- newShape[minusOne] = -1
-
- # Check for duplicates
- found = False
- for name, other_shape in arg_list:
- if other_shape[0] == newShape:
- found = True
- break
-
- escape_counter += 1
- if escape_counter >= 100:
- break
-
- if not found:
- arg_list.append(("perm{}_rank{}".format(p, newRank), [newShape]))
-
- return arg_list
-
- @staticmethod
- def agTranspose(testGen, opName, shapeList, dtype, error_name=None):
- arg_list = []
-
- ifm_shape = shapeList[0]
-
- if error_name == ErrorIf.IndexOutsideBounds:
- incorrect_large_index = range(len(ifm_shape) + 1, 2 * len(ifm_shape) + 1)
- incorrect_small_index = range(-len(ifm_shape), 0)
- permutations = [p for p in itertools.permutations(incorrect_large_index)]
- permutations.extend(
- [p for p in itertools.permutations(incorrect_small_index)]
- )
- elif error_name == ErrorIf.IndexUsedTwice:
- # Create list with a duplicated index
- perm_range = list(range(len(ifm_shape)))
- index_choice = testGen.rng.choice(range(len(perm_range)))
- perm_range[(index_choice + 1) % len(perm_range)] = perm_range[index_choice]
- permutations = [p for p in itertools.permutations(perm_range)]
-
- else:
- # Get all permutations
- permutations = [p for p in itertools.permutations(range(len(ifm_shape)))]
-
- # Limit to possible permutations from shape dimension or argument setting
- limit = min(len(permutations), testGen.args.num_rand_permutations)
-
- # Get random permutation generator that uses all permutations
- random_permutations = testGen.rng.permutation(permutations)
-
- # Create list of required amount of permutations
- arg_list = [
- ("perm{}".format(p), [random_permutations[p].tolist()])
- for p in range(limit)
- ]
- return arg_list
-
- @staticmethod
- def agSlice(testGen, opName, shapeList, dtype, error_name=None):
- arg_list = []
-
- ifm_shape = shapeList[0]
- rank = len(ifm_shape)
-
- for p in range(testGen.args.num_rand_permutations):
- start = []
- size = []
-
- valid = True
-
- for i in range(rank):
- if ifm_shape[i] > 1:
- start.append(testGen.randInt(0, ifm_shape[i]))
- size.append(testGen.randInt(0, ifm_shape[i] - start[i]))
-
- # Invalid slice size?
- if size[i] == 0:
- valid = False
- else:
- start.append(0)
- size.append(1)
-
- if valid:
- # If ERROR_IF test required then incorrect start, size will be returned
- start, size = TosaErrorIfArgGen.eiSliceErrorIf(
- testGen, error_name, ifm_shape, start, size
- )
- arg_list.append(("perm{}".format(p), [start, size]))
- return arg_list
-
- @staticmethod
- def agTile(testGen, opName, shapeList, dtype, error_name=None):
- arg_list = []
-
- ifm_shape = shapeList[0]
- rank = len(ifm_shape)
-
- for p in range(testGen.args.num_rand_permutations):
-
- # Pick a few random, but small multiple values
- # because otherwise this has a tendency to generate
- # enormous tensors
- multiples = []
- for i in range(rank):
- if ifm_shape[i] > 1000:
- # Multiple of 1 if ifm_shape dimension is large to reduce
- # tensor size
- multiples.append(1)
- elif max(ifm_shape) > 1000:
- multiples.append(2)
- else:
- multiples.append(testGen.randInt(1, 4))
- arg_list.append(("perm{}".format(p), [multiples]))
-
- return arg_list
-
- @staticmethod
- def agResize(testGen, opName, shapeList, dtype, error_name=None):
- arg_list = []
-
- ifm_shape = shapeList[0]
- for mode in [ResizeMode.NEAREST, ResizeMode.BILINEAR]:
-
- # Exclude illegal {mode, type} configurations. Pick legal output types
- if mode == ResizeMode.NEAREST and dtype == DType.INT8:
- outputDTypeList = [DType.INT8]
- elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
- outputDTypeList = [DType.INT16]
- elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
- outputDTypeList = [DType.INT32]
- elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
- outputDTypeList = [DType.INT48]
- elif dtype == DType.FLOAT:
- outputDTypeList = [DType.FLOAT]
- elif error_name == ErrorIf.WrongInputType:
- # If an incorrect input type is used then we set a 'correct'
- # output type to avoid other errors
- outputDTypeList = [DType.INT8, DType.INT16, DType.INT32]
- else:
- continue
-
- for outputDType in outputDTypeList:
- for perm in range(testGen.args.num_rand_permutations):
- # Randomly generate legal output dimensions and shift
- # and then compute the stride and offset based on them
- # A output_dim of 1 will cause offset to exceed allowed range
- # so minimum value 2 produced below
- output_dims = [testGen.randInt(1) + 1, testGen.randInt(1) + 1]
- while (float(ifm_shape[1]) / float(output_dims[0])) >= 16:
- output_dims[0] += 1
- while (float(ifm_shape[2]) / float(output_dims[1])) >= 16:
- output_dims[1] += 1
-
- in_center_h = (ifm_shape[1] - 1) / 2.0
- in_center_w = (ifm_shape[2] - 1) / 2.0
- out_center_h = (output_dims[0] - 1) / 2.0
- out_center_w = (output_dims[1] - 1) / 2.0
-
- fp_stride_y = float(ifm_shape[1]) / float(output_dims[0])
- fp_stride_x = float(ifm_shape[2]) / float(output_dims[1])
- fp_offset_y = in_center_h - fp_stride_y * out_center_h
- fp_offset_x = in_center_w - fp_stride_x * out_center_w
-
- if outputDType == DType.FLOAT:
- float_op = True
- arg_str = (
- "mode{}_shift{}_odim{}x{}_out{}"
- "_st{:.2f}x{:.2f}_off{:.2f}x{:.2f}"
- )
- shift = 0
- stride = [0, 0]
- offset = [0, 0]
- stride_fp = [fp_stride_y, fp_stride_x]
- offset_fp = [fp_offset_y, fp_offset_x]
-
- else:
- float_op = False
- arg_str = "mode{}_shift{}_odim{}x{}_out{}_st{}x{}_off{}x{}"
- shift = testGen.randInt(1, 12)
- # Now search for a shift value (1 to 11) that will produce
- # a valid and predictable resize operation
- count = 0
- while count < 12:
- unit = float(1 << shift)
- stride_y = int(round(fp_stride_y * unit))
- stride_x = int(round(fp_stride_x * unit))
- offset_y = int(round(fp_offset_y * unit))
- offset_x = int(round(fp_offset_x * unit))
-
- if (
- stride_y <= 0
- or stride_x <= 0
- or stride_y >= (16 << shift)
- or stride_x >= (16 << shift)
- or offset_y >= (16 << shift)
- or offset_x >= (16 << shift)
- or offset_y <= (-16 << shift)
- or offset_x <= (-16 << shift)
- ):
- # Change the shift value and check again
- count += 1
- shift = (shift % 11) + 1
- continue
-
- def RESIZE_REQUIRE_CALC(
- length_in, length_out, stride, offset, shift
- ):
- # Perform the pseudo loop to look for out of bounds
- for pos in range(0, length_out):
- a = pos * stride + offset
- ia = a >> shift
- ia0 = max(ia, 0)
- ia1 = min(ia + 1, length_in - 1)
- if ia0 > ia1:
- # Found a problem value
- break
- return ia0, ia1
-
- iy0, iy1 = RESIZE_REQUIRE_CALC(
- ifm_shape[1], output_dims[0], stride_y, offset_y, shift
- )
- ix0, ix1 = RESIZE_REQUIRE_CALC(
- ifm_shape[2], output_dims[1], stride_x, offset_x, shift
- )
- if ix0 > ix1 or iy0 > iy1:
- # Change the shift value and check again
- count += 1
- shift = (shift % 11) + 1
- continue
- break
-
- if count >= 12:
- # Couldn't find a good set of values for this test, skip it
- continue
-
- stride = [stride_y, stride_x]
- offset = [offset_y, offset_x]
-
- stride_fp = [0.0, 0.0]
- offset_fp = [0.0, 0.0]
-
- # Common for all data types
- if error_name is not None:
- (
- shift,
- stride,
- stride_fp,
- offset,
- offset_fp,
- outputDTypeNew,
- ) = TosaErrorIfArgGen.eiResizeErrorIf(
- testGen,
- error_name,
- mode,
- dtype,
- shapeList,
- outputDType,
- shift,
- stride,
- stride_fp,
- offset,
- offset_fp,
- )
- else:
- outputDTypeNew = outputDType
-
- arg_list.append(
- (
- arg_str.format(
- "N" if mode == ResizeMode.NEAREST else "B",
- shift,
- output_dims[0],
- output_dims[1],
- testGen.typeStr(outputDTypeNew),
- stride_fp[0] if float_op else stride[0],
- stride_fp[1] if float_op else stride[1],
- offset_fp[0] if float_op else offset[0],
- offset_fp[1] if float_op else offset[1],
- ),
- [
- mode,
- stride,
- offset,
- shift,
- stride_fp,
- offset_fp,
- output_dims,
- dtype,
- outputDTypeNew,
- ],
- )
- )
-
- return arg_list
-
- @staticmethod
- def agTable(testGen, opName, shapeList, dtype, error_name=None):
- arg_list = []
-
- if dtype == DType.INT8:
- table = np.int32(
- testGen.rng.integers(low=-128, high=128, size=[256])
- ).tolist()
- else: # INT16
- table = np.int32(
- testGen.rng.integers(low=-32768, high=32768, size=[513])
- ).tolist()
-
- arg_list.append(
- (
- "",
- [table],
- )
- )
- return arg_list
-
- def agCondIf(testGen, opName, shapeList, dtype, error_name=None):
- # CondIf generates the condition values here.
- # Convert to tensors in the build function, along with the
- # then and else blocks
- arg_list = []
-
- for c in [False, True]:
- arg_list.append(("cond{}".format(int(c)), [c]))
-
- return arg_list
-
- def agWhileLoop(testGen, opName, shapeList, dtype, error_name=None):
- # While loop: 0 iterations, 1, more than 1
- arg_list = []
-
- for iter in [0, 1, 4]:
- arg_list.append(("iter{}".format(iter), [iter]))
-
- return arg_list
-
-
-class TosaErrorIfArgGen:
- @staticmethod
- def eiResizeErrorIf(
- testGen,
- error_name,
- mode,
- dtype,
- shapeList,
- outputDType,
- shift,
- stride,
- stride_fp,
- offset,
- offset_fp,
- ):
-
- if outputDType == DType.FLOAT:
- if error_name == ErrorIf.StrideSmallerEqualZero:
- stride_fp = testGen.rng.random(size=[2]) - 2
- elif error_name == ErrorIf.ShiftNotZero:
- shift = testGen.rng.integers(1, 5)
- elif error_name == ErrorIf.StrideLargerDimension:
- shape = shapeList[0]
- transform_height = testGen.rng.choice([False, True])
- if transform_height:
- stride_fp[0] = shape[1] + testGen.rng.integers(1, 10)
- else:
- stride_fp[1] = shape[2] + testGen.rng.integers(1, 10)
- else:
- if error_name == ErrorIf.StrideSmallerEqualZero:
- stride = np.int16(testGen.rng.integers(-1, 1, size=[2]))
- elif error_name == ErrorIf.ShiftSmallerOne:
- shift = testGen.rng.integers(-3, 1)
- if shift <= 0:
- stride = [
- (16 >> -shift) - 1,
- (16 >> -shift) - 1,
- ] # avoids other ERROR_IF checks
- offset = [
- (16 >> -shift) - 1,
- (16 >> -shift) - 1,
- ] # avoids other ERROR_IF checks
- else:
- stride = [
- (16 << shift) - 1,
- (16 << shift) - 1,
- ] # avoids other ERROR_IF checks
- offset = [
- (16 << shift) - 1,
- (16 << shift) - 1,
- ] # avoids other ERROR_IF checks
- elif error_name == ErrorIf.ShiftLargerEleven:
- shift = np.int16(testGen.rng.integers(12, 15))
- elif error_name == ErrorIf.StrideLargerDimension:
- shape = shapeList[0]
- transform_height = testGen.rng.choice([False, True])
- if transform_height:
- stride[0] = shape[1] + testGen.rng.integers(1, 10)
- else:
- stride[1] = shape[2] + testGen.rng.integers(1, 10)
- elif error_name == ErrorIf.StrideLargerEqualMax:
- stride = [(16 << shift) + 1, (16 << shift) + 1]
- elif error_name == ErrorIf.OffsetLargerEqualMax:
- offset = [(16 << shift) + 1, (16 << shift) + 1]
- elif error_name == ErrorIf.OffsetSmallerEqualMin:
- offset = [(-16 << shift) - 1, (-16 << shift) - 1]
-
- if error_name == ErrorIf.WrongOutputType:
- if mode == ResizeMode.NEAREST and dtype == DType.INT8:
- incorrect_types = (
- DType.INT4,
- DType.INT16,
- DType.INT32,
- DType.INT48,
- DType.FLOAT,
- )
- elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
- incorrect_types = (
- DType.INT4,
- DType.INT8,
- DType.INT32,
- DType.INT48,
- DType.FLOAT,
- )
- elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
- incorrect_types = (
- DType.INT4,
- DType.INT8,
- DType.INT16,
- DType.INT48,
- DType.FLOAT,
- )
- elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
- incorrect_types = (
- DType.INT4,
- DType.INT8,
- DType.INT16,
- DType.INT32,
- DType.FLOAT,
- )
- elif dtype == DType.FLOAT:
- incorrect_types = (
- DType.INT4,
- DType.INT8,
- DType.INT16,
- DType.INT32,
- DType.INT48,
- )
- outputDType = testGen.rng.choice(a=incorrect_types)
-
- return shift, stride, stride_fp, offset, offset_fp, outputDType
-
- @staticmethod
- def eiPoolingErrorIf(testGen, error_name, stride, pad, kernel):
- if (
- error_name == ErrorIf.StrideSmallerOne
- # padding must not exceed the kernel size
- and pad[0] < kernel[0]
- and pad[1] < kernel[0]
- and pad[2] < kernel[1]
- and pad[3] < kernel[1]
- ):
- wrongStride = (
- testGen.rng.choice([0, -1, -2, -3]),
- testGen.rng.choice([0, -1, -2, -3]),
- )
- return wrongStride, pad, kernel
- elif error_name == ErrorIf.PadSmallerZero:
- wrongPad = (
- testGen.rng.choice([-1, -2, -3]),
- testGen.rng.choice([-1, -2, -3]),
- testGen.rng.choice([-1, -2, -3]),
- testGen.rng.choice([-1, -2, -3]),
- )
- return stride, wrongPad, kernel
- elif error_name == ErrorIf.KernelSmallerOne:
- wrongKernel = (
- testGen.rng.choice([0, -1, -2, -3]),
- testGen.rng.choice([0, -1, -2, -3]),
- )
- return stride, pad, wrongKernel
- elif error_name == ErrorIf.PadLargerEqualKernel:
- wrongPad = (
- testGen.rng.choice([kernel[0], kernel[0] + 1, kernel[0] + 2]),
- testGen.rng.choice([kernel[0], kernel[0] + 1, kernel[0] + 2]),
- testGen.rng.choice([kernel[1], kernel[1] + 1, kernel[1] + 2]),
- testGen.rng.choice([kernel[1], kernel[1] + 1, kernel[1] + 2]),
- )
- return stride, wrongPad, kernel
- else:
- return None, None, None
-
- @staticmethod
- def eiRescaleWrongOutputType(input_dtype, output_dtype):
- if input_dtype == DType.INT8:
- if output_dtype not in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
- return True
- if input_dtype in [DType.INT16, DType.INT32]:
- if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
- return True
- elif input_dtype == DType.INT48:
- if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
- return True
- elif input_dtype == DType.UINT8:
- if output_dtype != DType.INT8:
- return True
- return False
-
- @staticmethod
- def eiInvalidateInputOutputList(testGen, error_name, input_list, output_list):
- # Mess up input/output tensors for ERROR_IF checks
- if error_name == "WrongInputList":
- add_input = testGen.rng.choice([True, False])
- if add_input:
- input_list.append("eiDummyInput")
- else:
- input_list = input_list[:-1]
- elif error_name == "WrongOutputList":
- add_output = testGen.rng.choice([True, False])
- if add_output:
- output_list.append("eiDummyOutput")
- else:
- output_list = []
- return input_list, output_list
-
- @staticmethod
- def eiRestrictDimensions(shape, max_dim=32, max_items=100000):
- """Restrict the dimensions and overall size of a shape to
- max_dim and max_items.
- """
- new_shape = [min(d, max_dim) for d in shape] if max(shape) > max_dim else shape
- while product(new_shape) > max_items:
- new_shape = [max(d - 1, 1) for d in new_shape]
- return new_shape
-
- def eiSliceErrorIf(testGen, error_name, input_shape, start, size):
- if error_name == ErrorIf.StartSmallerZero:
- newStart = []
- for i in range(len(input_shape)):
- newStart.append(testGen.rng.choice([-3, -2, -1]))
- return newStart, size
- elif error_name == ErrorIf.SizeSmallerEqualZero:
- newSize = []
- for i in range(len(input_shape)):
- newSize.append(testGen.rng.choice([-3, -2, -1, 0]))
- return start, newSize
- elif error_name == ErrorIf.StartSizeOutsideBounds:
- newStart, newSize = [], []
- for i in range(len(input_shape)):
- newStart.append(input_shape[i] - 1)
- newSize.append(testGen.rng.choice([2, 3, 4]))
- return newStart, newSize
- elif error_name == ErrorIf.InputSizeStartLengthMismatch:
- remove = testGen.rng.choice([True, False])
- if remove:
- newStart = start[1:]
- newSize = size[1:]
- else:
- newStart = start
- newStart.append(1)
- newSize = size
- newSize.append(1)
- return newStart, newSize
- else:
- return start, size
-
- @staticmethod
- def eiCastErrorIf(testGen, input_dtype):
- if input_dtype in [DType.BOOL, DType.FLOAT]:
- outputDType = [DType.BOOL, DType.INT48, DType.FLOAT]
- elif input_dtype in [DType.INT8, DType.INT16, DType.INT32]:
- outputDType = [DType.INT48]
- else:
- assert True, f"input_dtype ({input_dtype}) not supported"
- return outputDType
-
-
-class TosaErrorValidator:
- @staticmethod
- def evValidateErrorIfs(serializer, validator_fcns, error_name, **kwargs):
- """Check ERROR_IF statements are caught and set the expected result.
-
- Args:
- serializer: the serializer to set the expected result in
- validator_fcns: a sequence of validator functions to verify the result
- error_name: the name of the ERROR_IF condition to check for
- kwargs: keyword arguments for the validator functions
- Returns:
- True if the result matches the expected result; otherwise False
- """
- overall_result = True
- for val_fcn in validator_fcns:
- val_result = val_fcn(True, **kwargs)
- validator_name = val_result["error_name"]
- error_result = val_result["error_result"]
- error_reason = val_result["error_reason"]
-
- # expect an error IFF the error_name and validator_name match
- expected_result = error_result == (error_name == validator_name)
- overall_result &= expected_result
-
- if expected_result and error_result:
- serializer.setExpectedReturnCode(2, True, desc=error_reason)
- elif error_result: # and not expected_result
- print(
- f"Unexpected ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}"
- f" Expected: {error_name}, Got: {validator_name}"
- )
- elif not expected_result: # and not error_result
- print(
- f"Missed ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}"
- f" Expected: {error_name}"
- )
-
- if not expected_result:
- for k, v in sorted(kwargs.items()):
- if k != "op":
- if k.endswith("dtype"):
- v = valueToName(DType, v)
- print(f" {k} = {v}")
-
- return overall_result
-
- @staticmethod
- def evWrongInputType(check=False, **kwargs):
- error_result = False
-
- # Find the unsupported input data types
- op = kwargs["op"]
- input_dtypes = op["types"]
- allowed_input_dtypes = {
- t[0] if isinstance(t, list) else t for t in input_dtypes
- }
- wrong_input_dtypes = list(usableDTypes(excludes=allowed_input_dtypes))
-
- if op["op"] == Op.CLAMP:
- wrong_input_dtypes.remove(DType.INT48)
-
- if check:
- input_dtype = kwargs["input_dtype"]
- if input_dtype not in allowed_input_dtypes:
- error_result = True
-
- info_dict = {
- "error_name": ErrorIf.WrongInputType,
- "error_result": error_result,
- "error_reason": "Input data type not supported for this operator",
- "param_reqs": {"rank": None, "dtype": wrong_input_dtypes, "shape": None},
- }
- return info_dict
-
- @staticmethod
- def evWrongOutputType(check=False, **kwargs):
- error_result = False
-
- if check:
- input_dtype = kwargs["input_dtype"]
- output_dtype = kwargs["output_dtype"]
- op = kwargs["op"]
-
- if op["op"] == Op.RESIZE:
- mode = kwargs["mode"]
- if (
- (
- mode == ResizeMode.NEAREST
- and input_dtype == DType.INT8
- and output_dtype != DType.INT8
- )
- or (
- mode == ResizeMode.NEAREST
- and input_dtype == DType.INT16
- and output_dtype != DType.INT16
- )
- or (
- mode == ResizeMode.BILINEAR
- and input_dtype == DType.INT8
- and output_dtype != DType.INT32
- )
- or (
- mode == ResizeMode.BILINEAR
- and input_dtype == DType.INT16
- and output_dtype != DType.INT48
- )
- or (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
- ):
- error_result = True
-
- elif op["op"] == Op.RESCALE:
- if input_dtype == DType.INT8:
- if output_dtype not in [
- DType.UINT8,
- DType.INT8,
- DType.INT16,
- DType.INT32,
- ]:
- error_result = True
- if input_dtype in [DType.INT16, DType.INT32]:
- if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
- error_result = True
- elif input_dtype == DType.INT48:
- if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
- error_result = True
- elif input_dtype == DType.UINT8:
- if output_dtype != DType.INT8:
- error_result = True
-
- elif op["op"] in [Op.FULLY_CONNECTED, Op.MATMUL]:
- if (
- (input_dtype == DType.INT8 and output_dtype != DType.INT32)
- or (input_dtype == DType.INT16 and output_dtype != DType.INT48)
- or (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
- ):
- error_result = True
-
- elif op["op"] == Op.ARGMAX:
- if (
- input_dtype in [DType.INT8, DType.INT16, DType.FLOAT]
- and output_dtype != DType.INT32
- ):
- error_result = True
-
- elif op["op"] == Op.MUL:
- if input_dtype != DType.FLOAT and output_dtype != DType.INT32:
- error_result = True
- elif input_dtype == DType.FLOAT and output_dtype != DType.FLOAT:
- error_result = True
-
- elif op["op"] == Op.TABLE:
- if input_dtype == DType.INT8 and output_dtype != DType.INT8:
- error_result = True
- elif input_dtype == DType.INT16 and output_dtype != DType.INT32:
- error_result = True
-
- elif op["op"] in [Op.EQUAL, Op.GREATER_EQUAL, Op.GREATER]:
- if output_dtype != DType.BOOL:
- error_result = True
-
- elif op["op"] == Op.CAST:
- if (
- (
- input_dtype == DType.BOOL
- and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
- )
- or (
- input_dtype == DType.INT8
- and output_dtype
- not in [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT]
- )
- or (
- input_dtype == DType.INT16
- and output_dtype
- not in [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT]
- )
- or (
- input_dtype == DType.INT32
- and output_dtype
- not in [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
- )
- or (
- input_dtype == DType.FLOAT
- and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
- )
- ):
- error_result = True
-
- elif op["op"] in {
- Op.CONV2D,
- Op.CONV3D,
- Op.DEPTHWISE_CONV2D,
- Op.TRANSPOSE_CONV2D,
- }:
- if (
- input_dtype == DType.INT8
- and output_dtype != DType.INT32
- or input_dtype == DType.INT16
- and output_dtype != DType.INT48
- or input_dtype == DType.FLOAT
- and output_dtype != DType.FLOAT
- ):
- error_result = True
- # invalid input types are ignored, to avoid reporting multiple errors
-
- else:
- if output_dtype != input_dtype:
- error_result = True
-
- info_dict = {
- "error_name": ErrorIf.WrongOutputType,
- "error_result": error_result,
- "error_reason": (
- "Output data type not supported for this configuration of operator"
- ),
- "param_reqs": {"rank": None, "dtype": None, "shape": None},
- }
- return info_dict
-
- @staticmethod
- def evWrongRank(check=False, **kwargs):
- all_ranks = (1, 2, 3, 4, 5)
-
- # Make a list of incorrect ranks
- assert "op" in kwargs
- op = kwargs["op"]
- rmin, rmax = op["rank"]
- rank_range = range(rmin, rmax + 1)
- incorrect_ranks = list(set(all_ranks) - set(rank_range))
- # Remove small incorrect ranks to avoid index errors
- incorrect_ranks = [rank for rank in incorrect_ranks if rank > rmin]
- # Set minimum incorrect rank to 3 to avoid index error
- if op["op"] in [Op.RESIZE]:
- incorrect_ranks = [3, 5]
- elif op["op"] in [Op.TRANSPOSE]:
- incorrect_ranks = [7, 8]
- elif op["op"] in [Op.CONV3D]:
- incorrect_ranks = [6, 7]
-
- error_name = ErrorIf.WrongRank
- param_reqs = {"rank": incorrect_ranks, "dtype": None, "shape": None}
- error_result = False
- error_reason = "Rank not supported for this operator"
-
- if check:
- input_shape = kwargs["input_shape"]
-
- if (
- op["op"] in [Op.RESIZE, Op.AVG_POOL2D, Op.MAX_POOL2D]
- and len(input_shape) != 4
- ):
- error_result = True
- elif op["op"] == Op.FULLY_CONNECTED and len(input_shape) != 2:
- error_result = True
- elif op["op"] == Op.MATMUL and len(input_shape) != 3:
- error_result = True
- else:
- if len(input_shape) not in rank_range:
- error_result = True
-
- info_dict = {
- "error_name": error_name,
- "error_result": error_result,
- "error_reason": error_reason,
- "param_reqs": param_reqs,
- }
- return info_dict
-
- @staticmethod
- def evWrongInputList(check=False, **kwargs):
- error_name = ErrorIf.WrongInputList
- param_reqs = {"rank": None, "dtype": None, "shape": None}
- error_result = False
- error_reason = "Op input list does not match expected input"
-
- if check:
- op = kwargs["op"]
- input_list = kwargs["input_list"]
- num_operands = kwargs["num_operands"]
- if op["op"] in [Op.SCATTER, Op.GATHER]:
- # SCATTER/GATHER add an indices input tensor in their build functions
- num_operands += 1
- if len(input_list) != num_operands:
- error_result = True
-
- info_dict = {
- "error_name": error_name,
- "error_result": error_result,
- "error_reason": error_reason,
- "param_reqs": param_reqs,
- }
- return info_dict
-
- @staticmethod
- def evWrongOutputList(check=False, **kwargs):
- error_name = ErrorIf.WrongOutputList
- param_reqs = {"rank": None, "dtype": None, "shape": None}
- error_result = False
- error_reason = "Op output list does not match expected output"
-
- if check:
- output_list = kwargs["output_list"]
- # Note this will be incorrect if an operator returns more than one output
- if len(output_list) != 1:
- error_result = True
-
- info_dict = {
- "error_name": error_name,
- "error_result": error_result,
- "error_reason": error_reason,
- "param_reqs": param_reqs,
- }
- return info_dict
-
- @staticmethod
- def evMaxDimExceeded(check=False, **kwargs):
- error_name = ErrorIf.MaxDimExceeded
- param_reqs = {
- "rank": [4, 4],
- "dtype": [DType.INT8],
- "shape": [[1, 16584, 5, 1], [1, 2, 16499, 4]],
- }
- error_result = False
- error_reason = (
- "At least one maximum dimension is greater than or equal to 16384"
- )
-
- if check:
- input_shape = kwargs["input_shape"]
- output_shape = kwargs["output_shape"] # Note this is just (OH, OW)
- if (
- (input_shape[1] >= 16384)
- or (input_shape[2] >= 16384)
- or (output_shape[0] >= 16384)
- or (output_shape[1] >= 16384)
- ):
- error_result = True
-
- info_dict = {
- "error_name": error_name,
- "error_result": error_result,
- "error_reason": error_reason,
- "param_reqs": param_reqs,
- }
- return info_dict
-
- @staticmethod
- def evBatchMismatch(check=False, **kwargs):
- error_name = ErrorIf.BatchMismatch
- param_reqs = {"rank": [4, 4], "dtype": None, "shape": None}
- error_result = False
- error_reason = "Input batch size not equal to output batch size"
-
- assert "op" in kwargs
- op = kwargs["op"]
- rmin, rmax = op["rank"]
- rank_range = range(rmin, rmax + 1)
-
- if check:
- input_shape = kwargs["input_shape"]
- output_shape = kwargs[
- "result_tensor"
- ].shape # Note this is just (N, OH, OW, C)
-
- if (len(input_shape) in rank_range) and (input_shape[0] != output_shape[0]):
- error_result = True
-
- info_dict = {
- "error_name": error_name,
- "error_result": error_result,
- "error_reason": error_reason,
- "param_reqs": param_reqs,
- }
- return info_dict
-
- @staticmethod
- def evChannelMismatch(check=False, **kwargs):
- error_name = ErrorIf.ChannelMismatch
- param_reqs = {"rank": [4, 4], "dtype": None, "shape": None}
- error_result = False
- error_reason = "Input channel size not equal to output channel size"
-
- assert "op" in kwargs
- op = kwargs["op"]
- rmin, rmax = op["rank"]
- rank_range = range(rmin, rmax + 1)
-
- if check:
- input_shape = kwargs["input_shape"]
- output_shape = kwargs[
- "result_tensor"
- ].shape # Note this is just (N, OH, OW, C)
- if (len(input_shape) in rank_range) and (input_shape[3] != output_shape[3]):
- error_result = True
-
- info_dict = {
- "error_name": error_name,
- "error_result": error_result,
- "error_reason": error_reason,
- "param_reqs": param_reqs,
- }
- return info_dict
-
- @staticmethod
- def evStrideSmallerEqualZero(check=False, **kwargs):
- error_name = ErrorIf.StrideSmallerEqualZero
- param_reqs = {"rank": None, "dtype": None, "shape": None}
- error_result = False
- error_reason = "Stride value smaller than or equal zero"
-
- if check:
- input_dtype = kwargs["input_dtype"]
- output_dtype = kwargs["output_dtype"]
- if input_dtype != DType.FLOAT and output_dtype == DType.FLOAT:
- stride = kwargs["stride"] # Work around wrong input/output type tests
- elif output_dtype == DType.FLOAT:
- stride = kwargs["stride_fp"]
- elif input_dtype == DType.FLOAT and output_dtype != DType.FLOAT:
- stride = kwargs[
- "stride_fp"
- ] # Work around wrong input/output type tests
- else:
- stride = kwargs["stride"]
-
- if min(stride) <= 0:
- error_result = True
-
- info_dict = {
- "error_name": error_name,
- "error_result": error_result,
- "error_reason": error_reason,
- "param_reqs": param_reqs,
- }
- return info_dict
-
- @staticmethod
- def evStrideLargerEqualMax(check=False, **kwargs):
- error_name = ErrorIf.StrideLargerEqualMax
- param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
- error_result = False
- error_reason = "Stride value larger than or equal to maximum value"
-
- if check:
- shift = kwargs["shift"]
- input_dtype = kwargs["input_dtype"]
- stride = kwargs["stride"]
- if input_dtype in [DType.INT8, DType.INT16]:
- if shift >= 0 and (
- stride[0] >= (16 << shift) or stride[1] >= (16 << shift)
- ):
- error_result = True
- elif shift < 0 and (
- stride[0] >= (16 >> -shift) or stride[1] >= (16 >> -shift)
- ):
- error_result = True
-
- info_dict = {
- "error_name": error_name,
- "error_result": error_result,
- "error_reason": error_reason,
- "param_reqs": param_reqs,
- }
- return info_dict
-
- @staticmethod
- def evStrideLargerDimension(check=False, **kwargs):
- error_name = ErrorIf.StrideLargerDimension
- param_reqs = {"rank": None, "dtype": [DType.FLOAT], "shape": None}
- error_result = False
- error_reason = "Stride value larger than or equal to H/W dimension"
-
- if check:
- shape = kwargs["input_shape"]
- input_dtype = kwargs["input_dtype"]
- stride = kwargs["stride_fp"]
-
- if (
- input_dtype == DType.FLOAT
- and (stride[0] > shape[1])
- or (stride[1] > shape[2])
- ):
- error_result = True
-
- info_dict = {
- "error_name": error_name,
- "error_result": error_result,
- "error_reason": error_reason,
- "param_reqs": param_reqs,
- }
- return info_dict
-
- @staticmethod
- def evOffsetSmallerEqualMin(check=False, **kwargs):
- error_name = ErrorIf.OffsetSmallerEqualMin
- param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
- error_result = False
- error_reason = "Offset value smaller than or equal to minimum value"
-
- if check:
- shift = kwargs["shift"]
- output_dtype = kwargs["output_dtype"]
- if output_dtype == DType.FLOAT:
- offset = kwargs["offset_fp"]
- else:
- offset = kwargs["offset"]
-
- if shift >= 0 and (
- offset[0] <= (-16 << shift) or offset[1] <= (-16 << shift)
- ):
- error_result = True
- elif shift < 0 and (
- offset[0] <= (-16 >> -shift) or offset[1] <= (-16 >> -shift)
- ):
- error_result = True
-
- info_dict = {
- "error_name": error_name,
- "error_result": error_result,
- "error_reason": error_reason,
- "param_reqs": param_reqs,
- }
- return info_dict
-
- @staticmethod
- def evOffsetLargerEqualMax(check=False, **kwargs):
- error_name = ErrorIf.OffsetLargerEqualMax
- param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
- error_result = False
- error_reason = "Offset value larger than or equal to maximum value"
-
- if check:
- shift = kwargs["shift"]
- output_dtype = kwargs["output_dtype"]
- if output_dtype == DType.FLOAT:
- offset = kwargs["offset_fp"]
- else:
- offset = kwargs["offset"]
-
- if shift >= 0:
- if offset[0] >= (16 << shift) or offset[1] >= (16 << shift):
- error_result = True
-
- if shift >= 0 and (
- offset[0] >= (16 << shift) or offset[1] >= (16 << shift)
- ):
- error_result = True
- elif shift < 0 and (
- offset[0] >= (16 >> -shift) or offset[1] >= (16 >> -shift)
- ):
- error_result = True
-
- info_dict = {
- "error_name": error_name,
- "error_result": error_result,
- "error_reason": error_reason,
- "param_reqs": param_reqs,
- }
- return info_dict
-
- @staticmethod
- def evShiftNotZero(check=False, **kwargs):
- error_name = ErrorIf.ShiftNotZero
- param_reqs = {"rank": None, "dtype": [DType.FLOAT], "shape": None}
- error_result = False
- error_reason = "Shift value must be zero for float input"
-
- if check:
- shift = kwargs["shift"]
- input_dtype = kwargs["input_dtype"]
- output_dtype = kwargs["output_dtype"]
- if (
- input_dtype == DType.FLOAT
- and output_dtype == DType.FLOAT
- and shift != 0
- ):
- error_result = True
-
- info_dict = {
- "error_name": error_name,
- "error_result": error_result,
- "error_reason": error_reason,
- "param_reqs": param_reqs,
- }
- return info_dict
-
- @staticmethod
- def evShiftSmallerOne(check=False, **kwargs):
- error_name = ErrorIf.ShiftSmallerOne
- param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
- error_result = False
- error_reason = "Shift value smaller than one"
-
- if check:
- shift = kwargs["shift"]
- input_dtype = kwargs["input_dtype"]
- output_dtype = kwargs["output_dtype"]
- if shift < 1 and input_dtype != DType.FLOAT and output_dtype != DType.FLOAT:
- error_result = True
-
- info_dict = {
- "error_name": error_name,
- "error_result": error_result,
- "error_reason": error_reason,
- "param_reqs": param_reqs,
- }
- return info_dict
-
- @staticmethod
- def evShiftLargerEleven(check=False, **kwargs):
- error_name = ErrorIf.ShiftLargerEleven
- param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
- error_result = False
- error_reason = "Shift value larger than eleven"
-
- if check:
- shift = kwargs["shift"]
- if shift > 11:
- error_result = True
-
- info_dict = {
- "error_name": error_name,
- "error_result": error_result,
- "error_reason": error_reason,
- "param_reqs": param_reqs,
- }
- return info_dict
-
- @staticmethod
- def evRankMismatch(check=False, **kwargs):
- error_name = ErrorIf.RankMismatch
- param_reqs = {"rank": None, "dtype": None, "shape": None}
- error_result = False
- error_reason = "Input Rank does not match output rank"
-
- if check:
- input1_shape = kwargs["input1"].shape
- input2_shape = kwargs["input2"].shape
- # In case of SELECT op
- input3_shape = (
- kwargs["input3"].shape if "input3" in kwargs else input2_shape
- )
- output_shape = kwargs["result_tensor"].shape
- if (
- (len(input1_shape) != len(output_shape))
- or (len(input2_shape) != len(output_shape))
- or (len(input3_shape) != len(output_shape))
- ):
- error_result = True
-
- info_dict = {
- "error_name": error_name,
- "error_result": error_result,
- "error_reason": error_reason,
- "param_reqs": param_reqs,
- }
- return info_dict
-
- @staticmethod
- def evDimensionMismatch(check=False, **kwargs):
- error_name = ErrorIf.DimensionMismatch
- param_reqs = {"rank": None, "dtype": None, "shape": None}
- error_result = False
- error_reason = "Input Dimensions do not match output"
-
- if check:
- input1_shape = kwargs["input1"].shape
- input2_shape = kwargs["input2"].shape
- # In case of SELECT op
- input3_shape = (
- kwargs["input3"].shape if "input3" in kwargs else input2_shape
- )
- output_shape = kwargs["result_tensor"].shape
- for i in range(
- min(len(input1_shape), len(input2_shape), len(input3_shape))
- ):
- if (
- (input1_shape[i] != 1 and input1_shape[i] != output_shape[i])
- or (input2_shape[i] != 1 and input2_shape[i] != output_shape[i])
- or (input3_shape[i] != 1 and input3_shape[i] != output_shape[i])
- ):
- error_result = True
-
- info_dict = {
- "error_name": error_name,
- "error_result": error_result,
- "error_reason": error_reason,
- "param_reqs": param_reqs,
- }
- return info_dict
-
- @staticmethod
- def evInputZeroPointNotZero(check=False, **kwargs):
- op = kwargs["op"]
- error_result = False
-
- # Quantizable types
- qTypes = (DType.INT8, DType.UINT8)
-
- # This does not apply to quantizable types
- inputDtypes = [
- dtype
- for dtype in op["types"]
- if (isinstance(dtype, list) and dtype[0] not in qTypes)
- or (not isinstance(dtype, list) and dtype not in qTypes)
- ]
-
- if check:
- input_dtype = kwargs["input_dtype"]
- if isinstance(kwargs["qinfo"], tuple):
- qinfo = kwargs["qinfo"]
- input_zero_point = qinfo[0]
- else:
- # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
- qinfo = kwargs["qinfo"].ints
- input_zero_point = qinfo[0][1]
-
- if op["op"] == Op.MATMUL:
- qinfo = kwargs["qinfo"].ints
- for dtype, zp in (
- (kwargs["input_dtype"], qinfo[0][1]),
- (kwargs["input2_dtype"], qinfo[1][1]),
- ):
- if dtype not in qTypes and zp != 0:
- error_result = True
- break
- else:
- error_result = input_dtype not in qTypes and input_zero_point != 0
-
- info_dict = {
- "error_name": ErrorIf.InputZeroPointNotZero,
- "error_result": error_result,
- "error_reason": "Input DType not INT8 and zero point not 0",
- "param_reqs": {"rank": None, "dtype": inputDtypes, "shape": None},
- }
- return info_dict
-
- @staticmethod
- def evWeightZeroPointNotZero(check=False, **kwargs):
- op = kwargs["op"]
-
- # exclude inputs with INT8 weights
- inputDtypes = [
- t for t in op["types"] if not isinstance(t, list) or t[1] != DType.INT8
- ]
-
- error_name = ErrorIf.WeightZeroPointNotZero
- param_reqs = {"rank": None, "dtype": inputDtypes, "shape": None}
- error_result = False
- error_reason = "Weight DType not INT8 and zero point not 0"
-
- if check:
- weight_dtype = kwargs["weight_dtype"]
- # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = weight_zp
- qinfo = kwargs["qinfo"].ints
- weight_zero_point = qinfo[1][1]
- if weight_dtype != DType.INT8 and weight_zero_point != 0:
- error_result = True
-
- info_dict = {
- "error_name": error_name,
- "error_result": error_result,
- "error_reason": error_reason,
- "param_reqs": param_reqs,
- }
- return info_dict
-
- @staticmethod
- def evOutputZeroPointNotZero(check=False, **kwargs):
- op = kwargs["op"]
- inputDtypes = op["types"].copy()
- if DType.INT8 in inputDtypes:
- inputDtypes.remove(DType.INT8)
- if DType.UINT8 in inputDtypes:
- inputDtypes.remove(DType.UINT8)
-
- error_name = ErrorIf.OutputZeroPointNotZero
- param_reqs = {"rank": None, "dtype": inputDtypes, "shape": None}
- error_result = False
- error_reason = "Output DType not INT8 and zero point not 0"
-
- if check:
- input_dtype = kwargs["input_dtype"]
- output_dtype = kwargs["output_dtype"]
- if isinstance(kwargs["qinfo"], tuple):
- qinfo = kwargs["qinfo"]
- output_zero_point = qinfo[1]
- else:
- # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
- qinfo = kwargs["qinfo"].ints
- output_zero_point = qinfo[1][1]
- if op["op"] == Op.AVG_POOL2D:
- if input_dtype != DType.INT8 and output_zero_point != 0:
- error_result = True
- elif (
- output_dtype not in [DType.INT8, DType.UINT8] and output_zero_point != 0
- ):
- error_result = True
-
- info_dict = {
- "error_name": error_name,
- "error_result": error_result,
- "error_reason": error_reason,
- "param_reqs": param_reqs,
- }
- return info_dict
-
- @staticmethod
- def evAxisSmallerZero(check=False, **kwargs):
- error_name = ErrorIf.AxisSmallerZero
- param_reqs = {"rank": None, "dtype": None, "shape": None}
- error_result = False
- error_reason = "Axis smaller than zero"
-
- if check:
- axis = kwargs["axis"]
- if axis < 0:
- error_result = True
-
- info_dict = {
- "error_name": error_name,
- "error_result": error_result,
- "error_reason": error_reason,
- "param_reqs": param_reqs,
- }
- return info_dict
-
- @staticmethod
- def evAxisLargerRank(check=False, **kwargs):
- error_name = ErrorIf.AxisLargerRank
- param_reqs = {"rank": None, "dtype": None, "shape": None}
- error_result = False
- error_reason = "Axis larger than rank"
-
- if check:
- axis = kwargs["axis"]
- shape = kwargs["input_shape"]
- if axis > len(shape):
- error_result = True
-
- info_dict = {
- "error_name": error_name,
- "error_result": error_result,
- "error_reason": error_reason,
- "param_reqs": param_reqs,
- }
- return info_dict
-
- @staticmethod
- def evShapeOfAxisNotOne(check=False, **kwargs):
- error_name = ErrorIf.ShapeOfAxisNotOne
- param_reqs = {"rank": None, "dtype": None, "shape": None}
- error_result = False
- error_reason = "shape[axis] is not equal to 1"
-
- if check:
- axis = kwargs["axis"]
- shape = kwargs["output_shape"]
- if (0 <= axis < len(shape)) and shape[axis] != 1:
- error_result = True
-
- info_dict = {
- "error_name": error_name,
- "error_result": error_result,
- "error_reason": error_reason,
- "param_reqs": param_reqs,
- }
- return info_dict
-
- @staticmethod
- def evPadSmallerZero(check=False, **kwargs):
- error_name = ErrorIf.PadSmallerZero
- param_reqs = {"rank": None, "dtype": None, "shape": None}
- error_result = False
- error_reason = "At least one pad is smaller than zero"
-
- if check:
- op = kwargs["op"]
- pad = kwargs["pad"]
- if op["op"] == Op.PAD:
- for padding in pad:
- if min(padding) < 0:
- error_result = True
- else:
- if min(pad) < 0:
- error_result = True
-
- info_dict = {
- "error_name": error_name,
- "error_result": error_result,
- "error_reason": error_reason,
- "param_reqs": param_reqs,
- }
- return info_dict
-
- @staticmethod
- def evPadLargerEqualKernel(check=False, **kwargs):
- error_name = ErrorIf.PadLargerEqualKernel
- param_reqs = {"rank": None, "dtype": None, "shape": None}
- error_result = False
- error_reason = "At least one pad is larger than kernel dimension"
-
- if check:
- pad = kwargs["pad"]
- kernel = kwargs["kernel"]
- if min(pad) > 0 and min(kernel) > 1:
- if (
- pad[0] >= kernel[0]
- or pad[1] >= kernel[0]
- or pad[2] >= kernel[1]
- or pad[3] >= kernel[1]
- ):
- error_result = True
-
- info_dict = {
- "error_name": error_name,
- "error_result": error_result,
- "error_reason": error_reason,
- "param_reqs": param_reqs,
- }
- return info_dict
-
- @staticmethod
- def evPoolingOutputShapeMismatch(check=False, **kwargs):
- error_name = ErrorIf.PoolingOutputShapeMismatch
- param_reqs = {"rank": None, "dtype": None, "shape": None}
- error_result = False
- error_reason = (
- "Mismatch between output shape provided and expected output shape"
- )
-
- if check:
- pad = kwargs["pad"]
- pad_top, pad_bottom, pad_left, pad_right = pad[0], pad[1], pad[2], pad[3]
-
- kernel = kwargs["kernel"]
- kernel_y, kernel_x = kernel[0], kernel[1]
-
- input_shape = kwargs["input_shape"]
- IH, IW = input_shape[1], input_shape[2]
-
- output_shape = kwargs["output_shape"]
- OH, OW = output_shape[1], output_shape[2]
-
- stride = kwargs["stride"]
- stride_y, stride_x = stride[0], stride[1]
-
- # calculate correct height, width dimensions
- if stride_x != 0 and stride_y != 0:
- y_correct = (
- IH + pad_top + pad_bottom + stride_y - kernel_y
- ) // stride_y
- x_correct = (
- IW + pad_left + pad_right + stride_x - kernel_x
- ) // stride_x
-
- # ensure parameters are valid
- params_valid = (
- min(kernel) >= 1
- and min(stride) >= 1
- and min(pad) >= 0
- and not (
- pad[0] >= kernel[0]
- or pad[1] >= kernel[0]
- or pad[2] >= kernel[1]
- or pad[3] >= kernel[1]
- )
- )
-
- if params_valid and (OH != y_correct or OW != x_correct):
- error_result = True
-
- info_dict = {
- "error_name": error_name,
- "error_result": error_result,
- "error_reason": error_reason,
- "param_reqs": param_reqs,
- }
- return info_dict
-
- @staticmethod
- def evArgmaxOutputShapeMismatch(check=False, **kwargs):
- error_name = ErrorIf.ArgmaxOutputShapeMismatch
- param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
- error_result = False
- error_reason = (
- "Mismatch between output shape provided and expected output shape"
- )
-
- if check:
- output_shape = kwargs["output_shape"]
- input_shape = kwargs["input_shape"]
- axis = kwargs["axis"]
-
- dimension_match = True
- axis_shift = 0
-
- # Check that rank is correct before trying to check dimensions
- if (len(input_shape) - 1) == len(output_shape):
- for i in range(len(input_shape)):
- if i == axis:
- axis_shift = 1
- continue
- if input_shape[i] != output_shape[i - axis_shift]:
- dimension_match = False
-
- if not dimension_match:
- error_result = True
-
- info_dict = {
- "error_name": error_name,
- "error_result": error_result,
- "error_reason": error_reason,
- "param_reqs": param_reqs,
- }
- return info_dict
-
- @staticmethod
- def evArgmaxOutputRankMismatch(check=False, **kwargs):
- error_name = ErrorIf.ArgmaxOutputRankMismatch
- param_reqs = {"rank": None, "dtype": None, "shape": None}
- error_result = False
- error_reason = (
- "Mismatch between output shape provided and expected output shape"
- )
-
- if check:
- output_shape = kwargs["output_shape"]
- input_shape = kwargs["input_shape"]
- axis = kwargs["axis"]
- valid_params = axis >= 0 and axis < len(input_shape)
-
- if valid_params and (len(input_shape) - 1) != len(output_shape):
- error_result = True
-
- info_dict = {
- "error_name": error_name,
- "error_result": error_result,
- "error_reason": error_reason,
- "param_reqs": param_reqs,
- }
- return info_dict
-
- @staticmethod
- def evKernelSmallerOne(check=False, **kwargs):
- error_name = ErrorIf.KernelSmallerOne
- param_reqs = {"rank": None, "dtype": None, "shape": None}
- error_result = False
- error_reason = "At least one kernel dimension is smaller than zero"
-
- if check:
- kernel = kwargs["kernel"]
- if min(kernel) < 1:
- error_result = True
-
- info_dict = {
- "error_name": error_name,
- "error_result": error_result,
- "error_reason": error_reason,
- "param_reqs": param_reqs,
- }
- return info_dict
-
- @staticmethod
- def evStrideSmallerOne(check=False, **kwargs):
- error_name = ErrorIf.StrideSmallerOne
- param_reqs = {"rank": None, "dtype": None, "shape": None}
- error_result = False
- error_reason = "At least one stride dimension is smaller than zero"
-
- if check:
- stride = kwargs["stride"]
- if min(stride) < 1:
- error_result = True
-
- info_dict = {
- "error_name": error_name,
- "error_result": error_result,
- "error_reason": error_reason,
- "param_reqs": param_reqs,
- }
- return info_dict
-
- @staticmethod
- def evDilationSmallerOne(check=False, **kwargs):
- error_result = check and min(kwargs["dilation"]) < 1
- return {
- "error_name": ErrorIf.DilationSmallerOne,
- "error_reason": "At least one dilation is smaller than one",
- "param_reqs": {"rank": None, "dtype": None, "shape": None},
- "error_result": error_result,
- }
-
- @staticmethod
- def evScaleTrue(check=False, **kwargs):
- error_name = ErrorIf.ScaleTrue
- param_reqs = {"rank": None, "dtype": [DType.INT48], "shape": None}
- error_result = False
- error_reason = "Scale set to true but input type is INT48"
-
- if check:
- input_dtype = kwargs["input_dtype"]
- scale32 = kwargs["scale32"]
- if scale32 and input_dtype == DType.INT48:
- error_result = True
-
- info_dict = {
- "error_name": error_name,
- "error_result": error_result,
- "error_reason": error_reason,
- "param_reqs": param_reqs,
- }
- return info_dict
-
- @staticmethod
- def evScaleNotTrue(check=False, **kwargs):
- error_name = ErrorIf.ScaleNotTrue
- param_reqs = {"rank": None, "dtype": None, "shape": None}
- error_result = False
- error_reason = "Scale set to false but double round set to true"
-
- if check:
- scale32 = kwargs["scale32"]
- double_round = kwargs["double_round"]
- if not scale32 and double_round:
- error_result = True
-
- info_dict = {
- "error_name": error_name,
- "error_result": error_result,
- "error_reason": error_reason,
- "param_reqs": param_reqs,
- }
- return info_dict
-
- @staticmethod
- def evTensorSizeInputOutputMismatch(check=False, **kwargs):
- error_name = ErrorIf.TensorSizeInputOutputMismatch
- param_reqs = {"rank": None, "dtype": None, "shape": None}
- error_result = False
- error_reason = "Input tensor size does not match output tensor size"
-
- if check:
- input_shape = kwargs["input_shape"]
- output_shape = kwargs["output_shape"]
- input_size = np.prod(input_shape)
- output_size = np.prod(output_shape)
- if input_size != output_size:
- error_result = True
-
- info_dict = {
- "error_name": error_name,
- "error_result": error_result,
- "error_reason": error_reason,
- "param_reqs": param_reqs,
- }
- return info_dict
-
- @staticmethod
- def evStartSmallerZero(check=False, **kwargs):
- error_name = ErrorIf.StartSmallerZero
- param_reqs = {"rank": None, "dtype": None, "shape": None}
- error_result = False
- error_reason = "Starting point smaller than zero"
-
- if check:
- input_shape = kwargs["input_shape"]
- start = kwargs["start"]
- rank = len(input_shape)
- if len(start) == rank:
- for index in range(rank):
- if start[index] < 0:
- error_result = True
-
- info_dict = {
- "error_name": error_name,
- "error_result": error_result,
- "error_reason": error_reason,
- "param_reqs": param_reqs,
- }
- return info_dict
-
- @staticmethod
- def evSizeSmallerEqualZero(check=False, **kwargs):
- error_name = ErrorIf.SizeSmallerEqualZero
- param_reqs = {"rank": None, "dtype": None, "shape": None}
- error_result = False
- error_reason = "Size smaller than or equal to zero"
-
- if check:
- input_shape = kwargs["input_shape"]
- size = kwargs["size"]
- rank = len(input_shape)
- if len(size) == rank:
- for index in range(rank):
- if size[index] <= 0:
- error_result = True
-
- info_dict = {
- "error_name": error_name,
- "error_result": error_result,
- "error_reason": error_reason,
- "param_reqs": param_reqs,
- }
- return info_dict
-
- @staticmethod
- def evStartSizeOutsideBounds(check=False, **kwargs):
- error_name = ErrorIf.StartSizeOutsideBounds
- param_reqs = {"rank": None, "dtype": None, "shape": None}
- error_result = False
- error_reason = "starting point plus size larger than input dimension"
-
- if check:
- input_shape = kwargs["input_shape"]
- start = kwargs["start"]
- size = kwargs["size"]
- rank = len(input_shape)
- if len(start) == rank and len(size) == rank:
- for index in range(rank):
- if start[index] + size[index] > input_shape[index]:
- error_result = True
-
- info_dict = {
- "error_name": error_name,
- "error_result": error_result,
- "error_reason": error_reason,
- "param_reqs": param_reqs,
- }
- return info_dict
-
- @staticmethod
- def evSizeOutputShapeMismatch(check=False, **kwargs):
- error_name = ErrorIf.SizeOutputShapeMismatch
- param_reqs = {"rank": None, "dtype": None, "shape": None}
- error_result = False
- error_reason = "Size does not match output dimension"
-
- if check:
- input_shape = kwargs["input_shape"]
- output_shape = kwargs["output_shape"]
- size = kwargs["size"]
- rank = len(input_shape)
- if len(size) == rank:
- for index in range(rank):
- if size[index] != output_shape[index]:
- error_result = True
-
- info_dict = {
- "error_name": error_name,
- "error_result": error_result,
- "error_reason": error_reason,
- "param_reqs": param_reqs,
- }
- return info_dict
-
- @staticmethod
- def evInputSizeStartLengthMismatch(check=False, **kwargs):
- error_name = ErrorIf.InputSizeStartLengthMismatch
- param_reqs = {"rank": None, "dtype": None, "shape": None}
- error_result = False
- error_reason = "rank of input not equal to length of start or size"
-
- if check:
- input_shape = kwargs["input_shape"]
- start = kwargs["start"]
- size = kwargs["size"]
- rank = len(input_shape)
- if rank != len(start) or rank != len(size):
- error_result = True
-
- info_dict = {
- "error_name": error_name,
- "error_result": error_result,
- "error_reason": error_reason,
- "param_reqs": param_reqs,
- }
- return info_dict
-
- @staticmethod
- def evIndexOutsideBounds(check=False, **kwargs):
- error_name = ErrorIf.IndexOutsideBounds
- param_reqs = {"rank": None, "dtype": None, "shape": None}
- error_result = False
- error_reason = "Index outside of allowed bounds"
-
- if check:
- input_shape = kwargs["input_shape"]
- perms = kwargs["perms"]
- rank = len(input_shape)
-
- for index in perms:
- if index < 0 or index > rank:
- error_result = True
-
- info_dict = {
- "error_name": error_name,
- "error_result": error_result,
- "error_reason": error_reason,
- "param_reqs": param_reqs,
- }
- return info_dict
-
- @staticmethod
- def evIndexUsedTwice(check=False, **kwargs):
- error_name = ErrorIf.IndexUsedTwice
- param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
- error_result = False
- error_reason = "Index used multiple times"
-
- if check:
- perms = kwargs["perms"]
-
- unique_indices = []
- for index in perms:
- if index in unique_indices:
- error_result = True
- else:
- unique_indices.append(index)
-
- info_dict = {
- "error_name": error_name,
- "error_result": error_result,
- "error_reason": error_reason,
- "param_reqs": param_reqs,
- }
- return info_dict
-
- @staticmethod
- def evMaxSmallerMin(check=False, **kwargs):
- error_name = ErrorIf.MaxSmallerMin
- param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
- error_result = False
- error_reason = "Max value smaller than min value"
-
- if check:
- max_val = kwargs["max_val"]
- min_val = kwargs["min_val"]
- if max_val < min_val:
- error_result = True
-
- info_dict = {
- "error_name": error_name,
- "error_result": error_result,
- "error_reason": error_reason,
- "param_reqs": param_reqs,
- }
- return info_dict
-
- @staticmethod
- def evConcatInputRankMismatch(check=False, **kwargs):
- error_name = ErrorIf.ConcatInputRankMismatch
- param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
- error_result = False
- error_reason = "Input ranks are not identical"
-
- if check:
- inputs = kwargs["inputs"]
- input_shape = kwargs["input_shape"]
- for input in inputs:
- if len(input.shape) != len(input_shape):
- error_result = True
-
- info_dict = {
- "error_name": error_name,
- "error_result": error_result,
- "error_reason": error_reason,
- "param_reqs": param_reqs,
- }
- return info_dict
-
- @staticmethod
- def evConcatInputDimMismatch(check=False, **kwargs):
- error_name = ErrorIf.ConcatInputDimMismatch
- param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
- error_result = False
- error_reason = "Input dimensions differ on too many axes"
-
- if check:
- inputs = kwargs["inputs"]
- input_shape = kwargs["input_shape"]
- axis = kwargs["axis"]
-
- # Ensure rank is valid before checking dims.
- valid_rank = True
- for input in inputs:
- if len(input.shape) != len(input_shape):
- valid_rank = False
-
- if valid_rank:
- for input in inputs:
- for i, dim in enumerate(input.shape):
- if dim != input_shape[i] and axis != i:
- error_result = True
-
- info_dict = {
- "error_name": error_name,
- "error_result": error_result,
- "error_reason": error_reason,
- "param_reqs": param_reqs,
- }
- return info_dict
-
- @staticmethod
- def evConcatShapeSumMismatch(check=False, **kwargs):
- error_name = ErrorIf.ConcatShapeSumMismatch
- param_reqs = {"rank": [2, 4], "dtype": None, "shape": None}
- error_result = False
- error_reason = "Sum of dimensions on axis not equal to output dimension"
-
- if check:
- inputs = kwargs["inputs"]
- input_shape = kwargs["input_shape"]
- output_shape = kwargs["output_shape"]
- axis = kwargs["axis"]
-
- # Ensure rank is valid before checking dims.
- valid_params = True
- for input in inputs:
- if len(input.shape) != len(input_shape):
- valid_params = False
- if axis < 0 or axis > len(input_shape):
- valid_params = False
-
- if valid_params:
- axis_dim_sum = 0
- for input in inputs:
- axis_dim_sum += input.shape[axis]
-
- if axis_dim_sum != output_shape[axis]:
- error_result = True
-
- info_dict = {
- "error_name": error_name,
- "error_result": error_result,
- "error_reason": error_reason,
- "param_reqs": param_reqs,
- }
- return info_dict
-
- @staticmethod
- def evInputListThenGraphMismatch(check=False, **kwargs):
- error_name = ErrorIf.CondIfInputListThenGraphMismatch
- param_reqs = {"rank": None, "dtype": None, "shape": None}
- error_result = False
- error_reason = "Input list shape does not match then-graph shape"
-
- if check:
- a = kwargs["a"]
- b = kwargs["b"]
- basicBlocks = kwargs["basicBlocks"]
- then_block = basicBlocks[1]
- then_inputs = then_block.inputs
- then_tens = then_block.tensors
- if (a.shape != then_tens[then_inputs[0]].shape) or (
- b.shape != then_tens[then_inputs[1]].shape
- ):
- error_result = True
-
- info_dict = {
- "error_name": error_name,
- "error_result": error_result,
- "error_reason": error_reason,
- "param_reqs": param_reqs,
- }
- return info_dict
-
- @staticmethod
- def evInputListElseGraphMismatch(check=False, **kwargs):
- error_name = ErrorIf.CondIfInputListElseGraphMismatch
- param_reqs = {"rank": None, "dtype": None, "shape": None}
- error_result = False
- error_reason = "Input list shape does not match else-graph shape"
-
- if check:
- a = kwargs["a"]
- b = kwargs["b"]
- basicBlocks = kwargs["basicBlocks"]
- else_block = basicBlocks[2]
- else_inputs = else_block.inputs
- else_tens = else_block.tensors
- if (a.shape != else_tens[else_inputs[0]].shape) or (
- b.shape != else_tens[else_inputs[1]].shape
- ):
- error_result = True
-
- info_dict = {
- "error_name": error_name,
- "error_result": error_result,
- "error_reason": error_reason,
- "param_reqs": param_reqs,
- }
- return info_dict
-
- @staticmethod
- def evOutputListThenGraphMismatch(check=False, **kwargs):
- error_name = ErrorIf.CondIfOutputListThenGraphMismatch
- param_reqs = {"rank": None, "dtype": None, "shape": None}
- error_result = False
- error_reason = "Output list shape does not match then-graph shape"
-
- if check:
- basicBlocks = kwargs["basicBlocks"]
- cond_block = basicBlocks[0]
- cond_outputs = cond_block.outputs
- cond_tens = cond_block.tensors
- then_block = basicBlocks[1]
- then_outputs = then_block.outputs
- then_tens = then_block.tensors
- if then_tens[then_outputs[0]].shape != cond_tens[cond_outputs[0]].shape:
- error_result = True
-
- info_dict = {
- "error_name": error_name,
- "error_result": error_result,
- "error_reason": error_reason,
- "param_reqs": param_reqs,
- }
- return info_dict
-
- @staticmethod
- def evOutputListElseGraphMismatch(check=False, **kwargs):
- error_name = ErrorIf.CondIfOutputListElseGraphMismatch
- param_reqs = {"rank": None, "dtype": None, "shape": None}
- error_result = False
- error_reason = "Output list shape does not match else-graph shape"
-
- if check:
- basicBlocks = kwargs["basicBlocks"]
- cond_block = basicBlocks[0]
- cond_outputs = cond_block.outputs
- cond_tens = cond_block.tensors
- else_block = basicBlocks[2]
- else_outputs = else_block.outputs
- else_tens = else_block.tensors
- if else_tens[else_outputs[0]].shape != cond_tens[cond_outputs[0]].shape:
- error_result = True
-
- info_dict = {
- "error_name": error_name,
- "error_result": error_result,
- "error_reason": error_reason,
- "param_reqs": param_reqs,
- }
- return info_dict
-
- @staticmethod
- def evInputListOutputListMismatch(check=False, **kwargs):
- error_name = ErrorIf.InputListOutputListMismatch
- param_reqs = {"rank": None, "dtype": None, "shape": None}
- error_result = False
- error_reason = "Input list does not match output list"
-
- if check:
- basicBlocks = kwargs["basicBlocks"]
- while_block = basicBlocks[0]
- while_inputs = while_block.inputs
- while_outputs = while_block.outputs
- while_tens = while_block.tensors
- if while_tens[while_inputs[1]].shape != while_tens[while_outputs[0]].shape:
- error_result = True
-
- info_dict = {
- "error_name": error_name,
- "error_result": error_result,
- "error_reason": error_reason,
- "param_reqs": param_reqs,
- }
- return info_dict
-
- @staticmethod
- def evInputListCondGraphMismatch(check=False, **kwargs):
- error_name = ErrorIf.InputListCondGraphMismatch
- param_reqs = {"rank": None, "dtype": None, "shape": None}
- error_result = False
- error_reason = "Input list does not match cond graph"
-
- if check:
- basicBlocks = kwargs["basicBlocks"]
- while_block = basicBlocks[0]
- while_inputs = while_block.inputs
- while_tens = while_block.tensors
- cond_block = basicBlocks[1]
- cond_inputs = cond_block.inputs
- cond_tens = cond_block.tensors
- if (
- while_tens[while_inputs[0]].shape != cond_tens[cond_inputs[0]].shape
- ) or (while_tens[while_inputs[1]].shape != cond_tens[cond_inputs[2]].shape):
- error_result = True
-
- info_dict = {
- "error_name": error_name,
- "error_result": error_result,
- "error_reason": error_reason,
- "param_reqs": param_reqs,
- }
- return info_dict
-
- @staticmethod
- def evInputListBodyGraphInputMismatch(check=False, **kwargs):
- error_name = ErrorIf.InputListBodyGraphInputMismatch
- param_reqs = {"rank": None, "dtype": None, "shape": None}
- error_result = False
- error_reason = "Input list does not match body graph input"
-
- if check:
- basicBlocks = kwargs["basicBlocks"]
- while_block = basicBlocks[0]
- while_inputs = while_block.inputs
- while_tens = while_block.tensors
- body_block = basicBlocks[2]
- body_outputs = body_block.inputs
- body_tens = body_block.tensors
- if (
- while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape
- ) or (
- while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape
- ):
- error_result = True
-
- info_dict = {
- "error_name": error_name,
- "error_result": error_result,
- "error_reason": error_reason,
- "param_reqs": param_reqs,
- }
- return info_dict
-
- @staticmethod
- def evInputListBodyGraphOutputMismatch(check=False, **kwargs):
- error_name = ErrorIf.InputListBodyGraphOutputMismatch
- param_reqs = {"rank": None, "dtype": None, "shape": None}
- error_result = False
- error_reason = "Input list does not match body graph output"
-
- if check:
- basicBlocks = kwargs["basicBlocks"]
- while_block = basicBlocks[0]
- while_inputs = while_block.inputs
- while_tens = while_block.tensors
- body_block = basicBlocks[2]
- body_outputs = body_block.outputs
- body_tens = body_block.tensors
- if (
- while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape
- ) or (
- while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape
- ):
- error_result = True
- info_dict = {
- "error_name": error_name,
- "error_result": error_result,
- "error_reason": error_reason,
- "param_reqs": param_reqs,
- }
- return info_dict
-
- @staticmethod
- def evCondGraphOutputNotMatchingBool(check=False, **kwargs):
- error_name = ErrorIf.CondGraphOutputNotMatchingBool
- param_reqs = {"rank": None, "dtype": None, "shape": None}
- error_result = False
- error_reason = "Cond graph output is not a match list of booleans"
-
- if check:
- basicBlocks = kwargs["basicBlocks"]
- cond_block = basicBlocks[1]
- cond_outputs = cond_block.outputs
- cond_tens = cond_block.tensors
- if cond_tens[cond_outputs[0]].dtype != DType.BOOL:
- error_result = True
-
- info_dict = {
- "error_name": error_name,
- "error_result": error_result,
- "error_reason": error_reason,
- "param_reqs": param_reqs,
- }
- return info_dict
-
-
-class TosaInvalidValidator:
- @staticmethod
- def ivWrongDataTypeOrModeResize(**kwargs):
- input_dtype = kwargs["input_dtype"]
- args = kwargs["args"]
- mode = args[0]
- output_dtype = args[8]
-
- if mode == ResizeMode.BILINEAR:
- # Invalid output data type / Invalid input datatype
- return (
- not (input_dtype == DType.INT8 and output_dtype == DType.INT32)
- or not (input_dtype == DType.INT16 and output_dtype == DType.INT48)
- or not (input_dtype == DType.FLOAT and output_dtype == DType.FLOAT)
- or (input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT])
- )
- elif mode == ResizeMode.NEAREST:
- # Invalid output data type / Invalid input datatype
- return (input_dtype != output_dtype) or (
- input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT]
- )
- else:
- # Invalid resize mode
- return True
-
- @staticmethod
- def ivBadStride(**kwargs):
- input_dtype = kwargs["input_dtype"]
- args = kwargs["args"]
- stride_x = args[1][0]
- stride_y = args[1][1]
- stride_fp_x = args[4][0]
- stride_fp_y = args[4][1]
-
- if input_dtype == DType.FLOAT:
- if stride_fp_x <= 0 or stride_fp_y <= 0:
- # Negative or zero stride
- return True
- else:
- if stride_x <= 0 or stride_y <= 0:
- # Negative or zero stride
- return True
- return False
-
- @staticmethod
- def ivHeightWidthInvalid(**kwargs):
- opName = kwargs["opName"]
-
- inputShapes = kwargs["shapeList"]
- input_shape = inputShapes[0]
-
- args = kwargs["args"]
- strides = args[0]
- padding = args[1]
-
- if opName.endswith("pool2d"):
- # avg_pool2d, max_pool2d
- kernel_shape = args[2]
- h = (
- input_shape[1] + padding[0] + padding[1] + strides[0] - kernel_shape[0]
- ) // strides[0]
- w = (
- input_shape[2] + padding[2] + padding[3] + strides[1] - kernel_shape[1]
- ) // strides[1]
- # return True if any dimension is < 1
- return h < 1 or w < 1
-
- if opName.startswith("transpose_conv2d"):
- # transpose_conv2d
- dilations = args[2]
- output_shape = args[3]
- filter_shape = inputShapes[1]
- kernel_shape = filter_shape[1:-1]
-
- def get_out_size(in_size, stride, kernel_size, dilation, out_pad, in_pad):
- """Calculate the transpose_conv2d output size for a dimension.
-
- Based on the keras function deconv_output_length, in
- https://github.com/keras-team/keras/blob/master/keras/utils/conv_utils.py
-
- Args:
- in_size: the input size - int
- stride: the stride - int
- kernel_size: the kernel size - int
- dilation: the kernel dilation - int
- out_pad: the output padding - int
- in_pad: the input padding - int
-
- Returns:
- the output size
- """
- dilated_kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
- return (
- (in_size - 1) * stride + dilated_kernel_size - 2 * in_pad + out_pad
- )
-
- for pad_h, pad_w in (
- (kernel_shape[0] - 1, kernel_shape[1] - 1), # FULL padding
- (kernel_shape[0] // 2, kernel_shape[1] // 2), # SAME padding
- (0, 0), # VALID padding
- ):
- h = get_out_size(
- input_shape[1],
- strides[0],
- kernel_shape[0],
- dilations[0],
- padding[0],
- pad_h,
- )
- w = get_out_size(
- input_shape[2],
- strides[1],
- kernel_shape[1],
- dilations[1],
- padding[1],
- pad_w,
- )
- if output_shape[1] == h and output_shape[2] == w:
- return False
-
- # output shape does not match the expected shape for any padding option
- return True
-
- if "conv2d" in opName or "conv3d" in opName:
- # conv2d, conv3d, depthwise_conv2d
- dilations = args[2]
- filter_shape = inputShapes[1]
- kernel_shape = (
- filter_shape[0:2]
- if opName.startswith("depthwise_conv2d")
- else filter_shape[1:-1]
- )
-
- for i in range(len(kernel_shape)):
- dim = (
- input_shape[i + 1]
- - kernel_shape[i]
- - (kernel_shape[i] - 1) * (dilations[i] - 1)
- + padding[i * 2 + 0]
- + padding[i * 2 + 1]
- ) // strides[i] + 1
- # return True if any dimension is < 1
- if dim < 1:
- return True
- return False
-
- assert False, f"Unrecognized Op: {opName}"
-
- @staticmethod
- def ivNonPositiveOutputShape(**kwargs):
- args = kwargs["args"]
- output_shape = args[3]
- if output_shape[1] <= 0 or output_shape[2] <= 0:
- # Negative output shape
- return True
- return False
class TosaTestGen:
@@ -5567,7 +2044,7 @@ class TosaTestGen:
# Initialize a new random number generator
self.rng = np.random.default_rng(self.random_seed)
- build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
+ build_fcn, tgen_fcn, tvgen_fcn, agen_fcn = op["build_fcn"]
# Test list consists of a tuple of:
# (opName, testNameStr, dtype, shapeList, argumentsList)
@@ -5668,7 +2145,7 @@ class TosaTestGen:
# Create a serializer
self.createSerializer(opName, testStr)
- build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
+ build_fcn, tgen_fcn, tvgen_fcn, agen_fcn = op["build_fcn"]
if "error_if_validators" in op:
error_if_validators = op["error_if_validators"]
else:
@@ -5709,9 +2186,7 @@ class TosaTestGen:
else:
qinfo = None
- tens = self.generate_tensors(
- op, dtypeList, shapeList, testArgs, qinfo, error_name
- )
+ tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, qinfo, error_name)
try:
if error_if_validators is None:
@@ -5750,344 +2225,6 @@ class TosaTestGen:
# The test is not valid
print(f"Invalid ERROR_IF test created: {opName} {testStr}")
- def generate_tensors(
- self, op, dtypeList, shapeList, testArgs, qinfo, error_name=None
- ):
- pCount, cCount = op["operands"]
-
- tens = []
- if op["op"] == Op.NEGATE and dtypeList[0] != DType.FLOAT and error_name is None:
- assert (
- pCount == 1 and cCount == 0
- ), "Op.NEGATE must have 1 placeholders, 0 consts"
- # Must create tensors with values within negatable ranges
- if dtypeList[0] == DType.INT8:
- # Must be within int8, adjustable by input_zp and then negatable
- # and be within int8
- # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
- max_val = min(127, 127 + qinfo.ints[0][1])
- min_val = max(-127, -127 + qinfo.ints[0][1])
- elif dtypeList[0] == DType.INT16:
- max_val = 32767
- min_val = -max_val
- else:
- assert (
- dtypeList[0] == DType.INT32
- ), "Op.NEGATE found with unsupported input type"
- max_val = (1 << 31) - 1
- min_val = -max_val
- arr = np.int32(
- self.rng.integers(low=min_val, high=(max_val + 1), size=shapeList[0])
- )
- placeholders = []
- placeholders.append(
- self.ser.addPlaceholder(shapeList[0], dtypeList[0], arr)
- )
- tens.extend(placeholders)
- elif (
- (op["op"] == Op.ADD or op["op"] == Op.SUB)
- and dtypeList[0] == DType.INT32
- and error_name is None
- ):
- # Make sure the operation does not cause value saturation - where
- # the number wraps due to limited number of bits to store the answer
- assert (
- pCount == 2 and cCount == 0
- ), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
- placeholders = []
- add = op["op"] == Op.ADD
- a_arr = self.getRandTensor(shapeList[0], dtypeList[0])
- b_arr = self.getRandTensor(shapeList[1], dtypeList[1])
- if add:
- res_arr = np.add(a_arr, b_arr, dtype=np.int64)
- else:
- res_arr = np.subtract(a_arr, b_arr, dtype=np.int64)
-
- # Work out the saturation limits
- max_i32 = (1 << 31) - 1
- min_i32 = -(1 << 31)
- max_arr = np.full(shapeList[1], max_i32)
- min_arr = np.full(shapeList[1], min_i32)
-
- # Find how much values exceed the maximum/minimums
- sat_max_arr = np.maximum(res_arr - max_arr, 0)
- sat_min_arr = np.minimum(res_arr - min_arr, 0)
-
- if not add:
- # Swap saturation values and negate values as we need to perform opposite operations
- sat_max_arr, sat_min_arr = -sat_min_arr, -sat_max_arr
-
- # Create new array of unsaturated values by clipping values as needed
- b_unsat_arr = b_arr
- if (sat_max_arr != 0).any():
- # Clip values that cause saturation
- b_unsat_arr = np.subtract(b_unsat_arr, sat_max_arr, dtype=np.int32)
- # Reduce axes in unsaturated tensor to match original tensor
- for axis, dim in enumerate(b_arr.shape):
- if dim != b_unsat_arr.shape[axis]:
- assert (
- dim == 1
- ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
- b_unsat_arr = np.amin(b_unsat_arr, axis=axis, keepdims=True)
-
- if (sat_min_arr != 0).any():
- # Clip values that cause saturation
- b_unsat_arr = np.subtract(b_unsat_arr, sat_min_arr, dtype=np.int32)
- # Reduce axes in unsaturated tensor to match original tensor
- for axis, dim in enumerate(b_arr.shape):
- if dim != b_unsat_arr.shape[axis]:
- assert (
- dim == 1
- ), "Op.ADD / SUB dimension must be 1 or matching to be broadcastable"
- b_unsat_arr = np.amax(b_unsat_arr, axis=axis, keepdims=True)
-
- placeholders.append(
- self.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
- )
- placeholders.append(
- self.ser.addPlaceholder(shapeList[1], dtypeList[1], b_unsat_arr)
- )
-
- tens.extend(placeholders)
- elif (op["op"] == Op.COND_IF or op["op"] == Op.WHILE_LOOP) and dtypeList[0] in (
- DType.INT32,
- DType.INT16,
- DType.INT8,
- ):
- # Limit input tensors with cond_if_binary or while_loop to stop
- # saturation of add/sub ops with int32 and keep all logical shift
- # values between 0 to 31 for int16 or int8
- pRemain = pCount
- placeholders = []
- for idx, shape in enumerate(shapeList[:]):
- if dtypeList[0] == DType.INT32:
- arr = self.getRandTensor(shapeList[idx], DType.INT16)
- else:
- arr = np.int32(
- self.rng.integers(low=0, high=32, size=shapeList[idx])
- )
- if pRemain > 0:
- placeholders.append(
- self.ser.addPlaceholder(shape, dtypeList[idx], arr)
- )
- pRemain -= 1
- else:
- placeholders.append(self.ser.addConst(shape, dtypeList[idx], arr))
-
- tens.extend(placeholders)
- elif op["op"] == Op.ARITHMETIC_RIGHT_SHIFT:
- # Force value of operand[1] to be within [0, num_bits]
- assert (
- pCount == 2 and cCount == 0
- ), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
-
- placeholders = []
- for idx, shape in enumerate(shapeList[:]):
- if idx == 1:
- if dtypeList[idx] == DType.INT8:
- arr = np.int32(self.rng.integers(low=0, high=8, size=shape))
- elif dtypeList[idx] == DType.INT16:
- arr = np.int32(self.rng.integers(low=0, high=16, size=shape))
- elif dtypeList[idx] == DType.INT32:
- arr = np.int32(self.rng.integers(low=0, high=32, size=shape))
- elif error_name == ErrorIf.WrongInputType:
- arr = np.int32(self.rng.integers(low=0, high=8, size=shape))
- else:
- raise Exception("OpArithmeticRightShift: invalid input dtype")
- else:
- arr = self.getRandTensor(shape, dtypeList[idx])
- placeholders.append(self.ser.addPlaceholder(shape, dtypeList[idx], arr))
-
- tens.extend(placeholders)
- elif op["op"] == Op.SELECT:
- # Set datatype of condition tensor to boolean
- dtypeList[0] = DType.BOOL
- tens.extend(
- self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
- )
- tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
- elif op["op"] == Op.INTDIV and error_name is None:
- assert (
- pCount == 2 and cCount == 0
- ), "Op.INTDIV must have 2 placeholders, 0 consts"
-
- placeholders = []
-
- # Two invalid cases for Op.INTDIV:
- # 1. divisor == 0
- # 2. dividend == -(1<<31) and divisor == -1
- while True:
- dividend_arr = self.getRandTensor(shapeList[0], dtypeList[0])
- divisor_arr = self.getRandTensor(shapeList[1], dtypeList[1])
-
- if (divisor_arr == 0).any():
- continue
-
- if (dividend_arr == -(2**31)).any() and (divisor_arr == -1).any():
- continue
-
- break
-
- placeholders.append(
- self.ser.addPlaceholder(shapeList[0], dtypeList[0], dividend_arr)
- )
- placeholders.append(
- self.ser.addPlaceholder(shapeList[1], dtypeList[1], divisor_arr)
- )
-
- tens.extend(placeholders)
- elif op["op"] == Op.MUL and error_name is None:
- assert (
- pCount == 2 and cCount == 0
- ), "Op.MUL must have 2 placeholders, 0 consts"
-
- if dtypeList[0] == DType.FLOAT:
- tens.extend(self.buildPlaceholderTensors(shapeList[:], dtypeList[:]))
- else:
- placeholders = []
-
- # Make sure multiply result in int32 range
- shift = testArgs[0]
- if dtypeList[0] == DType.INT8:
- num_bits = 8
- elif dtypeList[0] == DType.INT16:
- num_bits = 16
- elif dtypeList[0] == DType.INT32:
- num_bits = 32
- elif error_name == ErrorIf.WrongInputType:
- num_bits = 8
- else:
- raise Exception("OpMul: invalid input dtype")
-
- for idx, shape in enumerate(shapeList[:]):
- low = -(2 ** (num_bits - 1))
- high = (2 ** (num_bits - 1)) - 1
-
- a_arr = np.int32(
- self.rng.integers(low=low, high=high, size=shapeList[0])
- )
- b_arr = np.int32(
- self.rng.integers(low=low, high=high, size=shapeList[1])
- )
-
- i = 0
- while True:
-
- a_arr_64 = a_arr.astype(np.int64)
- b_arr_64 = b_arr.astype(np.int64)
-
- if shift > 0:
- rounding = 1 << (shift - 1)
- result_arr = ((a_arr_64 * b_arr_64) + rounding) >> shift
- else:
- result_arr = a_arr_64 * b_arr_64
-
- if (result_arr > -(2**31)).all() and (
- result_arr <= ((2**31) - 1)
- ).all():
- break
-
- i = i + 1
- a_arr = a_arr // 2
- b_arr = b_arr // 2
-
- placeholders.append(
- self.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
- )
- placeholders.append(
- self.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
- )
-
- tens.extend(placeholders)
- elif op["op"] == Op.CONCAT:
- count = len(shapeList) - self.args.num_const_inputs_concat
- if count < 1:
- count = 1
- if self.args.num_const_inputs_concat == 0:
- count = len(shapeList)
-
- # Ensure axis is an int
- testArgs[0] = int(testArgs[0])
-
- shapeList = TosaTensorGen.tgConcatConstInput(
- self, shapeList, testArgs[0], error_name
- )
-
- tens.extend(
- self.buildPlaceholderTensors(shapeList[0:count], dtypeList[0:count])
- )
- tens.extend(self.buildConstTensors(shapeList[count:], dtypeList[count:]))
- elif op["op"] == Op.LOGICAL_LEFT_SHIFT or op["op"] == Op.LOGICAL_RIGHT_SHIFT:
- assert (
- pCount == 2 and cCount == 0
- ), "Op.LOGICAL_LEFT_SHIFT or Op.LOGICAL_RIGHT_SHIFT must have 2 placeholders, 0 consts"
- values_arr = self.getRandTensor(shapeList[0], dtypeList[0])
- shift_arr = np.int32(self.rng.integers(low=0, high=32, size=shapeList[1]))
- placeholders = []
- placeholders.append(
- self.ser.addPlaceholder(shapeList[0], dtypeList[0], values_arr)
- )
- placeholders.append(
- self.ser.addPlaceholder(shapeList[1], dtypeList[1], shift_arr)
- )
- tens.extend(placeholders)
- elif op["op"] == Op.EQUAL and error_name is None:
- assert (
- pCount == 2 and cCount == 0
- ), "Op.EQUAL must have 2 placeholders, 0 consts"
- a_arr = self.getRandTensor(shapeList[0], dtypeList[0])
- b_arr = self.getRandTensor(shapeList[1], dtypeList[1])
- # Using random numbers means that it will be very unlikely that
- # there are any matching (equal) values, therefore force that
- # there are twice the number of matching values as the tensor rank
- for num in range(0, len(shapeList[0]) * 2):
- a_index = []
- b_index = []
- # Choose an index in each axis for the whole shape
- for axis in range(0, len(shapeList[0])):
- # Index can be up to the largest dimension in both shapes
- index = np.int32(
- self.rng.integers(
- 0, max(shapeList[0][axis], shapeList[1][axis])
- )
- )
- # Reduce the index down to a shape's dim for broadcasting
- a_index.append(min(shapeList[0][axis] - 1, index))
- b_index.append(min(shapeList[1][axis] - 1, index))
-
- a_arr[tuple(a_index)] = b_arr[tuple(b_index)]
-
- placeholders = []
- placeholders.append(
- self.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
- )
- placeholders.append(
- self.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
- )
- tens.extend(placeholders)
- elif op["op"] == Op.REDUCE_SUM and dtypeList[0] == DType.INT32:
- assert (
- pCount == 1 and cCount == 0
- ), "Op.REDUCE_SUM must have 1 placeholders, 0 consts"
- # Limit values so that the sum cannot exceed the range of an int32 during
- # summation of any axis
- range_val = int((1 << 31) / max(shapeList[0]))
- values_arr = np.int32(
- self.rng.integers(low=-range_val, high=range_val, size=shapeList[0])
- )
- placeholders = []
- placeholders.append(
- self.ser.addPlaceholder(shapeList[0], dtypeList[0], values_arr)
- )
- tens.extend(placeholders)
- else:
- tens.extend(
- self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
- )
- tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
-
- return tens
-
def createDynamicOpLists(self):
# Dynamically create op lists for convolutions with a list of kernel sizes
@@ -6149,7 +2286,7 @@ class TosaTestGen:
)
try:
- fcn, tgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
+ fcn, tgen, tvgen, arggen = self.TOSA_OP_LIST[op]["build_fcn"]
except (KeyError, ValueError, TypeError):
raise Exception(
"Op {} is missing a valid build_fcn tuple in TOSA_OP_LIST".format(
@@ -6211,7 +2348,12 @@ class TosaTestGen:
"op": Op.ARGMAX,
"operands": (1, 0),
"rank": (1, 4),
- "build_fcn": (build_argmax, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
+ "build_fcn": (
+ build_argmax,
+ TosaTensorGen.tgBasic,
+ TosaTensorValuesGen.tvgDefault,
+ TosaArgGen.agAxis,
+ ),
"types": TYPE_NARROW_INT_FP,
"error_if_validators": (
TosaErrorValidator.evAxisSmallerZero,
@@ -6229,7 +2371,12 @@ class TosaTestGen:
"op": Op.AVG_POOL2D,
"operands": (1, 0),
"rank": (4, 4),
- "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
+ "build_fcn": (
+ build_pool2d,
+ TosaTensorGen.tgNHWC,
+ TosaTensorValuesGen.tvgDefault,
+ TosaArgGen.agPooling,
+ ),
"qgen": TosaQuantGen.qgUnary,
"types": TYPE_NARROW_INT_FP,
"invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
@@ -6253,7 +2400,12 @@ class TosaTestGen:
"op": Op.CONV2D,
"operands": (1, 2),
"rank": (4, 4),
- "build_fcn": (build_conv2d, TosaTensorGen.tgConv2D, TosaArgGen.agConv),
+ "build_fcn": (
+ build_conv2d,
+ TosaTensorGen.tgConv2D,
+ TosaTensorValuesGen.tvgDefault,
+ TosaArgGen.agConv,
+ ),
"qgen": TosaQuantGen.qgConv,
"types": TYPE_CONV,
"invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
@@ -6276,7 +2428,12 @@ class TosaTestGen:
"op": Op.CONV3D,
"operands": (1, 2),
"rank": (5, 5),
- "build_fcn": (build_conv3d, TosaTensorGen.tgConv3D, TosaArgGen.agConv),
+ "build_fcn": (
+ build_conv3d,
+ TosaTensorGen.tgConv3D,
+ TosaTensorValuesGen.tvgDefault,
+ TosaArgGen.agConv,
+ ),
"qgen": TosaQuantGen.qgConv,
"types": TYPE_CONV,
"invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
@@ -6303,6 +2460,7 @@ class TosaTestGen:
"build_fcn": (
build_depthwise_conv2d,
TosaTensorGen.tgDepthwiseConv2D,
+ TosaTensorValuesGen.tvgDefault,
TosaArgGen.agConv,
),
"qgen": TosaQuantGen.qgConv,
@@ -6326,7 +2484,12 @@ class TosaTestGen:
"op": Op.FULLY_CONNECTED,
"operands": (1, 2),
"rank": (2, 2),
- "build_fcn": (build_fully_connected, TosaTensorGen.tgFullyConnected, None),
+ "build_fcn": (
+ build_fully_connected,
+ TosaTensorGen.tgFullyConnected,
+ TosaTensorValuesGen.tvgDefault,
+ None,
+ ),
"qgen": TosaQuantGen.qgConv,
"types": TYPE_CONV,
"error_if_validators": (
@@ -6343,7 +2506,12 @@ class TosaTestGen:
"op": Op.MATMUL,
"operands": (2, 0),
"rank": (3, 3),
- "build_fcn": (build_matmul, TosaTensorGen.tgMatmul, None),
+ "build_fcn": (
+ build_matmul,
+ TosaTensorGen.tgMatmul,
+ TosaTensorValuesGen.tvgDefault,
+ None,
+ ),
"qgen": TosaQuantGen.qgMatmul,
"types": TYPE_NARROW_INT_FP,
"error_if_validators": (
@@ -6359,7 +2527,12 @@ class TosaTestGen:
"op": Op.MAX_POOL2D,
"operands": (1, 0),
"rank": (4, 4),
- "build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
+ "build_fcn": (
+ build_pool2d,
+ TosaTensorGen.tgNHWC,
+ TosaTensorValuesGen.tvgDefault,
+ TosaArgGen.agPooling,
+ ),
"types": TYPE_NARROW_INT_FP,
"invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
"error_if_validators": (
@@ -6383,6 +2556,7 @@ class TosaTestGen:
"build_fcn": (
build_transpose_conv2d,
TosaTensorGen.tgTransposeConv2D,
+ TosaTensorValuesGen.tvgDefault,
TosaArgGen.agTransposeConv2D,
),
"qgen": TosaQuantGen.qgConv,
@@ -6409,7 +2583,12 @@ class TosaTestGen:
"clamp": {
"op": Op.CLAMP,
"operands": (1, 0),
- "build_fcn": (build_clamp, TosaTensorGen.tgBasic, None),
+ "build_fcn": (
+ build_clamp,
+ TosaTensorGen.tgBasic,
+ TosaTensorValuesGen.tvgDefault,
+ None,
+ ),
"types": TYPE_NARROW_INT_FP,
"error_if_validators": (
TosaErrorValidator.evMaxSmallerMin,
@@ -6422,7 +2601,12 @@ class TosaTestGen:
"sigmoid": {
"op": Op.SIGMOID,
"operands": (1, 0),
- "build_fcn": (build_sigmoid, TosaTensorGen.tgBasic, None),
+ "build_fcn": (
+ build_sigmoid,
+ TosaTensorGen.tgBasic,
+ TosaTensorValuesGen.tvgDefault,
+ None,
+ ),
"types": TYPE_FP,
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
@@ -6434,7 +2618,12 @@ class TosaTestGen:
"tanh": {
"op": Op.TANH,
"operands": (1, 0),
- "build_fcn": (build_tanh, TosaTensorGen.tgBasic, None),
+ "build_fcn": (
+ build_tanh,
+ TosaTensorGen.tgBasic,
+ TosaTensorValuesGen.tvgDefault,
+ None,
+ ),
"types": TYPE_FP,
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
@@ -6447,7 +2636,12 @@ class TosaTestGen:
"add": {
"op": Op.ADD,
"operands": (2, 0),
- "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
+ "build_fcn": (
+ build_binary_broadcast,
+ TosaTensorGen.tgBroadcastFuzz,
+ TosaTensorValuesGen.tvgAddSub,
+ None,
+ ),
"types": TYPE_FI32,
"error_if_validators": (
TosaErrorValidator.evRankMismatch,
@@ -6464,6 +2658,7 @@ class TosaTestGen:
"build_fcn": (
build_arithmetic_right_shift,
TosaTensorGen.tgBroadcastFuzz,
+ TosaTensorValuesGen.tvgArithmeticRightShift,
TosaArgGen.agArithmeticRightShift,
),
"types": TYPE_INT,
@@ -6479,7 +2674,12 @@ class TosaTestGen:
"bitwise_and": {
"op": Op.BITWISE_AND,
"operands": (2, 0),
- "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
+ "build_fcn": (
+ build_binary_broadcast,
+ TosaTensorGen.tgBroadcastFuzz,
+ TosaTensorValuesGen.tvgDefault,
+ None,
+ ),
"types": TYPE_INT,
"error_if_validators": (
TosaErrorValidator.evRankMismatch,
@@ -6493,7 +2693,12 @@ class TosaTestGen:
"bitwise_or": {
"op": Op.BITWISE_OR,
"operands": (2, 0),
- "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
+ "build_fcn": (
+ build_binary_broadcast,
+ TosaTensorGen.tgBroadcastFuzz,
+ TosaTensorValuesGen.tvgDefault,
+ None,
+ ),
"types": TYPE_INT,
"error_if_validators": (
TosaErrorValidator.evRankMismatch,
@@ -6507,7 +2712,12 @@ class TosaTestGen:
"bitwise_xor": {
"op": Op.BITWISE_XOR,
"operands": (2, 0),
- "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
+ "build_fcn": (
+ build_binary_broadcast,
+ TosaTensorGen.tgBroadcastFuzz,
+ TosaTensorValuesGen.tvgDefault,
+ None,
+ ),
"types": TYPE_INT,
"error_if_validators": (
TosaErrorValidator.evRankMismatch,
@@ -6521,7 +2731,12 @@ class TosaTestGen:
"intdiv": {
"op": Op.INTDIV,
"operands": (2, 0),
- "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
+ "build_fcn": (
+ build_binary_broadcast,
+ TosaTensorGen.tgBroadcastFuzz,
+ TosaTensorValuesGen.tvgIntDiv,
+ None,
+ ),
"types": [DType.INT32],
"error_if_validators": (
TosaErrorValidator.evRankMismatch,
@@ -6535,7 +2750,12 @@ class TosaTestGen:
"logical_and": {
"op": Op.LOGICAL_AND,
"operands": (2, 0),
- "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
+ "build_fcn": (
+ build_binary_broadcast,
+ TosaTensorGen.tgBroadcastFuzz,
+ TosaTensorValuesGen.tvgDefault,
+ None,
+ ),
"types": TYPE_BOOL,
"error_if_validators": (
TosaErrorValidator.evRankMismatch,
@@ -6549,7 +2769,12 @@ class TosaTestGen:
"logical_left_shift": {
"op": Op.LOGICAL_LEFT_SHIFT,
"operands": (2, 0),
- "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
+ "build_fcn": (
+ build_binary_broadcast,
+ TosaTensorGen.tgBroadcastFuzz,
+ TosaTensorValuesGen.tvgLogicalShift,
+ None,
+ ),
"types": TYPE_INT,
"error_if_validators": (
TosaErrorValidator.evRankMismatch,
@@ -6563,7 +2788,12 @@ class TosaTestGen:
"logical_right_shift": {
"op": Op.LOGICAL_RIGHT_SHIFT,
"operands": (2, 0),
- "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
+ "build_fcn": (
+ build_binary_broadcast,
+ TosaTensorGen.tgBroadcastFuzz,
+ TosaTensorValuesGen.tvgLogicalShift,
+ None,
+ ),
"types": TYPE_INT,
"error_if_validators": (
TosaErrorValidator.evRankMismatch,
@@ -6577,7 +2807,12 @@ class TosaTestGen:
"logical_or": {
"op": Op.LOGICAL_OR,
"operands": (2, 0),
- "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
+ "build_fcn": (
+ build_binary_broadcast,
+ TosaTensorGen.tgBroadcastFuzz,
+ TosaTensorValuesGen.tvgDefault,
+ None,
+ ),
"types": TYPE_BOOL,
"error_if_validators": (
TosaErrorValidator.evRankMismatch,
@@ -6591,7 +2826,12 @@ class TosaTestGen:
"logical_xor": {
"op": Op.LOGICAL_XOR,
"operands": (2, 0),
- "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
+ "build_fcn": (
+ build_binary_broadcast,
+ TosaTensorGen.tgBroadcastFuzz,
+ TosaTensorValuesGen.tvgDefault,
+ None,
+ ),
"types": TYPE_BOOL,
"error_if_validators": (
TosaErrorValidator.evRankMismatch,
@@ -6605,7 +2845,12 @@ class TosaTestGen:
"maximum": {
"op": Op.MAXIMUM,
"operands": (2, 0),
- "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
+ "build_fcn": (
+ build_binary_broadcast,
+ TosaTensorGen.tgBroadcastFuzz,
+ TosaTensorValuesGen.tvgDefault,
+ None,
+ ),
"types": TYPE_FI32,
"error_if_validators": (
TosaErrorValidator.evRankMismatch,
@@ -6619,7 +2864,12 @@ class TosaTestGen:
"minimum": {
"op": Op.MINIMUM,
"operands": (2, 0),
- "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
+ "build_fcn": (
+ build_binary_broadcast,
+ TosaTensorGen.tgBroadcastFuzz,
+ TosaTensorValuesGen.tvgDefault,
+ None,
+ ),
"types": TYPE_FI32,
"error_if_validators": (
TosaErrorValidator.evRankMismatch,
@@ -6633,7 +2883,12 @@ class TosaTestGen:
"mul": {
"op": Op.MUL,
"operands": (2, 0),
- "build_fcn": (build_mul, TosaTensorGen.tgBroadcastFuzz, TosaArgGen.agMul),
+ "build_fcn": (
+ build_mul,
+ TosaTensorGen.tgBroadcastFuzz,
+ TosaTensorValuesGen.tvgMul,
+ TosaArgGen.agMul,
+ ),
"types": TYPE_INT_FP,
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
@@ -6647,7 +2902,12 @@ class TosaTestGen:
"pow": {
"op": Op.POW,
"operands": (2, 0),
- "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
+ "build_fcn": (
+ build_binary_broadcast,
+ TosaTensorGen.tgBroadcastFuzz,
+ TosaTensorValuesGen.tvgDefault,
+ None,
+ ),
"types": TYPE_FP,
"error_if_validators": (
TosaErrorValidator.evRankMismatch,
@@ -6661,7 +2921,12 @@ class TosaTestGen:
"sub": {
"op": Op.SUB,
"operands": (2, 0),
- "build_fcn": (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
+ "build_fcn": (
+ build_binary_broadcast,
+ TosaTensorGen.tgBroadcastFuzz,
+ TosaTensorValuesGen.tvgAddSub,
+ None,
+ ),
"types": TYPE_FI32,
"error_if_validators": (
TosaErrorValidator.evRankMismatch,
@@ -6678,7 +2943,12 @@ class TosaTestGen:
# but create the table tensor in the build function, as it may be
# a different type from the input
"operands": (1, 0),
- "build_fcn": (build_table, TosaTensorGen.tgBasic, TosaArgGen.agTable),
+ "build_fcn": (
+ build_table,
+ TosaTensorGen.tgBasic,
+ TosaTensorValuesGen.tvgDefault,
+ TosaArgGen.agTable,
+ ),
"types": [DType.INT8, DType.INT16],
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
@@ -6691,7 +2961,12 @@ class TosaTestGen:
"abs": {
"op": Op.ABS,
"operands": (1, 0),
- "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
+ "build_fcn": (
+ build_unary,
+ TosaTensorGen.tgBasic,
+ TosaTensorValuesGen.tvgDefault,
+ None,
+ ),
"types": TYPE_FI32,
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
@@ -6703,7 +2978,12 @@ class TosaTestGen:
"bitwise_not": {
"op": Op.BITWISE_NOT,
"operands": (1, 0),
- "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
+ "build_fcn": (
+ build_unary,
+ TosaTensorGen.tgBasic,
+ TosaTensorValuesGen.tvgDefault,
+ None,
+ ),
"types": TYPE_INT,
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
@@ -6715,7 +2995,12 @@ class TosaTestGen:
"ceil": {
"op": Op.CEIL,
"operands": (1, 0),
- "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
+ "build_fcn": (
+ build_unary,
+ TosaTensorGen.tgBasic,
+ TosaTensorValuesGen.tvgDefault,
+ None,
+ ),
"types": TYPE_FP,
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
@@ -6727,7 +3012,12 @@ class TosaTestGen:
"clz": {
"op": Op.CLZ,
"operands": (1, 0),
- "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
+ "build_fcn": (
+ build_unary,
+ TosaTensorGen.tgBasic,
+ TosaTensorValuesGen.tvgDefault,
+ None,
+ ),
"types": [DType.INT32],
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
@@ -6739,7 +3029,12 @@ class TosaTestGen:
"exp": {
"op": Op.EXP,
"operands": (1, 0),
- "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
+ "build_fcn": (
+ build_unary,
+ TosaTensorGen.tgBasic,
+ TosaTensorValuesGen.tvgDefault,
+ None,
+ ),
"types": TYPE_FP,
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
@@ -6751,7 +3046,12 @@ class TosaTestGen:
"floor": {
"op": Op.FLOOR,
"operands": (1, 0),
- "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
+ "build_fcn": (
+ build_unary,
+ TosaTensorGen.tgBasic,
+ TosaTensorValuesGen.tvgDefault,
+ None,
+ ),
"types": TYPE_FP,
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
@@ -6763,7 +3063,12 @@ class TosaTestGen:
"log": {
"op": Op.LOG,
"operands": (1, 0),
- "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
+ "build_fcn": (
+ build_unary,
+ TosaTensorGen.tgBasic,
+ TosaTensorValuesGen.tvgDefault,
+ None,
+ ),
"types": TYPE_FP,
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
@@ -6775,7 +3080,12 @@ class TosaTestGen:
"logical_not": {
"op": Op.LOGICAL_NOT,
"operands": (1, 0),
- "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
+ "build_fcn": (
+ build_unary,
+ TosaTensorGen.tgBasic,
+ TosaTensorValuesGen.tvgDefault,
+ None,
+ ),
"types": TYPE_BOOL,
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
@@ -6787,7 +3097,12 @@ class TosaTestGen:
"negate": {
"op": Op.NEGATE,
"operands": (1, 0),
- "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
+ "build_fcn": (
+ build_unary,
+ TosaTensorGen.tgBasic,
+ TosaTensorValuesGen.tvgNegate,
+ None,
+ ),
"qgen": TosaQuantGen.qgUnary,
"types": TYPE_INT_FP,
"error_if_validators": (
@@ -6802,7 +3117,12 @@ class TosaTestGen:
"reciprocal": {
"op": Op.RECIPROCAL,
"operands": (1, 0),
- "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
+ "build_fcn": (
+ build_unary,
+ TosaTensorGen.tgBasic,
+ TosaTensorValuesGen.tvgDefault,
+ None,
+ ),
"types": TYPE_FP,
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
@@ -6814,7 +3134,12 @@ class TosaTestGen:
"rsqrt": {
"op": Op.RSQRT,
"operands": (1, 0),
- "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
+ "build_fcn": (
+ build_unary,
+ TosaTensorGen.tgBasic,
+ TosaTensorValuesGen.tvgDefault,
+ None,
+ ),
"types": TYPE_FP,
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
@@ -6827,7 +3152,12 @@ class TosaTestGen:
"select": {
"op": Op.SELECT,
"operands": (3, 0),
- "build_fcn": (build_select, TosaTensorGen.tgBroadcastFuzz, None),
+ "build_fcn": (
+ build_select,
+ TosaTensorGen.tgBroadcastFuzz,
+ TosaTensorValuesGen.tvgSelect,
+ None,
+ ),
"types": TYPE_FIB,
"error_if_validators": (
TosaErrorValidator.evRankMismatch,
@@ -6842,7 +3172,12 @@ class TosaTestGen:
"equal": {
"op": Op.EQUAL,
"operands": (2, 0),
- "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
+ "build_fcn": (
+ build_comparison,
+ TosaTensorGen.tgBroadcastFuzz,
+ TosaTensorValuesGen.tvgEqual,
+ None,
+ ),
"types": TYPE_FI32,
"error_if_validators": (
TosaErrorValidator.evRankMismatch,
@@ -6856,7 +3191,12 @@ class TosaTestGen:
"greater_equal": {
"op": Op.GREATER_EQUAL,
"operands": (2, 0),
- "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
+ "build_fcn": (
+ build_comparison,
+ TosaTensorGen.tgBroadcastFuzz,
+ TosaTensorValuesGen.tvgDefault,
+ None,
+ ),
"types": TYPE_FI32,
"error_if_validators": (
TosaErrorValidator.evRankMismatch,
@@ -6870,7 +3210,12 @@ class TosaTestGen:
"greater": {
"op": Op.GREATER,
"operands": (2, 0),
- "build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
+ "build_fcn": (
+ build_comparison,
+ TosaTensorGen.tgBroadcastFuzz,
+ TosaTensorValuesGen.tvgDefault,
+ None,
+ ),
"types": TYPE_FI32,
"error_if_validators": (
TosaErrorValidator.evRankMismatch,
@@ -6886,7 +3231,12 @@ class TosaTestGen:
"op": Op.REDUCE_ALL,
"operands": (1, 0),
"rank": (1, 4),
- "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
+ "build_fcn": (
+ build_reduce,
+ TosaTensorGen.tgBasic,
+ TosaTensorValuesGen.tvgDefault,
+ TosaArgGen.agAxis,
+ ),
"types": TYPE_BOOL,
"error_if_validators": (
TosaErrorValidator.evAxisLargerRank,
@@ -6903,7 +3253,12 @@ class TosaTestGen:
"op": Op.REDUCE_ANY,
"operands": (1, 0),
"rank": (1, 4),
- "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
+ "build_fcn": (
+ build_reduce,
+ TosaTensorGen.tgBasic,
+ TosaTensorValuesGen.tvgDefault,
+ TosaArgGen.agAxis,
+ ),
"types": TYPE_BOOL,
"error_if_validators": (
TosaErrorValidator.evAxisLargerRank,
@@ -6920,7 +3275,12 @@ class TosaTestGen:
"op": Op.REDUCE_MAX,
"operands": (1, 0),
"rank": (1, 4),
- "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
+ "build_fcn": (
+ build_reduce,
+ TosaTensorGen.tgBasic,
+ TosaTensorValuesGen.tvgDefault,
+ TosaArgGen.agAxis,
+ ),
"types": TYPE_INT_FP,
"error_if_validators": (
TosaErrorValidator.evAxisLargerRank,
@@ -6937,7 +3297,12 @@ class TosaTestGen:
"op": Op.REDUCE_MIN,
"operands": (1, 0),
"rank": (1, 4),
- "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
+ "build_fcn": (
+ build_reduce,
+ TosaTensorGen.tgBasic,
+ TosaTensorValuesGen.tvgDefault,
+ TosaArgGen.agAxis,
+ ),
"types": TYPE_INT_FP,
"error_if_validators": (
TosaErrorValidator.evAxisLargerRank,
@@ -6954,7 +3319,12 @@ class TosaTestGen:
"op": Op.REDUCE_PRODUCT,
"operands": (1, 0),
"rank": (1, 4),
- "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
+ "build_fcn": (
+ build_reduce,
+ TosaTensorGen.tgBasic,
+ TosaTensorValuesGen.tvgDefault,
+ TosaArgGen.agAxis,
+ ),
"types": TYPE_FP,
"error_if_validators": (
TosaErrorValidator.evAxisLargerRank,
@@ -6971,7 +3341,12 @@ class TosaTestGen:
"op": Op.REDUCE_SUM,
"operands": (1, 0),
"rank": (1, 4),
- "build_fcn": (build_reduce, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
+ "build_fcn": (
+ build_reduce,
+ TosaTensorGen.tgBasic,
+ TosaTensorValuesGen.tvgReduceSum,
+ TosaArgGen.agAxis,
+ ),
"types": TYPE_FI32,
"error_if_validators": (
TosaErrorValidator.evAxisLargerRank,
@@ -6988,7 +3363,12 @@ class TosaTestGen:
"concat": {
"op": Op.CONCAT,
"operands": (2, 0),
- "build_fcn": (build_concat, TosaTensorGen.tgConcat, TosaArgGen.agAxis),
+ "build_fcn": (
+ build_concat,
+ TosaTensorGen.tgConcat,
+ TosaTensorValuesGen.tvgConcat,
+ TosaArgGen.agAxis,
+ ),
"types": TYPE_FIB,
"error_if_validators": (
TosaErrorValidator.evAxisLargerRank,
@@ -7005,7 +3385,12 @@ class TosaTestGen:
"op": Op.PAD,
"operands": (1, 0),
"rank": (1, 5),
- "build_fcn": (build_pad, TosaTensorGen.tgBasic, TosaArgGen.agPad),
+ "build_fcn": (
+ build_pad,
+ TosaTensorGen.tgBasic,
+ TosaTensorValuesGen.tvgDefault,
+ TosaArgGen.agPad,
+ ),
"qgen": TosaQuantGen.qgPad,
"types": TYPE_FIB,
"error_if_validators": (
@@ -7019,7 +3404,12 @@ class TosaTestGen:
"reshape": {
"op": Op.RESHAPE,
"operands": (1, 0),
- "build_fcn": (build_reshape, TosaTensorGen.tgBasic, TosaArgGen.agReshape),
+ "build_fcn": (
+ build_reshape,
+ TosaTensorGen.tgBasic,
+ TosaTensorValuesGen.tvgDefault,
+ TosaArgGen.agReshape,
+ ),
"types": TYPE_FIB,
"error_if_validators": (
TosaErrorValidator.evTensorSizeInputOutputMismatch,
@@ -7032,7 +3422,12 @@ class TosaTestGen:
"reverse": {
"op": Op.REVERSE,
"operands": (1, 0),
- "build_fcn": (build_reverse, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
+ "build_fcn": (
+ build_reverse,
+ TosaTensorGen.tgBasic,
+ TosaTensorValuesGen.tvgDefault,
+ TosaArgGen.agAxis,
+ ),
"types": TYPE_FIB,
"error_if_validators": (
TosaErrorValidator.evAxisSmallerZero,
@@ -7047,7 +3442,12 @@ class TosaTestGen:
"op": Op.SLICE,
"operands": (1, 0),
"rank": (1, 4),
- "build_fcn": (build_slice, TosaTensorGen.tgBasic, TosaArgGen.agSlice),
+ "build_fcn": (
+ build_slice,
+ TosaTensorGen.tgBasic,
+ TosaTensorValuesGen.tvgDefault,
+ TosaArgGen.agSlice,
+ ),
"types": TYPE_FIB,
"error_if_validators": (
TosaErrorValidator.evStartSmallerZero,
@@ -7065,7 +3465,12 @@ class TosaTestGen:
"tile": {
"op": Op.TILE,
"operands": (1, 0),
- "build_fcn": (build_tile, TosaTensorGen.tgBasic, TosaArgGen.agTile),
+ "build_fcn": (
+ build_tile,
+ TosaTensorGen.tgBasic,
+ TosaTensorValuesGen.tvgDefault,
+ TosaArgGen.agTile,
+ ),
"types": TYPE_FIB,
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
@@ -7081,6 +3486,7 @@ class TosaTestGen:
"build_fcn": (
build_transpose,
TosaTensorGen.tgBasic,
+ TosaTensorValuesGen.tvgDefault,
TosaArgGen.agTranspose,
),
"types": TYPE_FIB,
@@ -7097,13 +3503,23 @@ class TosaTestGen:
"const": {
"op": Op.CONST,
"operands": (0, 1),
- "build_fcn": (build_const, TosaTensorGen.tgBasic, None),
+ "build_fcn": (
+ build_const,
+ TosaTensorGen.tgBasic,
+ TosaTensorValuesGen.tvgDefault,
+ None,
+ ),
"types": TYPE_FIB,
},
"identity": {
"op": Op.IDENTITY,
"operands": (1, 0),
- "build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
+ "build_fcn": (
+ build_unary,
+ TosaTensorGen.tgBasic,
+ TosaTensorValuesGen.tvgDefault,
+ None,
+ ),
"types": TYPE_FIB,
},
# Scatter/Gather
@@ -7112,7 +3528,12 @@ class TosaTestGen:
# Only specify 'values' tensor here. 'indices' is generated in op building stage
"operands": (1, 0),
"rank": (3, 3),
- "build_fcn": (build_gather, TosaTensorGen.tgBasic, None),
+ "build_fcn": (
+ build_gather,
+ TosaTensorGen.tgBasic,
+ TosaTensorValuesGen.tvgDefault,
+ None,
+ ),
"types": TYPE_INT_FP,
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
@@ -7128,7 +3549,12 @@ class TosaTestGen:
# 'indices' and 'input' are generated in op building stage
"operands": (2, 0),
"rank": (3, 3),
- "build_fcn": (build_scatter, TosaTensorGen.tgScatter, None),
+ "build_fcn": (
+ build_scatter,
+ TosaTensorGen.tgScatter,
+ TosaTensorValuesGen.tvgDefault,
+ None,
+ ),
"types": TYPE_INT_FP,
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
@@ -7143,7 +3569,12 @@ class TosaTestGen:
"op": Op.RESIZE,
"operands": (1, 0),
"rank": (4, 4),
- "build_fcn": (build_resize, TosaTensorGen.tgNHWC, TosaArgGen.agResize),
+ "build_fcn": (
+ build_resize,
+ TosaTensorGen.tgNHWC,
+ TosaTensorValuesGen.tvgDefault,
+ TosaArgGen.agResize,
+ ),
"types": [DType.INT8, DType.INT16, DType.FLOAT],
"invalid_test_validators": (
TosaInvalidValidator.ivWrongDataTypeOrModeResize,
@@ -7172,7 +3603,12 @@ class TosaTestGen:
"cast": {
"op": Op.CAST,
"operands": (1, 0),
- "build_fcn": (build_cast, TosaTensorGen.tgBasic, TosaArgGen.agCast),
+ "build_fcn": (
+ build_cast,
+ TosaTensorGen.tgBasic,
+ TosaTensorValuesGen.tvgDefault,
+ TosaArgGen.agCast,
+ ),
"types": [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL],
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
@@ -7185,7 +3621,12 @@ class TosaTestGen:
"op": Op.RESCALE,
"operands": (1, 0),
"rank": (1, 4),
- "build_fcn": (build_rescale, TosaTensorGen.tgBasic, TosaArgGen.agRescale),
+ "build_fcn": (
+ build_rescale,
+ TosaTensorGen.tgBasic,
+ TosaTensorValuesGen.tvgDefault,
+ TosaArgGen.agRescale,
+ ),
"types": [DType.UINT8, DType.INT8, DType.INT16, DType.INT32, DType.INT48],
"error_if_validators": (
TosaErrorValidator.evInputZeroPointNotZero,
@@ -7211,6 +3652,7 @@ class TosaTestGen:
"build_fcn": (
build_cond_if_const,
TosaTensorGen.tgBasic,
+ TosaTensorValuesGen.tvgCondIfWhileLoop,
TosaArgGen.agCondIf,
),
"types": [DType.BOOL],
@@ -7225,6 +3667,7 @@ class TosaTestGen:
"build_fcn": (
build_cond_if_binary,
TosaTensorGen.tgBasic,
+ TosaTensorValuesGen.tvgCondIfWhileLoop,
TosaArgGen.agCondIf,
),
"types": TYPE_INT_FP,
@@ -7242,6 +3685,7 @@ class TosaTestGen:
"build_fcn": (
build_while_loop,
TosaTensorGen.tgBasic,
+ TosaTensorValuesGen.tvgCondIfWhileLoop,
TosaArgGen.agWhileLoop,
),
"types": [DType.INT32],
diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py
new file mode 100644
index 0000000..ca115a2
--- /dev/null
+++ b/verif/generator/tosa_utils.py
@@ -0,0 +1,81 @@
+# Copyright (c) 2021-2022, ARM Limited.
+# SPDX-License-Identifier: Apache-2.0
+from tosa.DType import DType
+
+
+def valueToName(item, value):
+ """Get the name of an attribute with the given value.
+
+ This convenience function is needed to print meaningful names for
+ the values of the tosa.Op.Op and tosa.DType.DType classes.
+ This would not be necessary if they were subclasses of Enum, or
+ IntEnum, which, sadly, they are not.
+
+ Args:
+ item: The class, or object, to find the value in
+ value: The value to find
+
+ Example, to get the name of a DType value:
+
+ name = valueToName(DType, DType.INT8) # returns 'INT8'
+ name = valueToName(DType, 4) # returns 'INT8'
+
+ Returns:
+ The name of the first attribute found with a matching value,
+
+ Raises:
+ ValueError if the value is not found
+ """
+ for attr in dir(item):
+ if getattr(item, attr) == value:
+ return attr
+ raise ValueError(f"value ({value}) not found")
+
+
+def allDTypes(*, excludes=None):
+ """Get a set of all DType values, optionally excluding some values.
+
+ This convenience function is needed to provide a sequence of DType values.
+ This would be much easier if DType was a subclass of Enum, or IntEnum,
+ as we could then iterate over the values directly, instead of using
+ dir() to find the attributes and then check if they are what we want.
+
+ Args:
+ excludes: iterable of DTYPE values (e.g. [DType.INT8, DType.BOOL])
+
+ Returns:
+ A set of DType values
+ """
+ excludes = () if not excludes else excludes
+ return {
+ getattr(DType, t)
+ for t in dir(DType)
+ if not callable(getattr(DType, t))
+ and not t.startswith("__")
+ and getattr(DType, t) not in excludes
+ }
+
+
+def usableDTypes(*, excludes=None):
+ """Get a set of usable DType values, optionally excluding some values.
+
+ Excludes (DType.UNKNOWN, DType.UINT8) in addition to the excludes
+ specified by the caller, as the serializer lib does not support them.
+ If you wish to include 'UNKNOWN' or 'UINT8' use allDTypes instead.
+
+ Args:
+ excludes: iterable of DType values (e.g. [DType.INT8, DType.BOOL])
+
+ Returns:
+ A set of DType values
+ """
+ omit = {DType.UNKNOWN, DType.UINT8}
+ omit.update(excludes if excludes else ())
+ return allDTypes(excludes=omit)
+
+
+def product(shape):
+ value = 1
+ for n in shape:
+ value *= n
+ return value