diff options
Diffstat (limited to 'src/backends/tosaReference/TosaRefBackend.cpp')
-rw-r--r-- | src/backends/tosaReference/TosaRefBackend.cpp | 47 |
1 files changed, 40 insertions, 7 deletions
diff --git a/src/backends/tosaReference/TosaRefBackend.cpp b/src/backends/tosaReference/TosaRefBackend.cpp index e3a516a5f9..554bb10f0f 100644 --- a/src/backends/tosaReference/TosaRefBackend.cpp +++ b/src/backends/tosaReference/TosaRefBackend.cpp @@ -83,16 +83,20 @@ OptimizationViews TosaRefBackend::OptimizeSubgraphView(const SubgraphView& subgr const ModelOptions& modelOptions) const { OptimizationViews optimizationViews(modelOptions); + auto handler = std::make_unique<TosaSerializationHandler>(); - // A main block should only be added once. - bool isMain = true; + std::vector<std::string> graphInputs; + std::vector<std::string> graphOutputs; + + std::vector<TosaSerializationOperator*> operators; + std::vector<TosaSerializationTensor*> tensors; auto it = subgraph.endIConnectable(); while (it != subgraph.beginIConnectable()) { --it; - Layer &base = *(PolymorphicDowncast<Layer*>(*it)); + Layer& base = *(PolymorphicDowncast<Layer*>(*it)); if(base.GetType() == armnn::LayerType::Input || base.GetType() == armnn::LayerType::Output) @@ -100,15 +104,44 @@ OptimizationViews TosaRefBackend::OptimizeSubgraphView(const SubgraphView& subgr continue; } - tosa::TosaSerializationBasicBlock* mappings = GetTosaMappingFromLayer(&base, isMain); - handler.get()->GetBlocks().push_back(mappings); + tosa::TosaSerializationBasicBlock* mappings = GetTosaMappingFromLayer(&base); - if(isMain) + // Loop through inputs to see if there are any graph inputs, if so save them. + // If it's an input to the graph "input" can be found in the string. + for (uint32_t i = 0; i < mappings->GetInputs().size(); i++) { - isMain = false; + std::basic_string<char> blockInputName = mappings->GetInputs()[i]; + + if (blockInputName.find("input") != std::string::npos) + { + graphInputs.push_back(blockInputName); + } } + + // Loop through outputs to see if there are any graph outputs, if so save them. + // If it's an output to the graph "output" can be found in the string. + for (uint32_t i = 0; i < mappings->GetOutputs().size(); i++) + { + std::basic_string<char> blockOutputName = mappings->GetOutputs()[i]; + + if (blockOutputName.find("output") != std::string::npos) + { + graphOutputs.push_back(blockOutputName); + } + } + + auto blockOperators = mappings->GetOperators(); + operators.insert(operators.end(), blockOperators.begin(), blockOperators.end()); + + auto blockTensors = mappings->GetTensors(); + tensors.insert(tensors.end(), blockTensors.begin(), blockTensors.end()); } + // Add all mappings to main block, the TOSA Reference Model requires the full graph to be in one block called main. + auto* block = new TosaSerializationBasicBlock("main", operators, tensors, graphInputs, graphOutputs); + + handler.get()->GetBlocks().push_back(block); + auto compiledBlob = std::make_unique<PreCompiledObjectPtr>(handler.release(), DeleteAsType<TosaSerializationHandler>); |