aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--verif/tosa_ref_run.py4
-rw-r--r--verif/tosa_test_gen.py24
2 files changed, 20 insertions, 8 deletions
diff --git a/verif/tosa_ref_run.py b/verif/tosa_ref_run.py
index 99f504b..2035147 100644
--- a/verif/tosa_ref_run.py
+++ b/verif/tosa_ref_run.py
@@ -1,6 +1,6 @@
import os
-# Copyright (c) 2020, ARM Limited.
+# Copyright (c) 2020-2021, ARM Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -61,6 +61,6 @@ class TosaRefRunner(TosaTestRunner):
if expectedFailure:
result = TosaTestRunner.Result.EXPECTED_FAILURE
else:
- result = TosaTestRunner.Result.EXPECTED_PASS
+ result = TosaTestRunner.Result.UNEXPECTED_FAILURE
return result
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py
index 57030e7..ae1a5c6 100644
--- a/verif/tosa_test_gen.py
+++ b/verif/tosa_test_gen.py
@@ -50,7 +50,7 @@ class TosaQuantGen:
@staticmethod
def needsQinfo(op, dtype):
- if dtype == DType.INT8:
+ if dtype == DType.INT8 or dtype == DType.INT16:
return True
return False
@@ -1754,7 +1754,7 @@ class TosaTestGen:
{ 'op': Op.ARGMAX,
'operands': (1, 0),
'build_fcn': (build_argmax, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
- 'types': TYPE_FP },
+ 'types': TYPE_NARROW_INT_FP },
# Templated operator. Filled in by createDynamicOpLists
'conv2d_TEMPLATE':
@@ -1763,7 +1763,7 @@ class TosaTestGen:
'rank': (4, 4),
'build_fcn': (build_conv2d, TosaTensorGen.tgConv2D, TosaArgGen.agConv2D),
'qgen': TosaQuantGen.qgConv,
- 'types': TYPE_FP,
+ 'types': TYPE_NARROW_INT_FP,
'template': True },
# Templated operator. Filled in by createDynamicOpLists
@@ -1774,7 +1774,7 @@ class TosaTestGen:
'rank': (4, 4),
'build_fcn': (build_depthwise_conv2d, TosaTensorGen.tgDepthwiseConv2D, TosaArgGen.agConv2D),
'qgen': TosaQuantGen.qgConv,
- 'types': TYPE_FP,
+ 'types': TYPE_NARROW_INT_FP,
'template': True },
# Templated operator. Filled in by createDynamicOpLists
@@ -1784,7 +1784,7 @@ class TosaTestGen:
'rank': (4, 4),
'build_fcn': (build_transpose_conv2d, TosaTensorGen.tgTransposeConv2D, TosaArgGen.agTransposeConv2D),
'qgen': TosaQuantGen.qgConv,
- 'types': TYPE_FP,
+ 'types': TYPE_NARROW_INT_FP,
'template': True },
'fully_connected':
@@ -1793,7 +1793,7 @@ class TosaTestGen:
'rank': (2, 2),
'build_fcn': (build_fully_connected, TosaTensorGen.tgFullyConnected, None),
'qgen': TosaQuantGen.qgConv,
- 'types': TYPE_FP },
+ 'types': TYPE_NARROW_INT_FP },
'matmul':
{ 'op': Op.MATMUL,
@@ -2239,6 +2239,9 @@ class OutputShaper:
else:
raise Exception('Unsupported input dtype: {}'.format(ifm.dtype))
+ if ifm.dtype == DType.INT16:
+ ser.setExpectedFailure(True, "INT16 support is in progress")
+
return ser.addOutput(ofm_shape, out_dtype, ifm.usage, ifm.dformat)
@staticmethod
@@ -2269,6 +2272,9 @@ class OutputShaper:
else:
raise Exception('Unsupported input dtype: {}'.format(ifm.dtype))
+ if ifm.dtype == DType.INT16:
+ ser.setExpectedFailure(True, "INT16 support is in progress")
+
return ser.addOutput(ofm_shape, out_dtype, ifm.usage, ifm.dformat)
@@ -2304,6 +2310,9 @@ class OutputShaper:
else:
raise Exception('Unsupported input dtype: {}'.format(input.dtype))
+ if input.dtype == DType.INT16:
+ ser.setExpectedFailure(True, "INT16 support is in progress")
+
return ser.addOutput(output_shape, out_dtype, input.usage, input.dformat)
@staticmethod
@@ -2480,4 +2489,7 @@ class OutputShaper:
if output_shape[1] <= 0 or output_shape[2] <= 0:
ser.setExpectedFailure(True, 'Negative output shape')
+ if ifm.dtype == DType.INT16:
+ ser.setExpectedFailure(True, "INT16 support is in progress")
+
return ser.addOutput(output_shape, out_dtype, ifm.usage, ifm.dformat)