diff options
author | Jerry Ge <jerry.ge@arm.com> | 2023-05-01 18:36:43 +0000 |
---|---|---|
committer | Jerry Ge <jerry.ge@arm.com> | 2023-05-10 02:40:49 +0000 |
commit | 0bd4ec89d52cc1fd36e92dff2fb496b3550ee7f5 (patch) | |
tree | d2662a0e62aec08a648edf61da62ee789a481080 /reference_model/src/ops/data_layout.cc | |
parent | a4d748b08accce06fab93e2d2b96e499b35ae89b (diff) | |
download | reference_model-0bd4ec89d52cc1fd36e92dff2fb496b3550ee7f5.tar.gz |
Refactor ref_model rank checking and add level check to argmax
Signed-off-by: Jerry Ge <jerry.ge@arm.com>
Change-Id: Iad035b31d5e5e83040068e6311501490765bfff7
Diffstat (limited to 'reference_model/src/ops/data_layout.cc')
-rw-r--r-- | reference_model/src/ops/data_layout.cc | 18 |
1 files changed, 6 insertions, 12 deletions
diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc index 442cef8..fd19f96 100644 --- a/reference_model/src/ops/data_layout.cc +++ b/reference_model/src/ops/data_layout.cc @@ -27,7 +27,7 @@ OpConcat<Rank, Dtype>::OpConcat(SubgraphTraverser* sgt_, : GraphNode(sgt_, Op_CONCAT, id_) { setRequiredOperands(-1, 1); - setRequiredRank(1, 6); + setRequiredRank(1); INIT_ATTRIBUTE(Axis); } @@ -131,7 +131,7 @@ OpPad<Rank, Dtype>::OpPad(SubgraphTraverser* sgt_, : GraphNode(sgt_, Op_PAD, id_) { setRequiredOperands(1, 1); - setRequiredRank(1, 6); + setRequiredRank(1); INIT_ATTRIBUTE(Pad); } @@ -221,7 +221,6 @@ OpReshape<InRank, OutRank, Dtype>::OpReshape(SubgraphTraverser* sgt_, : GraphNode(sgt_, Op_RESHAPE, id_) { setRequiredOperands(1, 1); - setRequiredRank(0, 6); INIT_ATTRIBUTE(Reshape); } @@ -244,11 +243,6 @@ int OpReshape<InRank, OutRank, Dtype>::checkTensorAttributes() if (validateRequiredOperands()) return 1; - if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0])) - { - return 1; - } - // output and input must be the same types if (inputs[0]->matchType(*outputs[0])) { @@ -321,7 +315,7 @@ OpReverse<Rank, Dtype>::OpReverse(SubgraphTraverser* sgt_, : GraphNode(sgt_, Op_REVERSE, id_) { setRequiredOperands(1, 1); - setRequiredRank(1, 6); + setRequiredRank(1); INIT_ATTRIBUTE(Axis); } @@ -392,7 +386,7 @@ OpSlice<Rank, Dtype>::OpSlice(SubgraphTraverser* sgt_, : GraphNode(sgt_, Op_SLICE, id_) { setRequiredOperands(1, 1); - setRequiredRank(1, 6); + setRequiredRank(1); INIT_ATTRIBUTE(Slice); } @@ -465,7 +459,7 @@ OpTileBase<Rank, Dtype>::OpTileBase(SubgraphTraverser* sgt_, : GraphNode(sgt_, Op_TILE, id_) { setRequiredOperands(1, 1); - setRequiredRank(1, 6); + setRequiredRank(1); INIT_ATTRIBUTE(Tile); } @@ -667,7 +661,7 @@ OpTranspose<Rank, Dtype>::OpTranspose(SubgraphTraverser* sgt_, : GraphNode(sgt_, Op_TRANSPOSE, id_) { setRequiredOperands(1, 1); - setRequiredRank(1, 6); + setRequiredRank(1); INIT_ATTRIBUTE(Transpose); } |