aboutsummaryrefslogtreecommitdiff
path: root/src/backends/aclCommon
diff options
context:
space:
mode:
authorDavid Beck <david.beck@arm.com>2018-09-24 10:46:38 +0100
committerMatthew Bentham <matthew.bentham@arm.com>2018-10-10 16:16:57 +0100
commit711fa31d5d43b904d28bcd407cd7e921529a37ca (patch)
tree729e74dc8681a8a8ac108bf349637ebbce00ba76 /src/backends/aclCommon
parent5662c206864df4121eea29c541c24c0f62113809 (diff)
downloadarmnn-711fa31d5d43b904d28bcd407cd7e921529a37ca.tar.gz
IVGCVSW-1921: move common Acl code to a separate folder
Change-Id: I400be8e7c0cc5a31eb9d2a7396da145d50d51b6e
Diffstat (limited to 'src/backends/aclCommon')
-rw-r--r--src/backends/aclCommon/ArmComputeTensorUtils.cpp163
-rw-r--r--src/backends/aclCommon/ArmComputeTensorUtils.hpp216
-rw-r--r--src/backends/aclCommon/ArmComputeUtils.hpp125
-rw-r--r--src/backends/aclCommon/CMakeLists.txt15
-rw-r--r--src/backends/aclCommon/common.cmake9
5 files changed, 528 insertions, 0 deletions
diff --git a/src/backends/aclCommon/ArmComputeTensorUtils.cpp b/src/backends/aclCommon/ArmComputeTensorUtils.cpp
new file mode 100644
index 0000000000..d48408c430
--- /dev/null
+++ b/src/backends/aclCommon/ArmComputeTensorUtils.cpp
@@ -0,0 +1,163 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+#include <backends/aclCommon/ArmComputeTensorUtils.hpp>
+#include <backends/aclCommon/ArmComputeUtils.hpp>
+
+#include "armnn/Exceptions.hpp"
+#include <armnn/Descriptors.hpp>
+
+namespace armnn
+{
+namespace armcomputetensorutils
+{
+
+arm_compute::DataType GetArmComputeDataType(armnn::DataType dataType)
+{
+ switch(dataType)
+ {
+ case armnn::DataType::Float16:
+ return arm_compute::DataType::F16;
+ case armnn::DataType::Float32:
+ return arm_compute::DataType::F32;
+ case armnn::DataType::QuantisedAsymm8:
+ return arm_compute::DataType::QASYMM8;
+ case armnn::DataType::Signed32:
+ return arm_compute::DataType::S32;
+ default:
+ BOOST_ASSERT_MSG(false, "Unknown data type");
+ return arm_compute::DataType::UNKNOWN;
+ }
+}
+
+arm_compute::TensorShape BuildArmComputeTensorShape(const armnn::TensorShape& tensorShape)
+{
+ arm_compute::TensorShape shape;
+
+ // armnn tensors are (batch, channels, height, width).
+ // arm_compute tensors are (width, height, channels, batch).
+ for (unsigned int i = 0; i < tensorShape.GetNumDimensions(); i++)
+ {
+ // Note that our dimensions are stored in the opposite order to ACL's.
+ shape.set(tensorShape.GetNumDimensions() - i - 1, tensorShape[i]);
+
+ // TensorShape::set() flattens leading ones, so that batch size 1 cannot happen.
+ // arm_compute tensors expect this.
+ }
+
+ // prevent arm_compute issue where tensor is flattened to nothing
+ if (shape.num_dimensions() == 0)
+ {
+ shape.set_num_dimensions(1);
+ }
+
+ return shape;
+}
+
+// Utility function used to build a TensorInfo object, that can be used to initialise
+// ARM Compute Tensor and CLTensor allocators.
+arm_compute::TensorInfo BuildArmComputeTensorInfo(const armnn::TensorInfo& tensorInfo)
+{
+ const arm_compute::TensorShape aclTensorShape = BuildArmComputeTensorShape(tensorInfo.GetShape());
+ const arm_compute::DataType aclDataType = GetArmComputeDataType(tensorInfo.GetDataType());
+ const arm_compute::QuantizationInfo aclQuantizationInfo(tensorInfo.GetQuantizationScale(),
+ tensorInfo.GetQuantizationOffset());
+
+ return arm_compute::TensorInfo(aclTensorShape, 1, aclDataType, aclQuantizationInfo);
+}
+
+arm_compute::DataLayout ConvertDataLayout(armnn::DataLayout dataLayout)
+{
+ switch(dataLayout)
+ {
+ case armnn::DataLayout::NHWC : return arm_compute::DataLayout::NHWC;
+
+ case armnn::DataLayout::NCHW : return arm_compute::DataLayout::NCHW;
+
+ default: throw InvalidArgumentException("Unknown armnn::DataLayout: [" +
+ std::to_string(static_cast<int>(dataLayout)) + "]");
+ }
+}
+
+arm_compute::TensorInfo BuildArmComputeTensorInfo(const armnn::TensorInfo& tensorInfo,
+ armnn::DataLayout dataLayout)
+{
+ const arm_compute::TensorShape aclTensorShape = BuildArmComputeTensorShape(tensorInfo.GetShape());
+ const arm_compute::DataType aclDataType = GetArmComputeDataType(tensorInfo.GetDataType());
+ const arm_compute::QuantizationInfo aclQuantizationInfo(tensorInfo.GetQuantizationScale(),
+ tensorInfo.GetQuantizationOffset());
+
+ arm_compute::TensorInfo clTensorInfo(aclTensorShape, 1, aclDataType, aclQuantizationInfo);
+ clTensorInfo.set_data_layout(ConvertDataLayout(dataLayout));
+
+ return clTensorInfo;
+}
+
+arm_compute::PoolingLayerInfo BuildArmComputePoolingLayerInfo(const Pooling2dDescriptor& descriptor)
+{
+ using arm_compute::PoolingType;
+ using arm_compute::DimensionRoundingType;
+ using arm_compute::PadStrideInfo;
+ using arm_compute::PoolingLayerInfo;
+ using arm_compute::Size2D;
+
+ // Resolve ARM Compute layer parameters.
+ const PoolingType poolingType = ConvertPoolingAlgorithmToAclPoolingType(descriptor.m_PoolType);
+
+ bool isGlobalPooling = (descriptor.m_StrideX==0 && descriptor.m_StrideY==0);
+ //use specific constructor if global pooling
+ if(isGlobalPooling)
+ {
+ return arm_compute::PoolingLayerInfo(poolingType);
+ }
+
+ const DimensionRoundingType rounding = ConvertOutputShapeRoundingToAclDimensionRoundingType(
+ descriptor.m_OutputShapeRounding);
+ const PadStrideInfo padStrideInfo(descriptor.m_StrideX,
+ descriptor.m_StrideY,
+ descriptor.m_PadLeft,
+ descriptor.m_PadRight,
+ descriptor.m_PadTop,
+ descriptor.m_PadBottom,
+ rounding);
+
+ const bool excludePadding = (descriptor.m_PaddingMethod == PaddingMethod::Exclude);
+
+ const Size2D poolSize(descriptor.m_PoolWidth, descriptor.m_PoolHeight);
+
+ return arm_compute::PoolingLayerInfo(poolingType, poolSize, padStrideInfo, excludePadding);
+}
+
+arm_compute::NormalizationLayerInfo BuildArmComputeNormalizationLayerInfo(const NormalizationDescriptor& descriptor)
+{
+ const arm_compute::NormType normType =
+ ConvertNormalizationAlgorithmChannelToAclNormType(descriptor.m_NormChannelType);
+ return arm_compute::NormalizationLayerInfo(normType,
+ descriptor.m_NormSize,
+ descriptor.m_Alpha,
+ descriptor.m_Beta,
+ descriptor.m_K,
+ false);
+}
+
+arm_compute::PermutationVector BuildArmComputePermutationVector(const armnn::PermutationVector& perm)
+{
+ arm_compute::PermutationVector aclPerm;
+
+ unsigned int start = 0;
+ while ((start < perm.GetSize()) && (start == perm[start]))
+ {
+ ++start;
+ }
+
+ for (unsigned int i = start; i < perm.GetSize(); ++i)
+ {
+ aclPerm.set(i - start, perm[i] - start);
+ }
+
+ return aclPerm;
+}
+
+} // namespace armcomputetensorutils
+} // namespace armnn
diff --git a/src/backends/aclCommon/ArmComputeTensorUtils.hpp b/src/backends/aclCommon/ArmComputeTensorUtils.hpp
new file mode 100644
index 0000000000..18f41ee173
--- /dev/null
+++ b/src/backends/aclCommon/ArmComputeTensorUtils.hpp
@@ -0,0 +1,216 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+#pragma once
+
+#include <armnn/Tensor.hpp>
+#include <armnn/DescriptorsFwd.hpp>
+
+#include <arm_compute/core/ITensor.h>
+#include <arm_compute/core/TensorInfo.h>
+#include <arm_compute/core/Types.h>
+
+#include <boost/cast.hpp>
+
+namespace armnn
+{
+class ITensorHandle;
+
+namespace armcomputetensorutils
+{
+
+/// Utility function to map an armnn::DataType to corresponding arm_compute::DataType.
+arm_compute::DataType GetArmComputeDataType(armnn::DataType dataType);
+
+/// Utility function used to setup an arm_compute::TensorShape object from an armnn::TensorShape.
+arm_compute::TensorShape BuildArmComputeTensorShape(const armnn::TensorShape& tensorShape);
+
+/// Utility function used to setup an arm_compute::ITensorInfo object whose dimensions are based on the given
+/// armnn::ITensorInfo.
+arm_compute::TensorInfo BuildArmComputeTensorInfo(const armnn::TensorInfo& tensorInfo);
+
+/// Utility function used to convert armnn::DataLayout to arm_compute::DataLayout
+/// armnn::DataLayout.
+arm_compute::DataLayout ConvertDataLayout(armnn::DataLayout dataLayout);
+
+/// Utility function used to setup an arm_compute::ITensorInfo object whose dimensions are based on the given
+/// armnn::ITensorInfo.
+/// armnn::DataLayout.
+arm_compute::TensorInfo BuildArmComputeTensorInfo(const armnn::TensorInfo& tensorInfo,
+ armnn::DataLayout dataLayout);
+
+/// Utility function used to setup an arm_compute::PoolingLayerInfo object from an armnn::Pooling2dDescriptor.
+arm_compute::PoolingLayerInfo BuildArmComputePoolingLayerInfo(const Pooling2dDescriptor& descriptor);
+
+/// Utility function to setup an arm_compute::NormalizationLayerInfo object from an armnn::NormalizationDescriptor.
+arm_compute::NormalizationLayerInfo BuildArmComputeNormalizationLayerInfo(const NormalizationDescriptor& desc);
+
+/// Utility function used to setup an arm_compute::PermutationVector object from an armnn::PermutationVector.
+arm_compute::PermutationVector BuildArmComputePermutationVector(const armnn::PermutationVector& vector);
+
+/// Utility function used to setup an arm_compute::PadStrideInfo object from an armnn layer descriptor.
+template <typename Descriptor>
+arm_compute::PadStrideInfo BuildArmComputePadStrideInfo(const Descriptor &descriptor)
+{
+ return arm_compute::PadStrideInfo(descriptor.m_StrideX,
+ descriptor.m_StrideY,
+ descriptor.m_PadLeft,
+ descriptor.m_PadRight,
+ descriptor.m_PadTop,
+ descriptor.m_PadBottom,
+ arm_compute::DimensionRoundingType::FLOOR);
+}
+
+/// Sets up the given ArmCompute tensor's dimensions based on the given ArmNN tensor.
+template <typename Tensor>
+void BuildArmComputeTensor(Tensor& tensor, const armnn::TensorInfo& tensorInfo)
+{
+ tensor.allocator()->init(BuildArmComputeTensorInfo(tensorInfo));
+}
+
+/// Sets up the given ArmCompute tensor's dimensions based on the given ArmNN tensor.
+template <typename Tensor>
+void BuildArmComputeTensor(Tensor& tensor, const armnn::TensorInfo& tensorInfo, DataLayout dataLayout)
+{
+ tensor.allocator()->init(BuildArmComputeTensorInfo(tensorInfo, dataLayout));
+}
+
+template <typename Tensor>
+void InitialiseArmComputeTensorEmpty(Tensor& tensor)
+{
+ tensor.allocator()->allocate();
+}
+
+/// Utility function to free unused tensors after a workload is configured and prepared
+template <typename Tensor>
+void FreeTensorIfUnused(std::unique_ptr<Tensor>& tensor)
+{
+ if (tensor && !tensor->is_used())
+ {
+ tensor.reset(nullptr);
+ }
+}
+
+// Helper function to obtain byte offset into tensor data
+inline size_t GetTensorOffset(const arm_compute::ITensorInfo& info,
+ uint32_t batchIndex,
+ uint32_t channelIndex,
+ uint32_t y,
+ uint32_t x)
+{
+ arm_compute::Coordinates coords;
+ coords.set(3, static_cast<int>(batchIndex));
+ coords.set(2, static_cast<int>(channelIndex));
+ coords.set(1, static_cast<int>(y));
+ coords.set(0, static_cast<int>(x));
+ return info.offset_element_in_bytes(coords);
+}
+
+// Helper function to obtain element offset into data buffer representing tensor data (assuming no strides).
+inline size_t GetLinearBufferOffset(const arm_compute::ITensorInfo& info,
+ uint32_t batchIndex,
+ uint32_t channelIndex,
+ uint32_t y,
+ uint32_t x)
+{
+ const arm_compute::TensorShape& shape = info.tensor_shape();
+ uint32_t width = static_cast<uint32_t>(shape[0]);
+ uint32_t height = static_cast<uint32_t>(shape[1]);
+ uint32_t numChannels = static_cast<uint32_t>(shape[2]);
+ return ((batchIndex * numChannels + channelIndex) * height + y) * width + x;
+}
+
+template <typename T>
+void CopyArmComputeITensorData(const arm_compute::ITensor& srcTensor, T* dstData)
+{
+ // If MaxNumOfTensorDimensions is increased, this loop will need fixing.
+ static_assert(MaxNumOfTensorDimensions == 4, "Please update CopyArmComputeITensorData");
+ {
+ const arm_compute::ITensorInfo& info = *srcTensor.info();
+ const arm_compute::TensorShape& shape = info.tensor_shape();
+ const uint8_t* const bufferPtr = srcTensor.buffer();
+ uint32_t width = static_cast<uint32_t>(shape[0]);
+ uint32_t height = static_cast<uint32_t>(shape[1]);
+ uint32_t numChannels = static_cast<uint32_t>(shape[2]);
+ uint32_t numBatches = static_cast<uint32_t>(shape[3]);
+
+ for (unsigned int batchIndex = 0; batchIndex < numBatches; ++batchIndex)
+ {
+ for (unsigned int channelIndex = 0; channelIndex < numChannels; ++channelIndex)
+ {
+ for (unsigned int y = 0; y < height; ++y)
+ {
+ // Copies one row from arm_compute tensor buffer to linear memory buffer.
+ // A row is the largest contiguous region we can copy, as the tensor data may be using strides.
+ memcpy(dstData + GetLinearBufferOffset(info, batchIndex, channelIndex, y, 0),
+ bufferPtr + GetTensorOffset(info, batchIndex, channelIndex, y, 0),
+ width * sizeof(T));
+ }
+ }
+ }
+ }
+}
+
+template <typename T>
+void CopyArmComputeITensorData(const T* srcData, arm_compute::ITensor& dstTensor)
+{
+ // If MaxNumOfTensorDimensions is increased, this loop will need fixing.
+ static_assert(MaxNumOfTensorDimensions == 4, "Please update CopyArmComputeITensorData");
+ {
+ const arm_compute::ITensorInfo& info = *dstTensor.info();
+ const arm_compute::TensorShape& shape = info.tensor_shape();
+ uint8_t* const bufferPtr = dstTensor.buffer();
+ uint32_t width = static_cast<uint32_t>(shape[0]);
+ uint32_t height = static_cast<uint32_t>(shape[1]);
+ uint32_t numChannels = static_cast<uint32_t>(shape[2]);
+ uint32_t numBatches = static_cast<uint32_t>(shape[3]);
+
+ for (unsigned int batchIndex = 0; batchIndex < numBatches; ++batchIndex)
+ {
+ for (unsigned int channelIndex = 0; channelIndex < numChannels; ++channelIndex)
+ {
+ for (unsigned int y = 0; y < height; ++y)
+ {
+ // Copies one row from linear memory buffer to arm_compute tensor buffer.
+ // A row is the largest contiguous region we can copy, as the tensor data may be using strides.
+ memcpy(bufferPtr + GetTensorOffset(info, batchIndex, channelIndex, y, 0),
+ srcData + GetLinearBufferOffset(info, batchIndex, channelIndex, y, 0),
+ width * sizeof(T));
+ }
+ }
+ }
+ }
+}
+
+/// Construct a TensorShape object from an ArmCompute object based on arm_compute::Dimensions.
+/// \tparam ArmComputeType Any type that implements the Dimensions interface
+/// \tparam T Shape value type
+/// \param shapelike An ArmCompute object that implements the Dimensions interface
+/// \param initial A default value to initialise the shape with
+/// \return A TensorShape object filled from the Acl shapelike object.
+template<typename ArmComputeType, typename T>
+TensorShape GetTensorShape(const ArmComputeType& shapelike, T initial)
+{
+ std::vector<unsigned int> s(MaxNumOfTensorDimensions, initial);
+ for (unsigned int i=0; i < shapelike.num_dimensions(); ++i)
+ {
+ s[(shapelike.num_dimensions()-1)-i] = boost::numeric_cast<unsigned int>(shapelike[i]);
+ }
+ return TensorShape(boost::numeric_cast<unsigned int>(shapelike.num_dimensions()), s.data());
+};
+
+/// Get the strides from an ACL strides object
+inline TensorShape GetStrides(const arm_compute::Strides& strides)
+{
+ return GetTensorShape(strides, 0U);
+}
+
+/// Get the shape from an ACL shape object
+inline TensorShape GetShape(const arm_compute::TensorShape& shape)
+{
+ return GetTensorShape(shape, 1U);
+}
+
+} // namespace armcomputetensorutils
+} // namespace armnn
diff --git a/src/backends/aclCommon/ArmComputeUtils.hpp b/src/backends/aclCommon/ArmComputeUtils.hpp
new file mode 100644
index 0000000000..db472964ea
--- /dev/null
+++ b/src/backends/aclCommon/ArmComputeUtils.hpp
@@ -0,0 +1,125 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+#pragma once
+
+#if ARMCOMPUTENEON_ENABLED || ARMCOMPUTECL_ENABLED
+
+#include <armnn/Tensor.hpp>
+#include <armnn/Descriptors.hpp>
+
+#include <arm_compute/core/Types.h>
+
+namespace armnn
+{
+
+inline arm_compute::NormalizationLayerInfo
+CreateAclNormalizationLayerInfoForL2Normalization(const armnn::TensorInfo& tensorInfo)
+{
+ const unsigned int depth = tensorInfo.GetShape()[1];
+
+ // At the time of writing, {CL|Neon}L2Normalization performs the reduction only along dimension 0. This version of
+ // L2 Normalization always performs the reduction along the depth axis, though. Thus, we repurpose
+ // {CL|Neon}NormalizationLayers to act as depthwise L2 normalizations by carefully chosing the normalization
+ // parameters.
+ //
+ // Please refer to both the reference implementation of the normalization layer and the implementation of
+ // {CL|Neon}NormalizationLayer when checking the derivations for the parameter values below.
+
+ // Make sure normalization covers the entire depth range. ACL requires the normalization size to be odd.
+ // CL: This does not result in extra kernel threads not doing any work: See usage of the RADIUS parameter in
+ // ACL's normalization_layer_cross_map() CL function.
+ const uint32_t normSize = depth * 2u + 1u;
+
+ // See ACL's NormalizationLayerInfo::scale_coeff() definition.
+ // For the reference implementation, to make alpha_ become 1, we'd have to use alpha = normSize instead.
+ const float alpha = 1.0f;
+
+ // Don't offset the reduction.
+ const float kappa = 0.0f;
+
+ // pow(reduction, -0.5) = 1 / sqrt(reduction)
+ const float beta = 0.5f;
+
+ return arm_compute::NormalizationLayerInfo(arm_compute::NormType::CROSS_MAP, normSize, alpha, beta, kappa, false);
+}
+
+inline arm_compute::ActivationLayerInfo::ActivationFunction
+ConvertActivationFunctionToAclActivationFunction(ActivationFunction armnnFunction)
+{
+ using AclActivationFunction = arm_compute::ActivationLayerInfo::ActivationFunction;
+
+ switch (armnnFunction)
+ {
+ case ActivationFunction::Linear: return AclActivationFunction::LINEAR;
+ // Arm compute's 'logistic' function is non-parameterized, so it is exactly a sigmoid function.
+ case ActivationFunction::Sigmoid: return AclActivationFunction::LOGISTIC;
+ case ActivationFunction::ReLu: return AclActivationFunction::RELU;
+ case ActivationFunction::BoundedReLu: return AclActivationFunction::LU_BOUNDED_RELU;
+ case ActivationFunction::SoftReLu: return AclActivationFunction::SOFT_RELU;
+ case ActivationFunction::LeakyReLu: return AclActivationFunction::LEAKY_RELU;
+ case ActivationFunction::Abs: return AclActivationFunction::ABS;
+ case ActivationFunction::Sqrt: return AclActivationFunction::SQRT;
+ case ActivationFunction::Square: return AclActivationFunction::SQUARE;
+ case ActivationFunction::TanH: return AclActivationFunction::TANH;
+ default: throw InvalidArgumentException("Unsupported activation function");
+ }
+}
+
+inline arm_compute::ActivationLayerInfo
+ConvertActivationDescriptorToAclActivationLayerInfo(const ActivationDescriptor& actDesc)
+{
+ return arm_compute::ActivationLayerInfo(ConvertActivationFunctionToAclActivationFunction(actDesc.m_Function),
+ actDesc.m_A, actDesc.m_B);
+}
+
+inline arm_compute::PoolingType ConvertPoolingAlgorithmToAclPoolingType(PoolingAlgorithm poolingAlgorithm)
+{
+ using arm_compute::PoolingType;
+
+ switch (poolingAlgorithm)
+ {
+ case PoolingAlgorithm::Max: return PoolingType::MAX;
+ case PoolingAlgorithm::Average: return PoolingType::AVG;
+ case PoolingAlgorithm::L2: return PoolingType::L2;
+ default: throw InvalidArgumentException("Unsupported pooling algorithm");
+ }
+}
+
+inline arm_compute::DimensionRoundingType ConvertOutputShapeRoundingToAclDimensionRoundingType(OutputShapeRounding
+ rounding)
+{
+ using arm_compute::DimensionRoundingType;
+
+ switch (rounding)
+ {
+ case OutputShapeRounding::Ceiling: return DimensionRoundingType::CEIL;
+ case OutputShapeRounding::Floor: return DimensionRoundingType::FLOOR;
+ default: throw InvalidArgumentException("Unsupported Output Shape Rounding type");
+ }
+}
+
+inline arm_compute::NormType
+ConvertNormalizationAlgorithmChannelToAclNormType(NormalizationAlgorithmChannel channelType)
+{
+ using arm_compute::NormType;
+ switch (channelType)
+ {
+ case NormalizationAlgorithmChannel::Across: return NormType::CROSS_MAP;
+ case NormalizationAlgorithmChannel::Within: return NormType::IN_MAP_2D;
+ default: throw InvalidArgumentException("Unsupported normalization algorithm channel type");
+ }
+}
+
+inline arm_compute::FullyConnectedLayerInfo
+ConvertFullyConnectedDescriptorToAclFullyConnectedLayerInfo(const FullyConnectedDescriptor& fullyConnectedDesc)
+{
+ arm_compute::FullyConnectedLayerInfo fc_info;
+ fc_info.transpose_weights = fullyConnectedDesc.m_TransposeWeightMatrix;
+ return fc_info;
+}
+
+}
+
+#endif // ARMCOMPUTENEON_ENABLED || ARMCOMPUTECL_ENABLED
diff --git a/src/backends/aclCommon/CMakeLists.txt b/src/backends/aclCommon/CMakeLists.txt
new file mode 100644
index 0000000000..42f914263a
--- /dev/null
+++ b/src/backends/aclCommon/CMakeLists.txt
@@ -0,0 +1,15 @@
+#
+# Copyright © 2017 Arm Ltd. All rights reserved.
+# SPDX-License-Identifier: MIT
+#
+
+list(APPEND armnnAclCommon_sources
+ ArmComputeTensorUtils.hpp
+ ArmComputeTensorUtils.cpp
+ ArmComputeUtils.hpp
+)
+
+add_library(armnnAclCommon STATIC ${armnnAclCommon_sources})
+target_include_directories(armnnAclCommon PRIVATE ${PROJECT_SOURCE_DIR}/src)
+target_include_directories(armnnAclCommon PRIVATE ${PROJECT_SOURCE_DIR}/src/armnn)
+target_include_directories(armnnAclCommon PRIVATE ${PROJECT_SOURCE_DIR}/src/armnnUtils)
diff --git a/src/backends/aclCommon/common.cmake b/src/backends/aclCommon/common.cmake
new file mode 100644
index 0000000000..d9d035f307
--- /dev/null
+++ b/src/backends/aclCommon/common.cmake
@@ -0,0 +1,9 @@
+#
+# Copyright © 2017 Arm Ltd. All rights reserved.
+# SPDX-License-Identifier: MIT
+#
+
+if(ARMCOMPUTENEON OR ARMCOMPUTECL)
+ add_subdirectory(${PROJECT_SOURCE_DIR}/src/backends/aclCommon)
+ list(APPEND armnnLibraries armnnAclCommon)
+endif()