aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/data_layout.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops/data_layout.cc')
-rw-r--r--reference_model/src/ops/data_layout.cc48
1 files changed, 5 insertions, 43 deletions
diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc
index 2d1fdb0..fa99d21 100644
--- a/reference_model/src/ops/data_layout.cc
+++ b/reference_model/src/ops/data_layout.cc
@@ -297,48 +297,11 @@ int OpReshape<InRank, OutRank, Dtype>::checkTensorAttributes()
return 1;
}
- // -1 shape inferencing
- auto inferred_size = -1;
- auto inferred_dim = -1;
- auto total_size = getInputs()[0]->getElementCount();
- uint32_t accum_size = 1;
-
+ // Check for unsupported -1 shape inferencing
for (int32_t d = 0; d < OutRank; d++)
{
- auto curr_new_shape = attribute->new_shape()[d];
- if (curr_new_shape != -1)
- {
- accum_size *= curr_new_shape;
- }
- else
- {
- ERROR_IF(inferred_dim != -1, "OpReshape: only 1 inferred dimension in output shape is supported");
- inferred_dim = d;
- }
- }
-
- ERROR_IF((total_size % accum_size) != 0,
- "OpReshape: shape inference failed, missing dimension would be non-integer");
- inferred_size = total_size / accum_size;
-
- if (inferred_dim != -1)
- {
- getOutputs()[0]->setDimSize(inferred_dim, inferred_size);
-
- // Need to also edit the serializedTensor's shape at inferred_dim
- TosaSerializationTensor* serializedTensor;
- for (auto region : parent_sgt->getTsh()->GetRegions())
- {
- for (auto block : region->GetBlocks())
- {
- if (block->GetTensorByName(getOutputs()[0]->getName()))
- {
- serializedTensor = block->GetTensorByName(getOutputs()[0]->getName());
- serializedTensor->SetDimSize(inferred_dim, inferred_size);
- break;
- }
- }
- }
+ auto curr_new_dim = attribute->new_shape()[d];
+ ERROR_IF(curr_new_dim == -1, "OpReshape: inferred dimensions in output shape are unsupported")
}
ERROR_IF(inputs[0]->getElementCount() != outputs[0]->getElementCount(),
@@ -346,9 +309,8 @@ int OpReshape<InRank, OutRank, Dtype>::checkTensorAttributes()
for (uint32_t d = 0; d < OutRank; d++)
{
- auto curr_new_shape = attribute->new_shape()[d];
- ERROR_IF(curr_new_shape != -1 && curr_new_shape != outputs[0]->getShape()[d],
- "OpReshape: new_shape doesn't match output shape");
+ auto curr_new_dim = attribute->new_shape()[d];
+ ERROR_IF(curr_new_dim != outputs[0]->getShape()[d], "OpReshape: new_shape doesn't match output shape");
}
in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);