diff options
Diffstat (limited to 'ethosu/vela/tflite_reader.py')
-rw-r--r-- | ethosu/vela/tflite_reader.py | 13 |
1 files changed, 11 insertions, 2 deletions
diff --git a/ethosu/vela/tflite_reader.py b/ethosu/vela/tflite_reader.py index a2f744d3..7458b907 100644 --- a/ethosu/vela/tflite_reader.py +++ b/ethosu/vela/tflite_reader.py @@ -152,7 +152,8 @@ class TFLiteSubgraph: activation_function_to_split_out = None if op_type.startswith("DepthwiseConv2d") or op_type.startswith("Conv2D"): - inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 2, 3, 0)) + if inputs[1].values is not None: + inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 2, 3, 0)) if len(inputs) < 3 or (len(inputs) < 4 and "Backprop" in op_type): # No Bias tensor inputs.append(None) @@ -160,7 +161,8 @@ class TFLiteSubgraph: inputs[-1] = clone_and_reshape_tensor(inputs[-1], (0,)) if op_type.startswith("FullyConnected"): - inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 0)) + if inputs[1].values is not None: + inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 0)) if len(inputs) < 3: # No Bias tensor inputs.append(None) @@ -174,6 +176,13 @@ class TFLiteSubgraph: # Reshape should have an attrib "new_shape" but if it is missing, add it based on the output shape op.attrs["new_shape"] = outputs[0].shape + if op_type == "Cast": + # Cast op should have "in/out_data_type" attribs add if missing + if "in_data_type" not in op.attrs: + op.attrs["in_data_type"] = inputs[0].dtype + if "out_data_type" not in op.attrs: + op.attrs["out_data_type"] = outputs[0].dtype + if "stride_w" in op.attrs: op.attrs["strides"] = (1, op.attrs["stride_h"], op.attrs["stride_w"], 1) if "filter_width" in op.attrs: |