aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_arg_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_arg_gen.py')
-rw-r--r--verif/generator/tosa_arg_gen.py39
1 files changed, 11 insertions, 28 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index 370570c..75ca634 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -192,10 +192,7 @@ class TosaTensorGen:
assert rank == 4
shape = testGen.makeShape(rank)
-
- # Constrict the batch size?
- if testGen.args.max_batch_size:
- shape[0] = (shape[0] % testGen.args.max_batch_size) + 1
+ shape = testGen.constrictBatchSize(shape)
# Constrict the overall size of the shape when creating ERROR_IF tests
if error_name and error_name != ErrorIf.MaxDimExceeded:
@@ -220,7 +217,7 @@ class TosaTensorGen:
# ignore max batch size if target shape is set
if testGen.args.max_batch_size and not testGen.args.target_shapes:
- values_in_shape[0] = (values_in_shape[0] % testGen.args.max_batch_size) + 1
+ values_in_shape[0] = min(values_in_shape[0], testGen.args.max_batch_size)
W = testGen.randInt(
testGen.args.tensor_shape_range[0], testGen.args.tensor_shape_range[1]
@@ -282,10 +279,7 @@ class TosaTensorGen:
# IFM dimensions are NHWC
ifm_shape = testGen.makeShape(rank)
-
- # Constrict the batch size?
- if testGen.args.max_batch_size:
- ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
+ ifm_shape = testGen.constrictBatchSize(ifm_shape)
# Constrict the overall size of the shape when creating ERROR_IF tests
if error_name:
@@ -297,7 +291,7 @@ class TosaTensorGen:
filter_hw = op["filter"]
# Generate a random OFM depth
- ofm_depth = testGen.makeShape(1)[0]
+ ofm_depth = testGen.makeDimension()
# The filter dimensions are OHWI
filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
@@ -316,10 +310,7 @@ class TosaTensorGen:
# IFM dimensions are NDHWC
ifm_shape = testGen.makeShape(rank)
-
- # Constrict the batch size?
- if testGen.args.max_batch_size:
- ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
+ ifm_shape = testGen.constrictBatchSize(ifm_shape)
# Constrict the overall size of the shape when creating ERROR_IF tests
if error_name:
@@ -331,7 +322,7 @@ class TosaTensorGen:
filter_dhw = op["filter"]
# Generate a random OFM channel
- ofm_channel = testGen.makeShape(1)[0]
+ ofm_channel = testGen.makeDimension()
# The filter dimensions are ODHWI
filter_shape = np.asarray(
@@ -352,10 +343,7 @@ class TosaTensorGen:
# IFM dimensions are NHWC
ifm_shape = testGen.makeShape(rank)
-
- # Constrict the batch size?
- if testGen.args.max_batch_size:
- ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
+ ifm_shape = testGen.constrictBatchSize(ifm_shape)
# Constrict the overall size of the shape when creating ERROR_IF tests
if error_name:
@@ -367,7 +355,7 @@ class TosaTensorGen:
filter_hw = op["filter"]
# Generate a random OFM depth
- ofm_depth = testGen.makeShape(1)[0]
+ ofm_depth = testGen.makeDimension()
# The filter dimensions are OHWI
filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
@@ -387,10 +375,7 @@ class TosaTensorGen:
# IFM dimensions are NHWC
ifm_shape = testGen.makeShape(rank)
-
- # Constrict the batch size?
- if testGen.args.max_batch_size:
- ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
+ ifm_shape = testGen.constrictBatchSize(ifm_shape)
# Constrict the overall size of the shape when creating ERROR_IF tests
if error_name:
@@ -405,7 +390,7 @@ class TosaTensorGen:
# Generate a random OFM depth, but don't let it get too big because
# the output depth is M * C
filter_m = (
- testGen.makeShape(1)[0] % (testGen.args.tensor_shape_range[1] // 4)
+ testGen.makeDimension() % (testGen.args.tensor_shape_range[1] // 4)
) + 1
# The filter dimensions are HWCM
@@ -484,9 +469,7 @@ class TosaTensorGen:
ifm_shape[1] += selected_inc[0]
ifm_shape[2] += selected_inc[1]
- # Constrict the batch size
- if testGen.args.max_batch_size:
- ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
+ ifm_shape = testGen.constrictBatchSize(ifm_shape)
return [ifm_shape]