diff options
Diffstat (limited to 'ethosu/vela/data_type.py')
-rw-r--r-- | ethosu/vela/data_type.py | 116 |
1 files changed, 116 insertions, 0 deletions
diff --git a/ethosu/vela/data_type.py b/ethosu/vela/data_type.py new file mode 100644 index 00000000..1d3e94ed --- /dev/null +++ b/ethosu/vela/data_type.py @@ -0,0 +1,116 @@ +# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the License); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an AS IS BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Description: +# Defines the basic numeric type classes for tensors. + +from .numeric_util import round_up_divide +import enum + + +class BaseType(enum.Flag): + Signed = 1 + Unsigned = 2 + Asymmetric = 4 + Int = 8 + SignedInt = Int | Signed + UnsignedInt = Int | Unsigned + AsymmSInt = Int | Asymmetric | Signed + AsymmUInt = Int | Asymmetric | Unsigned + Float = 16 + BFloat = 32 + Bool = 64 + String = 128 + Resource = 256 + Variant = 512 + + +class DataType: + """Defines a data type. Consists of a base type, and the number of bits used for this type""" + + __slots__ = "type", "bits" + + def __init__(self, type_, bits): + self.type = type_ + self.bits = bits + + def __eq__(self, other): + return self.type == other.type and self.bits == other.bits + + def __hash__(self): + return hash((self.type, self.bits)) + + def size_in_bytes(self): + return round_up_divide(self.bits, 8) + + def size_in_bits(self): + return self.bits + + def __str__(self): + stem, needs_format = DataType.stem_name[self.type] + if not needs_format: + return stem + else: + return stem % (self.bits,) + + __repr__ = __str__ + + stem_name = { + BaseType.UnsignedInt: ("uint%s", True), + BaseType.SignedInt: ("int%s", True), + BaseType.AsymmUInt: ("quint%s", True), + BaseType.AsymmSInt: ("qint%s", True), + BaseType.Float: ("float%s", True), + BaseType.BFloat: ("bfloat%s", True), + BaseType.Bool: ("bool", False), + BaseType.String: ("string", False), + BaseType.Resource: ("resource", False), + BaseType.Variant: ("variant", False), + } + + +# generate the standard set of data types +DataType.int8 = DataType(BaseType.SignedInt, 8) +DataType.int16 = DataType(BaseType.SignedInt, 16) +DataType.int32 = DataType(BaseType.SignedInt, 32) +DataType.int64 = DataType(BaseType.SignedInt, 64) + +DataType.uint8 = DataType(BaseType.UnsignedInt, 8) +DataType.uint16 = DataType(BaseType.UnsignedInt, 16) +DataType.uint32 = DataType(BaseType.UnsignedInt, 32) +DataType.uint64 = DataType(BaseType.UnsignedInt, 64) + +DataType.quint4 = DataType(BaseType.AsymmUInt, 4) +DataType.quint8 = DataType(BaseType.AsymmUInt, 8) +DataType.quint12 = DataType(BaseType.AsymmUInt, 12) +DataType.quint16 = DataType(BaseType.AsymmUInt, 16) +DataType.quint32 = DataType(BaseType.AsymmUInt, 32) + +DataType.qint4 = DataType(BaseType.AsymmSInt, 4) +DataType.qint8 = DataType(BaseType.AsymmSInt, 8) +DataType.qint12 = DataType(BaseType.AsymmSInt, 12) +DataType.qint16 = DataType(BaseType.AsymmSInt, 16) +DataType.qint32 = DataType(BaseType.AsymmSInt, 32) + +DataType.float16 = DataType(BaseType.Float, 16) +DataType.float32 = DataType(BaseType.Float, 32) +DataType.float64 = DataType(BaseType.Float, 64) + +DataType.string = DataType(BaseType.String, 64) +DataType.bool = DataType(BaseType.Bool, 8) +DataType.resource = DataType(BaseType.Resource, 8) +DataType.variant = DataType(BaseType.Variant, 8) |