diff options
author | Les Bell <les.bell@arm.com> | 2021-07-23 09:43:31 +0100 |
---|---|---|
committer | Dominic Symes <dominic.symes@arm.com> | 2021-07-29 16:38:48 +0000 |
commit | 30e4680b789c728d503cf76ee316468969943d51 (patch) | |
tree | 8980c761beb6a991a229b7896d5ce16a51fde543 /verif | |
parent | 18e2666b159eb75be9cafe7531d1512b6b24abc8 (diff) | |
download | reference_model-30e4680b789c728d503cf76ee316468969943d51.tar.gz |
fix quantization zero-point generation
* for int8 and uint8 only
* not for int16
* pass all data types to the quantizaion generator functions as some
operators (all convolutions, fully_connected) may use different types
for the input and weight tensors
* also allows input type filtering for convolution operators
Change-Id: Iea3d00f03807a8db35b40d4e8988929ec6549b44
Signed-off-by: Les Bell <les.bell@arm.com>
Diffstat (limited to 'verif')
-rw-r--r-- | verif/tosa_test_gen.py | 44 |
1 files changed, 22 insertions, 22 deletions
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py index 2ccbb0a..c458538 100644 --- a/verif/tosa_test_gen.py +++ b/verif/tosa_test_gen.py @@ -55,45 +55,45 @@ class TosaQuantGen: pass @staticmethod - def needsQinfo(op, dtype): - if dtype == DType.INT8 or dtype == DType.INT16: - return True - return False + def getQinfo(testGen, dtype): + if dtype == DType.INT8: + return testGen.randInt(-128, 128) + if dtype == DType.UINT8: + return testGen.randInt(0, 256) + return 0 @staticmethod def qgUnary(testGen, op, dtype): qinfo = ts.TosaSerializerQuantInfo() - if TosaQuantGen.needsQinfo(op, dtype): - qinfo.UnaryQuantInfo(testGen.randInt(), testGen.randInt()) - else: - qinfo.UnaryQuantInfo(0, 0) + qinfo.UnaryQuantInfo(TosaQuantGen.getQinfo(testGen, dtype), + TosaQuantGen.getQinfo(testGen, dtype)) return qinfo @staticmethod - def qgConv(testGen, op, dtype): + def qgConv(testGen, op, dtype_or_dtypeList): qinfo = ts.TosaSerializerQuantInfo() - if TosaQuantGen.needsQinfo(op, dtype): - qinfo.ConvQuantInfo(testGen.randInt(), testGen.randInt()) + if isinstance(dtype_or_dtypeList, list): + # a list of [input, weights, accumulator] dtypes + dtypeList = dtype_or_dtypeList else: - qinfo.ConvQuantInfo(0, 0) + # an int, [input, weights, accumulator] dtypes are the same + dtypeList = [dtype_or_dtypeList] * 3 + input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0]) + weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1]) + qinfo.ConvQuantInfo(input_zp, weights_zp) return qinfo @staticmethod def qgMatmul(testGen, op, dtype): qinfo = ts.TosaSerializerQuantInfo() - if TosaQuantGen.needsQinfo(op, dtype): - qinfo.MatMulQuantInfo(testGen.randInt(), testGen.randInt()) - else: - qinfo.MatMulQuantInfo(0, 0) + qinfo.MatMulQuantInfo(TosaQuantGen.getQinfo(testGen, dtype), + TosaQuantGen.getQinfo(testGen, dtype)) return qinfo @staticmethod def qgPad(testGen, op, dtype): qinfo = ts.TosaSerializerQuantInfo() - if TosaQuantGen.needsQinfo(op, dtype): - qinfo.PadQuantInfo(testGen.randInt()) - else: - qinfo.PadQuantInfo(0) + qinfo.PadQuantInfo(TosaQuantGen.getQinfo(testGen, dtype)) return qinfo @staticmethod @@ -1636,7 +1636,7 @@ class TosaTestGen: # Filter tests based on dtype? if dtypeFilter is not None: - if t not in dtypeFilter: + if not (t in dtypeFilter or (isinstance(t, list) and t[0] in dtypeFilter)): continue # Create the placeholder and const tensors @@ -1836,7 +1836,7 @@ class TosaTestGen: tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:])) if qgen is not None: - qinfo = qgen(self, op, dtypeList[0]) + qinfo = qgen(self, op, dtype_or_dtypeList) else: qinfo = None |