From e5e2676409a936431f87d31fb74d825257b20804 Mon Sep 17 00:00:00 2001 From: Eric Kunze Date: Tue, 13 Oct 2020 16:11:07 -0700 Subject: Initial checkin of TOSA reference_model and tests Change-Id: I2f8e7fa63e2ae40203e57d2cc8814bde3b312cb6 Signed-off-by: Eric Kunze --- verif/tosa_verif_build_tests.py | 136 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 136 insertions(+) create mode 100755 verif/tosa_verif_build_tests.py (limited to 'verif/tosa_verif_build_tests.py') diff --git a/verif/tosa_verif_build_tests.py b/verif/tosa_verif_build_tests.py new file mode 100755 index 0000000..19eb2f4 --- /dev/null +++ b/verif/tosa_verif_build_tests.py @@ -0,0 +1,136 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2020, ARM Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import sys +import re +import os +import subprocess +import shlex +import json +import glob +import math +import queue +import threading +import traceback + + +from enum import IntEnum, Enum, unique +from datetime import datetime + +# Include the ../shared directory in PYTHONPATH +parent_dir = os.path.dirname(os.path.realpath(__file__)) +sys.path.append(os.path.join(parent_dir, '..', 'scripts')) +sys.path.append(os.path.join(parent_dir, '..', 'scripts', 'xunit')) +import xunit +from tosa_serializer import * +from tosa_test_gen import TosaTestGen +import tosa + +# Used for parsing a comma-separated list of integers in a string +# to an actual list of integers +def str_to_list(in_s): + '''Converts a comma-separated list of string integers to a python list of ints''' + lst = in_s.split(',') + out_list = [] + for i in lst: + out_list.append(int(i)) + return out_list + +def auto_int(x): + '''Converts hex/dec argument values to an int''' + return int(x, 0) + +def parseArgs(): + + parser = argparse.ArgumentParser() + parser.add_argument('-o', dest='output_dir', type=str, default='vtest', + help='Test output directory') + + parser.add_argument('--seed', dest='random_seed', default=42, type=int, + help='Random seed for test generation') + + parser.add_argument('--filter', dest='filter', default='', type=str, + help='Filter operator test names by this expression') + + parser.add_argument('-v', '--verbose', dest='verbose', action='count', + help='Verbose operation') + + # Constraints on tests + parser.add_argument('--tensor-dim-range', dest='tensor_shape_range', default='1,64', + type=lambda x: str_to_list(x), + help='Min,Max range of tensor shapes') + + parser.add_argument('--max-batch-size', dest='max_batch_size', default=1, type=int, + help='Maximum batch size for NHWC tests') + + parser.add_argument('--max-conv-padding', dest='max_conv_padding', default=1, type=int, + help='Maximum padding for Conv tests') + + parser.add_argument('--max-conv-dilation', dest='max_conv_dilation', default=2, type=int, + help='Maximum dilation for Conv tests') + + parser.add_argument('--max-conv-stride', dest='max_conv_stride', default=2, type=int, + help='Maximum stride for Conv tests') + + parser.add_argument('--max-pooling-padding', dest='max_pooling_padding', default=1, type=int, + help='Maximum padding for pooling tests') + + parser.add_argument('--max-pooling-stride', dest='max_pooling_stride', default=2, type=int, + help='Maximum stride for pooling tests') + + parser.add_argument('--max-pooling-kernel', dest='max_pooling_kernel', default=2, type=int, + help='Maximum padding for pooling tests') + + parser.add_argument('--num-rand-permutations', dest='num_rand_permutations', default=6, type=int, + help='Number of random permutations for a given shape/rank for randomly-sampled parameter spaces') + + # Targetting a specific shape/rank/dtype + parser.add_argument('--target-shape', dest='target_shapes', action='append', default=[], type=lambda x: str_to_list(x), + help='Create tests with a particular input tensor shape, e.g., 1,4,4,8 (may be repeated for tests that require multiple input shapes)') + + parser.add_argument('--target-rank', dest='target_ranks', action='append', default=None, type=lambda x: auto_int(x), + help='Create tests with a particular input tensor rank') + + parser.add_argument('--target-dtype', dest='target_dtypes', action='append', default=None, type=lambda x: dtype_str_to_val(x), + help='Create test with a particular DType (may be repeated)') + + args = parser.parse_args() + + return args + +def main(): + + + args = parseArgs() + + ttg = TosaTestGen(args) + + testList = [] + for op in ttg.TOSA_OP_LIST: + if re.match(args.filter + '.*', op): + testList.extend(ttg.genOpTestList(op, shapeFilter=args.target_shapes, rankFilter=args.target_ranks, dtypeFilter=args.target_dtypes)) + + print('{} matching tests'.format(len(testList))) + for opName, testStr, dtype, shapeList, testArgs in testList: + print(testStr) + ttg.serializeTest(opName, testStr, dtype, shapeList, testArgs) + print('Done creating {} tests'.format(len(testList))) + + +if __name__ == '__main__': + exit(main()) -- cgit v1.2.1