aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/data_layout.cc
diff options
context:
space:
mode:
authorKevin Cheng <kevin.cheng@arm.com>2021-03-04 15:15:03 -0800
committerKevin Cheng <kevin.cheng@arm.com>2021-04-30 10:04:40 -0700
commitad15dfab0430b72015d13d19b8a696bb9bacd0a6 (patch)
treea78f6e90c54451ca9da812c03349feb9ef87e071 /reference_model/src/ops/data_layout.cc
parent573ecd4373f4e8f5d1a219147b5a259125059cf0 (diff)
downloadreference_model-ad15dfab0430b72015d13d19b8a696bb9bacd0a6.tar.gz
Concat takes variadic inputs
Signed-off-by: Kevin Cheng <kevin.cheng@arm.com> Change-Id: Ic8fe6e1fd899b41d444fd4f477d0f515ce0e9cc9
Diffstat (limited to 'reference_model/src/ops/data_layout.cc')
-rw-r--r--reference_model/src/ops/data_layout.cc34
1 files changed, 19 insertions, 15 deletions
diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc
index c1a14b7..c66d64e 100644
--- a/reference_model/src/ops/data_layout.cc
+++ b/reference_model/src/ops/data_layout.cc
@@ -24,7 +24,7 @@ template <int Rank, DType Dtype>
OpConcat<Rank, Dtype>::OpConcat(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
: GraphNode(Op_CONCAT, id_)
{
- setRequiredOperands(2, 1);
+ setRequiredOperands(-1, 1);
setRequiredRank(1, 6);
INIT_ATTRIBUTE(Axis);
@@ -43,24 +43,25 @@ int OpConcat<Rank, Dtype>::checkTensorAttributes()
if (validateRequiredOperands())
return 1;
- if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
+ if (inputs.empty())
{
+ printNodeValidationError("Concat operator must have at least one input tensor");
return 1;
}
-
// output and input must be the same types and rank
- // inputs[0] and inputs[1] should also match type and rank
- if (inputs[0]->matchRankType(*outputs[0]) || inputs[1]->matchRankType(*outputs[0]))
+ for (size_t i = 0; i < inputs.size(); i++)
{
- printNodeValidationError("Concat operator input ranks and types must match");
- return 1;
+ if (inputs[i]->matchRankType(*outputs[0]))
+ {
+ printNodeValidationError("Concat operator input ranks and types must match");
+ return 1;
+ }
+ ins.push_back(dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[i]));
}
- lhs = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
- rhs = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
- if (attribute->axis() < 0 || (size_t)attribute->axis() >= rhs->getShape().size())
+ if (attribute->axis() < 0 || (size_t)attribute->axis() >= inputs[0]->getShape().size())
{
printNodeValidationError("Axis is beyond input tensor rank");
return 1;
@@ -80,12 +81,15 @@ int OpConcat<Rank, Dtype>::eval()
reverser[d] = Rank - 1 - d;
}
- TIn lhs_reversed = lhs->getTensor().shuffle(reverser);
- TIn rhs_reversed = rhs->getTensor().shuffle(reverser);
+ TIn result = ins[0]->getTensor().shuffle(reverser);
- TIn reversed_result = lhs_reversed.concatenate(rhs_reversed, reversed_axis);
- out->getTensor() = reversed_result.shuffle(reverser);
- // out->getTensor() = lhs->getTensor().concatenate(rhs->getTensor(), axis);
+ for (size_t i = 1; i < ins.size(); i++)
+ {
+ TIn in_reversed = ins[i]->getTensor().shuffle(reverser);
+ TIn temp = result.concatenate(in_reversed, reversed_axis);
+ result = temp;
+ }
+ out->getTensor() = result.shuffle(reverser);
return GraphNode::eval();
}