diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/armnn/test/TensorHandleStrategyTest.cpp | 12 | ||||
-rw-r--r-- | src/backends/backendsCommon/ITensorHandleFactory.hpp | 23 | ||||
-rw-r--r-- | src/backends/cl/ClTensorHandleFactory.cpp | 11 | ||||
-rw-r--r-- | src/backends/cl/ClTensorHandleFactory.hpp | 5 | ||||
-rw-r--r-- | src/backends/neon/NeonTensorHandleFactory.cpp | 11 | ||||
-rw-r--r-- | src/backends/neon/NeonTensorHandleFactory.hpp | 5 | ||||
-rw-r--r-- | src/backends/reference/RefTensorHandleFactory.cpp | 9 | ||||
-rw-r--r-- | src/backends/reference/RefTensorHandleFactory.hpp | 6 |
8 files changed, 62 insertions, 20 deletions
diff --git a/src/armnn/test/TensorHandleStrategyTest.cpp b/src/armnn/test/TensorHandleStrategyTest.cpp index ceb6e4dbc2..3c53b13e1a 100644 --- a/src/armnn/test/TensorHandleStrategyTest.cpp +++ b/src/armnn/test/TensorHandleStrategyTest.cpp @@ -45,15 +45,13 @@ public: return nullptr; } - std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo, - const bool IsMemoryManaged) const override + std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo) const override { return nullptr; } std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo, - DataLayout dataLayout, - const bool IsMemoryManaged) const override + DataLayout dataLayout) const override { return nullptr; } @@ -85,15 +83,13 @@ public: return nullptr; } - std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo, - const bool IsMemoryManaged) const override + std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo) const override { return nullptr; } std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo, - DataLayout dataLayout, - const bool IsMemoryManaged) const override + DataLayout dataLayout) const override { return nullptr; } diff --git a/src/backends/backendsCommon/ITensorHandleFactory.hpp b/src/backends/backendsCommon/ITensorHandleFactory.hpp index c6deaef6bb..2e4742301b 100644 --- a/src/backends/backendsCommon/ITensorHandleFactory.hpp +++ b/src/backends/backendsCommon/ITensorHandleFactory.hpp @@ -8,6 +8,9 @@ #include <armnn/IRuntime.hpp> #include <armnn/MemorySources.hpp> #include <armnn/Types.hpp> +#include "ITensorHandle.hpp" + +#include <boost/core/ignore_unused.hpp> namespace armnn { @@ -25,12 +28,28 @@ public: TensorShape const& subTensorShape, unsigned int const* subTensorOrigin) const = 0; + virtual std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo) const = 0; + + virtual std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo, + DataLayout dataLayout) const = 0; + + // Utility Functions for backends which require TensorHandles to have unmanaged memory. + // These should be overloaded if required to facilitate direct import of input tensors + // and direct export of output tensors. virtual std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo, - const bool IsMemoryManaged = true) const = 0; + const bool IsMemoryManaged) const + { + boost::ignore_unused(IsMemoryManaged); + return CreateTensorHandle(tensorInfo); + } virtual std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo, DataLayout dataLayout, - const bool IsMemoryManaged = true) const = 0; + const bool IsMemoryManaged) const + { + boost::ignore_unused(IsMemoryManaged); + return CreateTensorHandle(tensorInfo, dataLayout); + } virtual const FactoryId& GetId() const = 0; diff --git a/src/backends/cl/ClTensorHandleFactory.cpp b/src/backends/cl/ClTensorHandleFactory.cpp index 3d9908a1ac..9df3f1a4a6 100644 --- a/src/backends/cl/ClTensorHandleFactory.cpp +++ b/src/backends/cl/ClTensorHandleFactory.cpp @@ -45,6 +45,17 @@ std::unique_ptr<ITensorHandle> ClTensorHandleFactory::CreateSubTensorHandle(ITen boost::polymorphic_downcast<IClTensorHandle *>(&parent), shape, coords); } +std::unique_ptr<ITensorHandle> ClTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo) const +{ + return ClTensorHandleFactory::CreateTensorHandle(tensorInfo, true); +} + +std::unique_ptr<ITensorHandle> ClTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo, + DataLayout dataLayout) const +{ + return ClTensorHandleFactory::CreateTensorHandle(tensorInfo, dataLayout, true); +} + std::unique_ptr<ITensorHandle> ClTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo, const bool IsMemoryManaged) const { diff --git a/src/backends/cl/ClTensorHandleFactory.hpp b/src/backends/cl/ClTensorHandleFactory.hpp index ea3728f7f7..f0d427a6fb 100644 --- a/src/backends/cl/ClTensorHandleFactory.hpp +++ b/src/backends/cl/ClTensorHandleFactory.hpp @@ -28,6 +28,11 @@ public: const TensorShape& subTensorShape, const unsigned int* subTensorOrigin) const override; + std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo) const override; + + std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo, + DataLayout dataLayout) const override; + std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo, const bool IsMemoryManaged = true) const override; diff --git a/src/backends/neon/NeonTensorHandleFactory.cpp b/src/backends/neon/NeonTensorHandleFactory.cpp index 8296b8315c..4ccbb7b64f 100644 --- a/src/backends/neon/NeonTensorHandleFactory.cpp +++ b/src/backends/neon/NeonTensorHandleFactory.cpp @@ -39,6 +39,17 @@ std::unique_ptr<ITensorHandle> NeonTensorHandleFactory::CreateSubTensorHandle(IT boost::polymorphic_downcast<IAclTensorHandle*>(&parent), shape, coords); } +std::unique_ptr<ITensorHandle> NeonTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo) const +{ + return NeonTensorHandleFactory::CreateTensorHandle(tensorInfo, true); +} + +std::unique_ptr<ITensorHandle> NeonTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo, + DataLayout dataLayout) const +{ + return NeonTensorHandleFactory::CreateTensorHandle(tensorInfo, dataLayout, true); +} + std::unique_ptr<ITensorHandle> NeonTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo, const bool IsMemoryManaged) const { diff --git a/src/backends/neon/NeonTensorHandleFactory.hpp b/src/backends/neon/NeonTensorHandleFactory.hpp index b03433352e..d9b64045e6 100644 --- a/src/backends/neon/NeonTensorHandleFactory.hpp +++ b/src/backends/neon/NeonTensorHandleFactory.hpp @@ -26,6 +26,11 @@ public: const TensorShape& subTensorShape, const unsigned int* subTensorOrigin) const override; + std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo) const override; + + std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo, + DataLayout dataLayout) const override; + std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo, const bool IsMemoryManaged = true) const override; diff --git a/src/backends/reference/RefTensorHandleFactory.cpp b/src/backends/reference/RefTensorHandleFactory.cpp index 089f5e3325..c97a779cb3 100644 --- a/src/backends/reference/RefTensorHandleFactory.cpp +++ b/src/backends/reference/RefTensorHandleFactory.cpp @@ -27,18 +27,15 @@ std::unique_ptr<ITensorHandle> RefTensorHandleFactory::CreateSubTensorHandle(ITe return nullptr; } -std::unique_ptr<ITensorHandle> RefTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo, - const bool IsMemoryManaged) const +std::unique_ptr<ITensorHandle> RefTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo) const { - boost::ignore_unused(IsMemoryManaged); return std::make_unique<RefTensorHandle>(tensorInfo, m_MemoryManager, m_ImportFlags); } std::unique_ptr<ITensorHandle> RefTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo, - DataLayout dataLayout, - const bool IsMemoryManaged) const + DataLayout dataLayout) const { - boost::ignore_unused(dataLayout, IsMemoryManaged); + boost::ignore_unused(dataLayout); return std::make_unique<RefTensorHandle>(tensorInfo, m_MemoryManager, m_ImportFlags); } diff --git a/src/backends/reference/RefTensorHandleFactory.hpp b/src/backends/reference/RefTensorHandleFactory.hpp index ca6af72f71..220e6fd0de 100644 --- a/src/backends/reference/RefTensorHandleFactory.hpp +++ b/src/backends/reference/RefTensorHandleFactory.hpp @@ -28,12 +28,10 @@ public: TensorShape const& subTensorShape, unsigned int const* subTensorOrigin) const override; - std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo, - const bool IsMemoryManaged = true) const override; + std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo) const override; std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo, - DataLayout dataLayout, - const bool IsMemoryManaged = true) const override; + DataLayout dataLayout) const override; static const FactoryId& GetIdStatic(); |