diff options
Diffstat (limited to 'verif/generator/tosa_arg_gen.py')
-rw-r--r-- | verif/generator/tosa_arg_gen.py | 27 |
1 files changed, 18 insertions, 9 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index c596645..253e8ee 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -1,8 +1,8 @@ # Copyright (c) 2021-2024, ARM Limited. # SPDX-License-Identifier: Apache-2.0 import itertools +import logging import math -import warnings import generator.tosa_utils as gtu import numpy as np @@ -16,6 +16,9 @@ from tosa.ResizeMode import ResizeMode # DTypeNames, DType, Op and ResizeMode are convenience variables to the # flatc-generated types that should be enums, but aren't +logging.basicConfig() +logger = logging.getLogger("tosa_verif_build_tests") + class TosaQuantGen: """QuantizedInfo random generator helper functions. @@ -131,8 +134,9 @@ class TosaQuantGen: shift = shift + 1 shift = (-shift) + scaleBits - # print('scalefp {} scaleBits {} m {} mult {} shift {}'.format( - # scaleFp, scaleBits, m, multiplier, shift)) + logger.debug( + f"computeMultiplierAndShift: scalefp={scaleFp} scaleBits={scaleBits} m={m} mult={multiplier} shift={shift}" + ) # Adjust multiplier such that shift is in allowed value range. if shift == 0: @@ -690,8 +694,9 @@ class TosaTensorValuesGen: # Invalid data range from low to high created due to user # constraints revert to using internal ranges as they are # known to work - msg = f"Using safe data range ({low_val} to {high_val}) instead of supplied ({type_range[0]} to {type_range[1]})" - warnings.warn(msg) + logger.info( + f"Using safe data range ({low_val} to {high_val}) instead of supplied ({type_range[0]} to {type_range[1]})" + ) data_range = (low_val, high_val) return data_range return None @@ -1856,7 +1861,7 @@ class TosaArgGen: if "shape" in args_dict else "" ) - print( + logger.info( f"Skipping {opName}{shape_info} dot product test as too few calculations {dot_products} < {testGen.TOSA_MI_DOT_PRODUCT_MIN}" ) continue @@ -2503,7 +2508,7 @@ class TosaArgGen: arg_list.append((name, args_dict)) if error_name == ErrorIf.PadSmallerZero and len(arg_list) == 0: - warnings.warn(f"No ErrorIf test created for input shape: {shapeList[0]}") + logger.info(f"No ErrorIf test created for input shape: {shapeList[0]}") arg_list = TosaArgGen._add_data_generators( testGen, @@ -2683,7 +2688,9 @@ class TosaArgGen: remainder_w = partial_w % s[1] output_h = partial_h // s[0] + 1 output_w = partial_w // s[1] + 1 - # debug print(shape, remainder_h, remainder_w, "/", output_h, output_w) + logger.debug( + f"agPooling: {shape} remainder=({remainder_h}, {remainder_w}) output=({output_h}, {output_w})" + ) if ( # the parameters must produce integer exact output error_name != ErrorIf.PoolingOutputShapeNonInteger @@ -2920,7 +2927,9 @@ class TosaArgGen: # Cap the scaling at 2^15 - 1 for scale16 scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0) - # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr)) + logger.debug( + f"agRescale: {out_type_width} {in_type_width} -> {scale_arr}" + ) multiplier_arr = np.int32(np.zeros(shape=[nc])) shift_arr = np.int32(np.zeros(shape=[nc])) |