aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/type_conversion.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops/type_conversion.cc')
-rw-r--r--reference_model/src/ops/type_conversion.cc26
1 files changed, 25 insertions, 1 deletions
diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc
index 657eebf..e46ab38 100644
--- a/reference_model/src/ops/type_conversion.cc
+++ b/reference_model/src/ops/type_conversion.cc
@@ -30,7 +30,7 @@ OpRescale<Rank, InDtype, OutDtype>::OpRescale(SubgraphTraverser* sgt_,
: GraphNode(sgt_, Op_RESCALE, id_)
{
setRequiredOperands(1, 1);
- setRequiredRank(0, 6);
+ setRequiredRank(0, 4);
INIT_ATTRIBUTE(Rescale);
}
@@ -64,6 +64,30 @@ int OpRescale<Rank, InDtype, OutDtype>::checkTensorAttributes()
ASSERT_MEM(in && out);
+ if ((InDtype != DType_INT8) && (InDtype != DType_UINT8) && (attribute->input_zp() != 0))
+ {
+ printNodeValidationError("OpRescale: Input DType not INT8/UINT8 and zero point not 0");
+ return 1;
+ }
+
+ if ((OutDtype != DType_INT8) && (OutDtype != DType_UINT8) && (attribute->output_zp() != 0))
+ {
+ printNodeValidationError("OpRescale: Output DType not INT8/UINT8 and zero point not 0");
+ return 1;
+ }
+
+ if (attribute->scale32() && (InDtype == DType_INT48))
+ {
+ printNodeValidationError("OpRescale: Scale set to true but input type is INT48");
+ return 1;
+ }
+
+ if ((!attribute->scale32()) && attribute->double_round())
+ {
+ printNodeValidationError("OpRescale: Scale set to false but double round set to true");
+ return 1;
+ }
+
return 0;
}