diff options
author | telsoa01 <telmo.soares@arm.com> | 2018-08-31 09:22:23 +0100 |
---|---|---|
committer | telsoa01 <telmo.soares@arm.com> | 2018-08-31 09:22:23 +0100 |
commit | c577f2c6a3b4ddb6ba87a882723c53a248afbeba (patch) | |
tree | bd7d4c148df27f8be6649d313efb24f536b7cf34 /src/armnn/backends/RefWorkloads/RefFullyConnectedUint8Workload.cpp | |
parent | 4c7098bfeab1ffe1cdc77f6c15548d3e73274746 (diff) | |
download | armnn-c577f2c6a3b4ddb6ba87a882723c53a248afbeba.tar.gz |
Release 18.08
Diffstat (limited to 'src/armnn/backends/RefWorkloads/RefFullyConnectedUint8Workload.cpp')
-rw-r--r-- | src/armnn/backends/RefWorkloads/RefFullyConnectedUint8Workload.cpp | 16 |
1 files changed, 11 insertions, 5 deletions
diff --git a/src/armnn/backends/RefWorkloads/RefFullyConnectedUint8Workload.cpp b/src/armnn/backends/RefWorkloads/RefFullyConnectedUint8Workload.cpp index 0186d3f5e5..cd653657e1 100644 --- a/src/armnn/backends/RefWorkloads/RefFullyConnectedUint8Workload.cpp +++ b/src/armnn/backends/RefWorkloads/RefFullyConnectedUint8Workload.cpp @@ -14,6 +14,12 @@ namespace armnn { +RefFullyConnectedUint8Workload::RefFullyConnectedUint8Workload( + const FullyConnectedQueueDescriptor& descriptor, const WorkloadInfo& info) + : Uint8Workload<FullyConnectedQueueDescriptor>(descriptor, info), + m_Weight(std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Weight))), + m_Bias(descriptor.m_Parameters.m_BiasEnabled + ? std::make_unique<ScopedCpuTensorHandle>(*(descriptor.m_Bias)) : nullptr) {} void RefFullyConnectedUint8Workload::Execute() const { @@ -22,18 +28,18 @@ void RefFullyConnectedUint8Workload::Execute() const const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]); const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]); - const uint8_t* weightData = m_Data.m_Weight->GetConstTensor<uint8_t>(); + const uint8_t* weightData = m_Weight->GetConstTensor<uint8_t>(); auto dequant = Dequantize(GetInputTensorDataU8(0, m_Data), inputInfo); - auto weight = Dequantize(weightData, m_Data.m_Weight->GetTensorInfo()); + auto weight = Dequantize(weightData, m_Weight->GetTensorInfo()); - std::vector<float> results(inputInfo.GetNumElements()); + std::vector<float> results(outputInfo.GetNumElements()); if (m_Data.m_Parameters.m_BiasEnabled) { - const int32_t* biasData = m_Data.m_Bias->GetConstTensor<int32_t>(); - auto bias = Dequantize(biasData, m_Data.m_Bias->GetTensorInfo()); + const int32_t* biasData = m_Bias->GetConstTensor<int32_t>(); + auto bias = Dequantize(biasData, m_Bias->GetTensorInfo()); FullyConnected(dequant.data(), results.data(), |