aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAron Virginas-Tar <Aron.Virginas-Tar@arm.com>2019-10-04 13:10:16 +0100
committerÁron Virginás-Tar <aron.virginas-tar@arm.com>2019-10-09 10:43:00 +0000
commit8168f407f0f2715250f388089f26ed39683ac00a (patch)
treefeae8f29b6910aba496141c5672a2fdac0b8efe5
parent784db773ec0fb32562e8889b769bc04450159161 (diff)
downloadarmnn-8168f407f0f2715250f388089f26ed39683ac00a.tar.gz
IVGCVSW-3889 Add CL workload for INSTANCE_NORMALIZATION
!android-nn-driver:2039 Signed-off-by: Kevin May <kevin.may@arm.com> Signed-off-by: Aron Virginas-Tar <Aron.Virginas-Tar@arm.com> Change-Id: I621dd80920b58b8b795ed13917b88850519c8177
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp16
-rw-r--r--src/backends/backendsCommon/WorkloadData.hpp10
-rw-r--r--src/backends/backendsCommon/common.mk1
-rw-r--r--src/backends/backendsCommon/test/CMakeLists.txt2
-rw-r--r--src/backends/backendsCommon/test/LayerTests.hpp1
-rw-r--r--src/backends/backendsCommon/test/layerTests/InstanceNormalizationTestImpl.cpp267
-rw-r--r--src/backends/backendsCommon/test/layerTests/InstanceNormalizationTestImpl.hpp36
-rw-r--r--src/backends/cl/ClLayerSupport.cpp13
-rw-r--r--src/backends/cl/ClLayerSupport.hpp5
-rw-r--r--src/backends/cl/ClWorkloadFactory.cpp355
-rw-r--r--src/backends/cl/ClWorkloadFactory.hpp173
-rw-r--r--src/backends/cl/backend.mk1
-rw-r--r--src/backends/cl/test/ClLayerTests.cpp13
-rw-r--r--src/backends/cl/workloads/CMakeLists.txt2
-rw-r--r--src/backends/cl/workloads/ClInstanceNormalizationWorkload.cpp59
-rw-r--r--src/backends/cl/workloads/ClInstanceNormalizationWorkload.hpp29
-rw-r--r--src/backends/cl/workloads/ClWorkloads.hpp1
17 files changed, 700 insertions, 284 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index aca5023f97..89277d798f 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -1261,22 +1261,6 @@ void InstanceNormalizationQueueDescriptor::Validate(const WorkloadInfo& workload
ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
-
- ValidatePointer(m_Beta, descriptorName, "beta");
- ValidatePointer(m_Eps, descriptorName, "epsilon");
- ValidatePointer(m_Gamma, descriptorName, "gamma");
-
- const TensorInfo& beta = m_Beta->GetTensorInfo();
- const TensorInfo& epsilon = m_Eps->GetTensorInfo();
- const TensorInfo& gamma = m_Gamma->GetTensorInfo();
-
- ValidateTensorNumDimensions(beta, descriptorName, 1, "beta");
- ValidateTensorNumDimensions(epsilon, descriptorName, 1, "epsilon");
- ValidateTensorNumDimensions(gamma, descriptorName, 1, "gamma");
-
- ValidateTensorDataTypesMatch(inputTensorInfo, beta, descriptorName, "input", "beta");
- ValidateTensorDataTypesMatch(inputTensorInfo, epsilon, descriptorName, "input", "epsilon");
- ValidateTensorDataTypesMatch(inputTensorInfo, gamma, descriptorName, "input", "gamma");
}
void L2NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
diff --git a/src/backends/backendsCommon/WorkloadData.hpp b/src/backends/backendsCommon/WorkloadData.hpp
index 14d7b588e1..1bf3aa7509 100644
--- a/src/backends/backendsCommon/WorkloadData.hpp
+++ b/src/backends/backendsCommon/WorkloadData.hpp
@@ -309,16 +309,6 @@ struct FakeQuantizationQueueDescriptor : QueueDescriptorWithParameters<FakeQuant
struct InstanceNormalizationQueueDescriptor : QueueDescriptorWithParameters<InstanceNormalizationDescriptor>
{
- InstanceNormalizationQueueDescriptor()
- : m_Beta(nullptr)
- , m_Eps(nullptr)
- , m_Gamma(nullptr)
- {
- }
-
- const ConstCpuTensorHandle* m_Beta;
- const ConstCpuTensorHandle* m_Eps;
- const ConstCpuTensorHandle* m_Gamma;
void Validate(const WorkloadInfo& workloadInfo) const;
};
diff --git a/src/backends/backendsCommon/common.mk b/src/backends/backendsCommon/common.mk
index ae9fdec269..3da2259966 100644
--- a/src/backends/backendsCommon/common.mk
+++ b/src/backends/backendsCommon/common.mk
@@ -55,6 +55,7 @@ COMMON_TEST_SOURCES := \
test/layerTests/FullyConnectedTestImpl.cpp \
test/layerTests/GatherTestImpl.cpp \
test/layerTests/GreaterTestImpl.cpp \
+ test/layerTests/InstanceNormalizationTestImpl.cpp \
test/layerTests/L2NormalizationTestImpl.cpp \
test/layerTests/LstmTestImpl.cpp \
test/layerTests/MaximumTestImpl.cpp \
diff --git a/src/backends/backendsCommon/test/CMakeLists.txt b/src/backends/backendsCommon/test/CMakeLists.txt
index 034fe4cb1b..f7d58bf3d5 100644
--- a/src/backends/backendsCommon/test/CMakeLists.txt
+++ b/src/backends/backendsCommon/test/CMakeLists.txt
@@ -86,6 +86,8 @@ list(APPEND armnnBackendsCommonUnitTests_sources
layerTests/GatherTestImpl.hpp
layerTests/GreaterTestImpl.cpp
layerTests/GreaterTestImpl.hpp
+ layerTests/InstanceNormalizationTestImpl.cpp
+ layerTests/InstanceNormalizationTestImpl.hpp
layerTests/L2NormalizationTestImpl.cpp
layerTests/L2NormalizationTestImpl.hpp
layerTests/LayerTestResult.hpp
diff --git a/src/backends/backendsCommon/test/LayerTests.hpp b/src/backends/backendsCommon/test/LayerTests.hpp
index 84125b55d3..239d0d5e79 100644
--- a/src/backends/backendsCommon/test/LayerTests.hpp
+++ b/src/backends/backendsCommon/test/LayerTests.hpp
@@ -27,6 +27,7 @@
#include <backendsCommon/test/layerTests/FullyConnectedTestImpl.hpp>
#include <backendsCommon/test/layerTests/GatherTestImpl.hpp>
#include <backendsCommon/test/layerTests/GreaterTestImpl.hpp>
+#include <backendsCommon/test/layerTests/InstanceNormalizationTestImpl.hpp>
#include <backendsCommon/test/layerTests/L2NormalizationTestImpl.hpp>
#include <backendsCommon/test/layerTests/LstmTestImpl.hpp>
#include <backendsCommon/test/layerTests/MaximumTestImpl.hpp>
diff --git a/src/backends/backendsCommon/test/layerTests/InstanceNormalizationTestImpl.cpp b/src/backends/backendsCommon/test/layerTests/InstanceNormalizationTestImpl.cpp
new file mode 100644
index 0000000000..4e9cbbf40d
--- /dev/null
+++ b/src/backends/backendsCommon/test/layerTests/InstanceNormalizationTestImpl.cpp
@@ -0,0 +1,267 @@
+//
+// Copyright © 2019 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "InstanceNormalizationTestImpl.hpp"
+
+#include <ResolveType.hpp>
+
+#include <armnn/ArmNN.hpp>
+
+#include <backendsCommon/CpuTensorHandle.hpp>
+#include <backendsCommon/IBackendInternal.hpp>
+#include <backendsCommon/WorkloadFactory.hpp>
+
+#include <backendsCommon/test/DataLayoutUtils.hpp>
+#include <backendsCommon/test/QuantizeHelper.hpp>
+#include <backendsCommon/test/TensorCopyUtils.hpp>
+#include <backendsCommon/test/WorkloadTestUtils.hpp>
+
+#include <test/TensorHelpers.hpp>
+
+namespace
+{
+
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
+LayerTestResult<T, 4> InstanceNormTestImpl(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::TensorInfo& inputTensorInfo,
+ const armnn::TensorInfo& outputTensorInfo,
+ const std::vector<float>& inputValues,
+ const std::vector<float>& expectedOutputValues,
+ armnn::InstanceNormalizationQueueDescriptor descriptor,
+ float qScale = 0.0f,
+ int32_t qOffset = 0)
+{
+ auto inputTensor = MakeTensor<T, 4>(inputTensorInfo, QuantizedVector<T>(qScale, qOffset, inputValues));
+
+ LayerTestResult<T, 4> result(outputTensorInfo);
+
+ result.outputExpected =
+ MakeTensor<T, 4>(outputTensorInfo, QuantizedVector<T>(qScale, qOffset, expectedOutputValues));
+
+ std::unique_ptr<armnn::ITensorHandle> inputHandle = workloadFactory.CreateTensorHandle(inputTensorInfo);
+ std::unique_ptr<armnn::ITensorHandle> outputHandle = workloadFactory.CreateTensorHandle(outputTensorInfo);
+
+ armnn::WorkloadInfo info;
+
+
+ AddInputToWorkload(descriptor, info, inputTensorInfo, inputHandle.get());
+ AddOutputToWorkload(descriptor, info, outputTensorInfo, outputHandle.get());
+
+ std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateInstanceNormalization(descriptor, info);
+
+ inputHandle->Allocate();
+ outputHandle->Allocate();
+
+ CopyDataToITensorHandle(inputHandle.get(), &inputTensor[0][0][0][0]);
+
+ workload->Execute();
+
+ CopyDataFromITensorHandle(&result.output[0][0][0][0], outputHandle.get());
+
+ return result;
+}
+
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
+LayerTestResult<T, 4> InstanceNormTest(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ armnn::DataLayout dataLayout)
+{
+ // BatchSize: 2
+ // Height: 2
+ // Width: 2
+ // Channels: 2
+
+ const armnn::TensorShape inputOutputShape{ 2, 2, 2, 2 };
+
+ armnn::TensorInfo inputTensorInfo(inputOutputShape, ArmnnType);
+ armnn::TensorInfo outputTensorInfo(inputOutputShape, ArmnnType);
+
+ std::vector<float> inputValues
+ {
+ // Batch 0, Height 0, Width 0 x Channel (2)
+ 0.f, 1.f,
+ // Batch 0, Height 0, Width 1 x Channel (2)
+ 0.f, 2.f,
+
+ // Batch 0, Height 1, Width 0 x Channel (2)
+ 0.f, 2.f,
+ // Batch 0, Height 1, Width 1 x Channel (2)
+ 0.f, 4.f,
+
+ // Batch 1, Height 0, Width 0 x Channel (2)
+ 1.f, -1.f,
+ // Batch 1, Height 0, Width 1 x Channel (2)
+ -1.f, 2.f,
+
+ // Batch 1, Height 1, Width 0 x Channel (2)
+ -1.f, -2.f,
+ // Batch 1, Height 1, Width 1 x Channel (2)
+ 1.f, 4.f
+ };
+
+ std::vector<float> expectedOutputValues
+ {
+ // Batch 0, Height 0, Width 0 x Channel (2)
+ 0.f, -1.1470304f,
+ // Batch 0, Height 0, Width 1 x Channel (2)
+ 0.f, -0.22940612f,
+ // Batch 0, Height 1, Width 0 x Channel (2)
+ 0.f, -0.22940612f,
+ // Batch 0, Height 1, Width 1 x Channel (2)
+ 0.f, 1.6058424f,
+
+ // Batch 1, Height 0, Width 0 x Channel (2)
+ 0.99995005f, -0.7337929f,
+ // Batch 1, Height 0, Width 1 x Channel (2)
+ -0.99995005f, 0.52413774f,
+
+ // Batch 1, Height 1, Width 0 x Channel (2)
+ -0.99995005f, -1.1531031f,
+ // Batch 1, Height 1, Width 1 x Channel (2)
+ 0.99995005f, 1.3627582f
+ };
+
+ if (dataLayout == armnn::DataLayout::NCHW)
+ {
+ PermuteTensorNhwcToNchw(inputTensorInfo, inputValues);
+ PermuteTensorNhwcToNchw(outputTensorInfo, expectedOutputValues);
+ }
+
+ armnn::InstanceNormalizationQueueDescriptor descriptor;
+ descriptor.m_Parameters.m_Eps = 0.0001f;
+ descriptor.m_Parameters.m_Beta = 0.0f;
+ descriptor.m_Parameters.m_Gamma = 1.0f;
+ descriptor.m_Parameters.m_DataLayout = dataLayout;
+
+ return InstanceNormTestImpl<ArmnnType>(
+ workloadFactory,
+ memoryManager,
+ inputTensorInfo,
+ outputTensorInfo,
+ inputValues,
+ expectedOutputValues,
+ descriptor);
+}
+
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
+LayerTestResult<T, 4> InstanceNormTest2(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ armnn::DataLayout dataLayout)
+{
+ // BatchSize: 2
+ // Height: 2
+ // Width: 2
+ // Channels: 2
+
+ const armnn::TensorShape inputOutputShape{ 2, 2, 2, 2 };
+
+ armnn::TensorInfo inputTensorInfo(inputOutputShape, ArmnnType);
+ armnn::TensorInfo outputTensorInfo(inputOutputShape, ArmnnType);
+
+ std::vector<float> inputValues
+ {
+ // Batch 0, Height 0, Width 0 x Channel (2)
+ 0.f, 1.f,
+ // Batch 0, Height 0, Width 1 x Channel (2)
+ 0.f, 2.f,
+
+ // Batch 0, Height 1, Width 0 x Channel (2)
+ 0.f, 2.f,
+ // Batch 0, Height 1, Width 1 x Channel (2)
+ 0.f, 4.f,
+
+ // Batch 1, Height 0, Width 0 x Channel (2)
+ 1.f, -1.f,
+ // Batch 1, Height 0, Width 1 x Channel (2)
+ -1.f, 2.f,
+
+ // Batch 1, Height 1, Width 0 x Channel (2)
+ -1.f, -2.f,
+ // Batch 1, Height 1, Width 1 x Channel (2)
+ 1.f, 4.f
+ };
+
+ std::vector<float> expectedOutputValues
+ {
+ // Batch 0, Height 0, Width 0 x Channel (2)
+ 10.f, 7.7059393f,
+ // Batch 0, Height 0, Width 1 x Channel (2)
+ 10.f, 9.541187f,
+
+ // Batch 0, Height 1, Width 0 x Channel (2)
+ 10.f, 9.541187f,
+ // Batch 0, Height 1, Width 1 x Channel (2)
+ 10.f, 13.211685f,
+
+ // Batch 1, Height 0, Width 0 x Channel (2)
+ 11.9999f, 8.532414f,
+ // Batch 1, Height 0, Width 1 x Channel (2)
+ 8.0001f, 11.048275f,
+
+ // Batch 1, Height 1, Width 0 x Channel (2)
+ 8.0001f, 7.693794f,
+ // Batch 1, Height 1, Width 1 x Channel (2)
+ 11.9999f, 12.725516f
+ };
+
+ if (dataLayout == armnn::DataLayout::NCHW)
+ {
+ PermuteTensorNhwcToNchw(inputTensorInfo, inputValues);
+ PermuteTensorNhwcToNchw(outputTensorInfo, expectedOutputValues);
+ }
+
+ armnn::InstanceNormalizationQueueDescriptor descriptor;
+ descriptor.m_Parameters.m_Eps = 0.0001f;
+ descriptor.m_Parameters.m_Beta = 10.0f;
+ descriptor.m_Parameters.m_Gamma = 2.0f;
+ descriptor.m_Parameters.m_DataLayout = dataLayout;
+
+ return InstanceNormTestImpl<ArmnnType>(
+ workloadFactory,
+ memoryManager,
+ inputTensorInfo,
+ outputTensorInfo,
+ inputValues,
+ expectedOutputValues,
+ descriptor);
+}
+
+} // anonymous namespace
+
+LayerTestResult<float, 4> InstanceNormFloat32Test(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ armnn::DataLayout dataLayout)
+{
+ return InstanceNormTest<armnn::DataType::Float32>(workloadFactory, memoryManager, dataLayout);
+}
+
+LayerTestResult<armnn::Half, 4> InstanceNormFloat16Test(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ armnn::DataLayout dataLayout)
+{
+ return InstanceNormTest<armnn::DataType::Float16>(workloadFactory, memoryManager, dataLayout);
+}
+
+LayerTestResult<float, 4> InstanceNormFloat32Test2(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ armnn::DataLayout dataLayout)
+{
+ return InstanceNormTest2<armnn::DataType::Float32>(workloadFactory, memoryManager, dataLayout);
+}
+
+LayerTestResult<armnn::Half, 4> InstanceNormFloat16Test2(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ armnn::DataLayout dataLayout)
+{
+ return InstanceNormTest2<armnn::DataType::Float16>(workloadFactory, memoryManager, dataLayout);
+}
diff --git a/src/backends/backendsCommon/test/layerTests/InstanceNormalizationTestImpl.hpp b/src/backends/backendsCommon/test/layerTests/InstanceNormalizationTestImpl.hpp
new file mode 100644
index 0000000000..4030bee6b1
--- /dev/null
+++ b/src/backends/backendsCommon/test/layerTests/InstanceNormalizationTestImpl.hpp
@@ -0,0 +1,36 @@
+//
+// Copyright © 2019 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "LayerTestResult.hpp"
+
+#include <Half.hpp>
+
+#include <armnn/Types.hpp>
+
+#include <backendsCommon/IBackendInternal.hpp>
+#include <backendsCommon/WorkloadFactory.hpp>
+
+LayerTestResult<float, 4> InstanceNormFloat32Test(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ armnn::DataLayout dataLayout);
+
+LayerTestResult<armnn::Half, 4> InstanceNormFloat16Test(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ armnn::DataLayout dataLayout);
+
+LayerTestResult<float, 4> InstanceNormFloat32Test2(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ armnn::DataLayout dataLayout);
+
+LayerTestResult<armnn::Half, 4> InstanceNormFloat16Test2(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ armnn::DataLayout dataLayout);
+
diff --git a/src/backends/cl/ClLayerSupport.cpp b/src/backends/cl/ClLayerSupport.cpp
index 450391581e..c5ed8bff2a 100644
--- a/src/backends/cl/ClLayerSupport.cpp
+++ b/src/backends/cl/ClLayerSupport.cpp
@@ -30,6 +30,7 @@
#include "workloads/ClDivisionFloatWorkload.hpp"
#include "workloads/ClFullyConnectedWorkload.hpp"
#include "workloads/ClGreaterWorkload.hpp"
+#include "workloads/ClInstanceNormalizationWorkload.hpp"
#include "workloads/ClL2NormalizationFloatWorkload.hpp"
#include "workloads/ClLstmFloatWorkload.hpp"
#include "workloads/ClMaximumWorkload.hpp"
@@ -410,6 +411,18 @@ bool ClLayerSupport::IsInputSupported(const TensorInfo& input,
return IsClBackendSupported(reasonIfUnsupported);
}
+bool ClLayerSupport::IsInstanceNormalizationSupported(const TensorInfo& input,
+ const TensorInfo& output,
+ const InstanceNormalizationDescriptor& descriptor,
+ Optional<std::string&> reasonIfUnsupported) const
+{
+ FORWARD_WORKLOAD_VALIDATE_FUNC(ClInstanceNormalizationWorkloadValidate,
+ reasonIfUnsupported,
+ input,
+ output,
+ descriptor);
+}
+
bool ClLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
const TensorInfo& output,
const L2NormalizationDescriptor& descriptor,
diff --git a/src/backends/cl/ClLayerSupport.hpp b/src/backends/cl/ClLayerSupport.hpp
index 1a37315d1d..59e849316f 100644
--- a/src/backends/cl/ClLayerSupport.hpp
+++ b/src/backends/cl/ClLayerSupport.hpp
@@ -110,6 +110,11 @@ public:
bool IsInputSupported(const TensorInfo& input,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+ bool IsInstanceNormalizationSupported(const TensorInfo& input,
+ const TensorInfo& output,
+ const InstanceNormalizationDescriptor& descriptor,
+ Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+
bool IsL2NormalizationSupported(const TensorInfo& input,
const TensorInfo& output,
const L2NormalizationDescriptor& descriptor,
diff --git a/src/backends/cl/ClWorkloadFactory.cpp b/src/backends/cl/ClWorkloadFactory.cpp
index ea3c27ebf2..c427ae7e12 100644
--- a/src/backends/cl/ClWorkloadFactory.cpp
+++ b/src/backends/cl/ClWorkloadFactory.cpp
@@ -101,8 +101,8 @@ std::unique_ptr<ITensorHandle> ClWorkloadFactory::CreateTensorHandle(const Tenso
return tensorHandle;
}
-std::unique_ptr<ITensorHandle> ClWorkloadFactory::CreateSubTensorHandle(ITensorHandle& parent,
- TensorShape const& subTensorShape,
+std::unique_ptr<ITensorHandle> ClWorkloadFactory::CreateSubTensorHandle(ITensorHandle& parent,
+ TensorShape const& subTensorShape,
unsigned int const* subTensorOrigin) const
{
arm_compute::Coordinates coords;
@@ -126,89 +126,86 @@ std::unique_ptr<ITensorHandle> ClWorkloadFactory::CreateSubTensorHandle(ITensorH
boost::polymorphic_downcast<IClTensorHandle*>(&parent), shape, coords);
}
-std::unique_ptr<IWorkload> ClWorkloadFactory::CreateInput(const InputQueueDescriptor& descriptor,
- const WorkloadInfo& info) const
-{
- return std::make_unique<CopyMemGenericWorkload>(descriptor, info);
-}
-
-std::unique_ptr<IWorkload> ClWorkloadFactory::CreateOutput(const OutputQueueDescriptor& descriptor,
- const WorkloadInfo& info) const
-{
- return std::make_unique<CopyMemGenericWorkload>(descriptor, info);
-}
-
std::unique_ptr<IWorkload> ClWorkloadFactory::CreateAbs(const AbsQueueDescriptor& descriptor,
- const WorkloadInfo& info) const
+ const WorkloadInfo& info) const
{
return MakeWorkload<ClAbsWorkload>(descriptor, info);
}
std::unique_ptr<IWorkload> ClWorkloadFactory::CreateActivation(const ActivationQueueDescriptor& descriptor,
- const WorkloadInfo& info) const
+ const WorkloadInfo& info) const
{
return MakeWorkload<ClActivationWorkload>(descriptor, info);
}
-std::unique_ptr<IWorkload> ClWorkloadFactory::CreateSoftmax(const SoftmaxQueueDescriptor& descriptor,
- const WorkloadInfo& info) const
+std::unique_ptr<IWorkload> ClWorkloadFactory::CreateAddition(const AdditionQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
{
- return MakeWorkload<ClSoftmaxFloatWorkload, ClSoftmaxUint8Workload>(descriptor, info,
- m_MemoryManager->GetIntraLayerManager());
+ return MakeWorkload<ClAdditionWorkload>(descriptor, info);
}
-std::unique_ptr<IWorkload> ClWorkloadFactory::CreateSplitter(const SplitterQueueDescriptor& descriptor,
- const WorkloadInfo& info) const
+std::unique_ptr<IWorkload> ClWorkloadFactory::CreateBatchNormalization(
+ const BatchNormalizationQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
{
- return MakeWorkload<ClSplitterWorkload>(descriptor, info);
+ return MakeWorkload<ClBatchNormalizationFloatWorkload, NullWorkload>(descriptor, info);
}
-std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreateMerger(const MergerQueueDescriptor& descriptor,
- const WorkloadInfo& info) const
+std::unique_ptr<IWorkload> ClWorkloadFactory::CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
{
- return CreateConcat(descriptor, info);
+ return MakeWorkload<ClBatchToSpaceNdWorkload>(descriptor, info);
}
-std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreateFullyConnected(
- const FullyConnectedQueueDescriptor& descriptor, const WorkloadInfo& info) const
+std::unique_ptr<IWorkload> ClWorkloadFactory::CreateConcat(const ConcatQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
{
- return MakeWorkload<ClFullyConnectedWorkload>(descriptor, info, m_MemoryManager->GetIntraLayerManager());
+ return MakeWorkload<ClConcatWorkload>(descriptor, info);
}
-std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreatePermute(const PermuteQueueDescriptor& descriptor,
- const WorkloadInfo& info) const
+std::unique_ptr<IWorkload> ClWorkloadFactory::CreateConstant(const ConstantQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
{
- return MakeWorkload<ClPermuteWorkload>(descriptor, info);
+ return MakeWorkload<ClConstantWorkload>(descriptor, info);
}
-std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreatePooling2d(const Pooling2dQueueDescriptor& descriptor,
- const WorkloadInfo& info) const
+std::unique_ptr<IWorkload> ClWorkloadFactory::CreateConvertFp16ToFp32(
+ const ConvertFp16ToFp32QueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
{
- return MakeWorkload<ClPooling2dWorkload>(descriptor, info);
+ return MakeWorkload<ClConvertFp16ToFp32Workload>(descriptor, info);
}
-std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreatePrelu(const armnn::PreluQueueDescriptor &descriptor,
- const armnn::WorkloadInfo &info) const
+std::unique_ptr<IWorkload> ClWorkloadFactory::CreateConvertFp32ToFp16(
+ const ConvertFp32ToFp16QueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
{
- return MakeWorkload<ClPreluWorkload>(descriptor, info);
+ return MakeWorkload<ClConvertFp32ToFp16Workload>(descriptor, info);
}
-std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreateConvolution2d(const Convolution2dQueueDescriptor& descriptor,
- const WorkloadInfo& info) const
+std::unique_ptr<IWorkload> ClWorkloadFactory::CreateConvolution2d(const Convolution2dQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
{
return MakeWorkload<ClConvolution2dWorkload>(descriptor, info, m_MemoryManager->GetIntraLayerManager());
}
-std::unique_ptr<IWorkload> ClWorkloadFactory::CreateDepthwiseConvolution2d(
- const DepthwiseConvolution2dQueueDescriptor& descriptor, const WorkloadInfo& info) const
+std::unique_ptr<IWorkload> ClWorkloadFactory::CreateDebug(const DebugQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
{
- return MakeWorkload<ClDepthwiseConvolutionWorkload>(descriptor, info);
+ return MakeWorkload<NullWorkload, NullWorkload>(descriptor, info);
}
-std::unique_ptr<IWorkload> ClWorkloadFactory::CreateDetectionPostProcess(
- const armnn::DetectionPostProcessQueueDescriptor& descriptor, const armnn::WorkloadInfo& info) const
+std::unique_ptr<IWorkload> ClWorkloadFactory::CreateDepthToSpace(const DepthToSpaceQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
{
- return MakeWorkload<NullWorkload, NullWorkload>(descriptor, info);
+ return MakeWorkload<ClDepthToSpaceWorkload>(descriptor, info);
+}
+
+std::unique_ptr<IWorkload> ClWorkloadFactory::CreateDepthwiseConvolution2d(
+ const DepthwiseConvolution2dQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
+{
+ return MakeWorkload<ClDepthwiseConvolutionWorkload>(descriptor, info);
}
std::unique_ptr<IWorkload> ClWorkloadFactory::CreateDequantize(const DequantizeQueueDescriptor& descriptor,
@@ -217,207 +214,212 @@ std::unique_ptr<IWorkload> ClWorkloadFactory::CreateDequantize(const DequantizeQ
return MakeWorkload<ClDequantizeWorkload>(descriptor, info);
}
-std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreateNormalization(const NormalizationQueueDescriptor& descriptor,
- const WorkloadInfo& info) const
+std::unique_ptr<IWorkload> ClWorkloadFactory::CreateDetectionPostProcess(
+ const DetectionPostProcessQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
{
- return MakeWorkload<ClNormalizationFloatWorkload, NullWorkload>(descriptor, info);
+ return MakeWorkload<NullWorkload, NullWorkload>(descriptor, info);
}
-std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreateAddition(const AdditionQueueDescriptor& descriptor,
- const WorkloadInfo& info) const
+std::unique_ptr<IWorkload> ClWorkloadFactory::CreateDivision(const DivisionQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
{
- return MakeWorkload<ClAdditionWorkload>(descriptor, info);
+ return MakeWorkload<ClDivisionFloatWorkload, NullWorkload>(descriptor, info);
}
-std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreateMultiplication(
- const MultiplicationQueueDescriptor& descriptor, const WorkloadInfo& info) const
+std::unique_ptr<IWorkload> ClWorkloadFactory::CreateEqual(const EqualQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
{
- return MakeWorkload<ClMultiplicationWorkload>(descriptor, info);
+ return MakeWorkload<NullWorkload, NullWorkload>(descriptor, info);
}
-std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreateDivision(
- const DivisionQueueDescriptor& descriptor, const WorkloadInfo& info) const
+std::unique_ptr<IWorkload> ClWorkloadFactory::CreateFakeQuantization(
+ const FakeQuantizationQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
{
- return MakeWorkload<ClDivisionFloatWorkload, NullWorkload>(descriptor, info);
+ return MakeWorkload<NullWorkload, NullWorkload>(descriptor, info);
}
-std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreateSubtraction(const SubtractionQueueDescriptor& descriptor,
- const WorkloadInfo& info) const
+std::unique_ptr<IWorkload> ClWorkloadFactory::CreateFloor(const FloorQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
{
- return MakeWorkload<ClSubtractionWorkload>(descriptor, info);
+ return MakeWorkload<ClFloorFloatWorkload, NullWorkload>(descriptor, info);
}
-std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreateBatchNormalization(
- const BatchNormalizationQueueDescriptor& descriptor, const WorkloadInfo& info) const
+std::unique_ptr<IWorkload> ClWorkloadFactory::CreateFullyConnected(const FullyConnectedQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
{
- return MakeWorkload<ClBatchNormalizationFloatWorkload, NullWorkload>(descriptor, info);
+ return MakeWorkload<ClFullyConnectedWorkload>(descriptor, info, m_MemoryManager->GetIntraLayerManager());
}
-std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreateMemCopy(const MemCopyQueueDescriptor& descriptor,
- const WorkloadInfo& info) const
+std::unique_ptr<IWorkload> ClWorkloadFactory::CreateGather(const GatherQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
{
- if (descriptor.m_Inputs.empty() || !descriptor.m_Inputs[0])
- {
- throw InvalidArgumentException("ClWorkloadFactory: Invalid null input for MemCopy workload");
- }
-
- return MakeWorkload<CopyMemGenericWorkload>(descriptor, info);
+ return MakeWorkload<NullWorkload, NullWorkload>(descriptor, info);
}
-std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreateMemImport(const MemImportQueueDescriptor& descriptor,
- const WorkloadInfo& info) const
+std::unique_ptr<IWorkload> ClWorkloadFactory::CreateGreater(const GreaterQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
{
- if (descriptor.m_Inputs.empty() || !descriptor.m_Inputs[0])
- {
- throw InvalidArgumentException("ClWorkloadFactory: Invalid null input for MemImport workload");
- }
-
- return std::make_unique<ImportMemGenericWorkload>(descriptor, info);
+ return MakeWorkload<ClGreaterFloat32Workload, ClGreaterUint8Workload>(descriptor, info);
}
-std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreateResize(const ResizeQueueDescriptor& descriptor,
- const WorkloadInfo& info) const
+std::unique_ptr<IWorkload> ClWorkloadFactory::CreateInput(const InputQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
{
- return MakeWorkload<ClResizeWorkload>(descriptor, info);
+ return std::make_unique<CopyMemGenericWorkload>(descriptor, info);
}
-std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreateResizeBilinear(
- const ResizeBilinearQueueDescriptor& descriptor,
+std::unique_ptr<IWorkload> ClWorkloadFactory::CreateInstanceNormalization(
+ const InstanceNormalizationQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
- ResizeQueueDescriptor resizeDescriptor;
- resizeDescriptor.m_Inputs = descriptor.m_Inputs;
- resizeDescriptor.m_Outputs = descriptor.m_Outputs;
-
- resizeDescriptor.m_Parameters.m_Method = ResizeMethod::Bilinear;
- resizeDescriptor.m_Parameters.m_DataLayout = descriptor.m_Parameters.m_DataLayout;
- resizeDescriptor.m_Parameters.m_TargetHeight = descriptor.m_Parameters.m_TargetHeight;
- resizeDescriptor.m_Parameters.m_TargetWidth = descriptor.m_Parameters.m_TargetWidth;
+ return MakeWorkload<ClInstanceNormalizationWorkload>(descriptor, info);
+}
- return CreateResize(resizeDescriptor, info);
+std::unique_ptr<IWorkload> ClWorkloadFactory::CreateL2Normalization(const L2NormalizationQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
+{
+ return MakeWorkload<ClL2NormalizationFloatWorkload, NullWorkload>(descriptor, info);
}
-std::unique_ptr<IWorkload> ClWorkloadFactory::CreateFakeQuantization(
- const FakeQuantizationQueueDescriptor& descriptor,
- const WorkloadInfo& info) const
+std::unique_ptr<IWorkload> ClWorkloadFactory::CreateLstm(const LstmQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
{
- return MakeWorkload<NullWorkload, NullWorkload>(descriptor, info);
+ return MakeWorkload<ClLstmFloatWorkload, NullWorkload>(descriptor, info);
}
-std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreateQuantize(const QuantizeQueueDescriptor& descriptor,
- const WorkloadInfo& info) const
+std::unique_ptr<IWorkload> ClWorkloadFactory::CreateMaximum(const MaximumQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
{
- return MakeWorkload<ClQuantizeWorkload>(descriptor, info);
+ return MakeWorkload<ClMaximumWorkload>(descriptor, info);
}
-std::unique_ptr<IWorkload> ClWorkloadFactory::CreateL2Normalization(const L2NormalizationQueueDescriptor& descriptor,
- const WorkloadInfo& info) const
+std::unique_ptr<IWorkload> ClWorkloadFactory::CreateMean(const MeanQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
{
- return MakeWorkload<ClL2NormalizationFloatWorkload, NullWorkload>(descriptor, info);
+ return MakeWorkload<ClMeanWorkload>(descriptor, info);
}
-std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreateConcat(const ConcatQueueDescriptor& descriptor,
- const WorkloadInfo& info) const
+std::unique_ptr<IWorkload> ClWorkloadFactory::CreateMemCopy(const MemCopyQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
{
- return MakeWorkload<ClConcatWorkload>(descriptor, info);
+ if (descriptor.m_Inputs.empty() || !descriptor.m_Inputs[0])
+ {
+ throw InvalidArgumentException("ClWorkloadFactory: Invalid null input for MemCopy workload");
+ }
+
+ return MakeWorkload<CopyMemGenericWorkload>(descriptor, info);
}
-std::unique_ptr<IWorkload> ClWorkloadFactory::CreateConstant(const ConstantQueueDescriptor& descriptor,
- const WorkloadInfo& info) const
+std::unique_ptr<IWorkload> ClWorkloadFactory::CreateMemImport(const MemImportQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
{
- return MakeWorkload<ClConstantWorkload>(descriptor, info);
+ if (descriptor.m_Inputs.empty() || !descriptor.m_Inputs[0])
+ {
+ throw InvalidArgumentException("ClWorkloadFactory: Invalid null input for MemImport workload");
+ }
+
+ return std::make_unique<ImportMemGenericWorkload>(descriptor, info);
}
-std::unique_ptr<IWorkload> ClWorkloadFactory::CreateReshape(const ReshapeQueueDescriptor& descriptor,
- const WorkloadInfo& info) const
+std::unique_ptr<IWorkload> ClWorkloadFactory::CreateMerger(const MergerQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
{
- return MakeWorkload<ClReshapeWorkload>(descriptor, info);
+ return CreateConcat(descriptor, info);
}
-std::unique_ptr<IWorkload> ClWorkloadFactory::CreateSpaceToBatchNd(const SpaceToBatchNdQueueDescriptor& descriptor,
- const WorkloadInfo& info) const
+std::unique_ptr<IWorkload> ClWorkloadFactory::CreateMinimum(const MinimumQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
{
- return MakeWorkload<ClSpaceToBatchNdWorkload>(descriptor, info);
+ return MakeWorkload<ClMinimumWorkload>(descriptor, info);
}
-std::unique_ptr<IWorkload> ClWorkloadFactory::CreateFloor(const FloorQueueDescriptor& descriptor,
- const WorkloadInfo& info) const
+std::unique_ptr<IWorkload> ClWorkloadFactory::CreateMultiplication(const MultiplicationQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
{
- return MakeWorkload<ClFloorFloatWorkload, NullWorkload>(descriptor, info);
+ return MakeWorkload<ClMultiplicationWorkload>(descriptor, info);
}
-std::unique_ptr<IWorkload> ClWorkloadFactory::CreateLstm(const LstmQueueDescriptor& descriptor,
- const WorkloadInfo& info) const
+std::unique_ptr<IWorkload> ClWorkloadFactory::CreateNormalization(const NormalizationQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
{
- return MakeWorkload<ClLstmFloatWorkload, NullWorkload>(descriptor, info);
+ return MakeWorkload<ClNormalizationFloatWorkload, NullWorkload>(descriptor, info);
}
-std::unique_ptr<IWorkload> ClWorkloadFactory::CreateConvertFp16ToFp32(
- const ConvertFp16ToFp32QueueDescriptor& descriptor,
- const WorkloadInfo& info) const
+std::unique_ptr<IWorkload> ClWorkloadFactory::CreateOutput(const OutputQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
{
- return MakeWorkload<ClConvertFp16ToFp32Workload>(descriptor, info);
+ return std::make_unique<CopyMemGenericWorkload>(descriptor, info);
}
-std::unique_ptr<IWorkload> ClWorkloadFactory::CreateConvertFp32ToFp16(
- const ConvertFp32ToFp16QueueDescriptor& descriptor,
- const WorkloadInfo& info) const
+std::unique_ptr<IWorkload> ClWorkloadFactory::CreatePad(const PadQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
{
- return MakeWorkload<ClConvertFp32ToFp16Workload>(descriptor, info);
+ return MakeWorkload<ClPadWorkload>(descriptor, info);
}
-std::unique_ptr<IWorkload> ClWorkloadFactory::CreateMaximum(const MaximumQueueDescriptor& descriptor,
+std::unique_ptr<IWorkload> ClWorkloadFactory::CreatePermute(const PermuteQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
- return MakeWorkload<ClMaximumWorkload>(descriptor, info);
+ return MakeWorkload<ClPermuteWorkload>(descriptor, info);
}
-std::unique_ptr<IWorkload> ClWorkloadFactory::CreateMean(const MeanQueueDescriptor& descriptor,
- const WorkloadInfo& info) const
+std::unique_ptr<IWorkload> ClWorkloadFactory::CreatePooling2d(const Pooling2dQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
{
- return MakeWorkload<ClMeanWorkload>(descriptor, info);
+ return MakeWorkload<ClPooling2dWorkload>(descriptor, info);
}
-std::unique_ptr<IWorkload> ClWorkloadFactory::CreatePad(const PadQueueDescriptor& descriptor,
- const WorkloadInfo& info) const
+std::unique_ptr<IWorkload> ClWorkloadFactory::CreatePreCompiled(const PreCompiledQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
{
- return MakeWorkload<ClPadWorkload>(descriptor, info);
+ return MakeWorkload<NullWorkload, NullWorkload>(descriptor, info);
}
-std::unique_ptr<IWorkload> ClWorkloadFactory::CreateEqual(const EqualQueueDescriptor& descriptor,
- const WorkloadInfo& info) const
+std::unique_ptr<IWorkload> ClWorkloadFactory::CreatePrelu(const PreluQueueDescriptor &descriptor,
+ const WorkloadInfo &info) const
{
- return MakeWorkload<NullWorkload, NullWorkload>(descriptor, info);
+ return MakeWorkload<ClPreluWorkload>(descriptor, info);
}
-std::unique_ptr<IWorkload> ClWorkloadFactory::CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor& descriptor,
- const WorkloadInfo& info) const
+std::unique_ptr<IWorkload> ClWorkloadFactory::CreateQuantize(const QuantizeQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
{
- return MakeWorkload<ClBatchToSpaceNdWorkload>(descriptor, info);
+ return MakeWorkload<ClQuantizeWorkload>(descriptor, info);
}
-std::unique_ptr<IWorkload> ClWorkloadFactory::CreateStridedSlice(const StridedSliceQueueDescriptor& descriptor,
- const WorkloadInfo& info) const
+std::unique_ptr<IWorkload> ClWorkloadFactory::CreateQuantizedLstm(const QuantizedLstmQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
{
- return MakeWorkload<ClStridedSliceWorkload>(descriptor, info);
+ return MakeWorkload<ClQuantizedLstmWorkload>(descriptor, info);
}
-std::unique_ptr<IWorkload> ClWorkloadFactory::CreateMinimum(const MinimumQueueDescriptor& descriptor,
+std::unique_ptr<IWorkload> ClWorkloadFactory::CreateReshape(const ReshapeQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
- return MakeWorkload<ClMinimumWorkload>(descriptor, info);
+ return MakeWorkload<ClReshapeWorkload>(descriptor, info);
}
-std::unique_ptr<IWorkload> ClWorkloadFactory::CreateGreater(const GreaterQueueDescriptor& descriptor,
- const WorkloadInfo& info) const
+std::unique_ptr<IWorkload> ClWorkloadFactory::CreateResize(const ResizeQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
{
- return MakeWorkload<ClGreaterFloat32Workload, ClGreaterUint8Workload>(descriptor, info);
+ return MakeWorkload<ClResizeWorkload>(descriptor, info);
}
-std::unique_ptr<IWorkload> ClWorkloadFactory::CreateDebug(const DebugQueueDescriptor& descriptor,
- const WorkloadInfo& info) const
+std::unique_ptr<IWorkload> ClWorkloadFactory::CreateResizeBilinear(const ResizeBilinearQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
{
- return MakeWorkload<NullWorkload, NullWorkload>(descriptor, info);
+ ResizeQueueDescriptor resizeDescriptor;
+ resizeDescriptor.m_Inputs = descriptor.m_Inputs;
+ resizeDescriptor.m_Outputs = descriptor.m_Outputs;
+
+ resizeDescriptor.m_Parameters.m_Method = ResizeMethod::Bilinear;
+ resizeDescriptor.m_Parameters.m_DataLayout = descriptor.m_Parameters.m_DataLayout;
+ resizeDescriptor.m_Parameters.m_TargetHeight = descriptor.m_Parameters.m_TargetHeight;
+ resizeDescriptor.m_Parameters.m_TargetWidth = descriptor.m_Parameters.m_TargetWidth;
+
+ return CreateResize(resizeDescriptor, info);
}
std::unique_ptr<IWorkload> ClWorkloadFactory::CreateRsqrt(const RsqrtQueueDescriptor& descriptor,
@@ -426,23 +428,17 @@ std::unique_ptr<IWorkload> ClWorkloadFactory::CreateRsqrt(const RsqrtQueueDescri
return MakeWorkload<ClRsqrtWorkload>(descriptor, info);
}
-std::unique_ptr<IWorkload> ClWorkloadFactory::CreatePreCompiled(const PreCompiledQueueDescriptor& descriptor,
- const WorkloadInfo& info) const
-{
- return MakeWorkload<NullWorkload, NullWorkload>(descriptor, info);
-}
-
-std::unique_ptr<IWorkload> ClWorkloadFactory::CreateGather(const armnn::GatherQueueDescriptor& descriptor,
- const armnn::WorkloadInfo& info) const
+std::unique_ptr<IWorkload> ClWorkloadFactory::CreateSoftmax(const SoftmaxQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
{
- return MakeWorkload<NullWorkload, NullWorkload>(descriptor, info);
+ return MakeWorkload<ClSoftmaxFloatWorkload, ClSoftmaxUint8Workload>(descriptor, info,
+ m_MemoryManager->GetIntraLayerManager());
}
-std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreateTransposeConvolution2d(
- const TransposeConvolution2dQueueDescriptor& descriptor,
- const WorkloadInfo& info) const
+std::unique_ptr<IWorkload> ClWorkloadFactory::CreateSpaceToBatchNd(const SpaceToBatchNdQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
{
- return MakeWorkload<ClTransposeConvolution2dWorkload>(descriptor, info, m_MemoryManager->GetIntraLayerManager());
+ return MakeWorkload<ClSpaceToBatchNdWorkload>(descriptor, info);
}
std::unique_ptr<IWorkload> ClWorkloadFactory::CreateSpaceToDepth(const SpaceToDepthQueueDescriptor& descriptor,
@@ -451,10 +447,10 @@ std::unique_ptr<IWorkload> ClWorkloadFactory::CreateSpaceToDepth(const SpaceToDe
return MakeWorkload<ClSpaceToDepthWorkload>(descriptor, info);
}
-std::unique_ptr<IWorkload> ClWorkloadFactory::CreateQuantizedLstm(const QuantizedLstmQueueDescriptor& descriptor,
- const WorkloadInfo& info) const
+std::unique_ptr<IWorkload> ClWorkloadFactory::CreateSplitter(const SplitterQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
{
- return MakeWorkload<ClQuantizedLstmWorkload>(descriptor, info);
+ return MakeWorkload<ClSplitterWorkload>(descriptor, info);
}
std::unique_ptr<IWorkload> ClWorkloadFactory::CreateStack(const StackQueueDescriptor& descriptor,
@@ -463,10 +459,23 @@ std::unique_ptr<IWorkload> ClWorkloadFactory::CreateStack(const StackQueueDescri
return MakeWorkload<ClStackWorkload>(descriptor, info);
}
-std::unique_ptr<IWorkload> ClWorkloadFactory::CreateDepthToSpace(const DepthToSpaceQueueDescriptor& descriptor,
+std::unique_ptr<IWorkload> ClWorkloadFactory::CreateStridedSlice(const StridedSliceQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
- return MakeWorkload<ClDepthToSpaceWorkload>(descriptor, info);
+ return MakeWorkload<ClStridedSliceWorkload>(descriptor, info);
+}
+
+std::unique_ptr<IWorkload> ClWorkloadFactory::CreateSubtraction(const SubtractionQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
+{
+ return MakeWorkload<ClSubtractionWorkload>(descriptor, info);
+}
+
+std::unique_ptr<IWorkload> ClWorkloadFactory::CreateTransposeConvolution2d(
+ const TransposeConvolution2dQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
+{
+ return MakeWorkload<ClTransposeConvolution2dWorkload>(descriptor, info, m_MemoryManager->GetIntraLayerManager());
}
} // namespace armnn
diff --git a/src/backends/cl/ClWorkloadFactory.hpp b/src/backends/cl/ClWorkloadFactory.hpp
index 7ea7f261a1..9dbc615a4e 100644
--- a/src/backends/cl/ClWorkloadFactory.hpp
+++ b/src/backends/cl/ClWorkloadFactory.hpp
@@ -38,167 +38,170 @@ public:
DataLayout dataLayout,
const bool IsMemoryManaged = true) const override;
- std::unique_ptr<IWorkload> CreateInput(const InputQueueDescriptor& descriptor,
- const WorkloadInfo& info) const override;
-
- std::unique_ptr<IWorkload> CreateOutput(const OutputQueueDescriptor& descriptor,
- const WorkloadInfo& info) const override;
-
std::unique_ptr<IWorkload> CreateAbs(const AbsQueueDescriptor& descriptor,
const WorkloadInfo& info) const override;
std::unique_ptr<IWorkload> CreateActivation(const ActivationQueueDescriptor& descriptor,
const WorkloadInfo& info) const override;
- std::unique_ptr<IWorkload> CreateSoftmax(const SoftmaxQueueDescriptor& descriptor,
- const WorkloadInfo& info) const override;
-
- std::unique_ptr<IWorkload> CreateSplitter(const SplitterQueueDescriptor& descriptor,
+ std::unique_ptr<IWorkload> CreateAddition(const AdditionQueueDescriptor& descriptor,
const WorkloadInfo& info) const override;
- ARMNN_DEPRECATED_MSG("Use CreateConcat instead")
- std::unique_ptr<IWorkload> CreateMerger(const MergerQueueDescriptor& descriptor,
- const WorkloadInfo& info) const override;
+ std::unique_ptr<IWorkload> CreateBatchNormalization(const BatchNormalizationQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const override;
- std::unique_ptr<IWorkload> CreateFullyConnected(const FullyConnectedQueueDescriptor& descriptor,
+ std::unique_ptr<IWorkload> CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor& descriptor,
const WorkloadInfo& info) const override;
- std::unique_ptr<IWorkload> CreatePermute(const PermuteQueueDescriptor& descriptor,
- const WorkloadInfo& info) const override;
+ std::unique_ptr<IWorkload> CreateConcat(const ConcatQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const override;
- std::unique_ptr<IWorkload> CreatePooling2d(const Pooling2dQueueDescriptor& descriptor,
- const WorkloadInfo& info) const override;
+ std::unique_ptr<IWorkload> CreateConstant(const ConstantQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const override;
- std::unique_ptr<IWorkload> CreatePrelu(const PreluQueueDescriptor& descriptor,
- const WorkloadInfo& info) const override;
+ std::unique_ptr<IWorkload> CreateConvertFp16ToFp32(const ConvertFp16ToFp32QueueDescriptor& descriptor,
+ const WorkloadInfo& info) const override;
+
+ std::unique_ptr<IWorkload> CreateConvertFp32ToFp16(const ConvertFp32ToFp16QueueDescriptor& descriptor,
+ const WorkloadInfo& info) const override;
std::unique_ptr<IWorkload> CreateConvolution2d(const Convolution2dQueueDescriptor& descriptor,
const WorkloadInfo& info) const override;
+ std::unique_ptr<IWorkload> CreateDebug(const DebugQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const override;
+
std::unique_ptr<IWorkload> CreateDepthToSpace(const DepthToSpaceQueueDescriptor& descriptor,
const WorkloadInfo& info) const override;
std::unique_ptr<IWorkload> CreateDepthwiseConvolution2d(const DepthwiseConvolution2dQueueDescriptor& descriptor,
const WorkloadInfo& info) const override;
- std::unique_ptr<IWorkload> CreateDetectionPostProcess(const DetectionPostProcessQueueDescriptor& descriptor,
- const WorkloadInfo& info) const override;
-
std::unique_ptr<IWorkload> CreateDequantize(const DequantizeQueueDescriptor& descriptor,
const WorkloadInfo& info) const override;
- std::unique_ptr<IWorkload> CreateNormalization(const NormalizationQueueDescriptor& descriptor,
- const WorkloadInfo& info) const override;
+ std::unique_ptr<IWorkload> CreateDetectionPostProcess(const DetectionPostProcessQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const override;
- std::unique_ptr<IWorkload> CreateAddition(const AdditionQueueDescriptor& descriptor,
+ std::unique_ptr<IWorkload> CreateDivision(const DivisionQueueDescriptor& descriptor,
const WorkloadInfo& info) const override;
- std::unique_ptr<IWorkload> CreateMultiplication(const MultiplicationQueueDescriptor& descriptor,
- const WorkloadInfo& info) const override;
+ std::unique_ptr<IWorkload> CreateEqual(const EqualQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const override;
- std::unique_ptr<IWorkload> CreateBatchNormalization(const BatchNormalizationQueueDescriptor& descriptor,
- const WorkloadInfo& info) const override;
+ std::unique_ptr<IWorkload> CreateFakeQuantization(const FakeQuantizationQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const override;
- std::unique_ptr<IWorkload> CreateMemCopy(const MemCopyQueueDescriptor& descriptor,
- const WorkloadInfo& info) const override;
+ std::unique_ptr<IWorkload> CreateFloor(const FloorQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const override;
- std::unique_ptr<IWorkload> CreateMemImport(const MemImportQueueDescriptor& descriptor,
- const WorkloadInfo& info) const override;
+ std::unique_ptr<IWorkload> CreateFullyConnected(const FullyConnectedQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const override;
- std::unique_ptr<IWorkload> CreateResize(const ResizeQueueDescriptor& descriptor,
+ std::unique_ptr<IWorkload> CreateGather(const GatherQueueDescriptor& descriptor,
const WorkloadInfo& info) const override;
- ARMNN_DEPRECATED_MSG("Use CreateResize instead")
- std::unique_ptr<IWorkload> CreateResizeBilinear(const ResizeBilinearQueueDescriptor& descriptor,
- const WorkloadInfo& info) const override;
+ std::unique_ptr<IWorkload> CreateGreater(const GreaterQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const override;
- std::unique_ptr<IWorkload> CreateFakeQuantization(const FakeQuantizationQueueDescriptor& descriptor,
- const WorkloadInfo& info) const override;
+ std::unique_ptr<IWorkload> CreateInput(const InputQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const override;
- std::unique_ptr<IWorkload> CreateQuantize(const QuantizeQueueDescriptor& descriptor,
- const WorkloadInfo& info) const override;
+ std::unique_ptr<IWorkload> CreateInstanceNormalization(const InstanceNormalizationQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const override;
std::unique_ptr<IWorkload> CreateL2Normalization(const L2NormalizationQueueDescriptor& descriptor,
const WorkloadInfo& info) const override;
- std::unique_ptr<IWorkload> CreateConcat(const ConcatQueueDescriptor& descriptor,
- const WorkloadInfo& info) const override;
+ std::unique_ptr<IWorkload> CreateLstm(const LstmQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const override;
- std::unique_ptr<IWorkload> CreateConstant(const ConstantQueueDescriptor& descriptor,
- const WorkloadInfo& info) const override;
+ std::unique_ptr<IWorkload> CreateMaximum(const MaximumQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const override;
- std::unique_ptr<IWorkload> CreateReshape(const ReshapeQueueDescriptor& descriptor,
+ std::unique_ptr<IWorkload> CreateMean(const MeanQueueDescriptor& descriptor,
+ const WorkloadInfo& Info) const override;
+
+ std::unique_ptr<IWorkload> CreateMemCopy(const MemCopyQueueDescriptor& descriptor,
const WorkloadInfo& info) const override;
- std::unique_ptr<IWorkload> CreateSpaceToBatchNd(const SpaceToBatchNdQueueDescriptor& descriptor,
- const WorkloadInfo& info) const override;
+ std::unique_ptr<IWorkload> CreateMemImport(const MemImportQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const override;
- std::unique_ptr<IWorkload> CreateFloor(const FloorQueueDescriptor& descriptor,
- const WorkloadInfo& info) const override;
+ ARMNN_DEPRECATED_MSG("Use CreateConcat instead")
+ std::unique_ptr<IWorkload> CreateMerger(const MergerQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const override;
- std::unique_ptr<IWorkload> CreateLstm(const LstmQueueDescriptor& descriptor,
- const WorkloadInfo& info) const override;
+ std::unique_ptr<IWorkload> CreateMinimum(const MinimumQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const override;
- std::unique_ptr<IWorkload> CreateConvertFp16ToFp32(const ConvertFp16ToFp32QueueDescriptor& descriptor,
- const WorkloadInfo& info) const override;
+ std::unique_ptr<IWorkload> CreateMultiplication(const MultiplicationQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const override;
- std::unique_ptr<IWorkload> CreateConvertFp32ToFp16(const ConvertFp32ToFp16QueueDescriptor& descriptor,
- const WorkloadInfo& info) const override;
+ std::unique_ptr<IWorkload> CreateNormalization(const NormalizationQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const override;
- std::unique_ptr<IWorkload> CreateDivision(const DivisionQueueDescriptor& descriptor,
- const WorkloadInfo& info) const override;
+ std::unique_ptr<IWorkload> CreateOutput(const OutputQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const override;
- std::unique_ptr<IWorkload> CreateSubtraction(const SubtractionQueueDescriptor& descriptor,
- const WorkloadInfo& info) const override;
+ std::unique_ptr<IWorkload> CreatePad(const PadQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const override;
- std::unique_ptr<IWorkload> CreateMaximum(const MaximumQueueDescriptor& descriptor,
+ std::unique_ptr<IWorkload> CreatePermute(const PermuteQueueDescriptor& descriptor,
const WorkloadInfo& info) const override;
- std::unique_ptr<IWorkload> CreateMean(const MeanQueueDescriptor& descriptor,
- const WorkloadInfo& Info) const override;
+ std::unique_ptr<IWorkload> CreatePooling2d(const Pooling2dQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const override;
- std::unique_ptr<IWorkload> CreatePad(const PadQueueDescriptor& descriptor,
- const WorkloadInfo& info) const override;
+ std::unique_ptr<IWorkload> CreatePreCompiled(const PreCompiledQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const override;
- std::unique_ptr<IWorkload> CreateEqual(const EqualQueueDescriptor& descriptor,
+ std::unique_ptr<IWorkload> CreatePrelu(const PreluQueueDescriptor& descriptor,
const WorkloadInfo& info) const override;
- std::unique_ptr<IWorkload> CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor& descriptor,
- const WorkloadInfo& info) const override;
+ std::unique_ptr<IWorkload> CreateQuantize(const QuantizeQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const override;
- std::unique_ptr<IWorkload> CreateStridedSlice(const StridedSliceQueueDescriptor& descriptor,
- const WorkloadInfo& info) const override;
+ std::unique_ptr<IWorkload> CreateQuantizedLstm(const QuantizedLstmQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const override;
- std::unique_ptr<IWorkload> CreateMinimum(const MinimumQueueDescriptor& descriptor,
+ std::unique_ptr<IWorkload> CreateReshape(const ReshapeQueueDescriptor& descriptor,
const WorkloadInfo& info) const override;
- std::unique_ptr<IWorkload> CreateGreater(const GreaterQueueDescriptor& descriptor,
- const WorkloadInfo& info) const override;
+ std::unique_ptr<IWorkload> CreateResize(const ResizeQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const override;
- std::unique_ptr<IWorkload> CreateDebug(const DebugQueueDescriptor& descriptor,
- const WorkloadInfo& info) const override;
+ ARMNN_DEPRECATED_MSG("Use CreateResize instead")
+ std::unique_ptr<IWorkload> CreateResizeBilinear(const ResizeBilinearQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const override;
std::unique_ptr<IWorkload> CreateRsqrt(const RsqrtQueueDescriptor& descriptor,
const WorkloadInfo& info) const override;
- std::unique_ptr<IWorkload> CreatePreCompiled(const PreCompiledQueueDescriptor& descriptor,
- const WorkloadInfo& info) const override;
-
- std::unique_ptr<IWorkload> CreateGather(const GatherQueueDescriptor& descriptor,
- const WorkloadInfo& info) const override;
+ std::unique_ptr<IWorkload> CreateSoftmax(const SoftmaxQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const override;
- std::unique_ptr<IWorkload> CreateTransposeConvolution2d(const TransposeConvolution2dQueueDescriptor& descriptor,
- const WorkloadInfo& info) const override;
+ std::unique_ptr<IWorkload> CreateSpaceToBatchNd(const SpaceToBatchNdQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const override;
std::unique_ptr<IWorkload> CreateSpaceToDepth(const SpaceToDepthQueueDescriptor& descriptor,
const WorkloadInfo& info) const override;
- std::unique_ptr<IWorkload> CreateQuantizedLstm(const QuantizedLstmQueueDescriptor& descriptor,
- const WorkloadInfo& info) const override;
+ std::unique_ptr<IWorkload> CreateSplitter(const SplitterQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const override;
std::unique_ptr<IWorkload> CreateStack(const StackQueueDescriptor& descriptor,
const WorkloadInfo& info) const override;
+ std::unique_ptr<IWorkload> CreateStridedSlice(const StridedSliceQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const override;
+
+ std::unique_ptr<IWorkload> CreateSubtraction(const SubtractionQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const override;
+
+ std::unique_ptr<IWorkload> CreateTransposeConvolution2d(const TransposeConvolution2dQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const override;
+
private:
template<typename FloatWorkload, typename Uint8Workload, typename QueueDescriptorType, typename... Args>
static std::unique_ptr<IWorkload> MakeWorkload(const QueueDescriptorType& descriptor,
diff --git a/src/backends/cl/backend.mk b/src/backends/cl/backend.mk
index 20ece76953..b78bae1582 100644
--- a/src/backends/cl/backend.mk
+++ b/src/backends/cl/backend.mk
@@ -39,6 +39,7 @@ BACKEND_SOURCES := \
workloads/ClFloorFloatWorkload.cpp \
workloads/ClFullyConnectedWorkload.cpp \
workloads/ClGreaterWorkload.cpp \
+ workloads/ClInstanceNormalizationWorkload.cpp \
workloads/ClL2NormalizationFloatWorkload.cpp \
workloads/ClLstmFloatWorkload.cpp \
workloads/ClMaximumWorkload.cpp \
diff --git a/src/backends/cl/test/ClLayerTests.cpp b/src/backends/cl/test/ClLayerTests.cpp
index 76e5061ce2..8879928d27 100644
--- a/src/backends/cl/test/ClLayerTests.cpp
+++ b/src/backends/cl/test/ClLayerTests.cpp
@@ -294,6 +294,19 @@ ARMNN_AUTO_TEST_CASE(Multiplication5d, Multiplication5dTest)
ARMNN_AUTO_TEST_CASE(BatchNormFloat32, BatchNormFloat32Test)
ARMNN_AUTO_TEST_CASE(BatchNormFloat32Nhwc, BatchNormFloat32NhwcTest)
+// InstanceNormalization
+ARMNN_AUTO_TEST_CASE(InstanceNormFloat32Nchw, InstanceNormFloat32Test, DataLayout::NCHW);
+ARMNN_AUTO_TEST_CASE(InstanceNormFloat16Nchw, InstanceNormFloat16Test, DataLayout::NCHW);
+
+ARMNN_AUTO_TEST_CASE(InstanceNormFloat32Nhwc, InstanceNormFloat32Test, DataLayout::NHWC);
+ARMNN_AUTO_TEST_CASE(InstanceNormFloat16Nhwc, InstanceNormFloat16Test, DataLayout::NHWC);
+
+ARMNN_AUTO_TEST_CASE(InstanceNormFloat32Nchw2, InstanceNormFloat32Test2, DataLayout::NCHW);
+ARMNN_AUTO_TEST_CASE(InstanceNormFloat16Nchw2, InstanceNormFloat16Test2, DataLayout::NCHW);
+
+ARMNN_AUTO_TEST_CASE(InstanceNormFloat32Nhwc2, InstanceNormFloat32Test2, DataLayout::NHWC);
+ARMNN_AUTO_TEST_CASE(InstanceNormFloat16Nhwc2, InstanceNormFloat16Test2, DataLayout::NHWC);
+
// L2 Normalization
ARMNN_AUTO_TEST_CASE(L2Normalization1d, L2Normalization1dTest, DataLayout::NCHW)
ARMNN_AUTO_TEST_CASE(L2Normalization2d, L2Normalization2dTest, DataLayout::NCHW)
diff --git a/src/backends/cl/workloads/CMakeLists.txt b/src/backends/cl/workloads/CMakeLists.txt
index affe869c89..a9f320d51f 100644
--- a/src/backends/cl/workloads/CMakeLists.txt
+++ b/src/backends/cl/workloads/CMakeLists.txt
@@ -38,6 +38,8 @@ list(APPEND armnnClBackendWorkloads_sources
ClFullyConnectedWorkload.hpp
ClGreaterWorkload.cpp
ClGreaterWorkload.hpp
+ ClInstanceNormalizationWorkload.cpp
+ ClInstanceNormalizationWorkload.hpp
ClL2NormalizationFloatWorkload.cpp
ClL2NormalizationFloatWorkload.hpp
ClLstmFloatWorkload.cpp
diff --git a/src/backends/cl/workloads/ClInstanceNormalizationWorkload.cpp b/src/backends/cl/workloads/ClInstanceNormalizationWorkload.cpp
new file mode 100644
index 0000000000..50cf345a7f
--- /dev/null
+++ b/src/backends/cl/workloads/ClInstanceNormalizationWorkload.cpp
@@ -0,0 +1,59 @@
+//
+// Copyright © 2019 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "ClInstanceNormalizationWorkload.hpp"
+#include "ClWorkloadUtils.hpp"
+
+#include <aclCommon/ArmComputeTensorUtils.hpp>
+
+#include <cl/ClTensorHandle.hpp>
+
+using namespace armnn::armcomputetensorutils;
+
+namespace armnn
+{
+
+arm_compute::Status ClInstanceNormalizationWorkloadValidate(const TensorInfo& input,
+ const TensorInfo& output,
+ const InstanceNormalizationDescriptor& descriptor)
+{
+ const arm_compute::TensorInfo aclInputInfo = BuildArmComputeTensorInfo(input, descriptor.m_DataLayout);
+ const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output, descriptor.m_DataLayout);
+
+ return arm_compute::CLInstanceNormalizationLayer::validate(&aclInputInfo,
+ &aclOutputInfo,
+ descriptor.m_Gamma,
+ descriptor.m_Beta,
+ descriptor.m_Eps);
+}
+
+ClInstanceNormalizationWorkload::ClInstanceNormalizationWorkload(
+ const InstanceNormalizationQueueDescriptor& descriptor,
+ const WorkloadInfo& info)
+ : BaseWorkload<InstanceNormalizationQueueDescriptor>(descriptor, info)
+{
+ m_Data.ValidateInputsOutputs("ClInstanceNormalizationWorkload", 1, 1);
+
+ arm_compute::ICLTensor& input = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
+ arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
+
+ arm_compute::DataLayout aclDataLayout = ConvertDataLayout(m_Data.m_Parameters.m_DataLayout);
+ input.info()->set_data_layout(aclDataLayout);
+ output.info()->set_data_layout(aclDataLayout);
+
+ m_Layer.configure(&input,
+ &output,
+ descriptor.m_Parameters.m_Gamma,
+ descriptor.m_Parameters.m_Beta,
+ descriptor.m_Parameters.m_Eps);
+};
+
+void ClInstanceNormalizationWorkload::Execute() const
+{
+ ARMNN_SCOPED_PROFILING_EVENT_CL("ClInstanceNormalizationWorkload_Execute");
+ RunClFunction(m_Layer, CHECK_LOCATION());
+}
+
+} // namespace armnn
diff --git a/src/backends/cl/workloads/ClInstanceNormalizationWorkload.hpp b/src/backends/cl/workloads/ClInstanceNormalizationWorkload.hpp
new file mode 100644
index 0000000000..0e37bdcc9b
--- /dev/null
+++ b/src/backends/cl/workloads/ClInstanceNormalizationWorkload.hpp
@@ -0,0 +1,29 @@
+//
+// Copyright © 2019 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include <backendsCommon/Workload.hpp>
+
+#include <arm_compute/runtime/CL/functions/CLInstanceNormalizationLayer.h>
+
+namespace armnn
+{
+
+arm_compute::Status ClInstanceNormalizationWorkloadValidate(const TensorInfo& input,
+ const TensorInfo& output,
+ const InstanceNormalizationDescriptor& descriptor);
+
+class ClInstanceNormalizationWorkload : public BaseWorkload<InstanceNormalizationQueueDescriptor>
+{
+public:
+ ClInstanceNormalizationWorkload(const InstanceNormalizationQueueDescriptor& descriptor, const WorkloadInfo& info);
+ void Execute() const override;
+
+private:
+ mutable arm_compute::CLInstanceNormalizationLayer m_Layer;
+};
+
+} // namespace armnn
diff --git a/src/backends/cl/workloads/ClWorkloads.hpp b/src/backends/cl/workloads/ClWorkloads.hpp
index 5c42e764b8..cd6ca5fe17 100644
--- a/src/backends/cl/workloads/ClWorkloads.hpp
+++ b/src/backends/cl/workloads/ClWorkloads.hpp
@@ -18,6 +18,7 @@
#include "ClFloorFloatWorkload.hpp"
#include "ClFullyConnectedWorkload.hpp"
#include "ClGreaterWorkload.hpp"
+#include "ClInstanceNormalizationWorkload.hpp"
#include "ClL2NormalizationFloatWorkload.hpp"
#include "ClLstmFloatWorkload.hpp"
#include "ClConcatWorkload.hpp"