aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2022-05-25 15:26:38 +0100
committerJeremy Johnson <jeremy.johnson@arm.com>2022-05-26 09:53:44 +0100
commitf7f78ae236e623a57919f9450e8b2043e681ddb3 (patch)
tree0456c0006fbce5efbebe93818398b3a0cd7cd76c
parent0e6218e22f25901aa208fbec44c9b14e14a68ba7 (diff)
downloadreference_model-f7f78ae236e623a57919f9450e8b2043e681ddb3.tar.gz
Add support for uint16_t to RESCALE
Update ref-model RESCALE op to support UINT16 conversions Add testing for RESCALE UINT16 and ERROR_IFs Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: Ic6e6e53de1f0b054bedb9e6ba3856e7475498aba
-rw-r--r--reference_model/src/ops/op_factory.cc4
-rw-r--r--reference_model/src/ops/template_types.h22
-rw-r--r--reference_model/src/ops/type_conversion.cc24
-rw-r--r--reference_model/src/quant_util.h2
-rw-r--r--reference_model/src/tensor.cc3
-rw-r--r--reference_model/src/tensor.h1
m---------thirdparty/serialization_lib0
-rw-r--r--verif/generator/tosa_arg_gen.py49
-rw-r--r--verif/generator/tosa_error_if.py138
-rw-r--r--verif/generator/tosa_test_gen.py49
-rw-r--r--verif/generator/tosa_utils.py10
11 files changed, 225 insertions, 77 deletions
diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc
index 6edd63f..f7ded9a 100644
--- a/reference_model/src/ops/op_factory.cc
+++ b/reference_model/src/ops/op_factory.cc
@@ -396,7 +396,11 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT16);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT32);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT8, INT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT8, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT16, INT16);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, UINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, UINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, UINT16);
break;
// custom
diff --git a/reference_model/src/ops/template_types.h b/reference_model/src/ops/template_types.h
index 0fe9a41..2bc7e04 100644
--- a/reference_model/src/ops/template_types.h
+++ b/reference_model/src/ops/template_types.h
@@ -23,7 +23,7 @@ using namespace tosa;
namespace TosaReference
{
-// Shorter aliase templates for common Eigen::Tensor types
+// Shorter alias templates for common Eigen::Tensor types
template <typename T>
using ETensor0 = Eigen::Tensor<T, 0>;
template <typename T>
@@ -89,6 +89,11 @@ struct GetEigenType<DType_UINT8>
using type = int32_t;
};
template <>
+struct GetEigenType<DType_UINT16>
+{
+ using type = int32_t;
+};
+template <>
struct GetEigenType<DType_INT4>
{
using type = int32_t;
@@ -121,6 +126,11 @@ struct GetNumBits<DType_UINT8>
static constexpr int32_t value = 8;
};
template <>
+struct GetNumBits<DType_UINT16>
+{
+ static constexpr int32_t value = 16;
+};
+template <>
struct GetNumBits<DType_INT4>
{
static constexpr int32_t value = 4;
@@ -158,6 +168,11 @@ struct GetQMin<DType_UINT8>
static constexpr int64_t value = 0L;
};
template <>
+struct GetQMin<DType_UINT16>
+{
+ static constexpr int64_t value = 0L;
+};
+template <>
struct GetQMin<DType_INT4>
{
static constexpr int64_t value = -8L;
@@ -194,6 +209,11 @@ struct GetQMax<DType_UINT8>
static constexpr int64_t value = 255L;
};
template <>
+struct GetQMax<DType_UINT16>
+{
+ static constexpr int64_t value = 65535L;
+};
+template <>
struct GetQMax<DType_INT4>
{
static constexpr int64_t value = 7L;
diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc
index e46ab38..7ee9692 100644
--- a/reference_model/src/ops/type_conversion.cc
+++ b/reference_model/src/ops/type_conversion.cc
@@ -64,15 +64,27 @@ int OpRescale<Rank, InDtype, OutDtype>::checkTensorAttributes()
ASSERT_MEM(in && out);
- if ((InDtype != DType_INT8) && (InDtype != DType_UINT8) && (attribute->input_zp() != 0))
+ if ((InDtype != DType_INT8) && (InDtype != DType_UINT8) && (InDtype != DType_UINT16) && (attribute->input_zp() != 0))
{
- printNodeValidationError("OpRescale: Input DType not INT8/UINT8 and zero point not 0");
+ printNodeValidationError("OpRescale: Input DType not INT8/UINT8/UINT16 and zero point not 0");
return 1;
}
- if ((OutDtype != DType_INT8) && (OutDtype != DType_UINT8) && (attribute->output_zp() != 0))
+ if ((OutDtype != DType_INT8) && (OutDtype != DType_UINT8) && (OutDtype != DType_UINT16) && (attribute->output_zp() != 0))
{
- printNodeValidationError("OpRescale: Output DType not INT8/UINT8 and zero point not 0");
+ printNodeValidationError("OpRescale: Output DType not INT8/UINT8/UINT16 and zero point not 0");
+ return 1;
+ }
+
+ if ((InDtype == DType_UINT16) && ((attribute->input_zp() != 0) && (attribute->input_zp() != 32768)))
+ {
+ printNodeValidationError("OpRescale: Input DType UINT16 and zero point not 0 or 32768");
+ return 1;
+ }
+
+ if ((OutDtype == DType_UINT16) && ((attribute->output_zp() != 0) && (attribute->output_zp() != 32768)))
+ {
+ printNodeValidationError("OpRescale: Output DType UINT16 and zero point not 0 or 32768");
return 1;
}
@@ -329,4 +341,8 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT8, INT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT8, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT16, INT16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, UINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, UINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, UINT16);
diff --git a/reference_model/src/quant_util.h b/reference_model/src/quant_util.h
index 8c1b391..3b7674d 100644
--- a/reference_model/src/quant_util.h
+++ b/reference_model/src/quant_util.h
@@ -114,7 +114,7 @@ public:
static bool is_integer(DType dtype)
{
if (dtype == DType_INT4 || dtype == DType_INT8 || dtype == DType_UINT8 || dtype == DType_INT16 ||
- dtype == DType_INT32 || dtype == DType_INT48)
+ dtype == DType_UINT16 || dtype == DType_INT32 || dtype == DType_INT48)
{
return true;
}
diff --git a/reference_model/src/tensor.cc b/reference_model/src/tensor.cc
index f2a3a98..36ace48 100644
--- a/reference_model/src/tensor.cc
+++ b/reference_model/src/tensor.cc
@@ -102,6 +102,7 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename)
case DType_INT4:
case DType_INT8:
case DType_INT16:
+ case DType_UINT16:
i32databuf = (int32_t*)calloc(sizeof(int32_t), elements);
ASSERT_MEM(i32databuf);
@@ -157,6 +158,7 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename)
case DType_INT4:
case DType_INT8:
case DType_INT16:
+ case DType_UINT16:
if (setTensorValueInt32(elements, i32databuf))
{
free(i32databuf);
@@ -225,6 +227,7 @@ int TosaReference::Tensor::writeToNpyFile(const char* filename) const
case DType_INT4:
case DType_INT8:
case DType_INT16:
+ case DType_UINT16:
i32databuf = (int32_t*)calloc(sizeof(int32_t), elements);
ASSERT_MEM(i32databuf);
diff --git a/reference_model/src/tensor.h b/reference_model/src/tensor.h
index d857dc8..ede42a9 100644
--- a/reference_model/src/tensor.h
+++ b/reference_model/src/tensor.h
@@ -656,6 +656,7 @@ public:
case DType_INT4:
case DType_INT8:
case DType_INT16:
+ case DType_UINT16:
switch (rank)
{
case 0:
diff --git a/thirdparty/serialization_lib b/thirdparty/serialization_lib
-Subproject 9b22517ba0cd6f767123583ce56e864f50e9d75
+Subproject 4102773d83e236448130b43b1747621ace00160
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index b1f8942..a741efb 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -1349,29 +1349,58 @@ class TosaArgGen:
arg_list = []
# Enumerate the output types here
- for dtype in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
+ for outDtype in [
+ DType.UINT8,
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.UINT16,
+ ]:
if (
- dtype in [DType.UINT8, DType.INT8]
+ outDtype in [DType.UINT8, DType.INT8, DType.UINT16]
and error_name == ErrorIf.OutputZeroPointNotZero
):
continue
if (
+ outDtype != DType.UINT16
+ and error_name == ErrorIf.U16OutputZeroPointNotValid
+ ) or (
+ inDtype != DType.UINT16
+ and error_name == ErrorIf.U16InputZeroPointNotValid
+ ):
+ # ErrorIfs only valid with UINT16
+ continue
+ if (
inDtype == DType.UINT8
- and dtype != DType.INT8
+ and outDtype not in [DType.INT8, DType.INT16]
+ and error_name != ErrorIf.WrongOutputType
+ ):
+ # The only output dtypes for UINT8 are INT8/INT16, skip all others
+ continue
+ if (
+ inDtype not in [DType.INT8, DType.INT16]
+ and outDtype == DType.UINT8
+ and error_name != ErrorIf.WrongOutputType
+ ):
+ # The only input dtypes for UINT8 are INT8/INT16, skip all others
+ continue
+ if (
+ inDtype == DType.UINT16
+ and outDtype != DType.INT16
and error_name != ErrorIf.WrongOutputType
):
- # The only output dtype for UINT8 is INT8, skip all other combinations
+ # The only output dtype for UINT16 is INT16, skip all others
continue
if (
- inDtype != DType.INT8
- and dtype == DType.UINT8
+ inDtype != DType.INT16
+ and outDtype == DType.UINT16
and error_name != ErrorIf.WrongOutputType
):
- # The only input dtype for UINT8 is INT8, skip all other combinations
+ # The only input dtype for UINT16 is INT16, skip all others
continue
if (
error_name == ErrorIf.WrongOutputType
- and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, dtype)
+ and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, outDtype)
):
continue
@@ -1403,12 +1432,12 @@ class TosaArgGen:
arg_list.append(
(
"out{}_sc{}_dr{}_pc{}".format(
- DTypeNames[dtype],
+ DTypeNames[outDtype],
int(scale32),
int(double_round),
int(per_channel),
),
- [dtype, scale32, double_round, per_channel],
+ [outDtype, scale32, double_round, per_channel],
)
)
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index e7e758f..1900d8a 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -68,6 +68,8 @@ class ErrorIf(object):
InputListBodyGraphInputMismatch = "InputListBodyGraphInputMismatch"
InputListBodyGraphOutputMismatch = "InputListBodyGraphOutputMismatch"
CondGraphOutputNotMatchingBool = "CondGraphOutputNotMatchingBool"
+ U16InputZeroPointNotValid = "U16InputZeroPointNotValid"
+ U16OutputZeroPointNotValid = "U16OutputZeroPointNotValid"
class TosaErrorIfArgGen:
@@ -227,14 +229,26 @@ class TosaErrorIfArgGen:
if input_dtype == DType.INT8:
if output_dtype not in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
return True
- if input_dtype in [DType.INT16, DType.INT32]:
+ elif input_dtype == DType.INT16:
+ if output_dtype not in [
+ DType.UINT8,
+ DType.INT8,
+ DType.UINT16,
+ DType.INT16,
+ DType.INT32,
+ ]:
+ return True
+ elif input_dtype == DType.INT32:
if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
return True
elif input_dtype == DType.INT48:
if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
return True
elif input_dtype == DType.UINT8:
- if output_dtype != DType.INT8:
+ if output_dtype not in [DType.INT8, DType.INT16]:
+ return True
+ elif input_dtype == DType.UINT16:
+ if output_dtype != DType.INT16:
return True
return False
@@ -418,23 +432,9 @@ class TosaErrorValidator:
error_result = True
elif op["op"] == Op.RESCALE:
- if input_dtype == DType.INT8:
- if output_dtype not in [
- DType.UINT8,
- DType.INT8,
- DType.INT16,
- DType.INT32,
- ]:
- error_result = True
- if input_dtype in [DType.INT16, DType.INT32]:
- if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
- error_result = True
- elif input_dtype == DType.INT48:
- if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
- error_result = True
- elif input_dtype == DType.UINT8:
- if output_dtype != DType.INT8:
- error_result = True
+ error_result = TosaErrorIfArgGen.eiRescaleWrongOutputType(
+ input_dtype, output_dtype
+ )
elif op["op"] in [Op.FULLY_CONNECTED, Op.MATMUL]:
if (
@@ -998,12 +998,25 @@ class TosaErrorValidator:
return info_dict
@staticmethod
+ def _getZeroPoint(qinfo, index):
+ """Return zero point value from quantization info.
+
+ Generally input_zp is index 0, output_zp is index 1
+ """
+ if isinstance(qinfo, tuple):
+ zero_point = qinfo[index]
+ else:
+ # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
+ zero_point = qinfo.ints[index][1]
+ return zero_point
+
+ @staticmethod
def evInputZeroPointNotZero(check=False, **kwargs):
op = kwargs["op"]
error_result = False
# Quantizable types
- qTypes = (DType.INT8, DType.UINT8)
+ qTypes = (DType.INT8, DType.UINT8, DType.UINT16)
# This does not apply to quantizable types
inputDtypes = [
@@ -1015,19 +1028,12 @@ class TosaErrorValidator:
if check:
input_dtype = kwargs["input_dtype"]
- if isinstance(kwargs["qinfo"], tuple):
- qinfo = kwargs["qinfo"]
- input_zero_point = qinfo[0]
- else:
- # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
- qinfo = kwargs["qinfo"].ints
- input_zero_point = qinfo[0][1]
-
+ input_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 0)
if op["op"] == Op.MATMUL:
- qinfo = kwargs["qinfo"].ints
+ input2_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
for dtype, zp in (
- (kwargs["input_dtype"], qinfo[0][1]),
- (kwargs["input2_dtype"], qinfo[1][1]),
+ (kwargs["input_dtype"], input_zero_point),
+ (kwargs["input2_dtype"], input2_zero_point),
):
if dtype not in qTypes and zp != 0:
error_result = True
@@ -1059,9 +1065,7 @@ class TosaErrorValidator:
if check:
weight_dtype = kwargs["weight_dtype"]
- # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = weight_zp
- qinfo = kwargs["qinfo"].ints
- weight_zero_point = qinfo[1][1]
+ weight_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
if weight_dtype != DType.INT8 and weight_zero_point != 0:
error_result = True
@@ -1076,11 +1080,9 @@ class TosaErrorValidator:
@staticmethod
def evOutputZeroPointNotZero(check=False, **kwargs):
op = kwargs["op"]
- inputDtypes = op["types"].copy()
- if DType.INT8 in inputDtypes:
- inputDtypes.remove(DType.INT8)
- if DType.UINT8 in inputDtypes:
- inputDtypes.remove(DType.UINT8)
+ inputDtypes = [
+ t for t in op["types"] if t not in [DType.INT8, DType.UINT8, DType.UINT16]
+ ]
error_name = ErrorIf.OutputZeroPointNotZero
param_reqs = {"rank": None, "dtype": inputDtypes, "shape": None}
@@ -1090,18 +1092,13 @@ class TosaErrorValidator:
if check:
input_dtype = kwargs["input_dtype"]
output_dtype = kwargs["output_dtype"]
- if isinstance(kwargs["qinfo"], tuple):
- qinfo = kwargs["qinfo"]
- output_zero_point = qinfo[1]
- else:
- # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
- qinfo = kwargs["qinfo"].ints
- output_zero_point = qinfo[1][1]
+ output_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
if op["op"] == Op.AVG_POOL2D:
if input_dtype != DType.INT8 and output_zero_point != 0:
error_result = True
elif (
- output_dtype not in [DType.INT8, DType.UINT8] and output_zero_point != 0
+ output_dtype not in [DType.INT8, DType.UINT8, DType.UINT16]
+ and output_zero_point != 0
):
error_result = True
@@ -1114,6 +1111,53 @@ class TosaErrorValidator:
return info_dict
@staticmethod
+ def evU16InputZeroPointNotValid(check=False, **kwargs):
+ error_name = ErrorIf.U16InputZeroPointNotValid
+ param_reqs = {"rank": None, "dtype": [DType.UINT16], "shape": None}
+ error_result = False
+ error_reason = "Input DType is UINT16 and zero point not 0 or 32678"
+
+ if check:
+ input_dtype = kwargs["input_dtype"]
+ input_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 0)
+ error_result = input_dtype == DType.UINT16 and input_zero_point not in [
+ 0,
+ 32768,
+ ]
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
+ def evU16OutputZeroPointNotValid(check=False, **kwargs):
+ error_name = ErrorIf.U16OutputZeroPointNotValid
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Output DType is UINT16 and zero point not 0 or 32678"
+
+ if check:
+ output_dtype = kwargs["output_dtype"]
+ output_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
+
+ error_result = output_dtype == DType.UINT16 and output_zero_point not in [
+ 0,
+ 32768,
+ ]
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
def evAxisSmallerZero(check=False, **kwargs):
error_name = ErrorIf.AxisSmallerZero
param_reqs = {"rank": None, "dtype": None, "shape": None}
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 7c2b9de..c9c6d7e 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -70,6 +70,8 @@ class TosaTestGen:
return np.int32(self.rng.integers(low=0, high=256, size=shape))
elif dtype == DType.INT16:
return np.int32(self.rng.integers(low=-32768, high=32768, size=shape))
+ elif dtype == DType.UINT16:
+ return np.int32(self.rng.integers(low=0, high=65536, size=shape))
elif dtype == DType.INT32:
return np.int32(
self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape)
@@ -169,6 +171,8 @@ class TosaTestGen:
return "u8"
elif t == DType.INT16:
return "i16"
+ elif t == DType.UINT16:
+ return "u16"
elif t == DType.INT32:
return "i32"
elif t == DType.INT48:
@@ -188,6 +192,8 @@ class TosaTestGen:
return 8
elif t == DType.INT16:
return 16
+ elif t == DType.UINT16:
+ return 16
elif t == DType.INT32:
return 32
elif t == DType.INT48:
@@ -1575,29 +1581,43 @@ class TosaTestGen:
if val.dtype == DType.INT8:
input_zp = self.randInt(-128, 128)
- in_type_width = in_type_width + 1
+ in_type_width += 1
elif val.dtype == DType.UINT8:
input_zp = self.randInt(0, 256)
- in_type_width = in_type_width + 1
- elif error_name == ErrorIf.InputZeroPointNotZero:
+ in_type_width += 1
+ elif error_name in [
+ ErrorIf.InputZeroPointNotZero,
+ ErrorIf.U16InputZeroPointNotValid,
+ ]:
input_zp = self.randInt(-128, 128)
if input_zp == 0:
input_zp = input_zp + self.rng.integers(1, 10)
- in_type_width = in_type_width + 1
+ in_type_width += 1
+ elif val.dtype == DType.UINT16:
+ # Must come after ErrorIf.U16InputZeroPointNotValid check
+ input_zp = self.rng.choice([0, 32768])
+ in_type_width += 1
else:
input_zp = 0
if out_dtype == DType.INT8:
output_zp = self.randInt(-128, 128)
- out_type_width = out_type_width + 1
+ out_type_width += 1
elif out_dtype == DType.UINT8:
output_zp = self.randInt(0, 256)
- out_type_width = out_type_width + 1
- elif error_name == ErrorIf.OutputZeroPointNotZero:
+ out_type_width += 1
+ elif error_name in [
+ ErrorIf.OutputZeroPointNotZero,
+ ErrorIf.U16OutputZeroPointNotValid,
+ ]:
output_zp = self.randInt(-128, 128)
if output_zp == 0:
output_zp = output_zp + self.rng.integers(1, 10)
- out_type_width = out_type_width + 1
+ out_type_width += 1
+ elif out_dtype == DType.UINT16:
+ # Must come after ErrorIf.U16OutputZeroPointNotValid check
+ output_zp = self.rng.choice([0, 32768])
+ out_type_width += 1
else:
output_zp = 0
@@ -1631,7 +1651,7 @@ class TosaTestGen:
# print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
if scale32 and error_name is None:
- # Make sure random values are within apply_scale_32 speicification
+ # Make sure random values are within apply_scale_32 specification
# REQUIRES(value >= (-1<<(shift-2)) && value < (1<<(shift-2))
assert val.placeholderFilename
values = np.load(
@@ -3642,10 +3662,19 @@ class TosaTestGen:
TosaTensorValuesGen.tvgDefault,
TosaArgGen.agRescale,
),
- "types": [DType.UINT8, DType.INT8, DType.INT16, DType.INT32, DType.INT48],
+ "types": [
+ DType.UINT8,
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.INT48,
+ DType.UINT16,
+ ],
"error_if_validators": (
TosaErrorValidator.evInputZeroPointNotZero,
TosaErrorValidator.evOutputZeroPointNotZero,
+ TosaErrorValidator.evU16InputZeroPointNotValid,
+ TosaErrorValidator.evU16OutputZeroPointNotValid,
TosaErrorValidator.evScaleTrue,
TosaErrorValidator.evScaleNotTrue,
TosaErrorValidator.evWrongInputType,
diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py
index ca115a2..a4ef31a 100644
--- a/verif/generator/tosa_utils.py
+++ b/verif/generator/tosa_utils.py
@@ -59,9 +59,11 @@ def allDTypes(*, excludes=None):
def usableDTypes(*, excludes=None):
"""Get a set of usable DType values, optionally excluding some values.
- Excludes (DType.UNKNOWN, DType.UINT8) in addition to the excludes
- specified by the caller, as the serializer lib does not support them.
- If you wish to include 'UNKNOWN' or 'UINT8' use allDTypes instead.
+ Excludes uncommon types (DType.UNKNOWN, DType.UINT16, DType.UINT8) in
+ addition to the excludes specified by the caller, as the serializer lib
+ does not support them.
+ If you wish to include 'UNKNOWN', 'UINT8' or 'UINT16' use allDTypes
+ instead.
Args:
excludes: iterable of DType values (e.g. [DType.INT8, DType.BOOL])
@@ -69,7 +71,7 @@ def usableDTypes(*, excludes=None):
Returns:
A set of DType values
"""
- omit = {DType.UNKNOWN, DType.UINT8}
+ omit = {DType.UNKNOWN, DType.UINT8, DType.UINT16}
omit.update(excludes if excludes else ())
return allDTypes(excludes=omit)