aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/RefQuantizeWorkload.cpp
blob: ab2ee7fc4e90f1a91ffbc046d8af2a781dc460b7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//

#include "RefQuantizeWorkload.hpp"

#include <armnn/TypesUtils.hpp>


namespace armnn
{

namespace
{

template<typename T>
void QuantizeImpl(const void *input, void *output, size_t numValues, float scale, int offset)
{
    auto in = static_cast<const float *>(input);
    auto out = static_cast<T *>(output);
    for (size_t i = 0; i < numValues; i++, in++, out++)
    {
        *out = armnn::Quantize<T>(*in, scale, offset);
    }
}

} //namespace

RefQuantizeWorkload::RefQuantizeWorkload(const QuantizeQueueDescriptor& descriptor, const WorkloadInfo &info)
    : BaseWorkload(descriptor, info)
    , m_NumElements(info.m_InputTensorInfos[0].GetNumElements())
    , m_TargetType(info.m_OutputTensorInfos[0].GetDataType())
    , m_Scale(info.m_OutputTensorInfos[0].GetQuantizationScale())
    , m_Offset(info.m_OutputTensorInfos[0].GetQuantizationOffset())
{
}

void RefQuantizeWorkload::Execute() const
{
    const void* input = m_Data.m_Inputs[0]->Map(true);
    void* output =  m_Data.m_Outputs[0]->Map(true);

    switch(m_TargetType)
    {
        case DataType::QAsymmU8:
        {
            QuantizeImpl<uint8_t>(input, output, m_NumElements, m_Scale, m_Offset);
            break;
        }
        case DataType::QSymmS8:
        {
            QuantizeImpl<int8_t>(input, output, m_NumElements, m_Scale, 0);
            break;
        }
        case DataType::QSymmS16:
        {
            QuantizeImpl<int16_t>(input, output, m_NumElements, m_Scale, 0);
            break;
        }
        default:
        {
            BOOST_ASSERT_MSG(false, "RefQuantizeWorkload: Non quantized output type encountered");
        }
    }

    m_Data.m_Inputs[0]->Unmap();
    m_Data.m_Outputs[0]->Unmap();
}

} //namespace armnn