diff options
Diffstat (limited to 'reference_model/src')
-rw-r--r-- | reference_model/src/subgraph_traverser.cc | 67 |
1 files changed, 58 insertions, 9 deletions
diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc index b615feb..d64cb38 100644 --- a/reference_model/src/subgraph_traverser.cc +++ b/reference_model/src/subgraph_traverser.cc @@ -98,7 +98,6 @@ TosaReference::Tensor* SubgraphTraverser::getOutputTensorByName(const std::strin int SubgraphTraverser::initializeGraph() { - char tensor_fullname[1000]; int idx = 0; for (auto op : block->GetOperators()) { @@ -227,20 +226,70 @@ int SubgraphTraverser::initializeGraph() TosaReference::Tensor* tensor = TensorFactory::newTensor(ts->GetName(), ts->GetDtype(), ts->GetShape(), ts->GetShape().size()); - if (!ts->GetNpyFilePtr().empty()) + if (!ts->GetData().empty()) { if (tensor->allocate()) { - FATAL_ERROR("Fail to allocate Eigen tensor %s", tensor->getName().c_str()); + WARNING("Fail to allocate tensor %s", tensor->getName().c_str()); + return 1; } - bzero(tensor_fullname, sizeof(tensor_fullname)); - snprintf(tensor_fullname, sizeof(tensor_fullname), "%s/%s", g_func_config.flatbuffer_dir, - ts->GetNpyFilePtr().c_str()); - if (tensor->readFromNpyFile(tensor_fullname)) + switch (ts->GetDtype()) { - FATAL_ERROR("Cannot read input data into graph tensor %s from block %s", tensor->getName().c_str(), - block->GetName().c_str()); + case DType_INT8: + { + std::vector<int8_t> i8_data; + TosaSerializationHandler::ConvertU8toI8(ts->GetData(), tensor->getElementCount(), i8_data); + std::vector<int32_t> i32_data(i8_data.begin(), i8_data.end()); + tensor->setTensorValueInt32(i32_data.size(), i32_data.data()); + } + break; + case DType_INT16: + { + std::vector<int16_t> i16_data; + TosaSerializationHandler::ConvertU8toI16(ts->GetData(), tensor->getElementCount(), i16_data); + std::vector<int32_t> i32_data(i16_data.begin(), i16_data.end()); + tensor->setTensorValueInt32(i32_data.size(), i32_data.data()); + } + break; + case DType_INT32: + { + std::vector<int32_t> i32_data; + TosaSerializationHandler::ConvertU8toI32(ts->GetData(), tensor->getElementCount(), i32_data); + tensor->setTensorValueInt32(i32_data.size(), i32_data.data()); + } + break; + case DType_INT48: + { + std::vector<int64_t> i64_data; + TosaSerializationHandler::ConvertU8toI48(ts->GetData(), tensor->getElementCount(), i64_data); + tensor->setTensorValueInt64(i64_data.size(), i64_data.data()); + } + break; + case DType_FLOAT: + { + std::vector<float> fp32_data; + TosaSerializationHandler::ConvertU8toF32(ts->GetData(), tensor->getElementCount(), fp32_data); + tensor->setTensorValueFloat(fp32_data.size(), fp32_data.data()); + } + break; + case DType_BOOL: + { + std::vector<bool> bool_data; + TosaSerializationHandler::ConvertU8toBool(ts->GetData(), tensor->getElementCount(), bool_data); + + // std::vector<bool>::data() will return bit mask instead of array of bool array. + // Need to translate manually. + bool* bool_array = (bool*)calloc(bool_data.size(), sizeof(bool)); + for (size_t i = 0; i < bool_data.size(); i++) + { + bool_array[i] = bool_data[i]; + } + tensor->setTensorValueBool(bool_data.size(), bool_array); + } + break; + default: + FATAL_ERROR("Unsupported tensor type %s.", EnumNamesDType()[ts->GetDtype()]); } } |