diff options
Diffstat (limited to 'reference_model/src/subgraph_traverser.cc')
-rw-r--r-- | reference_model/src/subgraph_traverser.cc | 90 |
1 files changed, 53 insertions, 37 deletions
diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc index e7641ba..4508291 100644 --- a/reference_model/src/subgraph_traverser.cc +++ b/reference_model/src/subgraph_traverser.cc @@ -138,9 +138,9 @@ int SubgraphTraverser::initializeGraph() for (auto op : block->GetOperators()) { // translated TosaSerializationOperator to GraphNode - DType input_dtype = DType_UNKNOWN; - DType output_dtype = DType_UNKNOWN; - DType weight_dtype = DType_UNKNOWN; + TOSA_REF_TYPE input_dtype = TOSA_REF_TYPE_UNKNOWN; + TOSA_REF_TYPE output_dtype = TOSA_REF_TYPE_UNKNOWN; + TOSA_REF_TYPE weight_dtype = TOSA_REF_TYPE_UNKNOWN; uint32_t input_rank = 0; uint32_t output_rank = 0; uint32_t weight_rank = 0; @@ -185,7 +185,7 @@ int SubgraphTraverser::initializeGraph() !input_tensor, "SubgraphTraverser::initializeGraph(): fail to get input tensor %s from TosaSerializationHandler", input_name.c_str()); - input_dtype = input_tensor->GetDtype(); + input_dtype = ConvertDType(input_tensor->GetDtype()); input_rank = input_tensor->GetShape().size(); } @@ -207,7 +207,7 @@ int SubgraphTraverser::initializeGraph() !weight_tensor, "SubgraphTraverser::initializeGraph(): fail to get weight tensor %s from TosaSerializationHandler", weight_name.c_str()); - weight_dtype = weight_tensor->GetDtype(); + weight_dtype = ConvertDType(weight_tensor->GetDtype()); weight_rank = weight_tensor->GetShape().size(); } @@ -220,7 +220,7 @@ int SubgraphTraverser::initializeGraph() !output_tensor, "SubgraphTraverser::initializeGraph(): fail to get output tensor %s from TosaSerializationHandler", output_name.c_str()); - output_dtype = output_tensor->GetDtype(); + output_dtype = ConvertDType(output_tensor->GetDtype()); output_rank = output_tensor->GetShape().size(); DEBUG_INFO(GT, "Creating operator id_%03u, %8s, %lu input tensors, %lu output tensors", idx, @@ -246,16 +246,16 @@ int SubgraphTraverser::initializeGraph() fprintf(g_func_debug.func_debug_file, "SubgraphTraverser::initializeGraph(): OpFactory could not allocate op %8s input=(%s rank %d) " "-> (%s rank %d)", - EnumNamesOp()[op->GetOp()], EnumNamesDType()[input_dtype], input_rank, - EnumNamesDType()[output_dtype], output_rank); + EnumNamesOp()[op->GetOp()], EnumNameTOSAREFTYPE(input_dtype), input_rank, + EnumNameTOSAREFTYPE(output_dtype), output_rank); } else { fprintf(g_func_debug.func_debug_file, "SubgraphTraverser::initializeGraph(): OpFactory could not allocate op %8s input=(%s rank %d), " "weight=(%s rank %d) -> (%s rank %d)", - EnumNamesOp()[op->GetOp()], EnumNamesDType()[input_dtype], input_rank, - EnumNamesDType()[weight_dtype], weight_rank, EnumNamesDType()[output_dtype], output_rank); + EnumNamesOp()[op->GetOp()], EnumNameTOSAREFTYPE(input_dtype), input_rank, + EnumNameTOSAREFTYPE(weight_dtype), weight_rank, EnumNameTOSAREFTYPE(output_dtype), output_rank); } for (auto& ts : op->GetInputTensorNames()) @@ -309,7 +309,7 @@ int SubgraphTraverser::initializeGraph() TensorFactory::newTensor(ts->GetName(), ts->GetDtype(), ts->GetShape(), ts->GetShape().size()); SUBGRAPH_ERROR_IF(!tensor, "SubgraphTraverser::initializeGraph(): Unsupported tensor name=%s, type=%s, rank=%d", - ts->GetName().c_str(), EnumNamesDType()[ts->GetDtype()], (int)ts->GetShape().size()); + ts->GetName().c_str(), EnumNameDType(ts->GetDtype()), (int)ts->GetShape().size()); addTensor(tensor); } @@ -411,73 +411,89 @@ int SubgraphTraverser::allocateTensor() if (!ts->GetData().empty()) { DEBUG_INFO(GT, "Allocating tensor %s", tensor->getName().c_str()); - switch (ts->GetDtype()) + auto serialization_dtype = ts->GetDtype(); + switch (serialization_dtype) { - case DType_INT4: - { + case DType_INT4: { std::vector<int8_t> i4_data; TosaSerializationHandler::ConvertU8toI4(ts->GetData(), tensor->getElementCount(), i4_data); std::vector<int32_t> i32_data(i4_data.begin(), i4_data.end()); tensor->setTensorValueInt32(i32_data.size(), i32_data.data()); } break; - case DType_INT8: - { + case DType_INT8: { std::vector<int8_t> i8_data; TosaSerializationHandler::ConvertU8toI8(ts->GetData(), tensor->getElementCount(), i8_data); std::vector<int32_t> i32_data(i8_data.begin(), i8_data.end()); tensor->setTensorValueInt32(i32_data.size(), i32_data.data()); } break; - case DType_INT16: - { + case DType_INT16: { std::vector<int16_t> i16_data; TosaSerializationHandler::ConvertU8toI16(ts->GetData(), tensor->getElementCount(), i16_data); std::vector<int32_t> i32_data(i16_data.begin(), i16_data.end()); tensor->setTensorValueInt32(i32_data.size(), i32_data.data()); } break; - case DType_INT32: - { + case DType_INT32: { std::vector<int32_t> i32_data; TosaSerializationHandler::ConvertU8toI32(ts->GetData(), tensor->getElementCount(), i32_data); tensor->setTensorValueInt32(i32_data.size(), i32_data.data()); } break; - case DType_INT48: - { + case DType_INT48: { std::vector<int64_t> i64_data; TosaSerializationHandler::ConvertU8toI48(ts->GetData(), tensor->getElementCount(), i64_data); tensor->setTensorValueInt64(i64_data.size(), i64_data.data()); } break; - case DType_FP16: - { + case DType_FP16: { // Interpret f16 data as float std::vector<float> f16_data; TosaSerializationHandler::ConvertU8toF16(ts->GetData(), tensor->getElementCount(), f16_data); - tensor->setTensorValueFloat(f16_data.size(), f16_data.data()); + if (tensor->getDtype() == TOSA_REF_TYPE_FP64) + { + std::vector<double> f64_data(f16_data.begin(), f16_data.end()); + tensor->setTensorValueDouble(f64_data.size(), f64_data.data()); + } + else + { + tensor->setTensorValueFloat(f16_data.size(), f16_data.data()); + } } break; - case DType_BF16: - { + case DType_BF16: { std::vector<float> fp32_data; TosaSerializationHandler::ConvertU8toF32(ts->GetData(), tensor->getElementCount(), fp32_data); // Ensure valid bfloat16 stored in each float for (auto f : fp32_data) ASSERT_MSG(checkValidBFloat(f), "Float value %f not valid bfloat16", f); - tensor->setTensorValueFloat(fp32_data.size(), fp32_data.data()); + if (tensor->getDtype() == TOSA_REF_TYPE_FP64) + { + std::vector<double> f64_data(fp32_data.begin(), fp32_data.end()); + tensor->setTensorValueDouble(f64_data.size(), f64_data.data()); + } + else + { + tensor->setTensorValueFloat(fp32_data.size(), fp32_data.data()); + } } break; - case DType_FP32: - { + case DType_FP32: { std::vector<float> fp32_data; TosaSerializationHandler::ConvertU8toF32(ts->GetData(), tensor->getElementCount(), fp32_data); - tensor->setTensorValueFloat(fp32_data.size(), fp32_data.data()); + if (tensor->getDtype() == TOSA_REF_TYPE_FP64) + { + std::vector<double> f64_data(fp32_data.begin(), fp32_data.end()); + tensor->setTensorValueDouble(f64_data.size(), f64_data.data()); + } + else + { + tensor->setTensorValueFloat(fp32_data.size(), fp32_data.data()); + } } break; - case DType_BOOL: - { + case DType_BOOL: { std::vector<bool> bool_data; TosaSerializationHandler::ConvertU8toBool(ts->GetData(), tensor->getElementCount(), bool_data); @@ -493,7 +509,7 @@ int SubgraphTraverser::allocateTensor() break; default: SUBGRAPH_ERROR_IF(true, "SubgraphTraverser::initializeGraph(): Unsupported tensor type %s.", - EnumNamesDType()[ts->GetDtype()]); + EnumNameDType(ts->GetDtype())); } } } @@ -802,14 +818,14 @@ int SubgraphTraverser::validateGraph() if (g_func_config.tosa_profile == 0) { - DType dtype = currTensor->getDtype(); + TOSA_REF_TYPE dtype = currTensor->getDtype(); // Float-point disallowed - if (dtype == DType_FP32 || dtype == DType_FP16) + if (dtype == TOSA_REF_TYPE_FP32 || dtype == TOSA_REF_TYPE_FP16) { WARNING("SubgraphTraverser::validateGraph(): TOSA Base Inference profile selected: All floating point " "disabled, but %s tensor %s found\n", - EnumNamesDType()[dtype], currTensor->getName().c_str()); + EnumNameTOSAREFTYPE(dtype), currTensor->getName().c_str()); return 1; } } |