aboutsummaryrefslogtreecommitdiff
path: root/verif/tosa_test_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/tosa_test_gen.py')
-rw-r--r--verif/tosa_test_gen.py64
1 files changed, 27 insertions, 37 deletions
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py
index 24d7b7b..57030e7 100644
--- a/verif/tosa_test_gen.py
+++ b/verif/tosa_test_gen.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python3
-# Copyright (c) 2020, ARM Limited.
+# Copyright (c) 2020-2021, ARM Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -50,7 +50,7 @@ class TosaQuantGen:
@staticmethod
def needsQinfo(op, dtype):
- if dtype == DType.AINT8 or dtype == DType.INT8:
+ if dtype == DType.INT8:
return True
return False
@@ -498,7 +498,7 @@ class TosaArgGen:
arg_list = []
# Enumerate the output types here
- for dtype in [ DType.AINT8, DType.INT16, DType.INT32 ]:
+ for dtype in [ DType.INT8, DType.INT16, DType.INT32 ]:
for scale32 in [ False, True ]:
for double_round in [ False, True ]:
for per_channel in [ False, True ]:
@@ -790,8 +790,6 @@ class TosaTestGen:
if dtype == DType.BOOL:
np_dt = np.bool
return np.bool_(self.rng.choice(a=[False, True], size=shape))
- elif dtype == DType.AINT8:
- return np.int32(self.rng.integers(low=0, high=256, size=shape))
elif dtype == DType.INT4:
return np.int32(self.rng.integers(low=-7, high=8, size=shape))
elif dtype == DType.INT8:
@@ -845,8 +843,6 @@ class TosaTestGen:
return self.rng.choice([False, True])
elif dtype == DType.INT4:
low, high = (-7, 8)
- elif dtype == DType.AINT8:
- low, high = (0, 256)
elif dtype == DType.INT8:
low, high = (-127, 128)
elif dtype == DType.INT16:
@@ -874,12 +870,12 @@ class TosaTestGen:
def typeStr(self, t):
if t == DType.BOOL:
return 'b'
- elif t == DType.AINT8:
- return 'a8'
elif t == DType.INT4:
return 'i4'
elif t == DType.INT8:
return 'i8'
+ elif t == DType.UINT8:
+ return 'u8'
elif t == DType.INT16:
return 'i16'
elif t == DType.INT32:
@@ -893,14 +889,12 @@ class TosaTestGen:
def typeWidth(self, t):
''' Get the datatype width for integer types'''
- if t == DType.AINT8:
- return 8
- elif t == DType.UINT8:
- return 8
- elif t == DType.INT4:
+ if t == DType.INT4:
return 4
elif t == DType.INT8:
return 8
+ elif t == DType.UINT8:
+ return 8
elif t == DType.INT16:
return 16
elif t == DType.INT32:
@@ -1030,7 +1024,7 @@ class TosaTestGen:
# Create bias here since the acc_t depends on (but isn't the same as) the input dtype
# The bias is OC
- if ifm.dtype == DType.AINT8 or ifm.dtype == DType.INT8:
+ if ifm.dtype == DType.INT8:
bias_type = DType.INT32
elif ifm.dtype == DType.INT16:
bias_type = DType.INT48
@@ -1267,13 +1261,13 @@ class TosaTestGen:
in_type_width = self.typeWidth(val.dtype)
out_type_width = self.typeWidth(out_dtype)
- if val.dtype == DType.AINT8:
+ if val.dtype == DType.INT8:
input_zp = self.randInt()
in_type_width = in_type_width + 1
else:
input_zp = 0
- if out_dtype == DType.AINT8:
+ if out_dtype == DType.INT8:
output_zp = self.randInt()
out_type_width = out_type_width + 1
else:
@@ -1637,24 +1631,21 @@ class TosaTestGen:
# Tensor operator list
# 'op': op name
# 'operands': tuple of (placeholder, const) operands
- # 'rank': optional, restricts rank to tuple inclusive of (min, max), if not specified, defaults to (1, 4)
+ # 'rank': optional, restricts rank to tuple inclusive of (min, max),
+ # if not specified, defaults to (1, 4)
# 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
# 'types': array of datatypes to be tested
TYPE_FP = [ DType.FLOAT ]
- # Type with an aint8
- TYPE_INT = [ DType.AINT8, DType.INT16, DType.INT32 ] # Most operators support AINT8 instead of INT8, excludes INT4
- TYPE_INT_FP = [ DType.AINT8, DType.INT16, DType.INT32, DType.FLOAT ] # Most operators support AINT8 instead of INT8, excludes INT4
+ TYPE_INT = [ DType.INT8, DType.INT16, DType.INT32 ] # Excludes INT4
+ TYPE_INT_FP = [ DType.INT8, DType.INT16, DType.INT32, DType.FLOAT ] # Excludes INT4
- # Types with an int8
- TYPE_PURE_INT = [ DType.INT8, DType.INT16, DType.INT32 ] # Note: excludes INT4
- TYPE_PURE_INT_FP = [ DType.INT8, DType.INT16, DType.INT32, DType.FLOAT ] # Note: excludes INT4
TYPE_BOOL = [ DType.BOOL ]
TYPE_FI32 = [ DType.FLOAT, DType.INT32 ]
- TYPE_FIB = [ DType.FLOAT, DType.AINT8, DType.INT8, DType.INT16, DType.INT32, DType.BOOL ]
+ TYPE_FIB = [ DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL ]
TYPE_FI16 = [ DType.FLOAT, DType.INT16 ]
- TYPE_NARROW_INT_FP = [ DType.AINT8, DType.INT16, DType.FLOAT ]
+ TYPE_NARROW_INT_FP = [ DType.INT8, DType.INT16, DType.FLOAT ]
DEFAULT_RANK_RANGE = (1, 4)
@@ -1670,7 +1661,7 @@ class TosaTestGen:
{ 'op': Op.ARITHMETIC_RIGHT_SHIFT,
'operands': (2, 0),
'build_fcn': (build_arithmetic_right_shift, TosaTensorGen.tgBroadcastFuzz, TosaArgGen.agArithmeticRightShift),
- 'types': TYPE_PURE_INT },
+ 'types': TYPE_INT },
'bitwise_and':
{ 'op': Op.BITWISE_AND,
@@ -1700,13 +1691,13 @@ class TosaTestGen:
{ 'op': Op.LOGICAL_LEFT_SHIFT,
'operands': (2, 0),
'build_fcn': (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
- 'types': TYPE_PURE_INT },
+ 'types': TYPE_INT },
'logical_right_shift':
{ 'op': Op.LOGICAL_RIGHT_SHIFT,
'operands': (2, 0),
'build_fcn': (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
- 'types': TYPE_PURE_INT },
+ 'types': TYPE_INT },
'logical_or':
{ 'op': Op.LOGICAL_OR,
@@ -1736,7 +1727,7 @@ class TosaTestGen:
{ 'op': Op.MUL,
'operands': (2, 0),
'build_fcn': (build_mul, TosaTensorGen.tgBroadcastFuzz, TosaArgGen.agMul),
- 'types': TYPE_PURE_INT_FP },
+ 'types': TYPE_INT_FP },
'pow':
{ 'op': Op.POW,
@@ -2101,7 +2092,7 @@ class TosaTestGen:
{ 'op': Op.RESCALE,
'operands': (1, 0),
'build_fcn': ( build_rescale, TosaTensorGen.tgBasic, TosaArgGen.agRescale ),
- 'types': [ DType.AINT8, DType.INT16, DType.INT32, DType.INT48 ] },
+ 'types': [ DType.INT8, DType.INT16, DType.INT32, DType.INT48 ] },
# Custom
# Not implemented.
@@ -2239,7 +2230,7 @@ class OutputShaper:
ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
- if ifm.dtype == DType.AINT8 or ifm.dtype == DType.INT8:
+ if ifm.dtype == DType.INT8:
out_dtype = DType.INT32
elif ifm.dtype == DType.INT16:
out_dtype = DType.INT48
@@ -2269,7 +2260,7 @@ class OutputShaper:
ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
- if ifm.dtype == DType.AINT8 or ifm.dtype == DType.INT8:
+ if ifm.dtype == DType.INT8:
out_dtype = DType.INT32
elif ifm.dtype == DType.INT16:
out_dtype = DType.INT48
@@ -2304,7 +2295,7 @@ class OutputShaper:
output_shape = [input.shape[0], filter.shape[0]]
- if input.dtype == DType.AINT8 or input.dtype == DType.INT8:
+ if input.dtype == DType.INT8:
out_dtype = DType.INT32
elif input.dtype == DType.INT16:
out_dtype = DType.INT48
@@ -2323,8 +2314,7 @@ class OutputShaper:
output_shape = [a.shape[0], b.shape[1]]
-
- if a.dtype == DType.AINT8:
+ if a.dtype == DType.INT8:
out_dtype = DType.INT32
elif a.dtype == DType.INT16:
out_dtype = DType.INT48
@@ -2478,7 +2468,7 @@ class OutputShaper:
@staticmethod
def transposeConv2DOp(ser, ifm, output_shape):
- if ifm.dtype == DType.AINT8 or ifm.dtype == DType.INT8:
+ if ifm.dtype == DType.INT8:
out_dtype = DType.INT32
elif ifm.dtype == DType.INT16:
out_dtype = DType.INT48