aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_reader.py
diff options
context:
space:
mode:
authorAndreas Nevalainen <andreas.nevalainen@arm.com>2020-09-11 10:25:09 +0200
committerAndreas Nevalainen <andreas.nevalainen@arm.com>2020-09-22 14:02:26 +0200
commitd8c032d4be2a641946507b63023456312e333cb8 (patch)
tree4f55312012f3cdaf536364601f3fb7f1b2511846 /ethosu/vela/tflite_reader.py
parentd9e38fe2bc0458fdca83dd4932abee6554fe2eb2 (diff)
downloadethos-u-vela-d8c032d4be2a641946507b63023456312e333cb8.tar.gz
MLBEDSW-2813: Handle non-const weights and check shapes
- Added check for non-constant weights in supported operators - Added check ifm & ifm2 shapes - Handle None tensors for CPU operators - Handle missing attributes for Cast operator Signed-off-by: Andreas Nevalainen <andreas.nevalainen@arm.com> Change-Id: I2f16d3d44d0c6da5237550b39273cdb9cc3c7607
Diffstat (limited to 'ethosu/vela/tflite_reader.py')
-rw-r--r--ethosu/vela/tflite_reader.py13
1 files changed, 11 insertions, 2 deletions
diff --git a/ethosu/vela/tflite_reader.py b/ethosu/vela/tflite_reader.py
index a2f744d3..7458b907 100644
--- a/ethosu/vela/tflite_reader.py
+++ b/ethosu/vela/tflite_reader.py
@@ -152,7 +152,8 @@ class TFLiteSubgraph:
activation_function_to_split_out = None
if op_type.startswith("DepthwiseConv2d") or op_type.startswith("Conv2D"):
- inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 2, 3, 0))
+ if inputs[1].values is not None:
+ inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 2, 3, 0))
if len(inputs) < 3 or (len(inputs) < 4 and "Backprop" in op_type):
# No Bias tensor
inputs.append(None)
@@ -160,7 +161,8 @@ class TFLiteSubgraph:
inputs[-1] = clone_and_reshape_tensor(inputs[-1], (0,))
if op_type.startswith("FullyConnected"):
- inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 0))
+ if inputs[1].values is not None:
+ inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 0))
if len(inputs) < 3:
# No Bias tensor
inputs.append(None)
@@ -174,6 +176,13 @@ class TFLiteSubgraph:
# Reshape should have an attrib "new_shape" but if it is missing, add it based on the output shape
op.attrs["new_shape"] = outputs[0].shape
+ if op_type == "Cast":
+ # Cast op should have "in/out_data_type" attribs add if missing
+ if "in_data_type" not in op.attrs:
+ op.attrs["in_data_type"] = inputs[0].dtype
+ if "out_data_type" not in op.attrs:
+ op.attrs["out_data_type"] = outputs[0].dtype
+
if "stride_w" in op.attrs:
op.attrs["strides"] = (1, op.attrs["stride_h"], op.attrs["stride_w"], 1)
if "filter_width" in op.attrs: