From e5cabbf7528849aac35b498ce0711a144c1a08d5 Mon Sep 17 00:00:00 2001 From: Jerry Ge Date: Mon, 17 Jul 2023 21:33:17 +0000 Subject: Enable lazy tensor allocation - The previous ref_model was allocating the memory for all tensors in the graph upfront which is unnecessary and wasteful. - This patch changes to only allocate initial input tensors on the entry point using the allocateInputTensor() function. - The output tensors are ensured to have been allocated before executing a node. The output tenosrs are the inputs for the next node. - When a node's evaluation is finished, its input tensors will be freed if they will no longer be used by anyone else. Signed-off-by: Jerry Ge Change-Id: Ibb3e8c9e6344f6cd9eb20532a03b2097b93247f9 --- reference_model/src/main.cpp | 4 +- reference_model/src/model_runner_impl.cc | 4 +- reference_model/src/ops/control_flow.cc | 3 +- reference_model/src/subgraph_traverser.cc | 307 ++++++++++++++++-------------- reference_model/src/subgraph_traverser.h | 3 +- 5 files changed, 167 insertions(+), 154 deletions(-) diff --git a/reference_model/src/main.cpp b/reference_model/src/main.cpp index 0c86cbd..070eb33 100644 --- a/reference_model/src/main.cpp +++ b/reference_model/src/main.cpp @@ -109,9 +109,9 @@ int main(int argc, char** argv) goto done; } - if (main_gt.allocateTensor()) + if (main_gt.allocateInputTensors()) { - WARNING("Failed to allocate tensor. Evaluation aborted."); + WARNING("Failed to allocate input tensors. Evaluation aborted."); goto done; } diff --git a/reference_model/src/model_runner_impl.cc b/reference_model/src/model_runner_impl.cc index ce548e9..be97644 100644 --- a/reference_model/src/model_runner_impl.cc +++ b/reference_model/src/model_runner_impl.cc @@ -327,9 +327,9 @@ GraphStatus ModelRunnerImpl::initialize(TosaSerializationBasicBlock* bb, return _main_gt->getGraphStatus(); } - if (_main_gt->allocateTensor()) + if (_main_gt->allocateInputTensors()) { - WARNING("Failed to allocate tensor."); + WARNING("Failed to allocate input tensors."); return _main_gt->getGraphStatus(); } diff --git a/reference_model/src/ops/control_flow.cc b/reference_model/src/ops/control_flow.cc index 0afb7e2..6bbc587 100644 --- a/reference_model/src/ops/control_flow.cc +++ b/reference_model/src/ops/control_flow.cc @@ -43,7 +43,8 @@ int OpControlFlow::evalBlock(TosaSerializationBasicBlock* block, ERROR_IF(block_sgt.linkTensorsAndNodes(), "evalBlock(): Failed to link tensors and nodes for %s", block_name.c_str()); ERROR_IF(block_sgt.validateGraph(), "evalBlock(): Failed to validate subgraph for %s", block_name.c_str()); - ERROR_IF(block_sgt.allocateTensor(), "evalBlock(): Failed to allocate tensor for %s", block_name.c_str()); + ERROR_IF(block_sgt.allocateInputTensors(), "evalBlock(): Failed to allocate input tensors for %s", + block_name.c_str()); int num_input_tensors = block_sgt.getNumInputTensors(); int num_output_tensors = block_sgt.getNumOutputTensors(); diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc index a7ef5e9..5675be9 100644 --- a/reference_model/src/subgraph_traverser.cc +++ b/reference_model/src/subgraph_traverser.cc @@ -397,146 +397,156 @@ int SubgraphTraverser::initializeGraph() return 0; } -int SubgraphTraverser::allocateTensor() +int SubgraphTraverser::allocateInputTensors() { - for (auto ts : block->GetTensors()) + auto input_tensor_names_vec = block->GetInputs(); + + for (auto input_tensor_name : input_tensor_names_vec) + { + this->allocateTensor(input_tensor_name); + } + + return 0; +} + +int SubgraphTraverser::allocateTensor(std::string name) +{ + auto ts = block->GetTensorByName(name); + + // Bail out if tensor is used and any of its dimension is invalid. + auto got = used_tensor_name_set.find(ts->GetName()); + if (got != used_tensor_name_set.end()) { - // Bail out if tensor is used and any of its dimension is invalid. - auto got = used_tensor_name_set.find(ts->GetName()); - if (got != used_tensor_name_set.end()) + uint32_t elements = 1; + for (auto& dim : ts->GetShape()) { - uint32_t elements = 1; - for (auto& dim : ts->GetShape()) + if (dim <= 0) { - if (dim <= 0) - { - DEBUG_INFO(GT, "Failed to allocate tensor %s with invalid dimension of %d", ts->GetName().c_str(), - dim); - this->setGraphStatus(GraphStatus::TOSA_UNPREDICTABLE); - return 1; - } - if (dim > static_cast(TOSA_MAX_TENSOR_SIZE / elements)) - { - // Size greather than maximum defined in spec - DEBUG_INFO(GT, "Tensor %s size is greater than allowed maximum", ts->GetName().c_str()); - this->setGraphStatus(GraphStatus::TOSA_UNPREDICTABLE); - return 1; - } + DEBUG_INFO(GT, "Failed to allocate tensor %s with invalid dimension of %d", ts->GetName().c_str(), dim); + this->setGraphStatus(GraphStatus::TOSA_UNPREDICTABLE); + return 1; + } + if (dim > static_cast(TOSA_MAX_TENSOR_SIZE / elements)) + { + // Size greather than maximum defined in spec + DEBUG_INFO(GT, "Tensor %s size is greater than allowed maximum", ts->GetName().c_str()); + this->setGraphStatus(GraphStatus::TOSA_UNPREDICTABLE); + return 1; } } + } - TosaReference::Tensor* tensor = findTensorByName(ts->GetName()); - SUBGRAPH_ERROR_IF(!tensor, "SubgraphTraverser::allocateTensor(): can't find tensor %s.", ts->GetName().c_str()); + TosaReference::Tensor* tensor = findTensorByName(ts->GetName()); + SUBGRAPH_ERROR_IF(!tensor, "SubgraphTraverser::allocateTensor(): can't find tensor %s.", ts->GetName().c_str()); - DEBUG_INFO(GT, "Allocating tensor %s", tensor->getName().c_str()); - if (tensor->allocate()) - { - FATAL_ERROR("Failed to allocate tensor %s", tensor->getName().c_str()); - } + DEBUG_INFO(GT, "Allocating tensor %s", tensor->getName().c_str()); + if (tensor->allocate()) + { + FATAL_ERROR("Failed to allocate tensor %s", tensor->getName().c_str()); + } - if (!ts->GetData().empty()) + if (!ts->GetData().empty()) + { + DEBUG_INFO(GT, "Setting data for tensor %s", tensor->getName().c_str()); + auto serialization_dtype = ts->GetDtype(); + switch (serialization_dtype) { - DEBUG_INFO(GT, "Setting data for tensor %s", tensor->getName().c_str()); - auto serialization_dtype = ts->GetDtype(); - switch (serialization_dtype) - { - case DType_INT4: { - std::vector i4_data; - TosaSerializationHandler::ConvertU8toI4(ts->GetData(), tensor->getElementCount(), i4_data); - std::vector i32_data(i4_data.begin(), i4_data.end()); - tensor->setTensorValueInt32(i32_data.size(), i32_data.data()); - } - break; - case DType_INT8: { - std::vector i8_data; - TosaSerializationHandler::ConvertU8toI8(ts->GetData(), tensor->getElementCount(), i8_data); - std::vector i32_data(i8_data.begin(), i8_data.end()); - tensor->setTensorValueInt32(i32_data.size(), i32_data.data()); - } - break; - case DType_INT16: { - std::vector i16_data; - TosaSerializationHandler::ConvertU8toI16(ts->GetData(), tensor->getElementCount(), i16_data); - std::vector i32_data(i16_data.begin(), i16_data.end()); - tensor->setTensorValueInt32(i32_data.size(), i32_data.data()); + case DType_INT4: { + std::vector i4_data; + TosaSerializationHandler::ConvertU8toI4(ts->GetData(), tensor->getElementCount(), i4_data); + std::vector i32_data(i4_data.begin(), i4_data.end()); + tensor->setTensorValueInt32(i32_data.size(), i32_data.data()); + } + break; + case DType_INT8: { + std::vector i8_data; + TosaSerializationHandler::ConvertU8toI8(ts->GetData(), tensor->getElementCount(), i8_data); + std::vector i32_data(i8_data.begin(), i8_data.end()); + tensor->setTensorValueInt32(i32_data.size(), i32_data.data()); + } + break; + case DType_INT16: { + std::vector i16_data; + TosaSerializationHandler::ConvertU8toI16(ts->GetData(), tensor->getElementCount(), i16_data); + std::vector i32_data(i16_data.begin(), i16_data.end()); + tensor->setTensorValueInt32(i32_data.size(), i32_data.data()); + } + break; + case DType_INT32: { + std::vector i32_data; + TosaSerializationHandler::ConvertU8toI32(ts->GetData(), tensor->getElementCount(), i32_data); + tensor->setTensorValueInt32(i32_data.size(), i32_data.data()); + } + break; + case DType_INT48: { + std::vector i64_data; + TosaSerializationHandler::ConvertU8toI48(ts->GetData(), tensor->getElementCount(), i64_data); + tensor->setTensorValueInt64(i64_data.size(), i64_data.data()); + } + break; + case DType_FP16: { + // Interpret f16 data as float + std::vector f16_data; + TosaSerializationHandler::ConvertU8toF16(ts->GetData(), tensor->getElementCount(), f16_data); + if (tensor->getDtype() == TOSA_REF_TYPE_FP64) + { + std::vector f64_data(f16_data.begin(), f16_data.end()); + tensor->setTensorValueDouble(f64_data.size(), f64_data.data()); } - break; - case DType_INT32: { - std::vector i32_data; - TosaSerializationHandler::ConvertU8toI32(ts->GetData(), tensor->getElementCount(), i32_data); - tensor->setTensorValueInt32(i32_data.size(), i32_data.data()); + else + { + tensor->setTensorValueFloat(f16_data.size(), f16_data.data()); } - break; - case DType_INT48: { - std::vector i64_data; - TosaSerializationHandler::ConvertU8toI48(ts->GetData(), tensor->getElementCount(), i64_data); - tensor->setTensorValueInt64(i64_data.size(), i64_data.data()); + } + break; + case DType_BF16: { + std::vector fp32_data; + TosaSerializationHandler::ConvertU8toF32(ts->GetData(), tensor->getElementCount(), fp32_data); + // Ensure valid bfloat16 stored in each float + for (auto f : fp32_data) + ASSERT_MSG(checkValidBFloat(f), "Float value %f not valid bfloat16", f); + if (tensor->getDtype() == TOSA_REF_TYPE_FP64) + { + std::vector f64_data(fp32_data.begin(), fp32_data.end()); + tensor->setTensorValueDouble(f64_data.size(), f64_data.data()); } - break; - case DType_FP16: { - // Interpret f16 data as float - std::vector f16_data; - TosaSerializationHandler::ConvertU8toF16(ts->GetData(), tensor->getElementCount(), f16_data); - if (tensor->getDtype() == TOSA_REF_TYPE_FP64) - { - std::vector f64_data(f16_data.begin(), f16_data.end()); - tensor->setTensorValueDouble(f64_data.size(), f64_data.data()); - } - else - { - tensor->setTensorValueFloat(f16_data.size(), f16_data.data()); - } + else + { + tensor->setTensorValueFloat(fp32_data.size(), fp32_data.data()); } - break; - case DType_BF16: { - std::vector fp32_data; - TosaSerializationHandler::ConvertU8toF32(ts->GetData(), tensor->getElementCount(), fp32_data); - // Ensure valid bfloat16 stored in each float - for (auto f : fp32_data) - ASSERT_MSG(checkValidBFloat(f), "Float value %f not valid bfloat16", f); - if (tensor->getDtype() == TOSA_REF_TYPE_FP64) - { - std::vector f64_data(fp32_data.begin(), fp32_data.end()); - tensor->setTensorValueDouble(f64_data.size(), f64_data.data()); - } - else - { - tensor->setTensorValueFloat(fp32_data.size(), fp32_data.data()); - } + } + break; + case DType_FP32: { + std::vector fp32_data; + TosaSerializationHandler::ConvertU8toF32(ts->GetData(), tensor->getElementCount(), fp32_data); + if (tensor->getDtype() == TOSA_REF_TYPE_FP64) + { + std::vector f64_data(fp32_data.begin(), fp32_data.end()); + tensor->setTensorValueDouble(f64_data.size(), f64_data.data()); } - break; - case DType_FP32: { - std::vector fp32_data; - TosaSerializationHandler::ConvertU8toF32(ts->GetData(), tensor->getElementCount(), fp32_data); - if (tensor->getDtype() == TOSA_REF_TYPE_FP64) - { - std::vector f64_data(fp32_data.begin(), fp32_data.end()); - tensor->setTensorValueDouble(f64_data.size(), f64_data.data()); - } - else - { - tensor->setTensorValueFloat(fp32_data.size(), fp32_data.data()); - } + else + { + tensor->setTensorValueFloat(fp32_data.size(), fp32_data.data()); } - break; - case DType_BOOL: { - std::vector bool_data; - TosaSerializationHandler::ConvertU8toBool(ts->GetData(), tensor->getElementCount(), bool_data); - - // std::vector::data() will return bit mask instead of array of bool array. - // Need to translate manually. - bool* bool_array = (bool*)calloc(bool_data.size(), sizeof(bool)); - for (size_t i = 0; i < bool_data.size(); i++) - { - bool_array[i] = bool_data[i]; - } - tensor->setTensorValueBool(bool_data.size(), bool_array); + } + break; + case DType_BOOL: { + std::vector bool_data; + TosaSerializationHandler::ConvertU8toBool(ts->GetData(), tensor->getElementCount(), bool_data); + + // std::vector::data() will return bit mask instead of array of bool array. + // Need to translate manually. + bool* bool_array = (bool*)calloc(bool_data.size(), sizeof(bool)); + for (size_t i = 0; i < bool_data.size(); i++) + { + bool_array[i] = bool_data[i]; } - break; - default: - SUBGRAPH_ERROR_IF(true, "SubgraphTraverser::initializeGraph(): Unsupported tensor type %s.", - EnumNameDType(ts->GetDtype())); + tensor->setTensorValueBool(bool_data.size(), bool_array); } + break; + default: + SUBGRAPH_ERROR_IF(true, "SubgraphTraverser::initializeGraph(): Unsupported tensor type %s.", + EnumNameDType(ts->GetDtype())); } } @@ -593,11 +603,13 @@ int SubgraphTraverser::evaluateNextNode() for (auto tensor : currNode->getOutputs()) { if (!tensor->is_allocated()) - if (tensor->allocate()) + { + if (this->allocateTensor(tensor->getName())) { FATAL_ERROR("SubgraphTraverser::evaluateNextNode(): Failed to allocate Eigen tensor %s", tensor->getName().c_str()); } + } } if (currNode->eval()) @@ -607,41 +619,40 @@ int SubgraphTraverser::evaluateNextNode() } // free input tensor if all of its consumers have all of their outputs ready and it's not block's output - if (!currNode->getInMainBlock()) - { // we don't free it if the node is in main block and has nested blocks - for (auto tensor : currNode->getInputs()) + for (auto tensor : currNode->getInputs()) + { + bool in_use = false; + + auto tensor_check = findTensorByName(tensor->getName()); + if (tensor_check->getIsParentGraphOutput()) { - bool in_use = false; + // if it's parent's block output tensor, we can't free it + continue; + } - auto tensor_check = findTensorByName(tensor->getName()); - if (tensor_check->getIsParentGraphOutput()) + for (auto node : tensor->getConsumers()) + { + // If the node is inside a loop, the input tensor is still needed + if (!node->hasAllOutputsReady()) { - // if it's parent's block output tensor, we can't free it - continue; + in_use = true; } + } - for (auto node : tensor->getConsumers()) - { - // If the node is inside a loop, the input tensor is still needed - if (!node->hasAllOutputsReady()) - { - in_use = true; - } - } - for (auto name : block->GetOutputs()) + for (auto name : block->GetOutputs()) + { + if (name == tensor->getName()) { - if (name == tensor->getName()) - { - in_use = true; - } + in_use = true; } + } - if (!in_use) - { - tensor->deallocate(); - } + if (!in_use) + { + tensor->deallocate(); } } + // Search the output tensors of this node to see if // there are now new ready nodes available from completing this node for (TosaReference::Tensor* tensor : currNode->getOutputs()) diff --git a/reference_model/src/subgraph_traverser.h b/reference_model/src/subgraph_traverser.h index 00989ee..ef6ea42 100644 --- a/reference_model/src/subgraph_traverser.h +++ b/reference_model/src/subgraph_traverser.h @@ -49,7 +49,8 @@ public: int linkTensorsAndNodes(); int validateGraph(); - int allocateTensor(); + int allocateInputTensors(); + int allocateTensor(std::string name); int dumpGraph(FILE* out) const; int dumpNextNodeList(FILE* out) const; -- cgit v1.2.1