aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/shape4d.py
blob: e26389a1e3020a600790e43a330bf8df600870bd (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the License); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Description:
# Defines the class Shape4D.
from collections import namedtuple

from .numeric_util import full_shape
from .numeric_util import round_up_divide


class Shape4D(namedtuple("Shape4D", ["batch", "height", "width", "depth"])):
    """
    4D Shape (in NHWC format)
    """

    def __new__(cls, n=1, h=1, w=1, c=1):
        assert n is not None
        if isinstance(n, list):
            assert h == 1 and w == 1 and c == 1
            tmp = full_shape(4, n, 1)
            self = super(Shape4D, cls).__new__(cls, tmp[0], tmp[1], tmp[2], tmp[3])
        else:
            self = super(Shape4D, cls).__new__(cls, n, h, w, c)
        return self

    @classmethod
    def from_list(cls, shape, base=1):
        tmp = full_shape(4, shape, base)
        return cls(tmp[0], tmp[1], tmp[2], tmp[3])

    @classmethod
    def from_hwc(cls, h, w, c):
        return cls(1, h, w, c)

    def with_batch(self, new_batch):
        return Shape4D(new_batch, self.height, self.width, self.depth)

    def with_height(self, new_height):
        return Shape4D(self.batch, new_height, self.width, self.depth)

    def with_width(self, new_width):
        return Shape4D(self.batch, self.height, new_width, self.depth)

    def with_hw(self, new_height, new_width):
        return Shape4D(self.batch, new_height, new_width, self.depth)

    def with_depth(self, new_depth):
        return Shape4D(self.batch, self.height, self.width, new_depth)

    def add(self, n, h, w, c):
        return Shape4D(self.batch + n, self.height + h, self.width + w, self.depth + c)

    def __add__(self, rhs):
        return Shape4D(self.batch + rhs.batch, self.height + rhs.height, self.width + rhs.width, self.depth + rhs.depth)

    def __sub__(self, rhs):
        return Shape4D(self.batch - rhs.batch, self.height - rhs.height, self.width - rhs.width, self.depth - rhs.depth)

    def __floordiv__(self, rhs):
        return Shape4D(
            self.batch // rhs.batch, self.height // rhs.height, self.width // rhs.width, self.depth // rhs.depth
        )

    def __mod__(self, rhs):
        return Shape4D(self.batch % rhs.batch, self.height % rhs.height, self.width % rhs.width, self.depth % rhs.depth)

    def __str__(self):
        return f"<Shape4D {list(self)}>"

    def div_round_up(self, rhs):
        return Shape4D(
            round_up_divide(self.batch, rhs.batch),
            round_up_divide(self.height, rhs.height),
            round_up_divide(self.width, rhs.width),
            round_up_divide(self.depth, rhs.depth),
        )

    def elements(self):
        return self.batch * self.width * self.height * self.depth

    def elements_wh(self):
        return self.width * self.height

    def is_empty(self):
        return (self.batch + self.width + self.height + self.depth) == 0

    def as_list(self):
        return list(self)

    def get_hw_as_list(self):
        return list([self.height, self.width])