diff options
-rw-r--r-- | reference_model/src/ops/data_layout.cc | 142 | ||||
-rw-r--r-- | reference_model/src/ops/data_layout.h | 4 | ||||
-rw-r--r-- | reference_model/src/ops/op_factory.cc | 42 | ||||
-rw-r--r-- | verif/generator/tosa_arg_gen.py | 1 | ||||
-rw-r--r-- | verif/generator/tosa_error_if.py | 16 | ||||
-rw-r--r-- | verif/generator/tosa_test_gen.py | 48 | ||||
-rw-r--r-- | verif/generator/tosa_utils.py | 11 |
7 files changed, 186 insertions, 78 deletions
diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc index 909c567..ce5b5af 100644 --- a/reference_model/src/ops/data_layout.cc +++ b/reference_model/src/ops/data_layout.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2022, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -127,7 +127,7 @@ OpPad<Rank, Dtype>::OpPad(SubgraphTraverser* sgt_, : GraphNode(sgt_, Op_PAD, id_) { setRequiredOperands(1, 1); - setRequiredRank(0, 6); + setRequiredRank(1, 6); INIT_ATTRIBUTE(Pad); } @@ -374,7 +374,7 @@ OpSlice<Rank, Dtype>::OpSlice(SubgraphTraverser* sgt_, : GraphNode(sgt_, Op_SLICE, id_) { setRequiredOperands(1, 1); - setRequiredRank(1, 4); + setRequiredRank(1, 6); INIT_ATTRIBUTE(Slice); } @@ -398,15 +398,17 @@ int OpSlice<Rank, Dtype>::checkTensorAttributes() } // output and input must be the same types - if (inputs[0]->matchType(*outputs[0])) + if (inputs[0]->matchRankType(*outputs[0])) { - printNodeValidationError("Failure to match input and output type"); + printNodeValidationError("Failure to match input and output rank or type"); return 1; } in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]); + ASSERT_MEM(in && out); + ERROR_IF((int32_t)attribute->start().size() != in->getRank(), "OpSlice: begin array length needs to be rank(input)"); ERROR_IF((int32_t)attribute->size().size() != in->getRank(), "OpSlice: size array length needs to be rank(input)"); @@ -441,7 +443,7 @@ OpTileBase<Rank, Dtype>::OpTileBase(SubgraphTraverser* sgt_, : GraphNode(sgt_, Op_TILE, id_) { setRequiredOperands(1, 1); - setRequiredRank(0, 6); + setRequiredRank(1, 6); INIT_ATTRIBUTE(Tile); } @@ -474,6 +476,8 @@ int OpTileBase<Rank, Dtype>::checkTensorAttributes() in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]); + ASSERT_MEM(in && out); + if (attribute->multiples().size() != Rank) { printNodeValidationError("1D list 'multiples' must have size equal to input rank"); @@ -568,6 +572,68 @@ int OpTile<4, Dtype>::eval() return GraphNode::eval(); } +template <DType Dtype> +int OpTile<5, Dtype>::eval() +{ + for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++) + { + int32_t id0 = od0 % this->in->getShape()[0]; + for (int32_t od1 = 0; od1 < this->out->getShape()[1]; od1++) + { + int32_t id1 = od1 % this->in->getShape()[1]; + for (int32_t od2 = 0; od2 < this->out->getShape()[2]; od2++) + { + int32_t id2 = od2 % this->in->getShape()[2]; + for (int32_t od3 = 0; od3 < this->out->getShape()[3]; od3++) + { + int32_t id3 = od3 % this->in->getShape()[3]; + for (int32_t od4 = 0; od4 < this->out->getShape()[4]; od4++) + { + int32_t id4 = od4 % this->in->getShape()[4]; + this->out->getTensor()(od0, od1, od2, od3, od4) = + this->in->getTensor()(id0, id1, id2, id3, id4); + } + } + } + } + } + + return GraphNode::eval(); +} + +template <DType Dtype> +int OpTile<6, Dtype>::eval() +{ + for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++) + { + int32_t id0 = od0 % this->in->getShape()[0]; + for (int32_t od1 = 0; od1 < this->out->getShape()[1]; od1++) + { + int32_t id1 = od1 % this->in->getShape()[1]; + for (int32_t od2 = 0; od2 < this->out->getShape()[2]; od2++) + { + int32_t id2 = od2 % this->in->getShape()[2]; + for (int32_t od3 = 0; od3 < this->out->getShape()[3]; od3++) + { + int32_t id3 = od3 % this->in->getShape()[3]; + for (int32_t od4 = 0; od4 < this->out->getShape()[4]; od4++) + { + int32_t id4 = od4 % this->in->getShape()[4]; + for (int32_t od5 = 0; od5 < this->out->getShape()[5]; od5++) + { + int32_t id5 = od5 % this->in->getShape()[5]; + this->out->getTensor()(od0, od1, od2, od3, od4, od5) = + this->in->getTensor()(id0, id1, id2, id3, id4, id5); + } + } + } + } + } + } + + return GraphNode::eval(); +} + template <int Rank, DType Dtype> OpTranspose<Rank, Dtype>::OpTranspose(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, @@ -575,7 +641,7 @@ OpTranspose<Rank, Dtype>::OpTranspose(SubgraphTraverser* sgt_, : GraphNode(sgt_, Op_TRANSPOSE, id_) { setRequiredOperands(1, 1); - setRequiredRank(0, 6); + setRequiredRank(1, 6); INIT_ATTRIBUTE(Transpose); } @@ -673,34 +739,34 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BOOL); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FP16); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, BF16); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FP32); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT8); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT16); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT32); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL); - -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTileBase, FP16); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTileBase, BF16); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTileBase, FP32); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTileBase, INT8); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTileBase, INT16); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTileBase, INT32); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTileBase, BOOL); - -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FP16); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, BF16); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FP32); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT8); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT16); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT32); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, BOOL); - -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FP16); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, BF16); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FP32); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT32); -DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP16); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, BF16); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP32); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT8); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT16); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT32); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL); + +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, FP16); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, BF16); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, FP32); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, INT8); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, INT16); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, INT32); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, BOOL); + +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP16); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, BF16); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP32); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT8); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT16); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT32); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, BOOL); + +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP16); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, BF16); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP32); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT32); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL); diff --git a/reference_model/src/ops/data_layout.h b/reference_model/src/ops/data_layout.h index c6513ae..3a6cb0d 100644 --- a/reference_model/src/ops/data_layout.h +++ b/reference_model/src/ops/data_layout.h @@ -1,5 +1,5 @@ -// Copyright (c) 2020, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -186,6 +186,8 @@ DEF_OP_TILE_RANK(1) DEF_OP_TILE_RANK(2) DEF_OP_TILE_RANK(3) DEF_OP_TILE_RANK(4) +DEF_OP_TILE_RANK(5) +DEF_OP_TILE_RANK(6) #undef DEF_OP_TILE_RANK diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc index 8d84135..1db3974 100644 --- a/reference_model/src/ops/op_factory.cc +++ b/reference_model/src/ops/op_factory.cc @@ -394,31 +394,31 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BOOL); break; case Op_SLICE: - DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FP16); - DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, BF16); - DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FP32); - DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT8); - DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT16); - DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT32); - DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP16); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, BF16); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP32); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT8); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT16); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT32); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL); break; case Op_TILE: - DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FP16); - DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, BF16); - DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FP32); - DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT8); - DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT16); - DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT32); - DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, BOOL); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP16); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, BF16); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP32); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT8); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT16); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT32); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, BOOL); break; case Op_TRANSPOSE: - DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL); - DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FP16); - DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, BF16); - DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FP32); - DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8); - DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16); - DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT32); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP16); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, BF16); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP32); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT32); break; // scatter_gather diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index 75ca634..9209d9c 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -176,6 +176,7 @@ class TosaTensorGen: for i in range(pl + const): shape_list.append(shape.copy()) + # Generates an input rank mismatch for operators with more than one input if error_name == ErrorIf.RankMismatch: if rank == 1 and i != 1: shape = testGen.makeShape(rank + testGen.rng.choice([1, 2, 3])) diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py index ee227b3..b19d5e9 100644 --- a/verif/generator/tosa_error_if.py +++ b/verif/generator/tosa_error_if.py @@ -1067,7 +1067,9 @@ class TosaErrorValidator: if check: input1_shape = kwargs["input1"].shape - input2_shape = kwargs["input2"].shape + input2_shape = ( + kwargs["input2"].shape if "input2" in kwargs else input1_shape + ) # In case of SELECT op input3_shape = ( kwargs["input3"].shape if "input3" in kwargs else input2_shape @@ -1921,11 +1923,13 @@ class TosaErrorValidator: input_shape = kwargs["input_shape"] output_shape = kwargs["output_shape"] size = kwargs["size"] - rank = len(input_shape) - if len(size) == rank: - for index in range(rank): - if size[index] != output_shape[index]: - error_result = True + + if len(input_shape) == len(output_shape): + rank = len(input_shape) + if len(size) == rank: + for index in range(rank): + if size[index] != output_shape[index]: + error_result = True info_dict = { "error_name": error_name, diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index a768da0..7fef942 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -14,6 +14,7 @@ from generator.tosa_error_if import TosaErrorIfArgGen from generator.tosa_error_if import TosaErrorValidator from generator.tosa_error_if import TosaInvalidValidator from generator.tosa_utils import DTYPE_ATTRIBUTES +from generator.tosa_utils import get_rank_mismatch_shape from generator.tosa_utils import get_wrong_output_type from generator.tosa_utils import MAX_RESIZE_DIMENSION from generator.tosa_utils import usableDTypes @@ -1263,6 +1264,7 @@ class TosaTestGen: input_list=input_list, output_list=output_list, num_operands=num_operands, + input1=a, ): return None @@ -1369,6 +1371,7 @@ class TosaTestGen: input_list=input_list, output_list=output_list, num_operands=num_operands, + input1=a, ): return None @@ -1404,6 +1407,7 @@ class TosaTestGen: input_list=input_list, output_list=output_list, num_operands=num_operands, + input1=a, ): return None @@ -1438,6 +1442,7 @@ class TosaTestGen: input_list=input_list, output_list=output_list, num_operands=num_operands, + input1=a, ): return None @@ -3657,6 +3662,8 @@ class TosaTestGen: TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, + TosaErrorValidator.evRankMismatch, + TosaErrorValidator.evWrongRank, ), }, "reshape": { @@ -3699,7 +3706,7 @@ class TosaTestGen: "slice": { "op": Op.SLICE, "operands": (1, 0), - "rank": (1, 4), + "rank": (1, 6), "build_fcn": ( build_slice, TosaTensorGen.tgBasic, @@ -3718,11 +3725,13 @@ class TosaTestGen: TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, + TosaErrorValidator.evRankMismatch, ), }, "tile": { "op": Op.TILE, "operands": (1, 0), + "rank": (1, 6), "build_fcn": ( build_tile, TosaTensorGen.tgBasic, @@ -3735,12 +3744,14 @@ class TosaTestGen: TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, + TosaErrorValidator.evRankMismatch, + TosaErrorValidator.evWrongRank, ), }, "transpose": { "op": Op.TRANSPOSE, "operands": (1, 0), - "rank": (1, 4), + "rank": (1, 6), "build_fcn": ( build_transpose, TosaTensorGen.tgBasic, @@ -3755,6 +3766,9 @@ class TosaTestGen: TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, + TosaErrorValidator.evWrongRank, + TosaErrorValidator.evRankMismatch, + TosaErrorValidator.evTensorSizeInputOutputMismatch, ), }, # Data nodes @@ -4539,6 +4553,8 @@ class OutputShaper: if error_name == ErrorIf.PadOutputShapeMismatch: bad_dim = rng.choice(range(len(output_shape))) output_shape[bad_dim] -= rng.choice([1, 2]) + elif error_name == ErrorIf.RankMismatch: + output_shape = get_rank_mismatch_shape(rng, output_shape) if error_name == ErrorIf.WrongOutputType: all_dtypes = [ @@ -4583,7 +4599,7 @@ class OutputShaper: return ser.addOutput(output_shape, outputDType) @staticmethod - def sliceOp(ser, rng, a, start, size, error_name=None): + def sliceOp(ser, rng, input, start, size, error_name=None): if error_name == ErrorIf.WrongOutputType: all_dtypes = [ @@ -4595,13 +4611,13 @@ class OutputShaper: DType.FP16, DType.BF16, ] - wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) + wrong_dtypes = list(set(all_dtypes) - set([input.dtype])) outputDType = rng.choice(wrong_dtypes) else: - outputDType = a.dtype + outputDType = input.dtype + output_shape = size.copy() if error_name == ErrorIf.SizeOutputShapeMismatch: - output_shape = size.copy() for index in range(len(output_shape)): if output_shape[index] <= 2: output_shape[index] = output_shape[index] + rng.choice([1, 2]) @@ -4609,8 +4625,10 @@ class OutputShaper: output_shape[index] = output_shape[index] + rng.choice( [-2, -1, 1, 2] ) - else: - output_shape = size.copy() + elif error_name == ErrorIf.InputSizeStartLengthMismatch: + output_shape = input.shape.copy() + elif error_name == ErrorIf.RankMismatch: + output_shape = get_rank_mismatch_shape(rng, output_shape) return ser.addOutput(output_shape, outputDType) @@ -4623,6 +4641,9 @@ class OutputShaper: for i in range(len(output_shape)): output_shape[i] = a.shape[i] * multiples[i] + if error_name == ErrorIf.RankMismatch: + output_shape = get_rank_mismatch_shape(rng, output_shape) + if error_name == ErrorIf.WrongOutputType: all_dtypes = [ DType.INT8, @@ -4646,13 +4667,16 @@ class OutputShaper: assert len(perms) == len(output_shape) - if error_name == ErrorIf.IndexOutsideBounds: - for i in range(len(output_shape)): - output_shape[i] = a.shape[0] - else: + if error_name not in [ErrorIf.IndexOutsideBounds, ErrorIf.IndexUsedTwice]: for i in range(len(output_shape)): output_shape[i] = a.shape[perms[i]] + if error_name == ErrorIf.TensorSizeInputOutputMismatch: + for i in range(len(output_shape)): + output_shape[i] += rng.integers(1, 10) + elif error_name == ErrorIf.RankMismatch: + output_shape = get_rank_mismatch_shape(rng, output_shape) + if error_name == ErrorIf.WrongOutputType: all_dtypes = [ DType.INT8, diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py index 29ae898..8ff62f1 100644 --- a/verif/generator/tosa_utils.py +++ b/verif/generator/tosa_utils.py @@ -148,6 +148,17 @@ def get_wrong_output_type(op_name, rng, input_dtype): return rng.choice(a=incorrect_types) +def get_rank_mismatch_shape(rng, output_shape): + """ + Extends the rank of the provided output_shape by + an arbitrary amount but ensures the total element + count remains the same. + """ + rank_modifier = rng.choice([1, 2, 3]) + output_shape += [1] * rank_modifier + return output_shape + + def float32_is_valid_bfloat16(f): """Return True if float value is valid bfloat16.""" f32_bits = get_float32_bitstring(f) |