aboutsummaryrefslogtreecommitdiff
path: root/src/backends/cl/workloads
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/cl/workloads')
-rw-r--r--src/backends/cl/workloads/ClBatchNormalizationFloatWorkload.cpp38
-rw-r--r--src/backends/cl/workloads/ClBatchNormalizationFloatWorkload.hpp7
-rw-r--r--src/backends/cl/workloads/ClConvertFp16ToFp32Workload.cpp46
-rw-r--r--src/backends/cl/workloads/ClConvertFp16ToFp32Workload.hpp13
-rw-r--r--src/backends/cl/workloads/ClConvertFp32ToFp16Workload.cpp46
-rw-r--r--src/backends/cl/workloads/ClConvertFp32ToFp16Workload.hpp13
-rw-r--r--src/backends/cl/workloads/ClConvolution2dWorkload.cpp2
-rw-r--r--src/backends/cl/workloads/ClConvolution2dWorkload.hpp2
-rw-r--r--src/backends/cl/workloads/ClFloorFloatWorkload.cpp39
-rw-r--r--src/backends/cl/workloads/ClFloorFloatWorkload.hpp5
-rw-r--r--src/backends/cl/workloads/ClL2NormalizationFloatWorkload.cpp38
-rw-r--r--src/backends/cl/workloads/ClL2NormalizationFloatWorkload.hpp6
-rw-r--r--src/backends/cl/workloads/ClLstmFloatWorkload.cpp38
-rw-r--r--src/backends/cl/workloads/ClLstmFloatWorkload.hpp5
-rw-r--r--src/backends/cl/workloads/ClNormalizationFloatWorkload.cpp38
-rw-r--r--src/backends/cl/workloads/ClNormalizationFloatWorkload.hpp5
16 files changed, 337 insertions, 4 deletions
diff --git a/src/backends/cl/workloads/ClBatchNormalizationFloatWorkload.cpp b/src/backends/cl/workloads/ClBatchNormalizationFloatWorkload.cpp
index 992abc2f56..389605f17d 100644
--- a/src/backends/cl/workloads/ClBatchNormalizationFloatWorkload.cpp
+++ b/src/backends/cl/workloads/ClBatchNormalizationFloatWorkload.cpp
@@ -124,4 +124,42 @@ void ClBatchNormalizationFloatWorkload::FreeUnusedTensors()
FreeTensorIfUnused(m_Beta);
}
+void ClBatchNormalizationFloatWorkload::ReplaceInputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot)
+{
+ ITensorHandle* backupHandle = this->m_Data.m_Inputs[slot];
+ this->m_Data.m_Inputs[slot] = tensorHandle;
+ try
+ {
+ Reconfigure();
+ }
+ catch(armnn::UnimplementedException& e)
+ {
+ // Cannot reconfigure, revert the slot back and throw the exception.
+ this->m_Data.m_Inputs[slot] = backupHandle;
+ throw e;
+ }
+}
+
+// Replace output tensor handle with the given TensorHandle
+void ClBatchNormalizationFloatWorkload::ReplaceOutputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot)
+{
+ ITensorHandle* backupHandle = this->m_Data.m_Inputs[slot];
+ this->m_Data.m_Inputs[slot] = tensorHandle;
+ try
+ {
+ Reconfigure();
+ }
+ catch(armnn::UnimplementedException& e)
+ {
+ // Cannot reconfigure, revert the slot back and throw the exception.
+ this->m_Data.m_Inputs[slot] = backupHandle;
+ throw e;
+ }
+}
+
+void ClBatchNormalizationFloatWorkload::Reconfigure()
+{
+ throw armnn::UnimplementedException("Reconfigure not implemented for this workload");
+}
+
} //namespace armnn
diff --git a/src/backends/cl/workloads/ClBatchNormalizationFloatWorkload.hpp b/src/backends/cl/workloads/ClBatchNormalizationFloatWorkload.hpp
index dc76703382..d47663671e 100644
--- a/src/backends/cl/workloads/ClBatchNormalizationFloatWorkload.hpp
+++ b/src/backends/cl/workloads/ClBatchNormalizationFloatWorkload.hpp
@@ -32,6 +32,12 @@ public:
using FloatWorkload<BatchNormalizationQueueDescriptor>::FloatWorkload;
void Execute() const override;
+ // Replace input tensor handle with the given TensorHandle
+ void ReplaceInputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot) override;
+
+ // Replace output tensor handle with the given TensorHandle
+ void ReplaceOutputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot) override;
+
private:
mutable arm_compute::CLBatchNormalizationLayer m_Layer;
@@ -41,6 +47,7 @@ private:
std::unique_ptr<arm_compute::CLTensor> m_Beta;
void FreeUnusedTensors();
+ virtual void Reconfigure();
};
} //namespace armnn
diff --git a/src/backends/cl/workloads/ClConvertFp16ToFp32Workload.cpp b/src/backends/cl/workloads/ClConvertFp16ToFp32Workload.cpp
index 867770a112..8ccf157aca 100644
--- a/src/backends/cl/workloads/ClConvertFp16ToFp32Workload.cpp
+++ b/src/backends/cl/workloads/ClConvertFp16ToFp32Workload.cpp
@@ -25,9 +25,13 @@ ClConvertFp16ToFp32Workload::ClConvertFp16ToFp32Workload(
arm_compute::ICLTensor& input = static_cast<IClTensorHandle*>(this->m_Data.m_Inputs[0])->GetTensor();
arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(this->m_Data.m_Outputs[0])->GetTensor();
+ // Create Proxy tensor and set the initial tensor handle to it
+ m_InputProxy = std::make_unique<ICLTensorProxy>(&input);
+ m_OutputProxy = std::make_unique<ICLTensorProxy>(&output);
+
{
ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "ClConvertFp16ToFp32Workload_configure");
- m_Layer.configure(clCompileContext, &input, &output, g_AclConvertPolicy, 0);
+ m_Layer.configure(clCompileContext, m_InputProxy.get(), m_OutputProxy.get(), g_AclConvertPolicy, 0);
}
}
@@ -57,5 +61,45 @@ arm_compute::Status ClConvertFp16ToFp32WorkloadValidate(const TensorInfo& input,
return aclStatus;
}
+void ClConvertFp16ToFp32Workload::ReplaceInputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot)
+{
+ ITensorHandle* backupHandle = this->m_Data.m_Inputs[slot];
+ this->m_Data.m_Inputs[slot] = tensorHandle;
+ try
+ {
+ Reconfigure();
+ }
+ catch(armnn::UnimplementedException& e)
+ {
+ // Cannot reconfigure, revert the slot back and throw the exception.
+ this->m_Data.m_Inputs[slot] = backupHandle;
+ throw e;
+ }
+}
+
+// Replace output tensor handle with the given TensorHandle
+void ClConvertFp16ToFp32Workload::ReplaceOutputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot)
+{
+ ITensorHandle* backupHandle = this->m_Data.m_Inputs[slot];
+ this->m_Data.m_Inputs[slot] = tensorHandle;
+ try
+ {
+ Reconfigure();
+ }
+ catch(armnn::UnimplementedException& e)
+ {
+ // Cannot reconfigure, revert the slot back and throw the exception.
+ this->m_Data.m_Inputs[slot] = backupHandle;
+ throw e;
+ }
+}
+
+void ClConvertFp16ToFp32Workload::Reconfigure()
+{
+ arm_compute::ICLTensor& input = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
+ arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
+ m_InputProxy->set(&input);
+ m_OutputProxy->set(&output);
+}
} //namespace armnn
diff --git a/src/backends/cl/workloads/ClConvertFp16ToFp32Workload.hpp b/src/backends/cl/workloads/ClConvertFp16ToFp32Workload.hpp
index b392c0be2e..3c6fcd6c08 100644
--- a/src/backends/cl/workloads/ClConvertFp16ToFp32Workload.hpp
+++ b/src/backends/cl/workloads/ClConvertFp16ToFp32Workload.hpp
@@ -9,6 +9,8 @@
#include <arm_compute/runtime/CL/functions/CLDepthConvertLayer.h>
+#include <cl/ICLTensorProxy.hpp>
+
namespace armnn
{
@@ -21,8 +23,19 @@ public:
const arm_compute::CLCompileContext& clCompileContext);
virtual void Execute() const override;
+ bool SupportsTensorHandleReplacement() const override { return true;};
+
+ // Replace input tensor handle with the given TensorHandle
+ void ReplaceInputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot) override;
+
+ // Replace output tensor handle with the given TensorHandle
+ void ReplaceOutputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot) override;
private:
mutable arm_compute::CLDepthConvertLayer m_Layer;
+ virtual void Reconfigure();
+
+ std::unique_ptr<ICLTensorProxy> m_InputProxy;
+ std::unique_ptr<ICLTensorProxy> m_OutputProxy;
};
arm_compute::Status ClConvertFp16ToFp32WorkloadValidate(const TensorInfo& input, const TensorInfo& output);
diff --git a/src/backends/cl/workloads/ClConvertFp32ToFp16Workload.cpp b/src/backends/cl/workloads/ClConvertFp32ToFp16Workload.cpp
index 017fcaf454..a44a80c997 100644
--- a/src/backends/cl/workloads/ClConvertFp32ToFp16Workload.cpp
+++ b/src/backends/cl/workloads/ClConvertFp32ToFp16Workload.cpp
@@ -25,9 +25,13 @@ ClConvertFp32ToFp16Workload::ClConvertFp32ToFp16Workload(
arm_compute::ICLTensor& input = static_cast<IClTensorHandle*>(this->m_Data.m_Inputs[0])->GetTensor();
arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(this->m_Data.m_Outputs[0])->GetTensor();
+ // Create Proxy tensor and set the initial tensor handle to it
+ m_InputProxy = std::make_unique<ICLTensorProxy>(&input);
+ m_OutputProxy = std::make_unique<ICLTensorProxy>(&output);
+
{
ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "ClConvertFp32ToFp16Workload_configure");
- m_Layer.configure(clCompileContext, &input, &output, g_AclConvertPolicy, 0);
+ m_Layer.configure(clCompileContext, m_InputProxy.get(), m_OutputProxy.get(), g_AclConvertPolicy, 0);
}
}
@@ -57,5 +61,45 @@ arm_compute::Status ClConvertFp32ToFp16WorkloadValidate(const TensorInfo& input,
return aclStatus;
}
+void ClConvertFp32ToFp16Workload::ReplaceInputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot)
+{
+ ITensorHandle* backupHandle = this->m_Data.m_Inputs[slot];
+ this->m_Data.m_Inputs[slot] = tensorHandle;
+ try
+ {
+ Reconfigure();
+ }
+ catch(armnn::UnimplementedException& e)
+ {
+ // Cannot reconfigure, revert the slot back and throw the exception.
+ this->m_Data.m_Inputs[slot] = backupHandle;
+ throw e;
+ }
+}
+
+// Replace output tensor handle with the given TensorHandle
+void ClConvertFp32ToFp16Workload::ReplaceOutputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot)
+{
+ ITensorHandle* backupHandle = this->m_Data.m_Inputs[slot];
+ this->m_Data.m_Inputs[slot] = tensorHandle;
+ try
+ {
+ Reconfigure();
+ }
+ catch(armnn::UnimplementedException& e)
+ {
+ // Cannot reconfigure, revert the slot back and throw the exception.
+ this->m_Data.m_Inputs[slot] = backupHandle;
+ throw e;
+ }
+}
+
+void ClConvertFp32ToFp16Workload::Reconfigure()
+{
+ arm_compute::ICLTensor& input = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
+ arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
+ m_InputProxy->set(&input);
+ m_OutputProxy->set(&output);
+}
} //namespace armnn
diff --git a/src/backends/cl/workloads/ClConvertFp32ToFp16Workload.hpp b/src/backends/cl/workloads/ClConvertFp32ToFp16Workload.hpp
index 1d777b5256..6ce563e4f4 100644
--- a/src/backends/cl/workloads/ClConvertFp32ToFp16Workload.hpp
+++ b/src/backends/cl/workloads/ClConvertFp32ToFp16Workload.hpp
@@ -9,6 +9,8 @@
#include <arm_compute/runtime/CL/functions/CLDepthConvertLayer.h>
+#include <cl/ICLTensorProxy.hpp>
+
namespace armnn
{
@@ -21,8 +23,19 @@ public:
const arm_compute::CLCompileContext& clCompileContext);
virtual void Execute() const override;
+ bool SupportsTensorHandleReplacement() const override { return true;};
+
+ // Replace input tensor handle with the given TensorHandle
+ void ReplaceInputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot) override;
+
+ // Replace output tensor handle with the given TensorHandle
+ void ReplaceOutputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot) override;
private:
mutable arm_compute::CLDepthConvertLayer m_Layer;
+ virtual void Reconfigure();
+
+ std::unique_ptr<ICLTensorProxy> m_InputProxy;
+ std::unique_ptr<ICLTensorProxy> m_OutputProxy;
};
arm_compute::Status ClConvertFp32ToFp16WorkloadValidate(const TensorInfo& input, const TensorInfo& output);
diff --git a/src/backends/cl/workloads/ClConvolution2dWorkload.cpp b/src/backends/cl/workloads/ClConvolution2dWorkload.cpp
index cdfa885f67..bf82fbf255 100644
--- a/src/backends/cl/workloads/ClConvolution2dWorkload.cpp
+++ b/src/backends/cl/workloads/ClConvolution2dWorkload.cpp
@@ -180,9 +180,9 @@ void ClConvolution2dWorkload::FreeUnusedTensors()
void ClConvolution2dWorkload::Reconfigure()
{
- ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "ClConvolution2dWorkload_Reconfigure");
arm_compute::ICLTensor& input = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
+
m_InputProxy->set(&input);
m_OutputProxy->set(&output);
}
diff --git a/src/backends/cl/workloads/ClConvolution2dWorkload.hpp b/src/backends/cl/workloads/ClConvolution2dWorkload.hpp
index 891d5096cd..e4177e4327 100644
--- a/src/backends/cl/workloads/ClConvolution2dWorkload.hpp
+++ b/src/backends/cl/workloads/ClConvolution2dWorkload.hpp
@@ -40,6 +40,8 @@ public:
arm_compute::ConvolutionMethod GetConvolutionMethod() const;
+ bool SupportsTensorHandleReplacement() const override { return true;};
+
protected:
void Reconfigure() override;
diff --git a/src/backends/cl/workloads/ClFloorFloatWorkload.cpp b/src/backends/cl/workloads/ClFloorFloatWorkload.cpp
index 679e225c63..0aae1a30e3 100644
--- a/src/backends/cl/workloads/ClFloorFloatWorkload.cpp
+++ b/src/backends/cl/workloads/ClFloorFloatWorkload.cpp
@@ -29,7 +29,6 @@ ClFloorFloatWorkload::ClFloorFloatWorkload(const FloorQueueDescriptor& descripto
arm_compute::ICLTensor& input = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
-
{
ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "ClFloorFloatWorkload_configure");
m_Layer.configure(clCompileContext, &input, &output);
@@ -42,4 +41,42 @@ void ClFloorFloatWorkload::Execute() const
RunClFunction(m_Layer, CHECK_LOCATION());
}
+void ClFloorFloatWorkload::ReplaceInputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot)
+{
+ ITensorHandle* backupHandle = this->m_Data.m_Inputs[slot];
+ this->m_Data.m_Inputs[slot] = tensorHandle;
+ try
+ {
+ Reconfigure();
+ }
+ catch(armnn::UnimplementedException& e)
+ {
+ // Cannot reconfigure, revert the slot back and throw the exception.
+ this->m_Data.m_Inputs[slot] = backupHandle;
+ throw e;
+ }
+}
+
+// Replace output tensor handle with the given TensorHandle
+void ClFloorFloatWorkload::ReplaceOutputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot)
+{
+ ITensorHandle* backupHandle = this->m_Data.m_Inputs[slot];
+ this->m_Data.m_Inputs[slot] = tensorHandle;
+ try
+ {
+ Reconfigure();
+ }
+ catch(armnn::UnimplementedException& e)
+ {
+ // Cannot reconfigure, revert the slot back and throw the exception.
+ this->m_Data.m_Inputs[slot] = backupHandle;
+ throw e;
+ }
+}
+
+void ClFloorFloatWorkload::Reconfigure()
+{
+ throw armnn::UnimplementedException("Reconfigure not implemented for this workload");
+}
+
} //namespace armnn
diff --git a/src/backends/cl/workloads/ClFloorFloatWorkload.hpp b/src/backends/cl/workloads/ClFloorFloatWorkload.hpp
index 5740c6887a..dbe5f6f163 100644
--- a/src/backends/cl/workloads/ClFloorFloatWorkload.hpp
+++ b/src/backends/cl/workloads/ClFloorFloatWorkload.hpp
@@ -23,9 +23,14 @@ public:
const arm_compute::CLCompileContext& clCompileContext);
void Execute() const override;
+ // Replace input tensor handle with the given TensorHandle
+ void ReplaceInputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot) override;
+ // Replace output tensor handle with the given TensorHandle
+ void ReplaceOutputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot) override;
private:
mutable arm_compute::CLFloor m_Layer;
+ virtual void Reconfigure();
};
} //namespace armnn
diff --git a/src/backends/cl/workloads/ClL2NormalizationFloatWorkload.cpp b/src/backends/cl/workloads/ClL2NormalizationFloatWorkload.cpp
index b34153fff0..d120fb28f6 100644
--- a/src/backends/cl/workloads/ClL2NormalizationFloatWorkload.cpp
+++ b/src/backends/cl/workloads/ClL2NormalizationFloatWorkload.cpp
@@ -60,4 +60,42 @@ void ClL2NormalizationFloatWorkload::Execute() const
RunClFunction(m_Layer, CHECK_LOCATION());
}
+void ClL2NormalizationFloatWorkload::ReplaceInputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot)
+{
+ ITensorHandle* backupHandle = this->m_Data.m_Inputs[slot];
+ this->m_Data.m_Inputs[slot] = tensorHandle;
+ try
+ {
+ Reconfigure();
+ }
+ catch(armnn::UnimplementedException& e)
+ {
+ // Cannot reconfigure, revert the slot back and throw the exception.
+ this->m_Data.m_Inputs[slot] = backupHandle;
+ throw e;
+ }
+}
+
+// Replace output tensor handle with the given TensorHandle
+void ClL2NormalizationFloatWorkload::ReplaceOutputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot)
+{
+ ITensorHandle* backupHandle = this->m_Data.m_Inputs[slot];
+ this->m_Data.m_Inputs[slot] = tensorHandle;
+ try
+ {
+ Reconfigure();
+ }
+ catch(armnn::UnimplementedException& e)
+ {
+ // Cannot reconfigure, revert the slot back and throw the exception.
+ this->m_Data.m_Inputs[slot] = backupHandle;
+ throw e;
+ }
+}
+
+void ClL2NormalizationFloatWorkload::Reconfigure()
+{
+ throw armnn::UnimplementedException("Reconfigure not implemented for this workload");
+}
+
} //namespace armnn
diff --git a/src/backends/cl/workloads/ClL2NormalizationFloatWorkload.hpp b/src/backends/cl/workloads/ClL2NormalizationFloatWorkload.hpp
index cfa1a97eec..67e7b8b7b1 100644
--- a/src/backends/cl/workloads/ClL2NormalizationFloatWorkload.hpp
+++ b/src/backends/cl/workloads/ClL2NormalizationFloatWorkload.hpp
@@ -24,10 +24,16 @@ public:
const arm_compute::CLCompileContext& clCompileContext);
void Execute() const override;
+ // Replace input tensor handle with the given TensorHandle
+ void ReplaceInputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot) override;
+
+ // Replace output tensor handle with the given TensorHandle
+ void ReplaceOutputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot) override;
private:
// Purposely not a CLL2Normalize function. See constructor.
mutable arm_compute::CLL2NormalizeLayer m_Layer;
+ virtual void Reconfigure();
};
} //namespace armnn
diff --git a/src/backends/cl/workloads/ClLstmFloatWorkload.cpp b/src/backends/cl/workloads/ClLstmFloatWorkload.cpp
index d8d95f5c74..37dfab6a5f 100644
--- a/src/backends/cl/workloads/ClLstmFloatWorkload.cpp
+++ b/src/backends/cl/workloads/ClLstmFloatWorkload.cpp
@@ -446,4 +446,42 @@ void ClLstmFloatWorkload::FreeUnusedTensors()
FreeTensorIfUnused(m_OutputLayerNormWeightsTensor);
}
+void ClLstmFloatWorkload::ReplaceInputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot)
+{
+ ITensorHandle* backupHandle = this->m_Data.m_Inputs[slot];
+ this->m_Data.m_Inputs[slot] = tensorHandle;
+ try
+ {
+ Reconfigure();
+ }
+ catch(armnn::UnimplementedException& e)
+ {
+ // Cannot reconfigure, revert the slot back and throw the exception.
+ this->m_Data.m_Inputs[slot] = backupHandle;
+ throw e;
+ }
+}
+
+// Replace output tensor handle with the given TensorHandle
+void ClLstmFloatWorkload::ReplaceOutputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot)
+{
+ ITensorHandle* backupHandle = this->m_Data.m_Inputs[slot];
+ this->m_Data.m_Inputs[slot] = tensorHandle;
+ try
+ {
+ Reconfigure();
+ }
+ catch(armnn::UnimplementedException& e)
+ {
+ // Cannot reconfigure, revert the slot back and throw the exception.
+ this->m_Data.m_Inputs[slot] = backupHandle;
+ throw e;
+ }
+}
+
+void ClLstmFloatWorkload::Reconfigure()
+{
+ throw armnn::UnimplementedException("Reconfigure not implemented for this workload");
+}
+
} //namespace armnn
diff --git a/src/backends/cl/workloads/ClLstmFloatWorkload.hpp b/src/backends/cl/workloads/ClLstmFloatWorkload.hpp
index b9faca8b54..54c5c600dc 100644
--- a/src/backends/cl/workloads/ClLstmFloatWorkload.hpp
+++ b/src/backends/cl/workloads/ClLstmFloatWorkload.hpp
@@ -22,9 +22,14 @@ public:
const WorkloadInfo& info,
const arm_compute::CLCompileContext& clCompileContext);
void Execute() const override;
+ // Replace input tensor handle with the given TensorHandle
+ void ReplaceInputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot) override;
+ // Replace output tensor handle with the given TensorHandle
+ void ReplaceOutputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot) override;
private:
mutable arm_compute::CLLSTMLayer m_LstmLayer;
+ virtual void Reconfigure();
std::unique_ptr<arm_compute::CLTensor> m_InputToInputWeightsTensor;
std::unique_ptr<arm_compute::CLTensor> m_InputToForgetWeightsTensor;
diff --git a/src/backends/cl/workloads/ClNormalizationFloatWorkload.cpp b/src/backends/cl/workloads/ClNormalizationFloatWorkload.cpp
index d98532d7d1..8de8dd5c3b 100644
--- a/src/backends/cl/workloads/ClNormalizationFloatWorkload.cpp
+++ b/src/backends/cl/workloads/ClNormalizationFloatWorkload.cpp
@@ -62,4 +62,42 @@ void ClNormalizationFloatWorkload::Execute() const
RunClFunction(m_NormalizationLayer, CHECK_LOCATION());
}
+void ClNormalizationFloatWorkload::ReplaceInputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot)
+{
+ ITensorHandle* backupHandle = this->m_Data.m_Inputs[slot];
+ this->m_Data.m_Inputs[slot] = tensorHandle;
+ try
+ {
+ Reconfigure();
+ }
+ catch(armnn::UnimplementedException& e)
+ {
+ // Cannot reconfigure, revert the slot back and throw the exception.
+ this->m_Data.m_Inputs[slot] = backupHandle;
+ throw e;
+ }
+}
+
+// Replace output tensor handle with the given TensorHandle
+void ClNormalizationFloatWorkload::ReplaceOutputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot)
+{
+ ITensorHandle* backupHandle = this->m_Data.m_Inputs[slot];
+ this->m_Data.m_Inputs[slot] = tensorHandle;
+ try
+ {
+ Reconfigure();
+ }
+ catch(armnn::UnimplementedException& e)
+ {
+ // Cannot reconfigure, revert the slot back and throw the exception.
+ this->m_Data.m_Inputs[slot] = backupHandle;
+ throw e;
+ }
+}
+
+void ClNormalizationFloatWorkload::Reconfigure()
+{
+ throw armnn::UnimplementedException("Reconfigure not implemented for this workload");
+}
+
} //namespace armnn
diff --git a/src/backends/cl/workloads/ClNormalizationFloatWorkload.hpp b/src/backends/cl/workloads/ClNormalizationFloatWorkload.hpp
index 40b2693cd4..d9db0f2de3 100644
--- a/src/backends/cl/workloads/ClNormalizationFloatWorkload.hpp
+++ b/src/backends/cl/workloads/ClNormalizationFloatWorkload.hpp
@@ -23,9 +23,14 @@ public:
const WorkloadInfo& info,
const arm_compute::CLCompileContext& clCompileContext);
void Execute() const override;
+ // Replace input tensor handle with the given TensorHandle
+ void ReplaceInputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot) override;
+ // Replace output tensor handle with the given TensorHandle
+ void ReplaceOutputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot) override;
private:
mutable arm_compute::CLNormalizationLayer m_NormalizationLayer;
+ virtual void Reconfigure();
};
} //namespace armnn