aboutsummaryrefslogtreecommitdiff
path: root/src/backends/cl/ClImportTensorHandleFactory.cpp
diff options
context:
space:
mode:
authorDavid Monahan <david.monahan@arm.com>2021-04-14 16:55:36 +0100
committerNarumol Prangnawarat <narumol.prangnawarat@arm.com>2021-05-04 17:06:52 +0100
commite4a41dc5d5fc0f283c01b3260affdfdf6cfc1895 (patch)
treed0873f60422b46bdbe8ebe1dbef22bc87bf75175 /src/backends/cl/ClImportTensorHandleFactory.cpp
parentc21025dc3c07d60568dd27d816bcdf0575f7695a (diff)
downloadarmnn-e4a41dc5d5fc0f283c01b3260affdfdf6cfc1895.tar.gz
IVGCVSW-5727 Implement Import function of ClImportTensorHandle
* Split ClImportTensorHandle out from ClTenorHandle * Added implementation of Import function * Added Unit Tests Signed-off-by: David Monahan <david.monahan@arm.com> Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com> Change-Id: I07de2ca5bebf19dfb9a8dddea4b18340ffc31fad
Diffstat (limited to 'src/backends/cl/ClImportTensorHandleFactory.cpp')
-rw-r--r--src/backends/cl/ClImportTensorHandleFactory.cpp27
1 files changed, 12 insertions, 15 deletions
diff --git a/src/backends/cl/ClImportTensorHandleFactory.cpp b/src/backends/cl/ClImportTensorHandleFactory.cpp
index 1812034814..594e05423e 100644
--- a/src/backends/cl/ClImportTensorHandleFactory.cpp
+++ b/src/backends/cl/ClImportTensorHandleFactory.cpp
@@ -4,7 +4,7 @@
//
#include "ClImportTensorHandleFactory.hpp"
-#include "ClTensorHandle.hpp"
+#include "ClImportTensorHandle.hpp"
#include <armnn/utility/NumericCast.hpp>
#include <armnn/utility/PolymorphicDowncast.hpp>
@@ -49,48 +49,45 @@ std::unique_ptr<ITensorHandle> ClImportTensorHandleFactory::CreateSubTensorHandl
return nullptr;
}
- return std::make_unique<ClSubTensorHandle>(PolymorphicDowncast<IClTensorHandle*>(&parent), shape, coords);
+ return std::make_unique<ClImportSubTensorHandle>(
+ PolymorphicDowncast<IClImportTensorHandle*>(&parent), shape, coords);
}
std::unique_ptr<ITensorHandle> ClImportTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo) const
{
- return ClImportTensorHandleFactory::CreateTensorHandle(tensorInfo, false);
+ std::unique_ptr<ClImportTensorHandle> tensorHandle = std::make_unique<ClImportTensorHandle>(tensorInfo,
+ GetImportFlags());
+ return tensorHandle;
}
std::unique_ptr<ITensorHandle> ClImportTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
DataLayout dataLayout) const
{
- return ClImportTensorHandleFactory::CreateTensorHandle(tensorInfo, dataLayout, false);
+ std::unique_ptr<ClImportTensorHandle> tensorHandle = std::make_unique<ClImportTensorHandle>(tensorInfo,
+ dataLayout,
+ GetImportFlags());
+ return tensorHandle;
}
std::unique_ptr<ITensorHandle> ClImportTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
const bool IsMemoryManaged) const
{
- // If IsMemoryManaged is true then throw an exception.
if (IsMemoryManaged)
{
throw InvalidArgumentException("ClImportTensorHandleFactory does not support memory managed tensors.");
}
- std::unique_ptr<ClTensorHandle> tensorHandle = std::make_unique<ClTensorHandle>(tensorInfo);
- tensorHandle->SetImportEnabledFlag(true);
- tensorHandle->SetImportFlags(GetImportFlags());
- return tensorHandle;
+ return CreateTensorHandle(tensorInfo);
}
std::unique_ptr<ITensorHandle> ClImportTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
DataLayout dataLayout,
const bool IsMemoryManaged) const
{
- // If IsMemoryManaged is true then throw an exception.
if (IsMemoryManaged)
{
throw InvalidArgumentException("ClImportTensorHandleFactory does not support memory managed tensors.");
}
- std::unique_ptr<ClTensorHandle> tensorHandle = std::make_unique<ClTensorHandle>(tensorInfo, dataLayout);
- // If we are not Managing the Memory then we must be importing
- tensorHandle->SetImportEnabledFlag(true);
- tensorHandle->SetImportFlags(GetImportFlags());
- return tensorHandle;
+ return CreateTensorHandle(tensorInfo, dataLayout);
}
const FactoryId& ClImportTensorHandleFactory::GetIdStatic()