diff options
Diffstat (limited to 'ethosu/vela/tosa_reader.py')
-rw-r--r-- | ethosu/vela/tosa_reader.py | 71 |
1 files changed, 63 insertions, 8 deletions
diff --git a/ethosu/vela/tosa_reader.py b/ethosu/vela/tosa_reader.py index 2f37478f..9ffda801 100644 --- a/ethosu/vela/tosa_reader.py +++ b/ethosu/vela/tosa_reader.py @@ -131,13 +131,65 @@ class TosaSubgraph: # TODO Transpose_conv and conv3d if op.type.is_depthwise_conv2d_op() or op.type.is_conv2d_op() or op.type == Op.FullyConnected: - if inputs[1].values is not None: - if op.type == Op.FullyConnected: - inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 0), False) - elif op.type.is_conv2d_op(): - inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 2, 3, 0), False) - elif op.type.is_depthwise_conv2d_op(): - inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 2, 0, 3), False) + + def _remove_producing_identity_op(prod_op): + # find the producing op that is not an identity op and return it + while prod_op.type == Op.Identity: + prod_op = prod_op.inputs[0].ops[0] # get previous op + return prod_op + + def _check_and_get_connection(prod_op, tens): + # check weight producing op can be connected to the weight tensor + assert len(prod_op.outputs) == 1 + assert tens.shape == prod_op.outputs[0].shape + # only need to connect the current op connection as the tensor consuming connections haven't been + # initialised yet + return prod_op.outputs[0] + + # remove identity ops directly connected to the weight input of conv like ops + weights_producer_op = _remove_producing_identity_op(inputs[1].ops[0]) + inputs[1] = _check_and_get_connection(weights_producer_op, inputs[1]) # update connection + + if weights_producer_op.type == Op.Transpose: + # remove transpose op such that the weight op will a const op + transpose_op = weights_producer_op + # remove identity ops directly connected to the input of the transpose op + transpose_producer_op = _remove_producing_identity_op(transpose_op.inputs[0].ops[0]) + transpose_op.inputs[0] = _check_and_get_connection( + transpose_producer_op, transpose_op.inputs[0] + ) # update connection + + perms = transpose_op.attrs["perms"] + inputs[1] = clone_and_reshape_tensor(transpose_op.inputs[0], perms, False) + + if weights_producer_op.type == Op.Reshape: + # remove reshape op such that the weight op will a const op + reshape_op = weights_producer_op + # remove identity ops directly connected to the input of the reshape op + reshape_producer_op = _remove_producing_identity_op(reshape_op.inputs[0].ops[0]) + reshape_op.inputs[0] = _check_and_get_connection( + reshape_producer_op, reshape_op.inputs[0] + ) # update connection + + tens = reshape_op.inputs[0].clone("_reshape", False) + tens.values = np.reshape(tens.values, reshape_op.ofm.shape) + tens.shape = reshape_op.ofm.shape + tens._original_shape = tens.shape + tens.bandwidth_shape = tens.shape + tens.storage_shape = tens.shape + + tmp_op = Operation(Op.Const, tens.name) + tmp_op.set_output_tensor(tens) + inputs[1] = tens + + assert inputs[1].values is not None + + if op.type == Op.FullyConnected: + inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 0), False) + elif op.type.is_conv2d_op(): + inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 2, 3, 0), False) + elif op.type.is_depthwise_conv2d_op(): + inputs[1] = clone_and_reshape_tensor(inputs[1], (1, 2, 0, 3), False) if op.type.needs_bias() and len(inputs) <= op_type.info.indices.biases[0]: # No Bias tensor inputs.append(None) @@ -146,10 +198,13 @@ class TosaSubgraph: # a clone with a unique equivalence_id is needed inputs[-1] = clone_and_reshape_tensor(inputs[-1], (0,), True) + op.explicit_scaling = ExplicitScaling(False, [0], [1]) # no scaling + if attr_serializer is not None: op.attrs = attr_serializer.deserialize(op_data) - if "padding" in op.attrs: + if "pad" in op.attrs: + op.attrs["padding"] = op.attrs["pad"] # attribute was renamed to padding padding = op.attrs["padding"] # [top, bottom, left, right] op.attrs["explicit_padding"] = ( padding[0], |