aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLes Bell <les.bell@arm.com>2021-07-23 09:43:31 +0100
committerDominic Symes <dominic.symes@arm.com>2021-07-29 16:38:48 +0000
commit30e4680b789c728d503cf76ee316468969943d51 (patch)
tree8980c761beb6a991a229b7896d5ce16a51fde543
parent18e2666b159eb75be9cafe7531d1512b6b24abc8 (diff)
downloadreference_model-30e4680b789c728d503cf76ee316468969943d51.tar.gz
fix quantization zero-point generation
* for int8 and uint8 only * not for int16 * pass all data types to the quantizaion generator functions as some operators (all convolutions, fully_connected) may use different types for the input and weight tensors * also allows input type filtering for convolution operators Change-Id: Iea3d00f03807a8db35b40d4e8988929ec6549b44 Signed-off-by: Les Bell <les.bell@arm.com>
-rw-r--r--verif/tosa_test_gen.py44
1 files changed, 22 insertions, 22 deletions
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py
index 2ccbb0a..c458538 100644
--- a/verif/tosa_test_gen.py
+++ b/verif/tosa_test_gen.py
@@ -55,45 +55,45 @@ class TosaQuantGen:
pass
@staticmethod
- def needsQinfo(op, dtype):
- if dtype == DType.INT8 or dtype == DType.INT16:
- return True
- return False
+ def getQinfo(testGen, dtype):
+ if dtype == DType.INT8:
+ return testGen.randInt(-128, 128)
+ if dtype == DType.UINT8:
+ return testGen.randInt(0, 256)
+ return 0
@staticmethod
def qgUnary(testGen, op, dtype):
qinfo = ts.TosaSerializerQuantInfo()
- if TosaQuantGen.needsQinfo(op, dtype):
- qinfo.UnaryQuantInfo(testGen.randInt(), testGen.randInt())
- else:
- qinfo.UnaryQuantInfo(0, 0)
+ qinfo.UnaryQuantInfo(TosaQuantGen.getQinfo(testGen, dtype),
+ TosaQuantGen.getQinfo(testGen, dtype))
return qinfo
@staticmethod
- def qgConv(testGen, op, dtype):
+ def qgConv(testGen, op, dtype_or_dtypeList):
qinfo = ts.TosaSerializerQuantInfo()
- if TosaQuantGen.needsQinfo(op, dtype):
- qinfo.ConvQuantInfo(testGen.randInt(), testGen.randInt())
+ if isinstance(dtype_or_dtypeList, list):
+ # a list of [input, weights, accumulator] dtypes
+ dtypeList = dtype_or_dtypeList
else:
- qinfo.ConvQuantInfo(0, 0)
+ # an int, [input, weights, accumulator] dtypes are the same
+ dtypeList = [dtype_or_dtypeList] * 3
+ input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0])
+ weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1])
+ qinfo.ConvQuantInfo(input_zp, weights_zp)
return qinfo
@staticmethod
def qgMatmul(testGen, op, dtype):
qinfo = ts.TosaSerializerQuantInfo()
- if TosaQuantGen.needsQinfo(op, dtype):
- qinfo.MatMulQuantInfo(testGen.randInt(), testGen.randInt())
- else:
- qinfo.MatMulQuantInfo(0, 0)
+ qinfo.MatMulQuantInfo(TosaQuantGen.getQinfo(testGen, dtype),
+ TosaQuantGen.getQinfo(testGen, dtype))
return qinfo
@staticmethod
def qgPad(testGen, op, dtype):
qinfo = ts.TosaSerializerQuantInfo()
- if TosaQuantGen.needsQinfo(op, dtype):
- qinfo.PadQuantInfo(testGen.randInt())
- else:
- qinfo.PadQuantInfo(0)
+ qinfo.PadQuantInfo(TosaQuantGen.getQinfo(testGen, dtype))
return qinfo
@staticmethod
@@ -1636,7 +1636,7 @@ class TosaTestGen:
# Filter tests based on dtype?
if dtypeFilter is not None:
- if t not in dtypeFilter:
+ if not (t in dtypeFilter or (isinstance(t, list) and t[0] in dtypeFilter)):
continue
# Create the placeholder and const tensors
@@ -1836,7 +1836,7 @@ class TosaTestGen:
tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
if qgen is not None:
- qinfo = qgen(self, op, dtypeList[0])
+ qinfo = qgen(self, op, dtype_or_dtypeList)
else:
qinfo = None