From 2a76ad2368f4684a8391fe69f51e52356524bf15 Mon Sep 17 00:00:00 2001 From: Jared Smolens Date: Thu, 4 Mar 2021 11:18:54 -0800 Subject: Update DTypes for TOSA ops, test rig fixes - Updated DTypes and expected failures for TOSA ops, particularly missing int8/int16 tests for Conv, FullyConnected, MatMul - Fixed a bug where unexpected failures were incorrectly categorized as passes Change-Id: I2763626317cedad9f3723f748986bb59a32f2e42 Signed-off-by: Jared Smolens --- verif/tosa_ref_run.py | 4 ++-- verif/tosa_test_gen.py | 24 ++++++++++++++++++------ 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/verif/tosa_ref_run.py b/verif/tosa_ref_run.py index 99f504b..2035147 100644 --- a/verif/tosa_ref_run.py +++ b/verif/tosa_ref_run.py @@ -1,6 +1,6 @@ import os -# Copyright (c) 2020, ARM Limited. +# Copyright (c) 2020-2021, ARM Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -61,6 +61,6 @@ class TosaRefRunner(TosaTestRunner): if expectedFailure: result = TosaTestRunner.Result.EXPECTED_FAILURE else: - result = TosaTestRunner.Result.EXPECTED_PASS + result = TosaTestRunner.Result.UNEXPECTED_FAILURE return result diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py index 57030e7..ae1a5c6 100644 --- a/verif/tosa_test_gen.py +++ b/verif/tosa_test_gen.py @@ -50,7 +50,7 @@ class TosaQuantGen: @staticmethod def needsQinfo(op, dtype): - if dtype == DType.INT8: + if dtype == DType.INT8 or dtype == DType.INT16: return True return False @@ -1754,7 +1754,7 @@ class TosaTestGen: { 'op': Op.ARGMAX, 'operands': (1, 0), 'build_fcn': (build_argmax, TosaTensorGen.tgBasic, TosaArgGen.agAxis), - 'types': TYPE_FP }, + 'types': TYPE_NARROW_INT_FP }, # Templated operator. Filled in by createDynamicOpLists 'conv2d_TEMPLATE': @@ -1763,7 +1763,7 @@ class TosaTestGen: 'rank': (4, 4), 'build_fcn': (build_conv2d, TosaTensorGen.tgConv2D, TosaArgGen.agConv2D), 'qgen': TosaQuantGen.qgConv, - 'types': TYPE_FP, + 'types': TYPE_NARROW_INT_FP, 'template': True }, # Templated operator. Filled in by createDynamicOpLists @@ -1774,7 +1774,7 @@ class TosaTestGen: 'rank': (4, 4), 'build_fcn': (build_depthwise_conv2d, TosaTensorGen.tgDepthwiseConv2D, TosaArgGen.agConv2D), 'qgen': TosaQuantGen.qgConv, - 'types': TYPE_FP, + 'types': TYPE_NARROW_INT_FP, 'template': True }, # Templated operator. Filled in by createDynamicOpLists @@ -1784,7 +1784,7 @@ class TosaTestGen: 'rank': (4, 4), 'build_fcn': (build_transpose_conv2d, TosaTensorGen.tgTransposeConv2D, TosaArgGen.agTransposeConv2D), 'qgen': TosaQuantGen.qgConv, - 'types': TYPE_FP, + 'types': TYPE_NARROW_INT_FP, 'template': True }, 'fully_connected': @@ -1793,7 +1793,7 @@ class TosaTestGen: 'rank': (2, 2), 'build_fcn': (build_fully_connected, TosaTensorGen.tgFullyConnected, None), 'qgen': TosaQuantGen.qgConv, - 'types': TYPE_FP }, + 'types': TYPE_NARROW_INT_FP }, 'matmul': { 'op': Op.MATMUL, @@ -2239,6 +2239,9 @@ class OutputShaper: else: raise Exception('Unsupported input dtype: {}'.format(ifm.dtype)) + if ifm.dtype == DType.INT16: + ser.setExpectedFailure(True, "INT16 support is in progress") + return ser.addOutput(ofm_shape, out_dtype, ifm.usage, ifm.dformat) @staticmethod @@ -2269,6 +2272,9 @@ class OutputShaper: else: raise Exception('Unsupported input dtype: {}'.format(ifm.dtype)) + if ifm.dtype == DType.INT16: + ser.setExpectedFailure(True, "INT16 support is in progress") + return ser.addOutput(ofm_shape, out_dtype, ifm.usage, ifm.dformat) @@ -2304,6 +2310,9 @@ class OutputShaper: else: raise Exception('Unsupported input dtype: {}'.format(input.dtype)) + if input.dtype == DType.INT16: + ser.setExpectedFailure(True, "INT16 support is in progress") + return ser.addOutput(output_shape, out_dtype, input.usage, input.dformat) @staticmethod @@ -2480,4 +2489,7 @@ class OutputShaper: if output_shape[1] <= 0 or output_shape[2] <= 0: ser.setExpectedFailure(True, 'Negative output shape') + if ifm.dtype == DType.INT16: + ser.setExpectedFailure(True, "INT16 support is in progress") + return ser.addOutput(output_shape, out_dtype, ifm.usage, ifm.dformat) -- cgit v1.2.1