diff options
Diffstat (limited to 'reference_model/src/ops')
-rw-r--r-- | reference_model/src/ops/data_layout.cc | 34 | ||||
-rw-r--r-- | reference_model/src/ops/data_layout.h | 3 |
2 files changed, 20 insertions, 17 deletions
diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc index c1a14b7..c66d64e 100644 --- a/reference_model/src/ops/data_layout.cc +++ b/reference_model/src/ops/data_layout.cc @@ -24,7 +24,7 @@ template <int Rank, DType Dtype> OpConcat<Rank, Dtype>::OpConcat(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(Op_CONCAT, id_) { - setRequiredOperands(2, 1); + setRequiredOperands(-1, 1); setRequiredRank(1, 6); INIT_ATTRIBUTE(Axis); @@ -43,24 +43,25 @@ int OpConcat<Rank, Dtype>::checkTensorAttributes() if (validateRequiredOperands()) return 1; - if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0])) + if (inputs.empty()) { + printNodeValidationError("Concat operator must have at least one input tensor"); return 1; } - // output and input must be the same types and rank - // inputs[0] and inputs[1] should also match type and rank - if (inputs[0]->matchRankType(*outputs[0]) || inputs[1]->matchRankType(*outputs[0])) + for (size_t i = 0; i < inputs.size(); i++) { - printNodeValidationError("Concat operator input ranks and types must match"); - return 1; + if (inputs[i]->matchRankType(*outputs[0])) + { + printNodeValidationError("Concat operator input ranks and types must match"); + return 1; + } + ins.push_back(dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[i])); } - lhs = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); - rhs = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]); out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]); - if (attribute->axis() < 0 || (size_t)attribute->axis() >= rhs->getShape().size()) + if (attribute->axis() < 0 || (size_t)attribute->axis() >= inputs[0]->getShape().size()) { printNodeValidationError("Axis is beyond input tensor rank"); return 1; @@ -80,12 +81,15 @@ int OpConcat<Rank, Dtype>::eval() reverser[d] = Rank - 1 - d; } - TIn lhs_reversed = lhs->getTensor().shuffle(reverser); - TIn rhs_reversed = rhs->getTensor().shuffle(reverser); + TIn result = ins[0]->getTensor().shuffle(reverser); - TIn reversed_result = lhs_reversed.concatenate(rhs_reversed, reversed_axis); - out->getTensor() = reversed_result.shuffle(reverser); - // out->getTensor() = lhs->getTensor().concatenate(rhs->getTensor(), axis); + for (size_t i = 1; i < ins.size(); i++) + { + TIn in_reversed = ins[i]->getTensor().shuffle(reverser); + TIn temp = result.concatenate(in_reversed, reversed_axis); + result = temp; + } + out->getTensor() = result.shuffle(reverser); return GraphNode::eval(); } diff --git a/reference_model/src/ops/data_layout.h b/reference_model/src/ops/data_layout.h index 100bd6b..b180b4f 100644 --- a/reference_model/src/ops/data_layout.h +++ b/reference_model/src/ops/data_layout.h @@ -40,8 +40,7 @@ public: protected: Eigen::array<int, Rank> reverser; - TosaReference::TensorTemplate<TIn>* lhs; - TosaReference::TensorTemplate<TIn>* rhs; + std::vector<TosaReference::TensorTemplate<TIn>*> ins; TosaAxisAttribute* attribute; TosaReference::TensorTemplate<TOut>* out; }; |