aboutsummaryrefslogtreecommitdiff
path: root/reference_model
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model')
-rw-r--r--reference_model/src/ops/ewise_binary.cc19
-rw-r--r--reference_model/src/ops/ewise_binary.h2
-rw-r--r--reference_model/src/ops/ewise_ternary.cc26
-rw-r--r--reference_model/src/ops/ewise_ternary.h2
4 files changed, 43 insertions, 6 deletions
diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc
index 1e873e7..2bc894d 100644
--- a/reference_model/src/ops/ewise_binary.cc
+++ b/reference_model/src/ops/ewise_binary.cc
@@ -85,25 +85,40 @@ int BinaryNodeBase<Rank, InDtype, OutDtype>::checkTensorAttributes()
}
template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
-int BinaryNodeBase<Rank, InDtype, OutDtype>::broadcast()
+int BinaryNodeBase<Rank, InDtype, OutDtype>::broadcast(std::vector<int>& calculated_shape)
{
const std::vector<int>& a_shape = a->getShape();
const std::vector<int>& b_shape = b->getShape();
const std::vector<int>& output_shape = result->getShape();
+ // calculates the multipliers for Eigen
for (int i = 0; i < Rank; i++)
{
bcast_a[i] = (a_shape[i] != output_shape[i] && a_shape[i] == 1) ? output_shape[i] : 1;
bcast_b[i] = (b_shape[i] != output_shape[i] && b_shape[i] == 1) ? output_shape[i] : 1;
}
+ // calculates the broadcasted output shape
+ calculated_shape = a_shape;
+ for (size_t i = 0; i < calculated_shape.size(); i++) {
+ if (calculated_shape[i] == 1) {
+ calculated_shape[i] = b_shape[i];
+ } else {
+ ERROR_IF(b_shape[i] != 1 && b_shape[i] != calculated_shape[i], "Broadcast_shape failure, input shapes are not compatible");
+ }
+ }
+
return 0;
}
template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
int BinaryNode<Rank, InDtype, OutDtype>::eval()
{
- this->broadcast();
+ std::vector<int> calculated_shape;
+ this->broadcast(calculated_shape);
+
+ auto result_shape = this->result->getShape();
+ ERROR_IF(calculated_shape != result_shape, "Broadcast_shape failure, calculated_shape and result_shape don't match");
Eigen::array<int, Rank> reshaper;
reshaper.fill(1);
diff --git a/reference_model/src/ops/ewise_binary.h b/reference_model/src/ops/ewise_binary.h
index 5f6e531..3a6f24c 100644
--- a/reference_model/src/ops/ewise_binary.h
+++ b/reference_model/src/ops/ewise_binary.h
@@ -55,7 +55,7 @@ public:
using TOut = Eigen::Tensor<OutEigenType, Rank>;
protected:
- int broadcast();
+ int broadcast(std::vector<int>& calculated_shape);
protected:
std::function<OutEigenType(InEigenType, InEigenType)> fcn;
diff --git a/reference_model/src/ops/ewise_ternary.cc b/reference_model/src/ops/ewise_ternary.cc
index 16554b5..fd2510f 100644
--- a/reference_model/src/ops/ewise_ternary.cc
+++ b/reference_model/src/ops/ewise_ternary.cc
@@ -66,13 +66,14 @@ int OpSelectBase<Rank, Dtype>::eval()
}
template <int Rank, TOSA_REF_TYPE Dtype>
-int OpSelect<Rank, Dtype>::broadcast()
+int OpSelect<Rank, Dtype>::broadcast(std::vector<int>& calculated_shape)
{
const std::vector<int>& cond_shape = this->cond->getShape();
const std::vector<int>& then_shape = this->then_val->getShape();
const std::vector<int>& else_shape = this->else_val->getShape();
const std::vector<int>& output_shape = this->out->getShape();
+ // calculates the multipliers for Eigen
for (int i = 0; i < Rank; i++)
{
this->bcast_cond[i] = (cond_shape[i] != output_shape[i] && cond_shape[i] == 1) ? output_shape[i] : 1;
@@ -80,13 +81,34 @@ int OpSelect<Rank, Dtype>::broadcast()
this->bcast_else[i] = (else_shape[i] != output_shape[i] && else_shape[i] == 1) ? output_shape[i] : 1;
}
+ // calculates the broadcasted output shape
+ calculated_shape = cond_shape;
+ for (size_t i = 0; i < calculated_shape.size(); i++) {
+ if (calculated_shape[i] == 1) {
+ calculated_shape[i] = then_shape[i];
+ } else {
+ ERROR_IF(then_shape[i] != 1 && then_shape[i] != calculated_shape[i], "Broadcast_shape failure, input shapes are not compatible");
+ }
+
+ if (calculated_shape[i] == 1) {
+ calculated_shape[i] = else_shape[i];
+ } else {
+ ERROR_IF(else_shape[i] != 1 && else_shape[i] != calculated_shape[i], "Broadcast_shape failure, input shapes are not compatible");
+ }
+ }
+
return 0;
}
template <int Rank, TOSA_REF_TYPE Dtype>
int OpSelect<Rank, Dtype>::eval()
{
- this->broadcast();
+ std::vector<int> calculated_shape;
+ this->broadcast(calculated_shape);
+
+ auto result_shape = this->out->getShape();
+ ERROR_IF(calculated_shape != result_shape, "Broadcast_shape failure, calculated_shape and result_shape don't match");
+
this->out->getTensor() = this->cond->getTensor()
.broadcast(this->bcast_cond)
.select(this->then_val->getTensor().broadcast(this->bcast_then),
diff --git a/reference_model/src/ops/ewise_ternary.h b/reference_model/src/ops/ewise_ternary.h
index c6970cb..f24dfbe 100644
--- a/reference_model/src/ops/ewise_ternary.h
+++ b/reference_model/src/ops/ewise_ternary.h
@@ -63,7 +63,7 @@ public:
: OpSelectBase<Rank, Dtype>(sgt_, attribute_, id_)
{}
virtual int eval();
- int broadcast();
+ int broadcast(std::vector<int>& calculated_shape);
using InEigenType = typename OpSelectBase<Rank, Dtype>::InEigenType;
};