From ad15dfab0430b72015d13d19b8a696bb9bacd0a6 Mon Sep 17 00:00:00 2001 From: Kevin Cheng Date: Thu, 4 Mar 2021 15:15:03 -0800 Subject: Concat takes variadic inputs Signed-off-by: Kevin Cheng Change-Id: Ic8fe6e1fd899b41d444fd4f477d0f515ce0e9cc9 --- reference_model/src/ops/data_layout.cc | 34 +++++++++++++++++++--------------- reference_model/src/ops/data_layout.h | 3 +-- 2 files changed, 20 insertions(+), 17 deletions(-) (limited to 'reference_model/src/ops') diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc index c1a14b7..c66d64e 100644 --- a/reference_model/src/ops/data_layout.cc +++ b/reference_model/src/ops/data_layout.cc @@ -24,7 +24,7 @@ template OpConcat::OpConcat(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) : GraphNode(Op_CONCAT, id_) { - setRequiredOperands(2, 1); + setRequiredOperands(-1, 1); setRequiredRank(1, 6); INIT_ATTRIBUTE(Axis); @@ -43,24 +43,25 @@ int OpConcat::checkTensorAttributes() if (validateRequiredOperands()) return 1; - if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0])) + if (inputs.empty()) { + printNodeValidationError("Concat operator must have at least one input tensor"); return 1; } - // output and input must be the same types and rank - // inputs[0] and inputs[1] should also match type and rank - if (inputs[0]->matchRankType(*outputs[0]) || inputs[1]->matchRankType(*outputs[0])) + for (size_t i = 0; i < inputs.size(); i++) { - printNodeValidationError("Concat operator input ranks and types must match"); - return 1; + if (inputs[i]->matchRankType(*outputs[0])) + { + printNodeValidationError("Concat operator input ranks and types must match"); + return 1; + } + ins.push_back(dynamic_cast*>(inputs[i])); } - lhs = dynamic_cast*>(inputs[0]); - rhs = dynamic_cast*>(inputs[1]); out = dynamic_cast*>(outputs[0]); - if (attribute->axis() < 0 || (size_t)attribute->axis() >= rhs->getShape().size()) + if (attribute->axis() < 0 || (size_t)attribute->axis() >= inputs[0]->getShape().size()) { printNodeValidationError("Axis is beyond input tensor rank"); return 1; @@ -80,12 +81,15 @@ int OpConcat::eval() reverser[d] = Rank - 1 - d; } - TIn lhs_reversed = lhs->getTensor().shuffle(reverser); - TIn rhs_reversed = rhs->getTensor().shuffle(reverser); + TIn result = ins[0]->getTensor().shuffle(reverser); - TIn reversed_result = lhs_reversed.concatenate(rhs_reversed, reversed_axis); - out->getTensor() = reversed_result.shuffle(reverser); - // out->getTensor() = lhs->getTensor().concatenate(rhs->getTensor(), axis); + for (size_t i = 1; i < ins.size(); i++) + { + TIn in_reversed = ins[i]->getTensor().shuffle(reverser); + TIn temp = result.concatenate(in_reversed, reversed_axis); + result = temp; + } + out->getTensor() = result.shuffle(reverser); return GraphNode::eval(); } diff --git a/reference_model/src/ops/data_layout.h b/reference_model/src/ops/data_layout.h index 100bd6b..b180b4f 100644 --- a/reference_model/src/ops/data_layout.h +++ b/reference_model/src/ops/data_layout.h @@ -40,8 +40,7 @@ public: protected: Eigen::array reverser; - TosaReference::TensorTemplate* lhs; - TosaReference::TensorTemplate* rhs; + std::vector*> ins; TosaAxisAttribute* attribute; TosaReference::TensorTemplate* out; }; -- cgit v1.2.1