diff options
Diffstat (limited to 'src/armnn/Network.cpp')
-rw-r--r-- | src/armnn/Network.cpp | 11 |
1 files changed, 6 insertions, 5 deletions
diff --git a/src/armnn/Network.cpp b/src/armnn/Network.cpp index ac5159a855..c2bf27aa9b 100644 --- a/src/armnn/Network.cpp +++ b/src/armnn/Network.cpp @@ -24,6 +24,7 @@ #include <armnn/Logging.hpp> #include <armnn/utility/Assert.hpp> #include <armnn/utility/IgnoreUnused.hpp> +#include <armnn/utility/PolymorphicDowncast.hpp> #include <ProfilingService.hpp> @@ -53,12 +54,12 @@ armnn::INetworkPtr INetwork::Create() void INetwork::Destroy(INetwork* network) { - delete boost::polymorphic_downcast<Network*>(network); + delete PolymorphicDowncast<Network*>(network); } void IOptimizedNetwork::Destroy(IOptimizedNetwork* network) { - delete boost::polymorphic_downcast<OptimizedNetwork*>(network); + delete PolymorphicDowncast<OptimizedNetwork*>(network); } Status OptimizedNetwork::PrintGraph() @@ -149,7 +150,7 @@ bool CheckScaleSetOnQuantizedType(Layer* layer, Optional<std::vector<std::string template <typename LayerT> LayerT* ConvertBf16ToFp32Weight(Layer* l) { - LayerT* layer = boost::polymorphic_downcast<LayerT*>(l); + LayerT* layer = PolymorphicDowncast<LayerT*>(l); if ((layer->GetType() == LayerType::Convolution2d || layer->GetType() == LayerType::FullyConnected) && layer->m_Weight) { @@ -1015,12 +1016,12 @@ IOptimizedNetworkPtr Optimize(const INetwork& inNetwork, throw InvalidArgumentException("BFloat16 and Float16 optimization cannot be enabled at the same time."); } - const Network& network = *boost::polymorphic_downcast<const Network*>(&inNetwork); + const Network& network = *PolymorphicDowncast<const Network*>(&inNetwork); std::unique_ptr<Graph> graph = std::make_unique<Graph>(network.GetGraph()); auto optNet = IOptimizedNetworkPtr(new OptimizedNetwork(std::move(graph)), &IOptimizedNetwork::Destroy); - OptimizedNetwork* optNetObjPtr = boost::polymorphic_downcast<OptimizedNetwork*>(optNet.get()); + OptimizedNetwork* optNetObjPtr = PolymorphicDowncast<OptimizedNetwork*>(optNet.get()); // Get the optimized graph Graph& optGraph = optNetObjPtr->GetGraph(); |