aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/workloads')
-rw-r--r--src/backends/reference/workloads/CMakeLists.txt1
-rw-r--r--src/backends/reference/workloads/RefShapeWorkload.hpp48
-rw-r--r--src/backends/reference/workloads/RefWorkloads.hpp5
3 files changed, 52 insertions, 2 deletions
diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt
index 09e02e67bd..7a769e5246 100644
--- a/src/backends/reference/workloads/CMakeLists.txt
+++ b/src/backends/reference/workloads/CMakeLists.txt
@@ -143,6 +143,7 @@ list(APPEND armnnRefBackendWorkloads_sources
RefResizeBilinearWorkload.hpp
RefResizeWorkload.cpp
RefResizeWorkload.hpp
+ RefShapeWorkload.hpp
RefSliceWorkload.cpp
RefSliceWorkload.hpp
RefSoftmaxWorkload.cpp
diff --git a/src/backends/reference/workloads/RefShapeWorkload.hpp b/src/backends/reference/workloads/RefShapeWorkload.hpp
new file mode 100644
index 0000000000..8e2a410b0c
--- /dev/null
+++ b/src/backends/reference/workloads/RefShapeWorkload.hpp
@@ -0,0 +1,48 @@
+//
+// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include <backendsCommon/Workload.hpp>
+#include <backendsCommon/WorkloadData.hpp>
+
+#include "RefWorkloadUtils.hpp"
+
+namespace armnn
+{
+
+struct RefShapeWorkload : public BaseWorkload<ShapeQueueDescriptor>
+{
+public:
+ using BaseWorkload<ShapeQueueDescriptor>::BaseWorkload;
+ virtual void Execute() const override
+ {
+ Execute(m_Data.m_Inputs, m_Data.m_Outputs);
+ }
+ void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override
+ {
+ Execute(workingMemDescriptor.m_Inputs, workingMemDescriptor.m_Outputs);
+ }
+
+private:
+ void Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const
+ {
+ const TensorShape Shape = GetTensorInfo(inputs[0]).GetShape();
+
+ const TensorInfo& outputInfo = GetTensorInfo(outputs[0]);
+
+ unsigned int numBytes =
+ GetTensorInfo(inputs[0]).GetNumDimensions() * GetDataTypeSize(outputInfo.GetDataType());
+
+ std::memcpy(outputs[0]->Map(), &Shape, numBytes);
+ outputs[0]->Unmap();
+ }
+};
+
+} //namespace armnn
+
+
+
+
diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp
index d3995f2b82..afe63d13c0 100644
--- a/src/backends/reference/workloads/RefWorkloads.hpp
+++ b/src/backends/reference/workloads/RefWorkloads.hpp
@@ -35,10 +35,10 @@
#include "RefDequantizeWorkload.hpp"
#include "RefElementwiseWorkload.hpp"
#include "RefElementwiseUnaryWorkload.hpp"
+#include "RefFakeQuantizationFloat32Workload.hpp"
#include "RefFillWorkload.hpp"
-#include "RefFullyConnectedWorkload.hpp"
#include "RefFloorWorkload.hpp"
-#include "RefFakeQuantizationFloat32Workload.hpp"
+#include "RefFullyConnectedWorkload.hpp"
#include "RefGatherWorkload.hpp"
#include "RefInstanceNormalizationWorkload.hpp"
#include "RefL2NormalizationWorkload.hpp"
@@ -59,6 +59,7 @@
#include "RefReshapeWorkload.hpp"
#include "RefResizeBilinearWorkload.hpp"
#include "RefResizeWorkload.hpp"
+#include "RefShapeWorkload.hpp"
#include "RefSliceWorkload.hpp"
#include "RefSplitterWorkload.hpp"
#include "RefSoftmaxWorkload.hpp"