aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_test_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r--verif/generator/tosa_test_gen.py103
1 files changed, 55 insertions, 48 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 262a652..b0e7c8c 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -216,20 +216,19 @@ class TosaTestGen:
# build_placeholder returns an int, ABS/other ops does not
if isinstance(op, int):
- self.ser.addOperator(op, a.name, result_tens.name, None, qinfo)
+ self.ser.addOperator(op, a.name, result_tens.name, None)
return result_tens
elif op["op"] == Op.IDENTITY:
- self.ser.addOperator(op["op"], a.name, result_tens.name, None, qinfo)
+ self.ser.addOperator(op["op"], a.name, result_tens.name, None)
return result_tens
# Ensure new output type has correct qinfo
if error_name == ErrorIf.WrongOutputType:
if result_tens.dtype not in [DType.INT8, DType.UINT8]:
- qinfo = ts.TosaSerializerQuantInfo()
- qinfo.UnaryQuantInfo(
- TosaQuantGen.getQinfo(self, a.dtype),
- TosaQuantGen.getQinfo(self, result_tens.dtype),
- )
+ qinfo = [
+ TosaQuantGen.getZeroPoint(self, a.dtype),
+ TosaQuantGen.getZeroPoint(self, result_tens.dtype),
+ ]
# Invalidate Input/Output list for error if checks.
input_list = [a.name]
@@ -255,7 +254,12 @@ class TosaTestGen:
):
return None
- self.ser.addOperator(op["op"], input_list, output_list, None, qinfo)
+ attr = None
+ if op["op"] == Op.NEGATE:
+ attr = ts.TosaSerializerAttribute()
+ attr.NegateAttribute(qinfo[0], qinfo[1])
+
+ self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
def build_binary_broadcast(self, op, a, b, validator_fcns, error_name=None):
@@ -542,11 +546,10 @@ class TosaTestGen:
# Ensure new output type has correct qinfo
if error_name == ErrorIf.WrongInputType:
if input.dtype not in [DType.INT8, DType.UINT8]:
- qinfo = ts.TosaSerializerQuantInfo()
- qinfo.UnaryQuantInfo(
- TosaQuantGen.getQinfo(self, input.dtype),
- TosaQuantGen.getQinfo(self, result_tens.dtype),
- )
+ qinfo = [
+ TosaQuantGen.getZeroPoint(self, input.dtype),
+ TosaQuantGen.getZeroPoint(self, result_tens.dtype),
+ ]
# Invalidate Input/Output list for error if checks.
input_list = [input.name]
@@ -577,10 +580,13 @@ class TosaTestGen:
):
return None
+ if qinfo is None:
+ qinfo = [0, 0]
+
attr = ts.TosaSerializerAttribute()
- attr.PoolAttribute(kernel, stride, pad)
+ attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1])
- self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo)
+ self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
def build_conv2d(
@@ -606,11 +612,10 @@ class TosaTestGen:
DType.INT8,
DType.UINT8,
):
- qinfo = ts.TosaSerializerQuantInfo()
- qinfo.ConvQuantInfo(
- TosaQuantGen.getQinfo(self, ifm.dtype),
- TosaQuantGen.getQinfo(self, result_tens.dtype),
- )
+ qinfo = [
+ TosaQuantGen.getZeroPoint(self, ifm.dtype),
+ TosaQuantGen.getZeroPoint(self, result_tens.dtype),
+ ]
# Invalidate Input/Output list for error_if checks.
input_list = [ifm.name, filter.name, bias.name]
@@ -642,9 +647,9 @@ class TosaTestGen:
return None
attr = ts.TosaSerializerAttribute()
- attr.ConvAttribute(padding, strides, dilations)
+ attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
- self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo)
+ self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
def build_conv3d(
@@ -670,11 +675,10 @@ class TosaTestGen:
DType.INT8,
DType.UINT8,
):
- qinfo = ts.TosaSerializerQuantInfo()
- qinfo.ConvQuantInfo(
- TosaQuantGen.getQinfo(self, ifm.dtype),
- TosaQuantGen.getQinfo(self, result_tens.dtype),
- )
+ qinfo = [
+ TosaQuantGen.getZeroPoint(self, ifm.dtype),
+ TosaQuantGen.getZeroPoint(self, result_tens.dtype),
+ ]
# Invalidate Input/Output list for error_if checks.
input_list = [ifm.name, filter.name, bias.name]
@@ -706,9 +710,9 @@ class TosaTestGen:
return None
attr = ts.TosaSerializerAttribute()
- attr.ConvAttribute(padding, strides, dilations)
+ attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
- self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo)
+ self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
def build_transpose_conv2d(
@@ -734,11 +738,10 @@ class TosaTestGen:
DType.INT8,
DType.UINT8,
):
- qinfo = ts.TosaSerializerQuantInfo()
- qinfo.ConvQuantInfo(
- TosaQuantGen.getQinfo(self, ifm.dtype),
- TosaQuantGen.getQinfo(self, result_tens.dtype),
- )
+ qinfo = [
+ TosaQuantGen.getZeroPoint(self, ifm.dtype),
+ TosaQuantGen.getZeroPoint(self, result_tens.dtype),
+ ]
# Invalidate Input/Output list for error_if checks.
input_list = [ifm.name, filter.name, bias.name]
@@ -769,9 +772,9 @@ class TosaTestGen:
return None
attr = ts.TosaSerializerAttribute()
- attr.TransposeConvAttribute(out_pad, stride, output_shape)
+ attr.TransposeConvAttribute(out_pad, stride, output_shape, qinfo[0], qinfo[1])
- self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo)
+ self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
def build_depthwise_conv2d(
@@ -796,11 +799,10 @@ class TosaTestGen:
DType.INT8,
DType.UINT8,
):
- qinfo = ts.TosaSerializerQuantInfo()
- qinfo.ConvQuantInfo(
- TosaQuantGen.getQinfo(self, ifm.dtype),
- TosaQuantGen.getQinfo(self, result_tens.dtype),
- )
+ qinfo = [
+ TosaQuantGen.getZeroPoint(self, ifm.dtype),
+ TosaQuantGen.getZeroPoint(self, result_tens.dtype),
+ ]
# Invalidate Input/Output list for error_if checks.
input_list = [ifm.name, filter.name, bias.name]
@@ -832,9 +834,9 @@ class TosaTestGen:
return None
attr = ts.TosaSerializerAttribute()
- attr.ConvAttribute(padding, strides, dilations)
+ attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
- self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo)
+ self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
def build_fully_connected(
@@ -871,7 +873,10 @@ class TosaTestGen:
):
return None
- self.ser.addOperator(op["op"], input_list, output_list, None, qinfo)
+ attr = ts.TosaSerializerAttribute()
+ attr.FullyConnectedAttribute(qinfo[0], qinfo[1])
+
+ self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
def build_matmul(self, op, a, b, validator_fcns=None, error_name=None, qinfo=None):
@@ -905,7 +910,10 @@ class TosaTestGen:
):
return None
- self.ser.addOperator(op["op"], input_list, output_list, None, qinfo)
+ attr = ts.TosaSerializerAttribute()
+ attr.MatMulAttribute(qinfo[0], qinfo[1])
+
+ self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
def build_reduce(self, op, a, axis, validator_fcns, error_name=None):
@@ -1164,7 +1172,7 @@ class TosaTestGen:
):
return None
- self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo)
+ self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
def build_reshape(self, op, a, newShape, validator_fcns=None, error_name=None):
@@ -2212,7 +2220,7 @@ class TosaTestGen:
else:
qinfo = None
- tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, qinfo, error_name)
+ tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, error_name)
try:
if error_if_validators is None:
@@ -3425,7 +3433,6 @@ class TosaTestGen:
TosaTensorValuesGen.tvgDefault,
TosaArgGen.agPad,
),
- "qgen": TosaQuantGen.qgPad,
"types": TYPE_FIB,
"error_if_validators": (
TosaErrorValidator.evWrongInputType,