diff options
Diffstat (limited to 'src/backends/cl/ClImportTensorHandleFactory.cpp')
-rw-r--r-- | src/backends/cl/ClImportTensorHandleFactory.cpp | 27 |
1 files changed, 12 insertions, 15 deletions
diff --git a/src/backends/cl/ClImportTensorHandleFactory.cpp b/src/backends/cl/ClImportTensorHandleFactory.cpp index 1812034814..594e05423e 100644 --- a/src/backends/cl/ClImportTensorHandleFactory.cpp +++ b/src/backends/cl/ClImportTensorHandleFactory.cpp @@ -4,7 +4,7 @@ // #include "ClImportTensorHandleFactory.hpp" -#include "ClTensorHandle.hpp" +#include "ClImportTensorHandle.hpp" #include <armnn/utility/NumericCast.hpp> #include <armnn/utility/PolymorphicDowncast.hpp> @@ -49,48 +49,45 @@ std::unique_ptr<ITensorHandle> ClImportTensorHandleFactory::CreateSubTensorHandl return nullptr; } - return std::make_unique<ClSubTensorHandle>(PolymorphicDowncast<IClTensorHandle*>(&parent), shape, coords); + return std::make_unique<ClImportSubTensorHandle>( + PolymorphicDowncast<IClImportTensorHandle*>(&parent), shape, coords); } std::unique_ptr<ITensorHandle> ClImportTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo) const { - return ClImportTensorHandleFactory::CreateTensorHandle(tensorInfo, false); + std::unique_ptr<ClImportTensorHandle> tensorHandle = std::make_unique<ClImportTensorHandle>(tensorInfo, + GetImportFlags()); + return tensorHandle; } std::unique_ptr<ITensorHandle> ClImportTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo, DataLayout dataLayout) const { - return ClImportTensorHandleFactory::CreateTensorHandle(tensorInfo, dataLayout, false); + std::unique_ptr<ClImportTensorHandle> tensorHandle = std::make_unique<ClImportTensorHandle>(tensorInfo, + dataLayout, + GetImportFlags()); + return tensorHandle; } std::unique_ptr<ITensorHandle> ClImportTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo, const bool IsMemoryManaged) const { - // If IsMemoryManaged is true then throw an exception. if (IsMemoryManaged) { throw InvalidArgumentException("ClImportTensorHandleFactory does not support memory managed tensors."); } - std::unique_ptr<ClTensorHandle> tensorHandle = std::make_unique<ClTensorHandle>(tensorInfo); - tensorHandle->SetImportEnabledFlag(true); - tensorHandle->SetImportFlags(GetImportFlags()); - return tensorHandle; + return CreateTensorHandle(tensorInfo); } std::unique_ptr<ITensorHandle> ClImportTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo, DataLayout dataLayout, const bool IsMemoryManaged) const { - // If IsMemoryManaged is true then throw an exception. if (IsMemoryManaged) { throw InvalidArgumentException("ClImportTensorHandleFactory does not support memory managed tensors."); } - std::unique_ptr<ClTensorHandle> tensorHandle = std::make_unique<ClTensorHandle>(tensorInfo, dataLayout); - // If we are not Managing the Memory then we must be importing - tensorHandle->SetImportEnabledFlag(true); - tensorHandle->SetImportFlags(GetImportFlags()); - return tensorHandle; + return CreateTensorHandle(tensorInfo, dataLayout); } const FactoryId& ClImportTensorHandleFactory::GetIdStatic() |