aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/data_layout.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops/data_layout.cc')
-rw-r--r--reference_model/src/ops/data_layout.cc42
1 files changed, 40 insertions, 2 deletions
diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc
index fd19f96..86cd752 100644
--- a/reference_model/src/ops/data_layout.cc
+++ b/reference_model/src/ops/data_layout.cc
@@ -250,12 +250,50 @@ int OpReshape<InRank, OutRank, Dtype>::checkTensorAttributes()
return 1;
}
+ // -1 shape inferencing
+ auto inferred_size = -1;
+ auto inferred_dim = -1;
+ auto total_size = getInputs()[0]->getElementCount();
+ uint32_t accum_size = 1;
+
+ for (int32_t d = 0; d < OutRank; d++)
+ {
+ auto curr_new_shape = attribute->new_shape()[d];
+ if (curr_new_shape != -1) {
+ accum_size *= curr_new_shape;
+ } else {
+ ERROR_IF(inferred_dim != -1, "OpReshape: only 1 inferred dimension in output shape is supported");
+ inferred_dim = d;
+ }
+ }
+
+ ERROR_IF((total_size % accum_size) != 0, "OpReshape: shape inference failed, missing dimension would be non-integer");
+ inferred_size = total_size / accum_size;
+
+ if (inferred_dim != -1) {
+ getOutputs()[0]->setDimSize(inferred_dim, inferred_size);
+
+ // Need to also edit the serializedTensor's shape at inferred_dim
+ TosaSerializationTensor* serializedTensor;
+ for (auto region : parent_sgt->getTsh()->GetRegions()) {
+ for (auto block : region->GetBlocks()) {
+ if (block->GetTensorByName(getOutputs()[0]->getName())) {
+ serializedTensor = block->GetTensorByName(getOutputs()[0]->getName());
+ serializedTensor->SetDimSize(inferred_dim, inferred_size);
+ break;
+ }
+ }
+ }
+
+ }
+
ERROR_IF(inputs[0]->getElementCount() != outputs[0]->getElementCount(),
"Input tensor size does not match output tensor size");
for (uint32_t d = 0; d < OutRank; d++)
{
- ERROR_IF(attribute->new_shape()[d] != outputs[0]->getShape()[d],
+ auto curr_new_shape = attribute->new_shape()[d];
+ ERROR_IF(curr_new_shape != -1 && curr_new_shape != outputs[0]->getShape()[d],
"OpReshape: new_shape doesn't match output shape");
}
@@ -270,7 +308,7 @@ int OpReshape<InRank, OutRank, Dtype>::eval()
{
for (int32_t d = 0; d < OutRank; d++)
{
- array_shape[d] = attribute->new_shape()[OutRank - 1 - d];
+ array_shape[d] = getOutputs()[0]->getShape()[OutRank - 1 - d];
out_reverser[d] = OutRank - 1 - d;
}