From cb7201e173961760c042cade591afe763c949c8f Mon Sep 17 00:00:00 2001 From: Jerry Ge Date: Thu, 1 Jun 2023 22:45:26 +0000 Subject: Add accumtype to tf/tfl framework tests Signed-off-by: Jerry Ge Change-Id: Ie45cc27433f5dbce3fadc90014dc5cc8e36a9950 --- verif/frameworks/arg_gen.py | 33 +++++++++++++++++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) (limited to 'verif/frameworks') diff --git a/verif/frameworks/arg_gen.py b/verif/frameworks/arg_gen.py index 5de995b..a25c205 100644 --- a/verif/frameworks/arg_gen.py +++ b/verif/frameworks/arg_gen.py @@ -3,6 +3,21 @@ import math import numpy as np +from tosa.DType import DType + +DTYPE_ATTRIBUTES = { + DType.BOOL: {"str": "b", "width": 1}, + DType.INT4: {"str": "i4", "width": 4}, + DType.INT8: {"str": "i8", "width": 8}, + DType.UINT8: {"str": "u8", "width": 8}, + DType.INT16: {"str": "i16", "width": 16}, + DType.UINT16: {"str": "u16", "width": 16}, + DType.INT32: {"str": "i32", "width": 32}, + DType.INT48: {"str": "i48", "width": 48}, + DType.FP16: {"str": "f16", "width": 16}, + DType.BF16: {"str": "bf16", "width": 16}, + DType.FP32: {"str": "f32", "width": 32}, +} class ArgGen: @@ -13,6 +28,12 @@ class ArgGen: def __init__(self): pass + def typeStr(dtype): + if dtype in DTYPE_ATTRIBUTES: + return DTYPE_ATTRIBUTES[dtype]["str"] + else: + raise Exception("Unknown dtype, cannot convert to string: {}".format(dtype)) + @staticmethod def agNone(op, shapes, rng): """A trivial argument generator for operators that only take tensor @@ -332,10 +353,18 @@ class ArgGen: # Not an exact integer output continue + # Note: tf.nn.avg_pool2d API doesn't support setting accumtype + # setting a dummy value to the test name as an reminder + accum_dtype = ArgGen.typeStr(DType.INT32) arg_list.append( [ - "_st{}{}_pad{}_kern{}{}".format( - stride_h, stride_w, padding, kernel_h, kernel_w + "_st{}{}_pad{}_kern{}{}_acc{}".format( + stride_h, + stride_w, + padding, + kernel_h, + kernel_w, + accum_dtype, ), [ [stride_h, stride_w], -- cgit v1.2.1