aboutsummaryrefslogtreecommitdiff
path: root/reference_model
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model')
-rw-r--r--reference_model/src/ops/data_layout.cc42
-rw-r--r--reference_model/src/subgraph_traverser.h4
-rw-r--r--reference_model/src/tensor.h8
3 files changed, 51 insertions, 3 deletions
diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc
index fd19f96..86cd752 100644
--- a/reference_model/src/ops/data_layout.cc
+++ b/reference_model/src/ops/data_layout.cc
@@ -250,12 +250,50 @@ int OpReshape<InRank, OutRank, Dtype>::checkTensorAttributes()
return 1;
}
+ // -1 shape inferencing
+ auto inferred_size = -1;
+ auto inferred_dim = -1;
+ auto total_size = getInputs()[0]->getElementCount();
+ uint32_t accum_size = 1;
+
+ for (int32_t d = 0; d < OutRank; d++)
+ {
+ auto curr_new_shape = attribute->new_shape()[d];
+ if (curr_new_shape != -1) {
+ accum_size *= curr_new_shape;
+ } else {
+ ERROR_IF(inferred_dim != -1, "OpReshape: only 1 inferred dimension in output shape is supported");
+ inferred_dim = d;
+ }
+ }
+
+ ERROR_IF((total_size % accum_size) != 0, "OpReshape: shape inference failed, missing dimension would be non-integer");
+ inferred_size = total_size / accum_size;
+
+ if (inferred_dim != -1) {
+ getOutputs()[0]->setDimSize(inferred_dim, inferred_size);
+
+ // Need to also edit the serializedTensor's shape at inferred_dim
+ TosaSerializationTensor* serializedTensor;
+ for (auto region : parent_sgt->getTsh()->GetRegions()) {
+ for (auto block : region->GetBlocks()) {
+ if (block->GetTensorByName(getOutputs()[0]->getName())) {
+ serializedTensor = block->GetTensorByName(getOutputs()[0]->getName());
+ serializedTensor->SetDimSize(inferred_dim, inferred_size);
+ break;
+ }
+ }
+ }
+
+ }
+
ERROR_IF(inputs[0]->getElementCount() != outputs[0]->getElementCount(),
"Input tensor size does not match output tensor size");
for (uint32_t d = 0; d < OutRank; d++)
{
- ERROR_IF(attribute->new_shape()[d] != outputs[0]->getShape()[d],
+ auto curr_new_shape = attribute->new_shape()[d];
+ ERROR_IF(curr_new_shape != -1 && curr_new_shape != outputs[0]->getShape()[d],
"OpReshape: new_shape doesn't match output shape");
}
@@ -270,7 +308,7 @@ int OpReshape<InRank, OutRank, Dtype>::eval()
{
for (int32_t d = 0; d < OutRank; d++)
{
- array_shape[d] = attribute->new_shape()[OutRank - 1 - d];
+ array_shape[d] = getOutputs()[0]->getShape()[OutRank - 1 - d];
out_reverser[d] = OutRank - 1 - d;
}
diff --git a/reference_model/src/subgraph_traverser.h b/reference_model/src/subgraph_traverser.h
index 543b008..1cf582e 100644
--- a/reference_model/src/subgraph_traverser.h
+++ b/reference_model/src/subgraph_traverser.h
@@ -63,6 +63,10 @@ public:
{
return block->GetRegionName();
}
+ TosaSerializationHandler* getTsh() const
+ {
+ return tsh;
+ }
int getNumInputTensors() const;
Tensor* getInputTensor(const unsigned int idx) const;
Tensor* getInputTensorByName(const std::string name) const;
diff --git a/reference_model/src/tensor.h b/reference_model/src/tensor.h
index b68a9b6..21cf148 100644
--- a/reference_model/src/tensor.h
+++ b/reference_model/src/tensor.h
@@ -96,6 +96,12 @@ public:
return shape;
}
+ void setDimSize(size_t dim, uint32_t new_size)
+ {
+ this->shape[dim] = new_size;
+ return;
+ }
+
std::string getShapeAsString() const
{
std::string shape_str("[");
@@ -269,7 +275,7 @@ public:
protected:
const std::string tensorName;
const DType serializationDtype;
- const std::vector<int> shape;
+ std::vector<int> shape;
const TOSA_REF_TYPE tensorDtype;
int isValid;
int isSubgraphInput;