aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEric Kunze <eric.kunze@arm.com>2024-03-19 21:14:07 +0000
committerEric Kunze <eric.kunze@arm.com>2024-04-08 17:06:08 +0000
commit9c0a5075d9e184f6b92762b3bc903e021b700e65 (patch)
tree7c8f2f69ccb2383bb5e99748a6a1ce73bc91cd1e
parentad8e1e25e805f6face5fcf0b3906cd06db46e1d7 (diff)
downloadreference_model-9c0a5075d9e184f6b92762b3bc903e021b700e65.tar.gz
Modify Rescale signedness check to look at attributes
Also simplify the check to align the pesudo code structure. Signed-off-by: Eric Kunze <eric.kunze@arm.com> Signed-off-by: TatWai Chong <tatwai.chong@arm.com> Change-Id: I6023046026d2784dedd963b2b4d34a1117d45c23
-rw-r--r--reference_model/src/ops/type_conversion.cc23
-rw-r--r--reference_model/src/ops/type_conversion.h10
2 files changed, 25 insertions, 8 deletions
diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc
index 835b656..85f8c58 100644
--- a/reference_model/src/ops/type_conversion.cc
+++ b/reference_model/src/ops/type_conversion.cc
@@ -82,29 +82,36 @@ int OpRescale<Rank, InDtype, OutDtype>::checkTensorAttributes()
ASSERT_MEM(multiplierI16);
}
- if ((InDtype != TOSA_REF_TYPE_INT8) && (InDtype != TOSA_REF_TYPE_UINT8) && (InDtype != TOSA_REF_TYPE_UINT16) &&
- (attribute->input_zp() != 0))
+ auto input_zp = attribute->input_zp();
+ auto output_zp = attribute->output_zp();
+ auto input_unsigned = attribute->input_unsigned();
+ auto output_unsigned = attribute->output_unsigned();
+
+ // Note that how rescale op interprets signedness of the tensor depends on
+ // the value of input_unsigned and output_unsigned attributes, and doesn't
+ // care about the type of tensor itself.
+
+ if (!isI8(InDtype) && (!isI16(InDtype) || input_unsigned == false) && (input_zp != 0))
{
printNodeValidationError("OpRescale: Input TOSA_REF_TYPE not INT8/UINT8/UINT16 and zero point not 0");
return 1;
}
- if ((OutDtype != TOSA_REF_TYPE_INT8) && (OutDtype != TOSA_REF_TYPE_UINT8) && (OutDtype != TOSA_REF_TYPE_UINT16) &&
- (attribute->output_zp() != 0))
+ if (!isI8(OutDtype) && (!isI16(OutDtype) || output_unsigned == false) && (output_zp != 0))
{
printNodeValidationError("OpRescale: Output TOSA_REF_TYPE not INT8/UINT8/UINT16 and zero point not 0");
return 1;
}
- if ((InDtype == TOSA_REF_TYPE_UINT16) && ((attribute->input_zp() != 0) && (attribute->input_zp() != 32768)))
+ if (isI16(InDtype) && (input_unsigned == true) && (input_zp != 0) && (input_zp != 32768))
{
- printNodeValidationError("OpRescale: Input TOSA_REF_TYPE UINT16 and zero point not 0 or 32768");
+ printNodeValidationError("OpRescale: Input unsigned int16 and zero point not 0 or 32768");
return 1;
}
- if ((OutDtype == TOSA_REF_TYPE_UINT16) && ((attribute->output_zp() != 0) && (attribute->output_zp() != 32768)))
+ if (isI16(OutDtype) && (output_unsigned == true) && (output_zp != 0) && (output_zp != 32768))
{
- printNodeValidationError("OpRescale: Output TOSA_REF_TYPE UINT16 and zero point not 0 or 32768");
+ printNodeValidationError("OpRescale: Output unsigned int16 and zero point not 0 or 32768");
return 1;
}
diff --git a/reference_model/src/ops/type_conversion.h b/reference_model/src/ops/type_conversion.h
index da5537e..cf95f16 100644
--- a/reference_model/src/ops/type_conversion.h
+++ b/reference_model/src/ops/type_conversion.h
@@ -43,6 +43,16 @@ public:
using TMultiplierI32 = Eigen::Tensor<I32EigenType, 1>;
using TShift = Eigen::Tensor<I8EigenType, 1>;
+ bool isI8(TOSA_REF_TYPE Dtype)
+ {
+ return Dtype == TOSA_REF_TYPE_INT8 || Dtype == TOSA_REF_TYPE_UINT8;
+ }
+
+ bool isI16(TOSA_REF_TYPE Dtype)
+ {
+ return Dtype == TOSA_REF_TYPE_INT16 || Dtype == TOSA_REF_TYPE_UINT16;
+ }
+
protected:
TosaRescaleAttribute* attribute;
TosaReference::TensorTemplate<TIn>* in;