aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_test_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r--verif/generator/tosa_test_gen.py130
1 files changed, 76 insertions, 54 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 583e1ed..eeb0ac7 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -13,6 +13,7 @@ from generator.tosa_error_if import ErrorIf
from generator.tosa_error_if import TosaErrorIfArgGen
from generator.tosa_error_if import TosaErrorValidator
from generator.tosa_error_if import TosaInvalidValidator
+from generator.tosa_utils import MAX_RESIZE_DIMENSION
from generator.tosa_utils import usableDTypes
from tosa.DType import DType
from tosa.Op import Op
@@ -1450,12 +1451,9 @@ class TosaTestGen:
op,
input,
mode,
- stride,
+ scale,
offset,
- shift,
- stride_fp,
- offset_fp,
- output_dims,
+ border,
input_dtype,
output_dtype,
validator_fcns,
@@ -1466,12 +1464,9 @@ class TosaTestGen:
self.rng,
input,
mode,
- stride,
+ scale,
offset,
- shift,
- stride_fp,
- offset_fp,
- output_dims,
+ border,
input_dtype,
output_dtype,
error_name,
@@ -1492,15 +1487,13 @@ class TosaTestGen:
error_name,
op=op,
mode=mode,
- shift=shift,
+ scale=scale,
input_dtype=input_dtype,
output_dtype=output_dtype,
input_shape=input.shape,
- output_shape=output_dims,
+ output_shape=result_tens.shape,
offset=offset,
- offset_fp=offset_fp,
- stride=stride,
- stride_fp=stride_fp,
+ border=border,
input_list=input_list,
output_list=output_list,
result_tensor=result_tens,
@@ -1510,9 +1503,7 @@ class TosaTestGen:
attr = ts.TosaSerializerAttribute()
- attr.ResizeAttribute(
- output_dims, stride, offset, shift, stride_fp, offset_fp, mode
- )
+ attr.ResizeAttribute(scale, offset, border, mode)
self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
@@ -3619,18 +3610,16 @@ class TosaTestGen:
"types": [DType.INT8, DType.INT16, DType.FLOAT],
"invalid_test_validators": (
TosaInvalidValidator.ivWrongDataTypeOrModeResize,
- TosaInvalidValidator.ivBadStride,
),
"error_if_validators": (
TosaErrorValidator.evMaxDimExceeded,
- TosaErrorValidator.evStrideSmallerEqualZero,
- TosaErrorValidator.evStrideLargerDimension,
- TosaErrorValidator.evStrideLargerEqualMax,
- TosaErrorValidator.evOffsetSmallerEqualMin,
+ TosaErrorValidator.evScaleSmallerEqualZero,
+ TosaErrorValidator.evScaleNLargerMax,
+ TosaErrorValidator.evScaleDLargerMax,
+ TosaErrorValidator.evOffsetSmallerMin,
TosaErrorValidator.evOffsetLargerEqualMax,
- TosaErrorValidator.evShiftNotZero,
- TosaErrorValidator.evShiftSmallerOne,
- TosaErrorValidator.evShiftLargerEleven,
+ TosaErrorValidator.evBorderSmallerMin,
+ TosaErrorValidator.evBorderLargerEqualMax,
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongRank,
@@ -3638,6 +3627,8 @@ class TosaTestGen:
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evBatchMismatch,
TosaErrorValidator.evChannelMismatch,
+ TosaErrorValidator.evResizeOutputShapeMismatch,
+ TosaErrorValidator.evResizeOutputShapeNonInteger,
),
},
# Type conversion
@@ -4470,45 +4461,76 @@ class OutputShaper:
rng,
input,
mode,
- stride,
+ scale,
offset,
- shift,
- stride_fp,
- offset_fp,
- output_dims,
+ border,
input_dtype,
output_dtype,
error_name=None,
):
+ # Calculate OH, OW
+ scale_y_n = scale[0]
+ scale_y_d = scale[1]
+ scale_x_n = scale[2]
+ scale_x_d = scale[3]
+ if error_name == ErrorIf.ScaleSmallerEqualZero:
+ scale_y_n = max(scale_y_n, 1)
+ scale_y_d = max(scale_y_d, 1)
+ scale_x_n = max(scale_x_n, 1)
+ scale_x_d = max(scale_x_d, 1)
+
+ oh = ((input.shape[1] - 1) * scale_y_n - offset[0] + border[0]) // scale_y_d + 1
+ ow = ((input.shape[2] - 1) * scale_x_n - offset[1] + border[1]) // scale_x_d + 1
+
+ if error_name is not None:
+ # Make sure the output tensor is valid, which can occur when
+ # scale, offset or border have been changed for ERROR_IFs
+ oh = max(oh, 1)
+ ow = max(ow, 1)
+ if error_name != ErrorIf.MaxDimExceeded:
+ oh = min(oh, MAX_RESIZE_DIMENSION - 1)
+ ow = min(ow, MAX_RESIZE_DIMENSION - 1)
+
+ if error_name == ErrorIf.ResizeOutputShapeMismatch:
+ choices = [1, 2, 3]
+ change = rng.choice(choices)
+ # increment in multiples of scale_y/x_d so we don't hit non-integer error case
+ if change in [1, 3]:
+ if oh + scale_y_d >= MAX_RESIZE_DIMENSION:
+ oh -= scale_y_d
+ assert oh > 0 # Should have been caught in agResize
+ else:
+ oh += scale_y_d
+ if change in [2, 3]:
+ if ow + scale_x_d >= MAX_RESIZE_DIMENSION:
+ ow -= scale_x_d
+ assert ow > 0 # Should have been caught in agResize
+ else:
+ ow += scale_x_d
+
if error_name == ErrorIf.WrongRank:
output_dims = [
input.shape[0],
- output_dims[0],
- output_dims[0],
+ oh,
+ ow,
+ input.shape[0],
+ ]
+ elif error_name == ErrorIf.BatchMismatch:
+ output_dims = [
+ input.shape[0] + rng.integers(1, 10),
+ oh,
+ ow,
+ input.shape[3],
+ ]
+ elif error_name == ErrorIf.ChannelMismatch:
+ output_dims = [
input.shape[0],
+ oh,
+ ow,
+ input.shape[3] + rng.integers(1, 10),
]
else:
- if error_name == ErrorIf.BatchMismatch:
- output_dims = [
- input.shape[0] + rng.integers(1, 10),
- output_dims[0],
- output_dims[1],
- input.shape[3],
- ]
- elif error_name == ErrorIf.ChannelMismatch:
- output_dims = [
- input.shape[0],
- output_dims[0],
- output_dims[1],
- input.shape[3] + rng.integers(1, 10),
- ]
- else:
- output_dims = [
- input.shape[0],
- output_dims[0],
- output_dims[1],
- input.shape[3],
- ]
+ output_dims = [input.shape[0], oh, ow, input.shape[3]]
return serializer.addOutput(output_dims, output_dtype)