# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved. # # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the License); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an AS IS BASIS, WITHOUT # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Description: # Defines the basic numeric type classes for tensors. import enum from typing import Any import numpy as np from .numeric_util import round_up_divide class BaseType(enum.Flag): Signed = 1 Unsigned = 2 Asymmetric = 4 Int = 8 SignedInt = Int | Signed UnsignedInt = Int | Unsigned AsymmSInt = Int | Asymmetric | Signed AsymmUInt = Int | Asymmetric | Unsigned Float = 16 BFloat = 32 Bool = 64 String = 128 Resource = 256 Variant = 512 Complex = 1024 class DataType: """Defines a data type. Consists of a base type, and the number of bits used for this type""" __slots__ = "type", "bits" int4: Any int8: Any int16: Any int32: Any int48: Any int64: Any uint8: Any uint16: Any uint32: Any uint64: Any quint4: Any quint8: Any quint12: Any quint16: Any quint32: Any qint4: Any qint8: Any qint12: Any qint16: Any qint32: Any float16: Any float32: Any float64: Any string: Any bool: Any resource: Any variant: Any complex64: Any complex128: Any def __init__(self, type_, bits): self.type = type_ self.bits = bits def __eq__(self, other): return self.type == other.type and self.bits == other.bits def __hash__(self): return hash((self.type, self.bits)) def size_in_bytes(self): return round_up_divide(self.bits, 8) def size_in_bits(self): return self.bits def __str__(self): stem, needs_format = DataType.stem_name[self.type] if not needs_format: return stem else: return stem % (self.bits,) __repr__ = __str__ def as_numpy_type(self): numpy_dtype_code = { BaseType.UnsignedInt: "u", BaseType.SignedInt: "i", BaseType.Float: "f", BaseType.Complex: "c", } assert self.type in numpy_dtype_code, f"Failed to interpret {self} as a numpy dtype" return np.dtype(numpy_dtype_code[self.type] + str(self.size_in_bytes())) stem_name = { BaseType.UnsignedInt: ("uint%s", True), BaseType.SignedInt: ("int%s", True), BaseType.AsymmUInt: ("quint%s", True), BaseType.AsymmSInt: ("qint%s", True), BaseType.Float: ("float%s", True), BaseType.BFloat: ("bfloat%s", True), BaseType.Bool: ("bool", False), BaseType.String: ("string", False), BaseType.Resource: ("resource", False), BaseType.Variant: ("variant", False), BaseType.Complex: ("complex%s", True), } # generate the standard set of data types DataType.int4 = DataType(BaseType.SignedInt, 4) DataType.int8 = DataType(BaseType.SignedInt, 8) DataType.int16 = DataType(BaseType.SignedInt, 16) DataType.int32 = DataType(BaseType.SignedInt, 32) DataType.int48 = DataType(BaseType.SignedInt, 48) DataType.int64 = DataType(BaseType.SignedInt, 64) DataType.uint8 = DataType(BaseType.UnsignedInt, 8) DataType.uint16 = DataType(BaseType.UnsignedInt, 16) DataType.uint32 = DataType(BaseType.UnsignedInt, 32) DataType.uint64 = DataType(BaseType.UnsignedInt, 64) DataType.quint4 = DataType(BaseType.AsymmUInt, 4) DataType.quint8 = DataType(BaseType.AsymmUInt, 8) DataType.quint12 = DataType(BaseType.AsymmUInt, 12) DataType.quint16 = DataType(BaseType.AsymmUInt, 16) DataType.quint32 = DataType(BaseType.AsymmUInt, 32) DataType.qint4 = DataType(BaseType.AsymmSInt, 4) DataType.qint8 = DataType(BaseType.AsymmSInt, 8) DataType.qint12 = DataType(BaseType.AsymmSInt, 12) DataType.qint16 = DataType(BaseType.AsymmSInt, 16) DataType.qint32 = DataType(BaseType.AsymmSInt, 32) DataType.float16 = DataType(BaseType.Float, 16) DataType.float32 = DataType(BaseType.Float, 32) DataType.float64 = DataType(BaseType.Float, 64) DataType.string = DataType(BaseType.String, 64) DataType.bool = DataType(BaseType.Bool, 8) DataType.resource = DataType(BaseType.Resource, 8) DataType.variant = DataType(BaseType.Variant, 8) DataType.complex64 = DataType(BaseType.Complex, 64) DataType.complex128 = DataType(BaseType.Complex, 128)