aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/graph_optimiser_util.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/graph_optimiser_util.py')
-rw-r--r--ethosu/vela/graph_optimiser_util.py7
1 files changed, 5 insertions, 2 deletions
diff --git a/ethosu/vela/graph_optimiser_util.py b/ethosu/vela/graph_optimiser_util.py
index 8b24eaf9..e1341d82 100644
--- a/ethosu/vela/graph_optimiser_util.py
+++ b/ethosu/vela/graph_optimiser_util.py
@@ -30,6 +30,7 @@ from .operation import Op
from .shape4d import Shape4D
from .tensor import create_const_tensor
from .tensor import QuantizationParameters
+from .tensor import Tensor
memory_only_ops = (
Op.Reshape,
@@ -90,7 +91,9 @@ def _avoid_nhcwb16_for_memory_only(tens):
# Check if non linear format can be used
-def check_format_restrictions(tens, arch):
+def check_format_restrictions(tens: Tensor, arch):
+ if tens.force_linear_format:
+ return
if len(tens.ops) < 1:
return
if tens.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const) or any(
@@ -161,7 +164,7 @@ def check_format_restrictions(tens, arch):
else:
return
- tens.needs_linear_format = False
+ tens.force_linear_format = False
def calc_explicit_padding(input_size, stride, filter_size, pad_before, pad_after) -> Tuple[int, int]: