From 4fcda0101ec3d110c1d6d7bee5c83416b645528a Mon Sep 17 00:00:00 2001 From: telsoa01 Date: Fri, 9 Mar 2018 14:13:49 +0000 Subject: Release 18.02 Change-Id: Id3c11dc5ee94ef664374a988fcc6901e9a232fa6 --- src/armnn/optimizations/Optimization.hpp | 123 +++++++++++++++++++++++++++++++ 1 file changed, 123 insertions(+) create mode 100644 src/armnn/optimizations/Optimization.hpp (limited to 'src/armnn/optimizations/Optimization.hpp') diff --git a/src/armnn/optimizations/Optimization.hpp b/src/armnn/optimizations/Optimization.hpp new file mode 100644 index 0000000000..89e03ff88d --- /dev/null +++ b/src/armnn/optimizations/Optimization.hpp @@ -0,0 +1,123 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// See LICENSE file in the project root for full license information. +// +#pragma once + +#include "Graph.hpp" +#include "LayersFwd.hpp" + +namespace armnn +{ + +class Optimization +{ +public: + virtual void Run(Graph& graph, Graph::Iterator& pos) const = 0; +protected: + ~Optimization() = default; +}; + +// Wrappers +// The implementation of the following wrappers make use of the CRTP C++ idiom +// (curiously recurring template pattern). +// For details, see https://en.wikipedia.org/wiki/Curiously_recurring_template_pattern + +/// Wrapper Optimization base class that calls Wrapped::Run for every layer of type BaseType. +/// - Wrapped class mustn't remove the base layer. +/// - Base layer is removed if left unconnected after applying the wrapped optimization. +template +class OptimizeForTypeImpl : public armnn::Optimization, public Wrapped +{ +public: + using Wrapped::Wrapped; + + void Run(Graph& graph, Graph::Iterator& pos) const override + { + Layer* const base = *pos; + + if (base->GetType() == LayerEnumOf()) + { + Wrapped::Run(graph, *boost::polymorphic_downcast(base)); + } + } + +protected: + ~OptimizeForTypeImpl() = default; +}; + +/// Specialization that calls Wrapped::Run for any layer type +template +class OptimizeForTypeImpl : public armnn::Optimization, public Wrapped +{ +public: + using Wrapped::Wrapped; + + void Run(Graph& graph, Graph::Iterator& pos) const override + { + Wrapped::Run(graph, **pos); + } + +protected: + ~OptimizeForTypeImpl() = default; +}; + +template +class OptimizeForType final : public OptimizeForTypeImpl +{ +public: + using OptimizeForTypeImpl::OptimizeForTypeImpl; +}; + +/// Wrapper Optimization class that calls Wrapped::Run for every connection BaseType -> ChildType. +/// - Wrapped class mustn't remove the base layer. +/// - Wrapped class mustn't affect existing connections in the same output. It might add new ones. +/// - Base and children layers are removed if left unconnected after applying the wrapped optimization. +template +class OptimizeForConnectionImpl : public Wrapped +{ +public: + using Wrapped::Wrapped; + + void Run(Graph& graph, BaseType& base) const + { + for (auto output = base.BeginOutputSlots(); output != base.EndOutputSlots(); ++output) + { + for (auto&& childInput : output->GetConnections()) + { + if (childInput->GetOwningLayer().GetType() == LayerEnumOf()) + { + Wrapped::Run(graph, *childInput); + } + } + + // Remove unconnected children + for (unsigned int i = 0; i < output->GetNumConnections();) + { + Layer* child = &output->GetConnection(i)->GetOwningLayer(); + + if (child->IsOutputUnconnected()) + { + graph.EraseLayer(child); + } + else + { + ++i; + } + } + } + } + +protected: + ~OptimizeForConnectionImpl() = default; +}; + +template +class OptimizeForConnection final + : public OptimizeForTypeImpl> +{ +public: + using OptimizeForTypeImpl>::OptimizeForTypeImpl; +}; + +} // namespace armnn -- cgit v1.2.1