aboutsummaryrefslogtreecommitdiff
path: root/src/backends/cl/ClImportTensorHandleFactory.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/cl/ClImportTensorHandleFactory.cpp')
-rw-r--r--src/backends/cl/ClImportTensorHandleFactory.cpp27
1 files changed, 12 insertions, 15 deletions
diff --git a/src/backends/cl/ClImportTensorHandleFactory.cpp b/src/backends/cl/ClImportTensorHandleFactory.cpp
index 1812034814..594e05423e 100644
--- a/src/backends/cl/ClImportTensorHandleFactory.cpp
+++ b/src/backends/cl/ClImportTensorHandleFactory.cpp
@@ -4,7 +4,7 @@
//
#include "ClImportTensorHandleFactory.hpp"
-#include "ClTensorHandle.hpp"
+#include "ClImportTensorHandle.hpp"
#include <armnn/utility/NumericCast.hpp>
#include <armnn/utility/PolymorphicDowncast.hpp>
@@ -49,48 +49,45 @@ std::unique_ptr<ITensorHandle> ClImportTensorHandleFactory::CreateSubTensorHandl
return nullptr;
}
- return std::make_unique<ClSubTensorHandle>(PolymorphicDowncast<IClTensorHandle*>(&parent), shape, coords);
+ return std::make_unique<ClImportSubTensorHandle>(
+ PolymorphicDowncast<IClImportTensorHandle*>(&parent), shape, coords);
}
std::unique_ptr<ITensorHandle> ClImportTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo) const
{
- return ClImportTensorHandleFactory::CreateTensorHandle(tensorInfo, false);
+ std::unique_ptr<ClImportTensorHandle> tensorHandle = std::make_unique<ClImportTensorHandle>(tensorInfo,
+ GetImportFlags());
+ return tensorHandle;
}
std::unique_ptr<ITensorHandle> ClImportTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
DataLayout dataLayout) const
{
- return ClImportTensorHandleFactory::CreateTensorHandle(tensorInfo, dataLayout, false);
+ std::unique_ptr<ClImportTensorHandle> tensorHandle = std::make_unique<ClImportTensorHandle>(tensorInfo,
+ dataLayout,
+ GetImportFlags());
+ return tensorHandle;
}
std::unique_ptr<ITensorHandle> ClImportTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
const bool IsMemoryManaged) const
{
- // If IsMemoryManaged is true then throw an exception.
if (IsMemoryManaged)
{
throw InvalidArgumentException("ClImportTensorHandleFactory does not support memory managed tensors.");
}
- std::unique_ptr<ClTensorHandle> tensorHandle = std::make_unique<ClTensorHandle>(tensorInfo);
- tensorHandle->SetImportEnabledFlag(true);
- tensorHandle->SetImportFlags(GetImportFlags());
- return tensorHandle;
+ return CreateTensorHandle(tensorInfo);
}
std::unique_ptr<ITensorHandle> ClImportTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
DataLayout dataLayout,
const bool IsMemoryManaged) const
{
- // If IsMemoryManaged is true then throw an exception.
if (IsMemoryManaged)
{
throw InvalidArgumentException("ClImportTensorHandleFactory does not support memory managed tensors.");
}
- std::unique_ptr<ClTensorHandle> tensorHandle = std::make_unique<ClTensorHandle>(tensorInfo, dataLayout);
- // If we are not Managing the Memory then we must be importing
- tensorHandle->SetImportEnabledFlag(true);
- tensorHandle->SetImportFlags(GetImportFlags());
- return tensorHandle;
+ return CreateTensorHandle(tensorInfo, dataLayout);
}
const FactoryId& ClImportTensorHandleFactory::GetIdStatic()