diff options
Diffstat (limited to 'reference_model')
-rw-r--r-- | reference_model/src/graph_node.h | 18 | ||||
-rw-r--r-- | reference_model/src/ops/data_layout.cc | 18 | ||||
-rw-r--r-- | reference_model/src/ops/ewise_binary.cc | 6 | ||||
-rw-r--r-- | reference_model/src/ops/ewise_ternary.cc | 7 | ||||
-rw-r--r-- | reference_model/src/ops/ewise_unary.cc | 6 | ||||
-rw-r--r-- | reference_model/src/ops/image.cc | 10 | ||||
-rw-r--r-- | reference_model/src/ops/reduction.cc | 2 | ||||
-rw-r--r-- | reference_model/src/ops/tensor_ops.cc | 26 | ||||
-rw-r--r-- | reference_model/src/ops/type_conversion.cc | 6 | ||||
-rw-r--r-- | reference_model/src/tensor.h | 4 |
10 files changed, 36 insertions, 67 deletions
diff --git a/reference_model/src/graph_node.h b/reference_model/src/graph_node.h index 3433192..aafc07f 100644 --- a/reference_model/src/graph_node.h +++ b/reference_model/src/graph_node.h @@ -270,20 +270,16 @@ protected: int setRequiredRank(const int min, const int max = -1) { - if (max == -1) - { - requiredRankMin = requiredRankMax = min; - } - else + requiredRankMin = min; + requiredRankMax = max; + + if (requiredRankMin >= 0 && requiredRankMax >= 0) { - requiredRankMin = min; - requiredRankMax = max; + ASSERT_MSG(requiredRankMin <= requiredRankMax, + "GraphNode::setRequiredRank: requiredRankMin %d must be <= requiredRankMax %d", requiredRankMin, + requiredRankMax); } - ASSERT_MSG(requiredRankMin <= requiredRankMax, - "GraphNode::setRequiredRank: requiredRankMin %d must be <= requiredRankMax %d", requiredRankMin, - requiredRankMax); - return 0; } diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc index 442cef8..fd19f96 100644 --- a/reference_model/src/ops/data_layout.cc +++ b/reference_model/src/ops/data_layout.cc @@ -27,7 +27,7 @@ OpConcat<Rank, Dtype>::OpConcat(SubgraphTraverser* sgt_, : GraphNode(sgt_, Op_CONCAT, id_) { setRequiredOperands(-1, 1); - setRequiredRank(1, 6); + setRequiredRank(1); INIT_ATTRIBUTE(Axis); } @@ -131,7 +131,7 @@ OpPad<Rank, Dtype>::OpPad(SubgraphTraverser* sgt_, : GraphNode(sgt_, Op_PAD, id_) { setRequiredOperands(1, 1); - setRequiredRank(1, 6); + setRequiredRank(1); INIT_ATTRIBUTE(Pad); } @@ -221,7 +221,6 @@ OpReshape<InRank, OutRank, Dtype>::OpReshape(SubgraphTraverser* sgt_, : GraphNode(sgt_, Op_RESHAPE, id_) { setRequiredOperands(1, 1); - setRequiredRank(0, 6); INIT_ATTRIBUTE(Reshape); } @@ -244,11 +243,6 @@ int OpReshape<InRank, OutRank, Dtype>::checkTensorAttributes() if (validateRequiredOperands()) return 1; - if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0])) - { - return 1; - } - // output and input must be the same types if (inputs[0]->matchType(*outputs[0])) { @@ -321,7 +315,7 @@ OpReverse<Rank, Dtype>::OpReverse(SubgraphTraverser* sgt_, : GraphNode(sgt_, Op_REVERSE, id_) { setRequiredOperands(1, 1); - setRequiredRank(1, 6); + setRequiredRank(1); INIT_ATTRIBUTE(Axis); } @@ -392,7 +386,7 @@ OpSlice<Rank, Dtype>::OpSlice(SubgraphTraverser* sgt_, : GraphNode(sgt_, Op_SLICE, id_) { setRequiredOperands(1, 1); - setRequiredRank(1, 6); + setRequiredRank(1); INIT_ATTRIBUTE(Slice); } @@ -465,7 +459,7 @@ OpTileBase<Rank, Dtype>::OpTileBase(SubgraphTraverser* sgt_, : GraphNode(sgt_, Op_TILE, id_) { setRequiredOperands(1, 1); - setRequiredRank(1, 6); + setRequiredRank(1); INIT_ATTRIBUTE(Tile); } @@ -667,7 +661,7 @@ OpTranspose<Rank, Dtype>::OpTranspose(SubgraphTraverser* sgt_, : GraphNode(sgt_, Op_TRANSPOSE, id_) { setRequiredOperands(1, 1); - setRequiredRank(1, 6); + setRequiredRank(1); INIT_ATTRIBUTE(Transpose); } diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc index c5801e7..1e873e7 100644 --- a/reference_model/src/ops/ewise_binary.cc +++ b/reference_model/src/ops/ewise_binary.cc @@ -29,7 +29,6 @@ BinaryNodeBase<Rank, InDtype, OutDtype>::BinaryNodeBase(SubgraphTraverser* sgt_, : GraphNode(sgt_, op_, id_) { setRequiredOperands(2, 1); - setRequiredRank(0, 6); a = b = nullptr; result = nullptr; @@ -51,11 +50,6 @@ int BinaryNodeBase<Rank, InDtype, OutDtype>::checkTensorAttributes() if (validateRequiredOperands()) return 1; - if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0])) - { - return 1; - } - // A & B must be the same rank and types if (inputs[0]->matchRankType(*inputs[1])) { diff --git a/reference_model/src/ops/ewise_ternary.cc b/reference_model/src/ops/ewise_ternary.cc index 090ce29..16554b5 100644 --- a/reference_model/src/ops/ewise_ternary.cc +++ b/reference_model/src/ops/ewise_ternary.cc @@ -26,7 +26,6 @@ OpSelectBase<Rank, Dtype>::OpSelectBase(SubgraphTraverser* sgt_, : GraphNode(sgt_, Op_SELECT, id_) { setRequiredOperands(3, 1); - setRequiredRank(0, 6); } template <int Rank, TOSA_REF_TYPE Dtype> @@ -43,12 +42,6 @@ int OpSelectBase<Rank, Dtype>::checkTensorAttributes() if (validateRequiredOperands()) return 1; - if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(inputs[2]) || - validateRequiredRank(outputs[0])) - { - return 1; - } - // output and input must be the same types if (inputs[0]->matchRankShape(*outputs[0], true /* broadcastOk */) || inputs[1]->matchRankTypeShape(*outputs[0], true /* broadcastOk */) || diff --git a/reference_model/src/ops/ewise_unary.cc b/reference_model/src/ops/ewise_unary.cc index 514cb84..e6e870e 100644 --- a/reference_model/src/ops/ewise_unary.cc +++ b/reference_model/src/ops/ewise_unary.cc @@ -27,7 +27,6 @@ UnaryNode<Rank, Dtype>::UnaryNode(SubgraphTraverser* sgt_, const Op& op_, uint64 : GraphNode(sgt_, op_, id_) { setRequiredOperands(1, 1); - setRequiredRank(0, 6); fcn = [](InEigenType a) -> OutEigenType { ASSERT_MSG(0, "In default UnaryNode function, missing function registration"); @@ -49,11 +48,6 @@ int UnaryNode<Rank, Dtype>::checkTensorAttributes() if (validateRequiredOperands()) return 1; - if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0])) - { - return 1; - } - // output and input must be the same types if (inputs[0]->matchRankTypeShape(*outputs[0])) { diff --git a/reference_model/src/ops/image.cc b/reference_model/src/ops/image.cc index ca12cfe..575a500 100644 --- a/reference_model/src/ops/image.cc +++ b/reference_model/src/ops/image.cc @@ -113,11 +113,6 @@ int OpResize<InDtype, OutDtype, resize_t>::eval() int16_t border_y = border[0]; int16_t border_x = border[1]; - // Check Tosa Level - auto tosa_level = g_func_config.tosa_level; - LEVEL_CHECK(scale_y_n / scale_y_d <= tosa_level.MAX_SCALE, "scale_y_n / scale_y_d should be smaller than or equal to MAX_SCALE"); - LEVEL_CHECK(scale_x_n / scale_x_d <= tosa_level.MAX_SCALE, "scale_x_n / scale_x_d should be smaller than or equal to MAX_SCALE"); - ERROR_IF(std::max<int>({ in_height, in_width, out_height, out_width }) >= 16384, "OpResize: exceeds maximum dimension"); ERROR_IF(in_batch != out_batch, "OpResize: output tensor batch mismatch"); @@ -137,6 +132,11 @@ int OpResize<InDtype, OutDtype, resize_t>::eval() ERROR_IF((border_x < -16 * scale_x_n || border_x >= scale_x_n), "OpResize: invalid attribute border width dimension"); + // Check Tosa Level + auto tosa_level = g_func_config.tosa_level; + LEVEL_CHECK(scale_y_n / scale_y_d <= tosa_level.MAX_SCALE, "scale_y_n / scale_y_d should be smaller than or equal to MAX_SCALE"); + LEVEL_CHECK(scale_x_n / scale_x_d <= tosa_level.MAX_SCALE, "scale_x_n / scale_x_d should be smaller than or equal to MAX_SCALE"); + int32_t res_height = 0; int32_t res_width = 0; diff --git a/reference_model/src/ops/reduction.cc b/reference_model/src/ops/reduction.cc index bf8ba57..fd48472 100644 --- a/reference_model/src/ops/reduction.cc +++ b/reference_model/src/ops/reduction.cc @@ -25,7 +25,7 @@ ReduceNode<Rank, Dtype>::ReduceNode(SubgraphTraverser* sgt_, const Op& op_, Tosa : GraphNode(sgt_, op_, id_) { setRequiredOperands(1, 1); - setRequiredRank(0, 4); + setRequiredRank(1, 4); INIT_ATTRIBUTE(Axis); } diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc index f8fd323..a60819d 100644 --- a/reference_model/src/ops/tensor_ops.cc +++ b/reference_model/src/ops/tensor_ops.cc @@ -327,7 +327,7 @@ OpArgMax<Rank, Dtype>::OpArgMax(SubgraphTraverser* sgt_, : GraphNode(sgt_, Op_ARGMAX, id_) { setRequiredOperands(1, 1); - setRequiredRank(1, 4); + setRequiredRank(1); INIT_ATTRIBUTE(Axis); } @@ -405,6 +405,10 @@ int OpArgMax<Rank, Dtype>::checkTensorAttributes() template <int Rank, TOSA_REF_TYPE Dtype> int OpArgMax<Rank, Dtype>::eval() { + // Check Tosa Level + auto tosa_level = g_func_config.tosa_level; + LEVEL_CHECK(Rank <= tosa_level.MAX_RANK, "Rank should be smaller than or equal to MAX_RANK"); + Eigen::Tensor<DenseIndex, Rank - 1> index = this->input->getTensor().argmax(attribute->axis()); this->output->getTensor() = index.unaryExpr([](DenseIndex in) -> OutEigenType { return (OutEigenType)in; }); @@ -419,7 +423,7 @@ OpAvgPool2d<Dtype, AccDtype>::OpAvgPool2d(SubgraphTraverser* sgt_, : GraphNode(sgt_, Op_AVG_POOL2D, id_) { setRequiredOperands(1, 1); - setRequiredRank(4); + setRequiredRank(4, 4); INIT_ATTRIBUTE(Pool); } @@ -645,7 +649,7 @@ OpConv2d<InDtype, WeightDtype, OutDtype>::OpConv2d(SubgraphTraverser* sgt_, : GraphNode(sgt_, Op_CONV2D, id_) { setRequiredOperands(3, 1); - setRequiredRank(4); + setRequiredRank(4, 4); INIT_ATTRIBUTE(Conv); } @@ -839,7 +843,7 @@ OpConv3d<InDtype, WeightDtype, OutDtype>::OpConv3d(SubgraphTraverser* sgt_, : GraphNode(sgt_, Op_CONV3D, id_) { setRequiredOperands(3, 1); - setRequiredRank(5); + setRequiredRank(5, 5); INIT_ATTRIBUTE(Conv); } @@ -1042,7 +1046,7 @@ OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::OpDepthwiseConv2d(SubgraphTra : GraphNode(sgt_, Op_DEPTHWISE_CONV2D, id_) { setRequiredOperands(3, 1); - setRequiredRank(4); + setRequiredRank(4, 4); INIT_ATTRIBUTE(Conv); } @@ -1227,7 +1231,7 @@ OpFullyConnected<InDtype, WeightDtype, OutDtype>::OpFullyConnected(SubgraphTrave : GraphNode(sgt_, Op_FULLY_CONNECTED, id_) { setRequiredOperands(3, 1); - setRequiredRank(2); + setRequiredRank(2, 2); INIT_ATTRIBUTE(FullyConnected); } @@ -1322,7 +1326,7 @@ OpMatMul<Dtype, OutDtype>::OpMatMul(SubgraphTraverser* sgt_, : GraphNode(sgt_, Op_MATMUL, id_) { setRequiredOperands(2, 1); - setRequiredRank(3); + setRequiredRank(3, 3); INIT_ATTRIBUTE(MatMul); } @@ -1460,7 +1464,7 @@ OpMaxPool2d<Dtype>::OpMaxPool2d(SubgraphTraverser* sgt_, : GraphNode(sgt_, Op_MAX_POOL2D, id_) { setRequiredOperands(1, 1); - setRequiredRank(4); + setRequiredRank(4, 4); INIT_ATTRIBUTE(Pool); } @@ -1601,7 +1605,7 @@ OpFFT2d<Dtype>::OpFFT2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, : GraphNode(sgt_, Op_FFT2D, id_) { setRequiredOperands(2, 2); - setRequiredRank(3); + setRequiredRank(3, 3); INIT_ATTRIBUTE(FFT); } @@ -1724,7 +1728,7 @@ OpRFFT2d<Dtype>::OpRFFT2d(SubgraphTraverser* sgt_, : GraphNode(sgt_, Op_RFFT2D, id_) { setRequiredOperands(1, 2); - setRequiredRank(3); + setRequiredRank(3, 3); } template <TOSA_REF_TYPE Dtype> @@ -1830,7 +1834,7 @@ OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::OpTransposeConv2d(SubgraphTra : GraphNode(sgt_, Op_TRANSPOSE_CONV2D, id_) { setRequiredOperands(3, 1); - setRequiredRank(4); + setRequiredRank(4, 4); INIT_ATTRIBUTE(TransposeConv); } diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc index 68ffb1f..fce8e7c 100644 --- a/reference_model/src/ops/type_conversion.cc +++ b/reference_model/src/ops/type_conversion.cc @@ -31,7 +31,6 @@ OpRescale<Rank, InDtype, OutDtype>::OpRescale(SubgraphTraverser* sgt_, : GraphNode(sgt_, Op_RESCALE, id_) { setRequiredOperands(1, 1); - setRequiredRank(0, 6); INIT_ATTRIBUTE(Rescale); } @@ -52,11 +51,6 @@ int OpRescale<Rank, InDtype, OutDtype>::checkTensorAttributes() if (validateRequiredOperands()) return 1; - if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0])) - { - return 1; - } - // output and input must be the same rank and size if (inputs[0]->matchRankSize(*outputs[0])) { diff --git a/reference_model/src/tensor.h b/reference_model/src/tensor.h index 08ee8bf..b68a9b6 100644 --- a/reference_model/src/tensor.h +++ b/reference_model/src/tensor.h @@ -197,9 +197,9 @@ public: } // Unary check to make sure rank matches - const int checkRequiredRank(const int exactRank) const + const int checkRequiredRank(const int minRank) const { - return (shape.size() == (size_t)exactRank) ? 0 : 1; + return (shape.size() >= (size_t)minRank) ? 0 : 1; } const int checkRequiredRank(const int minRank, const int maxRank) const |