diff options
author | Jerry Ge <jerry.ge@arm.com> | 2024-04-01 17:05:10 +0000 |
---|---|---|
committer | Jerry Ge <jerry.ge@arm.com> | 2024-04-02 21:39:09 +0000 |
commit | 12159fc6fb776908f48fbda9c74cf34980540e4f (patch) | |
tree | 54408e7a10502a347fc7afa25b665c60d696b4d1 /reference_model/src/ops/shape.cc | |
parent | 9a97eb6cd6aab5eb58eb7860faa9fea305e37c07 (diff) | |
download | reference_model-12159fc6fb776908f48fbda9c74cf34980540e4f.tar.gz |
Show actual runtime value of shapeType tensors
* Enable showing actual runtime shapeType tensor value when the
--dump_intermediates=1 flag is on
Signed-off-by: Jerry Ge <jerry.ge@arm.com>
Change-Id: Ibd5aa8aa27505364fbbf9d1addd0bdef0deda885
Diffstat (limited to 'reference_model/src/ops/shape.cc')
-rw-r--r-- | reference_model/src/ops/shape.cc | 30 |
1 files changed, 30 insertions, 0 deletions
diff --git a/reference_model/src/ops/shape.cc b/reference_model/src/ops/shape.cc index b087dd8..425dfc2 100644 --- a/reference_model/src/ops/shape.cc +++ b/reference_model/src/ops/shape.cc @@ -37,6 +37,18 @@ int OpConstShape::checkTensorAttributes() int OpConstShape::eval() { + // set the shapeValue given the actual tensor value + using EigenType = typename GetEigenType<TOSA_REF_TYPE_SHAPE>::type; + auto out = dynamic_cast<TosaReference::TensorTemplate<Eigen::Tensor<EigenType, 1>>*>(this->getOutputs()[0]); + + std::vector<int> shapeValue; + for (int i = 0; out != nullptr && i < out->getTensor().size(); ++i) + { + shapeValue.push_back(out->getTensor()(i)); + } + + this->getOutputs()[0]->setShapeValue(shapeValue); + for (auto ct : getOutputs()) { if (!ct->getIsValid()) @@ -106,6 +118,15 @@ int OpConcatShape::eval() } } out->getTensor() = out_tensor; + + // set the shapeValue given the actual tensor value + std::vector<int> shapeValue; + for (int i = 0; i < out->getTensor().size(); ++i) + { + shapeValue.push_back(out->getTensor()(i)); + } + this->getOutputs()[0]->setShapeValue(shapeValue); + return GraphNode::eval(); } @@ -168,6 +189,15 @@ int ShapeBinaryNodeBase::eval() } result->getTensor() = out_tens; + + // set the shapeValue given the actual tensor value + std::vector<int> shapeValue; + for (int i = 0; i < result->getTensor().size(); ++i) + { + shapeValue.push_back(result->getTensor()(i)); + } + this->getOutputs()[0]->setShapeValue(shapeValue); + return GraphNode::eval(); } |