aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLouis Verhaard <louis.verhaard@arm.com>2020-12-17 13:54:09 +0100
committerLouis Verhaard <louis.verhaard@arm.com>2020-12-18 08:33:36 +0100
commit6c74c3bcaa733aa062c15d606726722b19c0dfdb (patch)
tree5daed3a395396003147df2f58d2f6f019a9bb113
parent168954814fb6a1cc5e7b2d44784b24402ef30199 (diff)
downloadethos-u-vela-6c74c3bcaa733aa062c15d606726722b19c0dfdb.tar.gz
MLBEDSW-3487: Support '<' for tensors
Added __lt__ for Tensor to avoid errors when sorting tensors. Change-Id: I19bb591ef17aa0d4a3389da411bd8863c2218d55 Signed-off-by: Louis Verhaard <louis.verhaard@arm.com>
-rw-r--r--ethosu/vela/tensor.py5
1 files changed, 5 insertions, 0 deletions
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index 257cb5ff..c1443b3b 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -21,6 +21,7 @@ import uuid
from collections import defaultdict
from enum import auto
from functools import lru_cache
+from functools import total_ordering
from typing import Dict
from typing import List
from typing import Optional
@@ -342,6 +343,7 @@ class TensorAddressMap:
cls.address_map[tens_id][mem_type] = address
+@total_ordering
class Tensor:
__slots__ = (
"shape",
@@ -841,6 +843,9 @@ class Tensor:
return (self.dtype.type & BaseType.Int) != 0 and self.quantization.is_valid()
+ def __lt__(self, other: "Tensor") -> bool:
+ return self.equivalence_id < other.equivalence_id
+
def __str__(self):
return "<nng.Tensor '%s' shape=%s dtype=%s>" % (self.name, self.shape, self.dtype)