aboutsummaryrefslogtreecommitdiff
path: root/src/backends/cl
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/cl')
-rw-r--r--src/backends/cl/CMakeLists.txt3
-rw-r--r--src/backends/cl/ClBackendContext.cpp133
-rw-r--r--src/backends/cl/ClBackendContext.hpp5
-rw-r--r--src/backends/cl/ClContextControl.cpp53
-rw-r--r--src/backends/cl/ClContextControl.hpp22
-rw-r--r--src/backends/cl/ClImportTensorHandle.hpp4
-rw-r--r--src/backends/cl/ClLayerSupport.cpp21
-rw-r--r--src/backends/cl/ClLayerSupport.hpp6
-rw-r--r--src/backends/cl/ClTensorHandle.hpp4
-rw-r--r--src/backends/cl/ClTensorHandleFactory.cpp6
-rw-r--r--src/backends/cl/ClTensorHandleFactory.hpp6
-rw-r--r--src/backends/cl/ClWorkloadFactory.cpp5
-rw-r--r--src/backends/cl/IClTensorHandle.hpp22
-rw-r--r--src/backends/cl/backend.mk1
-rw-r--r--src/backends/cl/test/CMakeLists.txt4
-rw-r--r--src/backends/cl/test/ClDefaultAllocatorTests.cpp (renamed from src/backends/cl/test/DefaultAllocatorTests.cpp)2
-rw-r--r--src/backends/cl/test/ClLayerTests.cpp23
-rw-r--r--src/backends/cl/workloads/CMakeLists.txt2
-rw-r--r--src/backends/cl/workloads/ClBatchMatMulWorkload.cpp203
-rw-r--r--src/backends/cl/workloads/ClBatchMatMulWorkload.hpp41
-rw-r--r--src/backends/cl/workloads/ClWorkloads.hpp1
21 files changed, 333 insertions, 234 deletions
diff --git a/src/backends/cl/CMakeLists.txt b/src/backends/cl/CMakeLists.txt
index aeb90b069c..20c42061fc 100644
--- a/src/backends/cl/CMakeLists.txt
+++ b/src/backends/cl/CMakeLists.txt
@@ -1,5 +1,5 @@
#
-# Copyright © 2017 Arm Ltd. All rights reserved.
+# Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
# SPDX-License-Identifier: MIT
#
@@ -44,7 +44,6 @@ if(ARMCOMPUTECL)
ClTensorHandleFactory.hpp
ClWorkloadFactory.cpp
ClWorkloadFactory.hpp
- IClTensorHandle.hpp
ICLTensorProxy.hpp
OpenClTimer.cpp
OpenClTimer.hpp
diff --git a/src/backends/cl/ClBackendContext.cpp b/src/backends/cl/ClBackendContext.cpp
index 62c6b038da..adee2763ba 100644
--- a/src/backends/cl/ClBackendContext.cpp
+++ b/src/backends/cl/ClBackendContext.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017, 2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -20,20 +20,11 @@ namespace armnn
struct ClBackendContext::ClContextControlWrapper
{
- ClContextControlWrapper() {}
-
- bool IsInitialised()
- {
- return m_Initialised;
- }
-
- void Init(arm_compute::CLTuner* tuner,
- arm_compute::CLGEMMHeuristicsHandle* heuristicsHandle,
- bool profilingEnabled)
- {
- m_ClContextControl = ClContextControl(tuner, heuristicsHandle, profilingEnabled);
- m_Initialised = true;
- }
+ ClContextControlWrapper(arm_compute::CLTuner* tuner,
+ arm_compute::CLGEMMHeuristicsHandle* heuristicsHandle,
+ bool profilingEnabled)
+ : m_ClContextControl(tuner, heuristicsHandle, profilingEnabled)
+ {}
bool Sync()
{
@@ -62,106 +53,12 @@ struct ClBackendContext::ClContextControlWrapper
{
// There are no loaded networks left, so clear the CL cache to free up memory
m_ClContextControl.ClearClCache();
- m_Initialised = false;
}
}
-private:
- bool m_Initialised;
ClContextControl m_ClContextControl;
-
};
-/**
- * Returns a shared_ptr to the CLContextControlWrapper. This wraps the CLContextControl and ensures that we only create
- * and use one at a time.
- */
-std::shared_ptr<ClBackendContext::ClContextControlWrapper> ClBackendContext::Get()
-{
- static std::shared_ptr<ClBackendContext::ClContextControlWrapper> instance
- = std::make_shared<ClBackendContext::ClContextControlWrapper>();
- // Instantiated on first use.
- return instance;
-}
-
-std::string LowerString(std::string value)
-{
- std::transform(value.begin(), value.end(), value.begin(),
- [](unsigned char c){ return std::tolower(c); });
-
- return value;
-}
-
-enum class TuningLevel
-{
- None,
- Rapid,
- Normal,
- Exhaustive
-};
-
-
-TuningLevel ParseTuningLevel(const BackendOptions::Var& value, TuningLevel defaultValue)
-{
- if (value.IsInt())
- {
- int v = value.AsInt();
- if (v > static_cast<int>(TuningLevel::Exhaustive) ||
- v < static_cast<int>(TuningLevel::None))
- {
- ARMNN_LOG(warning) << "Invalid GpuAcc tuning level ("<< v << ") selected. "
- "Using default(" << static_cast<int>(defaultValue) << ")";
- } else
- {
- return static_cast<TuningLevel>(v);
- }
- }
- return defaultValue;
-}
-
-bool ParseBoolean(const BackendOptions::Var& value, bool defaultValue)
-{
- if (value.IsBool())
- {
- return value.AsBool();
- }
- return defaultValue;
-}
-
-std::string ParseFile(const BackendOptions::Var& value, std::string defaultValue)
-{
- if (value.IsString())
- {
- return value.AsString();
- }
- return defaultValue;
-}
-
-void ConfigureTuner(arm_compute::CLTuner &tuner, TuningLevel level)
-{
- tuner.set_tune_new_kernels(true); // Turn on tuning initially.
-
- switch (level)
- {
- case TuningLevel::Rapid:
- ARMNN_LOG(info) << "Gpu tuning is activated. TuningLevel: Rapid (1)";
- tuner.set_tuner_mode(arm_compute::CLTunerMode::RAPID);
- break;
- case TuningLevel::Normal:
- ARMNN_LOG(info) << "Gpu tuning is activated. TuningLevel: Normal (2)";
- tuner.set_tuner_mode(arm_compute::CLTunerMode::NORMAL);
- break;
- case TuningLevel::Exhaustive:
- ARMNN_LOG(info) << "Gpu tuning is activated. TuningLevel: Exhaustive (3)";
- tuner.set_tuner_mode(arm_compute::CLTunerMode::EXHAUSTIVE);
- break;
- case TuningLevel::None:
- default:
- tuner.set_tune_new_kernels(false); // Turn off tuning. Set to "use" only mode.
- break;
- }
-}
-
ClBackendContext::ClBackendContext(const IRuntime::CreationOptions& options)
: IBackendContext(options)
, m_TuningFile()
@@ -171,7 +68,6 @@ ClBackendContext::ClBackendContext(const IRuntime::CreationOptions& options)
arm_compute::CLTuner* tuner = nullptr;
arm_compute::CLGEMMHeuristicsHandle* mlgoTuner = nullptr;
bool useLegacyTunerAPI = options.m_GpuAccTunedParameters.get() != nullptr;
-
if (useLegacyTunerAPI)
{
auto clTunerParams = PolymorphicDowncast<ClTunedParameters*>(
@@ -217,17 +113,17 @@ ClBackendContext::ClBackendContext(const IRuntime::CreationOptions& options)
{
if (name == "KernelProfilingEnabled")
{
- kernelProfiling |= ParseBoolean(value, false);
+ kernelProfiling |= ParseBooleanBackendOption(value, false);
} else if (name == "TuningFile")
{
- m_TuningFile = ParseFile(value, "");
+ m_TuningFile = ParseStringBackendOption(value, "");
} else if (name == "TuningLevel")
{
tuningLevel = ParseTuningLevel(value, defaultTuningLevel);
}
else if (name == "MLGOTuningFilePath")
{
- m_MLGOTuningFile = ParseFile(value, "");
+ m_MLGOTuningFile = ParseStringBackendOption(value, "");
}
});
@@ -272,12 +168,11 @@ ClBackendContext::ClBackendContext(const IRuntime::CreationOptions& options)
tuner = m_Tuner.get();
}
- m_ClContextControlWrapper = Get();
-
- if (!m_ClContextControlWrapper->IsInitialised())
- {
- m_ClContextControlWrapper->Init(tuner, mlgoTuner, kernelProfiling);
- }
+ m_ClContextControlWrapper = std::make_unique<ClContextControlWrapper>(
+ tuner,
+ mlgoTuner,
+ kernelProfiling
+ );
}
bool ClBackendContext::BeforeLoadNetwork(NetworkId)
diff --git a/src/backends/cl/ClBackendContext.hpp b/src/backends/cl/ClBackendContext.hpp
index 276067727b..659d47b7c2 100644
--- a/src/backends/cl/ClBackendContext.hpp
+++ b/src/backends/cl/ClBackendContext.hpp
@@ -31,11 +31,8 @@ public:
private:
std::mutex m_Mutex;
-
struct ClContextControlWrapper;
- static std::shared_ptr<ClBackendContext::ClContextControlWrapper> Get();
-
- std::shared_ptr<ClBackendContext::ClContextControlWrapper> m_ClContextControlWrapper;
+ std::unique_ptr<ClContextControlWrapper> m_ClContextControlWrapper;
std::unordered_set<NetworkId> m_NetworkIds;
diff --git a/src/backends/cl/ClContextControl.cpp b/src/backends/cl/ClContextControl.cpp
index fd2d0f53eb..34eca961b4 100644
--- a/src/backends/cl/ClContextControl.cpp
+++ b/src/backends/cl/ClContextControl.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017, 2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -166,55 +166,4 @@ void ClContextControl::ClearClCache()
DoLoadOpenClRuntime(true);
}
-armnn::IGpuAccTunedParameters* IGpuAccTunedParameters::CreateRaw(armnn::IGpuAccTunedParameters::Mode mode,
- armnn::IGpuAccTunedParameters::TuningLevel tuningLevel)
-{
- return new ClTunedParameters(mode, tuningLevel);
-}
-
-armnn::IGpuAccTunedParametersPtr IGpuAccTunedParameters::Create(armnn::IGpuAccTunedParameters::Mode mode,
- armnn::IGpuAccTunedParameters::TuningLevel tuningLevel)
-{
- return IGpuAccTunedParametersPtr(CreateRaw(mode, tuningLevel), &IGpuAccTunedParameters::Destroy);
-}
-
-void IGpuAccTunedParameters::Destroy(IGpuAccTunedParameters* params)
-{
- delete params;
-}
-
-ClTunedParameters::ClTunedParameters(armnn::IGpuAccTunedParameters::Mode mode,
- armnn::IGpuAccTunedParameters::TuningLevel tuningLevel)
- : m_Mode(mode)
- , m_TuningLevel(tuningLevel)
- , m_Tuner(mode == ClTunedParameters::Mode::UpdateTunedParameters)
-{
-}
-
-void ClTunedParameters::Load(const char* filename)
-{
- try
- {
- m_Tuner.load_from_file(filename);
- }
- catch (const std::exception& e)
- {
- throw armnn::Exception(std::string("Failed to load tuned parameters file '") + filename + "': " +
- e.what());
- }
-}
-
-void ClTunedParameters::Save(const char* filename) const
-{
- try
- {
- m_Tuner.save_to_file(filename);
- }
- catch (const std::exception& e)
- {
- throw armnn::Exception(std::string("Failed to save tuned parameters file to '") + filename + "': " +
- e.what());
- }
-}
-
} // namespace armnn
diff --git a/src/backends/cl/ClContextControl.hpp b/src/backends/cl/ClContextControl.hpp
index 4a640cdf22..7520d102a5 100644
--- a/src/backends/cl/ClContextControl.hpp
+++ b/src/backends/cl/ClContextControl.hpp
@@ -1,13 +1,10 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017, 2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
-#include "armnn/IRuntime.hpp"
-
-#include <arm_compute/runtime/CL/CLTuner.h>
-#include <arm_compute/runtime/CL/CLGEMMHeuristicsHandle.h>
+#include <aclCommon/ArmComputeTuningUtils.hpp>
namespace armnn
{
@@ -42,19 +39,4 @@ private:
bool m_ProfilingEnabled;
};
-class ClTunedParameters : public IGpuAccTunedParameters
-{
-public:
- ClTunedParameters(armnn::IGpuAccTunedParameters::Mode mode, armnn::IGpuAccTunedParameters::TuningLevel tuningLevel);
-
- virtual void Load(const char* filename);
- virtual void Save(const char* filename) const;
-
- Mode m_Mode;
- TuningLevel m_TuningLevel;
-
- arm_compute::CLTuner m_Tuner;
- arm_compute::CLGEMMHeuristicsHandle m_HeuristicsHandle;
-};
-
} // namespace armnn
diff --git a/src/backends/cl/ClImportTensorHandle.hpp b/src/backends/cl/ClImportTensorHandle.hpp
index 889a2ad5f3..a03a4e9ea6 100644
--- a/src/backends/cl/ClImportTensorHandle.hpp
+++ b/src/backends/cl/ClImportTensorHandle.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -19,7 +19,7 @@
#include <arm_compute/core/TensorShape.h>
#include <arm_compute/core/Coordinates.h>
-#include <cl/IClTensorHandle.hpp>
+#include <aclCommon/IClTensorHandle.hpp>
#include <CL/cl_ext.h>
#include <arm_compute/core/CL/CLKernelLibrary.h>
diff --git a/src/backends/cl/ClLayerSupport.cpp b/src/backends/cl/ClLayerSupport.cpp
index a61a5bb640..cb2d756037 100644
--- a/src/backends/cl/ClLayerSupport.cpp
+++ b/src/backends/cl/ClLayerSupport.cpp
@@ -22,6 +22,7 @@
#include "workloads/ClAdditionWorkload.hpp"
#include "workloads/ClActivationWorkload.hpp"
#include "workloads/ClArgMinMaxWorkload.hpp"
+#include "workloads/ClBatchMatMulWorkload.hpp"
#include "workloads/ClBatchNormalizationFloatWorkload.hpp"
#include "workloads/ClBatchToSpaceNdWorkload.hpp"
#include "workloads/ClCastWorkload.hpp"
@@ -201,6 +202,12 @@ bool ClLayerSupport::IsLayerSupported(const LayerType& type,
infos[1],
*(PolymorphicDowncast<const ArgMinMaxDescriptor*>(&descriptor)),
reasonIfUnsupported);
+ case LayerType::BatchMatMul:
+ return IsBatchMatMulSupported(infos[0],
+ infos[1],
+ infos[2],
+ *(PolymorphicDowncast<const BatchMatMulDescriptor*>(&descriptor)),
+ reasonIfUnsupported);
case LayerType::BatchNormalization:
return IsBatchNormalizationSupported(infos[0],
infos[1],
@@ -640,6 +647,20 @@ bool ClLayerSupport::IsArgMinMaxSupported(const TensorInfo& input,
descriptor);
}
+bool ClLayerSupport::IsBatchMatMulSupported(const TensorInfo& inputX,
+ const TensorInfo& inputY,
+ const TensorInfo& output,
+ const BatchMatMulDescriptor& descriptor,
+ Optional<std::string&> reasonIfUnsupported) const
+{
+ FORWARD_WORKLOAD_VALIDATE_FUNC(ClBatchMatMulValidate,
+ reasonIfUnsupported,
+ inputX,
+ inputY,
+ output,
+ descriptor);
+}
+
bool ClLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
const TensorInfo& output,
const TensorInfo& mean,
diff --git a/src/backends/cl/ClLayerSupport.hpp b/src/backends/cl/ClLayerSupport.hpp
index 27311f74aa..2d784e3df8 100644
--- a/src/backends/cl/ClLayerSupport.hpp
+++ b/src/backends/cl/ClLayerSupport.hpp
@@ -40,6 +40,12 @@ public:
const ArgMinMaxDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+ bool IsBatchMatMulSupported(const TensorInfo& inputX,
+ const TensorInfo& inputY,
+ const TensorInfo& output,
+ const BatchMatMulDescriptor& descriptor,
+ Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const;
+
bool IsBatchNormalizationSupported(const TensorInfo& input,
const TensorInfo& output,
const TensorInfo& mean,
diff --git a/src/backends/cl/ClTensorHandle.hpp b/src/backends/cl/ClTensorHandle.hpp
index f63f1faa07..3d750f9059 100644
--- a/src/backends/cl/ClTensorHandle.hpp
+++ b/src/backends/cl/ClTensorHandle.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017, 2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
@@ -18,7 +18,7 @@
#include <arm_compute/core/TensorShape.h>
#include <arm_compute/core/Coordinates.h>
-#include <cl/IClTensorHandle.hpp>
+#include <aclCommon/IClTensorHandle.hpp>
namespace armnn
{
diff --git a/src/backends/cl/ClTensorHandleFactory.cpp b/src/backends/cl/ClTensorHandleFactory.cpp
index b8ee57f0bf..82e41d3ff6 100644
--- a/src/backends/cl/ClTensorHandleFactory.cpp
+++ b/src/backends/cl/ClTensorHandleFactory.cpp
@@ -108,12 +108,12 @@ bool ClTensorHandleFactory::SupportsSubTensors() const
MemorySourceFlags ClTensorHandleFactory::GetExportFlags() const
{
- return m_ExportFlags;
+ return MemorySourceFlags(MemorySource::Undefined);
}
MemorySourceFlags ClTensorHandleFactory::GetImportFlags() const
{
- return m_ImportFlags;
+ return MemorySourceFlags(MemorySource::Undefined);
}
-} // namespace armnn \ No newline at end of file
+} // namespace armnn
diff --git a/src/backends/cl/ClTensorHandleFactory.hpp b/src/backends/cl/ClTensorHandleFactory.hpp
index 3acab0bce7..8e1c7a8a02 100644
--- a/src/backends/cl/ClTensorHandleFactory.hpp
+++ b/src/backends/cl/ClTensorHandleFactory.hpp
@@ -24,8 +24,6 @@ public:
ClTensorHandleFactory(std::shared_ptr<ClMemoryManager> mgr)
: m_MemoryManager(mgr)
- , m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Undefined))
- , m_ExportFlags(static_cast<MemorySourceFlags>(MemorySource::Undefined))
{}
std::unique_ptr<ITensorHandle> CreateSubTensorHandle(ITensorHandle& parent,
@@ -56,8 +54,6 @@ public:
private:
mutable std::shared_ptr<ClMemoryManager> m_MemoryManager;
- MemorySourceFlags m_ImportFlags;
- MemorySourceFlags m_ExportFlags;
};
-} // namespace armnn \ No newline at end of file
+} // namespace armnn
diff --git a/src/backends/cl/ClWorkloadFactory.cpp b/src/backends/cl/ClWorkloadFactory.cpp
index d0079abd38..6bf510a2ef 100644
--- a/src/backends/cl/ClWorkloadFactory.cpp
+++ b/src/backends/cl/ClWorkloadFactory.cpp
@@ -265,6 +265,11 @@ std::unique_ptr<IWorkload> ClWorkloadFactory::CreateWorkload(LayerType type,
auto argMinMaxQueueDescriptor = PolymorphicDowncast<const ArgMinMaxQueueDescriptor*>(&descriptor);
return MakeWorkload<ClArgMinMaxWorkload>(*argMinMaxQueueDescriptor, info, m_CLCompileContext);
}
+ case LayerType::BatchMatMul :
+ {
+ auto batchMatMulQueueDescriptor = PolymorphicDowncast<const BatchMatMulQueueDescriptor*>(&descriptor);
+ return std::make_unique<ClBatchMatMulWorkload>(*batchMatMulQueueDescriptor, info, m_CLCompileContext);
+ }
case LayerType::BatchNormalization :
{
auto batchNormalizationQueueDescriptor
diff --git a/src/backends/cl/IClTensorHandle.hpp b/src/backends/cl/IClTensorHandle.hpp
deleted file mode 100644
index 48cf5f57d6..0000000000
--- a/src/backends/cl/IClTensorHandle.hpp
+++ /dev/null
@@ -1,22 +0,0 @@
-//
-// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-#pragma once
-
-#include <arm_compute/core/CL/ICLTensor.h>
-#include <arm_compute/runtime/MemoryGroup.h>
-
-namespace armnn
-{
-
-class IClTensorHandle : public IAclTensorHandle
-{
-public:
- virtual arm_compute::ICLTensor& GetTensor() = 0;
- virtual arm_compute::ICLTensor const& GetTensor() const = 0;
- virtual arm_compute::DataType GetDataType() const = 0;
- virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) = 0;
-};
-
-} //namespace armnn \ No newline at end of file
diff --git a/src/backends/cl/backend.mk b/src/backends/cl/backend.mk
index 6fda16db05..1f97ae7cc8 100644
--- a/src/backends/cl/backend.mk
+++ b/src/backends/cl/backend.mk
@@ -30,6 +30,7 @@ BACKEND_SOURCES := \
workloads/ClActivationWorkload.cpp \
workloads/ClAdditionWorkload.cpp \
workloads/ClArgMinMaxWorkload.cpp \
+ workloads/ClBatchMatMulWorkload.cpp \
workloads/ClBatchNormalizationFloatWorkload.cpp \
workloads/ClBatchToSpaceNdWorkload.cpp \
workloads/ClCastWorkload.cpp \
diff --git a/src/backends/cl/test/CMakeLists.txt b/src/backends/cl/test/CMakeLists.txt
index ec1d0a6c2f..6568d48ce5 100644
--- a/src/backends/cl/test/CMakeLists.txt
+++ b/src/backends/cl/test/CMakeLists.txt
@@ -1,5 +1,5 @@
#
-# Copyright © 2017 Arm Ltd. All rights reserved.
+# Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
# SPDX-License-Identifier: MIT
#
@@ -8,6 +8,7 @@ list(APPEND armnnClBackendUnitTests_sources
ClContextControlFixture.hpp
ClContextSerializerTests.cpp
ClCustomAllocatorTests.cpp
+ ClDefaultAllocatorTests.cpp
ClCreateWorkloadTests.cpp
ClEndToEndTests.cpp
ClImportTensorHandleFactoryTests.cpp
@@ -18,7 +19,6 @@ list(APPEND armnnClBackendUnitTests_sources
ClOptimizedNetworkTests.cpp
ClRuntimeTests.cpp
ClWorkloadFactoryHelper.hpp
- DefaultAllocatorTests.cpp
Fp16SupportTest.cpp
ICLTensorProxyTests.cpp
OpenClTimerTest.cpp
diff --git a/src/backends/cl/test/DefaultAllocatorTests.cpp b/src/backends/cl/test/ClDefaultAllocatorTests.cpp
index eaa30c8800..411a480815 100644
--- a/src/backends/cl/test/DefaultAllocatorTests.cpp
+++ b/src/backends/cl/test/ClDefaultAllocatorTests.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2021, 2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
diff --git a/src/backends/cl/test/ClLayerTests.cpp b/src/backends/cl/test/ClLayerTests.cpp
index 855697c9be..4ba2a9ec3b 100644
--- a/src/backends/cl/test/ClLayerTests.cpp
+++ b/src/backends/cl/test/ClLayerTests.cpp
@@ -73,6 +73,29 @@ ARMNN_AUTO_TEST_FIXTURE_WITH_THF(Tanh, ClContextControlFixture, TanhTest)
// Elu Activation
ARMNN_AUTO_TEST_FIXTURE_WITH_THF(Elu, ClContextControlFixture, EluTest)
+// Batch Mat Mul
+ARMNN_AUTO_TEST_FIXTURE_WITH_THF(BatchMatMul2DSimpleFloat32,
+ ClContextControlFixture,
+ BatchMatMul2DSimpleTest<DataType::Float32>);
+ARMNN_AUTO_TEST_FIXTURE_WITH_THF(BatchMatMul3DSimpleFloat32,
+ ClContextControlFixture,
+ BatchMatMul3DSimpleTest<DataType::Float32>);
+ARMNN_AUTO_TEST_FIXTURE_WITH_THF(BatchMatMul3DBatchFloat32,
+ ClContextControlFixture,
+ BatchMatMul3DBatchTest<DataType::Float32>);
+ARMNN_AUTO_TEST_FIXTURE_WITH_THF(BatchMatMul3DBroadcastFloat32,
+ ClContextControlFixture,
+ BatchMatMul3DBroadcastTest<DataType::Float32>);
+ARMNN_AUTO_TEST_FIXTURE_WITH_THF(BatchMatMul3D2DBroadcastFloat32,
+ ClContextControlFixture,
+ BatchMatMul3D2DBroadcastTest<DataType::Float32>);
+ARMNN_AUTO_TEST_FIXTURE_WITH_THF(BatchMatMul2DTinyFloat32,
+ ClContextControlFixture,
+ BatchMatMul2DTinyTest<DataType::Float32>);
+ARMNN_AUTO_TEST_FIXTURE_WITH_THF(BatchMatMul2DTranspSimpleFloat32,
+ ClContextControlFixture,
+ BatchMatMul2DTranspSimpleTest<DataType::Float32>);
+
// Batch To Space
ARMNN_AUTO_TEST_FIXTURE_WITH_THF(BatchToSpaceNdNhwcFloat321,
ClContextControlFixture,
diff --git a/src/backends/cl/workloads/CMakeLists.txt b/src/backends/cl/workloads/CMakeLists.txt
index aef7fc7ad2..8616dec078 100644
--- a/src/backends/cl/workloads/CMakeLists.txt
+++ b/src/backends/cl/workloads/CMakeLists.txt
@@ -12,6 +12,8 @@ list(APPEND armnnClBackendWorkloads_sources
ClAdditionWorkload.hpp
ClArgMinMaxWorkload.cpp
ClArgMinMaxWorkload.hpp
+ ClBatchMatMulWorkload.cpp
+ ClBatchMatMulWorkload.hpp
ClBatchNormalizationFloatWorkload.cpp
ClBatchNormalizationFloatWorkload.hpp
ClBatchToSpaceNdWorkload.cpp
diff --git a/src/backends/cl/workloads/ClBatchMatMulWorkload.cpp b/src/backends/cl/workloads/ClBatchMatMulWorkload.cpp
new file mode 100644
index 0000000000..4acdef5e5c
--- /dev/null
+++ b/src/backends/cl/workloads/ClBatchMatMulWorkload.cpp
@@ -0,0 +1,203 @@
+//
+// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "ClBatchMatMulWorkload.hpp"
+
+#include "ClWorkloadUtils.hpp"
+
+#include <aclCommon/ArmComputeTensorUtils.hpp>
+#include <aclCommon/ArmComputeUtils.hpp>
+
+#include <armnn/utility/PolymorphicDowncast.hpp>
+
+#include <armnnUtils/Permute.hpp>
+
+#include <backendsCommon/WorkloadUtils.hpp>
+
+#include <cl/ClTensorHandle.hpp>
+
+#include <arm_compute/runtime/CL/functions/CLGEMM.h>
+#include <arm_compute/runtime/CL/functions/CLPermute.h>
+
+
+namespace armnn
+{
+arm_compute::Status ClBatchMatMulValidate(const TensorInfo& inputX,
+ const TensorInfo& inputY,
+ const TensorInfo& output,
+ const BatchMatMulDescriptor& descriptor)
+{
+ if (descriptor.m_AdjointX || descriptor.m_AdjointY )
+ {
+ throw Exception("Support for adjoint not implemented.");
+ }
+ if (descriptor.m_DataLayoutX != armnn::DataLayout::NCHW || descriptor.m_DataLayoutY != armnn::DataLayout::NCHW )
+ {
+ throw Exception("Only supported the MatMul in the last 2 dimensions");
+ }
+
+ arm_compute::Status statusGEMM = arm_compute::Status(arm_compute::ErrorCode::OK);
+ arm_compute::Status statusPermuteX = arm_compute::Status(arm_compute::ErrorCode::OK);
+ arm_compute::Status statusPermuteY = arm_compute::Status(arm_compute::ErrorCode::OK);
+
+ const auto aclInputXInfo = armcomputetensorutils::BuildArmComputeTensorInfo(inputX, descriptor.m_DataLayoutX);
+ const auto aclInputYInfo = armcomputetensorutils::BuildArmComputeTensorInfo(inputY, descriptor.m_DataLayoutY);
+ const auto aclOutputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(output);
+
+ arm_compute::TensorInfo aclPermutedXInfo = arm_compute::TensorInfo();
+ arm_compute::TensorInfo aclPermutedYInfo = arm_compute::TensorInfo();
+
+ if (descriptor.m_TransposeX == true)
+ {
+ auto permutationXVector = GeneratePermutationVectorOnLastTwoDimensions(inputX.GetNumDimensions());
+ const auto aclPermutationXVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationXVector);
+ const TensorInfo permutedXInfo = armnnUtils::Permuted(inputX, permutationXVector);
+ aclPermutedXInfo = armcomputetensorutils::BuildArmComputeTensorInfo(permutedXInfo);
+
+ statusPermuteX = arm_compute::CLPermute::validate(&aclInputXInfo,
+ &aclPermutedXInfo,
+ aclPermutationXVector);
+ }
+
+ if ( descriptor.m_TransposeY == true)
+ {
+ auto permutationYVector = GeneratePermutationVectorOnLastTwoDimensions(inputY.GetNumDimensions());
+ const auto aclPermutationYVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationYVector);
+ const TensorInfo permutedYInfo = armnnUtils::Permuted(inputY, permutationYVector);
+ aclPermutedYInfo = armcomputetensorutils::BuildArmComputeTensorInfo(permutedYInfo);
+
+ statusPermuteY = arm_compute::CLPermute::validate(&aclInputYInfo,
+ &aclPermutedYInfo,
+ aclPermutationYVector);
+
+ }
+
+ const arm_compute::GEMMInfo& gemm_info = arm_compute::GEMMInfo(false, // is inputX reshaped
+ false, // is inputY reshaped
+ false); // is inputY reshaped only 1st run
+
+
+ statusGEMM = arm_compute::CLGEMM::validate(descriptor.m_TransposeX ? &aclPermutedXInfo : &aclInputXInfo,
+ descriptor.m_TransposeY ? &aclPermutedYInfo : &aclInputYInfo,
+ nullptr,
+ &aclOutputInfo,
+ 1.0,
+ 0,
+ gemm_info);
+
+ if (statusPermuteX.error_code() == arm_compute::ErrorCode::OK &&
+ statusPermuteY.error_code() == arm_compute::ErrorCode::OK &&
+ statusGEMM.error_code() == arm_compute::ErrorCode::OK)
+ {
+ return arm_compute::Status(arm_compute::ErrorCode::OK,
+ "All Batch Mat Mul layers validate status OK.");
+ }
+ else
+ {
+ return arm_compute::Status(arm_compute::ErrorCode::RUNTIME_ERROR,
+ "BatchMatMul layer validate status failed."
+ + statusGEMM.error_description()
+ + statusPermuteX.error_description()
+ + statusPermuteY.error_description());
+ }
+
+}
+
+ClBatchMatMulWorkload::ClBatchMatMulWorkload(const BatchMatMulQueueDescriptor& descriptor,
+ const WorkloadInfo& info,
+ const arm_compute::CLCompileContext& clCompileContext)
+ : ClBaseWorkload<BatchMatMulQueueDescriptor>(descriptor, info)
+{
+ // Report Profiling Details
+ ARMNN_REPORT_PROFILING_WORKLOAD_DESC("ClBatchMatMulWorkload_Construct",
+ descriptor.m_Parameters,
+ info,
+ this->GetGuid());
+
+ if (descriptor.m_Parameters.m_AdjointX || descriptor.m_Parameters.m_AdjointY )
+ {
+ throw Exception("Support for adjoint not implemented.");
+ }
+ if (descriptor.m_Parameters.m_DataLayoutX != armnn::DataLayout::NCHW ||
+ descriptor.m_Parameters.m_DataLayoutY != armnn::DataLayout::NCHW )
+ {
+ throw Exception("Only supported the MatMul in the last 2 dimensions");
+ }
+
+ m_Data.ValidateInputsOutputs("ClBatchMatMulWorkload", 2, 1);
+
+ const arm_compute::ICLTensor& inputX = PolymorphicDowncast<ClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
+ const arm_compute::ICLTensor& inputY = PolymorphicDowncast<ClTensorHandle*>(m_Data.m_Inputs[1])->GetTensor();
+ arm_compute::ICLTensor& output = PolymorphicDowncast<ClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
+
+ inputX.info()->set_data_layout(armcomputetensorutils::ConvertDataLayout(m_Data.m_Parameters.m_DataLayoutX));
+ inputY.info()->set_data_layout(armcomputetensorutils::ConvertDataLayout(m_Data.m_Parameters.m_DataLayoutY));
+
+ arm_compute::TensorInfo aclPermutedXInfo = arm_compute::TensorInfo();
+ arm_compute::TensorInfo aclPermutedYInfo = arm_compute::TensorInfo();
+
+ if (descriptor.m_Parameters.m_TransposeX == true)
+ {
+ armnn::PermutationVector permutationXVector
+ = GeneratePermutationVectorOnLastTwoDimensions(info.m_InputTensorInfos[0].GetNumDimensions());
+ const TensorInfo permutedXInfo = armnnUtils::Permuted(info.m_InputTensorInfos[0], permutationXVector);
+ const auto aclPermutationXVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationXVector);
+ armcomputetensorutils::BuildArmComputeTensor(m_PermutedTensorX, permutedXInfo);
+ armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_PermutedTensorX);
+
+ auto permuteLayerX = std::make_unique<arm_compute::CLPermute>();
+ permuteLayerX->configure(clCompileContext,
+ &inputX,
+ &m_PermutedTensorX,
+ aclPermutationXVector);
+ m_PermuteLayerX.reset(permuteLayerX.release());
+ }
+
+ if (descriptor.m_Parameters.m_TransposeY == true)
+ {
+ armnn::PermutationVector permutationYVector
+ = GeneratePermutationVectorOnLastTwoDimensions(info.m_InputTensorInfos[0].GetNumDimensions());
+ const TensorInfo permutedYInfo = armnnUtils::Permuted(info.m_InputTensorInfos[0], permutationYVector);
+ const auto aclPermutationYVector = armcomputetensorutils::BuildArmComputePermutationVector(permutationYVector);
+ armcomputetensorutils::BuildArmComputeTensor(m_PermutedTensorY, permutedYInfo);
+ armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_PermutedTensorY);
+
+ std::unique_ptr<arm_compute::CLPermute> permuteLayerY(new arm_compute::CLPermute());
+ permuteLayerY->configure(clCompileContext,
+ &inputY,
+ &m_PermutedTensorY,
+ aclPermutationYVector);
+ m_PermuteLayerY.reset(permuteLayerY.release());
+ }
+
+ const arm_compute::GEMMInfo& gemm_info = arm_compute::GEMMInfo(false, // is inputX reshaped
+ false, // is inputY reshaped
+ false); // is inputY reshaped only 1st run
+ auto gemmLayer = std::make_unique<arm_compute::CLGEMM>();
+ gemmLayer->configure(clCompileContext,
+ descriptor.m_Parameters.m_TransposeX ? &m_PermutedTensorX : &inputX,
+ descriptor.m_Parameters.m_TransposeY ? &m_PermutedTensorY : &inputY,
+ nullptr,
+ &output,
+ 1.0,
+ 0,
+ gemm_info);
+ m_GEMMLayer.reset(gemmLayer.release());
+}
+
+void ClBatchMatMulWorkload::Execute() const
+{
+ ARMNN_SCOPED_PROFILING_EVENT_CL_GUID("ClBatchMatMulWorkload_Execute", this->GetGuid());
+ if (m_PermuteLayerX)
+ {
+ m_PermuteLayerX->run();
+ }
+ if (m_PermuteLayerY)
+ {
+ m_PermuteLayerY->run();
+ }
+ m_GEMMLayer->run();
+}
+} //namespace armnn
diff --git a/src/backends/cl/workloads/ClBatchMatMulWorkload.hpp b/src/backends/cl/workloads/ClBatchMatMulWorkload.hpp
new file mode 100644
index 0000000000..5277efc947
--- /dev/null
+++ b/src/backends/cl/workloads/ClBatchMatMulWorkload.hpp
@@ -0,0 +1,41 @@
+//
+// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "ClBaseWorkload.hpp"
+
+#include <arm_compute/runtime/IFunction.h>
+#include <arm_compute/runtime/CL/CLTensor.h>
+#include <memory>
+
+namespace armnn
+{
+ arm_compute::Status ClBatchMatMulValidate(const TensorInfo& inputX,
+ const TensorInfo& inputY,
+ const TensorInfo& output,
+ const BatchMatMulDescriptor& descriptor);
+
+ class ClBatchMatMulWorkload : public ClBaseWorkload<BatchMatMulQueueDescriptor>
+ {
+ public:
+ ClBatchMatMulWorkload(const BatchMatMulQueueDescriptor& descriptor,
+ const WorkloadInfo& info,
+ const arm_compute::CLCompileContext& clCompileContext);
+ virtual void Execute() const override;
+
+ private:
+ // ACL layers required to fully form a Batch Mat Mul layer.
+ std::unique_ptr<arm_compute::IFunction> m_GEMMLayer;
+ std::unique_ptr<arm_compute::IFunction> m_PermuteLayerX;
+ std::unique_ptr<arm_compute::IFunction> m_PermuteLayerY;
+
+ // Additional CL arm_compute::Tensors.
+ // Required to perform permutations.
+ arm_compute::CLTensor m_PermutedTensorX;
+ arm_compute::CLTensor m_PermutedTensorY;
+
+ };
+} //namespace armnn
diff --git a/src/backends/cl/workloads/ClWorkloads.hpp b/src/backends/cl/workloads/ClWorkloads.hpp
index c3a79b7583..44f3798d7d 100644
--- a/src/backends/cl/workloads/ClWorkloads.hpp
+++ b/src/backends/cl/workloads/ClWorkloads.hpp
@@ -10,6 +10,7 @@
#include "ClArgMinMaxWorkload.hpp"
#include "ClComparisonWorkload.hpp"
#include "ClConstantWorkload.hpp"
+#include "ClBatchMatMulWorkload.hpp"
#include "ClBatchNormalizationFloatWorkload.hpp"
#include "ClBatchToSpaceNdWorkload.hpp"
#include "ClCastWorkload.hpp"