aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_test_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r--verif/generator/tosa_test_gen.py351
1 files changed, 229 insertions, 122 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 3014c81..d15f785 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -1,8 +1,12 @@
# Copyright (c) 2020-2023, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
+import json
import os
from copy import deepcopy
+from datetime import datetime
+from pathlib import Path
+import generator.tosa_utils as gtu
import numpy as np
import serializer.tosa_serializer as ts
from generator.tosa_arg_gen import TosaArgGen
@@ -13,15 +17,15 @@ from generator.tosa_error_if import ErrorIf
from generator.tosa_error_if import TosaErrorIfArgGen
from generator.tosa_error_if import TosaErrorValidator
from generator.tosa_error_if import TosaInvalidValidator
-from generator.tosa_utils import DTYPE_ATTRIBUTES
-from generator.tosa_utils import get_rank_mismatch_shape
-from generator.tosa_utils import get_wrong_output_type
-from generator.tosa_utils import MAX_RESIZE_DIMENSION
-from generator.tosa_utils import usableDTypes
-from generator.tosa_utils import vect_f32_to_bf16
+from schemavalidation.schemavalidation import TestDescSchemaValidator
from tosa.DType import DType
from tosa.Op import Op
+TOSA_AUTOGENERATED_HEADER = f"""// Copyright (c) {datetime.today().year}, ARM Limited
+// SPDX-License-Identifier: Apache-2.0
+// AUTO-GENERATED FILE CREATED BY tosa_verif_build_tests
+"""
+
class TosaTestGen:
# Maximum rank of tensor supported by test generator.
@@ -31,6 +35,10 @@ class TosaTestGen:
TOSA_8K_LEVEL_MAX_KERNEL = 8192
TOSA_8K_LEVEL_MAX_STRIDE = 8192
+ # Main compliance dot product statistical test range
+ TOSA_MI_DOT_PRODUCT_TEST_SETS = range(0, 6)
+ TOSA_MI_DOT_PRODUCT_MIN = 1000
+
def __init__(self, args):
self.args = args
self.basePath = args.output_dir
@@ -45,6 +53,8 @@ class TosaTestGen:
# Work out floating point range
self.random_fp_low = min(args.tensor_fp_value_range)
self.random_fp_high = max(args.tensor_fp_value_range)
+ # JSON schema validation
+ self.descSchemaValidator = TestDescSchemaValidator()
def createSerializer(self, opName, testPath):
self.testPath = os.path.join(opName, testPath)
@@ -53,81 +63,131 @@ class TosaTestGen:
os.makedirs(fullPath, exist_ok=True)
# Embed const data in the flatbuffer
constMode = ts.ConstMode.EMBED
- if self.args.dump_consts:
+ if self.args.lazy_data_gen:
+ # Lazy data generation - so make constants files
+ constMode = ts.ConstMode.INPUTS
+ elif self.args.dump_consts:
constMode = ts.ConstMode.EMBED_DUMP
self.ser = ts.TosaSerializer(fullPath, constMode)
def getSerializer(self):
return self.ser
- def serialize(self, testName):
- with open(
- os.path.join(self.basePath, self.testPath, "{}.tosa".format(testName)), "wb"
- ) as fd:
+ def serialize(self, testName, metaData=None):
+ path = Path(self.basePath) / self.testPath
+
+ # Write out TOSA flatbuffer binary
+ path_fb = path / f"{testName}.tosa"
+ with path_fb.open("wb") as fd:
fd.write(self.ser.serialize())
- with open(os.path.join(self.basePath, self.testPath, "desc.json"), "w") as fd:
- fd.write(self.ser.writeJson("{}.tosa".format(testName)))
+ # Get JSON descriptor from serializer
+ desc = json.loads(self.ser.writeJson(f"{testName}.tosa"))
+
+ if metaData:
+ # Add extra meta data to desc.json
+ desc["meta"] = metaData
+
+ # Validate desc.json before we output it
+ self.descSchemaValidator.validate_config(desc)
+
+ if metaData:
+ if self.args.lazy_data_gen and "data_gen" in metaData:
+ # Output datagen meta data as CPP data
+ path_md = path / f"{testName}_meta_data_gen.cpp"
+ with path_md.open("w") as fd:
+ fd.write(TOSA_AUTOGENERATED_HEADER)
+ fd.write("// Test meta data for data generation setup\n\n")
+ fd.write(f'const char* json_tdg_config_{path.stem} = R"(')
+ json.dump(metaData["data_gen"], fd)
+ fd.write(')";\n\n')
+ if "compliance" in metaData:
+ # Output datagen meta data as CPP data
+ path_md = path / f"{testName}_meta_compliance.cpp"
+ with path_md.open("w") as fd:
+ fd.write(TOSA_AUTOGENERATED_HEADER)
+ fd.write("// Test meta data for compliance validation\n\n")
+ fd.write(f'const char* json_tvf_config_{path.stem} = R"(')
+ json.dump(metaData["compliance"], fd)
+ fd.write(')";\n\n')
+
+ # Write desc.json
+ path_desc = path / "desc.json"
+ with path_desc.open("w") as fd:
+ json.dump(desc, fd, indent=1)
def resetRNG(self, seed=None):
if seed is None:
seed = self.random_seed + 1
self.rng = np.random.default_rng(seed)
- def getRandTensor(self, shape, dtype):
- if dtype == DType.BOOL:
- return np.bool_(self.rng.choice(a=[False, True], size=shape))
- # TOSA specific INT4 weight range from -7 to 7
+ def getDTypeRange(self, dtype, high_inclusive=False):
+ # Returns dtype value range boundaries (low, high)
+ # The high boundary is excluded in the range
+ # unless high_inclusive is True
+
+ if dtype in (DType.FP32, DType.FP16, DType.BF16):
+ return (self.random_fp_low, self.random_fp_high)
+ elif dtype == DType.BOOL:
+ rng = (0, 2)
+ elif dtype == DType.UINT8:
+ rng = (0, 256)
+ elif dtype == DType.UINT16:
+ rng = (0, 65536)
elif dtype == DType.INT4:
- return np.int32(self.rng.integers(low=-7, high=8, size=shape))
+ # TOSA specific INT4 weight range from -7 to 7
+ rng = (-7, 8)
elif dtype == DType.INT8:
- return np.int32(self.rng.integers(low=-128, high=128, size=shape))
- elif dtype == DType.UINT8:
- return np.int32(self.rng.integers(low=0, high=256, size=shape))
+ rng = (-128, 128)
elif dtype == DType.INT16:
- return np.int32(self.rng.integers(low=-32768, high=32768, size=shape))
- elif dtype == DType.UINT16:
- return np.int32(self.rng.integers(low=0, high=65536, size=shape))
- elif (
- dtype == DType.INT32 or dtype == DType.SHAPE
- ): # restricting too large value for SHAPE
- return np.int32(
- self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape)
- )
+ rng = (-32768, 32768)
+ elif dtype in (DType.INT32, DType.SHAPE):
+ # restricting too large value for SHAPE
+ rng = (-(1 << 31), (1 << 31))
elif dtype == DType.INT48:
- return np.int64(
- self.rng.integers(low=-(1 << 47), high=(1 << 47), size=shape)
- )
- elif dtype == DType.FP16:
- return np.float16(
- self.rng.uniform(
- low=self.random_fp_low, high=self.random_fp_high, size=shape
- )
- )
- elif dtype == DType.BF16:
- f32_tensor = np.float32(
- self.rng.uniform(
- low=self.random_fp_low, high=self.random_fp_high, size=shape
- )
- )
- # Floor the last 16 bits of each f32 value
- return np.float32(vect_f32_to_bf16(f32_tensor))
- elif dtype == DType.FP32:
- return np.float32(
- self.rng.uniform(
- low=self.random_fp_low, high=self.random_fp_high, size=shape
- )
- )
+ rng = (-(1 << 47), (1 << 47))
+ else:
+ raise Exception("Unknown dtype: {}".format(dtype))
+
+ if not high_inclusive:
+ # Exclusive high: low <= range < high
+ return rng
else:
- raise Exception("Unrecognized Dtype: {}".format(dtype))
+ # Inclusive range: low <= range <= high
+ return (rng[0], rng[1] - 1)
+
+ def getRandTensor(self, shape, dtype):
+ low, high = self.getDTypeRange(dtype)
+
+ if dtype == DType.BOOL:
+ return np.bool_(self.rng.choice(a=[False, True], size=shape))
+ elif dtype == DType.INT48:
+ return np.int64(self.rng.integers(low=low, high=high, size=shape))
+ elif dtype in (DType.FP16, DType.BF16, DType.FP32):
+ f_tensor = self.rng.uniform(low=low, high=high, size=shape)
+
+ if dtype == DType.FP16:
+ return np.float16(f_tensor)
+ else:
+ f32_tensor = np.float32(f_tensor)
+ if dtype == DType.BF16:
+ # Floor the last 16 bits of each f32 value
+ return np.float32(gtu.vect_f32_to_bf16(f32_tensor))
+ else:
+ return f32_tensor
+ else:
+ # All other integer types
+ return np.int32(self.rng.integers(low=low, high=high, size=shape))
def buildPlaceholderTensors(self, shape_list, dtype_list):
placeholders = []
assert len(shape_list) == len(dtype_list)
+ arr = None
for idx, shape in enumerate(shape_list):
- arr = self.getRandTensor(shape, dtype_list[idx])
+ if not self.args.lazy_data_gen:
+ arr = self.getRandTensor(shape, dtype_list[idx])
placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
return placeholders
@@ -137,8 +197,10 @@ class TosaTestGen:
assert len(shape_list) == len(dtype_list)
+ arr = None
for idx, shape in enumerate(shape_list):
- arr = self.getRandTensor(shape, dtype_list[idx])
+ if not self.args.lazy_data_gen:
+ arr = self.getRandTensor(shape, dtype_list[idx])
consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
return consts
@@ -161,38 +223,20 @@ class TosaTestGen:
return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
def getRandNumberDType(self, dtype):
+ low, high = self.getDTypeRange(dtype)
+
if dtype == DType.FP32:
- return np.float32(
- self.rng.uniform(low=self.random_fp_low, high=self.random_fp_high)
- )
+ return np.float32(self.rng.uniform(low=low, high=high))
elif dtype == DType.FP16:
- return np.float16(
- self.rng.uniform(low=self.random_fp_low, high=self.random_fp_high)
- )
+ return np.float16(self.rng.uniform(low=low, high=high))
elif dtype == DType.BF16:
- rand_f32 = np.float32(
- self.rng.uniform(low=self.random_fp_low, high=self.random_fp_high)
- )
- return vect_f32_to_bf16(rand_f32)
+ rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
+ return gtu.vect_f32_to_bf16(rand_f32)
elif dtype == DType.BOOL:
return self.rng.choice([False, True])
- # TOSA specific INT4 weight range from -7 to 7
- elif dtype == DType.INT4:
- low, high = (-7, 8)
- elif dtype == DType.INT8:
- low, high = (-128, 128)
- elif dtype == DType.INT16:
- low, high = (-32768, 32768)
- elif (
- dtype == DType.INT32 or dtype == DType.SHAPE
- ): # restricting too large value for SHAPE
- low, high = (-(1 << 31), (1 << 31))
elif dtype == DType.INT48:
- low, high = (-(1 << 47), (1 << 47))
# Special size
return np.int64(self.rng.integers(low, high, size=1))[0]
- else:
- raise Exception("Unknown dtype: {}".format(dtype))
return np.int32(self.rng.integers(low, high, size=1))[0]
@@ -212,8 +256,8 @@ class TosaTestGen:
# Limit types to the first 2 as the 3rd is the accumulator
return "x".join(strs[:2])
else:
- if dtype in DTYPE_ATTRIBUTES:
- return DTYPE_ATTRIBUTES[dtype]["str"]
+ if dtype in gtu.DTYPE_ATTRIBUTES:
+ return gtu.DTYPE_ATTRIBUTES[dtype]["str"]
else:
raise Exception(
"Unknown dtype, cannot convert to string: {}".format(dtype)
@@ -221,8 +265,8 @@ class TosaTestGen:
def typeWidth(self, dtype):
"""Get the datatype width for data types"""
- if dtype in DTYPE_ATTRIBUTES:
- return DTYPE_ATTRIBUTES[dtype]["width"]
+ if dtype in gtu.DTYPE_ATTRIBUTES:
+ return gtu.DTYPE_ATTRIBUTES[dtype]["width"]
else:
raise Exception(f"Unknown dtype, cannot determine width: {dtype}")
@@ -237,11 +281,44 @@ class TosaTestGen:
low=self.args.tensor_shape_range[0], high=self.args.tensor_shape_range[1]
)
- # Argument generators
- # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
- # Where the string descriptor is used to generate the test name and
- # The build_fcn_arg_list is expanded and passed to the operator test
- # build function
+ def tensorComplianceMetaData(self, op, argsDict, outputTensor, errorName):
+ if errorName:
+ # No compliance for error tests
+ return None
+ # Create compliance meta data for expected output tensor
+ compliance_tens = {"mode": None}
+ if argsDict["dg_type"] == gtu.DataGenType.DOT_PRODUCT:
+ mode = gtu.ComplianceMode.DOT_PRODUCT
+ compliance_tens["dot_product_info"] = {
+ "s": argsDict["s"],
+ "ks": argsDict["ks"],
+ "data_type": gtu.DTYPE_ATTRIBUTES[outputTensor.dtype]["json"],
+ }
+ elif argsDict["dg_type"] == gtu.DataGenType.OP_SPECIAL:
+ mode = gtu.ComplianceMode.FP_SPECIAL
+ elif "compliance" in op and "ulp" in op["compliance"]:
+ mode = gtu.ComplianceMode.ULP
+ compliance_tens["ulp_info"] = {"ulp": op["compliance"]["ulp"]}
+ elif op["op"] == Op.REDUCE_PRODUCT:
+ mode = gtu.ComplianceMode.REDUCE_PRODUCT
+ else:
+ mode = gtu.ComplianceMode.EXACT
+ compliance_tens["mode"] = gtu.ComplianceMode(mode).name
+
+ return compliance_tens
+
+ # Build Op functions
+ # Create the output tensor (calling OutputShaper as needed)
+ # Do final tweaks to attributes (if necessary for errorIf)
+ # Add Op into graph
+ # Return resulting tensor information or BuildInfo
+
+ class BuildInfo:
+ """Enhanced build information containing result tensor and associated compliance dict."""
+
+ def __init__(self, resultTensor, complianceDict):
+ self.resultTensor = resultTensor
+ self.complianceDict = complianceDict
def build_unary(self, op, a, validator_fcns=None, error_name=None, qinfo=None):
result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
@@ -975,15 +1052,16 @@ class TosaTestGen:
return result_tens
def build_matmul(
- self, op, a, b, accum_dtype, validator_fcns=None, error_name=None, qinfo=None
+ self, op, a, b, args_dict, validator_fcns=None, error_name=None, qinfo=None
):
- result_tens = OutputShaper.matmulOp(
+ accum_dtype = args_dict["acc_type"]
+ result_tensor = OutputShaper.matmulOp(
self.ser, self.rng, a, b, accum_dtype, error_name
)
# Invalidate Input/Output list for error if checks.
input_list = [a.name, b.name]
- output_list = [result_tens.name]
+ output_list = [result_tensor.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
@@ -999,10 +1077,10 @@ class TosaTestGen:
input_dtype=a.dtype,
input2_shape=b.shape,
input2_dtype=b.dtype,
- output_shape=result_tens.shape,
- output_dtype=result_tens.dtype,
+ output_shape=result_tensor.shape,
+ output_dtype=result_tensor.dtype,
qinfo=qinfo,
- result_tensors=[result_tens],
+ result_tensors=[result_tensor],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -1014,7 +1092,12 @@ class TosaTestGen:
attr.MatMulAttribute(qinfo[0], qinfo[1])
self.ser.addOperator(op["op"], input_list, output_list, attr)
- return result_tens
+
+ compliance = self.tensorComplianceMetaData(
+ op, args_dict, result_tensor, error_name
+ )
+
+ return TosaTestGen.BuildInfo(result_tensor, compliance)
def build_reduce(self, op, a, axis, validator_fcns, error_name=None):
result_tens = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
@@ -1895,7 +1978,7 @@ class TosaTestGen:
def _get_condition_tensor(self, op, cond, error_name):
if error_name == ErrorIf.CondIfCondNotMatchingBool:
- cond_type = get_wrong_output_type(op, self.rng, DType.BOOL)
+ cond_type = gtu.get_wrong_output_type(op, self.rng, DType.BOOL)
else:
cond_type = DType.BOOL
if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
@@ -2357,7 +2440,7 @@ class TosaTestGen:
# Initialize a new random number generator
self.rng = np.random.default_rng(self.random_seed)
- build_fcn, tgen_fcn, tvgen_fcn, agen_fcn = op["build_fcn"]
+ _, tgen_fcn, _, agen_fcn = op["build_fcn"]
# Test list consists of a tuple of:
# (opName, testNameStr, dtype, shapeList, argumentsList)
@@ -2461,7 +2544,7 @@ class TosaTestGen:
# Create a serializer
self.createSerializer(opName, testStr)
- build_fcn, tgen_fcn, tvgen_fcn, agen_fcn = op["build_fcn"]
+ build_fcn, _, tvgen_fcn, _ = op["build_fcn"]
if "error_if_validators" in op:
error_if_validators = op["error_if_validators"]
else:
@@ -2495,24 +2578,37 @@ class TosaTestGen:
qgen = None
# Build the random tensor operands and the test
- tens = []
if qgen is not None:
qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
else:
qinfo = None
- tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, error_name)
+ # Extra meta data for the desc.json
+ tensMeta = {}
+
+ # Check we are using the new testArgs interface with an argsDict dictionary
+ if len(testArgs) == 1 and isinstance(testArgs[0], dict):
+ argsDict = testArgs[0]
+ assert "dg_type" in argsDict
+ tvgInfo = tvgen_fcn(
+ self, opName, dtypeList, shapeList, argsDict, error_name
+ )
+ if tvgInfo.dataGenDict:
+ tensMeta["data_gen"] = tvgInfo.dataGenDict
+ tens = tvgInfo.tensorList
+ else:
+ tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, error_name)
try:
if error_if_validators is None:
if qinfo is not None:
- resultName = build_fcn(self, op, *tens, *testArgs, qinfo)
+ result = build_fcn(self, op, *tens, *testArgs, qinfo)
else:
- resultName = build_fcn(self, op, *tens, *testArgs)
+ result = build_fcn(self, op, *tens, *testArgs)
else:
if qinfo is not None:
- resultName = build_fcn(
+ result = build_fcn(
self,
op,
*tens,
@@ -2522,7 +2618,7 @@ class TosaTestGen:
qinfo=qinfo,
)
else:
- resultName = build_fcn(
+ result = build_fcn(
self,
op,
*tens,
@@ -2534,9 +2630,16 @@ class TosaTestGen:
print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
raise e
- if resultName:
+ if result:
# The test is valid, serialize it
- self.serialize("test")
+ if isinstance(result, TosaTestGen.BuildInfo) and result.complianceDict:
+ # Add the compliance meta data
+ # NOTE: This currently expects only one result output
+ tensMeta["compliance"] = {
+ "version": "0.1",
+ "tensors": {result.resultTensor.name: result.complianceDict},
+ }
+ self.serialize("test", tensMeta)
else:
# The test is not valid
print(f"Invalid ERROR_IF test created: {opName} {testStr}")
@@ -2865,7 +2968,7 @@ class TosaTestGen:
"build_fcn": (
build_matmul,
TosaTensorGen.tgMatmul,
- TosaTensorValuesGen.tvgDefault,
+ TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agMatMul,
),
"qgen": TosaQuantGen.qgMatmul,
@@ -2878,6 +2981,10 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
+ "data_gen": {
+ "fp": (gtu.DataGenType.DOT_PRODUCT,),
+ "int": (gtu.DataGenType.PSEUDO_RANDOM,),
+ },
},
"max_pool2d": {
"op": Op.MAX_POOL2D,
@@ -4446,7 +4553,7 @@ class OutputShaper:
excludes = [DType.FP16, DType.FP32]
else:
excludes = [out_dtype]
- wrong_dtypes = list(usableDTypes(excludes=excludes))
+ wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
out_dtype = rng.choice(wrong_dtypes)
return ser.addOutput(ofm_shape, out_dtype)
@@ -4508,7 +4615,7 @@ class OutputShaper:
excludes = [DType.FP16, DType.FP32]
else:
excludes = [out_dtype]
- wrong_dtypes = list(usableDTypes(excludes=excludes))
+ wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
out_dtype = rng.choice(wrong_dtypes)
return ser.addOutput(ofm_shape, out_dtype)
@@ -4559,7 +4666,7 @@ class OutputShaper:
excludes = [DType.FP16, DType.FP32]
else:
excludes = [out_dtype]
- wrong_dtypes = list(usableDTypes(excludes=excludes))
+ wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
out_dtype = rng.choice(wrong_dtypes)
return ser.addOutput(ofm_shape, out_dtype)
@@ -4711,7 +4818,7 @@ class OutputShaper:
bad_dim = rng.choice(range(len(output_shape)))
output_shape[bad_dim] -= rng.choice([1, 2])
elif error_name == ErrorIf.RankMismatch:
- output_shape = get_rank_mismatch_shape(rng, output_shape)
+ output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
if error_name == ErrorIf.WrongOutputType:
all_dtypes = [
@@ -4806,7 +4913,7 @@ class OutputShaper:
elif error_name == ErrorIf.InputSizeStartLengthMismatch:
output_shape = input.shape.copy()
elif error_name == ErrorIf.RankMismatch:
- output_shape = get_rank_mismatch_shape(rng, output_shape)
+ output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
return ser.addOutput(output_shape, outputDType)
@@ -4820,7 +4927,7 @@ class OutputShaper:
output_shape[i] = a.shape[i] * multiples[i]
if error_name == ErrorIf.RankMismatch:
- output_shape = get_rank_mismatch_shape(rng, output_shape)
+ output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
if error_name == ErrorIf.WrongOutputType:
all_dtypes = [
@@ -4853,7 +4960,7 @@ class OutputShaper:
for i in range(len(output_shape)):
output_shape[i] += rng.integers(1, 10)
elif error_name == ErrorIf.RankMismatch:
- output_shape = get_rank_mismatch_shape(rng, output_shape)
+ output_shape = gtu.get_rank_mismatch_shape(rng, output_shape)
if error_name == ErrorIf.WrongOutputType:
all_dtypes = [
@@ -4980,21 +5087,21 @@ class OutputShaper:
oh = max(oh, 1)
ow = max(ow, 1)
if error_name != ErrorIf.MaxDimExceeded:
- oh = min(oh, MAX_RESIZE_DIMENSION - 1)
- ow = min(ow, MAX_RESIZE_DIMENSION - 1)
+ oh = min(oh, gtu.MAX_RESIZE_DIMENSION - 1)
+ ow = min(ow, gtu.MAX_RESIZE_DIMENSION - 1)
if error_name == ErrorIf.ResizeOutputShapeMismatch:
choices = [1, 2, 3]
change = rng.choice(choices)
# increment in multiples of scale_y/x_d so we don't hit non-integer error case
if change in [1, 3]:
- if oh + scale_y_d >= MAX_RESIZE_DIMENSION:
+ if oh + scale_y_d >= gtu.MAX_RESIZE_DIMENSION:
oh -= scale_y_d
assert oh > 0 # Should have been caught in agResize
else:
oh += scale_y_d
if change in [2, 3]:
- if ow + scale_x_d >= MAX_RESIZE_DIMENSION:
+ if ow + scale_x_d >= gtu.MAX_RESIZE_DIMENSION:
ow -= scale_x_d
assert ow > 0 # Should have been caught in agResize
else:
@@ -5051,7 +5158,7 @@ class OutputShaper:
excludes = [DType.FP16, DType.FP32]
else:
excludes = [out_dtype]
- wrong_dtypes = list(usableDTypes(excludes=excludes))
+ wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
out_dtype = rng.choice(wrong_dtypes)
return ser.addOutput(output_shape, out_dtype)
@@ -5075,7 +5182,7 @@ class OutputShaper:
if error_name == ErrorIf.WrongOutputType:
excludes = [DType.FP32]
- wrong_dtypes = list(usableDTypes(excludes=excludes))
+ wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
output_dtype = rng.choice(wrong_dtypes)
elif error_name == ErrorIf.BatchMismatch:
output_shape[0] += rng.integers(1, 10)
@@ -5100,7 +5207,7 @@ class OutputShaper:
output_dtype = value.dtype
if error_name == ErrorIf.WrongOutputType:
excludes = [DType.FP32]
- wrong_dtypes = list(usableDTypes(excludes=excludes))
+ wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
output_dtype = rng.choice(wrong_dtypes)
elif error_name == ErrorIf.BatchMismatch:
output_shape[0] += rng.integers(1, 10)