diff options
author | Jerry Ge <jerry.ge@arm.com> | 2023-04-21 22:49:57 +0000 |
---|---|---|
committer | Jerry Ge <jerry.ge@arm.com> | 2023-05-17 22:46:57 +0000 |
commit | 264f7faa59709ffa8117541f5d55c99c5dba967d (patch) | |
tree | ae767b3e4375ab87d4323f18b63239a84ac857db /reference_model/src/ops/data_layout.cc | |
parent | 7e5968166a5105da30bc11c9241f271cb3dc1da9 (diff) | |
download | reference_model-264f7faa59709ffa8117541f5d55c99c5dba967d.tar.gz |
Add support for one dimension of size -1 in ReshapeOp
Signed-off-by: Jerry Ge <jerry.ge@arm.com>
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Change-Id: I0ef7607f4266296a1204c5cccdb5be36f345b5ba
Diffstat (limited to 'reference_model/src/ops/data_layout.cc')
-rw-r--r-- | reference_model/src/ops/data_layout.cc | 42 |
1 files changed, 40 insertions, 2 deletions
diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc index fd19f96..86cd752 100644 --- a/reference_model/src/ops/data_layout.cc +++ b/reference_model/src/ops/data_layout.cc @@ -250,12 +250,50 @@ int OpReshape<InRank, OutRank, Dtype>::checkTensorAttributes() return 1; } + // -1 shape inferencing + auto inferred_size = -1; + auto inferred_dim = -1; + auto total_size = getInputs()[0]->getElementCount(); + uint32_t accum_size = 1; + + for (int32_t d = 0; d < OutRank; d++) + { + auto curr_new_shape = attribute->new_shape()[d]; + if (curr_new_shape != -1) { + accum_size *= curr_new_shape; + } else { + ERROR_IF(inferred_dim != -1, "OpReshape: only 1 inferred dimension in output shape is supported"); + inferred_dim = d; + } + } + + ERROR_IF((total_size % accum_size) != 0, "OpReshape: shape inference failed, missing dimension would be non-integer"); + inferred_size = total_size / accum_size; + + if (inferred_dim != -1) { + getOutputs()[0]->setDimSize(inferred_dim, inferred_size); + + // Need to also edit the serializedTensor's shape at inferred_dim + TosaSerializationTensor* serializedTensor; + for (auto region : parent_sgt->getTsh()->GetRegions()) { + for (auto block : region->GetBlocks()) { + if (block->GetTensorByName(getOutputs()[0]->getName())) { + serializedTensor = block->GetTensorByName(getOutputs()[0]->getName()); + serializedTensor->SetDimSize(inferred_dim, inferred_size); + break; + } + } + } + + } + ERROR_IF(inputs[0]->getElementCount() != outputs[0]->getElementCount(), "Input tensor size does not match output tensor size"); for (uint32_t d = 0; d < OutRank; d++) { - ERROR_IF(attribute->new_shape()[d] != outputs[0]->getShape()[d], + auto curr_new_shape = attribute->new_shape()[d]; + ERROR_IF(curr_new_shape != -1 && curr_new_shape != outputs[0]->getShape()[d], "OpReshape: new_shape doesn't match output shape"); } @@ -270,7 +308,7 @@ int OpReshape<InRank, OutRank, Dtype>::eval() { for (int32_t d = 0; d < OutRank; d++) { - array_shape[d] = attribute->new_shape()[OutRank - 1 - d]; + array_shape[d] = getOutputs()[0]->getShape()[OutRank - 1 - d]; out_reverser[d] = OutRank - 1 - d; } |