aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJerry Ge <jerry.ge@arm.com>2023-05-23 20:59:32 +0000
committerDominic Symes <dominic.symes@arm.com>2023-06-15 18:25:54 +0000
commit135c95544fda260e8ce622cff7835b886a97663f (patch)
tree5d46f8f48978112abff037309a827b5844ee80de
parentcb7201e173961760c042cade591afe763c949c8f (diff)
downloadreference_model-135c95544fda260e8ce622cff7835b886a97663f.tar.gz
Add ERROR_IF to incorrect broadcast shapes
Signed-off-by: Jerry Ge <jerry.ge@arm.com> Change-Id: I7460ad9eed3ed5c7cec6e855a0303753ed28eb1c
-rw-r--r--reference_model/src/ops/ewise_binary.cc19
-rw-r--r--reference_model/src/ops/ewise_binary.h2
-rw-r--r--reference_model/src/ops/ewise_ternary.cc26
-rw-r--r--reference_model/src/ops/ewise_ternary.h2
-rw-r--r--verif/generator/tosa_arg_gen.py13
-rw-r--r--verif/generator/tosa_error_if.py72
-rw-r--r--verif/generator/tosa_test_gen.py32
7 files changed, 145 insertions, 21 deletions
diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc
index 1e873e7..2bc894d 100644
--- a/reference_model/src/ops/ewise_binary.cc
+++ b/reference_model/src/ops/ewise_binary.cc
@@ -85,25 +85,40 @@ int BinaryNodeBase<Rank, InDtype, OutDtype>::checkTensorAttributes()
}
template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
-int BinaryNodeBase<Rank, InDtype, OutDtype>::broadcast()
+int BinaryNodeBase<Rank, InDtype, OutDtype>::broadcast(std::vector<int>& calculated_shape)
{
const std::vector<int>& a_shape = a->getShape();
const std::vector<int>& b_shape = b->getShape();
const std::vector<int>& output_shape = result->getShape();
+ // calculates the multipliers for Eigen
for (int i = 0; i < Rank; i++)
{
bcast_a[i] = (a_shape[i] != output_shape[i] && a_shape[i] == 1) ? output_shape[i] : 1;
bcast_b[i] = (b_shape[i] != output_shape[i] && b_shape[i] == 1) ? output_shape[i] : 1;
}
+ // calculates the broadcasted output shape
+ calculated_shape = a_shape;
+ for (size_t i = 0; i < calculated_shape.size(); i++) {
+ if (calculated_shape[i] == 1) {
+ calculated_shape[i] = b_shape[i];
+ } else {
+ ERROR_IF(b_shape[i] != 1 && b_shape[i] != calculated_shape[i], "Broadcast_shape failure, input shapes are not compatible");
+ }
+ }
+
return 0;
}
template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
int BinaryNode<Rank, InDtype, OutDtype>::eval()
{
- this->broadcast();
+ std::vector<int> calculated_shape;
+ this->broadcast(calculated_shape);
+
+ auto result_shape = this->result->getShape();
+ ERROR_IF(calculated_shape != result_shape, "Broadcast_shape failure, calculated_shape and result_shape don't match");
Eigen::array<int, Rank> reshaper;
reshaper.fill(1);
diff --git a/reference_model/src/ops/ewise_binary.h b/reference_model/src/ops/ewise_binary.h
index 5f6e531..3a6f24c 100644
--- a/reference_model/src/ops/ewise_binary.h
+++ b/reference_model/src/ops/ewise_binary.h
@@ -55,7 +55,7 @@ public:
using TOut = Eigen::Tensor<OutEigenType, Rank>;
protected:
- int broadcast();
+ int broadcast(std::vector<int>& calculated_shape);
protected:
std::function<OutEigenType(InEigenType, InEigenType)> fcn;
diff --git a/reference_model/src/ops/ewise_ternary.cc b/reference_model/src/ops/ewise_ternary.cc
index 16554b5..fd2510f 100644
--- a/reference_model/src/ops/ewise_ternary.cc
+++ b/reference_model/src/ops/ewise_ternary.cc
@@ -66,13 +66,14 @@ int OpSelectBase<Rank, Dtype>::eval()
}
template <int Rank, TOSA_REF_TYPE Dtype>
-int OpSelect<Rank, Dtype>::broadcast()
+int OpSelect<Rank, Dtype>::broadcast(std::vector<int>& calculated_shape)
{
const std::vector<int>& cond_shape = this->cond->getShape();
const std::vector<int>& then_shape = this->then_val->getShape();
const std::vector<int>& else_shape = this->else_val->getShape();
const std::vector<int>& output_shape = this->out->getShape();
+ // calculates the multipliers for Eigen
for (int i = 0; i < Rank; i++)
{
this->bcast_cond[i] = (cond_shape[i] != output_shape[i] && cond_shape[i] == 1) ? output_shape[i] : 1;
@@ -80,13 +81,34 @@ int OpSelect<Rank, Dtype>::broadcast()
this->bcast_else[i] = (else_shape[i] != output_shape[i] && else_shape[i] == 1) ? output_shape[i] : 1;
}
+ // calculates the broadcasted output shape
+ calculated_shape = cond_shape;
+ for (size_t i = 0; i < calculated_shape.size(); i++) {
+ if (calculated_shape[i] == 1) {
+ calculated_shape[i] = then_shape[i];
+ } else {
+ ERROR_IF(then_shape[i] != 1 && then_shape[i] != calculated_shape[i], "Broadcast_shape failure, input shapes are not compatible");
+ }
+
+ if (calculated_shape[i] == 1) {
+ calculated_shape[i] = else_shape[i];
+ } else {
+ ERROR_IF(else_shape[i] != 1 && else_shape[i] != calculated_shape[i], "Broadcast_shape failure, input shapes are not compatible");
+ }
+ }
+
return 0;
}
template <int Rank, TOSA_REF_TYPE Dtype>
int OpSelect<Rank, Dtype>::eval()
{
- this->broadcast();
+ std::vector<int> calculated_shape;
+ this->broadcast(calculated_shape);
+
+ auto result_shape = this->out->getShape();
+ ERROR_IF(calculated_shape != result_shape, "Broadcast_shape failure, calculated_shape and result_shape don't match");
+
this->out->getTensor() = this->cond->getTensor()
.broadcast(this->bcast_cond)
.select(this->then_val->getTensor().broadcast(this->bcast_then),
diff --git a/reference_model/src/ops/ewise_ternary.h b/reference_model/src/ops/ewise_ternary.h
index c6970cb..f24dfbe 100644
--- a/reference_model/src/ops/ewise_ternary.h
+++ b/reference_model/src/ops/ewise_ternary.h
@@ -63,7 +63,7 @@ public:
: OpSelectBase<Rank, Dtype>(sgt_, attribute_, id_)
{}
virtual int eval();
- int broadcast();
+ int broadcast(std::vector<int>& calculated_shape);
using InEigenType = typename OpSelectBase<Rank, Dtype>::InEigenType;
};
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index 9386ec2..97ff237 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -246,15 +246,18 @@ class TosaTensorGen:
# Choose one of the inputs to broadcast
# Note: Simplifies OutputShaper code if we don't change first shape for errors
bcast_idx = testGen.randInt(0 if error_name is None else 1, pl + const)
+ fuzz_idx = testGen.randInt(0, rank)
+
for i in range(pl + const):
shape_bcast = shape.copy()
+ # To test broadcasting, the chosen fuzz index dimension should not be 1
+ if shape_bcast[fuzz_idx] == 1:
+ shape_bcast[fuzz_idx] += 1
+
# If the chosen input, pick a random index to broadcast
if i == bcast_idx:
- fuzz_idx = testGen.randInt(0, rank)
- if error_name == ErrorIf.DimensionMismatch:
- shape_bcast[fuzz_idx] += 1
- elif error_name == ErrorIf.RankMismatch:
+ if error_name == ErrorIf.RankMismatch:
# Add one rank to the shape (or more for rank of 1)
extra_ranks = testGen.rng.choice([1, 2, 3]) if rank == 1 else 1
shape_bcast = np.concatenate(
@@ -264,6 +267,8 @@ class TosaTensorGen:
# Either keep the extra rank, or remove it
new_len = testGen.rng.choice([-2, len(shape_bcast)])
shape_bcast = shape_bcast[:new_len]
+ elif error_name == ErrorIf.BroadcastShapesMismatch:
+ shape_bcast[fuzz_idx] += 2
else:
shape_bcast[fuzz_idx] = 1
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index a0a9203..d490cf2 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -83,6 +83,7 @@ class ErrorIf(object):
FFTOutputShapeMismatch = "FFTOutputShapeMismatch"
ReshapeOutputSizeMultiInference = "ReshapeOutputSizeMultiInference"
ReshapeOutputSizeNonInteger = "ReshapeOutputSizeNonInteger"
+ BroadcastShapesMismatch = "BroadcastShapesMismatch"
class TosaErrorIfArgGen:
@@ -1109,17 +1110,19 @@ class TosaErrorValidator:
kwargs["input3"].shape if "input3" in kwargs else input2_shape
)
- for output in kwargs["result_tensors"]:
- output_shape = output.shape
- for i in range(
- min(len(input1_shape), len(input2_shape), len(input3_shape))
- ):
- if (
- (input1_shape[i] != 1 and input1_shape[i] != output_shape[i])
- or (input2_shape[i] != 1 and input2_shape[i] != output_shape[i])
- or (input3_shape[i] != 1 and input3_shape[i] != output_shape[i])
- ):
- error_result = True
+ if len(input1_shape) == len(input2_shape) == len(input3_shape):
+ calculated_shape = TosaErrorValidator.calculateBroadcastShape(
+ input3_shape,
+ TosaErrorValidator.calculateBroadcastShape(
+ input1_shape, input2_shape
+ ),
+ )
+ if calculated_shape is not None:
+ # Valid inputs - check for output mismatch
+ for output in kwargs["result_tensors"]:
+ output_shape = output.shape
+ if calculated_shape != output_shape:
+ error_result = True
info_dict = {
"error_name": error_name,
@@ -2566,6 +2569,53 @@ class TosaErrorValidator:
}
return info_dict
+ @staticmethod
+ def calculateBroadcastShape(input_shape_a, input_shape_b):
+ if input_shape_a is not None and input_shape_b is not None:
+ calculated_shape = input_shape_a.copy()
+ for idx in range(len(calculated_shape)):
+ if calculated_shape[idx] == 1:
+ calculated_shape[idx] = input_shape_b[idx]
+ elif (
+ input_shape_b[idx] != 1
+ and input_shape_b[idx] != calculated_shape[idx]
+ ):
+ return None
+ return calculated_shape
+ else:
+ return None
+
+ @staticmethod
+ def evBroadcastShapesMismatch(check=False, **kwargs):
+ error_name = ErrorIf.BroadcastShapesMismatch
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Broadcast shape calculating failed"
+
+ if check:
+ input_shape_a = kwargs["input1"].shape
+ input_shape_b = kwargs["input2"].shape
+ input_shape_c = (
+ kwargs["input3"].shape if "input3" in kwargs else input_shape_b
+ )
+
+ if len(input_shape_a) == len(input_shape_b) == len(input_shape_c):
+ calculated_shape = TosaErrorValidator.calculateBroadcastShape(
+ input_shape_c,
+ TosaErrorValidator.calculateBroadcastShape(
+ input_shape_a, input_shape_b
+ ),
+ )
+ error_result = calculated_shape is None
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
class TosaInvalidValidator:
@staticmethod
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 7691fdd..66084b4 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -2925,6 +2925,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
+ TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"arithmetic_right_shift": {
@@ -2944,6 +2945,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
+ TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"bitwise_and": {
@@ -2963,6 +2965,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
+ TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"bitwise_or": {
@@ -2982,6 +2985,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
+ TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"bitwise_xor": {
@@ -3001,6 +3005,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
+ TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"intdiv": {
@@ -3020,6 +3025,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
+ TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"logical_and": {
@@ -3039,6 +3045,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
+ TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"logical_left_shift": {
@@ -3058,6 +3065,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
+ TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"logical_right_shift": {
@@ -3077,6 +3085,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
+ TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"logical_or": {
@@ -3096,6 +3105,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
+ TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"logical_xor": {
@@ -3115,6 +3125,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
+ TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"maximum": {
@@ -3134,6 +3145,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
+ TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"minimum": {
@@ -3153,6 +3165,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
+ TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"mul": {
@@ -3172,6 +3185,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evRankMismatch,
TosaErrorValidator.evDimensionMismatch,
+ TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"pow": {
@@ -3191,6 +3205,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
+ TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"sub": {
@@ -3210,6 +3225,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
+ TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"table": {
@@ -3441,6 +3457,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
+ TosaErrorValidator.evBroadcastShapesMismatch,
),
},
# Comparison operators
@@ -3461,6 +3478,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
+ TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"greater_equal": {
@@ -3480,6 +3498,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
+ TosaErrorValidator.evBroadcastShapesMismatch,
),
},
"greater": {
@@ -3499,6 +3518,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList,
TosaErrorValidator.evDimensionMismatch,
+ TosaErrorValidator.evBroadcastShapesMismatch,
),
},
# Reduction operators
@@ -4078,6 +4098,10 @@ class OutputShaper:
else:
shape.append(a.shape[i])
+ fuzz_idx = rng.integers(0, len(a.shape))
+ if error_name == ErrorIf.DimensionMismatch:
+ shape[fuzz_idx] += 1
+
if error_name == ErrorIf.WrongOutputType:
all_dtypes = [
DType.INT8,
@@ -4139,6 +4163,10 @@ class OutputShaper:
else:
shape.append(cond.shape[i])
+ fuzz_idx = rng.integers(0, len(a.shape))
+ if error_name == ErrorIf.DimensionMismatch:
+ shape[fuzz_idx] += 1
+
if error_name == ErrorIf.WrongOutputType:
all_dtypes = [
DType.INT8,
@@ -4170,6 +4198,10 @@ class OutputShaper:
else:
shape.append(a.shape[i])
+ fuzz_idx = rng.integers(0, len(a.shape))
+ if error_name == ErrorIf.DimensionMismatch:
+ shape[fuzz_idx] += 1
+
if error_name == ErrorIf.WrongOutputType:
wrong_dtypes = [
DType.INT8,