diff options
Diffstat (limited to 'ethosu/vela/mark_tensors.py')
-rw-r--r-- | ethosu/vela/mark_tensors.py | 5 |
1 files changed, 3 insertions, 2 deletions
diff --git a/ethosu/vela/mark_tensors.py b/ethosu/vela/mark_tensors.py index 9b1824b5..c42a28df 100644 --- a/ethosu/vela/mark_tensors.py +++ b/ethosu/vela/mark_tensors.py @@ -21,7 +21,7 @@ from . import rewrite_graph from . import weight_compressor from .architecture_features import Block -from .nn_graph import TensorPurpose, TensorFormat, PassPlacement +from .tensor import TensorPurpose, TensorFormat from .operation import NpuBlockType @@ -55,6 +55,7 @@ def inputs_from_output(op, idx): print("Warning: Propagating unknown tensor purpose", op) return res + tensor_purposes = [ # ops, input_purpose ( set( @@ -327,7 +328,7 @@ def mark_tensor_format(nng, arch, verbose_tensor_format=False): return NpuBlockType.Default def visit_tens(tens, ps): - if not tens in formats_for_tensor: + if tens not in formats_for_tensor: fmt = init_tens(tens) else: fmt = formats_for_tensor[tens] |