From f7f78ae236e623a57919f9450e8b2043e681ddb3 Mon Sep 17 00:00:00 2001 From: Jeremy Johnson Date: Wed, 25 May 2022 15:26:38 +0100 Subject: Add support for uint16_t to RESCALE Update ref-model RESCALE op to support UINT16 conversions Add testing for RESCALE UINT16 and ERROR_IFs Signed-off-by: Jeremy Johnson Change-Id: Ic6e6e53de1f0b054bedb9e6ba3856e7475498aba --- reference_model/src/ops/op_factory.cc | 4 ++++ reference_model/src/ops/template_types.h | 22 +++++++++++++++++++++- reference_model/src/ops/type_conversion.cc | 24 ++++++++++++++++++++---- reference_model/src/quant_util.h | 2 +- reference_model/src/tensor.cc | 3 +++ reference_model/src/tensor.h | 1 + 6 files changed, 50 insertions(+), 6 deletions(-) (limited to 'reference_model') diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc index 6edd63f..f7ded9a 100644 --- a/reference_model/src/ops/op_factory.cc +++ b/reference_model/src/ops/op_factory.cc @@ -396,7 +396,11 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT16); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT32); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT8, INT8); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT8, INT16); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT16, INT16); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, UINT8); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, UINT8); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, UINT16); break; // custom diff --git a/reference_model/src/ops/template_types.h b/reference_model/src/ops/template_types.h index 0fe9a41..2bc7e04 100644 --- a/reference_model/src/ops/template_types.h +++ b/reference_model/src/ops/template_types.h @@ -23,7 +23,7 @@ using namespace tosa; namespace TosaReference { -// Shorter aliase templates for common Eigen::Tensor types +// Shorter alias templates for common Eigen::Tensor types template using ETensor0 = Eigen::Tensor; template @@ -89,6 +89,11 @@ struct GetEigenType using type = int32_t; }; template <> +struct GetEigenType +{ + using type = int32_t; +}; +template <> struct GetEigenType { using type = int32_t; @@ -121,6 +126,11 @@ struct GetNumBits static constexpr int32_t value = 8; }; template <> +struct GetNumBits +{ + static constexpr int32_t value = 16; +}; +template <> struct GetNumBits { static constexpr int32_t value = 4; @@ -158,6 +168,11 @@ struct GetQMin static constexpr int64_t value = 0L; }; template <> +struct GetQMin +{ + static constexpr int64_t value = 0L; +}; +template <> struct GetQMin { static constexpr int64_t value = -8L; @@ -194,6 +209,11 @@ struct GetQMax static constexpr int64_t value = 255L; }; template <> +struct GetQMax +{ + static constexpr int64_t value = 65535L; +}; +template <> struct GetQMax { static constexpr int64_t value = 7L; diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc index e46ab38..7ee9692 100644 --- a/reference_model/src/ops/type_conversion.cc +++ b/reference_model/src/ops/type_conversion.cc @@ -64,15 +64,27 @@ int OpRescale::checkTensorAttributes() ASSERT_MEM(in && out); - if ((InDtype != DType_INT8) && (InDtype != DType_UINT8) && (attribute->input_zp() != 0)) + if ((InDtype != DType_INT8) && (InDtype != DType_UINT8) && (InDtype != DType_UINT16) && (attribute->input_zp() != 0)) { - printNodeValidationError("OpRescale: Input DType not INT8/UINT8 and zero point not 0"); + printNodeValidationError("OpRescale: Input DType not INT8/UINT8/UINT16 and zero point not 0"); return 1; } - if ((OutDtype != DType_INT8) && (OutDtype != DType_UINT8) && (attribute->output_zp() != 0)) + if ((OutDtype != DType_INT8) && (OutDtype != DType_UINT8) && (OutDtype != DType_UINT16) && (attribute->output_zp() != 0)) { - printNodeValidationError("OpRescale: Output DType not INT8/UINT8 and zero point not 0"); + printNodeValidationError("OpRescale: Output DType not INT8/UINT8/UINT16 and zero point not 0"); + return 1; + } + + if ((InDtype == DType_UINT16) && ((attribute->input_zp() != 0) && (attribute->input_zp() != 32768))) + { + printNodeValidationError("OpRescale: Input DType UINT16 and zero point not 0 or 32768"); + return 1; + } + + if ((OutDtype == DType_UINT16) && ((attribute->output_zp() != 0) && (attribute->output_zp() != 32768))) + { + printNodeValidationError("OpRescale: Output DType UINT16 and zero point not 0 or 32768"); return 1; } @@ -329,4 +341,8 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT8, INT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT8, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT16, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, UINT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, UINT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, UINT16); diff --git a/reference_model/src/quant_util.h b/reference_model/src/quant_util.h index 8c1b391..3b7674d 100644 --- a/reference_model/src/quant_util.h +++ b/reference_model/src/quant_util.h @@ -114,7 +114,7 @@ public: static bool is_integer(DType dtype) { if (dtype == DType_INT4 || dtype == DType_INT8 || dtype == DType_UINT8 || dtype == DType_INT16 || - dtype == DType_INT32 || dtype == DType_INT48) + dtype == DType_UINT16 || dtype == DType_INT32 || dtype == DType_INT48) { return true; } diff --git a/reference_model/src/tensor.cc b/reference_model/src/tensor.cc index f2a3a98..36ace48 100644 --- a/reference_model/src/tensor.cc +++ b/reference_model/src/tensor.cc @@ -102,6 +102,7 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename) case DType_INT4: case DType_INT8: case DType_INT16: + case DType_UINT16: i32databuf = (int32_t*)calloc(sizeof(int32_t), elements); ASSERT_MEM(i32databuf); @@ -157,6 +158,7 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename) case DType_INT4: case DType_INT8: case DType_INT16: + case DType_UINT16: if (setTensorValueInt32(elements, i32databuf)) { free(i32databuf); @@ -225,6 +227,7 @@ int TosaReference::Tensor::writeToNpyFile(const char* filename) const case DType_INT4: case DType_INT8: case DType_INT16: + case DType_UINT16: i32databuf = (int32_t*)calloc(sizeof(int32_t), elements); ASSERT_MEM(i32databuf); diff --git a/reference_model/src/tensor.h b/reference_model/src/tensor.h index d857dc8..ede42a9 100644 --- a/reference_model/src/tensor.h +++ b/reference_model/src/tensor.h @@ -656,6 +656,7 @@ public: case DType_INT4: case DType_INT8: case DType_INT16: + case DType_UINT16: switch (rank) { case 0: -- cgit v1.2.1