diff options
author | Ayaan Masood <Ayaan.Masood@arm.com> | 2022-04-21 14:28:03 +0100 |
---|---|---|
committer | Ayaan Masood <Ayaan.Masood@arm.com> | 2022-04-21 14:28:03 +0100 |
commit | a2ec5aa747633da72b6310ce7e5552c39f7f54bb (patch) | |
tree | 04886e488a7059064653c836c32f0e0ec69a4a74 /ethosu/vela/tflite_model_semantic.py | |
parent | f9267da3ad6251a7e04f501218380ac9a89953b7 (diff) | |
download | ethos-u-vela-a2ec5aa747633da72b6310ce7e5552c39f7f54bb.tar.gz |
MLBEDSW-5384 FC layers run on NPU if underlying shape is 2D
*Added generic function which checks if underlying shape of
FullyConnected operation is 2D and performs shape reduction
*Fully connected operation >2 dimensions now run on NPU if the above
case is satisfied
*constraint_fc_output_2d and rewrite_fully_connected_input refactored
*Added unit test to confirm this functionality
Signed-off-by: Ayaan Masood <Ayaan.Masood@arm.com>
Change-Id: I0e29c767e5b84841eb53bbc44464b36a454f7b38
Diffstat (limited to 'ethosu/vela/tflite_model_semantic.py')
-rw-r--r-- | ethosu/vela/tflite_model_semantic.py | 13 |
1 files changed, 5 insertions, 8 deletions
diff --git a/ethosu/vela/tflite_model_semantic.py b/ethosu/vela/tflite_model_semantic.py index b2644791..c811a0d4 100644 --- a/ethosu/vela/tflite_model_semantic.py +++ b/ethosu/vela/tflite_model_semantic.py @@ -295,14 +295,11 @@ class TFLiteSemantic: @staticmethod def constraint_fc_output_2d(op): - "The output tensor(s) must have 2D shape" - valid = True - extra = [] - for tens in op.outputs: - if len(tens.shape) != 2: - valid = False - extra.append(f"Tensor '{tens.name}' is {len(tens.shape)}D") - return valid, ", ".join(extra) + """The output tensor(s) must have 2D shape""" + valid = op.ifm.get_shape_as_2d(op.weights.shape[-2]) is not None + extra = f"Op has non-2D output tensor '{op.ofm.name}'" if not valid else "" + + return valid, extra @staticmethod def constraint_stride_type(op): |