diff options
Diffstat (limited to 'ethosu/vela/operation.py')
-rw-r--r-- | ethosu/vela/operation.py | 21 |
1 files changed, 21 insertions, 0 deletions
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py index f7a95098..8dec379d 100644 --- a/ethosu/vela/operation.py +++ b/ethosu/vela/operation.py @@ -179,6 +179,27 @@ input and output tensors, as well as an attribute dictionary.""" return ifm_tensor, weight_tensor, bias_tensor, ofm_tensor + def get_ifm_ifm2_weights_biases_ofm(self): + ifm_tensor = None + ifm2_tensor = None + weight_tensor = None + bias_tensor = None + ofm_tensor = None + + ifm_idx, ifm2_idx, weight_idx, bias_idx, ofm_idx = self.get_ifm_ifm2_weight_bias_ofm_indices() + if ifm_idx != -1: + ifm_tensor = self.inputs[ifm_idx] + if ifm2_idx != -1: + ifm2_tensor = self.inputs[ifm2_idx] + if weight_idx != -1: + weight_tensor = self.inputs[weight_idx] + if bias_idx != -1: + bias_tensor = self.inputs[bias_idx] + if ofm_idx != -1: + ofm_tensor = self.outputs[ofm_idx] + + return ifm_tensor, ifm2_tensor, weight_tensor, bias_tensor, ofm_tensor + def is_concat_op(self): return self.type in ("Concat", "ConcatV2", "QuantizedConcat", "ConcatTFLite", "PackReshaped") |