aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLuke Hutton <luke.hutton@arm.com>2023-02-22 11:53:48 +0000
committerEric Kunze <eric.kunze@arm.com>2023-02-28 20:08:57 +0000
commita4e48ca7b032992ca0110900935c08d7cf860cd3 (patch)
treea58c8617390225ecc107721d9b5ff87c2bdb01b0
parent2226f90d5a6c48a975045bc9e0419113ce764aaf (diff)
downloadreference_model-a4e48ca7b032992ca0110900935c08d7cf860cd3.tar.gz
Update rank limits for SLICE, TILE and TRANSPOSE
Updated to align with corresponding changes to the spec. In addition, some ERROR_IF tests have been updated to match the checks specified by the spec, including: PAD, SLICE, TILE, TRANSPOSE. Signed-off-by: Luke Hutton <luke.hutton@arm.com> Change-Id: Ie2c5f48e79a5610eb82739170e25057a63dac1d8
-rw-r--r--reference_model/src/ops/data_layout.cc142
-rw-r--r--reference_model/src/ops/data_layout.h4
-rw-r--r--reference_model/src/ops/op_factory.cc42
-rw-r--r--verif/generator/tosa_arg_gen.py1
-rw-r--r--verif/generator/tosa_error_if.py16
-rw-r--r--verif/generator/tosa_test_gen.py48
-rw-r--r--verif/generator/tosa_utils.py11
7 files changed, 186 insertions, 78 deletions
diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc
index 909c567..ce5b5af 100644
--- a/reference_model/src/ops/data_layout.cc
+++ b/reference_model/src/ops/data_layout.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2022, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -127,7 +127,7 @@ OpPad<Rank, Dtype>::OpPad(SubgraphTraverser* sgt_,
: GraphNode(sgt_, Op_PAD, id_)
{
setRequiredOperands(1, 1);
- setRequiredRank(0, 6);
+ setRequiredRank(1, 6);
INIT_ATTRIBUTE(Pad);
}
@@ -374,7 +374,7 @@ OpSlice<Rank, Dtype>::OpSlice(SubgraphTraverser* sgt_,
: GraphNode(sgt_, Op_SLICE, id_)
{
setRequiredOperands(1, 1);
- setRequiredRank(1, 4);
+ setRequiredRank(1, 6);
INIT_ATTRIBUTE(Slice);
}
@@ -398,15 +398,17 @@ int OpSlice<Rank, Dtype>::checkTensorAttributes()
}
// output and input must be the same types
- if (inputs[0]->matchType(*outputs[0]))
+ if (inputs[0]->matchRankType(*outputs[0]))
{
- printNodeValidationError("Failure to match input and output type");
+ printNodeValidationError("Failure to match input and output rank or type");
return 1;
}
in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+ ASSERT_MEM(in && out);
+
ERROR_IF((int32_t)attribute->start().size() != in->getRank(),
"OpSlice: begin array length needs to be rank(input)");
ERROR_IF((int32_t)attribute->size().size() != in->getRank(), "OpSlice: size array length needs to be rank(input)");
@@ -441,7 +443,7 @@ OpTileBase<Rank, Dtype>::OpTileBase(SubgraphTraverser* sgt_,
: GraphNode(sgt_, Op_TILE, id_)
{
setRequiredOperands(1, 1);
- setRequiredRank(0, 6);
+ setRequiredRank(1, 6);
INIT_ATTRIBUTE(Tile);
}
@@ -474,6 +476,8 @@ int OpTileBase<Rank, Dtype>::checkTensorAttributes()
in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+ ASSERT_MEM(in && out);
+
if (attribute->multiples().size() != Rank)
{
printNodeValidationError("1D list 'multiples' must have size equal to input rank");
@@ -568,6 +572,68 @@ int OpTile<4, Dtype>::eval()
return GraphNode::eval();
}
+template <DType Dtype>
+int OpTile<5, Dtype>::eval()
+{
+ for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
+ {
+ int32_t id0 = od0 % this->in->getShape()[0];
+ for (int32_t od1 = 0; od1 < this->out->getShape()[1]; od1++)
+ {
+ int32_t id1 = od1 % this->in->getShape()[1];
+ for (int32_t od2 = 0; od2 < this->out->getShape()[2]; od2++)
+ {
+ int32_t id2 = od2 % this->in->getShape()[2];
+ for (int32_t od3 = 0; od3 < this->out->getShape()[3]; od3++)
+ {
+ int32_t id3 = od3 % this->in->getShape()[3];
+ for (int32_t od4 = 0; od4 < this->out->getShape()[4]; od4++)
+ {
+ int32_t id4 = od4 % this->in->getShape()[4];
+ this->out->getTensor()(od0, od1, od2, od3, od4) =
+ this->in->getTensor()(id0, id1, id2, id3, id4);
+ }
+ }
+ }
+ }
+ }
+
+ return GraphNode::eval();
+}
+
+template <DType Dtype>
+int OpTile<6, Dtype>::eval()
+{
+ for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
+ {
+ int32_t id0 = od0 % this->in->getShape()[0];
+ for (int32_t od1 = 0; od1 < this->out->getShape()[1]; od1++)
+ {
+ int32_t id1 = od1 % this->in->getShape()[1];
+ for (int32_t od2 = 0; od2 < this->out->getShape()[2]; od2++)
+ {
+ int32_t id2 = od2 % this->in->getShape()[2];
+ for (int32_t od3 = 0; od3 < this->out->getShape()[3]; od3++)
+ {
+ int32_t id3 = od3 % this->in->getShape()[3];
+ for (int32_t od4 = 0; od4 < this->out->getShape()[4]; od4++)
+ {
+ int32_t id4 = od4 % this->in->getShape()[4];
+ for (int32_t od5 = 0; od5 < this->out->getShape()[5]; od5++)
+ {
+ int32_t id5 = od5 % this->in->getShape()[5];
+ this->out->getTensor()(od0, od1, od2, od3, od4, od5) =
+ this->in->getTensor()(id0, id1, id2, id3, id4, id5);
+ }
+ }
+ }
+ }
+ }
+ }
+
+ return GraphNode::eval();
+}
+
template <int Rank, DType Dtype>
OpTranspose<Rank, Dtype>::OpTranspose(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
@@ -575,7 +641,7 @@ OpTranspose<Rank, Dtype>::OpTranspose(SubgraphTraverser* sgt_,
: GraphNode(sgt_, Op_TRANSPOSE, id_)
{
setRequiredOperands(1, 1);
- setRequiredRank(0, 6);
+ setRequiredRank(1, 6);
INIT_ATTRIBUTE(Transpose);
}
@@ -673,34 +739,34 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BOOL);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FP16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, BF16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FP32);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT8);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT32);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL);
-
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTileBase, FP16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTileBase, BF16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTileBase, FP32);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTileBase, INT8);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTileBase, INT16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTileBase, INT32);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTileBase, BOOL);
-
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FP16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, BF16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FP32);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT8);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT32);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, BOOL);
-
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FP16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, BF16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FP32);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT32);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, BF16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP32);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT8);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT32);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL);
+
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, FP16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, BF16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, FP32);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, INT8);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, INT16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, INT32);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, BOOL);
+
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, BF16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP32);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT8);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT32);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, BOOL);
+
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, BF16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP32);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT32);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL);
diff --git a/reference_model/src/ops/data_layout.h b/reference_model/src/ops/data_layout.h
index c6513ae..3a6cb0d 100644
--- a/reference_model/src/ops/data_layout.h
+++ b/reference_model/src/ops/data_layout.h
@@ -1,5 +1,5 @@
-// Copyright (c) 2020, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -186,6 +186,8 @@ DEF_OP_TILE_RANK(1)
DEF_OP_TILE_RANK(2)
DEF_OP_TILE_RANK(3)
DEF_OP_TILE_RANK(4)
+DEF_OP_TILE_RANK(5)
+DEF_OP_TILE_RANK(6)
#undef DEF_OP_TILE_RANK
diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc
index 8d84135..1db3974 100644
--- a/reference_model/src/ops/op_factory.cc
+++ b/reference_model/src/ops/op_factory.cc
@@ -394,31 +394,31 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BOOL);
break;
case Op_SLICE:
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FP16);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, BF16);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FP32);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT8);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT16);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT32);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP16);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, BF16);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP32);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT8);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT16);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT32);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL);
break;
case Op_TILE:
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FP16);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, BF16);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FP32);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT8);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT16);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT32);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, BOOL);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP16);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, BF16);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP32);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT8);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT16);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT32);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, BOOL);
break;
case Op_TRANSPOSE:
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FP16);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, BF16);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FP32);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT32);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP16);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, BF16);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP32);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT32);
break;
// scatter_gather
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index 75ca634..9209d9c 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -176,6 +176,7 @@ class TosaTensorGen:
for i in range(pl + const):
shape_list.append(shape.copy())
+ # Generates an input rank mismatch for operators with more than one input
if error_name == ErrorIf.RankMismatch:
if rank == 1 and i != 1:
shape = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3]))
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index ee227b3..b19d5e9 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -1067,7 +1067,9 @@ class TosaErrorValidator:
if check:
input1_shape = kwargs["input1"].shape
- input2_shape = kwargs["input2"].shape
+ input2_shape = (
+ kwargs["input2"].shape if "input2" in kwargs else input1_shape
+ )
# In case of SELECT op
input3_shape = (
kwargs["input3"].shape if "input3" in kwargs else input2_shape
@@ -1921,11 +1923,13 @@ class TosaErrorValidator:
input_shape = kwargs["input_shape"]
output_shape = kwargs["output_shape"]
size = kwargs["size"]
- rank = len(input_shape)
- if len(size) == rank:
- for index in range(rank):
- if size[index] != output_shape[index]:
- error_result = True
+
+ if len(input_shape) == len(output_shape):
+ rank = len(input_shape)
+ if len(size) == rank:
+ for index in range(rank):
+ if size[index] != output_shape[index]:
+ error_result = True
info_dict = {
"error_name": error_name,
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index a768da0..7fef942 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -14,6 +14,7 @@ from generator.tosa_error_if import TosaErrorIfArgGen
from generator.tosa_error_if import TosaErrorValidator
from generator.tosa_error_if import TosaInvalidValidator
from generator.tosa_utils import DTYPE_ATTRIBUTES
+from generator.tosa_utils import get_rank_mismatch_shape
from generator.tosa_utils import get_wrong_output_type
from generator.tosa_utils import MAX_RESIZE_DIMENSION
from generator.tosa_utils import usableDTypes
@@ -1263,6 +1264,7 @@ class TosaTestGen:
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
+ input1=a,
):
return None
@@ -1369,6 +1371,7 @@ class TosaTestGen:
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
+ input1=a,
):
return None
@@ -1404,6 +1407,7 @@ class TosaTestGen:
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
+ input1=a,
):
return None
@@ -1438,6 +1442,7 @@ class TosaTestGen:
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
+ input1=a,
):
return None
@@ -3657,6 +3662,8 @@ class TosaTestGen:
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
+ TosaErrorValidator.evRankMismatch,
+ TosaErrorValidator.evWrongRank,
),
},
"reshape": {
@@ -3699,7 +3706,7 @@ class TosaTestGen:
"slice": {
"op": Op.SLICE,
"operands": (1, 0),
- "rank": (1, 4),
+ "rank": (1, 6),
"build_fcn": (
build_slice,
TosaTensorGen.tgBasic,
@@ -3718,11 +3725,13 @@ class TosaTestGen:
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
+ TosaErrorValidator.evRankMismatch,
),
},
"tile": {
"op": Op.TILE,
"operands": (1, 0),
+ "rank": (1, 6),
"build_fcn": (
build_tile,
TosaTensorGen.tgBasic,
@@ -3735,12 +3744,14 @@ class TosaTestGen:
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
+ TosaErrorValidator.evRankMismatch,
+ TosaErrorValidator.evWrongRank,
),
},
"transpose": {
"op": Op.TRANSPOSE,
"operands": (1, 0),
- "rank": (1, 4),
+ "rank": (1, 6),
"build_fcn": (
build_transpose,
TosaTensorGen.tgBasic,
@@ -3755,6 +3766,9 @@ class TosaTestGen:
TosaErrorValidator.evWrongOutputType,
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
+ TosaErrorValidator.evWrongRank,
+ TosaErrorValidator.evRankMismatch,
+ TosaErrorValidator.evTensorSizeInputOutputMismatch,
),
},
# Data nodes
@@ -4539,6 +4553,8 @@ class OutputShaper:
if error_name == ErrorIf.PadOutputShapeMismatch:
bad_dim = rng.choice(range(len(output_shape)))
output_shape[bad_dim] -= rng.choice([1, 2])
+ elif error_name == ErrorIf.RankMismatch:
+ output_shape = get_rank_mismatch_shape(rng, output_shape)
if error_name == ErrorIf.WrongOutputType:
all_dtypes = [
@@ -4583,7 +4599,7 @@ class OutputShaper:
return ser.addOutput(output_shape, outputDType)
@staticmethod
- def sliceOp(ser, rng, a, start, size, error_name=None):
+ def sliceOp(ser, rng, input, start, size, error_name=None):
if error_name == ErrorIf.WrongOutputType:
all_dtypes = [
@@ -4595,13 +4611,13 @@ class OutputShaper:
DType.FP16,
DType.BF16,
]
- wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
+ wrong_dtypes = list(set(all_dtypes) - set([input.dtype]))
outputDType = rng.choice(wrong_dtypes)
else:
- outputDType = a.dtype
+ outputDType = input.dtype
+ output_shape = size.copy()
if error_name == ErrorIf.SizeOutputShapeMismatch:
- output_shape = size.copy()
for index in range(len(output_shape)):
if output_shape[index] <= 2:
output_shape[index] = output_shape[index] + rng.choice([1, 2])
@@ -4609,8 +4625,10 @@ class OutputShaper:
output_shape[index] = output_shape[index] + rng.choice(
[-2, -1, 1, 2]
)
- else:
- output_shape = size.copy()
+ elif error_name == ErrorIf.InputSizeStartLengthMismatch:
+ output_shape = input.shape.copy()
+ elif error_name == ErrorIf.RankMismatch:
+ output_shape = get_rank_mismatch_shape(rng, output_shape)
return ser.addOutput(output_shape, outputDType)
@@ -4623,6 +4641,9 @@ class OutputShaper:
for i in range(len(output_shape)):
output_shape[i] = a.shape[i] * multiples[i]
+ if error_name == ErrorIf.RankMismatch:
+ output_shape = get_rank_mismatch_shape(rng, output_shape)
+
if error_name == ErrorIf.WrongOutputType:
all_dtypes = [
DType.INT8,
@@ -4646,13 +4667,16 @@ class OutputShaper:
assert len(perms) == len(output_shape)
- if error_name == ErrorIf.IndexOutsideBounds:
- for i in range(len(output_shape)):
- output_shape[i] = a.shape[0]
- else:
+ if error_name not in [ErrorIf.IndexOutsideBounds, ErrorIf.IndexUsedTwice]:
for i in range(len(output_shape)):
output_shape[i] = a.shape[perms[i]]
+ if error_name == ErrorIf.TensorSizeInputOutputMismatch:
+ for i in range(len(output_shape)):
+ output_shape[i] += rng.integers(1, 10)
+ elif error_name == ErrorIf.RankMismatch:
+ output_shape = get_rank_mismatch_shape(rng, output_shape)
+
if error_name == ErrorIf.WrongOutputType:
all_dtypes = [
DType.INT8,
diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py
index 29ae898..8ff62f1 100644
--- a/verif/generator/tosa_utils.py
+++ b/verif/generator/tosa_utils.py
@@ -148,6 +148,17 @@ def get_wrong_output_type(op_name, rng, input_dtype):
return rng.choice(a=incorrect_types)
+def get_rank_mismatch_shape(rng, output_shape):
+ """
+ Extends the rank of the provided output_shape by
+ an arbitrary amount but ensures the total element
+ count remains the same.
+ """
+ rank_modifier = rng.choice([1, 2, 3])
+ output_shape += [1] * rank_modifier
+ return output_shape
+
+
def float32_is_valid_bfloat16(f):
"""Return True if float value is valid bfloat16."""
f32_bits = get_float32_bitstring(f)