aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_arg_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_arg_gen.py')
-rw-r--r--verif/generator/tosa_arg_gen.py47
1 files changed, 37 insertions, 10 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index 0851aca..592c491 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -254,19 +254,16 @@ class TosaTensorGen:
return shape_list
@staticmethod
- def tgBroadcastFuzz(testGen, op, rank, error_name=None):
+ def _get_broadcast_shapes(testGen, num_shapes, rank, error_name=None):
shape = testGen.makeShape(rank)
-
- pl, const = op["operands"]
-
shape_list = []
# Choose one of the inputs to broadcast
# Note: Simplifies OutputShaper code if we don't change first shape for errors
- bcast_idx = testGen.randInt(0 if error_name is None else 1, pl + const)
+ bcast_idx = testGen.randInt(0 if error_name is None else 1, num_shapes)
fuzz_idx = testGen.randInt(0, rank)
- for i in range(pl + const):
+ for i in range(num_shapes):
shape_bcast = shape.copy()
# To test broadcasting, the chosen fuzz index dimension should not be 1
@@ -295,6 +292,22 @@ class TosaTensorGen:
return shape_list
@staticmethod
+ def tgBroadcastFuzz(testGen, op, rank, error_name=None):
+ pl, const = op["operands"]
+ num_shapes = pl + const
+ return TosaTensorGen._get_broadcast_shapes(
+ testGen, num_shapes, rank, error_name
+ )
+
+ @staticmethod
+ def tgMul(testGen, op, rank, error_name=None):
+ # Get broadcast shapes for the first 2 inputs as the 3rd is shift
+ shape_list = TosaTensorGen._get_broadcast_shapes(testGen, 2, rank, error_name)
+ # Add a single dimension tensor for shift
+ shape_list.append([1])
+ return shape_list
+
+ @staticmethod
def tgConv2D(testGen, op, rank, error_name=None):
pl, const = op["operands"]
@@ -727,7 +740,12 @@ class TosaTensorValuesGen:
# Ignore lazy data gen option and create data array using any range limits
if "fixed_data" in argsDict and argsDict["fixed_data"][idx] is not None:
- arr = np.int64(argsDict["fixed_data"][idx])
+ if dtype == DType.SHAPE:
+ arr = np.int64(argsDict["fixed_data"][idx])
+ elif dtype == DType.INT8:
+ arr = np.int8(argsDict["fixed_data"][idx])
+ else:
+ assert False, "Unsupported fixed_data type"
else:
arr = testGen.getRandTensor(shape, dtype, data_range)
if roundMode:
@@ -1147,6 +1165,13 @@ class TosaTensorValuesGen:
if data_range:
argsDict["data_range"] = data_range
+ if dtypeList[0] != DType.SHAPE:
+ # Need to supply shift tensor for MUL (not needed for MUL_SHAPE)
+ dtypeList[2] = DType.INT8
+ shapeList[2] = [1]
+ # Create a new list for the pre-generated data in argsDict["fixed_data"]
+ argsDict["fixed_data"] = [None, None, [argsDict["shift"]]]
+
return TosaTensorValuesGen.tvgLazyGenDefault(
testGen, opName, dtypeList, shapeList, argsDict, error_name
)
@@ -1154,9 +1179,6 @@ class TosaTensorValuesGen:
# Integer test
op = testGen.TOSA_OP_LIST[opName]
pCount, cCount = op["operands"]
- assert (
- pCount == 2 and cCount == 0
- ), "Op.MUL must have 2 placeholders, 0 consts"
tens_ser_list = []
@@ -1213,6 +1235,7 @@ class TosaTensorValuesGen:
b_arr = b_arr // 2
if dtypeList[0] == DType.SHAPE:
+ # MUL_SHAPE with 2 inputs
tens_ser_list.append(
testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr_64)
)
@@ -1220,12 +1243,16 @@ class TosaTensorValuesGen:
testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr_64)
)
else:
+ # MUL with 3 inputs (3rd is shift)
tens_ser_list.append(
testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
)
tens_ser_list.append(
testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
)
+ tens_ser_list.append(
+ testGen.ser.addPlaceholder([1], DType.INT8, np.int8([shift]))
+ )
return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)