diff options
author | Kevin Cheng <kevin.cheng@arm.com> | 2021-10-14 17:09:57 -0700 |
---|---|---|
committer | Kevin Cheng <kevin.cheng@arm.com> | 2021-10-18 18:50:08 +0000 |
commit | cc61be36c3b0f5cd1ea719e129a54fd48a6ee9a2 (patch) | |
tree | 2d664f87e3fdd75de8c6794f6f6c8d6364ece6bb /reference_model/src/ops/tensor_ops.cc | |
parent | e807aae606a78d923a2565052f7c2179e3050650 (diff) | |
download | reference_model-cc61be36c3b0f5cd1ea719e129a54fd48a6ee9a2.tar.gz |
More ERROR_IF supports
- Also delay tensor allocation after operator being validated
ERROR_IF can be caught first before 0 or negative dimension set the graph_status to UNPREDICTABLE
- Rescale, Argmax, FullyConnected, Matmul, Pad, Reshape, Slice, Transpose, Clamp, Concat, Equal, Greater, GreaterEqual, Table
Signed-off-by: Kevin Cheng <kevin.cheng@arm.com>
Change-Id: I4e1b3e5794fe195ce1a37e28443ae584645a3b91
Diffstat (limited to 'reference_model/src/ops/tensor_ops.cc')
-rw-r--r-- | reference_model/src/ops/tensor_ops.cc | 156 |
1 files changed, 143 insertions, 13 deletions
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc index 118d048..be4e4aa 100644 --- a/reference_model/src/ops/tensor_ops.cc +++ b/reference_model/src/ops/tensor_ops.cc @@ -115,7 +115,7 @@ OpArgMax<Rank, Dtype>::OpArgMax(SubgraphTraverser* sgt_, : GraphNode(sgt_, Op_ARGMAX, id_) { setRequiredOperands(1, 1); - setRequiredRank(0, 6); + setRequiredRank(1, 4); INIT_ATTRIBUTE(Axis); } @@ -133,14 +133,60 @@ int OpArgMax<Rank, Dtype>::checkTensorAttributes() if (validateRequiredOperands()) return 1; - if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0])) + if (validateRequiredRank(inputs[0])) + { + return 1; + } + + int32_t output_rank = inputs[0]->getRank() - 1; + if (output_rank != outputs[0]->getRank()) { + printNodeValidationError("OpArgMax: Output rank needs to be rank(input) - 1"); + return 1; + } + + if (outputs[0]->getDtype() != DType_INT32) + { + printNodeValidationError("OpArgMax: Output data type not supported for this configuration of operator"); return 1; } input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]); + if (attribute->axis() < 0 || attribute->axis() >= input->getRank()) + { + printNodeValidationError("OpArgMax: Axis needs to be within [0, rank(input)]"); + return 1; + } + + bool shape_check = true; + for (int32_t i = 0; i < input->getRank(); i++) + { + if (i < attribute->axis()) + { + if (input->getShape()[i] != output->getShape()[i]) + { + shape_check = false; + break; + } + } + else if (i > attribute->axis()) + { + if (input->getShape()[i] != output->getShape()[i - 1]) + { + shape_check = false; + break; + } + } + // No need to check i == axis + } + if (!shape_check) + { + printNodeValidationError("OpArgMax: Mismatch between output shape provided and expected output shape"); + return 1; + } + return 0; } @@ -411,6 +457,9 @@ int OpConv2d<InDtype, WeightDtype>::checkTensorAttributes() printNodeValidationError("OpConv2d: bias tensor must be rank 1"); } + ERROR_IF(outputs[0]->getDtype() != AccDtype, + "OpFullyConnected: Output data type not supported for this configuration of operator"); + input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]); bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]); @@ -434,6 +483,18 @@ int OpConv2d<InDtype, WeightDtype>::checkTensorAttributes() return 1; } + if (this->qinfo) + { + if (InDtype != DType_INT8) + { + ERROR_IF(this->qinfo->input_zp() != 0, "OpConv2d: zeropoint only for int8_t"); + } + if (WeightDtype != DType_INT8) + { + ERROR_IF(this->qinfo->weight_zp() != 0, "OpConv2d: zeropoint only for int8_t"); + } + } + return 0; } @@ -603,6 +664,9 @@ int OpConv3d<InDtype, WeightDtype>::checkTensorAttributes() printNodeValidationError("OpConv3d: bias tensor must be rank 1"); } + ERROR_IF(outputs[0]->getDtype() != AccDtype, + "OpFullyConnected: Output data type not supported for this configuration of operator"); + input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]); bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]); @@ -626,6 +690,18 @@ int OpConv3d<InDtype, WeightDtype>::checkTensorAttributes() return 1; } + if (this->qinfo) + { + if (InDtype != DType_INT8) + { + ERROR_IF(this->qinfo->input_zp() != 0, "OpConv3d: zeropoint only for int8_t"); + } + if (WeightDtype != DType_INT8) + { + ERROR_IF(this->qinfo->weight_zp() != 0, "OpConv3d: zeropoint only for int8_t"); + } + } + return 0; } @@ -798,6 +874,9 @@ int OpDepthwiseConv2d<InDtype, WeightDtype>::checkTensorAttributes() printNodeValidationError("OpDepthwiseConv2d: bias tensor must be rank 1"); } + ERROR_IF(outputs[0]->getDtype() != AccDtype, + "OpFullyConnected: Output data type not supported for this configuration of operator"); + input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]); bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]); @@ -821,6 +900,18 @@ int OpDepthwiseConv2d<InDtype, WeightDtype>::checkTensorAttributes() return 1; } + if (this->qinfo) + { + if (InDtype != DType_INT8) + { + ERROR_IF(this->qinfo->input_zp() != 0, "OpDepthwiseConv2d: zeropoint only for int8_t"); + } + if (WeightDtype != DType_INT8) + { + ERROR_IF(this->qinfo->weight_zp() != 0, "OpDepthwiseConv2d: zeropoint only for int8_t"); + } + } + return 0; } @@ -987,8 +1078,23 @@ int OpFullyConnected<InDtype, WeightDtype>::checkTensorAttributes() return 1; } + ERROR_IF(outputs[0]->getDtype() != AccDtype, + "OpFullyConnected: Output data type not supported for this configuration of operator"); + output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]); + if (this->qinfo) + { + if (InDtype != DType_INT8) + { + ERROR_IF(this->qinfo->input_zp() != 0, "OpFullyConnected: zeropoint only for int8_t"); + } + if (WeightDtype != DType_INT8) + { + ERROR_IF(this->qinfo->weight_zp() != 0, "OpFullyConnected: zeropoint only for int8_t"); + } + } + return 0; } @@ -1059,6 +1165,9 @@ int OpMatMul<Dtype>::checkTensorAttributes() return 1; } + ERROR_IF(outputs[0]->getDtype() != AccDtype, + "OpFullyConnected: Output data type not supported for this configuration of operator"); + a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); b = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]); output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]); @@ -1101,6 +1210,12 @@ int OpMatMul<Dtype>::checkTensorAttributes() } W = b->getShape()[2]; + if (Dtype != DType_INT8) + { + ERROR_IF(this->qinfo->a_zp() != 0, "OpMatMul: zeropoint only for int8_t"); + ERROR_IF(this->qinfo->b_zp() != 0, "OpMatMul: zeropoint only for int8_t"); + } + return 0; } @@ -1291,11 +1406,11 @@ int OpMaxPool2d<Dtype>::eval() return GraphNode::eval(); } -template <DType InDtype, DType OutDtype> -OpTransposeConv2d<InDtype, OutDtype>::OpTransposeConv2d(SubgraphTraverser* sgt_, - TosaAttributeBase* attribute_, - TosaQuantInfoBase* qinfo_, - uint64_t id_) +template <DType InDtype, DType WeightDtype> +OpTransposeConv2d<InDtype, WeightDtype>::OpTransposeConv2d(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + TosaQuantInfoBase* qinfo_, + uint64_t id_) : GraphNode(sgt_, Op_TRANSPOSE_CONV2D, id_) { setRequiredOperands(3, 1); @@ -1305,8 +1420,8 @@ OpTransposeConv2d<InDtype, OutDtype>::OpTransposeConv2d(SubgraphTraverser* sgt_, INIT_QINFO(Conv); } -template <DType InDtype, DType OutDtype> -OpTransposeConv2d<InDtype, OutDtype>::~OpTransposeConv2d() +template <DType InDtype, DType WeightDtype> +OpTransposeConv2d<InDtype, WeightDtype>::~OpTransposeConv2d() { if (attribute) delete attribute; @@ -1314,8 +1429,8 @@ OpTransposeConv2d<InDtype, OutDtype>::~OpTransposeConv2d() delete qinfo; } -template <DType InDtype, DType OutDtype> -int OpTransposeConv2d<InDtype, OutDtype>::checkTensorAttributes() +template <DType InDtype, DType WeightDtype> +int OpTransposeConv2d<InDtype, WeightDtype>::checkTensorAttributes() { if (validateRequiredOperands()) return 1; @@ -1325,6 +1440,9 @@ int OpTransposeConv2d<InDtype, OutDtype>::checkTensorAttributes() return 1; } + ERROR_IF(outputs[0]->getDtype() != AccDtype, + "OpFullyConnected: Output data type not supported for this configuration of operator"); + input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]); bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]); @@ -1363,11 +1481,23 @@ int OpTransposeConv2d<InDtype, OutDtype>::checkTensorAttributes() } } + if (this->qinfo) + { + if (InDtype != DType_INT8) + { + ERROR_IF(this->qinfo->input_zp() != 0, "OpTransposeConv2d: zeropoint only for int8_t"); + } + if (WeightDtype != DType_INT8) + { + ERROR_IF(this->qinfo->weight_zp() != 0, "OpTransposeConv2d: zeropoint only for int8_t"); + } + } + return 0; } -template <DType InDtype, DType OutDtype> -int OpTransposeConv2d<InDtype, OutDtype>::eval() +template <DType InDtype, DType WeightDtype> +int OpTransposeConv2d<InDtype, WeightDtype>::eval() { int in_batch = this->input->getShape()[0]; int in_height = this->input->getShape()[1]; |