aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/backends/RefWorkloads/RefFullyConnectedUint8Workload.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/backends/RefWorkloads/RefFullyConnectedUint8Workload.cpp')
-rw-r--r--src/armnn/backends/RefWorkloads/RefFullyConnectedUint8Workload.cpp16
1 files changed, 11 insertions, 5 deletions
diff --git a/src/armnn/backends/RefWorkloads/RefFullyConnectedUint8Workload.cpp b/src/armnn/backends/RefWorkloads/RefFullyConnectedUint8Workload.cpp
index 0186d3f5e5..cd653657e1 100644
--- a/src/armnn/backends/RefWorkloads/RefFullyConnectedUint8Workload.cpp
+++ b/src/armnn/backends/RefWorkloads/RefFullyConnectedUint8Workload.cpp
@@ -14,6 +14,12 @@
namespace armnn
{
+RefFullyConnectedUint8Workload::RefFullyConnectedUint8Workload(
+ const FullyConnectedQueueDescriptor& descriptor, const WorkloadInfo& info)
+ : Uint8Workload<FullyConnectedQueueDescriptor>(descriptor, info),
+ m_Weight(std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Weight))),
+ m_Bias(descriptor.m_Parameters.m_BiasEnabled
+ ? std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Bias)) : nullptr) {}
void RefFullyConnectedUint8Workload::Execute() const
{
@@ -22,18 +28,18 @@ void RefFullyConnectedUint8Workload::Execute() const
const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]);
const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
- const uint8_t* weightData = m_Data.m_Weight->GetConstTensor<uint8_t>();
+ const uint8_t* weightData = m_Weight->GetConstTensor<uint8_t>();
auto dequant = Dequantize(GetInputTensorDataU8(0, m_Data), inputInfo);
- auto weight = Dequantize(weightData, m_Data.m_Weight->GetTensorInfo());
+ auto weight = Dequantize(weightData, m_Weight->GetTensorInfo());
- std::vector<float> results(inputInfo.GetNumElements());
+ std::vector<float> results(outputInfo.GetNumElements());
if (m_Data.m_Parameters.m_BiasEnabled)
{
- const int32_t* biasData = m_Data.m_Bias->GetConstTensor<int32_t>();
- auto bias = Dequantize(biasData, m_Data.m_Bias->GetTensorInfo());
+ const int32_t* biasData = m_Bias->GetConstTensor<int32_t>();
+ auto bias = Dequantize(biasData, m_Bias->GetTensorInfo());
FullyConnected(dequant.data(),
results.data(),