From 2ec3494060ffdafec072fe1b2099a8177b8eca6a Mon Sep 17 00:00:00 2001 From: Jeremy Johnson Date: Tue, 14 Dec 2021 16:34:05 +0000 Subject: Reorganize verif and create packages Split generator and runner scripts Add package setup Add py-dev-env.sh/.bash to allow editing source files during dev Update README.md with installation info Signed-off-by: Jeremy Johnson Change-Id: I172fe426d99e2e9aeeacedc8b8f3b6a79c8bd39d --- verif/generator/tosa_verif_build_tests.py | 246 ++++++++++++++++++++++++++++++ 1 file changed, 246 insertions(+) create mode 100644 verif/generator/tosa_verif_build_tests.py (limited to 'verif/generator/tosa_verif_build_tests.py') diff --git a/verif/generator/tosa_verif_build_tests.py b/verif/generator/tosa_verif_build_tests.py new file mode 100644 index 0000000..09ee238 --- /dev/null +++ b/verif/generator/tosa_verif_build_tests.py @@ -0,0 +1,246 @@ +# Copyright (c) 2020-2021, ARM Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import sys +import re +import os +import subprocess +import shlex +import json +import glob +import math +import queue +import threading +import traceback + + +from enum import IntEnum, Enum, unique +from datetime import datetime + +from generator.tosa_test_gen import TosaTestGen +from serializer.tosa_serializer import dtype_str_to_val + +# Used for parsing a comma-separated list of integers in a string +# to an actual list of integers +def str_to_list(in_s): + """Converts a comma-separated list of string integers to a python list of ints""" + lst = in_s.split(",") + out_list = [] + for i in lst: + out_list.append(int(i)) + return out_list + + +def auto_int(x): + """Converts hex/dec argument values to an int""" + return int(x, 0) + + +def parseArgs(): + + parser = argparse.ArgumentParser() + parser.add_argument( + "-o", dest="output_dir", type=str, default="vtest", help="Test output directory" + ) + + parser.add_argument( + "--seed", + dest="random_seed", + default=42, + type=int, + help="Random seed for test generation", + ) + + parser.add_argument( + "--filter", + dest="filter", + default="", + type=str, + help="Filter operator test names by this expression", + ) + + parser.add_argument( + "-v", "--verbose", dest="verbose", action="count", help="Verbose operation" + ) + + # Constraints on tests + parser.add_argument( + "--tensor-dim-range", + dest="tensor_shape_range", + default="1,64", + type=lambda x: str_to_list(x), + help="Min,Max range of tensor shapes", + ) + + parser.add_argument( + "--max-batch-size", + dest="max_batch_size", + default=1, + type=int, + help="Maximum batch size for NHWC tests", + ) + + parser.add_argument( + "--max-conv-padding", + dest="max_conv_padding", + default=1, + type=int, + help="Maximum padding for Conv tests", + ) + + parser.add_argument( + "--max-conv-dilation", + dest="max_conv_dilation", + default=2, + type=int, + help="Maximum dilation for Conv tests", + ) + + parser.add_argument( + "--max-conv-stride", + dest="max_conv_stride", + default=2, + type=int, + help="Maximum stride for Conv tests", + ) + + parser.add_argument( + "--max-pooling-padding", + dest="max_pooling_padding", + default=1, + type=int, + help="Maximum padding for pooling tests", + ) + + parser.add_argument( + "--max-pooling-stride", + dest="max_pooling_stride", + default=2, + type=int, + help="Maximum stride for pooling tests", + ) + + parser.add_argument( + "--max-pooling-kernel", + dest="max_pooling_kernel", + default=2, + type=int, + help="Maximum padding for pooling tests", + ) + + parser.add_argument( + "--num-rand-permutations", + dest="num_rand_permutations", + default=6, + type=int, + help="Number of random permutations for a given shape/rank for randomly-sampled parameter spaces", + ) + + # Targetting a specific shape/rank/dtype + parser.add_argument( + "--target-shape", + dest="target_shapes", + action="append", + default=[], + type=lambda x: str_to_list(x), + help="Create tests with a particular input tensor shape, e.g., 1,4,4,8 (may be repeated for tests that require multiple input shapes)", + ) + + parser.add_argument( + "--target-rank", + dest="target_ranks", + action="append", + default=None, + type=lambda x: auto_int(x), + help="Create tests with a particular input tensor rank", + ) + + parser.add_argument( + "--target-dtype", + dest="target_dtypes", + action="append", + default=None, + type=lambda x: dtype_str_to_val(x), + help="Create test with a particular DType (may be repeated)", + ) + + parser.add_argument( + "--num-const-inputs-concat", + dest="num_const_inputs_concat", + default=0, + choices=[0, 1, 2, 3], + type=int, + help="Allow constant input tensors for concat operator", + ) + + parser.add_argument( + "--test-type", + dest="test_type", + choices=['positive', 'negative', 'both'], + default="positive", + type=str, + help="type of tests produced, postive, negative, or both", + ) + args = parser.parse_args() + + return args + + +def main(): + + args = parseArgs() + + ttg = TosaTestGen(args) + + if args.test_type == 'both': + testType = ['positive', 'negative'] + else: + testType = [args.test_type] + results = [] + for test_type in testType: + testList = [] + for op in ttg.TOSA_OP_LIST: + if re.match(args.filter + ".*", op): + testList.extend( + ttg.genOpTestList( + op, + shapeFilter=args.target_shapes, + rankFilter=args.target_ranks, + dtypeFilter=args.target_dtypes, + testType=test_type + ) + ) + + print("{} matching {} tests".format(len(testList), test_type)) + + testStrings = [] + for opName, testStr, dtype, error, shapeList, testArgs in testList: + # Check for and skip duplicate tests + if testStr in testStrings: + continue + else: + testStrings.append(testStr) + + if args.verbose: + print(testStr) + results.append(ttg.serializeTest(opName, testStr, dtype, error, shapeList, testArgs)) + + print(f"Done creating {len(results)} tests") + + + +if __name__ == "__main__": + exit(main()) -- cgit v1.2.1