From 264f7faa59709ffa8117541f5d55c99c5dba967d Mon Sep 17 00:00:00 2001 From: Jerry Ge Date: Fri, 21 Apr 2023 22:49:57 +0000 Subject: Add support for one dimension of size -1 in ReshapeOp Signed-off-by: Jerry Ge Signed-off-by: Jeremy Johnson Change-Id: I0ef7607f4266296a1204c5cccdb5be36f345b5ba --- reference_model/src/ops/data_layout.cc | 42 ++++++++++++++++++++++++++++++++-- 1 file changed, 40 insertions(+), 2 deletions(-) (limited to 'reference_model/src/ops') diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc index fd19f96..86cd752 100644 --- a/reference_model/src/ops/data_layout.cc +++ b/reference_model/src/ops/data_layout.cc @@ -250,12 +250,50 @@ int OpReshape::checkTensorAttributes() return 1; } + // -1 shape inferencing + auto inferred_size = -1; + auto inferred_dim = -1; + auto total_size = getInputs()[0]->getElementCount(); + uint32_t accum_size = 1; + + for (int32_t d = 0; d < OutRank; d++) + { + auto curr_new_shape = attribute->new_shape()[d]; + if (curr_new_shape != -1) { + accum_size *= curr_new_shape; + } else { + ERROR_IF(inferred_dim != -1, "OpReshape: only 1 inferred dimension in output shape is supported"); + inferred_dim = d; + } + } + + ERROR_IF((total_size % accum_size) != 0, "OpReshape: shape inference failed, missing dimension would be non-integer"); + inferred_size = total_size / accum_size; + + if (inferred_dim != -1) { + getOutputs()[0]->setDimSize(inferred_dim, inferred_size); + + // Need to also edit the serializedTensor's shape at inferred_dim + TosaSerializationTensor* serializedTensor; + for (auto region : parent_sgt->getTsh()->GetRegions()) { + for (auto block : region->GetBlocks()) { + if (block->GetTensorByName(getOutputs()[0]->getName())) { + serializedTensor = block->GetTensorByName(getOutputs()[0]->getName()); + serializedTensor->SetDimSize(inferred_dim, inferred_size); + break; + } + } + } + + } + ERROR_IF(inputs[0]->getElementCount() != outputs[0]->getElementCount(), "Input tensor size does not match output tensor size"); for (uint32_t d = 0; d < OutRank; d++) { - ERROR_IF(attribute->new_shape()[d] != outputs[0]->getShape()[d], + auto curr_new_shape = attribute->new_shape()[d]; + ERROR_IF(curr_new_shape != -1 && curr_new_shape != outputs[0]->getShape()[d], "OpReshape: new_shape doesn't match output shape"); } @@ -270,7 +308,7 @@ int OpReshape::eval() { for (int32_t d = 0; d < OutRank; d++) { - array_shape[d] = attribute->new_shape()[OutRank - 1 - d]; + array_shape[d] = getOutputs()[0]->getShape()[OutRank - 1 - d]; out_reverser[d] = OutRank - 1 - d; } -- cgit v1.2.1