aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthew Haddon <matthew.haddon@arm.com>2021-07-27 09:12:49 +0100
committerEric Kunze <eric.kunze@arm.com>2021-08-19 15:01:20 +0000
commit818ab900ed8e64f43aeebff9924ad883fc349e64 (patch)
treee798900fd1ad317396134668b40a73d4f7c57bdf
parenta9017401461224b9bc81e7b1c770ca6091e0e3fb (diff)
downloadreference_model-818ab900ed8e64f43aeebff9924ad883fc349e64.tar.gz
Produce Concat tests with multiple input tensors
* Concat tests now contain between 2 and 5 input tensors concatenated together * Both input and const tensors are used as inputs to the operator * Option to add in const tensor inputs (this is slow), defaults to original behaviour Signed-off-by: Matthew Haddon <matthew.haddon@arm.com> Change-Id: I2a0cc622d31aceab8d24521668d0aae040ba73b1
-rw-r--r--verif/tosa_test_gen.py107
-rwxr-xr-xverif/tosa_verif_build_tests.py9
2 files changed, 98 insertions, 18 deletions
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py
index 5c25f8e..5138e3f 100644
--- a/verif/tosa_test_gen.py
+++ b/verif/tosa_test_gen.py
@@ -330,6 +330,46 @@ class TosaTensorGen:
return [a_shape, b_shape]
+ @staticmethod
+ def tgConcat(testGen, opName, rank):
+ pl, const = opName["operands"]
+ shape = testGen.makeShape(rank)
+
+ # Create extra tensors to concat.
+ # Take into account value of pl when getting maximum number of concats
+ num_tensors = testGen.randInt(0, 4)
+ shape_list = []
+ for i in range(pl + const + num_tensors):
+ shape_list.append(shape.copy())
+
+ return shape_list
+
+ @staticmethod
+ def tgConcatConstInput(testGen, shapeList, axis):
+ # Split concat shape along axis to allow for multiple const inputs
+ # without making too many large tensors
+ shape = shapeList[0]
+ if len(shapeList) == 2 or shape[axis] < len(shapeList):
+ return shapeList
+
+ new_shapeList = [shape.copy()]
+ length_on_axis = shape[axis]
+ remaining_length = length_on_axis
+ for i in range(len(shapeList)-2):
+ # Calculate split on axis and remaining value
+ split_shape_val = int(shape[axis] / 2)
+ remaining_length = remaining_length - split_shape_val
+
+ # Append new shape, and set remaining shape
+ shape[axis] = split_shape_val
+ new_shapeList.append(shape.copy())
+ shape[axis] = remaining_length
+ if i == len(shapeList) - 3:
+ new_shapeList.append(shape.copy())
+
+ return new_shapeList
+
+
class TosaArgGen:
"""Argument generators create exhaustive or random lists of attributes for operators that take
@@ -1263,13 +1303,23 @@ class TosaTestGen:
self.ser.addOperator(op, [a.name], [result_tens.name])
return result_tens
- def build_concat(self, op, a, b, axis):
- result_tens = OutputShaper.concatOp(self.ser, a, b, axis)
+ def build_concat(self, op, *a):
+ assert (type(a[-1]) == int)
+
+ # To store variable length list of input tensors we need to store axis along with it
+ axis = a[-1]
+ a = a[:-1]
+
+ result_tens = OutputShaper.concatOp(self.ser, axis, *a)
attr = ts.TosaSerializerAttribute()
attr.AxisAttribute(axis)
- self.ser.addOperator(op, [a.name, b.name], [result_tens.name], attr)
+ input_tensor_names = []
+ for tensor in a:
+ input_tensor_names.append(tensor.name)
+
+ self.ser.addOperator(op, input_tensor_names, [result_tens.name], attr)
def build_pad(self, op, a, padding, qinfo):
result_tens = OutputShaper.padOp(self.ser, a, padding)
@@ -1708,19 +1758,22 @@ class TosaTestGen:
if isinstance(dtype_or_dtypeList, list):
dtypeList = dtype_or_dtypeList
+ elif op['op'] == Op.CONCAT:
+ dtypeList = [dtype_or_dtypeList] * len(shapeList)
else:
dtypeList = [dtype_or_dtypeList] * (num_operands)
- assert (
- len(shapeList) == num_operands
- ), "shapeList length {} must match number of operands {}".format(
- len(shapeList), num_operands
- )
- assert (
- len(dtypeList) == num_operands
- ), "dtypeList length {} must match number of operands {}".format(
- len(dtypeList), num_operands
- )
+ if op['op'] != Op.CONCAT:
+ assert (
+ len(shapeList) == num_operands
+ ), "shapeList length {} must match number of operands {}".format(
+ len(shapeList), num_operands
+ )
+ assert (
+ len(dtypeList) == num_operands
+ ), "dtypeList length {} must match number of operands {}".format(
+ len(dtypeList), num_operands
+ )
try:
qgen = op["qgen"]
@@ -1850,6 +1903,18 @@ class TosaTestGen:
)
tens.extend(placeholders)
+ elif op["op"] == Op.CONCAT:
+ count = len(shapeList) - self.args.num_const_inputs_concat
+ if count < 1:
+ count = 1
+ if self.args.num_const_inputs_concat == 0:
+ count = len(shapeList)
+
+ shapeList = TosaTensorGen.tgConcatConstInput(self, shapeList, testArgs[0])
+ tens.extend(
+ self.buildPlaceholderTensors(shapeList[0:count], dtypeList[0:count])
+ )
+ tens.extend(self.buildConstTensors(shapeList[count:], dtypeList[count:]))
else:
tens.extend(
self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
@@ -2336,7 +2401,7 @@ class TosaTestGen:
"concat": {
"op": Op.CONCAT,
"operands": (2, 0),
- "build_fcn": (build_concat, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
+ "build_fcn": (build_concat, TosaTensorGen.tgConcat, TosaArgGen.agAxis),
"types": TYPE_FIB,
},
"pad": {
@@ -2694,12 +2759,18 @@ class OutputShaper:
return ser.addOutput(output_shape, out_dtype)
@staticmethod
- def concatOp(ser, a, b, axis):
+ def concatOp(ser, axis, *a):
+ input1 = a[0]
+ remaining_inputs = a[1:]
- output_shape = a.shape.copy()
- output_shape[axis] = a.shape[axis] + b.shape[axis]
+ output_shape = input1.shape.copy()
- return ser.addOutput(output_shape, a.dtype)
+ output_shape[axis] = input1.shape[axis]
+
+ for tensor in remaining_inputs:
+ output_shape[axis] += tensor.shape[axis]
+
+ return ser.addOutput(output_shape, input1.dtype)
@staticmethod
def padOp(ser, a, padding):
diff --git a/verif/tosa_verif_build_tests.py b/verif/tosa_verif_build_tests.py
index 15482e6..343d8d4 100755
--- a/verif/tosa_verif_build_tests.py
+++ b/verif/tosa_verif_build_tests.py
@@ -192,6 +192,15 @@ def parseArgs():
help="Create test with a particular DType (may be repeated)",
)
+ parser.add_argument(
+ "--num-const-inputs-concat",
+ dest="num_const_inputs_concat",
+ default=0,
+ choices=[0, 1, 2, 3],
+ type=int,
+ help="Allow constant input tensors for concat operator",
+ )
+
args = parser.parse_args()
return args