aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/data_layout.cc
diff options
context:
space:
mode:
authorJerry Ge <jerry.ge@arm.com>2023-04-21 22:49:57 +0000
committerJerry Ge <jerry.ge@arm.com>2023-05-17 22:46:57 +0000
commit264f7faa59709ffa8117541f5d55c99c5dba967d (patch)
treeae767b3e4375ab87d4323f18b63239a84ac857db /reference_model/src/ops/data_layout.cc
parent7e5968166a5105da30bc11c9241f271cb3dc1da9 (diff)
downloadreference_model-264f7faa59709ffa8117541f5d55c99c5dba967d.tar.gz
Add support for one dimension of size -1 in ReshapeOp
Signed-off-by: Jerry Ge <jerry.ge@arm.com> Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: I0ef7607f4266296a1204c5cccdb5be36f345b5ba
Diffstat (limited to 'reference_model/src/ops/data_layout.cc')
-rw-r--r--reference_model/src/ops/data_layout.cc42
1 files changed, 40 insertions, 2 deletions
diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc
index fd19f96..86cd752 100644
--- a/reference_model/src/ops/data_layout.cc
+++ b/reference_model/src/ops/data_layout.cc
@@ -250,12 +250,50 @@ int OpReshape<InRank, OutRank, Dtype>::checkTensorAttributes()
return 1;
}
+ // -1 shape inferencing
+ auto inferred_size = -1;
+ auto inferred_dim = -1;
+ auto total_size = getInputs()[0]->getElementCount();
+ uint32_t accum_size = 1;
+
+ for (int32_t d = 0; d < OutRank; d++)
+ {
+ auto curr_new_shape = attribute->new_shape()[d];
+ if (curr_new_shape != -1) {
+ accum_size *= curr_new_shape;
+ } else {
+ ERROR_IF(inferred_dim != -1, "OpReshape: only 1 inferred dimension in output shape is supported");
+ inferred_dim = d;
+ }
+ }
+
+ ERROR_IF((total_size % accum_size) != 0, "OpReshape: shape inference failed, missing dimension would be non-integer");
+ inferred_size = total_size / accum_size;
+
+ if (inferred_dim != -1) {
+ getOutputs()[0]->setDimSize(inferred_dim, inferred_size);
+
+ // Need to also edit the serializedTensor's shape at inferred_dim
+ TosaSerializationTensor* serializedTensor;
+ for (auto region : parent_sgt->getTsh()->GetRegions()) {
+ for (auto block : region->GetBlocks()) {
+ if (block->GetTensorByName(getOutputs()[0]->getName())) {
+ serializedTensor = block->GetTensorByName(getOutputs()[0]->getName());
+ serializedTensor->SetDimSize(inferred_dim, inferred_size);
+ break;
+ }
+ }
+ }
+
+ }
+
ERROR_IF(inputs[0]->getElementCount() != outputs[0]->getElementCount(),
"Input tensor size does not match output tensor size");
for (uint32_t d = 0; d < OutRank; d++)
{
- ERROR_IF(attribute->new_shape()[d] != outputs[0]->getShape()[d],
+ auto curr_new_shape = attribute->new_shape()[d];
+ ERROR_IF(curr_new_shape != -1 && curr_new_shape != outputs[0]->getShape()[d],
"OpReshape: new_shape doesn't match output shape");
}
@@ -270,7 +308,7 @@ int OpReshape<InRank, OutRank, Dtype>::eval()
{
for (int32_t d = 0; d < OutRank; d++)
{
- array_shape[d] = attribute->new_shape()[OutRank - 1 - d];
+ array_shape[d] = getOutputs()[0]->getShape()[OutRank - 1 - d];
out_reverser[d] = OutRank - 1 - d;
}