aboutsummaryrefslogtreecommitdiff
path: root/src/backends/tosaReference/TosaRefBackend.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/tosaReference/TosaRefBackend.cpp')
-rw-r--r--src/backends/tosaReference/TosaRefBackend.cpp47
1 files changed, 40 insertions, 7 deletions
diff --git a/src/backends/tosaReference/TosaRefBackend.cpp b/src/backends/tosaReference/TosaRefBackend.cpp
index e3a516a5f9..554bb10f0f 100644
--- a/src/backends/tosaReference/TosaRefBackend.cpp
+++ b/src/backends/tosaReference/TosaRefBackend.cpp
@@ -83,16 +83,20 @@ OptimizationViews TosaRefBackend::OptimizeSubgraphView(const SubgraphView& subgr
const ModelOptions& modelOptions) const
{
OptimizationViews optimizationViews(modelOptions);
+
auto handler = std::make_unique<TosaSerializationHandler>();
- // A main block should only be added once.
- bool isMain = true;
+ std::vector<std::string> graphInputs;
+ std::vector<std::string> graphOutputs;
+
+ std::vector<TosaSerializationOperator*> operators;
+ std::vector<TosaSerializationTensor*> tensors;
auto it = subgraph.endIConnectable();
while (it != subgraph.beginIConnectable())
{
--it;
- Layer &base = *(PolymorphicDowncast<Layer*>(*it));
+ Layer& base = *(PolymorphicDowncast<Layer*>(*it));
if(base.GetType() == armnn::LayerType::Input ||
base.GetType() == armnn::LayerType::Output)
@@ -100,15 +104,44 @@ OptimizationViews TosaRefBackend::OptimizeSubgraphView(const SubgraphView& subgr
continue;
}
- tosa::TosaSerializationBasicBlock* mappings = GetTosaMappingFromLayer(&base, isMain);
- handler.get()->GetBlocks().push_back(mappings);
+ tosa::TosaSerializationBasicBlock* mappings = GetTosaMappingFromLayer(&base);
- if(isMain)
+ // Loop through inputs to see if there are any graph inputs, if so save them.
+ // If it's an input to the graph "input" can be found in the string.
+ for (uint32_t i = 0; i < mappings->GetInputs().size(); i++)
{
- isMain = false;
+ std::basic_string<char> blockInputName = mappings->GetInputs()[i];
+
+ if (blockInputName.find("input") != std::string::npos)
+ {
+ graphInputs.push_back(blockInputName);
+ }
}
+
+ // Loop through outputs to see if there are any graph outputs, if so save them.
+ // If it's an output to the graph "output" can be found in the string.
+ for (uint32_t i = 0; i < mappings->GetOutputs().size(); i++)
+ {
+ std::basic_string<char> blockOutputName = mappings->GetOutputs()[i];
+
+ if (blockOutputName.find("output") != std::string::npos)
+ {
+ graphOutputs.push_back(blockOutputName);
+ }
+ }
+
+ auto blockOperators = mappings->GetOperators();
+ operators.insert(operators.end(), blockOperators.begin(), blockOperators.end());
+
+ auto blockTensors = mappings->GetTensors();
+ tensors.insert(tensors.end(), blockTensors.begin(), blockTensors.end());
}
+ // Add all mappings to main block, the TOSA Reference Model requires the full graph to be in one block called main.
+ auto* block = new TosaSerializationBasicBlock("main", operators, tensors, graphInputs, graphOutputs);
+
+ handler.get()->GetBlocks().push_back(block);
+
auto compiledBlob =
std::make_unique<PreCompiledObjectPtr>(handler.release(), DeleteAsType<TosaSerializationHandler>);