From 12159fc6fb776908f48fbda9c74cf34980540e4f Mon Sep 17 00:00:00 2001 From: Jerry Ge Date: Mon, 1 Apr 2024 17:05:10 +0000 Subject: Show actual runtime value of shapeType tensors * Enable showing actual runtime shapeType tensor value when the --dump_intermediates=1 flag is on Signed-off-by: Jerry Ge Change-Id: Ibd5aa8aa27505364fbbf9d1addd0bdef0deda885 --- reference_model/src/ops/data_layout.cc | 9 +++++++++ reference_model/src/ops/shape.cc | 30 ++++++++++++++++++++++++++++++ reference_model/src/tensor.cc | 13 +++++++++++-- reference_model/src/tensor.h | 26 ++++++++++++++++++++++++++ 4 files changed, 76 insertions(+), 2 deletions(-) diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc index 4c17e78..e264284 100644 --- a/reference_model/src/ops/data_layout.cc +++ b/reference_model/src/ops/data_layout.cc @@ -270,6 +270,15 @@ int OpDim::eval() this->out->getTensor().setValues({ out_val }); + // set the shapeValue given the actual tensor value + std::vector shapeValue; + for (int i = 0; i < out->getTensor().size(); ++i) + { + shapeValue.push_back(out->getTensor()(i)); + } + + this->getOutputs()[0]->setShapeValue(shapeValue); + return GraphNode::eval(); } diff --git a/reference_model/src/ops/shape.cc b/reference_model/src/ops/shape.cc index b087dd8..425dfc2 100644 --- a/reference_model/src/ops/shape.cc +++ b/reference_model/src/ops/shape.cc @@ -37,6 +37,18 @@ int OpConstShape::checkTensorAttributes() int OpConstShape::eval() { + // set the shapeValue given the actual tensor value + using EigenType = typename GetEigenType::type; + auto out = dynamic_cast>*>(this->getOutputs()[0]); + + std::vector shapeValue; + for (int i = 0; out != nullptr && i < out->getTensor().size(); ++i) + { + shapeValue.push_back(out->getTensor()(i)); + } + + this->getOutputs()[0]->setShapeValue(shapeValue); + for (auto ct : getOutputs()) { if (!ct->getIsValid()) @@ -106,6 +118,15 @@ int OpConcatShape::eval() } } out->getTensor() = out_tensor; + + // set the shapeValue given the actual tensor value + std::vector shapeValue; + for (int i = 0; i < out->getTensor().size(); ++i) + { + shapeValue.push_back(out->getTensor()(i)); + } + this->getOutputs()[0]->setShapeValue(shapeValue); + return GraphNode::eval(); } @@ -168,6 +189,15 @@ int ShapeBinaryNodeBase::eval() } result->getTensor() = out_tens; + + // set the shapeValue given the actual tensor value + std::vector shapeValue; + for (int i = 0; i < result->getTensor().size(); ++i) + { + shapeValue.push_back(result->getTensor()(i)); + } + this->getOutputs()[0]->setShapeValue(shapeValue); + return GraphNode::eval(); } diff --git a/reference_model/src/tensor.cc b/reference_model/src/tensor.cc index 1417fed..088c327 100644 --- a/reference_model/src/tensor.cc +++ b/reference_model/src/tensor.cc @@ -85,8 +85,17 @@ int TosaReference::Tensor::addConsumer(GraphNode* node) int TosaReference::Tensor::dumpTensorParams(FILE* out) const { - fprintf(out, "Name: %s DType=%s isValid=%d Rank=%d Shape=%s\n", tensorName.c_str(), EnumNameTOSAREFTYPE(getDtype()), - getIsValid(), getRank(), getShapeAsString().c_str()); + if (this->getShapeValueSize() > 0) + { + fprintf(out, "Name: %s DType=%s isValid=%d Rank=%d Shape=%s ShapeValue=%s\n", tensorName.c_str(), + EnumNameTOSAREFTYPE(getDtype()), getIsValid(), getRank(), getShapeAsString().c_str(), + getShapeValueAsString().c_str()); + } + else + { + fprintf(out, "Name: %s DType=%s isValid=%d Rank=%d Shape=%s\n", tensorName.c_str(), + EnumNameTOSAREFTYPE(getDtype()), getIsValid(), getRank(), getShapeAsString().c_str()); + } return 0; } diff --git a/reference_model/src/tensor.h b/reference_model/src/tensor.h index 26c6aa7..f13de0e 100644 --- a/reference_model/src/tensor.h +++ b/reference_model/src/tensor.h @@ -109,6 +109,31 @@ public: return; } + void setShapeValue(std::vector& shapeValue) + { + for (auto dim : shapeValue) + { + this->shapeValue.push_back(dim); + } + return; + } + + int getShapeValueSize() const + { + return this->shapeValue.size(); + } + + std::string getShapeValueAsString() const + { + std::string shape_str("["); + for (auto& dim : shapeValue) + { + shape_str += (std::to_string(dim) + ", "); + } + shape_str.append("]"); + return shape_str; + } + std::string getShapeAsString() const { std::string shape_str("["); @@ -297,6 +322,7 @@ protected: const std::string tensorName; const DType serializationDtype; std::vector shape; + std::vector shapeValue; const TOSA_REF_TYPE tensorDtype; bool isValid; bool isSubgraphInput; -- cgit v1.2.1