aboutsummaryrefslogtreecommitdiff
path: root/reference_model
diff options
context:
space:
mode:
authorKevin Cheng <kevin.cheng@arm.com>2021-06-17 16:01:59 -0700
committerKevin Cheng <kevin.cheng@arm.com>2021-06-24 16:37:58 -0700
commit82507d77056dd5510547438ba2064c1ee8bebc2c (patch)
treeca5c4bd49029430b1153091f30680866e97ccd2f /reference_model
parent2d60f0063eb91f6514b20a1817663ce0ddd3ff4a (diff)
downloadreference_model-82507d77056dd5510547438ba2064c1ee8bebc2c.tar.gz
Update to use new serialization_lib API.
- Constant tensors are now initialized from embedded u8 array instead from numpy. - Python unit test generator and built-in test hasn't been updated. Signed-off-by: Kevin Cheng <kevin.cheng@arm.com> Change-Id: I5cb86f8e5ec8f23fee5dcbf257874a0f204ede04
Diffstat (limited to 'reference_model')
-rw-r--r--reference_model/src/subgraph_traverser.cc67
1 files changed, 58 insertions, 9 deletions
diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc
index b615feb..d64cb38 100644
--- a/reference_model/src/subgraph_traverser.cc
+++ b/reference_model/src/subgraph_traverser.cc
@@ -98,7 +98,6 @@ TosaReference::Tensor* SubgraphTraverser::getOutputTensorByName(const std::strin
int SubgraphTraverser::initializeGraph()
{
- char tensor_fullname[1000];
int idx = 0;
for (auto op : block->GetOperators())
{
@@ -227,20 +226,70 @@ int SubgraphTraverser::initializeGraph()
TosaReference::Tensor* tensor =
TensorFactory::newTensor(ts->GetName(), ts->GetDtype(), ts->GetShape(), ts->GetShape().size());
- if (!ts->GetNpyFilePtr().empty())
+ if (!ts->GetData().empty())
{
if (tensor->allocate())
{
- FATAL_ERROR("Fail to allocate Eigen tensor %s", tensor->getName().c_str());
+ WARNING("Fail to allocate tensor %s", tensor->getName().c_str());
+ return 1;
}
- bzero(tensor_fullname, sizeof(tensor_fullname));
- snprintf(tensor_fullname, sizeof(tensor_fullname), "%s/%s", g_func_config.flatbuffer_dir,
- ts->GetNpyFilePtr().c_str());
- if (tensor->readFromNpyFile(tensor_fullname))
+ switch (ts->GetDtype())
{
- FATAL_ERROR("Cannot read input data into graph tensor %s from block %s", tensor->getName().c_str(),
- block->GetName().c_str());
+ case DType_INT8:
+ {
+ std::vector<int8_t> i8_data;
+ TosaSerializationHandler::ConvertU8toI8(ts->GetData(), tensor->getElementCount(), i8_data);
+ std::vector<int32_t> i32_data(i8_data.begin(), i8_data.end());
+ tensor->setTensorValueInt32(i32_data.size(), i32_data.data());
+ }
+ break;
+ case DType_INT16:
+ {
+ std::vector<int16_t> i16_data;
+ TosaSerializationHandler::ConvertU8toI16(ts->GetData(), tensor->getElementCount(), i16_data);
+ std::vector<int32_t> i32_data(i16_data.begin(), i16_data.end());
+ tensor->setTensorValueInt32(i32_data.size(), i32_data.data());
+ }
+ break;
+ case DType_INT32:
+ {
+ std::vector<int32_t> i32_data;
+ TosaSerializationHandler::ConvertU8toI32(ts->GetData(), tensor->getElementCount(), i32_data);
+ tensor->setTensorValueInt32(i32_data.size(), i32_data.data());
+ }
+ break;
+ case DType_INT48:
+ {
+ std::vector<int64_t> i64_data;
+ TosaSerializationHandler::ConvertU8toI48(ts->GetData(), tensor->getElementCount(), i64_data);
+ tensor->setTensorValueInt64(i64_data.size(), i64_data.data());
+ }
+ break;
+ case DType_FLOAT:
+ {
+ std::vector<float> fp32_data;
+ TosaSerializationHandler::ConvertU8toF32(ts->GetData(), tensor->getElementCount(), fp32_data);
+ tensor->setTensorValueFloat(fp32_data.size(), fp32_data.data());
+ }
+ break;
+ case DType_BOOL:
+ {
+ std::vector<bool> bool_data;
+ TosaSerializationHandler::ConvertU8toBool(ts->GetData(), tensor->getElementCount(), bool_data);
+
+ // std::vector<bool>::data() will return bit mask instead of array of bool array.
+ // Need to translate manually.
+ bool* bool_array = (bool*)calloc(bool_data.size(), sizeof(bool));
+ for (size_t i = 0; i < bool_data.size(); i++)
+ {
+ bool_array[i] = bool_data[i];
+ }
+ tensor->setTensorValueBool(bool_data.size(), bool_array);
+ }
+ break;
+ default:
+ FATAL_ERROR("Unsupported tensor type %s.", EnumNamesDType()[ts->GetDtype()]);
}
}