aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src')
-rw-r--r--reference_model/src/ops/op_factory.cc4
-rw-r--r--reference_model/src/ops/template_types.h22
-rw-r--r--reference_model/src/ops/type_conversion.cc24
-rw-r--r--reference_model/src/quant_util.h2
-rw-r--r--reference_model/src/tensor.cc3
-rw-r--r--reference_model/src/tensor.h1
6 files changed, 50 insertions, 6 deletions
diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc
index 6edd63f..f7ded9a 100644
--- a/reference_model/src/ops/op_factory.cc
+++ b/reference_model/src/ops/op_factory.cc
@@ -396,7 +396,11 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT16);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT32);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT8, INT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT8, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT16, INT16);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, UINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, UINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, UINT16);
break;
// custom
diff --git a/reference_model/src/ops/template_types.h b/reference_model/src/ops/template_types.h
index 0fe9a41..2bc7e04 100644
--- a/reference_model/src/ops/template_types.h
+++ b/reference_model/src/ops/template_types.h
@@ -23,7 +23,7 @@ using namespace tosa;
namespace TosaReference
{
-// Shorter aliase templates for common Eigen::Tensor types
+// Shorter alias templates for common Eigen::Tensor types
template <typename T>
using ETensor0 = Eigen::Tensor<T, 0>;
template <typename T>
@@ -89,6 +89,11 @@ struct GetEigenType<DType_UINT8>
using type = int32_t;
};
template <>
+struct GetEigenType<DType_UINT16>
+{
+ using type = int32_t;
+};
+template <>
struct GetEigenType<DType_INT4>
{
using type = int32_t;
@@ -121,6 +126,11 @@ struct GetNumBits<DType_UINT8>
static constexpr int32_t value = 8;
};
template <>
+struct GetNumBits<DType_UINT16>
+{
+ static constexpr int32_t value = 16;
+};
+template <>
struct GetNumBits<DType_INT4>
{
static constexpr int32_t value = 4;
@@ -158,6 +168,11 @@ struct GetQMin<DType_UINT8>
static constexpr int64_t value = 0L;
};
template <>
+struct GetQMin<DType_UINT16>
+{
+ static constexpr int64_t value = 0L;
+};
+template <>
struct GetQMin<DType_INT4>
{
static constexpr int64_t value = -8L;
@@ -194,6 +209,11 @@ struct GetQMax<DType_UINT8>
static constexpr int64_t value = 255L;
};
template <>
+struct GetQMax<DType_UINT16>
+{
+ static constexpr int64_t value = 65535L;
+};
+template <>
struct GetQMax<DType_INT4>
{
static constexpr int64_t value = 7L;
diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc
index e46ab38..7ee9692 100644
--- a/reference_model/src/ops/type_conversion.cc
+++ b/reference_model/src/ops/type_conversion.cc
@@ -64,15 +64,27 @@ int OpRescale<Rank, InDtype, OutDtype>::checkTensorAttributes()
ASSERT_MEM(in && out);
- if ((InDtype != DType_INT8) && (InDtype != DType_UINT8) && (attribute->input_zp() != 0))
+ if ((InDtype != DType_INT8) && (InDtype != DType_UINT8) && (InDtype != DType_UINT16) && (attribute->input_zp() != 0))
{
- printNodeValidationError("OpRescale: Input DType not INT8/UINT8 and zero point not 0");
+ printNodeValidationError("OpRescale: Input DType not INT8/UINT8/UINT16 and zero point not 0");
return 1;
}
- if ((OutDtype != DType_INT8) && (OutDtype != DType_UINT8) && (attribute->output_zp() != 0))
+ if ((OutDtype != DType_INT8) && (OutDtype != DType_UINT8) && (OutDtype != DType_UINT16) && (attribute->output_zp() != 0))
{
- printNodeValidationError("OpRescale: Output DType not INT8/UINT8 and zero point not 0");
+ printNodeValidationError("OpRescale: Output DType not INT8/UINT8/UINT16 and zero point not 0");
+ return 1;
+ }
+
+ if ((InDtype == DType_UINT16) && ((attribute->input_zp() != 0) && (attribute->input_zp() != 32768)))
+ {
+ printNodeValidationError("OpRescale: Input DType UINT16 and zero point not 0 or 32768");
+ return 1;
+ }
+
+ if ((OutDtype == DType_UINT16) && ((attribute->output_zp() != 0) && (attribute->output_zp() != 32768)))
+ {
+ printNodeValidationError("OpRescale: Output DType UINT16 and zero point not 0 or 32768");
return 1;
}
@@ -329,4 +341,8 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT8, INT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT8, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT16, INT16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, UINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, UINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, UINT16);
diff --git a/reference_model/src/quant_util.h b/reference_model/src/quant_util.h
index 8c1b391..3b7674d 100644
--- a/reference_model/src/quant_util.h
+++ b/reference_model/src/quant_util.h
@@ -114,7 +114,7 @@ public:
static bool is_integer(DType dtype)
{
if (dtype == DType_INT4 || dtype == DType_INT8 || dtype == DType_UINT8 || dtype == DType_INT16 ||
- dtype == DType_INT32 || dtype == DType_INT48)
+ dtype == DType_UINT16 || dtype == DType_INT32 || dtype == DType_INT48)
{
return true;
}
diff --git a/reference_model/src/tensor.cc b/reference_model/src/tensor.cc
index f2a3a98..36ace48 100644
--- a/reference_model/src/tensor.cc
+++ b/reference_model/src/tensor.cc
@@ -102,6 +102,7 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename)
case DType_INT4:
case DType_INT8:
case DType_INT16:
+ case DType_UINT16:
i32databuf = (int32_t*)calloc(sizeof(int32_t), elements);
ASSERT_MEM(i32databuf);
@@ -157,6 +158,7 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename)
case DType_INT4:
case DType_INT8:
case DType_INT16:
+ case DType_UINT16:
if (setTensorValueInt32(elements, i32databuf))
{
free(i32databuf);
@@ -225,6 +227,7 @@ int TosaReference::Tensor::writeToNpyFile(const char* filename) const
case DType_INT4:
case DType_INT8:
case DType_INT16:
+ case DType_UINT16:
i32databuf = (int32_t*)calloc(sizeof(int32_t), elements);
ASSERT_MEM(i32databuf);
diff --git a/reference_model/src/tensor.h b/reference_model/src/tensor.h
index d857dc8..ede42a9 100644
--- a/reference_model/src/tensor.h
+++ b/reference_model/src/tensor.h
@@ -656,6 +656,7 @@ public:
case DType_INT4:
case DType_INT8:
case DType_INT16:
+ case DType_UINT16:
switch (rank)
{
case 0: