aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_model_semantic.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tflite_model_semantic.py')
-rw-r--r--ethosu/vela/tflite_model_semantic.py34
1 files changed, 34 insertions, 0 deletions
diff --git a/ethosu/vela/tflite_model_semantic.py b/ethosu/vela/tflite_model_semantic.py
index 3ac78b25..62648914 100644
--- a/ethosu/vela/tflite_model_semantic.py
+++ b/ethosu/vela/tflite_model_semantic.py
@@ -110,6 +110,10 @@ class TFLiteSemantic:
# Conv-like checks:
for op_type in TFLiteSemantic.convolution_like_ops:
self.specific_constraints[op_type].append(TFLiteSemantic.constraint_stride_type)
+ if op_type in TFLiteSemantic.convolution_ops:
+ # Only Conv has groups
+ self.specific_constraints[op_type].append(TFLiteSemantic.constraint_conv_groups_ifm_depth)
+ self.specific_constraints[op_type].append(TFLiteSemantic.constraint_conv_groups_num_filters)
if op_type not in TFLiteSemantic.transpose_convolution_ops:
# Transpose Conv does not contain dilation
self.specific_constraints[op_type].append(TFLiteSemantic.constraint_dilation_type)
@@ -373,6 +377,36 @@ class TFLiteSemantic:
return valid, f"Op has stride WxH as: {repr(w)}x{repr(h)}"
@staticmethod
+ def constraint_conv_groups_ifm_depth(op):
+ """IFM depth must be a whole multiple of the filter kernel depth"""
+ ifm_depth = op.ifm.shape[-1] # nhwc
+ kernel_ic = op.weights.shape[-2] # hwio
+ num_conv_groups = ifm_depth // kernel_ic
+
+ if ifm_depth % kernel_ic == 0:
+ op.attrs["num_conv_groups"] = num_conv_groups
+ valid = True
+ else:
+ valid = False
+
+ return valid, f"IFM depth = {ifm_depth} and filter kernel depth = {kernel_ic}"
+
+ @staticmethod
+ def constraint_conv_groups_num_filters(op):
+ """Number of filter kernels must be equally divisible by the number of convolution groups"""
+ ifm_depth = op.ifm.shape[-1] # nhwc
+ kernel_ic = op.weights.shape[-2] # hwio
+ kernel_oc = op.weights.shape[-1] # hwio
+ num_conv_groups = ifm_depth // kernel_ic
+
+ if kernel_oc % num_conv_groups == 0:
+ valid = True
+ else:
+ valid = False
+
+ return valid, f"Filter kernels = {kernel_oc} and convolution groups = {num_conv_groups}"
+
+ @staticmethod
def constraint_dilation_type(op):
"Dilation factor values for both width and height must be integer types"
w, h = op.get_kernel_dilation()