aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_arg_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_arg_gen.py')
-rw-r--r--verif/generator/tosa_arg_gen.py79
1 files changed, 62 insertions, 17 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index a655a50..f598377 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -622,6 +622,28 @@ class TosaTensorGen:
return new_shapeList
+ @staticmethod
+ def tgShape(testGen, opName, rank, error_name=None):
+ pl, const = opName["operands"]
+ shape = [rank]
+
+ # Constrict the overall size of the shape when creating ERROR_IF tests
+ if error_name:
+ shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
+
+ shape_list = []
+ for i in range(pl + const):
+ shape_list.append(shape.copy())
+
+ # Generates an input rank mismatch for operators with more than one input
+ if error_name == ErrorIf.RankMismatch:
+ if rank == 1 and i != 1:
+ shape = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
+ elif i != 1:
+ shape = testGen.makeShape(rank + testGen.rng.choice([-1, 1]))
+
+ return shape_list
+
class TosaTensorValuesGen:
"""Tensor Value generators create the random data for each tensor in each test."""
@@ -891,7 +913,7 @@ class TosaTensorValuesGen:
@staticmethod
def tvgAddSub(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
- if dtypeList[0] == DType.INT32 and error_name is None:
+ if dtypeList[0] in (DType.INT32, DType.SHAPE) and error_name is None:
# Make sure the integer operation does not cause value saturation - where
# the number wraps due to limited number of bits to store the answer
op = testGen.TOSA_OP_LIST[opName]
@@ -900,9 +922,10 @@ class TosaTensorValuesGen:
pCount == 2 and cCount == 0
), "Op.ADD / Op.SUB must have 2 placeholders, 0 consts"
tens_ser_list = []
- add = op["op"] == Op.ADD
- a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
- b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
+ add = op["op"] in (Op.ADD, Op.ADD_SHAPE)
+ data_range = testGen.args.tensor_shape_range
+ a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0], data_range)
+ b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1], data_range)
if add:
res_arr = np.add(a_arr, b_arr, dtype=np.int64)
else:
@@ -1138,12 +1161,15 @@ class TosaTensorValuesGen:
tens_ser_list = []
# Make sure multiply result in int32 range
- shift = argsDict["shift"]
+ if dtypeList[0] == DType.SHAPE:
+ shift = 0
+ else:
+ shift = argsDict["shift"]
if dtypeList[0] == DType.INT8:
num_bits = 8
elif dtypeList[0] == DType.INT16:
num_bits = 16
- elif dtypeList[0] == DType.INT32:
+ elif dtypeList[0] in (DType.INT32, DType.SHAPE):
num_bits = 32
elif error_name == ErrorIf.WrongInputType:
num_bits = 8
@@ -1151,8 +1177,12 @@ class TosaTensorValuesGen:
raise Exception("OpMul: invalid input dtype")
for idx, shape in enumerate(shapeList[:]):
- low = -(2 ** (num_bits - 1))
- high = (2 ** (num_bits - 1)) - 1
+ if dtypeList[idx] == DType.SHAPE:
+ low = testGen.args.tensor_shape_range[0]
+ high = testGen.args.tensor_shape_range[1]
+ else:
+ low = -(2 ** (num_bits - 1))
+ high = (2 ** (num_bits - 1)) - 1
a_arr = np.int32(
testGen.rng.integers(low=low, high=high, size=shapeList[0])
@@ -1182,12 +1212,20 @@ class TosaTensorValuesGen:
a_arr = a_arr // 2
b_arr = b_arr // 2
- tens_ser_list.append(
- testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
- )
- tens_ser_list.append(
- testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
- )
+ if dtypeList[0] == DType.SHAPE:
+ tens_ser_list.append(
+ testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr_64)
+ )
+ tens_ser_list.append(
+ testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr_64)
+ )
+ else:
+ tens_ser_list.append(
+ testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
+ )
+ tens_ser_list.append(
+ testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
+ )
return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
@@ -1199,9 +1237,16 @@ class TosaTensorValuesGen:
if testGen.args.num_const_inputs_concat == 0:
count = len(shapeList)
- shapeList = TosaTensorGen.tgConcatConstInput(
- testGen, shapeList, argsDict["axis"], error_name
- )
+ op = testGen.TOSA_OP_LIST[opName]
+ if op["op"] == Op.CONCAT_SHAPE:
+ # Set the axis to 0
+ shapeList = TosaTensorGen.tgConcatConstInput(
+ testGen, shapeList, 0, error_name
+ )
+ else:
+ shapeList = TosaTensorGen.tgConcatConstInput(
+ testGen, shapeList, argsDict["axis"], error_name
+ )
# Override default pCount/cCount for operator
argsDict["p_count"] = count