aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_arg_gen.py
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2023-12-07 16:35:28 +0000
committerEric Kunze <eric.kunze@arm.com>2023-12-14 17:56:51 +0000
commita8420add949564053495ef78f3213f163c30fb9a (patch)
tree4c5e2783433e9443b2ed02e5e25c51cc5de2affd /verif/generator/tosa_arg_gen.py
parent81db5d2f275f69cc0d3e8687af57bdba99971042 (diff)
downloadreference_model-a8420add949564053495ef78f3213f163c30fb9a.tar.gz
Main Compliance testing for SCATTER and GATHER
Added indices shuffling and random INT32 support to generate lib with testing of these new random generator modes Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: I058d8b092470228075e8fe69c2ededa639163003
Diffstat (limited to 'verif/generator/tosa_arg_gen.py')
-rw-r--r--verif/generator/tosa_arg_gen.py150
1 files changed, 141 insertions, 9 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index 35253e0..50811ac 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -204,7 +204,7 @@ class TosaTensorGen:
return shape_list
@staticmethod
- def tgScatter(testGen, opName, rank, error_name=None):
+ def tgGather(testGen, opName, rank, error_name=None):
pl, const = opName["operands"]
assert pl == 2
@@ -212,12 +212,31 @@ class TosaTensorGen:
if error_name != ErrorIf.WrongRank:
assert rank == 3
+ values_shape = testGen.makeShape(rank)
+ values_shape = testGen.constrictBatchSize(values_shape)
+
+ N = values_shape[0]
+ W = testGen.makeDimension()
+ indices_shape = [N, W]
+
+ shape_list = [values_shape, indices_shape]
+ return shape_list
+
+ @staticmethod
+ def tgScatter(testGen, opName, rank, error_name=None):
+ pl, const = opName["operands"]
+
+ assert pl == 3
+ assert const == 0
+ if error_name != ErrorIf.WrongRank:
+ assert rank == 3
+
values_in_shape = testGen.makeShape(rank)
- K = values_in_shape[1]
+ values_in_shape = testGen.constrictBatchSize(values_in_shape)
- # ignore max batch size if target shape is set
- if testGen.args.max_batch_size and not testGen.args.target_shapes:
- values_in_shape[0] = min(values_in_shape[0], testGen.args.max_batch_size)
+ N = values_in_shape[0]
+ K = values_in_shape[1]
+ C = values_in_shape[2]
# Make sure W is not greater than K, as we can only write each output index
# once (having a W greater than K means that you have to repeat a K index)
@@ -225,11 +244,12 @@ class TosaTensorGen:
W_max = min(testGen.args.tensor_shape_range[1], K)
W = testGen.randInt(W_min, W_max) if W_min < W_max else W_min
- input_shape = [values_in_shape[0], W, values_in_shape[2]]
+ input_shape = [N, W, C]
shape_list = []
- shape_list.append(values_in_shape.copy())
- shape_list.append(input_shape.copy())
+ shape_list.append(values_in_shape)
+ shape_list.append([N, W]) # indices
+ shape_list.append(input_shape)
return shape_list
@@ -695,6 +715,13 @@ class TosaTensorValuesGen:
"round" in argsDict["data_range_list"][idx]
and argsDict["data_range_list"][idx]["round"] is True
)
+ if data_range is not None and dtype not in (
+ DType.FP16,
+ DType.FP32,
+ DType.BF16,
+ ):
+ # Change from inclusive to exclusive range
+ data_range = (data_range[0], data_range[1] + 1)
# Ignore lazy data gen option and create data array using any range limits
arr = testGen.getRandTensor(shape, dtype, data_range)
if roundMode:
@@ -732,13 +759,15 @@ class TosaTensorValuesGen:
# TODO - generate seed for this generator based on test
info["rng_seed"] = 42
+ data_range = None
if "data_range_list" in argsDict:
data_range = argsDict["data_range_list"][idx]["range"]
if "round" in argsDict["data_range_list"][idx]:
info["round"] = argsDict["data_range_list"][idx]["round"]
elif "data_range" in argsDict:
data_range = argsDict["data_range"]
- else:
+
+ if data_range is None:
data_range = testGen.getDTypeRange(
dtypeList[idx], high_inclusive=True
)
@@ -1455,6 +1484,109 @@ class TosaTensorValuesGen:
testGen, opName, dtypeList, shapeList, argsDict, error_name
)
+ @staticmethod
+ def tvgGather(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
+ K = shapeList[0][1]
+
+ # Fix the type of the indices tensor
+ dtypeList[1] = DType.INT32
+
+ dtype = dtypeList[0]
+ if not gtu.dtypeIsSupportedByCompliance(dtype):
+ # Test unsupported by data generator
+ op = testGen.TOSA_OP_LIST[opName]
+ pCount, cCount = op["operands"]
+ assert (
+ pCount == 2 and cCount == 0
+ ), "Op.GATHER must have 2 placeholders, 0 consts"
+
+ tens_ser_list = []
+ for idx, shape in enumerate(shapeList):
+ dtype = dtypeList[idx]
+ if idx != 1:
+ arr = testGen.getRandTensor(shape, dtype)
+ tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
+ else:
+ # Limit data range of indices tensor upto K (exclusive)
+ arr = testGen.getRandTensor(shape, dtype, (0, K))
+ # To match old functionality - create indices as CONST
+ tens_ser_list.append(testGen.ser.addConst(shape, dtype, arr))
+
+ return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
+
+ else:
+ # ERROR_IF or floating point test
+ # Use inclusive values upto index K for indices tensor
+ data_range_list = (
+ {"range": None},
+ {"range": (0, K - 1)},
+ )
+ argsDict["data_range_list"] = data_range_list
+
+ return TosaTensorValuesGen.tvgLazyGenDefault(
+ testGen, opName, dtypeList, shapeList, argsDict, error_name
+ )
+
+ @staticmethod
+ def tvgScatter(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
+ K = shapeList[0][1]
+ W = shapeList[2][1]
+
+ # Work out an indices tensor here with data that doesn't exceed the
+ # dimension K of the values_in tensor and does NOT repeat the same K
+ # location as needed by the spec:
+ # "It is not permitted to repeat the same output index within a single
+ # SCATTER operation and so each output index occurs at most once."
+ assert K >= W, "Op.SCATTER W must be smaller or equal to K"
+
+ # Fix the type of the indices tensor
+ dtypeList[1] = DType.INT32
+
+ dtype = dtypeList[0]
+ if not gtu.dtypeIsSupportedByCompliance(dtype):
+ # Test unsupported by data generator
+ op = testGen.TOSA_OP_LIST[opName]
+ pCount, cCount = op["operands"]
+ assert (
+ pCount == 3 and cCount == 0
+ ), "Op.SCATTER must have 3 placeholders, 0 consts"
+
+ tens_ser_list = []
+ for idx, shape in enumerate(shapeList):
+ dtype = dtypeList[idx]
+ if idx != 1:
+ arr = testGen.getRandTensor(shape, dtype)
+ tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
+ else:
+ # Create the indices array
+ assert dtype == DType.INT32, "Op.SCATTER unexpected indices type"
+ arr = []
+ for n in range(shape[0]):
+ # Get a shuffled list of output indices (0 to K-1) and
+ # limit length to W
+ arr.append(testGen.rng.permutation(K)[:W])
+ indices_arr = np.array(arr, dtype=np.int32) # (N, W)
+ # To match old functionality - create indices as CONST
+ tens_ser_list.append(
+ testGen.ser.addConst(shape, dtype, indices_arr)
+ )
+
+ return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
+
+ else:
+ # ERROR_IF or floating point test
+ # Use inclusive values upto index K for indices tensor
+ data_range_list = (
+ {"range": None},
+ {"range": (0, K - 1)},
+ {"range": None},
+ )
+ argsDict["data_range_list"] = data_range_list
+
+ return TosaTensorValuesGen.tvgLazyGenDefault(
+ testGen, opName, dtypeList, shapeList, argsDict, error_name
+ )
+
class TosaArgGen:
"""Argument generators create exhaustive or random lists of attributes for