diff options
author | Jared Smolens <jared.smolens@arm.com> | 2021-03-04 11:18:54 -0800 |
---|---|---|
committer | Jared Smolens <jared.smolens@arm.com> | 2021-03-04 13:21:03 -0800 |
commit | 2a76ad2368f4684a8391fe69f51e52356524bf15 (patch) | |
tree | c05d26b964e2006c4c88d00facab17ee110f7f2c | |
parent | df8626976df6c779bb30df9c5ceef689462109c0 (diff) | |
download | reference_model-2a76ad2368f4684a8391fe69f51e52356524bf15.tar.gz |
Update DTypes for TOSA ops, test rig fixes
- Updated DTypes and expected failures for TOSA ops,
particularly missing int8/int16 tests for Conv, FullyConnected,
MatMul
- Fixed a bug where unexpected failures were incorrectly
categorized as passes
Change-Id: I2763626317cedad9f3723f748986bb59a32f2e42
Signed-off-by: Jared Smolens <jared.smolens@arm.com>
-rw-r--r-- | verif/tosa_ref_run.py | 4 | ||||
-rw-r--r-- | verif/tosa_test_gen.py | 24 |
2 files changed, 20 insertions, 8 deletions
diff --git a/verif/tosa_ref_run.py b/verif/tosa_ref_run.py index 99f504b..2035147 100644 --- a/verif/tosa_ref_run.py +++ b/verif/tosa_ref_run.py @@ -1,6 +1,6 @@ import os -# Copyright (c) 2020, ARM Limited. +# Copyright (c) 2020-2021, ARM Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -61,6 +61,6 @@ class TosaRefRunner(TosaTestRunner): if expectedFailure: result = TosaTestRunner.Result.EXPECTED_FAILURE else: - result = TosaTestRunner.Result.EXPECTED_PASS + result = TosaTestRunner.Result.UNEXPECTED_FAILURE return result diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py index 57030e7..ae1a5c6 100644 --- a/verif/tosa_test_gen.py +++ b/verif/tosa_test_gen.py @@ -50,7 +50,7 @@ class TosaQuantGen: @staticmethod def needsQinfo(op, dtype): - if dtype == DType.INT8: + if dtype == DType.INT8 or dtype == DType.INT16: return True return False @@ -1754,7 +1754,7 @@ class TosaTestGen: { 'op': Op.ARGMAX, 'operands': (1, 0), 'build_fcn': (build_argmax, TosaTensorGen.tgBasic, TosaArgGen.agAxis), - 'types': TYPE_FP }, + 'types': TYPE_NARROW_INT_FP }, # Templated operator. Filled in by createDynamicOpLists 'conv2d_TEMPLATE': @@ -1763,7 +1763,7 @@ class TosaTestGen: 'rank': (4, 4), 'build_fcn': (build_conv2d, TosaTensorGen.tgConv2D, TosaArgGen.agConv2D), 'qgen': TosaQuantGen.qgConv, - 'types': TYPE_FP, + 'types': TYPE_NARROW_INT_FP, 'template': True }, # Templated operator. Filled in by createDynamicOpLists @@ -1774,7 +1774,7 @@ class TosaTestGen: 'rank': (4, 4), 'build_fcn': (build_depthwise_conv2d, TosaTensorGen.tgDepthwiseConv2D, TosaArgGen.agConv2D), 'qgen': TosaQuantGen.qgConv, - 'types': TYPE_FP, + 'types': TYPE_NARROW_INT_FP, 'template': True }, # Templated operator. Filled in by createDynamicOpLists @@ -1784,7 +1784,7 @@ class TosaTestGen: 'rank': (4, 4), 'build_fcn': (build_transpose_conv2d, TosaTensorGen.tgTransposeConv2D, TosaArgGen.agTransposeConv2D), 'qgen': TosaQuantGen.qgConv, - 'types': TYPE_FP, + 'types': TYPE_NARROW_INT_FP, 'template': True }, 'fully_connected': @@ -1793,7 +1793,7 @@ class TosaTestGen: 'rank': (2, 2), 'build_fcn': (build_fully_connected, TosaTensorGen.tgFullyConnected, None), 'qgen': TosaQuantGen.qgConv, - 'types': TYPE_FP }, + 'types': TYPE_NARROW_INT_FP }, 'matmul': { 'op': Op.MATMUL, @@ -2239,6 +2239,9 @@ class OutputShaper: else: raise Exception('Unsupported input dtype: {}'.format(ifm.dtype)) + if ifm.dtype == DType.INT16: + ser.setExpectedFailure(True, "INT16 support is in progress") + return ser.addOutput(ofm_shape, out_dtype, ifm.usage, ifm.dformat) @staticmethod @@ -2269,6 +2272,9 @@ class OutputShaper: else: raise Exception('Unsupported input dtype: {}'.format(ifm.dtype)) + if ifm.dtype == DType.INT16: + ser.setExpectedFailure(True, "INT16 support is in progress") + return ser.addOutput(ofm_shape, out_dtype, ifm.usage, ifm.dformat) @@ -2304,6 +2310,9 @@ class OutputShaper: else: raise Exception('Unsupported input dtype: {}'.format(input.dtype)) + if input.dtype == DType.INT16: + ser.setExpectedFailure(True, "INT16 support is in progress") + return ser.addOutput(output_shape, out_dtype, input.usage, input.dformat) @staticmethod @@ -2480,4 +2489,7 @@ class OutputShaper: if output_shape[1] <= 0 or output_shape[2] <= 0: ser.setExpectedFailure(True, 'Negative output shape') + if ifm.dtype == DType.INT16: + ser.setExpectedFailure(True, "INT16 support is in progress") + return ser.addOutput(output_shape, out_dtype, ifm.usage, ifm.dformat) |