diff options
Diffstat (limited to 'reference_model/src/ops/data_layout.cc')
-rw-r--r-- | reference_model/src/ops/data_layout.cc | 48 |
1 files changed, 5 insertions, 43 deletions
diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc index 2d1fdb0..fa99d21 100644 --- a/reference_model/src/ops/data_layout.cc +++ b/reference_model/src/ops/data_layout.cc @@ -297,48 +297,11 @@ int OpReshape<InRank, OutRank, Dtype>::checkTensorAttributes() return 1; } - // -1 shape inferencing - auto inferred_size = -1; - auto inferred_dim = -1; - auto total_size = getInputs()[0]->getElementCount(); - uint32_t accum_size = 1; - + // Check for unsupported -1 shape inferencing for (int32_t d = 0; d < OutRank; d++) { - auto curr_new_shape = attribute->new_shape()[d]; - if (curr_new_shape != -1) - { - accum_size *= curr_new_shape; - } - else - { - ERROR_IF(inferred_dim != -1, "OpReshape: only 1 inferred dimension in output shape is supported"); - inferred_dim = d; - } - } - - ERROR_IF((total_size % accum_size) != 0, - "OpReshape: shape inference failed, missing dimension would be non-integer"); - inferred_size = total_size / accum_size; - - if (inferred_dim != -1) - { - getOutputs()[0]->setDimSize(inferred_dim, inferred_size); - - // Need to also edit the serializedTensor's shape at inferred_dim - TosaSerializationTensor* serializedTensor; - for (auto region : parent_sgt->getTsh()->GetRegions()) - { - for (auto block : region->GetBlocks()) - { - if (block->GetTensorByName(getOutputs()[0]->getName())) - { - serializedTensor = block->GetTensorByName(getOutputs()[0]->getName()); - serializedTensor->SetDimSize(inferred_dim, inferred_size); - break; - } - } - } + auto curr_new_dim = attribute->new_shape()[d]; + ERROR_IF(curr_new_dim == -1, "OpReshape: inferred dimensions in output shape are unsupported") } ERROR_IF(inputs[0]->getElementCount() != outputs[0]->getElementCount(), @@ -346,9 +309,8 @@ int OpReshape<InRank, OutRank, Dtype>::checkTensorAttributes() for (uint32_t d = 0; d < OutRank; d++) { - auto curr_new_shape = attribute->new_shape()[d]; - ERROR_IF(curr_new_shape != -1 && curr_new_shape != outputs[0]->getShape()[d], - "OpReshape: new_shape doesn't match output shape"); + auto curr_new_dim = attribute->new_shape()[d]; + ERROR_IF(curr_new_dim != outputs[0]->getShape()[d], "OpReshape: new_shape doesn't match output shape"); } in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); |