diff options
author | Louis Verhaard <louis.verhaard@arm.com> | 2020-12-17 13:54:09 +0100 |
---|---|---|
committer | Louis Verhaard <louis.verhaard@arm.com> | 2020-12-18 08:33:36 +0100 |
commit | 6c74c3bcaa733aa062c15d606726722b19c0dfdb (patch) | |
tree | 5daed3a395396003147df2f58d2f6f019a9bb113 /ethosu/vela | |
parent | 168954814fb6a1cc5e7b2d44784b24402ef30199 (diff) | |
download | ethos-u-vela-6c74c3bcaa733aa062c15d606726722b19c0dfdb.tar.gz |
MLBEDSW-3487: Support '<' for tensors
Added __lt__ for Tensor to avoid errors when sorting tensors.
Change-Id: I19bb591ef17aa0d4a3389da411bd8863c2218d55
Signed-off-by: Louis Verhaard <louis.verhaard@arm.com>
Diffstat (limited to 'ethosu/vela')
-rw-r--r-- | ethosu/vela/tensor.py | 5 |
1 files changed, 5 insertions, 0 deletions
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py index 257cb5ff..c1443b3b 100644 --- a/ethosu/vela/tensor.py +++ b/ethosu/vela/tensor.py @@ -21,6 +21,7 @@ import uuid from collections import defaultdict from enum import auto from functools import lru_cache +from functools import total_ordering from typing import Dict from typing import List from typing import Optional @@ -342,6 +343,7 @@ class TensorAddressMap: cls.address_map[tens_id][mem_type] = address +@total_ordering class Tensor: __slots__ = ( "shape", @@ -841,6 +843,9 @@ class Tensor: return (self.dtype.type & BaseType.Int) != 0 and self.quantization.is_valid() + def __lt__(self, other: "Tensor") -> bool: + return self.equivalence_id < other.equivalence_id + def __str__(self): return "<nng.Tensor '%s' shape=%s dtype=%s>" % (self.name, self.shape, self.dtype) |