aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_verif_build_tests.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_verif_build_tests.py')
-rw-r--r--verif/generator/tosa_verif_build_tests.py246
1 files changed, 246 insertions, 0 deletions
diff --git a/verif/generator/tosa_verif_build_tests.py b/verif/generator/tosa_verif_build_tests.py
new file mode 100644
index 0000000..09ee238
--- /dev/null
+++ b/verif/generator/tosa_verif_build_tests.py
@@ -0,0 +1,246 @@
+# Copyright (c) 2020-2021, ARM Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import sys
+import re
+import os
+import subprocess
+import shlex
+import json
+import glob
+import math
+import queue
+import threading
+import traceback
+
+
+from enum import IntEnum, Enum, unique
+from datetime import datetime
+
+from generator.tosa_test_gen import TosaTestGen
+from serializer.tosa_serializer import dtype_str_to_val
+
+# Used for parsing a comma-separated list of integers in a string
+# to an actual list of integers
+def str_to_list(in_s):
+ """Converts a comma-separated list of string integers to a python list of ints"""
+ lst = in_s.split(",")
+ out_list = []
+ for i in lst:
+ out_list.append(int(i))
+ return out_list
+
+
+def auto_int(x):
+ """Converts hex/dec argument values to an int"""
+ return int(x, 0)
+
+
+def parseArgs():
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "-o", dest="output_dir", type=str, default="vtest", help="Test output directory"
+ )
+
+ parser.add_argument(
+ "--seed",
+ dest="random_seed",
+ default=42,
+ type=int,
+ help="Random seed for test generation",
+ )
+
+ parser.add_argument(
+ "--filter",
+ dest="filter",
+ default="",
+ type=str,
+ help="Filter operator test names by this expression",
+ )
+
+ parser.add_argument(
+ "-v", "--verbose", dest="verbose", action="count", help="Verbose operation"
+ )
+
+ # Constraints on tests
+ parser.add_argument(
+ "--tensor-dim-range",
+ dest="tensor_shape_range",
+ default="1,64",
+ type=lambda x: str_to_list(x),
+ help="Min,Max range of tensor shapes",
+ )
+
+ parser.add_argument(
+ "--max-batch-size",
+ dest="max_batch_size",
+ default=1,
+ type=int,
+ help="Maximum batch size for NHWC tests",
+ )
+
+ parser.add_argument(
+ "--max-conv-padding",
+ dest="max_conv_padding",
+ default=1,
+ type=int,
+ help="Maximum padding for Conv tests",
+ )
+
+ parser.add_argument(
+ "--max-conv-dilation",
+ dest="max_conv_dilation",
+ default=2,
+ type=int,
+ help="Maximum dilation for Conv tests",
+ )
+
+ parser.add_argument(
+ "--max-conv-stride",
+ dest="max_conv_stride",
+ default=2,
+ type=int,
+ help="Maximum stride for Conv tests",
+ )
+
+ parser.add_argument(
+ "--max-pooling-padding",
+ dest="max_pooling_padding",
+ default=1,
+ type=int,
+ help="Maximum padding for pooling tests",
+ )
+
+ parser.add_argument(
+ "--max-pooling-stride",
+ dest="max_pooling_stride",
+ default=2,
+ type=int,
+ help="Maximum stride for pooling tests",
+ )
+
+ parser.add_argument(
+ "--max-pooling-kernel",
+ dest="max_pooling_kernel",
+ default=2,
+ type=int,
+ help="Maximum padding for pooling tests",
+ )
+
+ parser.add_argument(
+ "--num-rand-permutations",
+ dest="num_rand_permutations",
+ default=6,
+ type=int,
+ help="Number of random permutations for a given shape/rank for randomly-sampled parameter spaces",
+ )
+
+ # Targetting a specific shape/rank/dtype
+ parser.add_argument(
+ "--target-shape",
+ dest="target_shapes",
+ action="append",
+ default=[],
+ type=lambda x: str_to_list(x),
+ help="Create tests with a particular input tensor shape, e.g., 1,4,4,8 (may be repeated for tests that require multiple input shapes)",
+ )
+
+ parser.add_argument(
+ "--target-rank",
+ dest="target_ranks",
+ action="append",
+ default=None,
+ type=lambda x: auto_int(x),
+ help="Create tests with a particular input tensor rank",
+ )
+
+ parser.add_argument(
+ "--target-dtype",
+ dest="target_dtypes",
+ action="append",
+ default=None,
+ type=lambda x: dtype_str_to_val(x),
+ help="Create test with a particular DType (may be repeated)",
+ )
+
+ parser.add_argument(
+ "--num-const-inputs-concat",
+ dest="num_const_inputs_concat",
+ default=0,
+ choices=[0, 1, 2, 3],
+ type=int,
+ help="Allow constant input tensors for concat operator",
+ )
+
+ parser.add_argument(
+ "--test-type",
+ dest="test_type",
+ choices=['positive', 'negative', 'both'],
+ default="positive",
+ type=str,
+ help="type of tests produced, postive, negative, or both",
+ )
+ args = parser.parse_args()
+
+ return args
+
+
+def main():
+
+ args = parseArgs()
+
+ ttg = TosaTestGen(args)
+
+ if args.test_type == 'both':
+ testType = ['positive', 'negative']
+ else:
+ testType = [args.test_type]
+ results = []
+ for test_type in testType:
+ testList = []
+ for op in ttg.TOSA_OP_LIST:
+ if re.match(args.filter + ".*", op):
+ testList.extend(
+ ttg.genOpTestList(
+ op,
+ shapeFilter=args.target_shapes,
+ rankFilter=args.target_ranks,
+ dtypeFilter=args.target_dtypes,
+ testType=test_type
+ )
+ )
+
+ print("{} matching {} tests".format(len(testList), test_type))
+
+ testStrings = []
+ for opName, testStr, dtype, error, shapeList, testArgs in testList:
+ # Check for and skip duplicate tests
+ if testStr in testStrings:
+ continue
+ else:
+ testStrings.append(testStr)
+
+ if args.verbose:
+ print(testStr)
+ results.append(ttg.serializeTest(opName, testStr, dtype, error, shapeList, testArgs))
+
+ print(f"Done creating {len(results)} tests")
+
+
+
+if __name__ == "__main__":
+ exit(main())