aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_arg_gen.py
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2024-02-08 11:45:44 +0000
committerJeremy Johnson <jeremy.johnson@arm.com>2024-02-12 12:17:53 +0000
commit587cc84c2b8c4b0d030b5e257c9a32461c0969b9 (patch)
tree2465605214e9723636bad2645508300421967052 /verif/generator/tosa_arg_gen.py
parent01e1c1c7f965ceb07e78a3b1ad063161c0f47b94 (diff)
downloadreference_model-587cc84c2b8c4b0d030b5e257c9a32461c0969b9.tar.gz
Update test builder internal interfaces
Move remaining ops from using testArgs to argsDict. All tvg/build_fcn function interfaces updated. Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: Ie886fd931bd74608bda621363100bf8bfd7385e6
Diffstat (limited to 'verif/generator/tosa_arg_gen.py')
-rw-r--r--verif/generator/tosa_arg_gen.py117
1 files changed, 79 insertions, 38 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index 33e74b5..7ec0cfe 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -636,18 +636,6 @@ class TosaTensorValuesGen:
self.tensorList = tensorList
self.dataGenDict = dataGenDict
- @staticmethod
- def tvgDefault(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
- pCount, cCount = op["operands"]
-
- tens = []
- tens.extend(
- testGen.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
- )
- tens.extend(testGen.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
-
- return tens
-
# Default high value for random numbers
TVG_FLOAT_HIGH_VALUE = {
DType.FP32: (1 << 128) - (1 << (127 - 23)),
@@ -969,7 +957,7 @@ class TosaTensorValuesGen:
@staticmethod
def tvgCondIfWhileLoop(
- testGen, op, dtypeList, shapeList, testArgs, error_name=None
+ testGen, opName, dtypeList, shapeList, argsDict, error_name=None
):
if dtypeList[0] in (
DType.INT32,
@@ -979,9 +967,10 @@ class TosaTensorValuesGen:
# Limit input tensors with cond_if_binary or while_loop to stop
# saturation of add/sub ops with int32 and keep all logical shift
# values between 0 to 31 for int16 or int8
+ op = testGen.TOSA_OP_LIST[opName]
pCount, cCount = op["operands"]
pRemain = pCount
- placeholders = []
+ tens_ser_list = []
for idx, shape in enumerate(shapeList[:]):
if dtypeList[0] == DType.INT32:
arr = testGen.getRandTensor(shapeList[idx], DType.INT16)
@@ -990,32 +979,33 @@ class TosaTensorValuesGen:
testGen.rng.integers(low=0, high=32, size=shapeList[idx])
)
if pRemain > 0:
- placeholders.append(
+ tens_ser_list.append(
testGen.ser.addPlaceholder(shape, dtypeList[idx], arr)
)
pRemain -= 1
else:
- placeholders.append(
+ tens_ser_list.append(
testGen.ser.addConst(shape, dtypeList[idx], arr)
)
- return placeholders
+ return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
else:
- return TosaTensorValuesGen.tvgDefault(
- testGen, op, dtypeList, shapeList, testArgs, error_name
+ return TosaTensorValuesGen.tvgLazyGenDefault(
+ testGen, opName, dtypeList, shapeList, argsDict, error_name
)
@staticmethod
def tvgArithmeticRightShift(
- testGen, op, dtypeList, shapeList, testArgs, error_name=None
+ testGen, opName, dtypeList, shapeList, argsDict, error_name=None
):
+ op = testGen.TOSA_OP_LIST[opName]
pCount, cCount = op["operands"]
# Force value of operand[1] to be within [0, num_bits]
assert (
pCount == 2 and cCount == 0
), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
- placeholders = []
+ tens_ser_list = []
for idx, shape in enumerate(shapeList[:]):
if idx == 1:
if dtypeList[idx] == DType.INT8:
@@ -1030,23 +1020,23 @@ class TosaTensorValuesGen:
raise Exception("OpArithmeticRightShift: invalid input dtype")
else:
arr = testGen.getRandTensor(shape, dtypeList[idx])
- placeholders.append(testGen.ser.addPlaceholder(shape, dtypeList[idx], arr))
+ tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtypeList[idx], arr))
- return placeholders
+ return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
@staticmethod
- def tvgReshape(testGen, op, dtypeList, shapeList, argsDict, error_name=None):
+ def tvgReshape(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
dtypeList[1] = DType.SHAPE
shapeList[1] = [len(argsDict["new_shape"])]
# Create a new list for the pre-generated data in argsDict["fixed_data"]
argsDict["fixed_data"] = [None, argsDict["new_shape"]]
return TosaTensorValuesGen.tvgLazyGenDefault(
- testGen, op, dtypeList, shapeList, argsDict, error_name
+ testGen, opName, dtypeList, shapeList, argsDict, error_name
)
@staticmethod
- def tvgPad(testGen, op, dtypeList, shapeList, argsDict, error_name=None):
+ def tvgPad(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
# argsDict["pad"] is 2D array, need to flatten it to get list of values
pad_values = argsDict["pad"].flatten()
dtypeList[1] = DType.SHAPE
@@ -1055,11 +1045,11 @@ class TosaTensorValuesGen:
argsDict["fixed_data"] = [None, pad_values]
return TosaTensorValuesGen.tvgLazyGenDefault(
- testGen, op, dtypeList, shapeList, argsDict, error_name
+ testGen, opName, dtypeList, shapeList, argsDict, error_name
)
@staticmethod
- def tvgSlice(testGen, op, dtypeList, shapeList, argsDict, error_name=None):
+ def tvgSlice(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
dtypeList[1] = DType.SHAPE
shapeList[1] = [len(argsDict["start"])]
dtypeList[2] = DType.SHAPE
@@ -1068,17 +1058,17 @@ class TosaTensorValuesGen:
argsDict["fixed_data"] = [None, argsDict["start"], argsDict["size"]]
return TosaTensorValuesGen.tvgLazyGenDefault(
- testGen, op, dtypeList, shapeList, argsDict, error_name
+ testGen, opName, dtypeList, shapeList, argsDict, error_name
)
@staticmethod
- def tvgTile(testGen, op, dtypeList, shapeList, argsDict, error_name=None):
+ def tvgTile(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
dtypeList[1] = DType.SHAPE
shapeList[1] = [len(argsDict["multiples"])]
argsDict["fixed_data"] = [None, argsDict["multiples"]]
return TosaTensorValuesGen.tvgLazyGenDefault(
- testGen, op, dtypeList, shapeList, argsDict, error_name
+ testGen, opName, dtypeList, shapeList, argsDict, error_name
)
@staticmethod
@@ -2776,10 +2766,23 @@ class TosaArgGen:
int(double_round),
int(per_channel),
),
- [outDtype, scale32, double_round, per_channel],
+ {
+ "output_dtype": outDtype,
+ "scale": scale32,
+ "double_round": double_round,
+ "per_channel": per_channel,
+ },
)
)
+ arg_list = TosaArgGen._add_data_generators(
+ testGen,
+ opName,
+ inDtype,
+ arg_list,
+ error_name,
+ )
+ # Return list of tuples: (arg_str, args_dict)
return arg_list
@staticmethod
@@ -2808,9 +2811,20 @@ class TosaArgGen:
def agArithmeticRightShift(testGen, opName, shapeList, dtype, error_name=None):
arg_list = []
- arg_list.append(("roundTrue", [True]))
- arg_list.append(("roundFalse", [False]))
+ for round in (True, False):
+ args_dict = {
+ "round": round,
+ }
+ arg_list.append((f"round{round}", args_dict))
+ arg_list = TosaArgGen._add_data_generators(
+ testGen,
+ opName,
+ dtype,
+ arg_list,
+ error_name,
+ )
+ # Return list of tuples: (arg_str, args_dict)
return arg_list
@staticmethod
@@ -3414,9 +3428,18 @@ class TosaArgGen:
arg_list.append(
(
"",
- [table],
+ {"table": table},
)
)
+ # Now add data generator types
+ arg_list = TosaArgGen._add_data_generators(
+ testGen,
+ opName,
+ dtype,
+ arg_list,
+ error_name,
+ )
+ # Return list of tuples: (arg_str, args_dict)
return arg_list
def agCondIf(testGen, opName, shapeList, dtype, error_name=None):
@@ -3426,15 +3449,33 @@ class TosaArgGen:
arg_list = []
for c in [False, True]:
- arg_list.append(("cond{}".format(int(c)), [c]))
+ arg_list.append(("cond{}".format(int(c)), {"condition": c}))
+ # Now add data generator types
+ arg_list = TosaArgGen._add_data_generators(
+ testGen,
+ opName,
+ dtype,
+ arg_list,
+ error_name,
+ )
+ # Return list of tuples: (arg_str, args_dict)
return arg_list
def agWhileLoop(testGen, opName, shapeList, dtype, error_name=None):
# While loop: 0 iterations, 1, more than 1
arg_list = []
- for iter in [0, 1, 4]:
- arg_list.append(("iter{}".format(iter), [iter]))
+ for iterations in [0, 1, 4]:
+ arg_list.append(("iter{}".format(iterations), {"iterations": iterations}))
+ # Now add data generator types
+ arg_list = TosaArgGen._add_data_generators(
+ testGen,
+ opName,
+ dtype,
+ arg_list,
+ error_name,
+ )
+ # Return list of tuples: (arg_str, args_dict)
return arg_list