From f7f78ae236e623a57919f9450e8b2043e681ddb3 Mon Sep 17 00:00:00 2001 From: Jeremy Johnson Date: Wed, 25 May 2022 15:26:38 +0100 Subject: Add support for uint16_t to RESCALE Update ref-model RESCALE op to support UINT16 conversions Add testing for RESCALE UINT16 and ERROR_IFs Signed-off-by: Jeremy Johnson Change-Id: Ic6e6e53de1f0b054bedb9e6ba3856e7475498aba --- reference_model/src/ops/type_conversion.cc | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) (limited to 'reference_model/src/ops/type_conversion.cc') diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc index e46ab38..7ee9692 100644 --- a/reference_model/src/ops/type_conversion.cc +++ b/reference_model/src/ops/type_conversion.cc @@ -64,15 +64,27 @@ int OpRescale::checkTensorAttributes() ASSERT_MEM(in && out); - if ((InDtype != DType_INT8) && (InDtype != DType_UINT8) && (attribute->input_zp() != 0)) + if ((InDtype != DType_INT8) && (InDtype != DType_UINT8) && (InDtype != DType_UINT16) && (attribute->input_zp() != 0)) { - printNodeValidationError("OpRescale: Input DType not INT8/UINT8 and zero point not 0"); + printNodeValidationError("OpRescale: Input DType not INT8/UINT8/UINT16 and zero point not 0"); return 1; } - if ((OutDtype != DType_INT8) && (OutDtype != DType_UINT8) && (attribute->output_zp() != 0)) + if ((OutDtype != DType_INT8) && (OutDtype != DType_UINT8) && (OutDtype != DType_UINT16) && (attribute->output_zp() != 0)) { - printNodeValidationError("OpRescale: Output DType not INT8/UINT8 and zero point not 0"); + printNodeValidationError("OpRescale: Output DType not INT8/UINT8/UINT16 and zero point not 0"); + return 1; + } + + if ((InDtype == DType_UINT16) && ((attribute->input_zp() != 0) && (attribute->input_zp() != 32768))) + { + printNodeValidationError("OpRescale: Input DType UINT16 and zero point not 0 or 32768"); + return 1; + } + + if ((OutDtype == DType_UINT16) && ((attribute->output_zp() != 0) && (attribute->output_zp() != 32768))) + { + printNodeValidationError("OpRescale: Output DType UINT16 and zero point not 0 or 32768"); return 1; } @@ -329,4 +341,8 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT8, INT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT8, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT16, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, UINT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, UINT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, UINT16); -- cgit v1.2.1