aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/scatter_gather.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops/scatter_gather.cc')
-rw-r--r--reference_model/src/ops/scatter_gather.cc210
1 files changed, 160 insertions, 50 deletions
diff --git a/reference_model/src/ops/scatter_gather.cc b/reference_model/src/ops/scatter_gather.cc
index c54204a..2d1026f 100644
--- a/reference_model/src/ops/scatter_gather.cc
+++ b/reference_model/src/ops/scatter_gather.cc
@@ -20,31 +20,61 @@ using namespace TosaReference;
using namespace Eigen;
using namespace tosa;
-template <int InRank, int IndexRank, DType Dtype>
-OpGather<InRank, IndexRank, Dtype>::OpGather(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+template <DType Dtype>
+OpGather<Dtype>::OpGather(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
: GraphNode(Op_GATHER, id_)
{
setRequiredOperands(2, 1);
- setRequiredRank(1, 6);
-
- INIT_ATTRIBUTE(Axis);
}
-template <int InRank, int IndexRank, DType Dtype>
-OpGather<InRank, IndexRank, Dtype>::~OpGather()
-{
- if (attribute)
- delete attribute;
-}
+template <DType Dtype>
+OpGather<Dtype>::~OpGather()
+{}
-template <int InRank, int IndexRank, DType Dtype>
-int OpGather<InRank, IndexRank, Dtype>::checkTensorAttributes()
+template <DType Dtype>
+int OpGather<Dtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
return 1;
- if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
+ if (inputs[0]->getRank() != 3)
+ {
+ printNodeValidationError("OpGather: values needs to be rank 3 tensor");
+ return 1;
+ }
+
+ if (inputs[1]->getRank() != 2)
{
+ printNodeValidationError("OpGather: indices needs to be rank 2 tensor");
+ return 1;
+ }
+
+ if (outputs[0]->getRank() != 3)
+ {
+ printNodeValidationError("OpGather: output needs to be rank 3 tensor");
+ return 1;
+ }
+
+ K = inputs[0]->getShape()[1];
+ N = outputs[0]->getShape()[0];
+ W = outputs[0]->getShape()[1];
+ C = outputs[0]->getShape()[2];
+
+ if (N != inputs[0]->getShape()[0] || N != inputs[1]->getShape()[0])
+ {
+ printNodeValidationError("OpGather: dimension N mismatch");
+ return 1;
+ }
+
+ if (W != inputs[1]->getShape()[1])
+ {
+ printNodeValidationError("OpGather: dimension W mismatch");
+ return 1;
+ }
+
+ if (C != inputs[0]->getShape()[2])
+ {
+ printNodeValidationError("OpGather: dimension C mismatch");
return 1;
}
@@ -55,59 +85,133 @@ int OpGather<InRank, IndexRank, Dtype>::checkTensorAttributes()
return 1;
}
- in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
- index = dynamic_cast<TosaReference::TensorTemplate<TIndex>*>(inputs[1]);
- out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+ values = dynamic_cast<TosaReference::TensorTemplate<TValue>*>(inputs[0]);
+ indices = dynamic_cast<TosaReference::TensorTemplate<TIndex>*>(inputs[1]);
+ output = dynamic_cast<TosaReference::TensorTemplate<TOutput>*>(outputs[0]);
- ASSERT_MEM(in && index && out);
+ ASSERT_MEM(values && indices && output);
return 0;
}
-template <int InRank, int IndexRank, DType Dtype>
-int OpGather<InRank, IndexRank, Dtype>::eval()
+template <DType Dtype>
+int OpGather<Dtype>::eval()
+{
+ for (int32_t n = 0; n < N; n++)
+ {
+ for (int32_t w = 0; w < W; w++)
+ {
+ int32_t k = this->indices->getTensor()(n, w);
+ ASSERT_MSG_NODE(k >= 0 && k < K, "OpGather: index(%d, %d)=%d exceed valid range [0, %d]", n, w, k, K);
+ for (int32_t c = 0; c < C; c++)
+ {
+ EigenType value = this->values->getTensor()(n, k, c);
+ this->output->getTensor()(n, w, c) = value;
+ }
+ }
+ }
+
+ return GraphNode::eval();
+}
+
+template <DType Dtype>
+OpScatter<Dtype>::OpScatter(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(Op_SCATTER, id_)
{
- int axis = attribute->axis();
+ setRequiredOperands(3, 1);
+}
+
+template <DType Dtype>
+OpScatter<Dtype>::~OpScatter()
+{}
- // calculate size left and right to axis
- int left_size = 1;
- for (int i = 0; i < axis; ++i)
+template <DType Dtype>
+int OpScatter<Dtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (inputs[0]->getRank() != 3)
+ {
+ printNodeValidationError("OpGather: values_in needs to be rank 3 tensor");
+ return 1;
+ }
+
+ if (inputs[1]->getRank() != 2)
+ {
+ printNodeValidationError("OpGather: indices needs to be rank 2 tensor");
+ return 1;
+ }
+
+ if (inputs[2]->getRank() != 3)
{
- left_size *= in->getShape()[i];
+ printNodeValidationError("OpGather: input needs to be rank 3 tensor");
+ return 1;
}
- int right_size = 1;
- for (int i = axis + 1; i < in->getRank(); ++i)
+ if (outputs[0]->getRank() != 3)
{
- right_size *= in->getShape()[i];
+ printNodeValidationError("OpGather: values_out needs to be rank 3 tensor");
+ return 1;
}
- InEigenType* input_data = in->getTensor().data();
- int32_t* index_data = index->getTensor().data();
- OutEigenType* output_data = out->getTensor().data();
+ W = inputs[2]->getShape()[1];
+ N = outputs[0]->getShape()[0];
+ K = outputs[0]->getShape()[1];
+ C = outputs[0]->getShape()[2];
- int32_t axis_size = in->getShape()[axis];
- int32_t index_count = index->getElementCount();
+ if (N != inputs[0]->getShape()[0] || N != inputs[1]->getShape()[0] || N != inputs[2]->getShape()[0])
+ {
+ printNodeValidationError("OpScatter: dimension N mismatch");
+ return 1;
+ }
- // sanity check if index is valid
- // need to check until this point since index is not known until runtime
- for (size_t i = 0; i < index->getElementCount(); i++)
+ if (W != inputs[1]->getShape()[1])
{
- if (index_data[i] >= axis_size)
- {
- FATAL_ERROR_NODE("OpGather: index[%lu]=%i can't exceed axis_size=%i", i, index_data[i], axis_size);
- }
+ printNodeValidationError("OpGather: dimension W mismatch");
+ return 1;
}
- // Eigen stores tensor in column-major
- // so we iterate through dimension right to axis and the index array
- // do memory copy with size of left size each time
- for (int right = 0; right < right_size; ++right)
+ if (C != inputs[0]->getShape()[2] || C != inputs[2]->getShape()[2])
+ {
+ printNodeValidationError("OpGather: dimension C mismatch");
+ return 1;
+ }
+
+ // output and input must be the same types
+ if (inputs[0]->matchType(*outputs[0]))
+ {
+ printNodeValidationError("Failure to match input and output type");
+ return 1;
+ }
+
+ values_in = dynamic_cast<TosaReference::TensorTemplate<TValue>*>(inputs[0]);
+ indices = dynamic_cast<TosaReference::TensorTemplate<TIndex>*>(inputs[1]);
+ input = dynamic_cast<TosaReference::TensorTemplate<TValue>*>(inputs[2]);
+ values_out = dynamic_cast<TosaReference::TensorTemplate<TOutput>*>(outputs[0]);
+
+ ASSERT_MEM(values_in && indices && input && values_out);
+
+ return 0;
+}
+
+template <DType Dtype>
+int OpScatter<Dtype>::eval()
+{
+ // Initializes the output tensor with the input value for values that are unchanged by the scatter operation.
+ this->values_out->getTensor() = this->values_in->getTensor();
+
+ for (int n = 0; n < N; n++)
{
- for (int i = 0; i < index_count; ++i)
+ for (int w = 0; w < W; w++)
{
- std::memcpy(output_data + (right * index_count + i) * left_size,
- input_data + (right * axis_size + index_data[i]) * left_size, sizeof(InEigenType) * left_size);
+ int32_t k = this->indices->getTensor()(n, w);
+ ASSERT_MSG_NODE(k >= 0 && k < K, "OpScatter: index(%d, %d)=%d exceed valid range [0, %d]", n, w, k, K);
+ for (int c = 0; c < C; c++)
+ {
+ EigenType value = this->input->getTensor()(n, w, c);
+ this->values_out->getTensor()(n, k, c) = value;
+ }
}
}
@@ -115,6 +219,12 @@ int OpGather<InRank, IndexRank, Dtype>::eval()
}
// template explicit instantiation
-DEF_INSTANTIATE_GATHER(OpGather, AINT8);
-DEF_INSTANTIATE_GATHER(OpGather, INT16);
-DEF_INSTANTIATE_GATHER(OpGather, INT32);
+DEF_INSTANTIATE_ONE_TYPE(OpGather, AINT8);
+DEF_INSTANTIATE_ONE_TYPE(OpGather, INT16);
+DEF_INSTANTIATE_ONE_TYPE(OpGather, INT32);
+DEF_INSTANTIATE_ONE_TYPE(OpGather, FLOAT);
+
+DEF_INSTANTIATE_ONE_TYPE(OpScatter, AINT8);
+DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT16);
+DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT32);
+DEF_INSTANTIATE_ONE_TYPE(OpScatter, FLOAT);