diff options
author | Jerry Ge <jerry.ge@arm.com> | 2023-06-01 22:45:26 +0000 |
---|---|---|
committer | Dominic Symes <dominic.symes@arm.com> | 2023-06-15 18:24:04 +0000 |
commit | cb7201e173961760c042cade591afe763c949c8f (patch) | |
tree | 7d30d408d9237e73510e0ef78e9163856dd5a48b /verif | |
parent | 41df428ed5e3b07f0a497fc504f1eddb8e115188 (diff) | |
download | reference_model-cb7201e173961760c042cade591afe763c949c8f.tar.gz |
Add accumtype to tf/tfl framework tests
Signed-off-by: Jerry Ge <jerry.ge@arm.com>
Change-Id: Ie45cc27433f5dbce3fadc90014dc5cc8e36a9950
Diffstat (limited to 'verif')
-rw-r--r-- | verif/frameworks/arg_gen.py | 33 |
1 files changed, 31 insertions, 2 deletions
diff --git a/verif/frameworks/arg_gen.py b/verif/frameworks/arg_gen.py index 5de995b..a25c205 100644 --- a/verif/frameworks/arg_gen.py +++ b/verif/frameworks/arg_gen.py @@ -3,6 +3,21 @@ import math import numpy as np +from tosa.DType import DType + +DTYPE_ATTRIBUTES = { + DType.BOOL: {"str": "b", "width": 1}, + DType.INT4: {"str": "i4", "width": 4}, + DType.INT8: {"str": "i8", "width": 8}, + DType.UINT8: {"str": "u8", "width": 8}, + DType.INT16: {"str": "i16", "width": 16}, + DType.UINT16: {"str": "u16", "width": 16}, + DType.INT32: {"str": "i32", "width": 32}, + DType.INT48: {"str": "i48", "width": 48}, + DType.FP16: {"str": "f16", "width": 16}, + DType.BF16: {"str": "bf16", "width": 16}, + DType.FP32: {"str": "f32", "width": 32}, +} class ArgGen: @@ -13,6 +28,12 @@ class ArgGen: def __init__(self): pass + def typeStr(dtype): + if dtype in DTYPE_ATTRIBUTES: + return DTYPE_ATTRIBUTES[dtype]["str"] + else: + raise Exception("Unknown dtype, cannot convert to string: {}".format(dtype)) + @staticmethod def agNone(op, shapes, rng): """A trivial argument generator for operators that only take tensor @@ -332,10 +353,18 @@ class ArgGen: # Not an exact integer output continue + # Note: tf.nn.avg_pool2d API doesn't support setting accumtype + # setting a dummy value to the test name as an reminder + accum_dtype = ArgGen.typeStr(DType.INT32) arg_list.append( [ - "_st{}{}_pad{}_kern{}{}".format( - stride_h, stride_w, padding, kernel_h, kernel_w + "_st{}{}_pad{}_kern{}{}_acc{}".format( + stride_h, + stride_w, + padding, + kernel_h, + kernel_w, + accum_dtype, ), [ [stride_h, stride_w], |