aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/mark_tensors.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/mark_tensors.py')
-rw-r--r--ethosu/vela/mark_tensors.py5
1 files changed, 3 insertions, 2 deletions
diff --git a/ethosu/vela/mark_tensors.py b/ethosu/vela/mark_tensors.py
index 9b1824b5..c42a28df 100644
--- a/ethosu/vela/mark_tensors.py
+++ b/ethosu/vela/mark_tensors.py
@@ -21,7 +21,7 @@
from . import rewrite_graph
from . import weight_compressor
from .architecture_features import Block
-from .nn_graph import TensorPurpose, TensorFormat, PassPlacement
+from .tensor import TensorPurpose, TensorFormat
from .operation import NpuBlockType
@@ -55,6 +55,7 @@ def inputs_from_output(op, idx):
print("Warning: Propagating unknown tensor purpose", op)
return res
+
tensor_purposes = [ # ops, input_purpose
(
set(
@@ -327,7 +328,7 @@ def mark_tensor_format(nng, arch, verbose_tensor_format=False):
return NpuBlockType.Default
def visit_tens(tens, ps):
- if not tens in formats_for_tensor:
+ if tens not in formats_for_tensor:
fmt = init_tens(tens)
else:
fmt = formats_for_tensor[tens]