aboutsummaryrefslogtreecommitdiff
path: root/verif
diff options
context:
space:
mode:
authorEric Kunze <eric.kunze@arm.com>2022-06-07 05:20:44 +0000
committerEric Kunze <eric.kunze@arm.com>2022-06-15 11:38:04 -0700
commitb5fabec33abeca2d92c20c7b094fa3f113d0ddd8 (patch)
tree9c7d946012c7a70a7fcb237daa4376d7b65c6f76 /verif
parent24594f55ee3bf0e95c764e51b94c3ec7f9cfa54a (diff)
downloadreference_model-b5fabec33abeca2d92c20c7b094fa3f113d0ddd8.tar.gz
Remove quantization info from serialization attributes
Any needed information moves into the attributes for each operator. New serialization library version removes teh quantization information attributes from the schema Signed-off-by: Eric Kunze <eric.kunze@arm.com> Change-Id: Icf6165687ab1fd34a01f64c01b0b92b2820e72fa
Diffstat (limited to 'verif')
-rw-r--r--verif/generator/tosa_arg_gen.py119
-rw-r--r--verif/generator/tosa_error_if.py7
-rw-r--r--verif/generator/tosa_test_gen.py103
3 files changed, 109 insertions, 120 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index a27d849..8e00fab 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -4,7 +4,6 @@ import itertools
import math
import numpy as np
-import serializer.tosa_serializer as ts
from generator.tosa_error_if import ErrorIf
from generator.tosa_error_if import TosaErrorIfArgGen
from serializer.tosa_serializer import DTypeNames
@@ -26,7 +25,7 @@ class TosaQuantGen:
pass
@staticmethod
- def getQinfo(testGen, dtype, error_name=None):
+ def getZeroPoint(testGen, dtype, error_name=None):
if dtype == DType.INT8:
return testGen.randInt(-128, 128)
@@ -45,27 +44,25 @@ class TosaQuantGen:
@staticmethod
def qgUnary(testGen, op, dtype, error_name=None):
- qinfo = ts.TosaSerializerQuantInfo()
if error_name == ErrorIf.InputZeroPointNotZero:
- qinfo.UnaryQuantInfo(
- TosaQuantGen.getQinfo(testGen, dtype, error_name),
- TosaQuantGen.getQinfo(testGen, dtype),
- )
+ qinfo = [
+ TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
+ TosaQuantGen.getZeroPoint(testGen, dtype),
+ ]
elif error_name == ErrorIf.OutputZeroPointNotZero:
- qinfo.UnaryQuantInfo(
- TosaQuantGen.getQinfo(testGen, dtype),
- TosaQuantGen.getQinfo(testGen, dtype, error_name),
- )
+ qinfo = [
+ TosaQuantGen.getZeroPoint(testGen, dtype),
+ TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
+ ]
else:
- qinfo.UnaryQuantInfo(
- TosaQuantGen.getQinfo(testGen, dtype),
- TosaQuantGen.getQinfo(testGen, dtype),
- )
+ qinfo = [
+ TosaQuantGen.getZeroPoint(testGen, dtype),
+ TosaQuantGen.getZeroPoint(testGen, dtype),
+ ]
return qinfo
@staticmethod
def qgConv(testGen, op, dtype_or_dtypeList, error_name=None):
- qinfo = ts.TosaSerializerQuantInfo()
if isinstance(dtype_or_dtypeList, list):
# a list of [input, weights, accumulator] dtypes
dtypeList = dtype_or_dtypeList
@@ -74,40 +71,34 @@ class TosaQuantGen:
dtypeList = [dtype_or_dtypeList] * 3
if error_name == ErrorIf.InputZeroPointNotZero:
- input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0], error_name)
- weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1])
+ qinfo = [
+ TosaQuantGen.getZeroPoint(testGen, dtypeList[0], error_name),
+ TosaQuantGen.getZeroPoint(testGen, dtypeList[1]),
+ ]
elif error_name == ErrorIf.WeightZeroPointNotZero:
- input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0])
- weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1], error_name)
+ qinfo = [
+ TosaQuantGen.getZeroPoint(testGen, dtypeList[0]),
+ TosaQuantGen.getZeroPoint(testGen, dtypeList[1], error_name),
+ ]
else:
- input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0])
- weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1])
-
- qinfo.ConvQuantInfo(input_zp, weights_zp)
+ qinfo = [
+ TosaQuantGen.getZeroPoint(testGen, dtypeList[0]),
+ TosaQuantGen.getZeroPoint(testGen, dtypeList[1]),
+ ]
return qinfo
@staticmethod
def qgMatmul(testGen, op, dtype, error_name=None):
- qinfo = ts.TosaSerializerQuantInfo()
if error_name == ErrorIf.InputZeroPointNotZero:
- qinfo.MatMulQuantInfo(
- TosaQuantGen.getQinfo(testGen, dtype, error_name),
- TosaQuantGen.getQinfo(testGen, dtype, error_name),
- )
+ qinfo = [
+ TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
+ TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
+ ]
else:
- qinfo.MatMulQuantInfo(
- TosaQuantGen.getQinfo(testGen, dtype),
- TosaQuantGen.getQinfo(testGen, dtype),
- )
- return qinfo
-
- @staticmethod
- def qgPad(testGen, op, dtype, error_name=None):
- qinfo = ts.TosaSerializerQuantInfo()
- if error_name == ErrorIf.InputZeroPointNotZero:
- qinfo.PadQuantInfo(TosaQuantGen.getQinfo(testGen, dtype, error_name))
- else:
- qinfo.PadQuantInfo(TosaQuantGen.getQinfo(testGen, dtype))
+ qinfo = [
+ TosaQuantGen.getZeroPoint(testGen, dtype),
+ TosaQuantGen.getZeroPoint(testGen, dtype),
+ ]
return qinfo
@staticmethod
@@ -550,7 +541,7 @@ class TosaTensorValuesGen:
pass
@staticmethod
- def tvgDefault(testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None):
+ def tvgDefault(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
pCount, cCount = op["operands"]
tens = []
@@ -562,7 +553,7 @@ class TosaTensorValuesGen:
return tens
@staticmethod
- def tvgNegate(testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None):
+ def tvgNegate(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
if dtypeList[0] == DType.INT32 and error_name is None:
pCount, cCount = op["operands"]
assert (
@@ -582,11 +573,11 @@ class TosaTensorValuesGen:
return placeholders
else:
return TosaTensorValuesGen.tvgDefault(
- testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name
+ testGen, op, dtypeList, shapeList, testArgs, error_name
)
@staticmethod
- def tvgAddSub(testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None):
+ def tvgAddSub(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
if dtypeList[0] == DType.INT32 and error_name is None:
# Make sure the operation does not cause value saturation - where
# the number wraps due to limited number of bits to store the answer
@@ -651,12 +642,12 @@ class TosaTensorValuesGen:
return placeholders
else:
return TosaTensorValuesGen.tvgDefault(
- testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name
+ testGen, op, dtypeList, shapeList, testArgs, error_name
)
@staticmethod
def tvgCondIfWhileLoop(
- testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None
+ testGen, op, dtypeList, shapeList, testArgs, error_name=None
):
if dtypeList[0] in (
DType.INT32,
@@ -689,12 +680,12 @@ class TosaTensorValuesGen:
return placeholders
else:
return TosaTensorValuesGen.tvgDefault(
- testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name
+ testGen, op, dtypeList, shapeList, testArgs, error_name
)
@staticmethod
def tvgArithmeticRightShift(
- testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None
+ testGen, op, dtypeList, shapeList, testArgs, error_name=None
):
pCount, cCount = op["operands"]
# Force value of operand[1] to be within [0, num_bits]
@@ -722,16 +713,16 @@ class TosaTensorValuesGen:
return placeholders
@staticmethod
- def tvgSelect(testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None):
+ def tvgSelect(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
# Set datatype of condition tensor to boolean
dtypeList[0] = DType.BOOL
return TosaTensorValuesGen.tvgDefault(
- testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name
+ testGen, op, dtypeList, shapeList, testArgs, error_name
)
@staticmethod
- def tvgIntDiv(testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None):
+ def tvgIntDiv(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
if error_name is None:
pCount, cCount = op["operands"]
assert (
@@ -765,11 +756,11 @@ class TosaTensorValuesGen:
return placeholders
else:
return TosaTensorValuesGen.tvgDefault(
- testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name
+ testGen, op, dtypeList, shapeList, testArgs, error_name
)
@staticmethod
- def tvgMul(testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None):
+ def tvgMul(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
if error_name is None:
pCount, cCount = op["operands"]
assert (
@@ -839,11 +830,11 @@ class TosaTensorValuesGen:
return tens
else:
return TosaTensorValuesGen.tvgDefault(
- testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name
+ testGen, op, dtypeList, shapeList, testArgs, error_name
)
@staticmethod
- def tvgConcat(testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None):
+ def tvgConcat(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
count = len(shapeList) - testGen.args.num_const_inputs_concat
if count < 1:
count = 1
@@ -866,9 +857,7 @@ class TosaTensorValuesGen:
return tens
@staticmethod
- def tvgLogicalShift(
- testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None
- ):
+ def tvgLogicalShift(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
pCount, cCount = op["operands"]
assert (
pCount == 2 and cCount == 0
@@ -886,7 +875,7 @@ class TosaTensorValuesGen:
return placeholders
@staticmethod
- def tvgEqual(testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None):
+ def tvgEqual(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
if error_name is None:
pCount, cCount = op["operands"]
assert (
@@ -924,13 +913,11 @@ class TosaTensorValuesGen:
return placeholders
else:
return TosaTensorValuesGen.tvgDefault(
- testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name
+ testGen, op, dtypeList, shapeList, testArgs, error_name
)
@staticmethod
- def tvgReduceSum(
- testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None
- ):
+ def tvgReduceSum(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
if dtypeList[0] == DType.INT32:
pCount, cCount = op["operands"]
assert (
@@ -949,7 +936,7 @@ class TosaTensorValuesGen:
return placeholders
else:
return TosaTensorValuesGen.tvgDefault(
- testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name
+ testGen, op, dtypeList, shapeList, testArgs, error_name
)
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index 1967d8a..b331a42 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -1003,12 +1003,7 @@ class TosaErrorValidator:
Generally input_zp is index 0, output_zp is index 1
"""
- if isinstance(qinfo, tuple):
- zero_point = qinfo[index]
- else:
- # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
- zero_point = qinfo.ints[index][1]
- return zero_point
+ return qinfo[index]
@staticmethod
def evInputZeroPointNotZero(check=False, **kwargs):
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 262a652..b0e7c8c 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -216,20 +216,19 @@ class TosaTestGen:
# build_placeholder returns an int, ABS/other ops does not
if isinstance(op, int):
- self.ser.addOperator(op, a.name, result_tens.name, None, qinfo)
+ self.ser.addOperator(op, a.name, result_tens.name, None)
return result_tens
elif op["op"] == Op.IDENTITY:
- self.ser.addOperator(op["op"], a.name, result_tens.name, None, qinfo)
+ self.ser.addOperator(op["op"], a.name, result_tens.name, None)
return result_tens
# Ensure new output type has correct qinfo
if error_name == ErrorIf.WrongOutputType:
if result_tens.dtype not in [DType.INT8, DType.UINT8]:
- qinfo = ts.TosaSerializerQuantInfo()
- qinfo.UnaryQuantInfo(
- TosaQuantGen.getQinfo(self, a.dtype),
- TosaQuantGen.getQinfo(self, result_tens.dtype),
- )
+ qinfo = [
+ TosaQuantGen.getZeroPoint(self, a.dtype),
+ TosaQuantGen.getZeroPoint(self, result_tens.dtype),
+ ]
# Invalidate Input/Output list for error if checks.
input_list = [a.name]
@@ -255,7 +254,12 @@ class TosaTestGen:
):
return None
- self.ser.addOperator(op["op"], input_list, output_list, None, qinfo)
+ attr = None
+ if op["op"] == Op.NEGATE:
+ attr = ts.TosaSerializerAttribute()
+ attr.NegateAttribute(qinfo[0], qinfo[1])
+
+ self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
def build_binary_broadcast(self, op, a, b, validator_fcns, error_name=None):
@@ -542,11 +546,10 @@ class TosaTestGen:
# Ensure new output type has correct qinfo
if error_name == ErrorIf.WrongInputType:
if input.dtype not in [DType.INT8, DType.UINT8]:
- qinfo = ts.TosaSerializerQuantInfo()
- qinfo.UnaryQuantInfo(
- TosaQuantGen.getQinfo(self, input.dtype),
- TosaQuantGen.getQinfo(self, result_tens.dtype),
- )
+ qinfo = [
+ TosaQuantGen.getZeroPoint(self, input.dtype),
+ TosaQuantGen.getZeroPoint(self, result_tens.dtype),
+ ]
# Invalidate Input/Output list for error if checks.
input_list = [input.name]
@@ -577,10 +580,13 @@ class TosaTestGen:
):
return None
+ if qinfo is None:
+ qinfo = [0, 0]
+
attr = ts.TosaSerializerAttribute()
- attr.PoolAttribute(kernel, stride, pad)
+ attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1])
- self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo)
+ self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
def build_conv2d(
@@ -606,11 +612,10 @@ class TosaTestGen:
DType.INT8,
DType.UINT8,
):
- qinfo = ts.TosaSerializerQuantInfo()
- qinfo.ConvQuantInfo(
- TosaQuantGen.getQinfo(self, ifm.dtype),
- TosaQuantGen.getQinfo(self, result_tens.dtype),
- )
+ qinfo = [
+ TosaQuantGen.getZeroPoint(self, ifm.dtype),
+ TosaQuantGen.getZeroPoint(self, result_tens.dtype),
+ ]
# Invalidate Input/Output list for error_if checks.
input_list = [ifm.name, filter.name, bias.name]
@@ -642,9 +647,9 @@ class TosaTestGen:
return None
attr = ts.TosaSerializerAttribute()
- attr.ConvAttribute(padding, strides, dilations)
+ attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
- self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo)
+ self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
def build_conv3d(
@@ -670,11 +675,10 @@ class TosaTestGen:
DType.INT8,
DType.UINT8,
):
- qinfo = ts.TosaSerializerQuantInfo()
- qinfo.ConvQuantInfo(
- TosaQuantGen.getQinfo(self, ifm.dtype),
- TosaQuantGen.getQinfo(self, result_tens.dtype),
- )
+ qinfo = [
+ TosaQuantGen.getZeroPoint(self, ifm.dtype),
+ TosaQuantGen.getZeroPoint(self, result_tens.dtype),
+ ]
# Invalidate Input/Output list for error_if checks.
input_list = [ifm.name, filter.name, bias.name]
@@ -706,9 +710,9 @@ class TosaTestGen:
return None
attr = ts.TosaSerializerAttribute()
- attr.ConvAttribute(padding, strides, dilations)
+ attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
- self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo)
+ self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
def build_transpose_conv2d(
@@ -734,11 +738,10 @@ class TosaTestGen:
DType.INT8,
DType.UINT8,
):
- qinfo = ts.TosaSerializerQuantInfo()
- qinfo.ConvQuantInfo(
- TosaQuantGen.getQinfo(self, ifm.dtype),
- TosaQuantGen.getQinfo(self, result_tens.dtype),
- )
+ qinfo = [
+ TosaQuantGen.getZeroPoint(self, ifm.dtype),
+ TosaQuantGen.getZeroPoint(self, result_tens.dtype),
+ ]
# Invalidate Input/Output list for error_if checks.
input_list = [ifm.name, filter.name, bias.name]
@@ -769,9 +772,9 @@ class TosaTestGen:
return None
attr = ts.TosaSerializerAttribute()
- attr.TransposeConvAttribute(out_pad, stride, output_shape)
+ attr.TransposeConvAttribute(out_pad, stride, output_shape, qinfo[0], qinfo[1])
- self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo)
+ self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
def build_depthwise_conv2d(
@@ -796,11 +799,10 @@ class TosaTestGen:
DType.INT8,
DType.UINT8,
):
- qinfo = ts.TosaSerializerQuantInfo()
- qinfo.ConvQuantInfo(
- TosaQuantGen.getQinfo(self, ifm.dtype),
- TosaQuantGen.getQinfo(self, result_tens.dtype),
- )
+ qinfo = [
+ TosaQuantGen.getZeroPoint(self, ifm.dtype),
+ TosaQuantGen.getZeroPoint(self, result_tens.dtype),
+ ]
# Invalidate Input/Output list for error_if checks.
input_list = [ifm.name, filter.name, bias.name]
@@ -832,9 +834,9 @@ class TosaTestGen:
return None
attr = ts.TosaSerializerAttribute()
- attr.ConvAttribute(padding, strides, dilations)
+ attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
- self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo)
+ self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
def build_fully_connected(
@@ -871,7 +873,10 @@ class TosaTestGen:
):
return None
- self.ser.addOperator(op["op"], input_list, output_list, None, qinfo)
+ attr = ts.TosaSerializerAttribute()
+ attr.FullyConnectedAttribute(qinfo[0], qinfo[1])
+
+ self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
def build_matmul(self, op, a, b, validator_fcns=None, error_name=None, qinfo=None):
@@ -905,7 +910,10 @@ class TosaTestGen:
):
return None
- self.ser.addOperator(op["op"], input_list, output_list, None, qinfo)
+ attr = ts.TosaSerializerAttribute()
+ attr.MatMulAttribute(qinfo[0], qinfo[1])
+
+ self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
def build_reduce(self, op, a, axis, validator_fcns, error_name=None):
@@ -1164,7 +1172,7 @@ class TosaTestGen:
):
return None
- self.ser.addOperator(op["op"], input_list, output_list, attr, qinfo)
+ self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
def build_reshape(self, op, a, newShape, validator_fcns=None, error_name=None):
@@ -2212,7 +2220,7 @@ class TosaTestGen:
else:
qinfo = None
- tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, qinfo, error_name)
+ tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, error_name)
try:
if error_if_validators is None:
@@ -3425,7 +3433,6 @@ class TosaTestGen:
TosaTensorValuesGen.tvgDefault,
TosaArgGen.agPad,
),
- "qgen": TosaQuantGen.qgPad,
"types": TYPE_FIB,
"error_if_validators": (
TosaErrorValidator.evWrongInputType,