aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--verif/tosa_test_gen.py44
1 files changed, 22 insertions, 22 deletions
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py
index 2ccbb0a..c458538 100644
--- a/verif/tosa_test_gen.py
+++ b/verif/tosa_test_gen.py
@@ -55,45 +55,45 @@ class TosaQuantGen:
pass
@staticmethod
- def needsQinfo(op, dtype):
- if dtype == DType.INT8 or dtype == DType.INT16:
- return True
- return False
+ def getQinfo(testGen, dtype):
+ if dtype == DType.INT8:
+ return testGen.randInt(-128, 128)
+ if dtype == DType.UINT8:
+ return testGen.randInt(0, 256)
+ return 0
@staticmethod
def qgUnary(testGen, op, dtype):
qinfo = ts.TosaSerializerQuantInfo()
- if TosaQuantGen.needsQinfo(op, dtype):
- qinfo.UnaryQuantInfo(testGen.randInt(), testGen.randInt())
- else:
- qinfo.UnaryQuantInfo(0, 0)
+ qinfo.UnaryQuantInfo(TosaQuantGen.getQinfo(testGen, dtype),
+ TosaQuantGen.getQinfo(testGen, dtype))
return qinfo
@staticmethod
- def qgConv(testGen, op, dtype):
+ def qgConv(testGen, op, dtype_or_dtypeList):
qinfo = ts.TosaSerializerQuantInfo()
- if TosaQuantGen.needsQinfo(op, dtype):
- qinfo.ConvQuantInfo(testGen.randInt(), testGen.randInt())
+ if isinstance(dtype_or_dtypeList, list):
+ # a list of [input, weights, accumulator] dtypes
+ dtypeList = dtype_or_dtypeList
else:
- qinfo.ConvQuantInfo(0, 0)
+ # an int, [input, weights, accumulator] dtypes are the same
+ dtypeList = [dtype_or_dtypeList] * 3
+ input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0])
+ weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1])
+ qinfo.ConvQuantInfo(input_zp, weights_zp)
return qinfo
@staticmethod
def qgMatmul(testGen, op, dtype):
qinfo = ts.TosaSerializerQuantInfo()
- if TosaQuantGen.needsQinfo(op, dtype):
- qinfo.MatMulQuantInfo(testGen.randInt(), testGen.randInt())
- else:
- qinfo.MatMulQuantInfo(0, 0)
+ qinfo.MatMulQuantInfo(TosaQuantGen.getQinfo(testGen, dtype),
+ TosaQuantGen.getQinfo(testGen, dtype))
return qinfo
@staticmethod
def qgPad(testGen, op, dtype):
qinfo = ts.TosaSerializerQuantInfo()
- if TosaQuantGen.needsQinfo(op, dtype):
- qinfo.PadQuantInfo(testGen.randInt())
- else:
- qinfo.PadQuantInfo(0)
+ qinfo.PadQuantInfo(TosaQuantGen.getQinfo(testGen, dtype))
return qinfo
@staticmethod
@@ -1636,7 +1636,7 @@ class TosaTestGen:
# Filter tests based on dtype?
if dtypeFilter is not None:
- if t not in dtypeFilter:
+ if not (t in dtypeFilter or (isinstance(t, list) and t[0] in dtypeFilter)):
continue
# Create the placeholder and const tensors
@@ -1836,7 +1836,7 @@ class TosaTestGen:
tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
if qgen is not None:
- qinfo = qgen(self, op, dtypeList[0])
+ qinfo = qgen(self, op, dtype_or_dtypeList)
else:
qinfo = None