aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorerik.andersson@arm.com <erik.andersson@arm.com>2021-02-22 15:47:07 +0100
committererik.andersson@arm.com <erik.andersson@arm.com>2021-02-25 16:30:25 +0100
commit0cbb166cd032c779bd4681afef8097f0831ac8be (patch)
tree662bb70133ecf22a3f9aa52a668150692ce5c06a
parent460c689603b6cb713dbad07451278fe208b3f2f6 (diff)
downloadethos-u-vela-0cbb166cd032c779bd4681afef8097f0831ac8be.tar.gz
MLBEDSW-3571: Sum and FC should not crash when asking for keep_dims.
Previously the keep_dims or keep_num_dims attribute was not supported for Sum and Fully Connected operators and would thus crash for certain tests. With this update, the attribute is extracted correctly and saved to the optimised tflite file. Signed-off-by: erik.andersson@arm.com <erik.andersson@arm.com> Change-Id: If33487f6d299bb99788bb3d13332b842ba961641
-rw-r--r--ethosu/vela/supported_operators.py9
-rw-r--r--ethosu/vela/test/test_supported_operators.py8
-rw-r--r--ethosu/vela/tflite_mapping.py6
3 files changed, 21 insertions, 2 deletions
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index 84432c7..8b759be 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -252,6 +252,7 @@ class SupportedOperators:
# FullyConnected specific checks:
self.specific_constraints[Op.FullyConnected].append(SupportedOperators.constraint_fc_output_2d)
+ self.specific_constraints[Op.FullyConnected].append(SupportedOperators.constraint_keep_dim_ifm_ofm)
# Pad specific checks:
self.specific_constraints[Op.Pad].append(SupportedOperators.constraint_matching_in_out_types)
@@ -1068,3 +1069,11 @@ class SupportedOperators:
alpha = op.attrs["alpha"]
valid = alpha >= 0
return valid, f"Op has alpha={alpha}"
+
+ @staticmethod
+ def constraint_keep_dim_ifm_ofm(op):
+ "The IFM and OFM must have the same number of dimensions if keep_num_dims is set to true"
+ valid = True
+ if op.attrs.get("keep_num_dims"):
+ valid = len(op.ifm.shape) == len(op.ofm.shape)
+ return valid, f"Op has ifm shape={op.ifm.shape} and ofm shape={op.ofm.shape}"
diff --git a/ethosu/vela/test/test_supported_operators.py b/ethosu/vela/test/test_supported_operators.py
index 6401d29..832d60f 100644
--- a/ethosu/vela/test/test_supported_operators.py
+++ b/ethosu/vela/test/test_supported_operators.py
@@ -945,3 +945,11 @@ def test_constraint_hardswish_dtype():
out_tens = Tensor([1, 8, 8, 8], DataType.uint8, "out")
op = testutil.create_op(Op.HardSwish, [in_tens], out_tens)
assert not support.is_operator_supported(op)
+
+
+def test_constraint_keep_dims_ifm_ofm():
+ op = testutil.create_op_with_quant_tensors(Op.FullyConnected, [4, 8, 8, 4], [32, 32], weights_shape=[4, 8, 8, 4])
+ op.attrs["keep_num_dims"] = True
+ assert not support.is_operator_supported(op)
+ op.attrs["keep_num_dims"] = False
+ assert support.is_operator_supported(op)
diff --git a/ethosu/vela/tflite_mapping.py b/ethosu/vela/tflite_mapping.py
index 40e3090..41d57c0 100644
--- a/ethosu/vela/tflite_mapping.py
+++ b/ethosu/vela/tflite_mapping.py
@@ -501,7 +501,9 @@ builtin_operator_map = {
BuiltinOperator.FLOOR: (Op.Floor, None),
BuiltinOperator.FULLY_CONNECTED: (
Op.FullyConnected,
- OptionsSerializer("FullyConnectedOptions", (fused_act, "weights_format", "asymmetric_quantize_inputs")),
+ OptionsSerializer(
+ "FullyConnectedOptions", (fused_act, "weights_format", "asymmetric_quantize_inputs", "keep_num_dims")
+ ),
),
BuiltinOperator.HASHTABLE_LOOKUP: (Op.HashtableLookup, None),
BuiltinOperator.L2_NORMALIZATION: (Op.L2Norm, OptionsSerializer("L2NormOptions", (fused_act,))),
@@ -618,7 +620,7 @@ builtin_operator_map = {
BuiltinOperator.EQUAL: (Op.Equal, OptionsSerializer("EqualOptions")),
BuiltinOperator.NOT_EQUAL: (Op.NotEqual, OptionsSerializer("NotEqualOptions")),
BuiltinOperator.LOG: (Op.Log, None),
- BuiltinOperator.SUM: (Op.Sum, None),
+ BuiltinOperator.SUM: (Op.Sum, reducer_opts),
BuiltinOperator.SQRT: (Op.Sqrt, None),
BuiltinOperator.RSQRT: (Op.Rsqrt, None),
BuiltinOperator.SHAPE: (