aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tensor.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tensor.py')
-rw-r--r--ethosu/vela/tensor.py12
1 files changed, 8 insertions, 4 deletions
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index ecca0e0e..312e8f35 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -54,16 +54,17 @@ class MemArea(enum.IntFlag):
Dram = 2
OnChipFlash = 3
OffChipFlash = 4
- Size = OffChipFlash + 1
+ Shram = 5 # for LUT
+ Size = Shram + 1
def display_name(self):
- return ("Unknown", "SRAM", "DRAM", "On-chip Flash", "Off-chip Flash", "Size")[self.value]
+ return ("Unknown", "SRAM", "DRAM", "On-chip Flash", "Off-chip Flash", "SHRAM", "Size")[self.value]
def identifier_name(self):
- return ("unknown", "sram", "dram", "on_chip_flash", "off_chip_flash", "size")[self.value]
+ return ("unknown", "sram", "dram", "on_chip_flash", "off_chip_flash", "shram", "size")[self.value]
def all():
- return (MemArea.Sram, MemArea.Dram, MemArea.OnChipFlash, MemArea.OffChipFlash)
+ return (MemArea.Sram, MemArea.Dram, MemArea.OnChipFlash, MemArea.OffChipFlash, MemArea.Shram)
def __str__(self):
return self.name
@@ -728,6 +729,9 @@ class Tensor:
return True
return False
+ def equivalent(self, tens):
+ return self.equivalence_id == tens.equivalence_id
+
def set_all_shapes(self, shape):
self.shape = shape
self.storage_shape = shape