diff options
author | erik.andersson@arm.com <erik.andersson@arm.com> | 2021-02-11 14:02:08 +0100 |
---|---|---|
committer | patrik.gustavsson <patrik.gustavsson@arm.com> | 2021-02-12 13:30:07 +0000 |
commit | 42b94edb8bcd71057ebeac2fa255d44dbf56a0f3 (patch) | |
tree | 777566e204fca798d9f993a13c2a028ac247d730 /ethosu/vela | |
parent | 2a58530de2686a3dc1cbe791f1f951b9ea7c39aa (diff) | |
download | ethos-u-vela-42b94edb8bcd71057ebeac2fa255d44dbf56a0f3.tar.gz |
MLBEDSW-3902: Fixes invalid op when cloning LeakyReLU operator
When running specific networks containing LeakyReLU operators, Vela would crash when cloning an ofm of a LeakyReLU operator.
In this procedure a deepcopy usage would try to copy an OperatorInfo object, which caused an error.
This was fixed by replacing the deepcopy usage with a copy and then manually referencing new instances of sensitive variables.
Signed-off-by: erik.andersson@arm.com <erik.andersson@arm.com>
Change-Id: I46917858896fbdf52245dac6c6d9c18bc7ecdd0d
Diffstat (limited to 'ethosu/vela')
-rw-r--r-- | ethosu/vela/tensor.py | 14 |
1 files changed, 6 insertions, 8 deletions
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py index ef8a28fc..b7d4307f 100644 --- a/ethosu/vela/tensor.py +++ b/ethosu/vela/tensor.py @@ -1,4 +1,4 @@ -# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved. +# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved. # # SPDX-License-Identifier: Apache-2.0 # @@ -437,15 +437,13 @@ class Tensor: # Depending on set_unique, the copy is shallow, or deep # For set_unique==True, a new equivalence_id will be set def clone(self, suffix="_clone", set_unique: bool = False) -> "Tensor": + res = copy.copy(self) if set_unique: - res = copy.deepcopy(self) res.equivalence_id = uuid.uuid4() - else: - res = copy.copy(self) - res.storage_shape = list(self.storage_shape) - res.bandwidth_shape = list(self.bandwidth_shape) - if self.quantization is not None: - res.quantization = self.quantization.clone() + res.storage_shape = list(self.storage_shape) + res.bandwidth_shape = list(self.bandwidth_shape) + if self.quantization is not None: + res.quantization = self.quantization.clone() res.name = res.name + suffix res.ops = [] |