aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/tensor_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops/tensor_ops.cc')
-rw-r--r--reference_model/src/ops/tensor_ops.cc156
1 files changed, 143 insertions, 13 deletions
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc
index 118d048..be4e4aa 100644
--- a/reference_model/src/ops/tensor_ops.cc
+++ b/reference_model/src/ops/tensor_ops.cc
@@ -115,7 +115,7 @@ OpArgMax<Rank, Dtype>::OpArgMax(SubgraphTraverser* sgt_,
: GraphNode(sgt_, Op_ARGMAX, id_)
{
setRequiredOperands(1, 1);
- setRequiredRank(0, 6);
+ setRequiredRank(1, 4);
INIT_ATTRIBUTE(Axis);
}
@@ -133,14 +133,60 @@ int OpArgMax<Rank, Dtype>::checkTensorAttributes()
if (validateRequiredOperands())
return 1;
- if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
+ if (validateRequiredRank(inputs[0]))
+ {
+ return 1;
+ }
+
+ int32_t output_rank = inputs[0]->getRank() - 1;
+ if (output_rank != outputs[0]->getRank())
{
+ printNodeValidationError("OpArgMax: Output rank needs to be rank(input) - 1");
+ return 1;
+ }
+
+ if (outputs[0]->getDtype() != DType_INT32)
+ {
+ printNodeValidationError("OpArgMax: Output data type not supported for this configuration of operator");
return 1;
}
input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+ if (attribute->axis() < 0 || attribute->axis() >= input->getRank())
+ {
+ printNodeValidationError("OpArgMax: Axis needs to be within [0, rank(input)]");
+ return 1;
+ }
+
+ bool shape_check = true;
+ for (int32_t i = 0; i < input->getRank(); i++)
+ {
+ if (i < attribute->axis())
+ {
+ if (input->getShape()[i] != output->getShape()[i])
+ {
+ shape_check = false;
+ break;
+ }
+ }
+ else if (i > attribute->axis())
+ {
+ if (input->getShape()[i] != output->getShape()[i - 1])
+ {
+ shape_check = false;
+ break;
+ }
+ }
+ // No need to check i == axis
+ }
+ if (!shape_check)
+ {
+ printNodeValidationError("OpArgMax: Mismatch between output shape provided and expected output shape");
+ return 1;
+ }
+
return 0;
}
@@ -411,6 +457,9 @@ int OpConv2d<InDtype, WeightDtype>::checkTensorAttributes()
printNodeValidationError("OpConv2d: bias tensor must be rank 1");
}
+ ERROR_IF(outputs[0]->getDtype() != AccDtype,
+ "OpFullyConnected: Output data type not supported for this configuration of operator");
+
input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
@@ -434,6 +483,18 @@ int OpConv2d<InDtype, WeightDtype>::checkTensorAttributes()
return 1;
}
+ if (this->qinfo)
+ {
+ if (InDtype != DType_INT8)
+ {
+ ERROR_IF(this->qinfo->input_zp() != 0, "OpConv2d: zeropoint only for int8_t");
+ }
+ if (WeightDtype != DType_INT8)
+ {
+ ERROR_IF(this->qinfo->weight_zp() != 0, "OpConv2d: zeropoint only for int8_t");
+ }
+ }
+
return 0;
}
@@ -603,6 +664,9 @@ int OpConv3d<InDtype, WeightDtype>::checkTensorAttributes()
printNodeValidationError("OpConv3d: bias tensor must be rank 1");
}
+ ERROR_IF(outputs[0]->getDtype() != AccDtype,
+ "OpFullyConnected: Output data type not supported for this configuration of operator");
+
input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
@@ -626,6 +690,18 @@ int OpConv3d<InDtype, WeightDtype>::checkTensorAttributes()
return 1;
}
+ if (this->qinfo)
+ {
+ if (InDtype != DType_INT8)
+ {
+ ERROR_IF(this->qinfo->input_zp() != 0, "OpConv3d: zeropoint only for int8_t");
+ }
+ if (WeightDtype != DType_INT8)
+ {
+ ERROR_IF(this->qinfo->weight_zp() != 0, "OpConv3d: zeropoint only for int8_t");
+ }
+ }
+
return 0;
}
@@ -798,6 +874,9 @@ int OpDepthwiseConv2d<InDtype, WeightDtype>::checkTensorAttributes()
printNodeValidationError("OpDepthwiseConv2d: bias tensor must be rank 1");
}
+ ERROR_IF(outputs[0]->getDtype() != AccDtype,
+ "OpFullyConnected: Output data type not supported for this configuration of operator");
+
input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
@@ -821,6 +900,18 @@ int OpDepthwiseConv2d<InDtype, WeightDtype>::checkTensorAttributes()
return 1;
}
+ if (this->qinfo)
+ {
+ if (InDtype != DType_INT8)
+ {
+ ERROR_IF(this->qinfo->input_zp() != 0, "OpDepthwiseConv2d: zeropoint only for int8_t");
+ }
+ if (WeightDtype != DType_INT8)
+ {
+ ERROR_IF(this->qinfo->weight_zp() != 0, "OpDepthwiseConv2d: zeropoint only for int8_t");
+ }
+ }
+
return 0;
}
@@ -987,8 +1078,23 @@ int OpFullyConnected<InDtype, WeightDtype>::checkTensorAttributes()
return 1;
}
+ ERROR_IF(outputs[0]->getDtype() != AccDtype,
+ "OpFullyConnected: Output data type not supported for this configuration of operator");
+
output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
+ if (this->qinfo)
+ {
+ if (InDtype != DType_INT8)
+ {
+ ERROR_IF(this->qinfo->input_zp() != 0, "OpFullyConnected: zeropoint only for int8_t");
+ }
+ if (WeightDtype != DType_INT8)
+ {
+ ERROR_IF(this->qinfo->weight_zp() != 0, "OpFullyConnected: zeropoint only for int8_t");
+ }
+ }
+
return 0;
}
@@ -1059,6 +1165,9 @@ int OpMatMul<Dtype>::checkTensorAttributes()
return 1;
}
+ ERROR_IF(outputs[0]->getDtype() != AccDtype,
+ "OpFullyConnected: Output data type not supported for this configuration of operator");
+
a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
b = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
@@ -1101,6 +1210,12 @@ int OpMatMul<Dtype>::checkTensorAttributes()
}
W = b->getShape()[2];
+ if (Dtype != DType_INT8)
+ {
+ ERROR_IF(this->qinfo->a_zp() != 0, "OpMatMul: zeropoint only for int8_t");
+ ERROR_IF(this->qinfo->b_zp() != 0, "OpMatMul: zeropoint only for int8_t");
+ }
+
return 0;
}
@@ -1291,11 +1406,11 @@ int OpMaxPool2d<Dtype>::eval()
return GraphNode::eval();
}
-template <DType InDtype, DType OutDtype>
-OpTransposeConv2d<InDtype, OutDtype>::OpTransposeConv2d(SubgraphTraverser* sgt_,
- TosaAttributeBase* attribute_,
- TosaQuantInfoBase* qinfo_,
- uint64_t id_)
+template <DType InDtype, DType WeightDtype>
+OpTransposeConv2d<InDtype, WeightDtype>::OpTransposeConv2d(SubgraphTraverser* sgt_,
+ TosaAttributeBase* attribute_,
+ TosaQuantInfoBase* qinfo_,
+ uint64_t id_)
: GraphNode(sgt_, Op_TRANSPOSE_CONV2D, id_)
{
setRequiredOperands(3, 1);
@@ -1305,8 +1420,8 @@ OpTransposeConv2d<InDtype, OutDtype>::OpTransposeConv2d(SubgraphTraverser* sgt_,
INIT_QINFO(Conv);
}
-template <DType InDtype, DType OutDtype>
-OpTransposeConv2d<InDtype, OutDtype>::~OpTransposeConv2d()
+template <DType InDtype, DType WeightDtype>
+OpTransposeConv2d<InDtype, WeightDtype>::~OpTransposeConv2d()
{
if (attribute)
delete attribute;
@@ -1314,8 +1429,8 @@ OpTransposeConv2d<InDtype, OutDtype>::~OpTransposeConv2d()
delete qinfo;
}
-template <DType InDtype, DType OutDtype>
-int OpTransposeConv2d<InDtype, OutDtype>::checkTensorAttributes()
+template <DType InDtype, DType WeightDtype>
+int OpTransposeConv2d<InDtype, WeightDtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
return 1;
@@ -1325,6 +1440,9 @@ int OpTransposeConv2d<InDtype, OutDtype>::checkTensorAttributes()
return 1;
}
+ ERROR_IF(outputs[0]->getDtype() != AccDtype,
+ "OpFullyConnected: Output data type not supported for this configuration of operator");
+
input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
@@ -1363,11 +1481,23 @@ int OpTransposeConv2d<InDtype, OutDtype>::checkTensorAttributes()
}
}
+ if (this->qinfo)
+ {
+ if (InDtype != DType_INT8)
+ {
+ ERROR_IF(this->qinfo->input_zp() != 0, "OpTransposeConv2d: zeropoint only for int8_t");
+ }
+ if (WeightDtype != DType_INT8)
+ {
+ ERROR_IF(this->qinfo->weight_zp() != 0, "OpTransposeConv2d: zeropoint only for int8_t");
+ }
+ }
+
return 0;
}
-template <DType InDtype, DType OutDtype>
-int OpTransposeConv2d<InDtype, OutDtype>::eval()
+template <DType InDtype, DType WeightDtype>
+int OpTransposeConv2d<InDtype, WeightDtype>::eval()
{
int in_batch = this->input->getShape()[0];
int in_height = this->input->getShape()[1];