From 9c0a5075d9e184f6b92762b3bc903e021b700e65 Mon Sep 17 00:00:00 2001 From: Eric Kunze Date: Tue, 19 Mar 2024 21:14:07 +0000 Subject: Modify Rescale signedness check to look at attributes Also simplify the check to align the pesudo code structure. Signed-off-by: Eric Kunze Signed-off-by: TatWai Chong Change-Id: I6023046026d2784dedd963b2b4d34a1117d45c23 --- reference_model/src/ops/type_conversion.cc | 23 +++++++++++++++-------- reference_model/src/ops/type_conversion.h | 10 ++++++++++ 2 files changed, 25 insertions(+), 8 deletions(-) diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc index 835b656..85f8c58 100644 --- a/reference_model/src/ops/type_conversion.cc +++ b/reference_model/src/ops/type_conversion.cc @@ -82,29 +82,36 @@ int OpRescale::checkTensorAttributes() ASSERT_MEM(multiplierI16); } - if ((InDtype != TOSA_REF_TYPE_INT8) && (InDtype != TOSA_REF_TYPE_UINT8) && (InDtype != TOSA_REF_TYPE_UINT16) && - (attribute->input_zp() != 0)) + auto input_zp = attribute->input_zp(); + auto output_zp = attribute->output_zp(); + auto input_unsigned = attribute->input_unsigned(); + auto output_unsigned = attribute->output_unsigned(); + + // Note that how rescale op interprets signedness of the tensor depends on + // the value of input_unsigned and output_unsigned attributes, and doesn't + // care about the type of tensor itself. + + if (!isI8(InDtype) && (!isI16(InDtype) || input_unsigned == false) && (input_zp != 0)) { printNodeValidationError("OpRescale: Input TOSA_REF_TYPE not INT8/UINT8/UINT16 and zero point not 0"); return 1; } - if ((OutDtype != TOSA_REF_TYPE_INT8) && (OutDtype != TOSA_REF_TYPE_UINT8) && (OutDtype != TOSA_REF_TYPE_UINT16) && - (attribute->output_zp() != 0)) + if (!isI8(OutDtype) && (!isI16(OutDtype) || output_unsigned == false) && (output_zp != 0)) { printNodeValidationError("OpRescale: Output TOSA_REF_TYPE not INT8/UINT8/UINT16 and zero point not 0"); return 1; } - if ((InDtype == TOSA_REF_TYPE_UINT16) && ((attribute->input_zp() != 0) && (attribute->input_zp() != 32768))) + if (isI16(InDtype) && (input_unsigned == true) && (input_zp != 0) && (input_zp != 32768)) { - printNodeValidationError("OpRescale: Input TOSA_REF_TYPE UINT16 and zero point not 0 or 32768"); + printNodeValidationError("OpRescale: Input unsigned int16 and zero point not 0 or 32768"); return 1; } - if ((OutDtype == TOSA_REF_TYPE_UINT16) && ((attribute->output_zp() != 0) && (attribute->output_zp() != 32768))) + if (isI16(OutDtype) && (output_unsigned == true) && (output_zp != 0) && (output_zp != 32768)) { - printNodeValidationError("OpRescale: Output TOSA_REF_TYPE UINT16 and zero point not 0 or 32768"); + printNodeValidationError("OpRescale: Output unsigned int16 and zero point not 0 or 32768"); return 1; } diff --git a/reference_model/src/ops/type_conversion.h b/reference_model/src/ops/type_conversion.h index da5537e..cf95f16 100644 --- a/reference_model/src/ops/type_conversion.h +++ b/reference_model/src/ops/type_conversion.h @@ -43,6 +43,16 @@ public: using TMultiplierI32 = Eigen::Tensor; using TShift = Eigen::Tensor; + bool isI8(TOSA_REF_TYPE Dtype) + { + return Dtype == TOSA_REF_TYPE_INT8 || Dtype == TOSA_REF_TYPE_UINT8; + } + + bool isI16(TOSA_REF_TYPE Dtype) + { + return Dtype == TOSA_REF_TYPE_INT16 || Dtype == TOSA_REF_TYPE_UINT16; + } + protected: TosaRescaleAttribute* attribute; TosaReference::TensorTemplate* in; -- cgit v1.2.1