diff options
-rw-r--r-- | verif/tosa_test_gen.py | 59 |
1 files changed, 43 insertions, 16 deletions
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py index b9cca18..3702142 100644 --- a/verif/tosa_test_gen.py +++ b/verif/tosa_test_gen.py @@ -1111,28 +1111,55 @@ class TosaArgGen: ) ) else: - shift = 11 - unit = float(1 << shift) - stride_y = int(round(fp_stride_y * unit)) - stride_x = int(round(fp_stride_x * unit)) - offset_y = int(round(fp_offset_y * unit)) - offset_x = int(round(fp_offset_x * unit)) - - while ( - stride_y >= (16 << shift) - or stride_x >= (16 << shift) - or offset_y >= (16 << shift) - or offset_x >= (16 << shift) - or offset_y <= (-16 << shift) - or offset_x <= (-16 << shift) - ): - shift = shift - 1 + shift = testGen.randInt(1,12) + # Now search for a shift value (1 to 11) that will produce + # a valid and predictable resize operation + count = 0 + while (count < 12): unit = float(1 << shift) stride_y = int(round(fp_stride_y * unit)) stride_x = int(round(fp_stride_x * unit)) offset_y = int(round(fp_offset_y * unit)) offset_x = int(round(fp_offset_x * unit)) + if ( + stride_y >= (16 << shift) + or stride_x >= (16 << shift) + or offset_y >= (16 << shift) + or offset_x >= (16 << shift) + or offset_y <= (-16 << shift) + or offset_x <= (-16 << shift) + ): + # Change the shift value and check again + count += 1 + shift = (shift % 11) + 1 + continue + + def RESIZE_REQUIRE_CALC(length_in, length_out, stride, offset, shift): + # Perform the pseudo loop to look for out of bounds + for pos in range(0,length_out): + a = pos * stride + offset + ia = a >> shift + ia0 = max(ia, 0) + ia1 = min(ia+1, length_in-1) + if ia0 > ia1: + # Found a problem value + break + return ia0, ia1 + + iy0, iy1 = RESIZE_REQUIRE_CALC(ifm_shape[1], output_dims[0], stride_y, offset_y, shift) + ix0, ix1 = RESIZE_REQUIRE_CALC(ifm_shape[2], output_dims[1], stride_x, offset_x, shift) + if ix0 > ix1 or iy0 > iy1: + # Change the shift value and check again + count += 1 + shift = (shift % 11) + 1 + continue + break + + if count >= 12: + # Couldn't find a good set of values for this test, skip it + continue + stride = [stride_y, stride_x] offset = [offset_y, offset_x] |