aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_test_gen.py
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2023-11-01 11:29:56 +0000
committerEric Kunze <eric.kunze@arm.com>2023-11-10 16:35:17 +0000
commitbfc53031803338d9f0866f88f1d2deffd4928bcc (patch)
treeb43dbe24df3445639344662e76d1dc37341c2a4f /verif/generator/tosa_test_gen.py
parent7bf0cb990b55d5738c8dc4291686576654d2d8ab (diff)
downloadreference_model-bfc53031803338d9f0866f88f1d2deffd4928bcc.tar.gz
Main Compliance testing support for ARGMAX, REDUCE_SUM/MAX/MIN
Add extra tests for FP32 REDUCE_SUM that meet MIN_DOT_PRODUCTS. Plus improved dot product test generation skip information. Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: Ia8198a9500ddddfc86c5bb84230b9a4edf5ffd50
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r--verif/generator/tosa_test_gen.py156
1 files changed, 99 insertions, 57 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 556a0d8..3180cf5 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -658,12 +658,17 @@ class TosaTestGen:
)
return result_tens
- def build_argmax(self, op, a, axis, validator_fcns, error_name):
- result_tens = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
+ def build_argmax(
+ self, op, inputs, args_dict, validator_fcns, error_name, qinfo=None
+ ):
+ assert len(inputs) == 1
+ a = inputs[0]
+ axis = args_dict["axis"]
+ result_tensor = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
# Invalidate Input/Output list for error if checks.
input_list = [a.name]
- output_list = [result_tens.name]
+ output_list = [result_tensor.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
@@ -678,9 +683,9 @@ class TosaTestGen:
axis=axis,
input_shape=a.shape,
input_dtype=a.dtype,
- output_shape=result_tens.shape,
- output_dtype=result_tens.dtype,
- result_tensors=[result_tens],
+ output_shape=result_tensor.shape,
+ output_dtype=result_tensor.dtype,
+ result_tensors=[result_tensor],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -691,7 +696,11 @@ class TosaTestGen:
attr.AxisAttribute(axis)
self.ser.addOperator(op["op"], input_list, output_list, attr)
- return result_tens
+
+ compliance = self.tensorComplianceMetaData(
+ op, inputs[0].dtype, args_dict, result_tensor, error_name
+ )
+ return TosaTestGen.BuildInfo(result_tensor, compliance)
def build_pool2d(
self,
@@ -1173,12 +1182,17 @@ class TosaTestGen:
return TosaTestGen.BuildInfo(result_tensor, compliance)
- def build_reduce(self, op, a, axis, validator_fcns, error_name=None):
- result_tens = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
+ def build_reduce(
+ self, op, inputs, args_dict, validator_fcns, error_name=None, qinfo=None
+ ):
+ assert len(inputs) == 1
+ a = inputs[0]
+ axis = args_dict["axis"]
+ result_tensor = OutputShaper.reduceOp(self.ser, self.rng, a, axis, error_name)
# Invalidate Input/Output list for error if checks.
input_list = [a.name]
- output_list = [result_tens.name]
+ output_list = [result_tensor.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
@@ -1192,10 +1206,10 @@ class TosaTestGen:
op=op,
axis=axis,
input_shape=a.shape,
- output_shape=result_tens.shape,
+ output_shape=result_tensor.shape,
input_dtype=a.dtype,
- output_dtype=result_tens.dtype,
- result_tensors=[result_tens],
+ output_dtype=result_tensor.dtype,
+ result_tensors=[result_tensor],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -1206,7 +1220,16 @@ class TosaTestGen:
attr.AxisAttribute(axis)
self.ser.addOperator(op["op"], input_list, output_list, attr)
- return result_tens
+
+ if op["op"] == Op.REDUCE_PRODUCT:
+ # TODO: Add compliance support!
+ compliance = None
+ else:
+ compliance = self.tensorComplianceMetaData(
+ op, a.dtype, args_dict, result_tensor, error_name
+ )
+
+ return TosaTestGen.BuildInfo(result_tensor, compliance)
def build_clamp(self, op, a, validator_fcns=None, error_name=None):
result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
@@ -1373,25 +1396,24 @@ class TosaTestGen:
self.ser.addOperator(op["op"], input_list, output_list)
return result_tens
- def build_concat(self, op, *a, validator_fcns=None, error_name=None):
+ def build_concat(
+ self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
+ ):
+ axis = args_dict["axis"]
if error_name != ErrorIf.WrongInputType:
- assert type(a[-1]) == int
+ assert type(axis) == int
- # To store variable length list of input tensors we need to store axis along with it
- axis = a[-1]
- a = a[:-1]
-
- result_tens = OutputShaper.concatOp(
- self.ser, self.rng, axis, *a, error_name=error_name
+ result_tensor = OutputShaper.concatOp(
+ self.ser, self.rng, axis, inputs, error_name=error_name
)
input_tensor_names = []
- for tensor in a:
+ for tensor in inputs:
input_tensor_names.append(tensor.name)
# Invalidate Input/Output list for error if checks.
input_list = input_tensor_names
- output_list = [result_tens.name]
+ output_list = [result_tensor.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
@@ -1404,12 +1426,12 @@ class TosaTestGen:
error_name,
op=op,
axis=axis,
- input_shape=a[0].shape,
- output_shape=result_tens.shape,
- input_dtype=a[0].dtype,
- output_dtype=result_tens.dtype,
- inputs=a,
- result_tensors=[result_tens],
+ input_shape=inputs[0].shape,
+ output_shape=result_tensor.shape,
+ input_dtype=inputs[0].dtype,
+ output_dtype=result_tensor.dtype,
+ inputs=inputs,
+ result_tensors=[result_tensor],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -1420,7 +1442,7 @@ class TosaTestGen:
attr.AxisAttribute(axis)
self.ser.addOperator(op["op"], input_list, output_list, attr)
- return result_tens
+ return TosaTestGen.BuildInfo(result_tensor, None)
def build_pad(
self,
@@ -1483,17 +1505,20 @@ class TosaTestGen:
def build_dim(
self,
op,
- a,
- axis,
+ inputs,
+ args_dict,
validator_fcns=None,
error_name=None,
qinfo=None,
):
- result_tens = OutputShaper.dimOp(self.ser, self.rng, a, axis, error_name)
+ assert len(inputs) == 1
+ a = inputs[0]
+ axis = args_dict["axis"]
+ result_tensor = OutputShaper.dimOp(self.ser, self.rng, a, axis, error_name)
# Invalidate Input/Output list for error if checks.
input_list = [a.name]
- output_list = [result_tens.name]
+ output_list = [result_tensor.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
@@ -1508,9 +1533,9 @@ class TosaTestGen:
axis=axis,
input_shape=a.shape,
input_dtype=a.dtype,
- output_shape=result_tens.shape,
- output_dtype=result_tens.dtype,
- result_tensors=[result_tens],
+ output_shape=result_tensor.shape,
+ output_dtype=result_tensor.dtype,
+ result_tensors=[result_tensor],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -1521,7 +1546,7 @@ class TosaTestGen:
attr.AxisAttribute(axis)
self.ser.addOperator(op["op"], input_list, output_list, attr)
- return result_tens
+ return TosaTestGen.BuildInfo(result_tensor, None)
def build_reshape(self, op, a, newShape, validator_fcns=None, error_name=None):
result_tens = OutputShaper.reshapeOp(
@@ -1559,12 +1584,17 @@ class TosaTestGen:
self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
- def build_reverse(self, op, a, axis, validator_fcns=None, error_name=None):
- result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
+ def build_reverse(
+ self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
+ ):
+ assert len(inputs) == 1
+ a = inputs[0]
+ axis = args_dict["axis"]
+ result_tensor = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
# Invalidate Input/Output list for error if checks.
input_list = [a.name]
- output_list = [result_tens.name]
+ output_list = [result_tensor.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
@@ -1578,10 +1608,10 @@ class TosaTestGen:
op=op,
axis=axis,
input_shape=a.shape,
- output_shape=result_tens.shape,
+ output_shape=result_tensor.shape,
input_dtype=a.dtype,
- output_dtype=result_tens.dtype,
- result_tensors=[result_tens],
+ output_dtype=result_tensor.dtype,
+ result_tensors=[result_tensor],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -1592,7 +1622,7 @@ class TosaTestGen:
attr.AxisAttribute(axis)
self.ser.addOperator(op["op"], input_list, output_list, attr)
- return result_tens
+ return TosaTestGen.BuildInfo(result_tensor, None)
def build_transpose(self, op, a, perms, validator_fcns=None, error_name=None):
result_tens = OutputShaper.transposeOp(self.ser, self.rng, a, perms, error_name)
@@ -2898,7 +2928,7 @@ class TosaTestGen:
"build_fcn": (
build_argmax,
TosaTensorGen.tgBasic,
- TosaTensorValuesGen.tvgDefault,
+ TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agAxis,
),
"types": TYPE_NARROW_INT_FP,
@@ -2913,6 +2943,9 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
+ "data_gen": {
+ "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
+ },
},
"avg_pool2d": {
"op": Op.AVG_POOL2D,
@@ -3853,7 +3886,7 @@ class TosaTestGen:
"build_fcn": (
build_reduce,
TosaTensorGen.tgBasic,
- TosaTensorValuesGen.tvgDefault,
+ TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agAxis,
),
"types": TYPE_BOOL,
@@ -3875,7 +3908,7 @@ class TosaTestGen:
"build_fcn": (
build_reduce,
TosaTensorGen.tgBasic,
- TosaTensorValuesGen.tvgDefault,
+ TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agAxis,
),
"types": TYPE_BOOL,
@@ -3897,7 +3930,7 @@ class TosaTestGen:
"build_fcn": (
build_reduce,
TosaTensorGen.tgBasic,
- TosaTensorValuesGen.tvgDefault,
+ TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agAxis,
),
"types": TYPE_INT_FP,
@@ -3911,6 +3944,9 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
+ "data_gen": {
+ "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
+ },
},
"reduce_min": {
"op": Op.REDUCE_MIN,
@@ -3919,7 +3955,7 @@ class TosaTestGen:
"build_fcn": (
build_reduce,
TosaTensorGen.tgBasic,
- TosaTensorValuesGen.tvgDefault,
+ TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agAxis,
),
"types": TYPE_INT_FP,
@@ -3933,6 +3969,9 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
+ "data_gen": {
+ "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
+ },
},
"reduce_product": {
"op": Op.REDUCE_PRODUCT,
@@ -3941,7 +3980,7 @@ class TosaTestGen:
"build_fcn": (
build_reduce,
TosaTensorGen.tgBasic,
- TosaTensorValuesGen.tvgDefault,
+ TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agAxis,
),
"types": TYPE_FP,
@@ -3977,6 +4016,9 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
),
+ "data_gen": {
+ "fp": (gtu.DataGenType.DOT_PRODUCT,),
+ },
},
# Data layout operators
"concat": {
@@ -4030,7 +4072,7 @@ class TosaTestGen:
"build_fcn": (
build_dim,
TosaTensorGen.tgBasic,
- TosaTensorValuesGen.tvgDefault,
+ TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agAxis,
),
"types": TYPE_FIB,
@@ -4069,7 +4111,7 @@ class TosaTestGen:
"build_fcn": (
build_reverse,
TosaTensorGen.tgBasic,
- TosaTensorValuesGen.tvgDefault,
+ TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agAxis,
),
"types": TYPE_FIB,
@@ -4892,9 +4934,9 @@ class OutputShaper:
return ser.addOutput(output_shape, out_dtype)
@staticmethod
- def concatOp(ser, rng, axis, *a, error_name=None):
- input1 = a[0]
- remaining_inputs = a[1:]
+ def concatOp(ser, rng, axis, inputs, error_name=None):
+ input1 = inputs[0]
+ remaining_inputs = inputs[1:]
# calculate the output shape, if possible, otherwise just use the first input shape
output_shape = input1.shape.copy()