From eeb85154b00a9864d0d63e382e9c80ca8e294d5d Mon Sep 17 00:00:00 2001 From: "patrik.gustavsson" Date: Mon, 21 Dec 2020 17:10:40 +0000 Subject: Revert "Revert "MLBEDSW-3645 4D class for op ifm/ofm shapes"" This reverts commit df0a5905177f3a1b836076bc3f9f39b2e86f1794. Reason for revert: Change-Id: I891c66fb29db9d25e942947e8d1c29a10610de51 --- ethosu/vela/shape4d.py | 77 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) create mode 100644 ethosu/vela/shape4d.py (limited to 'ethosu/vela/shape4d.py') diff --git a/ethosu/vela/shape4d.py b/ethosu/vela/shape4d.py new file mode 100644 index 00000000..a1b4feaa --- /dev/null +++ b/ethosu/vela/shape4d.py @@ -0,0 +1,77 @@ +# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the License); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an AS IS BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Description: +# Defines the class Shape4D. +from .numeric_util import full_shape + + +class Shape4D: + """ + 4D Shape (in NHWC format) + """ + + def __init__(self, shape, base=1): + assert shape is not None + assert len(shape) <= 4 + self._shape4D = tuple(full_shape(4, shape, base)) + + def __str__(self): + return f"" + + def __eq__(self, other): + return self._shape4D == other._shape4D + + def clone(self): + return Shape4D(self.as_list()) + + @property + def batch(self): + return self._shape4D[0] + + @property + def height(self): + return self._shape4D[1] + + @property + def width(self): + return self._shape4D[2] + + @property + def depth(self): + return self._shape4D[3] + + @batch.setter + def batch(self, new_batch): + self._shape4D = (new_batch, self._shape4D[1], self._shape4D[2], self._shape4D[3]) + + @height.setter + def height(self, new_height): + self._shape4D = (self._shape4D[0], new_height, self._shape4D[2], self._shape4D[3]) + + @width.setter + def width(self, new_width): + self._shape4D = (self._shape4D[0], self._shape4D[1], new_width, self._shape4D[3]) + + @depth.setter + def depth(self, new_depth): + self._shape4D = (self._shape4D[0], self._shape4D[1], self._shape4D[2], new_depth) + + def get_dim(self, dim): + assert -4 <= dim < 4 + return self._shape4D[dim] + + def as_list(self): + return list(self._shape4D) -- cgit v1.2.1