aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJerry Ge <jerry.ge@arm.com>2023-07-17 21:33:17 +0000
committerJerry Ge <jerry.ge@arm.com>2023-07-27 17:24:26 +0000
commite5cabbf7528849aac35b498ce0711a144c1a08d5 (patch)
treec4d4756a5961cd095a4f0faa677dbe6b47522f82
parent97b0027ca018ad9b5f91f41b413e843afb15c6d7 (diff)
downloadreference_model-e5cabbf7528849aac35b498ce0711a144c1a08d5.tar.gz
Enable lazy tensor allocation
- The previous ref_model was allocating the memory for all tensors in the graph upfront which is unnecessary and wasteful. - This patch changes to only allocate initial input tensors on the entry point using the allocateInputTensor() function. - The output tensors are ensured to have been allocated before executing a node. The output tenosrs are the inputs for the next node. - When a node's evaluation is finished, its input tensors will be freed if they will no longer be used by anyone else. Signed-off-by: Jerry Ge <jerry.ge@arm.com> Change-Id: Ibb3e8c9e6344f6cd9eb20532a03b2097b93247f9
-rw-r--r--reference_model/src/main.cpp4
-rw-r--r--reference_model/src/model_runner_impl.cc4
-rw-r--r--reference_model/src/ops/control_flow.cc3
-rw-r--r--reference_model/src/subgraph_traverser.cc307
-rw-r--r--reference_model/src/subgraph_traverser.h3
5 files changed, 167 insertions, 154 deletions
diff --git a/reference_model/src/main.cpp b/reference_model/src/main.cpp
index 0c86cbd..070eb33 100644
--- a/reference_model/src/main.cpp
+++ b/reference_model/src/main.cpp
@@ -109,9 +109,9 @@ int main(int argc, char** argv)
goto done;
}
- if (main_gt.allocateTensor())
+ if (main_gt.allocateInputTensors())
{
- WARNING("Failed to allocate tensor. Evaluation aborted.");
+ WARNING("Failed to allocate input tensors. Evaluation aborted.");
goto done;
}
diff --git a/reference_model/src/model_runner_impl.cc b/reference_model/src/model_runner_impl.cc
index ce548e9..be97644 100644
--- a/reference_model/src/model_runner_impl.cc
+++ b/reference_model/src/model_runner_impl.cc
@@ -327,9 +327,9 @@ GraphStatus ModelRunnerImpl::initialize(TosaSerializationBasicBlock* bb,
return _main_gt->getGraphStatus();
}
- if (_main_gt->allocateTensor())
+ if (_main_gt->allocateInputTensors())
{
- WARNING("Failed to allocate tensor.");
+ WARNING("Failed to allocate input tensors.");
return _main_gt->getGraphStatus();
}
diff --git a/reference_model/src/ops/control_flow.cc b/reference_model/src/ops/control_flow.cc
index 0afb7e2..6bbc587 100644
--- a/reference_model/src/ops/control_flow.cc
+++ b/reference_model/src/ops/control_flow.cc
@@ -43,7 +43,8 @@ int OpControlFlow::evalBlock(TosaSerializationBasicBlock* block,
ERROR_IF(block_sgt.linkTensorsAndNodes(), "evalBlock(): Failed to link tensors and nodes for %s",
block_name.c_str());
ERROR_IF(block_sgt.validateGraph(), "evalBlock(): Failed to validate subgraph for %s", block_name.c_str());
- ERROR_IF(block_sgt.allocateTensor(), "evalBlock(): Failed to allocate tensor for %s", block_name.c_str());
+ ERROR_IF(block_sgt.allocateInputTensors(), "evalBlock(): Failed to allocate input tensors for %s",
+ block_name.c_str());
int num_input_tensors = block_sgt.getNumInputTensors();
int num_output_tensors = block_sgt.getNumOutputTensors();
diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc
index a7ef5e9..5675be9 100644
--- a/reference_model/src/subgraph_traverser.cc
+++ b/reference_model/src/subgraph_traverser.cc
@@ -397,146 +397,156 @@ int SubgraphTraverser::initializeGraph()
return 0;
}
-int SubgraphTraverser::allocateTensor()
+int SubgraphTraverser::allocateInputTensors()
{
- for (auto ts : block->GetTensors())
+ auto input_tensor_names_vec = block->GetInputs();
+
+ for (auto input_tensor_name : input_tensor_names_vec)
+ {
+ this->allocateTensor(input_tensor_name);
+ }
+
+ return 0;
+}
+
+int SubgraphTraverser::allocateTensor(std::string name)
+{
+ auto ts = block->GetTensorByName(name);
+
+ // Bail out if tensor is used and any of its dimension is invalid.
+ auto got = used_tensor_name_set.find(ts->GetName());
+ if (got != used_tensor_name_set.end())
{
- // Bail out if tensor is used and any of its dimension is invalid.
- auto got = used_tensor_name_set.find(ts->GetName());
- if (got != used_tensor_name_set.end())
+ uint32_t elements = 1;
+ for (auto& dim : ts->GetShape())
{
- uint32_t elements = 1;
- for (auto& dim : ts->GetShape())
+ if (dim <= 0)
{
- if (dim <= 0)
- {
- DEBUG_INFO(GT, "Failed to allocate tensor %s with invalid dimension of %d", ts->GetName().c_str(),
- dim);
- this->setGraphStatus(GraphStatus::TOSA_UNPREDICTABLE);
- return 1;
- }
- if (dim > static_cast<int32_t>(TOSA_MAX_TENSOR_SIZE / elements))
- {
- // Size greather than maximum defined in spec
- DEBUG_INFO(GT, "Tensor %s size is greater than allowed maximum", ts->GetName().c_str());
- this->setGraphStatus(GraphStatus::TOSA_UNPREDICTABLE);
- return 1;
- }
+ DEBUG_INFO(GT, "Failed to allocate tensor %s with invalid dimension of %d", ts->GetName().c_str(), dim);
+ this->setGraphStatus(GraphStatus::TOSA_UNPREDICTABLE);
+ return 1;
+ }
+ if (dim > static_cast<int32_t>(TOSA_MAX_TENSOR_SIZE / elements))
+ {
+ // Size greather than maximum defined in spec
+ DEBUG_INFO(GT, "Tensor %s size is greater than allowed maximum", ts->GetName().c_str());
+ this->setGraphStatus(GraphStatus::TOSA_UNPREDICTABLE);
+ return 1;
}
}
+ }
- TosaReference::Tensor* tensor = findTensorByName(ts->GetName());
- SUBGRAPH_ERROR_IF(!tensor, "SubgraphTraverser::allocateTensor(): can't find tensor %s.", ts->GetName().c_str());
+ TosaReference::Tensor* tensor = findTensorByName(ts->GetName());
+ SUBGRAPH_ERROR_IF(!tensor, "SubgraphTraverser::allocateTensor(): can't find tensor %s.", ts->GetName().c_str());
- DEBUG_INFO(GT, "Allocating tensor %s", tensor->getName().c_str());
- if (tensor->allocate())
- {
- FATAL_ERROR("Failed to allocate tensor %s", tensor->getName().c_str());
- }
+ DEBUG_INFO(GT, "Allocating tensor %s", tensor->getName().c_str());
+ if (tensor->allocate())
+ {
+ FATAL_ERROR("Failed to allocate tensor %s", tensor->getName().c_str());
+ }
- if (!ts->GetData().empty())
+ if (!ts->GetData().empty())
+ {
+ DEBUG_INFO(GT, "Setting data for tensor %s", tensor->getName().c_str());
+ auto serialization_dtype = ts->GetDtype();
+ switch (serialization_dtype)
{
- DEBUG_INFO(GT, "Setting data for tensor %s", tensor->getName().c_str());
- auto serialization_dtype = ts->GetDtype();
- switch (serialization_dtype)
- {
- case DType_INT4: {
- std::vector<int8_t> i4_data;
- TosaSerializationHandler::ConvertU8toI4(ts->GetData(), tensor->getElementCount(), i4_data);
- std::vector<int32_t> i32_data(i4_data.begin(), i4_data.end());
- tensor->setTensorValueInt32(i32_data.size(), i32_data.data());
- }
- break;
- case DType_INT8: {
- std::vector<int8_t> i8_data;
- TosaSerializationHandler::ConvertU8toI8(ts->GetData(), tensor->getElementCount(), i8_data);
- std::vector<int32_t> i32_data(i8_data.begin(), i8_data.end());
- tensor->setTensorValueInt32(i32_data.size(), i32_data.data());
- }
- break;
- case DType_INT16: {
- std::vector<int16_t> i16_data;
- TosaSerializationHandler::ConvertU8toI16(ts->GetData(), tensor->getElementCount(), i16_data);
- std::vector<int32_t> i32_data(i16_data.begin(), i16_data.end());
- tensor->setTensorValueInt32(i32_data.size(), i32_data.data());
+ case DType_INT4: {
+ std::vector<int8_t> i4_data;
+ TosaSerializationHandler::ConvertU8toI4(ts->GetData(), tensor->getElementCount(), i4_data);
+ std::vector<int32_t> i32_data(i4_data.begin(), i4_data.end());
+ tensor->setTensorValueInt32(i32_data.size(), i32_data.data());
+ }
+ break;
+ case DType_INT8: {
+ std::vector<int8_t> i8_data;
+ TosaSerializationHandler::ConvertU8toI8(ts->GetData(), tensor->getElementCount(), i8_data);
+ std::vector<int32_t> i32_data(i8_data.begin(), i8_data.end());
+ tensor->setTensorValueInt32(i32_data.size(), i32_data.data());
+ }
+ break;
+ case DType_INT16: {
+ std::vector<int16_t> i16_data;
+ TosaSerializationHandler::ConvertU8toI16(ts->GetData(), tensor->getElementCount(), i16_data);
+ std::vector<int32_t> i32_data(i16_data.begin(), i16_data.end());
+ tensor->setTensorValueInt32(i32_data.size(), i32_data.data());
+ }
+ break;
+ case DType_INT32: {
+ std::vector<int32_t> i32_data;
+ TosaSerializationHandler::ConvertU8toI32(ts->GetData(), tensor->getElementCount(), i32_data);
+ tensor->setTensorValueInt32(i32_data.size(), i32_data.data());
+ }
+ break;
+ case DType_INT48: {
+ std::vector<int64_t> i64_data;
+ TosaSerializationHandler::ConvertU8toI48(ts->GetData(), tensor->getElementCount(), i64_data);
+ tensor->setTensorValueInt64(i64_data.size(), i64_data.data());
+ }
+ break;
+ case DType_FP16: {
+ // Interpret f16 data as float
+ std::vector<float> f16_data;
+ TosaSerializationHandler::ConvertU8toF16(ts->GetData(), tensor->getElementCount(), f16_data);
+ if (tensor->getDtype() == TOSA_REF_TYPE_FP64)
+ {
+ std::vector<double> f64_data(f16_data.begin(), f16_data.end());
+ tensor->setTensorValueDouble(f64_data.size(), f64_data.data());
}
- break;
- case DType_INT32: {
- std::vector<int32_t> i32_data;
- TosaSerializationHandler::ConvertU8toI32(ts->GetData(), tensor->getElementCount(), i32_data);
- tensor->setTensorValueInt32(i32_data.size(), i32_data.data());
+ else
+ {
+ tensor->setTensorValueFloat(f16_data.size(), f16_data.data());
}
- break;
- case DType_INT48: {
- std::vector<int64_t> i64_data;
- TosaSerializationHandler::ConvertU8toI48(ts->GetData(), tensor->getElementCount(), i64_data);
- tensor->setTensorValueInt64(i64_data.size(), i64_data.data());
+ }
+ break;
+ case DType_BF16: {
+ std::vector<float> fp32_data;
+ TosaSerializationHandler::ConvertU8toF32(ts->GetData(), tensor->getElementCount(), fp32_data);
+ // Ensure valid bfloat16 stored in each float
+ for (auto f : fp32_data)
+ ASSERT_MSG(checkValidBFloat(f), "Float value %f not valid bfloat16", f);
+ if (tensor->getDtype() == TOSA_REF_TYPE_FP64)
+ {
+ std::vector<double> f64_data(fp32_data.begin(), fp32_data.end());
+ tensor->setTensorValueDouble(f64_data.size(), f64_data.data());
}
- break;
- case DType_FP16: {
- // Interpret f16 data as float
- std::vector<float> f16_data;
- TosaSerializationHandler::ConvertU8toF16(ts->GetData(), tensor->getElementCount(), f16_data);
- if (tensor->getDtype() == TOSA_REF_TYPE_FP64)
- {
- std::vector<double> f64_data(f16_data.begin(), f16_data.end());
- tensor->setTensorValueDouble(f64_data.size(), f64_data.data());
- }
- else
- {
- tensor->setTensorValueFloat(f16_data.size(), f16_data.data());
- }
+ else
+ {
+ tensor->setTensorValueFloat(fp32_data.size(), fp32_data.data());
}
- break;
- case DType_BF16: {
- std::vector<float> fp32_data;
- TosaSerializationHandler::ConvertU8toF32(ts->GetData(), tensor->getElementCount(), fp32_data);
- // Ensure valid bfloat16 stored in each float
- for (auto f : fp32_data)
- ASSERT_MSG(checkValidBFloat(f), "Float value %f not valid bfloat16", f);
- if (tensor->getDtype() == TOSA_REF_TYPE_FP64)
- {
- std::vector<double> f64_data(fp32_data.begin(), fp32_data.end());
- tensor->setTensorValueDouble(f64_data.size(), f64_data.data());
- }
- else
- {
- tensor->setTensorValueFloat(fp32_data.size(), fp32_data.data());
- }
+ }
+ break;
+ case DType_FP32: {
+ std::vector<float> fp32_data;
+ TosaSerializationHandler::ConvertU8toF32(ts->GetData(), tensor->getElementCount(), fp32_data);
+ if (tensor->getDtype() == TOSA_REF_TYPE_FP64)
+ {
+ std::vector<double> f64_data(fp32_data.begin(), fp32_data.end());
+ tensor->setTensorValueDouble(f64_data.size(), f64_data.data());
}
- break;
- case DType_FP32: {
- std::vector<float> fp32_data;
- TosaSerializationHandler::ConvertU8toF32(ts->GetData(), tensor->getElementCount(), fp32_data);
- if (tensor->getDtype() == TOSA_REF_TYPE_FP64)
- {
- std::vector<double> f64_data(fp32_data.begin(), fp32_data.end());
- tensor->setTensorValueDouble(f64_data.size(), f64_data.data());
- }
- else
- {
- tensor->setTensorValueFloat(fp32_data.size(), fp32_data.data());
- }
+ else
+ {
+ tensor->setTensorValueFloat(fp32_data.size(), fp32_data.data());
}
- break;
- case DType_BOOL: {
- std::vector<bool> bool_data;
- TosaSerializationHandler::ConvertU8toBool(ts->GetData(), tensor->getElementCount(), bool_data);
-
- // std::vector<bool>::data() will return bit mask instead of array of bool array.
- // Need to translate manually.
- bool* bool_array = (bool*)calloc(bool_data.size(), sizeof(bool));
- for (size_t i = 0; i < bool_data.size(); i++)
- {
- bool_array[i] = bool_data[i];
- }
- tensor->setTensorValueBool(bool_data.size(), bool_array);
+ }
+ break;
+ case DType_BOOL: {
+ std::vector<bool> bool_data;
+ TosaSerializationHandler::ConvertU8toBool(ts->GetData(), tensor->getElementCount(), bool_data);
+
+ // std::vector<bool>::data() will return bit mask instead of array of bool array.
+ // Need to translate manually.
+ bool* bool_array = (bool*)calloc(bool_data.size(), sizeof(bool));
+ for (size_t i = 0; i < bool_data.size(); i++)
+ {
+ bool_array[i] = bool_data[i];
}
- break;
- default:
- SUBGRAPH_ERROR_IF(true, "SubgraphTraverser::initializeGraph(): Unsupported tensor type %s.",
- EnumNameDType(ts->GetDtype()));
+ tensor->setTensorValueBool(bool_data.size(), bool_array);
}
+ break;
+ default:
+ SUBGRAPH_ERROR_IF(true, "SubgraphTraverser::initializeGraph(): Unsupported tensor type %s.",
+ EnumNameDType(ts->GetDtype()));
}
}
@@ -593,11 +603,13 @@ int SubgraphTraverser::evaluateNextNode()
for (auto tensor : currNode->getOutputs())
{
if (!tensor->is_allocated())
- if (tensor->allocate())
+ {
+ if (this->allocateTensor(tensor->getName()))
{
FATAL_ERROR("SubgraphTraverser::evaluateNextNode(): Failed to allocate Eigen tensor %s",
tensor->getName().c_str());
}
+ }
}
if (currNode->eval())
@@ -607,41 +619,40 @@ int SubgraphTraverser::evaluateNextNode()
}
// free input tensor if all of its consumers have all of their outputs ready and it's not block's output
- if (!currNode->getInMainBlock())
- { // we don't free it if the node is in main block and has nested blocks
- for (auto tensor : currNode->getInputs())
+ for (auto tensor : currNode->getInputs())
+ {
+ bool in_use = false;
+
+ auto tensor_check = findTensorByName(tensor->getName());
+ if (tensor_check->getIsParentGraphOutput())
{
- bool in_use = false;
+ // if it's parent's block output tensor, we can't free it
+ continue;
+ }
- auto tensor_check = findTensorByName(tensor->getName());
- if (tensor_check->getIsParentGraphOutput())
+ for (auto node : tensor->getConsumers())
+ {
+ // If the node is inside a loop, the input tensor is still needed
+ if (!node->hasAllOutputsReady())
{
- // if it's parent's block output tensor, we can't free it
- continue;
+ in_use = true;
}
+ }
- for (auto node : tensor->getConsumers())
- {
- // If the node is inside a loop, the input tensor is still needed
- if (!node->hasAllOutputsReady())
- {
- in_use = true;
- }
- }
- for (auto name : block->GetOutputs())
+ for (auto name : block->GetOutputs())
+ {
+ if (name == tensor->getName())
{
- if (name == tensor->getName())
- {
- in_use = true;
- }
+ in_use = true;
}
+ }
- if (!in_use)
- {
- tensor->deallocate();
- }
+ if (!in_use)
+ {
+ tensor->deallocate();
}
}
+
// Search the output tensors of this node to see if
// there are now new ready nodes available from completing this node
for (TosaReference::Tensor* tensor : currNode->getOutputs())
diff --git a/reference_model/src/subgraph_traverser.h b/reference_model/src/subgraph_traverser.h
index 00989ee..ef6ea42 100644
--- a/reference_model/src/subgraph_traverser.h
+++ b/reference_model/src/subgraph_traverser.h
@@ -49,7 +49,8 @@ public:
int linkTensorsAndNodes();
int validateGraph();
- int allocateTensor();
+ int allocateInputTensors();
+ int allocateTensor(std::string name);
int dumpGraph(FILE* out) const;
int dumpNextNodeList(FILE* out) const;