aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/subgraph_traverser.cc
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2023-12-18 20:40:24 +0000
committerTai Ly <tai.ly@arm.com>2024-01-18 23:50:04 +0000
commit8690a0873fac28acccbb0acb15c16e8337e14df1 (patch)
treea13d5e195d8b7becffc23da98fde7449e91c96e4 /reference_model/src/subgraph_traverser.cc
parent9f5febe05901bfbd3919ef17f2caea8087cd9ccf (diff)
downloadreference_model-8690a0873fac28acccbb0acb15c16e8337e14df1.tar.gz
[reference model] Add shape operators
- fixed up reshape conformance tests to use shape input instead of attribute - fixed up tile conformance tests to use shape input instead of attribute - fixed output and output rank of dim op - allow rank 0 and rank 1 tensors for tosa.shape values (for shape = {}) - added initialization of rank 0 const_shape tensors (for shape = {}) - Update conformance tests to use new rescale attributes Signed-off-by: Tai Ly <tai.ly@arm.com> Signed-off-by: Won Jeon <won.jeon@arm.com> Change-Id: I6cce0d2a9ab066fe20a2abf9d2cfde3eb3d8c18b
Diffstat (limited to 'reference_model/src/subgraph_traverser.cc')
-rw-r--r--reference_model/src/subgraph_traverser.cc19
1 files changed, 17 insertions, 2 deletions
diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc
index 745213e..fae0b30 100644
--- a/reference_model/src/subgraph_traverser.cc
+++ b/reference_model/src/subgraph_traverser.cc
@@ -510,6 +510,12 @@ int SubgraphTraverser::allocateTensor(std::string name)
FATAL_ERROR("Failed to allocate tensor %s", tensor->getName().c_str());
}
+ // set valid for constant tensors:
+ if ((ts->GetShape().empty() && ts->GetDtype() == DType_SHAPE))
+ {
+ // corner case: const_shape {} has no data
+ tensor->setIsValid();
+ }
if (!ts->GetData().empty())
{
if (ts->GetVariable() && g_func_config.initialize_variable_tensor_from_numpy)
@@ -545,13 +551,18 @@ int SubgraphTraverser::allocateTensor(std::string name)
tensor->setTensorValueInt32(i32_data.size(), i32_data.data());
}
break;
- case DType_INT48:
- case DType_SHAPE: {
+ case DType_INT48: {
std::vector<int64_t> i64_data;
TosaSerializationHandler::ConvertU8toI48(ts->GetData(), tensor->getElementCount(), i64_data);
tensor->setTensorValueInt64(i64_data.size(), i64_data.data());
}
break;
+ case DType_SHAPE: {
+ std::vector<int64_t> i64_data;
+ TosaSerializationHandler::ConvertU8toI64(ts->GetData(), tensor->getElementCount(), i64_data);
+ tensor->setTensorValueInt64(i64_data.size(), i64_data.data());
+ }
+ break;
case DType_FP16: {
// Interpret f16 data as float
std::vector<float> f16_data;
@@ -617,6 +628,10 @@ int SubgraphTraverser::allocateTensor(std::string name)
EnumNameDType(ts->GetDtype()));
}
tensor->setIsValid();
+ }
+
+ if (tensor->getIsValid())
+ {
// Push ready consumers to the next node list
for (auto gn : tensor->getConsumers())
{