aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_arg_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_arg_gen.py')
-rw-r--r--verif/generator/tosa_arg_gen.py119
1 files changed, 53 insertions, 66 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index a27d849..8e00fab 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -4,7 +4,6 @@ import itertools
import math
import numpy as np
-import serializer.tosa_serializer as ts
from generator.tosa_error_if import ErrorIf
from generator.tosa_error_if import TosaErrorIfArgGen
from serializer.tosa_serializer import DTypeNames
@@ -26,7 +25,7 @@ class TosaQuantGen:
pass
@staticmethod
- def getQinfo(testGen, dtype, error_name=None):
+ def getZeroPoint(testGen, dtype, error_name=None):
if dtype == DType.INT8:
return testGen.randInt(-128, 128)
@@ -45,27 +44,25 @@ class TosaQuantGen:
@staticmethod
def qgUnary(testGen, op, dtype, error_name=None):
- qinfo = ts.TosaSerializerQuantInfo()
if error_name == ErrorIf.InputZeroPointNotZero:
- qinfo.UnaryQuantInfo(
- TosaQuantGen.getQinfo(testGen, dtype, error_name),
- TosaQuantGen.getQinfo(testGen, dtype),
- )
+ qinfo = [
+ TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
+ TosaQuantGen.getZeroPoint(testGen, dtype),
+ ]
elif error_name == ErrorIf.OutputZeroPointNotZero:
- qinfo.UnaryQuantInfo(
- TosaQuantGen.getQinfo(testGen, dtype),
- TosaQuantGen.getQinfo(testGen, dtype, error_name),
- )
+ qinfo = [
+ TosaQuantGen.getZeroPoint(testGen, dtype),
+ TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
+ ]
else:
- qinfo.UnaryQuantInfo(
- TosaQuantGen.getQinfo(testGen, dtype),
- TosaQuantGen.getQinfo(testGen, dtype),
- )
+ qinfo = [
+ TosaQuantGen.getZeroPoint(testGen, dtype),
+ TosaQuantGen.getZeroPoint(testGen, dtype),
+ ]
return qinfo
@staticmethod
def qgConv(testGen, op, dtype_or_dtypeList, error_name=None):
- qinfo = ts.TosaSerializerQuantInfo()
if isinstance(dtype_or_dtypeList, list):
# a list of [input, weights, accumulator] dtypes
dtypeList = dtype_or_dtypeList
@@ -74,40 +71,34 @@ class TosaQuantGen:
dtypeList = [dtype_or_dtypeList] * 3
if error_name == ErrorIf.InputZeroPointNotZero:
- input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0], error_name)
- weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1])
+ qinfo = [
+ TosaQuantGen.getZeroPoint(testGen, dtypeList[0], error_name),
+ TosaQuantGen.getZeroPoint(testGen, dtypeList[1]),
+ ]
elif error_name == ErrorIf.WeightZeroPointNotZero:
- input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0])
- weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1], error_name)
+ qinfo = [
+ TosaQuantGen.getZeroPoint(testGen, dtypeList[0]),
+ TosaQuantGen.getZeroPoint(testGen, dtypeList[1], error_name),
+ ]
else:
- input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0])
- weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1])
-
- qinfo.ConvQuantInfo(input_zp, weights_zp)
+ qinfo = [
+ TosaQuantGen.getZeroPoint(testGen, dtypeList[0]),
+ TosaQuantGen.getZeroPoint(testGen, dtypeList[1]),
+ ]
return qinfo
@staticmethod
def qgMatmul(testGen, op, dtype, error_name=None):
- qinfo = ts.TosaSerializerQuantInfo()
if error_name == ErrorIf.InputZeroPointNotZero:
- qinfo.MatMulQuantInfo(
- TosaQuantGen.getQinfo(testGen, dtype, error_name),
- TosaQuantGen.getQinfo(testGen, dtype, error_name),
- )
+ qinfo = [
+ TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
+ TosaQuantGen.getZeroPoint(testGen, dtype, error_name),
+ ]
else:
- qinfo.MatMulQuantInfo(
- TosaQuantGen.getQinfo(testGen, dtype),
- TosaQuantGen.getQinfo(testGen, dtype),
- )
- return qinfo
-
- @staticmethod
- def qgPad(testGen, op, dtype, error_name=None):
- qinfo = ts.TosaSerializerQuantInfo()
- if error_name == ErrorIf.InputZeroPointNotZero:
- qinfo.PadQuantInfo(TosaQuantGen.getQinfo(testGen, dtype, error_name))
- else:
- qinfo.PadQuantInfo(TosaQuantGen.getQinfo(testGen, dtype))
+ qinfo = [
+ TosaQuantGen.getZeroPoint(testGen, dtype),
+ TosaQuantGen.getZeroPoint(testGen, dtype),
+ ]
return qinfo
@staticmethod
@@ -550,7 +541,7 @@ class TosaTensorValuesGen:
pass
@staticmethod
- def tvgDefault(testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None):
+ def tvgDefault(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
pCount, cCount = op["operands"]
tens = []
@@ -562,7 +553,7 @@ class TosaTensorValuesGen:
return tens
@staticmethod
- def tvgNegate(testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None):
+ def tvgNegate(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
if dtypeList[0] == DType.INT32 and error_name is None:
pCount, cCount = op["operands"]
assert (
@@ -582,11 +573,11 @@ class TosaTensorValuesGen:
return placeholders
else:
return TosaTensorValuesGen.tvgDefault(
- testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name
+ testGen, op, dtypeList, shapeList, testArgs, error_name
)
@staticmethod
- def tvgAddSub(testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None):
+ def tvgAddSub(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
if dtypeList[0] == DType.INT32 and error_name is None:
# Make sure the operation does not cause value saturation - where
# the number wraps due to limited number of bits to store the answer
@@ -651,12 +642,12 @@ class TosaTensorValuesGen:
return placeholders
else:
return TosaTensorValuesGen.tvgDefault(
- testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name
+ testGen, op, dtypeList, shapeList, testArgs, error_name
)
@staticmethod
def tvgCondIfWhileLoop(
- testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None
+ testGen, op, dtypeList, shapeList, testArgs, error_name=None
):
if dtypeList[0] in (
DType.INT32,
@@ -689,12 +680,12 @@ class TosaTensorValuesGen:
return placeholders
else:
return TosaTensorValuesGen.tvgDefault(
- testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name
+ testGen, op, dtypeList, shapeList, testArgs, error_name
)
@staticmethod
def tvgArithmeticRightShift(
- testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None
+ testGen, op, dtypeList, shapeList, testArgs, error_name=None
):
pCount, cCount = op["operands"]
# Force value of operand[1] to be within [0, num_bits]
@@ -722,16 +713,16 @@ class TosaTensorValuesGen:
return placeholders
@staticmethod
- def tvgSelect(testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None):
+ def tvgSelect(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
# Set datatype of condition tensor to boolean
dtypeList[0] = DType.BOOL
return TosaTensorValuesGen.tvgDefault(
- testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name
+ testGen, op, dtypeList, shapeList, testArgs, error_name
)
@staticmethod
- def tvgIntDiv(testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None):
+ def tvgIntDiv(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
if error_name is None:
pCount, cCount = op["operands"]
assert (
@@ -765,11 +756,11 @@ class TosaTensorValuesGen:
return placeholders
else:
return TosaTensorValuesGen.tvgDefault(
- testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name
+ testGen, op, dtypeList, shapeList, testArgs, error_name
)
@staticmethod
- def tvgMul(testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None):
+ def tvgMul(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
if error_name is None:
pCount, cCount = op["operands"]
assert (
@@ -839,11 +830,11 @@ class TosaTensorValuesGen:
return tens
else:
return TosaTensorValuesGen.tvgDefault(
- testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name
+ testGen, op, dtypeList, shapeList, testArgs, error_name
)
@staticmethod
- def tvgConcat(testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None):
+ def tvgConcat(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
count = len(shapeList) - testGen.args.num_const_inputs_concat
if count < 1:
count = 1
@@ -866,9 +857,7 @@ class TosaTensorValuesGen:
return tens
@staticmethod
- def tvgLogicalShift(
- testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None
- ):
+ def tvgLogicalShift(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
pCount, cCount = op["operands"]
assert (
pCount == 2 and cCount == 0
@@ -886,7 +875,7 @@ class TosaTensorValuesGen:
return placeholders
@staticmethod
- def tvgEqual(testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None):
+ def tvgEqual(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
if error_name is None:
pCount, cCount = op["operands"]
assert (
@@ -924,13 +913,11 @@ class TosaTensorValuesGen:
return placeholders
else:
return TosaTensorValuesGen.tvgDefault(
- testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name
+ testGen, op, dtypeList, shapeList, testArgs, error_name
)
@staticmethod
- def tvgReduceSum(
- testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name=None
- ):
+ def tvgReduceSum(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
if dtypeList[0] == DType.INT32:
pCount, cCount = op["operands"]
assert (
@@ -949,7 +936,7 @@ class TosaTensorValuesGen:
return placeholders
else:
return TosaTensorValuesGen.tvgDefault(
- testGen, op, dtypeList, shapeList, testArgs, qinfo, error_name
+ testGen, op, dtypeList, shapeList, testArgs, error_name
)