aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJared Smolens <jared.smolens@arm.com>2021-03-04 11:18:54 -0800
committerJared Smolens <jared.smolens@arm.com>2021-03-04 13:21:03 -0800
commit2a76ad2368f4684a8391fe69f51e52356524bf15 (patch)
treec05d26b964e2006c4c88d00facab17ee110f7f2c
parentdf8626976df6c779bb30df9c5ceef689462109c0 (diff)
downloadreference_model-2a76ad2368f4684a8391fe69f51e52356524bf15.tar.gz
Update DTypes for TOSA ops, test rig fixes
- Updated DTypes and expected failures for TOSA ops, particularly missing int8/int16 tests for Conv, FullyConnected, MatMul - Fixed a bug where unexpected failures were incorrectly categorized as passes Change-Id: I2763626317cedad9f3723f748986bb59a32f2e42 Signed-off-by: Jared Smolens <jared.smolens@arm.com>
-rw-r--r--verif/tosa_ref_run.py4
-rw-r--r--verif/tosa_test_gen.py24
2 files changed, 20 insertions, 8 deletions
diff --git a/verif/tosa_ref_run.py b/verif/tosa_ref_run.py
index 99f504b..2035147 100644
--- a/verif/tosa_ref_run.py
+++ b/verif/tosa_ref_run.py
@@ -1,6 +1,6 @@
import os
-# Copyright (c) 2020, ARM Limited.
+# Copyright (c) 2020-2021, ARM Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -61,6 +61,6 @@ class TosaRefRunner(TosaTestRunner):
if expectedFailure:
result = TosaTestRunner.Result.EXPECTED_FAILURE
else:
- result = TosaTestRunner.Result.EXPECTED_PASS
+ result = TosaTestRunner.Result.UNEXPECTED_FAILURE
return result
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py
index 57030e7..ae1a5c6 100644
--- a/verif/tosa_test_gen.py
+++ b/verif/tosa_test_gen.py
@@ -50,7 +50,7 @@ class TosaQuantGen:
@staticmethod
def needsQinfo(op, dtype):
- if dtype == DType.INT8:
+ if dtype == DType.INT8 or dtype == DType.INT16:
return True
return False
@@ -1754,7 +1754,7 @@ class TosaTestGen:
{ 'op': Op.ARGMAX,
'operands': (1, 0),
'build_fcn': (build_argmax, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
- 'types': TYPE_FP },
+ 'types': TYPE_NARROW_INT_FP },
# Templated operator. Filled in by createDynamicOpLists
'conv2d_TEMPLATE':
@@ -1763,7 +1763,7 @@ class TosaTestGen:
'rank': (4, 4),
'build_fcn': (build_conv2d, TosaTensorGen.tgConv2D, TosaArgGen.agConv2D),
'qgen': TosaQuantGen.qgConv,
- 'types': TYPE_FP,
+ 'types': TYPE_NARROW_INT_FP,
'template': True },
# Templated operator. Filled in by createDynamicOpLists
@@ -1774,7 +1774,7 @@ class TosaTestGen:
'rank': (4, 4),
'build_fcn': (build_depthwise_conv2d, TosaTensorGen.tgDepthwiseConv2D, TosaArgGen.agConv2D),
'qgen': TosaQuantGen.qgConv,
- 'types': TYPE_FP,
+ 'types': TYPE_NARROW_INT_FP,
'template': True },
# Templated operator. Filled in by createDynamicOpLists
@@ -1784,7 +1784,7 @@ class TosaTestGen:
'rank': (4, 4),
'build_fcn': (build_transpose_conv2d, TosaTensorGen.tgTransposeConv2D, TosaArgGen.agTransposeConv2D),
'qgen': TosaQuantGen.qgConv,
- 'types': TYPE_FP,
+ 'types': TYPE_NARROW_INT_FP,
'template': True },
'fully_connected':
@@ -1793,7 +1793,7 @@ class TosaTestGen:
'rank': (2, 2),
'build_fcn': (build_fully_connected, TosaTensorGen.tgFullyConnected, None),
'qgen': TosaQuantGen.qgConv,
- 'types': TYPE_FP },
+ 'types': TYPE_NARROW_INT_FP },
'matmul':
{ 'op': Op.MATMUL,
@@ -2239,6 +2239,9 @@ class OutputShaper:
else:
raise Exception('Unsupported input dtype: {}'.format(ifm.dtype))
+ if ifm.dtype == DType.INT16:
+ ser.setExpectedFailure(True, "INT16 support is in progress")
+
return ser.addOutput(ofm_shape, out_dtype, ifm.usage, ifm.dformat)
@staticmethod
@@ -2269,6 +2272,9 @@ class OutputShaper:
else:
raise Exception('Unsupported input dtype: {}'.format(ifm.dtype))
+ if ifm.dtype == DType.INT16:
+ ser.setExpectedFailure(True, "INT16 support is in progress")
+
return ser.addOutput(ofm_shape, out_dtype, ifm.usage, ifm.dformat)
@@ -2304,6 +2310,9 @@ class OutputShaper:
else:
raise Exception('Unsupported input dtype: {}'.format(input.dtype))
+ if input.dtype == DType.INT16:
+ ser.setExpectedFailure(True, "INT16 support is in progress")
+
return ser.addOutput(output_shape, out_dtype, input.usage, input.dformat)
@staticmethod
@@ -2480,4 +2489,7 @@ class OutputShaper:
if output_shape[1] <= 0 or output_shape[2] <= 0:
ser.setExpectedFailure(True, 'Negative output shape')
+ if ifm.dtype == DType.INT16:
+ ser.setExpectedFailure(True, "INT16 support is in progress")
+
return ser.addOutput(output_shape, out_dtype, ifm.usage, ifm.dformat)