diff options
Diffstat (limited to 'ethosu/vela/shape4d.py')
-rw-r--r-- | ethosu/vela/shape4d.py | 77 |
1 files changed, 77 insertions, 0 deletions
diff --git a/ethosu/vela/shape4d.py b/ethosu/vela/shape4d.py new file mode 100644 index 00000000..a1b4feaa --- /dev/null +++ b/ethosu/vela/shape4d.py @@ -0,0 +1,77 @@ +# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the License); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an AS IS BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Description: +# Defines the class Shape4D. +from .numeric_util import full_shape + + +class Shape4D: + """ + 4D Shape (in NHWC format) + """ + + def __init__(self, shape, base=1): + assert shape is not None + assert len(shape) <= 4 + self._shape4D = tuple(full_shape(4, shape, base)) + + def __str__(self): + return f"<Shape4D {self.as_list()}>" + + def __eq__(self, other): + return self._shape4D == other._shape4D + + def clone(self): + return Shape4D(self.as_list()) + + @property + def batch(self): + return self._shape4D[0] + + @property + def height(self): + return self._shape4D[1] + + @property + def width(self): + return self._shape4D[2] + + @property + def depth(self): + return self._shape4D[3] + + @batch.setter + def batch(self, new_batch): + self._shape4D = (new_batch, self._shape4D[1], self._shape4D[2], self._shape4D[3]) + + @height.setter + def height(self, new_height): + self._shape4D = (self._shape4D[0], new_height, self._shape4D[2], self._shape4D[3]) + + @width.setter + def width(self, new_width): + self._shape4D = (self._shape4D[0], self._shape4D[1], new_width, self._shape4D[3]) + + @depth.setter + def depth(self, new_depth): + self._shape4D = (self._shape4D[0], self._shape4D[1], self._shape4D[2], new_depth) + + def get_dim(self, dim): + assert -4 <= dim < 4 + return self._shape4D[dim] + + def as_list(self): + return list(self._shape4D) |