diff options
author | Jerry Ge <jerry.ge@arm.com> | 2023-05-23 20:59:32 +0000 |
---|---|---|
committer | Dominic Symes <dominic.symes@arm.com> | 2023-06-15 18:25:54 +0000 |
commit | 135c95544fda260e8ce622cff7835b886a97663f (patch) | |
tree | 5d46f8f48978112abff037309a827b5844ee80de | |
parent | cb7201e173961760c042cade591afe763c949c8f (diff) | |
download | reference_model-135c95544fda260e8ce622cff7835b886a97663f.tar.gz |
Add ERROR_IF to incorrect broadcast shapes
Signed-off-by: Jerry Ge <jerry.ge@arm.com>
Change-Id: I7460ad9eed3ed5c7cec6e855a0303753ed28eb1c
-rw-r--r-- | reference_model/src/ops/ewise_binary.cc | 19 | ||||
-rw-r--r-- | reference_model/src/ops/ewise_binary.h | 2 | ||||
-rw-r--r-- | reference_model/src/ops/ewise_ternary.cc | 26 | ||||
-rw-r--r-- | reference_model/src/ops/ewise_ternary.h | 2 | ||||
-rw-r--r-- | verif/generator/tosa_arg_gen.py | 13 | ||||
-rw-r--r-- | verif/generator/tosa_error_if.py | 72 | ||||
-rw-r--r-- | verif/generator/tosa_test_gen.py | 32 |
7 files changed, 145 insertions, 21 deletions
diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc index 1e873e7..2bc894d 100644 --- a/reference_model/src/ops/ewise_binary.cc +++ b/reference_model/src/ops/ewise_binary.cc @@ -85,25 +85,40 @@ int BinaryNodeBase<Rank, InDtype, OutDtype>::checkTensorAttributes() } template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype> -int BinaryNodeBase<Rank, InDtype, OutDtype>::broadcast() +int BinaryNodeBase<Rank, InDtype, OutDtype>::broadcast(std::vector<int>& calculated_shape) { const std::vector<int>& a_shape = a->getShape(); const std::vector<int>& b_shape = b->getShape(); const std::vector<int>& output_shape = result->getShape(); + // calculates the multipliers for Eigen for (int i = 0; i < Rank; i++) { bcast_a[i] = (a_shape[i] != output_shape[i] && a_shape[i] == 1) ? output_shape[i] : 1; bcast_b[i] = (b_shape[i] != output_shape[i] && b_shape[i] == 1) ? output_shape[i] : 1; } + // calculates the broadcasted output shape + calculated_shape = a_shape; + for (size_t i = 0; i < calculated_shape.size(); i++) { + if (calculated_shape[i] == 1) { + calculated_shape[i] = b_shape[i]; + } else { + ERROR_IF(b_shape[i] != 1 && b_shape[i] != calculated_shape[i], "Broadcast_shape failure, input shapes are not compatible"); + } + } + return 0; } template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype> int BinaryNode<Rank, InDtype, OutDtype>::eval() { - this->broadcast(); + std::vector<int> calculated_shape; + this->broadcast(calculated_shape); + + auto result_shape = this->result->getShape(); + ERROR_IF(calculated_shape != result_shape, "Broadcast_shape failure, calculated_shape and result_shape don't match"); Eigen::array<int, Rank> reshaper; reshaper.fill(1); diff --git a/reference_model/src/ops/ewise_binary.h b/reference_model/src/ops/ewise_binary.h index 5f6e531..3a6f24c 100644 --- a/reference_model/src/ops/ewise_binary.h +++ b/reference_model/src/ops/ewise_binary.h @@ -55,7 +55,7 @@ public: using TOut = Eigen::Tensor<OutEigenType, Rank>; protected: - int broadcast(); + int broadcast(std::vector<int>& calculated_shape); protected: std::function<OutEigenType(InEigenType, InEigenType)> fcn; diff --git a/reference_model/src/ops/ewise_ternary.cc b/reference_model/src/ops/ewise_ternary.cc index 16554b5..fd2510f 100644 --- a/reference_model/src/ops/ewise_ternary.cc +++ b/reference_model/src/ops/ewise_ternary.cc @@ -66,13 +66,14 @@ int OpSelectBase<Rank, Dtype>::eval() } template <int Rank, TOSA_REF_TYPE Dtype> -int OpSelect<Rank, Dtype>::broadcast() +int OpSelect<Rank, Dtype>::broadcast(std::vector<int>& calculated_shape) { const std::vector<int>& cond_shape = this->cond->getShape(); const std::vector<int>& then_shape = this->then_val->getShape(); const std::vector<int>& else_shape = this->else_val->getShape(); const std::vector<int>& output_shape = this->out->getShape(); + // calculates the multipliers for Eigen for (int i = 0; i < Rank; i++) { this->bcast_cond[i] = (cond_shape[i] != output_shape[i] && cond_shape[i] == 1) ? output_shape[i] : 1; @@ -80,13 +81,34 @@ int OpSelect<Rank, Dtype>::broadcast() this->bcast_else[i] = (else_shape[i] != output_shape[i] && else_shape[i] == 1) ? output_shape[i] : 1; } + // calculates the broadcasted output shape + calculated_shape = cond_shape; + for (size_t i = 0; i < calculated_shape.size(); i++) { + if (calculated_shape[i] == 1) { + calculated_shape[i] = then_shape[i]; + } else { + ERROR_IF(then_shape[i] != 1 && then_shape[i] != calculated_shape[i], "Broadcast_shape failure, input shapes are not compatible"); + } + + if (calculated_shape[i] == 1) { + calculated_shape[i] = else_shape[i]; + } else { + ERROR_IF(else_shape[i] != 1 && else_shape[i] != calculated_shape[i], "Broadcast_shape failure, input shapes are not compatible"); + } + } + return 0; } template <int Rank, TOSA_REF_TYPE Dtype> int OpSelect<Rank, Dtype>::eval() { - this->broadcast(); + std::vector<int> calculated_shape; + this->broadcast(calculated_shape); + + auto result_shape = this->out->getShape(); + ERROR_IF(calculated_shape != result_shape, "Broadcast_shape failure, calculated_shape and result_shape don't match"); + this->out->getTensor() = this->cond->getTensor() .broadcast(this->bcast_cond) .select(this->then_val->getTensor().broadcast(this->bcast_then), diff --git a/reference_model/src/ops/ewise_ternary.h b/reference_model/src/ops/ewise_ternary.h index c6970cb..f24dfbe 100644 --- a/reference_model/src/ops/ewise_ternary.h +++ b/reference_model/src/ops/ewise_ternary.h @@ -63,7 +63,7 @@ public: : OpSelectBase<Rank, Dtype>(sgt_, attribute_, id_) {} virtual int eval(); - int broadcast(); + int broadcast(std::vector<int>& calculated_shape); using InEigenType = typename OpSelectBase<Rank, Dtype>::InEigenType; }; diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index 9386ec2..97ff237 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -246,15 +246,18 @@ class TosaTensorGen: # Choose one of the inputs to broadcast # Note: Simplifies OutputShaper code if we don't change first shape for errors bcast_idx = testGen.randInt(0 if error_name is None else 1, pl + const) + fuzz_idx = testGen.randInt(0, rank) + for i in range(pl + const): shape_bcast = shape.copy() + # To test broadcasting, the chosen fuzz index dimension should not be 1 + if shape_bcast[fuzz_idx] == 1: + shape_bcast[fuzz_idx] += 1 + # If the chosen input, pick a random index to broadcast if i == bcast_idx: - fuzz_idx = testGen.randInt(0, rank) - if error_name == ErrorIf.DimensionMismatch: - shape_bcast[fuzz_idx] += 1 - elif error_name == ErrorIf.RankMismatch: + if error_name == ErrorIf.RankMismatch: # Add one rank to the shape (or more for rank of 1) extra_ranks = testGen.rng.choice([1, 2, 3]) if rank == 1 else 1 shape_bcast = np.concatenate( @@ -264,6 +267,8 @@ class TosaTensorGen: # Either keep the extra rank, or remove it new_len = testGen.rng.choice([-2, len(shape_bcast)]) shape_bcast = shape_bcast[:new_len] + elif error_name == ErrorIf.BroadcastShapesMismatch: + shape_bcast[fuzz_idx] += 2 else: shape_bcast[fuzz_idx] = 1 diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py index a0a9203..d490cf2 100644 --- a/verif/generator/tosa_error_if.py +++ b/verif/generator/tosa_error_if.py @@ -83,6 +83,7 @@ class ErrorIf(object): FFTOutputShapeMismatch = "FFTOutputShapeMismatch" ReshapeOutputSizeMultiInference = "ReshapeOutputSizeMultiInference" ReshapeOutputSizeNonInteger = "ReshapeOutputSizeNonInteger" + BroadcastShapesMismatch = "BroadcastShapesMismatch" class TosaErrorIfArgGen: @@ -1109,17 +1110,19 @@ class TosaErrorValidator: kwargs["input3"].shape if "input3" in kwargs else input2_shape ) - for output in kwargs["result_tensors"]: - output_shape = output.shape - for i in range( - min(len(input1_shape), len(input2_shape), len(input3_shape)) - ): - if ( - (input1_shape[i] != 1 and input1_shape[i] != output_shape[i]) - or (input2_shape[i] != 1 and input2_shape[i] != output_shape[i]) - or (input3_shape[i] != 1 and input3_shape[i] != output_shape[i]) - ): - error_result = True + if len(input1_shape) == len(input2_shape) == len(input3_shape): + calculated_shape = TosaErrorValidator.calculateBroadcastShape( + input3_shape, + TosaErrorValidator.calculateBroadcastShape( + input1_shape, input2_shape + ), + ) + if calculated_shape is not None: + # Valid inputs - check for output mismatch + for output in kwargs["result_tensors"]: + output_shape = output.shape + if calculated_shape != output_shape: + error_result = True info_dict = { "error_name": error_name, @@ -2566,6 +2569,53 @@ class TosaErrorValidator: } return info_dict + @staticmethod + def calculateBroadcastShape(input_shape_a, input_shape_b): + if input_shape_a is not None and input_shape_b is not None: + calculated_shape = input_shape_a.copy() + for idx in range(len(calculated_shape)): + if calculated_shape[idx] == 1: + calculated_shape[idx] = input_shape_b[idx] + elif ( + input_shape_b[idx] != 1 + and input_shape_b[idx] != calculated_shape[idx] + ): + return None + return calculated_shape + else: + return None + + @staticmethod + def evBroadcastShapesMismatch(check=False, **kwargs): + error_name = ErrorIf.BroadcastShapesMismatch + param_reqs = {"rank": None, "dtype": None, "shape": None} + error_result = False + error_reason = "Broadcast shape calculating failed" + + if check: + input_shape_a = kwargs["input1"].shape + input_shape_b = kwargs["input2"].shape + input_shape_c = ( + kwargs["input3"].shape if "input3" in kwargs else input_shape_b + ) + + if len(input_shape_a) == len(input_shape_b) == len(input_shape_c): + calculated_shape = TosaErrorValidator.calculateBroadcastShape( + input_shape_c, + TosaErrorValidator.calculateBroadcastShape( + input_shape_a, input_shape_b + ), + ) + error_result = calculated_shape is None + + info_dict = { + "error_name": error_name, + "error_result": error_result, + "error_reason": error_reason, + "param_reqs": param_reqs, + } + return info_dict + class TosaInvalidValidator: @staticmethod diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index 7691fdd..66084b4 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -2925,6 +2925,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, + TosaErrorValidator.evBroadcastShapesMismatch, ), }, "arithmetic_right_shift": { @@ -2944,6 +2945,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, + TosaErrorValidator.evBroadcastShapesMismatch, ), }, "bitwise_and": { @@ -2963,6 +2965,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, + TosaErrorValidator.evBroadcastShapesMismatch, ), }, "bitwise_or": { @@ -2982,6 +2985,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, + TosaErrorValidator.evBroadcastShapesMismatch, ), }, "bitwise_xor": { @@ -3001,6 +3005,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, + TosaErrorValidator.evBroadcastShapesMismatch, ), }, "intdiv": { @@ -3020,6 +3025,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, + TosaErrorValidator.evBroadcastShapesMismatch, ), }, "logical_and": { @@ -3039,6 +3045,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, + TosaErrorValidator.evBroadcastShapesMismatch, ), }, "logical_left_shift": { @@ -3058,6 +3065,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, + TosaErrorValidator.evBroadcastShapesMismatch, ), }, "logical_right_shift": { @@ -3077,6 +3085,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, + TosaErrorValidator.evBroadcastShapesMismatch, ), }, "logical_or": { @@ -3096,6 +3105,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, + TosaErrorValidator.evBroadcastShapesMismatch, ), }, "logical_xor": { @@ -3115,6 +3125,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, + TosaErrorValidator.evBroadcastShapesMismatch, ), }, "maximum": { @@ -3134,6 +3145,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, + TosaErrorValidator.evBroadcastShapesMismatch, ), }, "minimum": { @@ -3153,6 +3165,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, + TosaErrorValidator.evBroadcastShapesMismatch, ), }, "mul": { @@ -3172,6 +3185,7 @@ class TosaTestGen: TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evRankMismatch, TosaErrorValidator.evDimensionMismatch, + TosaErrorValidator.evBroadcastShapesMismatch, ), }, "pow": { @@ -3191,6 +3205,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, + TosaErrorValidator.evBroadcastShapesMismatch, ), }, "sub": { @@ -3210,6 +3225,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, + TosaErrorValidator.evBroadcastShapesMismatch, ), }, "table": { @@ -3441,6 +3457,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, + TosaErrorValidator.evBroadcastShapesMismatch, ), }, # Comparison operators @@ -3461,6 +3478,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, + TosaErrorValidator.evBroadcastShapesMismatch, ), }, "greater_equal": { @@ -3480,6 +3498,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, + TosaErrorValidator.evBroadcastShapesMismatch, ), }, "greater": { @@ -3499,6 +3518,7 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evDimensionMismatch, + TosaErrorValidator.evBroadcastShapesMismatch, ), }, # Reduction operators @@ -4078,6 +4098,10 @@ class OutputShaper: else: shape.append(a.shape[i]) + fuzz_idx = rng.integers(0, len(a.shape)) + if error_name == ErrorIf.DimensionMismatch: + shape[fuzz_idx] += 1 + if error_name == ErrorIf.WrongOutputType: all_dtypes = [ DType.INT8, @@ -4139,6 +4163,10 @@ class OutputShaper: else: shape.append(cond.shape[i]) + fuzz_idx = rng.integers(0, len(a.shape)) + if error_name == ErrorIf.DimensionMismatch: + shape[fuzz_idx] += 1 + if error_name == ErrorIf.WrongOutputType: all_dtypes = [ DType.INT8, @@ -4170,6 +4198,10 @@ class OutputShaper: else: shape.append(a.shape[i]) + fuzz_idx = rng.integers(0, len(a.shape)) + if error_name == ErrorIf.DimensionMismatch: + shape[fuzz_idx] += 1 + if error_name == ErrorIf.WrongOutputType: wrong_dtypes = [ DType.INT8, |