aboutsummaryrefslogtreecommitdiff
path: root/verif/tosa_test_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/tosa_test_gen.py')
-rw-r--r--verif/tosa_test_gen.py6
1 files changed, 4 insertions, 2 deletions
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py
index 07dc7e5..fbf240d 100644
--- a/verif/tosa_test_gen.py
+++ b/verif/tosa_test_gen.py
@@ -404,10 +404,12 @@ class TosaTensorGen:
def tgConcatConstInput(testGen, shapeList, axis):
# Split concat shape along axis to allow for multiple const inputs
# without making too many large tensors
- shape = shapeList[0]
- if len(shapeList) == 2 or shape[axis] < len(shapeList):
+ if len(shapeList) == 2 or shapeList[0][axis] < len(shapeList):
return shapeList
+ # Create copy of shape we are going to split (so we don't alter shapeList)
+ shape = shapeList[0].copy()
+ # Add original shape as first input
new_shapeList = [shape.copy()]
length_on_axis = shape[axis]
remaining_length = length_on_axis