aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_test_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r--verif/generator/tosa_test_gen.py114
1 files changed, 76 insertions, 38 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 38365d0..7c2b9de 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -630,6 +630,8 @@ class TosaTestGen:
stride=strides,
dilation=dilations,
input_shape=ifm.shape,
+ weight_shape=filter.shape,
+ output_shape=result_tens.shape,
):
return None
@@ -692,6 +694,8 @@ class TosaTestGen:
stride=strides,
dilation=dilations,
input_shape=ifm.shape,
+ weight_shape=filter.shape,
+ output_shape=result_tens.shape,
):
return None
@@ -715,7 +719,7 @@ class TosaTestGen:
error_name=None,
qinfo=None,
):
- assert len(outpad) == 2
+ assert len(outpad) == 4
result_tens = OutputShaper.transposeConv2DOp(
self.ser, self.rng, ifm, output_shape, error_name
)
@@ -753,8 +757,9 @@ class TosaTestGen:
output_list=output_list,
pad=outpad,
stride=stride,
- dilation=dilation,
input_shape=ifm.shape,
+ weight_shape=filter.shape,
+ output_shape=result_tens.shape,
):
return None
@@ -816,6 +821,8 @@ class TosaTestGen:
stride=strides,
dilation=dilations,
input_shape=ifm.shape,
+ weight_shape=filter.shape,
+ output_shape=result_tens.shape,
):
return None
@@ -2393,6 +2400,7 @@ class TosaTestGen:
TosaErrorValidator.evOutputZeroPointNotZero,
TosaErrorValidator.evPadLargerEqualKernel,
TosaErrorValidator.evPoolingOutputShapeMismatch,
+ TosaErrorValidator.evPoolingOutputShapeNonInteger,
),
},
# Templated operator. Filled in by createDynamicOpLists
@@ -2420,6 +2428,8 @@ class TosaTestGen:
TosaErrorValidator.evStrideSmallerOne,
TosaErrorValidator.evDilationSmallerOne,
TosaErrorValidator.evWrongRank,
+ TosaErrorValidator.evConvOutputShapeMismatch,
+ TosaErrorValidator.evConvOutputShapeNonInteger,
),
"template": True,
},
@@ -2448,6 +2458,8 @@ class TosaTestGen:
TosaErrorValidator.evStrideSmallerOne,
TosaErrorValidator.evDilationSmallerOne,
TosaErrorValidator.evWrongRank,
+ TosaErrorValidator.evConvOutputShapeMismatch,
+ TosaErrorValidator.evConvOutputShapeNonInteger,
),
"template": True,
},
@@ -2477,6 +2489,8 @@ class TosaTestGen:
TosaErrorValidator.evStrideSmallerOne,
TosaErrorValidator.evDilationSmallerOne,
TosaErrorValidator.evWrongRank,
+ TosaErrorValidator.evConvOutputShapeMismatch,
+ TosaErrorValidator.evConvOutputShapeNonInteger,
),
"template": True,
},
@@ -2546,6 +2560,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evPadLargerEqualKernel,
TosaErrorValidator.evPoolingOutputShapeMismatch,
+ TosaErrorValidator.evPoolingOutputShapeNonInteger,
),
},
# Templated operator. Filled in by createDynamicOpLists
@@ -2574,8 +2589,8 @@ class TosaTestGen:
TosaErrorValidator.evWeightZeroPointNotZero,
TosaErrorValidator.evPadSmallerZero,
TosaErrorValidator.evStrideSmallerOne,
- TosaErrorValidator.evDilationSmallerOne,
TosaErrorValidator.evWrongRank,
+ TosaErrorValidator.evConvOutputShapeMismatch,
),
"template": True,
},
@@ -3887,30 +3902,30 @@ class OutputShaper:
# Filter: OHWI
# OFM: NHWC
- if len(padding) == 2:
- # Expand padding to 4 parameters in the case of transpose_conv2d
- # From H,W to T,B,L,R
- padding = [padding[0], padding[0], padding[1], padding[1]]
-
h = (
ifm.shape[1]
- - filter.shape[1]
- - (filter.shape[1] - 1) * (dilations[0] - 1)
+ - 1
+ padding[0]
+ padding[1]
+ - (filter.shape[1] - 1) * dilations[0]
) // strides[0] + 1
w = (
ifm.shape[2]
- - filter.shape[2]
- - (filter.shape[2] - 1) * (dilations[1] - 1)
+ - 1
+ padding[2]
+ padding[3]
+ - (filter.shape[2] - 1) * dilations[1]
) // strides[1] + 1
- # Avoid illegal dimensions, which can be generated in error_if tests
- h = max(h, 1)
- w = max(w, 1)
+ if error_name == ErrorIf.ConvOutputShapeMismatch:
+ choices = [1, 2, 3]
+ change = rng.choice(choices)
+ # increment in multiples of stride to not hit non-integer error case
+ if change in [1, 3]:
+ h = h + (rng.choice(choices) * strides[0])
+ if change in [2, 3]:
+ w = w + (rng.choice(choices) * strides[1])
ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
@@ -3941,32 +3956,38 @@ class OutputShaper:
d = (
ifm.shape[1]
- - filter.shape[1]
- - (filter.shape[1] - 1) * (dilations[0] - 1)
+ - 1
+ padding[0]
+ padding[1]
+ - (filter.shape[1] - 1) * dilations[0]
) // strides[0] + 1
h = (
ifm.shape[2]
- - filter.shape[2]
- - (filter.shape[2] - 1) * (dilations[1] - 1)
+ - 1
+ padding[2]
+ padding[3]
+ - (filter.shape[2] - 1) * dilations[1]
) // strides[1] + 1
w = (
ifm.shape[3]
- - filter.shape[3]
- - (filter.shape[3] - 1) * (dilations[2] - 1)
+ - 1
+ padding[4]
+ padding[5]
+ - (filter.shape[3] - 1) * dilations[2]
) // strides[2] + 1
- # Avoid illegal dimensions, which can be generated in error_if tests
- d = max(d, 1)
- h = max(h, 1)
- w = max(w, 1)
+ if error_name == ErrorIf.ConvOutputShapeMismatch:
+ choices = [1, 2, 3, 4]
+ change = rng.choice(choices)
+ # increment in multiples of stride to not hit non-integer error case
+ if change in [1, 4]:
+ d = d + (rng.choice(choices) * strides[0])
+ if change in [2, 4]:
+ h = h + (rng.choice(choices) * strides[1])
+ if change in [3, 4]:
+ w = w + (rng.choice(choices) * strides[2])
ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
@@ -3995,25 +4016,31 @@ class OutputShaper:
# IFM: NHWC
# Filter: HWCM
# OFM: NHW C*M
+
h = (
ifm.shape[1]
- - filter.shape[0]
- - (filter.shape[0] - 1) * (dilations[0] - 1)
+ - 1
+ padding[0]
+ padding[1]
+ - (filter.shape[0] - 1) * dilations[0]
) // strides[0] + 1
w = (
ifm.shape[2]
- - filter.shape[1]
- - (filter.shape[1] - 1) * (dilations[1] - 1)
+ - 1
+ padding[2]
+ padding[3]
+ - (filter.shape[1] - 1) * dilations[1]
) // strides[1] + 1
- # Avoid illegal dimensions, which can be generated in error_if tests
- h = max(h, 1)
- w = max(w, 1)
+ if error_name == ErrorIf.ConvOutputShapeMismatch:
+ choices = [1, 2, 3]
+ change = rng.choice(choices)
+ # increment in multiples of stride to not hit non-integer error case
+ if change in [1, 3]:
+ h = h + (rng.choice(choices) * strides[0])
+ if change in [2, 3]:
+ w = w + (rng.choice(choices) * strides[1])
ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
@@ -4043,14 +4070,17 @@ class OutputShaper:
h = 1
w = 1
else:
- h = (ifm.shape[1] + pad[0] + pad[1] + stride[0] - kernel[0]) // stride[0]
- w = (ifm.shape[2] + pad[2] + pad[3] + stride[1] - kernel[1]) // stride[1]
+ h = (ifm.shape[1] + pad[0] + pad[1] - kernel[0]) // stride[0] + 1
+ w = (ifm.shape[2] + pad[2] + pad[3] - kernel[1]) // stride[1] + 1
if error_name == ErrorIf.PoolingOutputShapeMismatch:
- choices = [1, 2, 3, 4, 5]
- h = h + rng.choice(choices)
- w = w + rng.choice(choices)
-
+ choices = [1, 2, 3]
+ change = rng.choice(choices)
+ # increment in multiples of stride to not hit non-integer error case
+ if change in [1, 3]:
+ h = h + (rng.choice(choices) * stride[0])
+ if change in [2, 3]:
+ w = w + (rng.choice(choices) * stride[1])
ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
if error_name == ErrorIf.WrongOutputType:
@@ -4468,6 +4498,14 @@ class OutputShaper:
@staticmethod
def transposeConv2DOp(ser, rng, ifm, output_shape, error_name=None):
+ if error_name == ErrorIf.ConvOutputShapeMismatch:
+ choices = [1, 2, 3]
+ change = rng.choice(choices)
+ if change in [1, 3]:
+ output_shape[1] = output_shape[1] + rng.choice(choices)
+ if change in [2, 3]:
+ output_shape[2] = output_shape[2] + rng.choice(choices)
+
if ifm.dtype == DType.INT8:
out_dtype = DType.INT32
elif ifm.dtype == DType.INT16: