aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJerry Ge <jerry.ge@arm.com>2024-04-01 17:05:10 +0000
committerJerry Ge <jerry.ge@arm.com>2024-04-02 21:39:09 +0000
commit12159fc6fb776908f48fbda9c74cf34980540e4f (patch)
tree54408e7a10502a347fc7afa25b665c60d696b4d1
parent9a97eb6cd6aab5eb58eb7860faa9fea305e37c07 (diff)
downloadreference_model-12159fc6fb776908f48fbda9c74cf34980540e4f.tar.gz
Show actual runtime value of shapeType tensors
* Enable showing actual runtime shapeType tensor value when the --dump_intermediates=1 flag is on Signed-off-by: Jerry Ge <jerry.ge@arm.com> Change-Id: Ibd5aa8aa27505364fbbf9d1addd0bdef0deda885
-rw-r--r--reference_model/src/ops/data_layout.cc9
-rw-r--r--reference_model/src/ops/shape.cc30
-rw-r--r--reference_model/src/tensor.cc13
-rw-r--r--reference_model/src/tensor.h26
4 files changed, 76 insertions, 2 deletions
diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc
index 4c17e78..e264284 100644
--- a/reference_model/src/ops/data_layout.cc
+++ b/reference_model/src/ops/data_layout.cc
@@ -270,6 +270,15 @@ int OpDim<Rank, Dtype>::eval()
this->out->getTensor().setValues({ out_val });
+ // set the shapeValue given the actual tensor value
+ std::vector<int> shapeValue;
+ for (int i = 0; i < out->getTensor().size(); ++i)
+ {
+ shapeValue.push_back(out->getTensor()(i));
+ }
+
+ this->getOutputs()[0]->setShapeValue(shapeValue);
+
return GraphNode::eval();
}
diff --git a/reference_model/src/ops/shape.cc b/reference_model/src/ops/shape.cc
index b087dd8..425dfc2 100644
--- a/reference_model/src/ops/shape.cc
+++ b/reference_model/src/ops/shape.cc
@@ -37,6 +37,18 @@ int OpConstShape::checkTensorAttributes()
int OpConstShape::eval()
{
+ // set the shapeValue given the actual tensor value
+ using EigenType = typename GetEigenType<TOSA_REF_TYPE_SHAPE>::type;
+ auto out = dynamic_cast<TosaReference::TensorTemplate<Eigen::Tensor<EigenType, 1>>*>(this->getOutputs()[0]);
+
+ std::vector<int> shapeValue;
+ for (int i = 0; out != nullptr && i < out->getTensor().size(); ++i)
+ {
+ shapeValue.push_back(out->getTensor()(i));
+ }
+
+ this->getOutputs()[0]->setShapeValue(shapeValue);
+
for (auto ct : getOutputs())
{
if (!ct->getIsValid())
@@ -106,6 +118,15 @@ int OpConcatShape::eval()
}
}
out->getTensor() = out_tensor;
+
+ // set the shapeValue given the actual tensor value
+ std::vector<int> shapeValue;
+ for (int i = 0; i < out->getTensor().size(); ++i)
+ {
+ shapeValue.push_back(out->getTensor()(i));
+ }
+ this->getOutputs()[0]->setShapeValue(shapeValue);
+
return GraphNode::eval();
}
@@ -168,6 +189,15 @@ int ShapeBinaryNodeBase::eval()
}
result->getTensor() = out_tens;
+
+ // set the shapeValue given the actual tensor value
+ std::vector<int> shapeValue;
+ for (int i = 0; i < result->getTensor().size(); ++i)
+ {
+ shapeValue.push_back(result->getTensor()(i));
+ }
+ this->getOutputs()[0]->setShapeValue(shapeValue);
+
return GraphNode::eval();
}
diff --git a/reference_model/src/tensor.cc b/reference_model/src/tensor.cc
index 1417fed..088c327 100644
--- a/reference_model/src/tensor.cc
+++ b/reference_model/src/tensor.cc
@@ -85,8 +85,17 @@ int TosaReference::Tensor::addConsumer(GraphNode* node)
int TosaReference::Tensor::dumpTensorParams(FILE* out) const
{
- fprintf(out, "Name: %s DType=%s isValid=%d Rank=%d Shape=%s\n", tensorName.c_str(), EnumNameTOSAREFTYPE(getDtype()),
- getIsValid(), getRank(), getShapeAsString().c_str());
+ if (this->getShapeValueSize() > 0)
+ {
+ fprintf(out, "Name: %s DType=%s isValid=%d Rank=%d Shape=%s ShapeValue=%s\n", tensorName.c_str(),
+ EnumNameTOSAREFTYPE(getDtype()), getIsValid(), getRank(), getShapeAsString().c_str(),
+ getShapeValueAsString().c_str());
+ }
+ else
+ {
+ fprintf(out, "Name: %s DType=%s isValid=%d Rank=%d Shape=%s\n", tensorName.c_str(),
+ EnumNameTOSAREFTYPE(getDtype()), getIsValid(), getRank(), getShapeAsString().c_str());
+ }
return 0;
}
diff --git a/reference_model/src/tensor.h b/reference_model/src/tensor.h
index 26c6aa7..f13de0e 100644
--- a/reference_model/src/tensor.h
+++ b/reference_model/src/tensor.h
@@ -109,6 +109,31 @@ public:
return;
}
+ void setShapeValue(std::vector<int>& shapeValue)
+ {
+ for (auto dim : shapeValue)
+ {
+ this->shapeValue.push_back(dim);
+ }
+ return;
+ }
+
+ int getShapeValueSize() const
+ {
+ return this->shapeValue.size();
+ }
+
+ std::string getShapeValueAsString() const
+ {
+ std::string shape_str("[");
+ for (auto& dim : shapeValue)
+ {
+ shape_str += (std::to_string(dim) + ", ");
+ }
+ shape_str.append("]");
+ return shape_str;
+ }
+
std::string getShapeAsString() const
{
std::string shape_str("[");
@@ -297,6 +322,7 @@ protected:
const std::string tensorName;
const DType serializationDtype;
std::vector<int> shape;
+ std::vector<int> shapeValue;
const TOSA_REF_TYPE tensorDtype;
bool isValid;
bool isSubgraphInput;