aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_model_semantic.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tflite_model_semantic.py')
-rw-r--r--ethosu/vela/tflite_model_semantic.py22
1 files changed, 15 insertions, 7 deletions
diff --git a/ethosu/vela/tflite_model_semantic.py b/ethosu/vela/tflite_model_semantic.py
index 56dce14f..3ac78b25 100644
--- a/ethosu/vela/tflite_model_semantic.py
+++ b/ethosu/vela/tflite_model_semantic.py
@@ -702,7 +702,7 @@ class TFLiteSemantic:
When IFM tensor is 3D or 4D:
- Reduction in Batch axis is only supported if batch size is 1.
- Reduction in both Height and Width axes is supported.
- - Reduction in Depth axis is only supported if depth is 1."""
+ - Reduction in Depth axis is supported if at least one of H,W,C are of size 1."""
input_shape = op.inputs[0].shape
dims = len(input_shape)
if op.inputs[1].shape == []:
@@ -714,14 +714,22 @@ class TFLiteSemantic:
for ax in axis:
if ax < 0 or ax >= dims:
return False, "Axis parameter is out of bounds. axis: {axis}, dims: {dims}. "
- elif dims == 3:
- # depth is only supported if size is 1
- if ax == 2 and input_shape[ax] != 1:
+
+ # Batch is only supported if batch shape is 1
+ if dims == 4 and ax == 0:
+ if input_shape[0] != 1:
valid = False
break
- else: # 4D
- # batch and depth are only supported if sizes are 1
- if ax in [0, 3] and input_shape[ax] != 1:
+
+ # Depth is supported if any of h,w,c == 1
+ if dims == 3:
+ if ax == 2 and not any([s == 1 for s in input_shape]):
+ valid = False
+ break
+
+ # Depth is supported if any of h,w,c == 1
+ if dims == 4:
+ if ax == 3 and not any([s == 1 for s in input_shape[1:]]):
valid = False
break