aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference')
-rw-r--r--src/backends/reference/RefLayerSupport.cpp47
-rw-r--r--src/backends/reference/RefLayerSupport.hpp11
-rw-r--r--src/backends/reference/RefWorkloadFactory.cpp13
-rw-r--r--src/backends/reference/RefWorkloadFactory.hpp6
-rw-r--r--src/backends/reference/backend.mk2
-rw-r--r--src/backends/reference/test/RefLayerTests.cpp13
-rw-r--r--src/backends/reference/workloads/BaseIterator.hpp34
-rw-r--r--src/backends/reference/workloads/CMakeLists.txt4
-rw-r--r--src/backends/reference/workloads/Decoders.hpp18
-rw-r--r--src/backends/reference/workloads/ElementwiseFunction.cpp25
-rw-r--r--src/backends/reference/workloads/ElementwiseFunction.hpp26
-rw-r--r--src/backends/reference/workloads/RefLogicalBinaryWorkload.cpp75
-rw-r--r--src/backends/reference/workloads/RefLogicalBinaryWorkload.hpp34
-rw-r--r--src/backends/reference/workloads/RefLogicalUnaryWorkload.cpp64
-rw-r--r--src/backends/reference/workloads/RefLogicalUnaryWorkload.hpp33
-rw-r--r--src/backends/reference/workloads/RefWorkloads.hpp2
16 files changed, 407 insertions, 0 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index 52c079fae4..f48c120203 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -1105,6 +1105,53 @@ bool RefLayerSupport::IsL2NormalizationSupported(const TensorInfo& input,
return supported;
}
+bool RefLayerSupport::IsLogicalBinarySupported(const TensorInfo& input0,
+ const TensorInfo& input1,
+ const TensorInfo& output,
+ const LogicalBinaryDescriptor& descriptor,
+ Optional<std::string&> reasonIfUnsupported) const
+{
+ IgnoreUnused(descriptor);
+
+ std::array<DataType, 1> supportedTypes =
+ {
+ DataType::Boolean
+ };
+
+ bool supported = true;
+ supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
+ "Reference LogicalBinary: input 0 type not supported");
+ supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
+ "Reference LogicalBinary: input 1 type not supported");
+
+ supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
+ "Reference LogicalBinary: input and output types do not match");
+
+ return supported;
+}
+
+bool RefLayerSupport::IsLogicalUnarySupported(const TensorInfo& input,
+ const TensorInfo& output,
+ const ElementwiseUnaryDescriptor& descriptor,
+ Optional<std::string&> reasonIfUnsupported) const
+{
+ IgnoreUnused(descriptor);
+
+ std::array<DataType, 1> supportedTypes =
+ {
+ DataType::Boolean
+ };
+
+ bool supported = true;
+ supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
+ "Reference LogicalUnary: input type not supported");
+
+ supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
+ "Reference LogicalUnary: input and output types do not match");
+
+ return supported;
+}
+
bool RefLayerSupport::IsLogSoftmaxSupported(const TensorInfo& input,
const TensorInfo& output,
const LogSoftmaxDescriptor& descriptor,
diff --git a/src/backends/reference/RefLayerSupport.hpp b/src/backends/reference/RefLayerSupport.hpp
index a233082aaa..318eb4064b 100644
--- a/src/backends/reference/RefLayerSupport.hpp
+++ b/src/backends/reference/RefLayerSupport.hpp
@@ -182,6 +182,17 @@ public:
const L2NormalizationDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+ bool IsLogicalBinarySupported(const TensorInfo& input0,
+ const TensorInfo& input1,
+ const TensorInfo& output,
+ const LogicalBinaryDescriptor& descriptor,
+ Optional<std::string&> reasonIfUnsupported) const override;
+
+ bool IsLogicalUnarySupported(const TensorInfo& input,
+ const TensorInfo& output,
+ const ElementwiseUnaryDescriptor& descriptor,
+ Optional<std::string&> reasonIfUnsupported) const override;
+
bool IsLogSoftmaxSupported(const TensorInfo& input,
const TensorInfo& output,
const LogSoftmaxDescriptor& descriptor,
diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp
index e7e57b15d1..9080028e72 100644
--- a/src/backends/reference/RefWorkloadFactory.cpp
+++ b/src/backends/reference/RefWorkloadFactory.cpp
@@ -401,6 +401,19 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateL2Normalization(const L2Nor
return std::make_unique<RefL2NormalizationWorkload>(descriptor, info);
}
+std::unique_ptr<IWorkload> RefWorkloadFactory::CreateLogicalBinary(const LogicalBinaryQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
+{
+ return std::make_unique<RefLogicalBinaryWorkload>(descriptor, info);
+}
+
+std::unique_ptr<IWorkload> RefWorkloadFactory::CreateLogicalUnary(const ElementwiseUnaryQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
+{
+ return std::make_unique<RefLogicalUnaryWorkload>(descriptor, info);
+}
+
+
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateLogSoftmax(const LogSoftmaxQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
diff --git a/src/backends/reference/RefWorkloadFactory.hpp b/src/backends/reference/RefWorkloadFactory.hpp
index 5f22c9eac2..8c3d719ae0 100644
--- a/src/backends/reference/RefWorkloadFactory.hpp
+++ b/src/backends/reference/RefWorkloadFactory.hpp
@@ -162,6 +162,12 @@ public:
std::unique_ptr<IWorkload> CreateL2Normalization(const L2NormalizationQueueDescriptor& descriptor,
const WorkloadInfo& info) const override;
+ std::unique_ptr<IWorkload> CreateLogicalBinary(const LogicalBinaryQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const override;
+
+ std::unique_ptr<IWorkload> CreateLogicalUnary(const ElementwiseUnaryQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const override;
+
std::unique_ptr<IWorkload> CreateLogSoftmax(const LogSoftmaxQueueDescriptor& descriptor,
const WorkloadInfo& info) const override;
diff --git a/src/backends/reference/backend.mk b/src/backends/reference/backend.mk
index bf5f340afc..b4aa3a0953 100644
--- a/src/backends/reference/backend.mk
+++ b/src/backends/reference/backend.mk
@@ -69,6 +69,8 @@ BACKEND_SOURCES := \
workloads/RefGatherWorkload.cpp \
workloads/RefInstanceNormalizationWorkload.cpp \
workloads/RefL2NormalizationWorkload.cpp \
+ workloads/RefLogicalBinaryWorkload.cpp \
+ workloads/RefLogicalUnaryWorkload.cpp \
workloads/RefLogSoftmaxWorkload.cpp \
workloads/RefLstmWorkload.cpp \
workloads/RefMeanWorkload.cpp \
diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp
index 7542a64711..60400c514e 100644
--- a/src/backends/reference/test/RefLayerTests.cpp
+++ b/src/backends/reference/test/RefLayerTests.cpp
@@ -2215,4 +2215,17 @@ ARMNN_AUTO_TEST_CASE_WITH_THF(Exp3dQuantisedAsymm8, Exp3dTest<DataType::QAsymmU8
ARMNN_AUTO_TEST_CASE_WITH_THF(Exp2dQuantisedSymm16, Exp2dTest<DataType::QSymmS16>)
ARMNN_AUTO_TEST_CASE_WITH_THF(Exp3dQuantisedSymm16, Exp3dTest<DataType::QSymmS16>)
+ARMNN_AUTO_TEST_CASE_WITH_THF(LogicalNot, LogicalNotTest)
+ARMNN_AUTO_TEST_CASE_WITH_THF(LogicalNotInt, LogicalNotIntTest)
+ARMNN_AUTO_TEST_CASE_WITH_THF(LogicalAnd, LogicalAndTest)
+ARMNN_AUTO_TEST_CASE_WITH_THF(LogicalOr, LogicalOrTest)
+ARMNN_AUTO_TEST_CASE_WITH_THF(LogicalAndInt, LogicalAndIntTest)
+ARMNN_AUTO_TEST_CASE_WITH_THF(LogicalOrInt, LogicalOrIntTest)
+ARMNN_AUTO_TEST_CASE_WITH_THF(LogicalAndBroadcast1, LogicalAndBroadcast1Test)
+ARMNN_AUTO_TEST_CASE_WITH_THF(LogicalOrBroadcast1, LogicalOrBroadcast1Test)
+ARMNN_AUTO_TEST_CASE_WITH_THF(LogicalAndBroadcast2, LogicalAndBroadcast2Test)
+ARMNN_AUTO_TEST_CASE_WITH_THF(LogicalOrBroadcast2, LogicalOrBroadcast2Test)
+ARMNN_AUTO_TEST_CASE_WITH_THF(LogicalAndBroadcast3, LogicalAndBroadcast3Test)
+ARMNN_AUTO_TEST_CASE_WITH_THF(LogicalOrBroadcast3, LogicalOrBroadcast3Test)
+
BOOST_AUTO_TEST_SUITE_END()
diff --git a/src/backends/reference/workloads/BaseIterator.hpp b/src/backends/reference/workloads/BaseIterator.hpp
index a10f383e90..73e24691d9 100644
--- a/src/backends/reference/workloads/BaseIterator.hpp
+++ b/src/backends/reference/workloads/BaseIterator.hpp
@@ -515,6 +515,40 @@ public:
}
};
+class BooleanDecoderBool : public TypedIterator<const uint8_t, Decoder<bool>>
+{
+public:
+ BooleanDecoderBool(const uint8_t* data)
+ : TypedIterator(data) {}
+
+ BooleanDecoderBool()
+ : BooleanDecoderBool(nullptr) {}
+
+ bool Get() const override
+ {
+ return *m_Iterator;
+ }
+
+ std::vector<float> DecodeTensor(const TensorShape& tensorShape,
+ const unsigned int channelMultiplier,
+ const bool isDepthwise) override
+ {
+ IgnoreUnused(channelMultiplier, isDepthwise);
+
+ const unsigned int size = tensorShape.GetNumElements();
+ std::vector<float> decodedTensor;
+ decodedTensor.reserve(size);
+
+ for (uint32_t i = 0; i < size; ++i)
+ {
+ this->operator[](i);
+ decodedTensor.emplace_back(*m_Iterator);
+ }
+
+ return decodedTensor;
+ }
+};
+
class QASymm8Encoder : public TypedIterator<uint8_t, Encoder<float>>
{
public:
diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt
index cd9efc96af..1b20e5bf2d 100644
--- a/src/backends/reference/workloads/CMakeLists.txt
+++ b/src/backends/reference/workloads/CMakeLists.txt
@@ -107,6 +107,10 @@ list(APPEND armnnRefBackendWorkloads_sources
RefInstanceNormalizationWorkload.hpp
RefL2NormalizationWorkload.cpp
RefL2NormalizationWorkload.hpp
+ RefLogicalBinaryWorkload.cpp
+ RefLogicalBinaryWorkload.hpp
+ RefLogicalUnaryWorkload.cpp
+ RefLogicalUnaryWorkload.hpp
RefLogSoftmaxWorkload.cpp
RefLogSoftmaxWorkload.hpp
RefLstmWorkload.cpp
diff --git a/src/backends/reference/workloads/Decoders.hpp b/src/backends/reference/workloads/Decoders.hpp
index 08e0140fad..0b3f36047d 100644
--- a/src/backends/reference/workloads/Decoders.hpp
+++ b/src/backends/reference/workloads/Decoders.hpp
@@ -150,6 +150,24 @@ inline std::unique_ptr<Decoder<float>> MakeDecoder(const TensorInfo& info, const
}
template<>
+inline std::unique_ptr<Decoder<bool>> MakeDecoder(const TensorInfo& info, const void* data)
+{
+ switch(info.GetDataType())
+ {
+ case DataType::Boolean:
+ {
+ return std::make_unique<BooleanDecoderBool>(static_cast<const uint8_t*>(data));
+ }
+ default:
+ {
+ ARMNN_ASSERT_MSG(false, "Unsupported Data Type!");
+ break;
+ }
+ }
+ return nullptr;
+}
+
+template<>
inline std::unique_ptr<Decoder<int32_t>> MakeDecoder(const TensorInfo& info, const void* data)
{
switch(info.GetDataType())
diff --git a/src/backends/reference/workloads/ElementwiseFunction.cpp b/src/backends/reference/workloads/ElementwiseFunction.cpp
index afae188bd6..d6f3f42478 100644
--- a/src/backends/reference/workloads/ElementwiseFunction.cpp
+++ b/src/backends/reference/workloads/ElementwiseFunction.cpp
@@ -37,6 +37,26 @@ ElementwiseUnaryFunction<Functor>::ElementwiseUnaryFunction(const TensorShape& i
BroadcastLoop(inShape, outShape).Unroll(Functor(), 0, inData, outData);
}
+template <typename Functor>
+LogicalBinaryFunction<Functor>::LogicalBinaryFunction(const TensorShape& inShape0,
+ const TensorShape& inShape1,
+ const TensorShape& outShape,
+ Decoder<InType>& inData0,
+ Decoder<InType>& inData1,
+ Encoder<OutType>& outData)
+{
+ BroadcastLoop(inShape0, inShape1, outShape).Unroll(Functor(), 0, inData0, inData1, outData);
+}
+
+template <typename Functor>
+LogicalUnaryFunction<Functor>::LogicalUnaryFunction(const TensorShape& inShape,
+ const TensorShape& outShape,
+ Decoder<InType>& inData,
+ Encoder<OutType>& outData)
+{
+ BroadcastLoop(inShape, outShape).Unroll(Functor(), 0, inData, outData);
+}
+
} //namespace armnn
template struct armnn::ElementwiseBinaryFunction<std::plus<float>>;
@@ -67,3 +87,8 @@ template struct armnn::ElementwiseUnaryFunction<armnn::exp<float>>;
template struct armnn::ElementwiseUnaryFunction<std::negate<float>>;
template struct armnn::ElementwiseUnaryFunction<armnn::rsqrt<float>>;
template struct armnn::ElementwiseUnaryFunction<armnn::sqrt<float>>;
+
+// Logical Unary
+template struct armnn::LogicalUnaryFunction<std::logical_not<bool>>;
+template struct armnn::LogicalBinaryFunction<std::logical_and<bool>>;
+template struct armnn::LogicalBinaryFunction<std::logical_or<bool>>;
diff --git a/src/backends/reference/workloads/ElementwiseFunction.hpp b/src/backends/reference/workloads/ElementwiseFunction.hpp
index 8259ba5ac7..ef4a2dc7d5 100644
--- a/src/backends/reference/workloads/ElementwiseFunction.hpp
+++ b/src/backends/reference/workloads/ElementwiseFunction.hpp
@@ -37,4 +37,30 @@ struct ElementwiseUnaryFunction
Encoder<OutType>& outData);
};
+template <typename Functor>
+struct LogicalBinaryFunction
+{
+ using OutType = bool;
+ using InType = bool;
+
+ LogicalBinaryFunction(const TensorShape& inShape0,
+ const TensorShape& inShape1,
+ const TensorShape& outShape,
+ Decoder<InType>& inData0,
+ Decoder<InType>& inData1,
+ Encoder<OutType>& outData);
+};
+
+template <typename Functor>
+struct LogicalUnaryFunction
+{
+ using OutType = bool;
+ using InType = bool;
+
+ LogicalUnaryFunction(const TensorShape& inShape,
+ const TensorShape& outShape,
+ Decoder<InType>& inData,
+ Encoder<OutType>& outData);
+};
+
} //namespace armnn
diff --git a/src/backends/reference/workloads/RefLogicalBinaryWorkload.cpp b/src/backends/reference/workloads/RefLogicalBinaryWorkload.cpp
new file mode 100644
index 0000000000..1b4e8f9aa0
--- /dev/null
+++ b/src/backends/reference/workloads/RefLogicalBinaryWorkload.cpp
@@ -0,0 +1,75 @@
+//
+// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "RefLogicalBinaryWorkload.hpp"
+
+#include "Decoders.hpp"
+#include "ElementwiseFunction.hpp"
+#include "Encoders.hpp"
+#include "RefWorkloadUtils.hpp"
+
+#include <Profiling.hpp>
+
+#include <armnn/TypesUtils.hpp>
+
+namespace armnn
+{
+
+RefLogicalBinaryWorkload::RefLogicalBinaryWorkload(const LogicalBinaryQueueDescriptor& desc,
+ const WorkloadInfo& info)
+ : BaseWorkload<LogicalBinaryQueueDescriptor>(desc, info)
+{}
+
+void RefLogicalBinaryWorkload::PostAllocationConfigure()
+{
+ const TensorInfo& inputInfo0 = GetTensorInfo(m_Data.m_Inputs[0]);
+ const TensorInfo& inputInfo1 = GetTensorInfo(m_Data.m_Inputs[1]);
+ const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
+
+ m_Input0 = MakeDecoder<InType>(inputInfo0);
+ m_Input1 = MakeDecoder<InType>(inputInfo1);
+ m_Output = MakeEncoder<OutType>(outputInfo);
+}
+
+void RefLogicalBinaryWorkload::Execute() const
+{
+ ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefLogicalBinaryWorkload_Execute");
+
+ const TensorInfo& inputInfo0 = GetTensorInfo(m_Data.m_Inputs[0]);
+ const TensorInfo& inputInfo1 = GetTensorInfo(m_Data.m_Inputs[1]);
+ const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
+
+ const TensorShape& inShape0 = inputInfo0.GetShape();
+ const TensorShape& inShape1 = inputInfo1.GetShape();
+ const TensorShape& outShape = outputInfo.GetShape();
+
+ m_Input0->Reset(m_Data.m_Inputs[0]->Map());
+ m_Input1->Reset(m_Data.m_Inputs[1]->Map());
+ m_Output->Reset(m_Data.m_Outputs[0]->Map());
+
+ using AndFunction = LogicalBinaryFunction<std::logical_and<bool>>;
+ using OrFunction = LogicalBinaryFunction<std::logical_or<bool>>;
+
+ switch (m_Data.m_Parameters.m_Operation)
+ {
+ case LogicalBinaryOperation::LogicalAnd:
+ {
+ AndFunction(inShape0, inShape1, outShape, *m_Input0, *m_Input1, *m_Output);
+ break;
+ }
+ case LogicalBinaryOperation::LogicalOr:
+ {
+ OrFunction(inShape0, inShape1, outShape, *m_Input0, *m_Input1, *m_Output);
+ break;
+ }
+ default:
+ {
+ throw InvalidArgumentException(std::string("Unsupported Logical Binary operation") +
+ GetLogicalBinaryOperationAsCString(m_Data.m_Parameters.m_Operation), CHECK_LOCATION());
+ }
+ }
+}
+
+} // namespace armnn
diff --git a/src/backends/reference/workloads/RefLogicalBinaryWorkload.hpp b/src/backends/reference/workloads/RefLogicalBinaryWorkload.hpp
new file mode 100644
index 0000000000..4d6baf5fa4
--- /dev/null
+++ b/src/backends/reference/workloads/RefLogicalBinaryWorkload.hpp
@@ -0,0 +1,34 @@
+//
+// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "BaseIterator.hpp"
+
+#include <backendsCommon/Workload.hpp>
+#include <backendsCommon/WorkloadData.hpp>
+
+namespace armnn
+{
+
+class RefLogicalBinaryWorkload : public BaseWorkload<LogicalBinaryQueueDescriptor>
+{
+public:
+ using BaseWorkload<LogicalBinaryQueueDescriptor>::m_Data;
+
+ RefLogicalBinaryWorkload(const LogicalBinaryQueueDescriptor& descriptor, const WorkloadInfo& info);
+ void PostAllocationConfigure() override;
+ virtual void Execute() const override;
+
+private:
+ using InType = bool;
+ using OutType = bool;
+
+ std::unique_ptr<Decoder<InType>> m_Input0;
+ std::unique_ptr<Decoder<InType>> m_Input1;
+ std::unique_ptr<Encoder<OutType>> m_Output;
+};
+
+} // namespace armnn
diff --git a/src/backends/reference/workloads/RefLogicalUnaryWorkload.cpp b/src/backends/reference/workloads/RefLogicalUnaryWorkload.cpp
new file mode 100644
index 0000000000..76eb5ac39f
--- /dev/null
+++ b/src/backends/reference/workloads/RefLogicalUnaryWorkload.cpp
@@ -0,0 +1,64 @@
+//
+// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "RefLogicalUnaryWorkload.hpp"
+
+#include "Decoders.hpp"
+#include "ElementwiseFunction.hpp"
+#include "Encoders.hpp"
+#include "RefWorkloadUtils.hpp"
+
+#include <Profiling.hpp>
+
+#include <armnn/TypesUtils.hpp>
+
+namespace armnn
+{
+
+RefLogicalUnaryWorkload::RefLogicalUnaryWorkload(const ElementwiseUnaryQueueDescriptor& desc,
+ const WorkloadInfo& info)
+ : BaseWorkload<ElementwiseUnaryQueueDescriptor>(desc, info)
+{}
+
+void RefLogicalUnaryWorkload::PostAllocationConfigure()
+{
+ const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]);
+ const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
+
+ m_Input = MakeDecoder<InType>(inputInfo);
+ m_Output = MakeEncoder<OutType>(outputInfo);
+}
+
+void RefLogicalUnaryWorkload::Execute() const
+{
+ ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefLogicalUnaryWorkload_Execute");
+
+ const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]);
+ const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
+
+ const TensorShape& inShape = inputInfo.GetShape();
+ const TensorShape& outShape = outputInfo.GetShape();
+
+ m_Input->Reset(m_Data.m_Inputs[0]->Map());
+ m_Output->Reset(m_Data.m_Outputs[0]->Map());
+
+ using NotFunction = LogicalUnaryFunction<std::logical_not<bool>>;
+
+ switch (m_Data.m_Parameters.m_Operation)
+ {
+ case UnaryOperation::LogicalNot:
+ {
+ NotFunction(inShape, outShape, *m_Input, *m_Output);
+ break;
+ }
+ default:
+ {
+ throw InvalidArgumentException(std::string("Unsupported Logical Unary operation") +
+ GetUnaryOperationAsCString(m_Data.m_Parameters.m_Operation), CHECK_LOCATION());
+ }
+ }
+}
+
+} // namespace armnn
diff --git a/src/backends/reference/workloads/RefLogicalUnaryWorkload.hpp b/src/backends/reference/workloads/RefLogicalUnaryWorkload.hpp
new file mode 100644
index 0000000000..0d8b35495c
--- /dev/null
+++ b/src/backends/reference/workloads/RefLogicalUnaryWorkload.hpp
@@ -0,0 +1,33 @@
+//
+// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "BaseIterator.hpp"
+
+#include <backendsCommon/Workload.hpp>
+#include <backendsCommon/WorkloadData.hpp>
+
+namespace armnn
+{
+
+class RefLogicalUnaryWorkload : public BaseWorkload<ElementwiseUnaryQueueDescriptor>
+{
+public:
+ using BaseWorkload<ElementwiseUnaryQueueDescriptor>::m_Data;
+
+ RefLogicalUnaryWorkload(const ElementwiseUnaryQueueDescriptor& descriptor, const WorkloadInfo& info);
+ void PostAllocationConfigure() override;
+ virtual void Execute() const override;
+
+private:
+ using InType = bool;
+ using OutType = bool;
+
+ std::unique_ptr<Decoder<InType>> m_Input;
+ std::unique_ptr<Encoder<OutType>> m_Output;
+};
+
+} // namespace armnn
diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp
index fc47cff84f..390b2a8d55 100644
--- a/src/backends/reference/workloads/RefWorkloads.hpp
+++ b/src/backends/reference/workloads/RefWorkloads.hpp
@@ -41,6 +41,8 @@
#include "RefGatherWorkload.hpp"
#include "RefInstanceNormalizationWorkload.hpp"
#include "RefL2NormalizationWorkload.hpp"
+#include "RefLogicalBinaryWorkload.hpp"
+#include "RefLogicalUnaryWorkload.hpp"
#include "RefLogSoftmaxWorkload.hpp"
#include "RefLstmWorkload.hpp"
#include "RefMeanWorkload.hpp"