aboutsummaryrefslogtreecommitdiff
path: root/verif/frameworks/arg_gen.py
diff options
context:
space:
mode:
authorJerry Ge <jerry.ge@arm.com>2023-06-01 22:45:26 +0000
committerDominic Symes <dominic.symes@arm.com>2023-06-15 18:24:04 +0000
commitcb7201e173961760c042cade591afe763c949c8f (patch)
tree7d30d408d9237e73510e0ef78e9163856dd5a48b /verif/frameworks/arg_gen.py
parent41df428ed5e3b07f0a497fc504f1eddb8e115188 (diff)
downloadreference_model-cb7201e173961760c042cade591afe763c949c8f.tar.gz
Add accumtype to tf/tfl framework tests
Signed-off-by: Jerry Ge <jerry.ge@arm.com> Change-Id: Ie45cc27433f5dbce3fadc90014dc5cc8e36a9950
Diffstat (limited to 'verif/frameworks/arg_gen.py')
-rw-r--r--verif/frameworks/arg_gen.py33
1 files changed, 31 insertions, 2 deletions
diff --git a/verif/frameworks/arg_gen.py b/verif/frameworks/arg_gen.py
index 5de995b..a25c205 100644
--- a/verif/frameworks/arg_gen.py
+++ b/verif/frameworks/arg_gen.py
@@ -3,6 +3,21 @@
import math
import numpy as np
+from tosa.DType import DType
+
+DTYPE_ATTRIBUTES = {
+ DType.BOOL: {"str": "b", "width": 1},
+ DType.INT4: {"str": "i4", "width": 4},
+ DType.INT8: {"str": "i8", "width": 8},
+ DType.UINT8: {"str": "u8", "width": 8},
+ DType.INT16: {"str": "i16", "width": 16},
+ DType.UINT16: {"str": "u16", "width": 16},
+ DType.INT32: {"str": "i32", "width": 32},
+ DType.INT48: {"str": "i48", "width": 48},
+ DType.FP16: {"str": "f16", "width": 16},
+ DType.BF16: {"str": "bf16", "width": 16},
+ DType.FP32: {"str": "f32", "width": 32},
+}
class ArgGen:
@@ -13,6 +28,12 @@ class ArgGen:
def __init__(self):
pass
+ def typeStr(dtype):
+ if dtype in DTYPE_ATTRIBUTES:
+ return DTYPE_ATTRIBUTES[dtype]["str"]
+ else:
+ raise Exception("Unknown dtype, cannot convert to string: {}".format(dtype))
+
@staticmethod
def agNone(op, shapes, rng):
"""A trivial argument generator for operators that only take tensor
@@ -332,10 +353,18 @@ class ArgGen:
# Not an exact integer output
continue
+ # Note: tf.nn.avg_pool2d API doesn't support setting accumtype
+ # setting a dummy value to the test name as an reminder
+ accum_dtype = ArgGen.typeStr(DType.INT32)
arg_list.append(
[
- "_st{}{}_pad{}_kern{}{}".format(
- stride_h, stride_w, padding, kernel_h, kernel_w
+ "_st{}{}_pad{}_kern{}{}_acc{}".format(
+ stride_h,
+ stride_w,
+ padding,
+ kernel_h,
+ kernel_w,
+ accum_dtype,
),
[
[stride_h, stride_w],