aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops')
-rw-r--r--reference_model/src/ops/activation_funcs.cc8
-rw-r--r--reference_model/src/ops/data_layout.cc93
-rw-r--r--reference_model/src/ops/data_layout.h1
-rw-r--r--reference_model/src/ops/ewise_binary.cc20
-rw-r--r--reference_model/src/ops/tensor_ops.cc156
-rw-r--r--reference_model/src/ops/type_conversion.cc26
6 files changed, 245 insertions, 59 deletions
diff --git a/reference_model/src/ops/activation_funcs.cc b/reference_model/src/ops/activation_funcs.cc
index 21677d5..c344bcb 100644
--- a/reference_model/src/ops/activation_funcs.cc
+++ b/reference_model/src/ops/activation_funcs.cc
@@ -25,14 +25,15 @@ using namespace tosa;
template <int Rank, DType Dtype>
int OpClamp<Rank, Dtype>::register_fcn()
{
-
switch (Dtype)
{
case DType_FLOAT:
{
InEigenType min = (InEigenType)attribute->min_fp();
InEigenType max = (InEigenType)attribute->max_fp();
- this->fcn = [min, max](InEigenType a) -> OutEigenType { return a <= min ? min : a >= max ? max : a; };
+ ERROR_IF(max < min, "OpClamp: max smaller than min");
+
+ this->fcn = [min, max](InEigenType a) -> OutEigenType { return a <= min ? min : a >= max ? max : a; };
}
break;
case DType_INT8:
@@ -40,7 +41,8 @@ int OpClamp<Rank, Dtype>::register_fcn()
{
InEigenType min = (InEigenType)attribute->min_int();
InEigenType max = (InEigenType)attribute->max_int();
- this->fcn = [min, max](InEigenType a) -> OutEigenType { return a <= min ? min : a >= max ? max : a; };
+ ERROR_IF(max < min, "OpClamp: max smaller than min");
+ this->fcn = [min, max](InEigenType a) -> OutEigenType { return a <= min ? min : a >= max ? max : a; };
}
break;
default:
diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc
index 86326f5..f3e80f3 100644
--- a/reference_model/src/ops/data_layout.cc
+++ b/reference_model/src/ops/data_layout.cc
@@ -51,25 +51,49 @@ int OpConcat<Rank, Dtype>::checkTensorAttributes()
printNodeValidationError("Concat operator must have at least one input tensor");
return 1;
}
+
+ int32_t num_inputs = inputs.size();
+
// output and input must be the same types and rank
- for (size_t i = 0; i < inputs.size(); i++)
+ for (int32_t i = 0; i < num_inputs; i++)
{
if (inputs[i]->matchRankType(*outputs[0]))
{
- printNodeValidationError("Concat operator input ranks and types must match");
+ printNodeValidationError("OpConcat: input ranks and types must match");
return 1;
}
ins.push_back(dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[i]));
}
- out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
-
- if (attribute->axis() < 0 || (size_t)attribute->axis() >= inputs[0]->getShape().size())
+ if (attribute->axis() < 0 || (size_t)attribute->axis() >= Rank)
{
- printNodeValidationError("Axis is beyond input tensor rank");
+ printNodeValidationError("OpConcat: axis is beyond output tensor rank");
return 1;
}
+ int32_t output_dim_on_axis = 0;
+ for (int32_t j = 0; j < num_inputs; j++)
+ {
+ for (int32_t i = 0; i < Rank; i++)
+ {
+ int32_t input_dim = inputs[j]->getShape()[i];
+ if (i == attribute->axis())
+ {
+ output_dim_on_axis += input_dim;
+ }
+ else if (input_dim != outputs[0]->getShape()[i])
+ {
+ printNodeValidationError("OpConcat: input dimension not matching output dimension");
+ return 1;
+ }
+ }
+ }
+
+ ERROR_IF(output_dim_on_axis == outputs[0]->getShape()[attribute->axis()],
+ "OpConcat: sum of input dimension on axis not equal to output dimension on axis");
+
+ out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+
return 0;
}
@@ -135,14 +159,13 @@ int OpPad<Rank, Dtype>::checkTensorAttributes()
return 1;
}
- in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
- out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
- TosaReference::TensorTemplate<ETensor2<int32_t>>* paddings =
- dynamic_cast<TosaReference::TensorTemplate<ETensor2<int32_t>>*>(inputs[1]);
+ in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+ paddings = dynamic_cast<TosaReference::TensorTemplate<ETensor2<int32_t>>*>(inputs[1]);
- for (int i = 0; i < Rank; i++)
+ if (this->qinfo && Dtype != DType_INT8)
{
- paddings_array[i] = std::make_pair(paddings->getTensor()(i, 0), paddings->getTensor()(i, 1));
+ ERROR_IF(this->qinfo->input_zp() != 0, "OpPad: zeropoint should be 0");
}
return 0;
@@ -151,6 +174,14 @@ int OpPad<Rank, Dtype>::checkTensorAttributes()
template <int Rank, DType Dtype>
int OpPad<Rank, Dtype>::eval()
{
+ // Move this to
+ for (int i = 0; i < Rank; i++)
+ {
+ ERROR_IF((paddings->getTensor()(i, 0) < 0) || (paddings->getTensor()(i, 1) < 0),
+ "OpPad: padding can't be smaller than 0");
+ paddings_array[i] = std::make_pair(paddings->getTensor()(i, 0), paddings->getTensor()(i, 1));
+ }
+
InEigenType pad_value = 0;
if (this->qinfo)
{
@@ -202,12 +233,20 @@ int OpReshape<InRank, OutRank, Dtype>::checkTensorAttributes()
return 1;
}
+ ERROR_IF(inputs[0]->getElementCount() != outputs[0]->getElementCount(),
+ "Input tensor size does not match output tensor size");
+
for (uint32_t d = 0; d < OutRank; d++)
{
if (attribute->shape()[d] == -1)
{
minusOneCount++;
}
+ else
+ {
+ ERROR_IF(attribute->shape()[d] != outputs[0]->getShape()[d],
+ "OpReshape: new_shape doesn't match output shape");
+ }
}
if (minusOneCount > 1)
@@ -358,7 +397,7 @@ OpSlice<Rank, Dtype>::OpSlice(SubgraphTraverser* sgt_,
: GraphNode(sgt_, Op_SLICE, id_)
{
setRequiredOperands(1, 1);
- setRequiredRank(0, 6);
+ setRequiredRank(1, 4);
INIT_ATTRIBUTE(Slice);
}
@@ -391,23 +430,20 @@ int OpSlice<Rank, Dtype>::checkTensorAttributes()
in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
- for (size_t i = 0; i < attribute->begin().size(); i++)
- {
- begin_array[i] = attribute->begin()[i];
- }
+ ERROR_IF((int32_t)attribute->begin().size() != in->getRank(),
+ "OpSlice: begin array length needs to be rank(input)");
+ ERROR_IF((int32_t)attribute->size().size() != in->getRank(), "OpSlice: size array length needs to be rank(input)");
- for (size_t i = 0; i < attribute->size().size(); i++)
+ for (int32_t i = 0; i < in->getRank(); i++)
{
- if (attribute->size()[i] != 0)
- {
- size_array[i] = attribute->size()[i];
- }
- else
- {
- // Tensorflow assigns a zero size to dimensions that are kept
- // Eigen expects size to be the full size of the dimension
- size_array[i] = in->getTensor().dimension(0);
- }
+ int32_t b = attribute->begin()[i];
+ int32_t s = attribute->size()[i];
+ ERROR_IF(b < 0 || b >= in->getShape()[i], "OpSlice: start out of boundary");
+ ERROR_IF((b + s) < 0 || (b + s) > in->getShape()[i], "OpSlice: (start+size) out of boundary");
+ ERROR_IF(s <= 0, "OpSlice: output must be positive");
+ ERROR_IF(s != out->getShape()[i], "OpSlice: size doesn't match output tensor dimension");
+ begin_array[i] = b;
+ size_array[i] = s;
}
return 0;
@@ -611,6 +647,7 @@ int OpTranspose<Rank, Dtype>::eval()
for (int32_t d = 0; d < Rank; d++)
{
perm_array[d] = this->perm_tensor->getTensor().data()[d];
+ ERROR_IF(perm_array[d] < 0 or perm_array[d] >= Rank, "OpTranspose: index out of boundary");
}
out->getTensor() = in->getTensor().shuffle(perm_array);
diff --git a/reference_model/src/ops/data_layout.h b/reference_model/src/ops/data_layout.h
index c9c2602..9f44fc7 100644
--- a/reference_model/src/ops/data_layout.h
+++ b/reference_model/src/ops/data_layout.h
@@ -63,6 +63,7 @@ protected:
Eigen::array<std::pair<ptrdiff_t, ptrdiff_t>, Rank> paddings_array;
TosaReference::TensorTemplate<TIn>* in;
TosaReference::TensorTemplate<TOut>* out;
+ TosaReference::TensorTemplate<Eigen::Tensor<int32_t, 2>>* paddings;
TosaPadQuantInfo* qinfo;
};
diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc
index 023158c..6808604 100644
--- a/reference_model/src/ops/ewise_binary.cc
+++ b/reference_model/src/ops/ewise_binary.cc
@@ -60,26 +60,16 @@ int BinaryNodeBase<Rank, InDtype, OutDtype>::checkTensorAttributes()
return 1;
}
- // In some ops, only rank of input and output tensor needs to match
- if (nodeType == Op_MUL || nodeType == Op_GREATER || nodeType == Op_EQUAL || nodeType == Op_GREATER_EQUAL)
- {
- if (inputs[0]->matchRank(*outputs[0]))
- {
- std::string err =
- "Binary operators " + std::string(EnumNamesOp()[nodeType]) + " input and output rank must match";
- printNodeValidationError(err.c_str());
- return 1;
- }
- }
- // Otherwise both rand/type of input and output must match
- else if (inputs[0]->matchRankType(*outputs[0]))
+ if (inputs[0]->matchRank(*outputs[0]))
{
std::string err =
- "Binary operators " + std::string(EnumNamesOp()[nodeType]) + " input and output rank and type must match";
+ "Binary operators " + std::string(EnumNamesOp()[nodeType]) + " input and output rank must match";
printNodeValidationError(err.c_str());
return 1;
}
+ ERROR_IF(outputs[0]->getDtype() != OutDtype, "Binary operator type doesn't match");
+
a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
b = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
result = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
@@ -532,6 +522,7 @@ int OpTable<Rank, InDtype>::checkTensorAttributes()
printNodeValidationError("OpTable: Table must be INT8[256] if input is INT8");
return 1;
}
+ ERROR_IF(outputs[0]->getDtype() != DType_INT8, "OpTable: output tensor must be INT8");
}
else if (inputs[0]->getDtype() == DType_INT16)
{
@@ -540,6 +531,7 @@ int OpTable<Rank, InDtype>::checkTensorAttributes()
printNodeValidationError("OpTable: Table must be INT16[513] if input is INT16");
return 1;
}
+ ERROR_IF(outputs[0]->getDtype() != DType_INT32, "OpTable: output tensor must be INT32");
}
in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc
index 118d048..be4e4aa 100644
--- a/reference_model/src/ops/tensor_ops.cc
+++ b/reference_model/src/ops/tensor_ops.cc
@@ -115,7 +115,7 @@ OpArgMax<Rank, Dtype>::OpArgMax(SubgraphTraverser* sgt_,
: GraphNode(sgt_, Op_ARGMAX, id_)
{
setRequiredOperands(1, 1);
- setRequiredRank(0, 6);
+ setRequiredRank(1, 4);
INIT_ATTRIBUTE(Axis);
}
@@ -133,14 +133,60 @@ int OpArgMax<Rank, Dtype>::checkTensorAttributes()
if (validateRequiredOperands())
return 1;
- if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
+ if (validateRequiredRank(inputs[0]))
+ {
+ return 1;
+ }
+
+ int32_t output_rank = inputs[0]->getRank() - 1;
+ if (output_rank != outputs[0]->getRank())
{
+ printNodeValidationError("OpArgMax: Output rank needs to be rank(input) - 1");
+ return 1;
+ }
+
+ if (outputs[0]->getDtype() != DType_INT32)
+ {
+ printNodeValidationError("OpArgMax: Output data type not supported for this configuration of operator");
return 1;
}
input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+ if (attribute->axis() < 0 || attribute->axis() >= input->getRank())
+ {
+ printNodeValidationError("OpArgMax: Axis needs to be within [0, rank(input)]");
+ return 1;
+ }
+
+ bool shape_check = true;
+ for (int32_t i = 0; i < input->getRank(); i++)
+ {
+ if (i < attribute->axis())
+ {
+ if (input->getShape()[i] != output->getShape()[i])
+ {
+ shape_check = false;
+ break;
+ }
+ }
+ else if (i > attribute->axis())
+ {
+ if (input->getShape()[i] != output->getShape()[i - 1])
+ {
+ shape_check = false;
+ break;
+ }
+ }
+ // No need to check i == axis
+ }
+ if (!shape_check)
+ {
+ printNodeValidationError("OpArgMax: Mismatch between output shape provided and expected output shape");
+ return 1;
+ }
+
return 0;
}
@@ -411,6 +457,9 @@ int OpConv2d<InDtype, WeightDtype>::checkTensorAttributes()
printNodeValidationError("OpConv2d: bias tensor must be rank 1");
}
+ ERROR_IF(outputs[0]->getDtype() != AccDtype,
+ "OpFullyConnected: Output data type not supported for this configuration of operator");
+
input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
@@ -434,6 +483,18 @@ int OpConv2d<InDtype, WeightDtype>::checkTensorAttributes()
return 1;
}
+ if (this->qinfo)
+ {
+ if (InDtype != DType_INT8)
+ {
+ ERROR_IF(this->qinfo->input_zp() != 0, "OpConv2d: zeropoint only for int8_t");
+ }
+ if (WeightDtype != DType_INT8)
+ {
+ ERROR_IF(this->qinfo->weight_zp() != 0, "OpConv2d: zeropoint only for int8_t");
+ }
+ }
+
return 0;
}
@@ -603,6 +664,9 @@ int OpConv3d<InDtype, WeightDtype>::checkTensorAttributes()
printNodeValidationError("OpConv3d: bias tensor must be rank 1");
}
+ ERROR_IF(outputs[0]->getDtype() != AccDtype,
+ "OpFullyConnected: Output data type not supported for this configuration of operator");
+
input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
@@ -626,6 +690,18 @@ int OpConv3d<InDtype, WeightDtype>::checkTensorAttributes()
return 1;
}
+ if (this->qinfo)
+ {
+ if (InDtype != DType_INT8)
+ {
+ ERROR_IF(this->qinfo->input_zp() != 0, "OpConv3d: zeropoint only for int8_t");
+ }
+ if (WeightDtype != DType_INT8)
+ {
+ ERROR_IF(this->qinfo->weight_zp() != 0, "OpConv3d: zeropoint only for int8_t");
+ }
+ }
+
return 0;
}
@@ -798,6 +874,9 @@ int OpDepthwiseConv2d<InDtype, WeightDtype>::checkTensorAttributes()
printNodeValidationError("OpDepthwiseConv2d: bias tensor must be rank 1");
}
+ ERROR_IF(outputs[0]->getDtype() != AccDtype,
+ "OpFullyConnected: Output data type not supported for this configuration of operator");
+
input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
@@ -821,6 +900,18 @@ int OpDepthwiseConv2d<InDtype, WeightDtype>::checkTensorAttributes()
return 1;
}
+ if (this->qinfo)
+ {
+ if (InDtype != DType_INT8)
+ {
+ ERROR_IF(this->qinfo->input_zp() != 0, "OpDepthwiseConv2d: zeropoint only for int8_t");
+ }
+ if (WeightDtype != DType_INT8)
+ {
+ ERROR_IF(this->qinfo->weight_zp() != 0, "OpDepthwiseConv2d: zeropoint only for int8_t");
+ }
+ }
+
return 0;
}
@@ -987,8 +1078,23 @@ int OpFullyConnected<InDtype, WeightDtype>::checkTensorAttributes()
return 1;
}
+ ERROR_IF(outputs[0]->getDtype() != AccDtype,
+ "OpFullyConnected: Output data type not supported for this configuration of operator");
+
output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
+ if (this->qinfo)
+ {
+ if (InDtype != DType_INT8)
+ {
+ ERROR_IF(this->qinfo->input_zp() != 0, "OpFullyConnected: zeropoint only for int8_t");
+ }
+ if (WeightDtype != DType_INT8)
+ {
+ ERROR_IF(this->qinfo->weight_zp() != 0, "OpFullyConnected: zeropoint only for int8_t");
+ }
+ }
+
return 0;
}
@@ -1059,6 +1165,9 @@ int OpMatMul<Dtype>::checkTensorAttributes()
return 1;
}
+ ERROR_IF(outputs[0]->getDtype() != AccDtype,
+ "OpFullyConnected: Output data type not supported for this configuration of operator");
+
a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
b = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
@@ -1101,6 +1210,12 @@ int OpMatMul<Dtype>::checkTensorAttributes()
}
W = b->getShape()[2];
+ if (Dtype != DType_INT8)
+ {
+ ERROR_IF(this->qinfo->a_zp() != 0, "OpMatMul: zeropoint only for int8_t");
+ ERROR_IF(this->qinfo->b_zp() != 0, "OpMatMul: zeropoint only for int8_t");
+ }
+
return 0;
}
@@ -1291,11 +1406,11 @@ int OpMaxPool2d<Dtype>::eval()
return GraphNode::eval();
}
-template <DType InDtype, DType OutDtype>
-OpTransposeConv2d<InDtype, OutDtype>::OpTransposeConv2d(SubgraphTraverser* sgt_,
- TosaAttributeBase* attribute_,
- TosaQuantInfoBase* qinfo_,
- uint64_t id_)
+template <DType InDtype, DType WeightDtype>
+OpTransposeConv2d<InDtype, WeightDtype>::OpTransposeConv2d(SubgraphTraverser* sgt_,
+ TosaAttributeBase* attribute_,
+ TosaQuantInfoBase* qinfo_,
+ uint64_t id_)
: GraphNode(sgt_, Op_TRANSPOSE_CONV2D, id_)
{
setRequiredOperands(3, 1);
@@ -1305,8 +1420,8 @@ OpTransposeConv2d<InDtype, OutDtype>::OpTransposeConv2d(SubgraphTraverser* sgt_,
INIT_QINFO(Conv);
}
-template <DType InDtype, DType OutDtype>
-OpTransposeConv2d<InDtype, OutDtype>::~OpTransposeConv2d()
+template <DType InDtype, DType WeightDtype>
+OpTransposeConv2d<InDtype, WeightDtype>::~OpTransposeConv2d()
{
if (attribute)
delete attribute;
@@ -1314,8 +1429,8 @@ OpTransposeConv2d<InDtype, OutDtype>::~OpTransposeConv2d()
delete qinfo;
}
-template <DType InDtype, DType OutDtype>
-int OpTransposeConv2d<InDtype, OutDtype>::checkTensorAttributes()
+template <DType InDtype, DType WeightDtype>
+int OpTransposeConv2d<InDtype, WeightDtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
return 1;
@@ -1325,6 +1440,9 @@ int OpTransposeConv2d<InDtype, OutDtype>::checkTensorAttributes()
return 1;
}
+ ERROR_IF(outputs[0]->getDtype() != AccDtype,
+ "OpFullyConnected: Output data type not supported for this configuration of operator");
+
input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
@@ -1363,11 +1481,23 @@ int OpTransposeConv2d<InDtype, OutDtype>::checkTensorAttributes()
}
}
+ if (this->qinfo)
+ {
+ if (InDtype != DType_INT8)
+ {
+ ERROR_IF(this->qinfo->input_zp() != 0, "OpTransposeConv2d: zeropoint only for int8_t");
+ }
+ if (WeightDtype != DType_INT8)
+ {
+ ERROR_IF(this->qinfo->weight_zp() != 0, "OpTransposeConv2d: zeropoint only for int8_t");
+ }
+ }
+
return 0;
}
-template <DType InDtype, DType OutDtype>
-int OpTransposeConv2d<InDtype, OutDtype>::eval()
+template <DType InDtype, DType WeightDtype>
+int OpTransposeConv2d<InDtype, WeightDtype>::eval()
{
int in_batch = this->input->getShape()[0];
int in_height = this->input->getShape()[1];
diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc
index 657eebf..e46ab38 100644
--- a/reference_model/src/ops/type_conversion.cc
+++ b/reference_model/src/ops/type_conversion.cc
@@ -30,7 +30,7 @@ OpRescale<Rank, InDtype, OutDtype>::OpRescale(SubgraphTraverser* sgt_,
: GraphNode(sgt_, Op_RESCALE, id_)
{
setRequiredOperands(1, 1);
- setRequiredRank(0, 6);
+ setRequiredRank(0, 4);
INIT_ATTRIBUTE(Rescale);
}
@@ -64,6 +64,30 @@ int OpRescale<Rank, InDtype, OutDtype>::checkTensorAttributes()
ASSERT_MEM(in && out);
+ if ((InDtype != DType_INT8) && (InDtype != DType_UINT8) && (attribute->input_zp() != 0))
+ {
+ printNodeValidationError("OpRescale: Input DType not INT8/UINT8 and zero point not 0");
+ return 1;
+ }
+
+ if ((OutDtype != DType_INT8) && (OutDtype != DType_UINT8) && (attribute->output_zp() != 0))
+ {
+ printNodeValidationError("OpRescale: Output DType not INT8/UINT8 and zero point not 0");
+ return 1;
+ }
+
+ if (attribute->scale32() && (InDtype == DType_INT48))
+ {
+ printNodeValidationError("OpRescale: Scale set to true but input type is INT48");
+ return 1;
+ }
+
+ if ((!attribute->scale32()) && attribute->double_round())
+ {
+ printNodeValidationError("OpRescale: Scale set to false but double round set to true");
+ return 1;
+ }
+
return 0;
}