diff options
Diffstat (limited to 'src/graph/Utils.cpp')
-rw-r--r-- | src/graph/Utils.cpp | 94 |
1 files changed, 62 insertions, 32 deletions
diff --git a/src/graph/Utils.cpp b/src/graph/Utils.cpp index 8e12689fb9..452d8ec7b2 100644 --- a/src/graph/Utils.cpp +++ b/src/graph/Utils.cpp @@ -23,8 +23,8 @@ */ #include "arm_compute/graph/Utils.h" -#include "arm_compute/graph/GraphContext.h" #include "arm_compute/graph/backends/BackendRegistry.h" +#include "arm_compute/graph/GraphContext.h" #include "arm_compute/graph/mutators/GraphMutators.h" namespace arm_compute @@ -33,41 +33,38 @@ namespace graph { bool is_target_supported(Target target) { - return backends::BackendRegistry::get().contains(target) && backends::BackendRegistry::get().find_backend(target)->is_backend_supported(); + return backends::BackendRegistry::get().contains(target) && + backends::BackendRegistry::get().find_backend(target)->is_backend_supported(); } Target get_default_target() { - if(is_target_supported(Target::NEON)) + if (is_target_supported(Target::NEON)) { return Target::NEON; } - if(is_target_supported(Target::CL)) + if (is_target_supported(Target::CL)) { return Target::CL; } - if(is_target_supported(Target::GC)) - { - return Target::GC; - } ARM_COMPUTE_ERROR("No backend exists!"); } void force_target_to_graph(Graph &g, Target target) { auto &nodes = g.nodes(); - for(auto &node : nodes) + for (auto &node : nodes) { - if(node) + if (node) { node->set_assigned_target(target); } } auto &tensors = g.tensors(); - for(auto &tensor : tensors) + for (auto &tensor : tensors) { - if(tensor) + if (tensor) { tensor->desc().target = target; } @@ -76,19 +73,18 @@ void force_target_to_graph(Graph &g, Target target) PassManager create_default_pass_manager(Target target, const GraphConfig &cfg) { + ARM_COMPUTE_UNUSED(target); PassManager pm; - const bool is_target_gc = target == Target::GC; - // Passes that mutate graph IR - if(cfg.use_synthetic_type) + if (cfg.use_synthetic_type) { - switch(cfg.synthetic_type) + switch (cfg.synthetic_type) { case DataType::QASYMM8: case DataType::QASYMM8_SIGNED: { - pm.append(std::make_unique<SyntheticDataTypeMutator>(cfg.synthetic_type), !is_target_gc); + pm.append(std::make_unique<SyntheticDataTypeMutator>(cfg.synthetic_type)); break; } default: @@ -98,13 +94,13 @@ PassManager create_default_pass_manager(Target target, const GraphConfig &cfg) } } } - pm.append(std::make_unique<NodeFusionMutator>(), !is_target_gc); + pm.append(std::make_unique<NodeFusionMutator>()); pm.append(std::make_unique<GroupedConvolutionMutator>()); - pm.append(std::make_unique<InPlaceOperationMutator>(), !is_target_gc); + pm.append(std::make_unique<InPlaceOperationMutator>()); // Passes that mutate backend information - pm.append(std::make_unique<DepthConcatSubTensorMutator>(), !is_target_gc); - pm.append(std::make_unique<SplitLayerSubTensorMutator>(), !is_target_gc); + pm.append(std::make_unique<DepthConcatSubTensorMutator>()); + pm.append(std::make_unique<SplitLayerSubTensorMutator>()); pm.append(std::make_unique<NodeExecutionMethodMutator>()); return pm; @@ -112,21 +108,32 @@ PassManager create_default_pass_manager(Target target, const GraphConfig &cfg) void release_default_graph_context(GraphContext &ctx) { - for(const auto &backend : backends::BackendRegistry::get().backends()) + for (const auto &backend : backends::BackendRegistry::get().backends()) { - if(backend.second->is_backend_supported()) + if (backend.second->is_backend_supported()) { backend.second->release_backend_context(ctx); } } } +void sync_backends() +{ + for (const auto &backend : backends::BackendRegistry::get().backends()) + { + if (backend.second->backend_allocator()) + { + backend.second->sync(); + } + } +} + void setup_requested_backend_context(GraphContext &ctx, Target target) { - if(backends::BackendRegistry::get().contains(target)) + if (backends::BackendRegistry::get().contains(target)) { const auto &backend = backends::BackendRegistry::get().find_backend(target); - if(backend->is_backend_supported()) + if (backend->is_backend_supported()) { backend->setup_backend_context(ctx); } @@ -135,20 +142,22 @@ void setup_requested_backend_context(GraphContext &ctx, Target target) size_t get_dimension_size(const TensorDescriptor &descriptor, const DataLayoutDimension data_layout_dimension) { - ARM_COMPUTE_ERROR_ON_MSG(descriptor.layout == DataLayout::UNKNOWN, "Cannot retrieve the dimension index for an unknown layout!"); + ARM_COMPUTE_ERROR_ON_MSG(descriptor.layout == DataLayout::UNKNOWN, + "Cannot retrieve the dimension index for an unknown layout!"); return descriptor.shape[get_dimension_idx(descriptor.layout, data_layout_dimension)]; } size_t get_dimension_idx(DataLayout data_layout, const DataLayoutDimension data_layout_dimension) { - ARM_COMPUTE_ERROR_ON_MSG(data_layout == DataLayout::UNKNOWN, "Cannot retrieve the dimension index for an unknown layout!"); + ARM_COMPUTE_ERROR_ON_MSG(data_layout == DataLayout::UNKNOWN, + "Cannot retrieve the dimension index for an unknown layout!"); /* Return the index based on the data layout * [N C H W] * [3 2 1 0] * [N H W C] */ - switch(data_layout_dimension) + switch (data_layout_dimension) { case DataLayoutDimension::CHANNEL: return (data_layout == DataLayout::NCHW) ? 2 : 0; @@ -175,22 +184,42 @@ std::vector<NodeIdxPair> get_driving_nodes(const INode &node) const Graph *g = node.graph(); ARM_COMPUTE_ERROR_ON(g == nullptr); - for(auto &output_edge_id : node.output_edges()) + for (auto &output_edge_id : node.output_edges()) { auto output_edge = g->edge(output_edge_id); - if(output_edge != nullptr) + if (output_edge != nullptr) { ARM_COMPUTE_ERROR_ON(output_edge->consumer() == nullptr); - driving_nodes.push_back({ output_edge->consumer_id(), output_edge->consumer_idx() }); + driving_nodes.push_back({output_edge->consumer_id(), output_edge->consumer_idx()}); } } return driving_nodes; } +std::vector<NodeIdxPair> get_driver_nodes(const INode &node) +{ + std::vector<NodeIdxPair> driver_nodes; + + const Graph *g = node.graph(); + ARM_COMPUTE_ERROR_ON(g == nullptr); + + for (auto &input_edge_id : node.input_edges()) + { + auto input_edge = g->edge(input_edge_id); + if (input_edge != nullptr) + { + ARM_COMPUTE_ERROR_ON(input_edge->producer() == nullptr); + driver_nodes.push_back({input_edge->producer_id(), input_edge->producer_idx()}); + } + } + + return driver_nodes; +} + void configure_tensor(Tensor *tensor) { - if(tensor != nullptr && tensor->handle() == nullptr) + if (tensor != nullptr && tensor->handle() == nullptr) { Target target = tensor->desc().target; backends::IDeviceBackend &backend = backends::BackendRegistry::get().get_backend(target); @@ -199,5 +228,6 @@ void configure_tensor(Tensor *tensor) tensor->set_handle(std::move(handle)); } } + } // namespace graph } // namespace arm_compute |