aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/backends/RefWorkloads/RefFullyConnectedUint8Workload.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/backends/RefWorkloads/RefFullyConnectedUint8Workload.cpp')
-rw-r--r--src/armnn/backends/RefWorkloads/RefFullyConnectedUint8Workload.cpp60
1 files changed, 60 insertions, 0 deletions
diff --git a/src/armnn/backends/RefWorkloads/RefFullyConnectedUint8Workload.cpp b/src/armnn/backends/RefWorkloads/RefFullyConnectedUint8Workload.cpp
new file mode 100644
index 0000000000..0186d3f5e5
--- /dev/null
+++ b/src/armnn/backends/RefWorkloads/RefFullyConnectedUint8Workload.cpp
@@ -0,0 +1,60 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// See LICENSE file in the project root for full license information.
+//
+
+#include "RefFullyConnectedUint8Workload.hpp"
+
+#include "FullyConnected.hpp"
+#include "RefWorkloadUtils.hpp"
+
+#include "Profiling.hpp"
+
+#include <vector>
+
+namespace armnn
+{
+
+void RefFullyConnectedUint8Workload::Execute() const
+{
+ ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefFullyConnectedUint8Workload_Execute");
+
+ const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]);
+ const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
+
+ const uint8_t* weightData = m_Data.m_Weight->GetConstTensor<uint8_t>();
+
+ auto dequant = Dequantize(GetInputTensorDataU8(0, m_Data), inputInfo);
+
+ auto weight = Dequantize(weightData, m_Data.m_Weight->GetTensorInfo());
+
+ std::vector<float> results(inputInfo.GetNumElements());
+
+ if (m_Data.m_Parameters.m_BiasEnabled)
+ {
+ const int32_t* biasData = m_Data.m_Bias->GetConstTensor<int32_t>();
+ auto bias = Dequantize(biasData, m_Data.m_Bias->GetTensorInfo());
+
+ FullyConnected(dequant.data(),
+ results.data(),
+ inputInfo,
+ outputInfo,
+ weight.data(),
+ bias.data(),
+ m_Data.m_Parameters.m_TransposeWeightMatrix);
+ }
+ else
+ {
+ FullyConnected(dequant.data(),
+ results.data(),
+ inputInfo,
+ outputInfo,
+ weight.data(),
+ nullptr,
+ m_Data.m_Parameters.m_TransposeWeightMatrix);
+ }
+
+ Quantize(GetOutputTensorDataU8(0, m_Data), results.data(), outputInfo);
+}
+
+} //namespace armnn