aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/tensor_ops.cc
diff options
context:
space:
mode:
authorJerry Ge <jerry.ge@arm.com>2023-05-01 18:36:43 +0000
committerJerry Ge <jerry.ge@arm.com>2023-05-10 02:40:49 +0000
commit0bd4ec89d52cc1fd36e92dff2fb496b3550ee7f5 (patch)
treed2662a0e62aec08a648edf61da62ee789a481080 /reference_model/src/ops/tensor_ops.cc
parenta4d748b08accce06fab93e2d2b96e499b35ae89b (diff)
downloadreference_model-0bd4ec89d52cc1fd36e92dff2fb496b3550ee7f5.tar.gz
Refactor ref_model rank checking and add level check to argmax
Signed-off-by: Jerry Ge <jerry.ge@arm.com> Change-Id: Iad035b31d5e5e83040068e6311501490765bfff7
Diffstat (limited to 'reference_model/src/ops/tensor_ops.cc')
-rw-r--r--reference_model/src/ops/tensor_ops.cc26
1 files changed, 15 insertions, 11 deletions
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc
index f8fd323..a60819d 100644
--- a/reference_model/src/ops/tensor_ops.cc
+++ b/reference_model/src/ops/tensor_ops.cc
@@ -327,7 +327,7 @@ OpArgMax<Rank, Dtype>::OpArgMax(SubgraphTraverser* sgt_,
: GraphNode(sgt_, Op_ARGMAX, id_)
{
setRequiredOperands(1, 1);
- setRequiredRank(1, 4);
+ setRequiredRank(1);
INIT_ATTRIBUTE(Axis);
}
@@ -405,6 +405,10 @@ int OpArgMax<Rank, Dtype>::checkTensorAttributes()
template <int Rank, TOSA_REF_TYPE Dtype>
int OpArgMax<Rank, Dtype>::eval()
{
+ // Check Tosa Level
+ auto tosa_level = g_func_config.tosa_level;
+ LEVEL_CHECK(Rank <= tosa_level.MAX_RANK, "Rank should be smaller than or equal to MAX_RANK");
+
Eigen::Tensor<DenseIndex, Rank - 1> index = this->input->getTensor().argmax(attribute->axis());
this->output->getTensor() = index.unaryExpr([](DenseIndex in) -> OutEigenType { return (OutEigenType)in; });
@@ -419,7 +423,7 @@ OpAvgPool2d<Dtype, AccDtype>::OpAvgPool2d(SubgraphTraverser* sgt_,
: GraphNode(sgt_, Op_AVG_POOL2D, id_)
{
setRequiredOperands(1, 1);
- setRequiredRank(4);
+ setRequiredRank(4, 4);
INIT_ATTRIBUTE(Pool);
}
@@ -645,7 +649,7 @@ OpConv2d<InDtype, WeightDtype, OutDtype>::OpConv2d(SubgraphTraverser* sgt_,
: GraphNode(sgt_, Op_CONV2D, id_)
{
setRequiredOperands(3, 1);
- setRequiredRank(4);
+ setRequiredRank(4, 4);
INIT_ATTRIBUTE(Conv);
}
@@ -839,7 +843,7 @@ OpConv3d<InDtype, WeightDtype, OutDtype>::OpConv3d(SubgraphTraverser* sgt_,
: GraphNode(sgt_, Op_CONV3D, id_)
{
setRequiredOperands(3, 1);
- setRequiredRank(5);
+ setRequiredRank(5, 5);
INIT_ATTRIBUTE(Conv);
}
@@ -1042,7 +1046,7 @@ OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::OpDepthwiseConv2d(SubgraphTra
: GraphNode(sgt_, Op_DEPTHWISE_CONV2D, id_)
{
setRequiredOperands(3, 1);
- setRequiredRank(4);
+ setRequiredRank(4, 4);
INIT_ATTRIBUTE(Conv);
}
@@ -1227,7 +1231,7 @@ OpFullyConnected<InDtype, WeightDtype, OutDtype>::OpFullyConnected(SubgraphTrave
: GraphNode(sgt_, Op_FULLY_CONNECTED, id_)
{
setRequiredOperands(3, 1);
- setRequiredRank(2);
+ setRequiredRank(2, 2);
INIT_ATTRIBUTE(FullyConnected);
}
@@ -1322,7 +1326,7 @@ OpMatMul<Dtype, OutDtype>::OpMatMul(SubgraphTraverser* sgt_,
: GraphNode(sgt_, Op_MATMUL, id_)
{
setRequiredOperands(2, 1);
- setRequiredRank(3);
+ setRequiredRank(3, 3);
INIT_ATTRIBUTE(MatMul);
}
@@ -1460,7 +1464,7 @@ OpMaxPool2d<Dtype>::OpMaxPool2d(SubgraphTraverser* sgt_,
: GraphNode(sgt_, Op_MAX_POOL2D, id_)
{
setRequiredOperands(1, 1);
- setRequiredRank(4);
+ setRequiredRank(4, 4);
INIT_ATTRIBUTE(Pool);
}
@@ -1601,7 +1605,7 @@ OpFFT2d<Dtype>::OpFFT2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_,
: GraphNode(sgt_, Op_FFT2D, id_)
{
setRequiredOperands(2, 2);
- setRequiredRank(3);
+ setRequiredRank(3, 3);
INIT_ATTRIBUTE(FFT);
}
@@ -1724,7 +1728,7 @@ OpRFFT2d<Dtype>::OpRFFT2d(SubgraphTraverser* sgt_,
: GraphNode(sgt_, Op_RFFT2D, id_)
{
setRequiredOperands(1, 2);
- setRequiredRank(3);
+ setRequiredRank(3, 3);
}
template <TOSA_REF_TYPE Dtype>
@@ -1830,7 +1834,7 @@ OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::OpTransposeConv2d(SubgraphTra
: GraphNode(sgt_, Op_TRANSPOSE_CONV2D, id_)
{
setRequiredOperands(3, 1);
- setRequiredRank(4);
+ setRequiredRank(4, 4);
INIT_ATTRIBUTE(TransposeConv);
}