aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/subgraph_traverser.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/subgraph_traverser.cc')
-rw-r--r--reference_model/src/subgraph_traverser.cc90
1 files changed, 53 insertions, 37 deletions
diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc
index e7641ba..4508291 100644
--- a/reference_model/src/subgraph_traverser.cc
+++ b/reference_model/src/subgraph_traverser.cc
@@ -138,9 +138,9 @@ int SubgraphTraverser::initializeGraph()
for (auto op : block->GetOperators())
{
// translated TosaSerializationOperator to GraphNode
- DType input_dtype = DType_UNKNOWN;
- DType output_dtype = DType_UNKNOWN;
- DType weight_dtype = DType_UNKNOWN;
+ TOSA_REF_TYPE input_dtype = TOSA_REF_TYPE_UNKNOWN;
+ TOSA_REF_TYPE output_dtype = TOSA_REF_TYPE_UNKNOWN;
+ TOSA_REF_TYPE weight_dtype = TOSA_REF_TYPE_UNKNOWN;
uint32_t input_rank = 0;
uint32_t output_rank = 0;
uint32_t weight_rank = 0;
@@ -185,7 +185,7 @@ int SubgraphTraverser::initializeGraph()
!input_tensor,
"SubgraphTraverser::initializeGraph(): fail to get input tensor %s from TosaSerializationHandler",
input_name.c_str());
- input_dtype = input_tensor->GetDtype();
+ input_dtype = ConvertDType(input_tensor->GetDtype());
input_rank = input_tensor->GetShape().size();
}
@@ -207,7 +207,7 @@ int SubgraphTraverser::initializeGraph()
!weight_tensor,
"SubgraphTraverser::initializeGraph(): fail to get weight tensor %s from TosaSerializationHandler",
weight_name.c_str());
- weight_dtype = weight_tensor->GetDtype();
+ weight_dtype = ConvertDType(weight_tensor->GetDtype());
weight_rank = weight_tensor->GetShape().size();
}
@@ -220,7 +220,7 @@ int SubgraphTraverser::initializeGraph()
!output_tensor,
"SubgraphTraverser::initializeGraph(): fail to get output tensor %s from TosaSerializationHandler",
output_name.c_str());
- output_dtype = output_tensor->GetDtype();
+ output_dtype = ConvertDType(output_tensor->GetDtype());
output_rank = output_tensor->GetShape().size();
DEBUG_INFO(GT, "Creating operator id_%03u, %8s, %lu input tensors, %lu output tensors", idx,
@@ -246,16 +246,16 @@ int SubgraphTraverser::initializeGraph()
fprintf(g_func_debug.func_debug_file,
"SubgraphTraverser::initializeGraph(): OpFactory could not allocate op %8s input=(%s rank %d) "
"-> (%s rank %d)",
- EnumNamesOp()[op->GetOp()], EnumNamesDType()[input_dtype], input_rank,
- EnumNamesDType()[output_dtype], output_rank);
+ EnumNamesOp()[op->GetOp()], EnumNameTOSAREFTYPE(input_dtype), input_rank,
+ EnumNameTOSAREFTYPE(output_dtype), output_rank);
}
else
{
fprintf(g_func_debug.func_debug_file,
"SubgraphTraverser::initializeGraph(): OpFactory could not allocate op %8s input=(%s rank %d), "
"weight=(%s rank %d) -> (%s rank %d)",
- EnumNamesOp()[op->GetOp()], EnumNamesDType()[input_dtype], input_rank,
- EnumNamesDType()[weight_dtype], weight_rank, EnumNamesDType()[output_dtype], output_rank);
+ EnumNamesOp()[op->GetOp()], EnumNameTOSAREFTYPE(input_dtype), input_rank,
+ EnumNameTOSAREFTYPE(weight_dtype), weight_rank, EnumNameTOSAREFTYPE(output_dtype), output_rank);
}
for (auto& ts : op->GetInputTensorNames())
@@ -309,7 +309,7 @@ int SubgraphTraverser::initializeGraph()
TensorFactory::newTensor(ts->GetName(), ts->GetDtype(), ts->GetShape(), ts->GetShape().size());
SUBGRAPH_ERROR_IF(!tensor, "SubgraphTraverser::initializeGraph(): Unsupported tensor name=%s, type=%s, rank=%d",
- ts->GetName().c_str(), EnumNamesDType()[ts->GetDtype()], (int)ts->GetShape().size());
+ ts->GetName().c_str(), EnumNameDType(ts->GetDtype()), (int)ts->GetShape().size());
addTensor(tensor);
}
@@ -411,73 +411,89 @@ int SubgraphTraverser::allocateTensor()
if (!ts->GetData().empty())
{
DEBUG_INFO(GT, "Allocating tensor %s", tensor->getName().c_str());
- switch (ts->GetDtype())
+ auto serialization_dtype = ts->GetDtype();
+ switch (serialization_dtype)
{
- case DType_INT4:
- {
+ case DType_INT4: {
std::vector<int8_t> i4_data;
TosaSerializationHandler::ConvertU8toI4(ts->GetData(), tensor->getElementCount(), i4_data);
std::vector<int32_t> i32_data(i4_data.begin(), i4_data.end());
tensor->setTensorValueInt32(i32_data.size(), i32_data.data());
}
break;
- case DType_INT8:
- {
+ case DType_INT8: {
std::vector<int8_t> i8_data;
TosaSerializationHandler::ConvertU8toI8(ts->GetData(), tensor->getElementCount(), i8_data);
std::vector<int32_t> i32_data(i8_data.begin(), i8_data.end());
tensor->setTensorValueInt32(i32_data.size(), i32_data.data());
}
break;
- case DType_INT16:
- {
+ case DType_INT16: {
std::vector<int16_t> i16_data;
TosaSerializationHandler::ConvertU8toI16(ts->GetData(), tensor->getElementCount(), i16_data);
std::vector<int32_t> i32_data(i16_data.begin(), i16_data.end());
tensor->setTensorValueInt32(i32_data.size(), i32_data.data());
}
break;
- case DType_INT32:
- {
+ case DType_INT32: {
std::vector<int32_t> i32_data;
TosaSerializationHandler::ConvertU8toI32(ts->GetData(), tensor->getElementCount(), i32_data);
tensor->setTensorValueInt32(i32_data.size(), i32_data.data());
}
break;
- case DType_INT48:
- {
+ case DType_INT48: {
std::vector<int64_t> i64_data;
TosaSerializationHandler::ConvertU8toI48(ts->GetData(), tensor->getElementCount(), i64_data);
tensor->setTensorValueInt64(i64_data.size(), i64_data.data());
}
break;
- case DType_FP16:
- {
+ case DType_FP16: {
// Interpret f16 data as float
std::vector<float> f16_data;
TosaSerializationHandler::ConvertU8toF16(ts->GetData(), tensor->getElementCount(), f16_data);
- tensor->setTensorValueFloat(f16_data.size(), f16_data.data());
+ if (tensor->getDtype() == TOSA_REF_TYPE_FP64)
+ {
+ std::vector<double> f64_data(f16_data.begin(), f16_data.end());
+ tensor->setTensorValueDouble(f64_data.size(), f64_data.data());
+ }
+ else
+ {
+ tensor->setTensorValueFloat(f16_data.size(), f16_data.data());
+ }
}
break;
- case DType_BF16:
- {
+ case DType_BF16: {
std::vector<float> fp32_data;
TosaSerializationHandler::ConvertU8toF32(ts->GetData(), tensor->getElementCount(), fp32_data);
// Ensure valid bfloat16 stored in each float
for (auto f : fp32_data)
ASSERT_MSG(checkValidBFloat(f), "Float value %f not valid bfloat16", f);
- tensor->setTensorValueFloat(fp32_data.size(), fp32_data.data());
+ if (tensor->getDtype() == TOSA_REF_TYPE_FP64)
+ {
+ std::vector<double> f64_data(fp32_data.begin(), fp32_data.end());
+ tensor->setTensorValueDouble(f64_data.size(), f64_data.data());
+ }
+ else
+ {
+ tensor->setTensorValueFloat(fp32_data.size(), fp32_data.data());
+ }
}
break;
- case DType_FP32:
- {
+ case DType_FP32: {
std::vector<float> fp32_data;
TosaSerializationHandler::ConvertU8toF32(ts->GetData(), tensor->getElementCount(), fp32_data);
- tensor->setTensorValueFloat(fp32_data.size(), fp32_data.data());
+ if (tensor->getDtype() == TOSA_REF_TYPE_FP64)
+ {
+ std::vector<double> f64_data(fp32_data.begin(), fp32_data.end());
+ tensor->setTensorValueDouble(f64_data.size(), f64_data.data());
+ }
+ else
+ {
+ tensor->setTensorValueFloat(fp32_data.size(), fp32_data.data());
+ }
}
break;
- case DType_BOOL:
- {
+ case DType_BOOL: {
std::vector<bool> bool_data;
TosaSerializationHandler::ConvertU8toBool(ts->GetData(), tensor->getElementCount(), bool_data);
@@ -493,7 +509,7 @@ int SubgraphTraverser::allocateTensor()
break;
default:
SUBGRAPH_ERROR_IF(true, "SubgraphTraverser::initializeGraph(): Unsupported tensor type %s.",
- EnumNamesDType()[ts->GetDtype()]);
+ EnumNameDType(ts->GetDtype()));
}
}
}
@@ -802,14 +818,14 @@ int SubgraphTraverser::validateGraph()
if (g_func_config.tosa_profile == 0)
{
- DType dtype = currTensor->getDtype();
+ TOSA_REF_TYPE dtype = currTensor->getDtype();
// Float-point disallowed
- if (dtype == DType_FP32 || dtype == DType_FP16)
+ if (dtype == TOSA_REF_TYPE_FP32 || dtype == TOSA_REF_TYPE_FP16)
{
WARNING("SubgraphTraverser::validateGraph(): TOSA Base Inference profile selected: All floating point "
"disabled, but %s tensor %s found\n",
- EnumNamesDType()[dtype], currTensor->getName().c_str());
+ EnumNameTOSAREFTYPE(dtype), currTensor->getName().c_str());
return 1;
}
}