aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/data_type.py
blob: 1d3e94ed8066aeab69c2abebaa6b7a0c2df57e2e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the License); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


# Description:
# Defines the basic numeric type classes for tensors.

from .numeric_util import round_up_divide
import enum


class BaseType(enum.Flag):
    Signed = 1
    Unsigned = 2
    Asymmetric = 4
    Int = 8
    SignedInt = Int | Signed
    UnsignedInt = Int | Unsigned
    AsymmSInt = Int | Asymmetric | Signed
    AsymmUInt = Int | Asymmetric | Unsigned
    Float = 16
    BFloat = 32
    Bool = 64
    String = 128
    Resource = 256
    Variant = 512


class DataType:
    """Defines a data type. Consists of a base type, and the number of bits used for this type"""

    __slots__ = "type", "bits"

    def __init__(self, type_, bits):
        self.type = type_
        self.bits = bits

    def __eq__(self, other):
        return self.type == other.type and self.bits == other.bits

    def __hash__(self):
        return hash((self.type, self.bits))

    def size_in_bytes(self):
        return round_up_divide(self.bits, 8)

    def size_in_bits(self):
        return self.bits

    def __str__(self):
        stem, needs_format = DataType.stem_name[self.type]
        if not needs_format:
            return stem
        else:
            return stem % (self.bits,)

    __repr__ = __str__

    stem_name = {
        BaseType.UnsignedInt: ("uint%s", True),
        BaseType.SignedInt: ("int%s", True),
        BaseType.AsymmUInt: ("quint%s", True),
        BaseType.AsymmSInt: ("qint%s", True),
        BaseType.Float: ("float%s", True),
        BaseType.BFloat: ("bfloat%s", True),
        BaseType.Bool: ("bool", False),
        BaseType.String: ("string", False),
        BaseType.Resource: ("resource", False),
        BaseType.Variant: ("variant", False),
    }


# generate the standard set of data types
DataType.int8 = DataType(BaseType.SignedInt, 8)
DataType.int16 = DataType(BaseType.SignedInt, 16)
DataType.int32 = DataType(BaseType.SignedInt, 32)
DataType.int64 = DataType(BaseType.SignedInt, 64)

DataType.uint8 = DataType(BaseType.UnsignedInt, 8)
DataType.uint16 = DataType(BaseType.UnsignedInt, 16)
DataType.uint32 = DataType(BaseType.UnsignedInt, 32)
DataType.uint64 = DataType(BaseType.UnsignedInt, 64)

DataType.quint4 = DataType(BaseType.AsymmUInt, 4)
DataType.quint8 = DataType(BaseType.AsymmUInt, 8)
DataType.quint12 = DataType(BaseType.AsymmUInt, 12)
DataType.quint16 = DataType(BaseType.AsymmUInt, 16)
DataType.quint32 = DataType(BaseType.AsymmUInt, 32)

DataType.qint4 = DataType(BaseType.AsymmSInt, 4)
DataType.qint8 = DataType(BaseType.AsymmSInt, 8)
DataType.qint12 = DataType(BaseType.AsymmSInt, 12)
DataType.qint16 = DataType(BaseType.AsymmSInt, 16)
DataType.qint32 = DataType(BaseType.AsymmSInt, 32)

DataType.float16 = DataType(BaseType.Float, 16)
DataType.float32 = DataType(BaseType.Float, 32)
DataType.float64 = DataType(BaseType.Float, 64)

DataType.string = DataType(BaseType.String, 64)
DataType.bool = DataType(BaseType.Bool, 8)
DataType.resource = DataType(BaseType.Resource, 8)
DataType.variant = DataType(BaseType.Variant, 8)