aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tensor.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tensor.py')
-rw-r--r--ethosu/vela/tensor.py12
1 files changed, 6 insertions, 6 deletions
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index b07b4dc3..f6e628c8 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -557,8 +557,10 @@ class Tensor:
return self.consumer_list
def get_address_ranges_for_coordinates(self, start_coord, end_coord):
- if self.sub_purpose in set(
- (TensorSubPurpose.RollingBufferX, TensorSubPurpose.RollingBufferY, TensorSubPurpose.RollingBufferXY)
+ if self.sub_purpose in (
+ TensorSubPurpose.RollingBufferX,
+ TensorSubPurpose.RollingBufferY,
+ TensorSubPurpose.RollingBufferXY,
):
# build dummy coordinates that cover the entire buffer
start_coord = [0] * len(start_coord)
@@ -637,7 +639,7 @@ class Tensor:
augmented_shape[1] = 1
else:
- assert self.format in set((TensorFormat.Unknown, TensorFormat.WeightsCompressed))
+ assert self.format in (TensorFormat.Unknown, TensorFormat.WeightsCompressed)
return None, None
strides = [0] * len(augmented_shape)
@@ -774,9 +776,7 @@ class Tensor:
return address_offset
def is_allocated_in_tensor_arena(self, scratch_tensor_mem_area):
- if self.mem_area == scratch_tensor_mem_area and (self.mem_type in set((MemType.Scratch, MemType.Scratch_fast))):
- return True
- return False
+ return (self.mem_area == scratch_tensor_mem_area) and (self.mem_type in (MemType.Scratch, MemType.Scratch_fast))
def equivalent(self, tens):
return self.equivalence_id == tens.equivalence_id