aboutsummaryrefslogtreecommitdiff
path: root/verif/tosa_verif_build_tests.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/tosa_verif_build_tests.py')
-rwxr-xr-xverif/tosa_verif_build_tests.py136
1 files changed, 136 insertions, 0 deletions
diff --git a/verif/tosa_verif_build_tests.py b/verif/tosa_verif_build_tests.py
new file mode 100755
index 0000000..19eb2f4
--- /dev/null
+++ b/verif/tosa_verif_build_tests.py
@@ -0,0 +1,136 @@
+#!/usr/bin/env python3
+
+# Copyright (c) 2020, ARM Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import sys
+import re
+import os
+import subprocess
+import shlex
+import json
+import glob
+import math
+import queue
+import threading
+import traceback
+
+
+from enum import IntEnum, Enum, unique
+from datetime import datetime
+
+# Include the ../shared directory in PYTHONPATH
+parent_dir = os.path.dirname(os.path.realpath(__file__))
+sys.path.append(os.path.join(parent_dir, '..', 'scripts'))
+sys.path.append(os.path.join(parent_dir, '..', 'scripts', 'xunit'))
+import xunit
+from tosa_serializer import *
+from tosa_test_gen import TosaTestGen
+import tosa
+
+# Used for parsing a comma-separated list of integers in a string
+# to an actual list of integers
+def str_to_list(in_s):
+ '''Converts a comma-separated list of string integers to a python list of ints'''
+ lst = in_s.split(',')
+ out_list = []
+ for i in lst:
+ out_list.append(int(i))
+ return out_list
+
+def auto_int(x):
+ '''Converts hex/dec argument values to an int'''
+ return int(x, 0)
+
+def parseArgs():
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-o', dest='output_dir', type=str, default='vtest',
+ help='Test output directory')
+
+ parser.add_argument('--seed', dest='random_seed', default=42, type=int,
+ help='Random seed for test generation')
+
+ parser.add_argument('--filter', dest='filter', default='', type=str,
+ help='Filter operator test names by this expression')
+
+ parser.add_argument('-v', '--verbose', dest='verbose', action='count',
+ help='Verbose operation')
+
+ # Constraints on tests
+ parser.add_argument('--tensor-dim-range', dest='tensor_shape_range', default='1,64',
+ type=lambda x: str_to_list(x),
+ help='Min,Max range of tensor shapes')
+
+ parser.add_argument('--max-batch-size', dest='max_batch_size', default=1, type=int,
+ help='Maximum batch size for NHWC tests')
+
+ parser.add_argument('--max-conv-padding', dest='max_conv_padding', default=1, type=int,
+ help='Maximum padding for Conv tests')
+
+ parser.add_argument('--max-conv-dilation', dest='max_conv_dilation', default=2, type=int,
+ help='Maximum dilation for Conv tests')
+
+ parser.add_argument('--max-conv-stride', dest='max_conv_stride', default=2, type=int,
+ help='Maximum stride for Conv tests')
+
+ parser.add_argument('--max-pooling-padding', dest='max_pooling_padding', default=1, type=int,
+ help='Maximum padding for pooling tests')
+
+ parser.add_argument('--max-pooling-stride', dest='max_pooling_stride', default=2, type=int,
+ help='Maximum stride for pooling tests')
+
+ parser.add_argument('--max-pooling-kernel', dest='max_pooling_kernel', default=2, type=int,
+ help='Maximum padding for pooling tests')
+
+ parser.add_argument('--num-rand-permutations', dest='num_rand_permutations', default=6, type=int,
+ help='Number of random permutations for a given shape/rank for randomly-sampled parameter spaces')
+
+ # Targetting a specific shape/rank/dtype
+ parser.add_argument('--target-shape', dest='target_shapes', action='append', default=[], type=lambda x: str_to_list(x),
+ help='Create tests with a particular input tensor shape, e.g., 1,4,4,8 (may be repeated for tests that require multiple input shapes)')
+
+ parser.add_argument('--target-rank', dest='target_ranks', action='append', default=None, type=lambda x: auto_int(x),
+ help='Create tests with a particular input tensor rank')
+
+ parser.add_argument('--target-dtype', dest='target_dtypes', action='append', default=None, type=lambda x: dtype_str_to_val(x),
+ help='Create test with a particular DType (may be repeated)')
+
+ args = parser.parse_args()
+
+ return args
+
+def main():
+
+
+ args = parseArgs()
+
+ ttg = TosaTestGen(args)
+
+ testList = []
+ for op in ttg.TOSA_OP_LIST:
+ if re.match(args.filter + '.*', op):
+ testList.extend(ttg.genOpTestList(op, shapeFilter=args.target_shapes, rankFilter=args.target_ranks, dtypeFilter=args.target_dtypes))
+
+ print('{} matching tests'.format(len(testList)))
+ for opName, testStr, dtype, shapeList, testArgs in testList:
+ print(testStr)
+ ttg.serializeTest(opName, testStr, dtype, shapeList, testArgs)
+ print('Done creating {} tests'.format(len(testList)))
+
+
+if __name__ == '__main__':
+ exit(main())