aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/lut.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/lut.py')
-rw-r--r--ethosu/vela/lut.py6
1 files changed, 4 insertions, 2 deletions
diff --git a/ethosu/vela/lut.py b/ethosu/vela/lut.py
index 0e8dcc95..e3373ca2 100644
--- a/ethosu/vela/lut.py
+++ b/ethosu/vela/lut.py
@@ -42,9 +42,9 @@ class LUTState:
self.tensors = []
def get_equivalent(self, lut_tens):
- # Returns existing lut with same equivalence id, None if not found
+ # Returns existing lut with the same values, None if not found
for t in self.tensors:
- if t.equivalent(lut_tens):
+ if np.array_equal(t.values, lut_tens.values):
return t
return None
@@ -60,6 +60,7 @@ class LUTState:
end2 = start2 + tens.storage_size()
if not numeric_util.overlaps(start, end, start2, end2):
new_state.tensors.append(tens)
+
return new_state
def find_best_address(self, start, stop, step):
@@ -129,6 +130,7 @@ def optimize_high_level_cmd_stream(sg, arch):
# Place the LUT in the last 2 blocks of SHRAM
# Alignment is always on the size of the LUT, 256 for 256-byte LUT, 1K for 1K LUT, etc
address = lut_state.find_best_address(lut_start, lut_end, lut_tens.storage_size())
+ lut_tens.equivalence_id = uuid.uuid4()
lut_tens.address = address
cmd.ps.primary_op.attrs["lut_index"] = (address - lut_start) // slot_size
lut_state = lut_state.put(lut_tens)