diff options
Diffstat (limited to 'reference_model/src/ops/tensor_ops.cc')
-rw-r--r-- | reference_model/src/ops/tensor_ops.cc | 26 |
1 files changed, 15 insertions, 11 deletions
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc index f8fd323..a60819d 100644 --- a/reference_model/src/ops/tensor_ops.cc +++ b/reference_model/src/ops/tensor_ops.cc @@ -327,7 +327,7 @@ OpArgMax<Rank, Dtype>::OpArgMax(SubgraphTraverser* sgt_, : GraphNode(sgt_, Op_ARGMAX, id_) { setRequiredOperands(1, 1); - setRequiredRank(1, 4); + setRequiredRank(1); INIT_ATTRIBUTE(Axis); } @@ -405,6 +405,10 @@ int OpArgMax<Rank, Dtype>::checkTensorAttributes() template <int Rank, TOSA_REF_TYPE Dtype> int OpArgMax<Rank, Dtype>::eval() { + // Check Tosa Level + auto tosa_level = g_func_config.tosa_level; + LEVEL_CHECK(Rank <= tosa_level.MAX_RANK, "Rank should be smaller than or equal to MAX_RANK"); + Eigen::Tensor<DenseIndex, Rank - 1> index = this->input->getTensor().argmax(attribute->axis()); this->output->getTensor() = index.unaryExpr([](DenseIndex in) -> OutEigenType { return (OutEigenType)in; }); @@ -419,7 +423,7 @@ OpAvgPool2d<Dtype, AccDtype>::OpAvgPool2d(SubgraphTraverser* sgt_, : GraphNode(sgt_, Op_AVG_POOL2D, id_) { setRequiredOperands(1, 1); - setRequiredRank(4); + setRequiredRank(4, 4); INIT_ATTRIBUTE(Pool); } @@ -645,7 +649,7 @@ OpConv2d<InDtype, WeightDtype, OutDtype>::OpConv2d(SubgraphTraverser* sgt_, : GraphNode(sgt_, Op_CONV2D, id_) { setRequiredOperands(3, 1); - setRequiredRank(4); + setRequiredRank(4, 4); INIT_ATTRIBUTE(Conv); } @@ -839,7 +843,7 @@ OpConv3d<InDtype, WeightDtype, OutDtype>::OpConv3d(SubgraphTraverser* sgt_, : GraphNode(sgt_, Op_CONV3D, id_) { setRequiredOperands(3, 1); - setRequiredRank(5); + setRequiredRank(5, 5); INIT_ATTRIBUTE(Conv); } @@ -1042,7 +1046,7 @@ OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::OpDepthwiseConv2d(SubgraphTra : GraphNode(sgt_, Op_DEPTHWISE_CONV2D, id_) { setRequiredOperands(3, 1); - setRequiredRank(4); + setRequiredRank(4, 4); INIT_ATTRIBUTE(Conv); } @@ -1227,7 +1231,7 @@ OpFullyConnected<InDtype, WeightDtype, OutDtype>::OpFullyConnected(SubgraphTrave : GraphNode(sgt_, Op_FULLY_CONNECTED, id_) { setRequiredOperands(3, 1); - setRequiredRank(2); + setRequiredRank(2, 2); INIT_ATTRIBUTE(FullyConnected); } @@ -1322,7 +1326,7 @@ OpMatMul<Dtype, OutDtype>::OpMatMul(SubgraphTraverser* sgt_, : GraphNode(sgt_, Op_MATMUL, id_) { setRequiredOperands(2, 1); - setRequiredRank(3); + setRequiredRank(3, 3); INIT_ATTRIBUTE(MatMul); } @@ -1460,7 +1464,7 @@ OpMaxPool2d<Dtype>::OpMaxPool2d(SubgraphTraverser* sgt_, : GraphNode(sgt_, Op_MAX_POOL2D, id_) { setRequiredOperands(1, 1); - setRequiredRank(4); + setRequiredRank(4, 4); INIT_ATTRIBUTE(Pool); } @@ -1601,7 +1605,7 @@ OpFFT2d<Dtype>::OpFFT2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, : GraphNode(sgt_, Op_FFT2D, id_) { setRequiredOperands(2, 2); - setRequiredRank(3); + setRequiredRank(3, 3); INIT_ATTRIBUTE(FFT); } @@ -1724,7 +1728,7 @@ OpRFFT2d<Dtype>::OpRFFT2d(SubgraphTraverser* sgt_, : GraphNode(sgt_, Op_RFFT2D, id_) { setRequiredOperands(1, 2); - setRequiredRank(3); + setRequiredRank(3, 3); } template <TOSA_REF_TYPE Dtype> @@ -1830,7 +1834,7 @@ OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::OpTransposeConv2d(SubgraphTra : GraphNode(sgt_, Op_TRANSPOSE_CONV2D, id_) { setRequiredOperands(3, 1); - setRequiredRank(4); + setRequiredRank(4, 4); INIT_ATTRIBUTE(TransposeConv); } |