aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops')
-rw-r--r--reference_model/src/ops/image.cc116
-rw-r--r--reference_model/src/ops/image.h2
-rw-r--r--reference_model/src/ops/op_factory.cc20
-rw-r--r--reference_model/src/ops/op_factory.h53
-rw-r--r--reference_model/src/ops/scatter_gather.cc210
-rw-r--r--reference_model/src/ops/scatter_gather.h45
6 files changed, 293 insertions, 153 deletions
diff --git a/reference_model/src/ops/image.cc b/reference_model/src/ops/image.cc
index d3352ce..829a6e0 100644
--- a/reference_model/src/ops/image.cc
+++ b/reference_model/src/ops/image.cc
@@ -51,6 +51,8 @@ int OpResize<InDtype, OutDtype>::checkTensorAttributes()
stride = this->attribute->stride();
offset = this->attribute->offset();
shift = this->attribute->shift();
+ stride_fp = this->attribute->stride_fp();
+ offset_fp = this->attribute->offset_fp();
mode = this->attribute->mode();
int output_height = outputs[0]->getShape()[1];
@@ -58,7 +60,7 @@ int OpResize<InDtype, OutDtype>::checkTensorAttributes()
if (this->mode == ResizeMode_BILINEAR)
{
- if (OutDtype != DType_INT32 && OutDtype != DType_INT48)
+ if (OutDtype != DType_INT32 && OutDtype != DType_INT48 && OutDtype != DType_FLOAT)
{
printNodeValidationError("OpResize: invalid data type for BILINEAR");
return 1;
@@ -66,7 +68,7 @@ int OpResize<InDtype, OutDtype>::checkTensorAttributes()
}
else
{
- if (OutDtype != DType_INT8 && OutDtype != DType_INT16)
+ if (OutDtype != DType_INT8 && OutDtype != DType_INT16 && OutDtype != DType_FLOAT)
{
printNodeValidationError("OpResize: invalid data type for NEAREST");
return 1;
@@ -79,18 +81,6 @@ int OpResize<InDtype, OutDtype>::checkTensorAttributes()
return 1;
}
- if (shift < 1 || shift > 11)
- {
- printNodeValidationError("OpResize: attribute shift should be within [1, 11]");
- return 1;
- }
-
- if (stride[0] <= 0 || stride[1] <= 0)
- {
- printNodeValidationError("OpResize: invalid attribute stride");
- return 1;
- }
-
in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
@@ -112,6 +102,8 @@ int OpResize<InDtype, OutDtype>::eval()
int out_width = out->getShape()[2];
int out_channels = out->getShape()[3];
+ ASSERT_MSG_NODE(shift > 0 && shift <= 11, "OpResize: attribute shift should be within [1, 11]");
+ ASSERT_MSG_NODE(stride[0] > 0 && stride[1] > 0, "OpResize: invalid attribute stride");
ASSERT_MSG_NODE(in_batch == out_batch, "OpResize: output tensor batch mismatch");
ASSERT_MSG_NODE(in_channels == out_channels, "OpResize: output tensor channel mismatch");
@@ -120,30 +112,30 @@ int OpResize<InDtype, OutDtype>::eval()
for (int oy = 0; oy < out_height; oy++)
for (int ox = 0; ox < out_width; ox++)
{
- int y = oy * stride[0] + offset[0];
- int x = ox * stride[1] + offset[1];
+ int32_t y = oy * stride[0] + offset[0];
+ int32_t x = ox * stride[1] + offset[1];
- int iy = y >> shift;
- int dy = y - (iy << shift);
- int ix = x >> shift;
- int dx = x - (ix << shift);
+ int32_t iy = y >> shift;
+ int32_t dy = y - (iy << shift);
+ int32_t ix = x >> shift;
+ int32_t dx = x - (ix << shift);
- int iy0 = MAX(iy, 0);
- int iy1 = MIN(iy + 1, in_height - 1);
- int ix0 = MAX(ix, 0);
- int ix1 = MIN(ix + 1, in_width - 1);
+ int32_t iy0 = MAX(iy, 0);
+ int32_t iy1 = MIN(iy + 1, in_height - 1);
+ int32_t ix0 = MAX(ix, 0);
+ int32_t ix1 = MIN(ix + 1, in_width - 1);
ASSERT_MSG(iy0 <= iy1 && ix0 <= ix1, "OpResize: invalid index (iy0, iy1, ix0, ix1)=(%d,%d,%d,%d)",
iy0, iy1, ix0, ix1);
- InEigenType v00 = in->getTensor()(b, iy0, ix0, c);
- InEigenType v01 = in->getTensor()(b, iy0, ix1, c);
- InEigenType v10 = in->getTensor()(b, iy1, ix0, c);
- InEigenType v11 = in->getTensor()(b, iy1, ix1, c);
-
OutEigenType acc;
if (mode == ResizeMode_BILINEAR)
{
+ InEigenType v00 = in->getTensor()(b, iy0, ix0, c);
+ InEigenType v01 = in->getTensor()(b, iy0, ix1, c);
+ InEigenType v10 = in->getTensor()(b, iy1, ix0, c);
+ InEigenType v11 = in->getTensor()(b, iy1, ix1, c);
+
acc = (OutEigenType)v00 * ((1 << shift) - dy) * ((1 << shift) - dx);
acc = acc + (OutEigenType)v01 * ((1 << shift) - dy) * dx;
acc = acc + (OutEigenType)v10 * dy * ((1 << shift) - dx);
@@ -162,8 +154,74 @@ int OpResize<InDtype, OutDtype>::eval()
return GraphNode::eval();
}
+template <>
+int OpResize<DType_FLOAT, DType_FLOAT>::eval()
+{
+ int in_batch = in->getShape()[0];
+ int in_height = in->getShape()[1];
+ int in_width = in->getShape()[2];
+ int in_channels = in->getShape()[3];
+
+ int out_batch = out->getShape()[0];
+ int out_height = out->getShape()[1];
+ int out_width = out->getShape()[2];
+ int out_channels = out->getShape()[3];
+
+ ASSERT_MSG_NODE(shift == 0, "OpResize: float mode must have 0 shift");
+ ASSERT_MSG_NODE(stride_fp[0] > 0.0f && stride_fp[1] > 0.0f, "OpResize: invalid attribute stride");
+ ASSERT_MSG_NODE(in_batch == out_batch, "OpResize: output tensor batch mismatch");
+ ASSERT_MSG_NODE(in_channels == out_channels, "OpResize: output tensor channel mismatch");
+
+ for (int b = 0; b < out_batch; b++)
+ for (int c = 0; c < out_channels; c++)
+ for (int oy = 0; oy < out_height; oy++)
+ for (int ox = 0; ox < out_width; ox++)
+ {
+ float y = oy * stride_fp[0] + offset_fp[0];
+ float x = ox * stride_fp[1] + offset_fp[1];
+
+ int32_t iy = static_cast<int32_t>(std::floor(y));
+ float dy = y - static_cast<float>(iy);
+ int32_t ix = static_cast<int32_t>(std::floor(x));
+ float dx = x - static_cast<float>(ix);
+
+ int32_t iy0 = MAX(iy, 0);
+ int32_t iy1 = MIN(iy + 1, in_height - 1);
+ int32_t ix0 = MAX(ix, 0);
+ int32_t ix1 = MIN(ix + 1, in_width - 1);
+
+ ASSERT_MSG(iy0 <= iy1 && ix0 <= ix1, "OpResize: invalid index (iy0, iy1, ix0, ix1)=(%d,%d,%d,%d)",
+ iy0, iy1, ix0, ix1);
+
+ OutEigenType acc;
+ if (mode == ResizeMode_BILINEAR)
+ {
+ InEigenType v00 = in->getTensor()(b, iy0, ix0, c);
+ InEigenType v01 = in->getTensor()(b, iy0, ix1, c);
+ InEigenType v10 = in->getTensor()(b, iy1, ix0, c);
+ InEigenType v11 = in->getTensor()(b, iy1, ix1, c);
+
+ acc = (OutEigenType)v00 * (1.0 - dy) * (1.0 - dx);
+ acc = acc + (OutEigenType)v01 * (1.0 - dy) * dx;
+ acc = acc + (OutEigenType)v10 * dy * (1.0 - dx);
+ acc = acc + (OutEigenType)v11 * dy * dx;
+ }
+ else
+ {
+ iy = (dy >= 0.5) ? iy1 : iy0;
+ ix = (dx >= 0.5) ? ix1 : ix0;
+ acc = in->getTensor()(b, iy, ix, c);
+ }
+
+ out->getTensor()(b, oy, ox, c) = acc;
+ }
+
+ return GraphNode::eval();
+}
+
// template explicit instantiation
DEF_INSTANTIATE_TWO_TYPE(OpResize, INT8, INT32);
DEF_INSTANTIATE_TWO_TYPE(OpResize, INT8, INT8);
DEF_INSTANTIATE_TWO_TYPE(OpResize, INT16, INT48);
DEF_INSTANTIATE_TWO_TYPE(OpResize, INT16, INT16);
+DEF_INSTANTIATE_TWO_TYPE(OpResize, FLOAT, FLOAT);
diff --git a/reference_model/src/ops/image.h b/reference_model/src/ops/image.h
index 9d15d49..5dd14c8 100644
--- a/reference_model/src/ops/image.h
+++ b/reference_model/src/ops/image.h
@@ -43,6 +43,8 @@ protected:
std::vector<int32_t> stride;
std::vector<int32_t> offset;
int32_t shift;
+ std::vector<float> stride_fp;
+ std::vector<float> offset_fp;
ResizeMode mode;
TosaReference::TensorTemplate<TIn>* in;
TosaReference::TensorTemplate<TOut>* out;
diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc
index bad0c40..4a06248 100644
--- a/reference_model/src/ops/op_factory.cc
+++ b/reference_model/src/ops/op_factory.cc
@@ -334,14 +334,17 @@ GraphNode* OpFactory::newOp(TosaSerializationHandler* tsh,
// scatter_gather
case Op_GATHER:
- {
- // output.rank = input.rank - 1 + index.rank
- int32_t index_rank = outputRank - inputRank + 1;
- DEF_FACTORY_GATHER(OpGather, AINT8);
- DEF_FACTORY_GATHER(OpGather, INT16);
- DEF_FACTORY_GATHER(OpGather, INT32);
- }
- break;
+ DEF_FACTORY_ONE_TYPE(OpGather, AINT8);
+ DEF_FACTORY_ONE_TYPE(OpGather, INT16);
+ DEF_FACTORY_ONE_TYPE(OpGather, INT32);
+ DEF_FACTORY_ONE_TYPE(OpGather, FLOAT);
+ break;
+ case Op_SCATTER:
+ DEF_FACTORY_ONE_TYPE(OpScatter, AINT8);
+ DEF_FACTORY_ONE_TYPE(OpScatter, INT16);
+ DEF_FACTORY_ONE_TYPE(OpScatter, INT32);
+ DEF_FACTORY_ONE_TYPE(OpScatter, FLOAT);
+ break;
// image
case Op_RESIZE:
@@ -349,6 +352,7 @@ GraphNode* OpFactory::newOp(TosaSerializationHandler* tsh,
DEF_FACTORY_TWO_TYPE_RESIZE(OpResize, INT8, INT8);
DEF_FACTORY_TWO_TYPE_RESIZE(OpResize, INT16, INT48);
DEF_FACTORY_TWO_TYPE_RESIZE(OpResize, INT16, INT16);
+ DEF_FACTORY_TWO_TYPE_RESIZE(OpResize, FLOAT, FLOAT);
break;
// data_nodes
diff --git a/reference_model/src/ops/op_factory.h b/reference_model/src/ops/op_factory.h
index cde6841..0c116b6 100644
--- a/reference_model/src/ops/op_factory.h
+++ b/reference_model/src/ops/op_factory.h
@@ -218,59 +218,6 @@
} \
}
-#define DEF_FACTORY_GATHER(OP, DTYPE) \
- if (inputDType == DType_##DTYPE && outputDType == DType_##DTYPE) \
- { \
- switch (inputRank) \
- { \
- case 1: \
- switch (index_rank) \
- { \
- DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 1, DTYPE); \
- DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 2, DTYPE); \
- DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 3, DTYPE); \
- DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 4, DTYPE); \
- DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 5, DTYPE); \
- DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 6, DTYPE); \
- } \
- case 2: \
- switch (index_rank) \
- { \
- DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 1, DTYPE); \
- DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 2, DTYPE); \
- DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 3, DTYPE); \
- DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 4, DTYPE); \
- DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 5, DTYPE); \
- } \
- case 3: \
- switch (index_rank) \
- { \
- DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 1, DTYPE); \
- DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 2, DTYPE); \
- DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 3, DTYPE); \
- DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 4, DTYPE); \
- } \
- case 4: \
- switch (index_rank) \
- { \
- DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 1, DTYPE); \
- DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 2, DTYPE); \
- DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 3, DTYPE); \
- } \
- case 5: \
- switch (index_rank) \
- { \
- DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 5, 1, DTYPE); \
- DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 5, 2, DTYPE); \
- } \
- case 6: \
- switch (index_rank) \
- { \
- DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 6, 1, DTYPE); \
- } \
- } \
- }
-
namespace TosaReference
{
diff --git a/reference_model/src/ops/scatter_gather.cc b/reference_model/src/ops/scatter_gather.cc
index c54204a..2d1026f 100644
--- a/reference_model/src/ops/scatter_gather.cc
+++ b/reference_model/src/ops/scatter_gather.cc
@@ -20,31 +20,61 @@ using namespace TosaReference;
using namespace Eigen;
using namespace tosa;
-template <int InRank, int IndexRank, DType Dtype>
-OpGather<InRank, IndexRank, Dtype>::OpGather(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+template <DType Dtype>
+OpGather<Dtype>::OpGather(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
: GraphNode(Op_GATHER, id_)
{
setRequiredOperands(2, 1);
- setRequiredRank(1, 6);
-
- INIT_ATTRIBUTE(Axis);
}
-template <int InRank, int IndexRank, DType Dtype>
-OpGather<InRank, IndexRank, Dtype>::~OpGather()
-{
- if (attribute)
- delete attribute;
-}
+template <DType Dtype>
+OpGather<Dtype>::~OpGather()
+{}
-template <int InRank, int IndexRank, DType Dtype>
-int OpGather<InRank, IndexRank, Dtype>::checkTensorAttributes()
+template <DType Dtype>
+int OpGather<Dtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
return 1;
- if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
+ if (inputs[0]->getRank() != 3)
+ {
+ printNodeValidationError("OpGather: values needs to be rank 3 tensor");
+ return 1;
+ }
+
+ if (inputs[1]->getRank() != 2)
{
+ printNodeValidationError("OpGather: indices needs to be rank 2 tensor");
+ return 1;
+ }
+
+ if (outputs[0]->getRank() != 3)
+ {
+ printNodeValidationError("OpGather: output needs to be rank 3 tensor");
+ return 1;
+ }
+
+ K = inputs[0]->getShape()[1];
+ N = outputs[0]->getShape()[0];
+ W = outputs[0]->getShape()[1];
+ C = outputs[0]->getShape()[2];
+
+ if (N != inputs[0]->getShape()[0] || N != inputs[1]->getShape()[0])
+ {
+ printNodeValidationError("OpGather: dimension N mismatch");
+ return 1;
+ }
+
+ if (W != inputs[1]->getShape()[1])
+ {
+ printNodeValidationError("OpGather: dimension W mismatch");
+ return 1;
+ }
+
+ if (C != inputs[0]->getShape()[2])
+ {
+ printNodeValidationError("OpGather: dimension C mismatch");
return 1;
}
@@ -55,59 +85,133 @@ int OpGather<InRank, IndexRank, Dtype>::checkTensorAttributes()
return 1;
}
- in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
- index = dynamic_cast<TosaReference::TensorTemplate<TIndex>*>(inputs[1]);
- out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+ values = dynamic_cast<TosaReference::TensorTemplate<TValue>*>(inputs[0]);
+ indices = dynamic_cast<TosaReference::TensorTemplate<TIndex>*>(inputs[1]);
+ output = dynamic_cast<TosaReference::TensorTemplate<TOutput>*>(outputs[0]);
- ASSERT_MEM(in && index && out);
+ ASSERT_MEM(values && indices && output);
return 0;
}
-template <int InRank, int IndexRank, DType Dtype>
-int OpGather<InRank, IndexRank, Dtype>::eval()
+template <DType Dtype>
+int OpGather<Dtype>::eval()
+{
+ for (int32_t n = 0; n < N; n++)
+ {
+ for (int32_t w = 0; w < W; w++)
+ {
+ int32_t k = this->indices->getTensor()(n, w);
+ ASSERT_MSG_NODE(k >= 0 && k < K, "OpGather: index(%d, %d)=%d exceed valid range [0, %d]", n, w, k, K);
+ for (int32_t c = 0; c < C; c++)
+ {
+ EigenType value = this->values->getTensor()(n, k, c);
+ this->output->getTensor()(n, w, c) = value;
+ }
+ }
+ }
+
+ return GraphNode::eval();
+}
+
+template <DType Dtype>
+OpScatter<Dtype>::OpScatter(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(Op_SCATTER, id_)
{
- int axis = attribute->axis();
+ setRequiredOperands(3, 1);
+}
+
+template <DType Dtype>
+OpScatter<Dtype>::~OpScatter()
+{}
- // calculate size left and right to axis
- int left_size = 1;
- for (int i = 0; i < axis; ++i)
+template <DType Dtype>
+int OpScatter<Dtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (inputs[0]->getRank() != 3)
+ {
+ printNodeValidationError("OpGather: values_in needs to be rank 3 tensor");
+ return 1;
+ }
+
+ if (inputs[1]->getRank() != 2)
+ {
+ printNodeValidationError("OpGather: indices needs to be rank 2 tensor");
+ return 1;
+ }
+
+ if (inputs[2]->getRank() != 3)
{
- left_size *= in->getShape()[i];
+ printNodeValidationError("OpGather: input needs to be rank 3 tensor");
+ return 1;
}
- int right_size = 1;
- for (int i = axis + 1; i < in->getRank(); ++i)
+ if (outputs[0]->getRank() != 3)
{
- right_size *= in->getShape()[i];
+ printNodeValidationError("OpGather: values_out needs to be rank 3 tensor");
+ return 1;
}
- InEigenType* input_data = in->getTensor().data();
- int32_t* index_data = index->getTensor().data();
- OutEigenType* output_data = out->getTensor().data();
+ W = inputs[2]->getShape()[1];
+ N = outputs[0]->getShape()[0];
+ K = outputs[0]->getShape()[1];
+ C = outputs[0]->getShape()[2];
- int32_t axis_size = in->getShape()[axis];
- int32_t index_count = index->getElementCount();
+ if (N != inputs[0]->getShape()[0] || N != inputs[1]->getShape()[0] || N != inputs[2]->getShape()[0])
+ {
+ printNodeValidationError("OpScatter: dimension N mismatch");
+ return 1;
+ }
- // sanity check if index is valid
- // need to check until this point since index is not known until runtime
- for (size_t i = 0; i < index->getElementCount(); i++)
+ if (W != inputs[1]->getShape()[1])
{
- if (index_data[i] >= axis_size)
- {
- FATAL_ERROR_NODE("OpGather: index[%lu]=%i can't exceed axis_size=%i", i, index_data[i], axis_size);
- }
+ printNodeValidationError("OpGather: dimension W mismatch");
+ return 1;
}
- // Eigen stores tensor in column-major
- // so we iterate through dimension right to axis and the index array
- // do memory copy with size of left size each time
- for (int right = 0; right < right_size; ++right)
+ if (C != inputs[0]->getShape()[2] || C != inputs[2]->getShape()[2])
+ {
+ printNodeValidationError("OpGather: dimension C mismatch");
+ return 1;
+ }
+
+ // output and input must be the same types
+ if (inputs[0]->matchType(*outputs[0]))
+ {
+ printNodeValidationError("Failure to match input and output type");
+ return 1;
+ }
+
+ values_in = dynamic_cast<TosaReference::TensorTemplate<TValue>*>(inputs[0]);
+ indices = dynamic_cast<TosaReference::TensorTemplate<TIndex>*>(inputs[1]);
+ input = dynamic_cast<TosaReference::TensorTemplate<TValue>*>(inputs[2]);
+ values_out = dynamic_cast<TosaReference::TensorTemplate<TOutput>*>(outputs[0]);
+
+ ASSERT_MEM(values_in && indices && input && values_out);
+
+ return 0;
+}
+
+template <DType Dtype>
+int OpScatter<Dtype>::eval()
+{
+ // Initializes the output tensor with the input value for values that are unchanged by the scatter operation.
+ this->values_out->getTensor() = this->values_in->getTensor();
+
+ for (int n = 0; n < N; n++)
{
- for (int i = 0; i < index_count; ++i)
+ for (int w = 0; w < W; w++)
{
- std::memcpy(output_data + (right * index_count + i) * left_size,
- input_data + (right * axis_size + index_data[i]) * left_size, sizeof(InEigenType) * left_size);
+ int32_t k = this->indices->getTensor()(n, w);
+ ASSERT_MSG_NODE(k >= 0 && k < K, "OpScatter: index(%d, %d)=%d exceed valid range [0, %d]", n, w, k, K);
+ for (int c = 0; c < C; c++)
+ {
+ EigenType value = this->input->getTensor()(n, w, c);
+ this->values_out->getTensor()(n, k, c) = value;
+ }
}
}
@@ -115,6 +219,12 @@ int OpGather<InRank, IndexRank, Dtype>::eval()
}
// template explicit instantiation
-DEF_INSTANTIATE_GATHER(OpGather, AINT8);
-DEF_INSTANTIATE_GATHER(OpGather, INT16);
-DEF_INSTANTIATE_GATHER(OpGather, INT32);
+DEF_INSTANTIATE_ONE_TYPE(OpGather, AINT8);
+DEF_INSTANTIATE_ONE_TYPE(OpGather, INT16);
+DEF_INSTANTIATE_ONE_TYPE(OpGather, INT32);
+DEF_INSTANTIATE_ONE_TYPE(OpGather, FLOAT);
+
+DEF_INSTANTIATE_ONE_TYPE(OpScatter, AINT8);
+DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT16);
+DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT32);
+DEF_INSTANTIATE_ONE_TYPE(OpScatter, FLOAT);
diff --git a/reference_model/src/ops/scatter_gather.h b/reference_model/src/ops/scatter_gather.h
index d9b1263..17ea723 100644
--- a/reference_model/src/ops/scatter_gather.h
+++ b/reference_model/src/ops/scatter_gather.h
@@ -23,9 +23,7 @@ using namespace tosa;
namespace TosaReference
{
-// input and index can have different rank
-// and infer OutRank statically
-template <int InRank, int IndexRank, DType Dtype>
+template <DType Dtype>
class OpGather : public GraphNode
{
public:
@@ -35,18 +33,39 @@ public:
virtual int checkTensorAttributes();
virtual int eval();
- static constexpr int OutRank = InRank - 1 + IndexRank;
- using InEigenType = typename GetEigenType<Dtype>::type;
- using OutEigenType = typename GetEigenType<Dtype>::type;
- using TIn = Eigen::Tensor<InEigenType, InRank>;
- using TIndex = Eigen::Tensor<int32_t, IndexRank>;
- using TOut = Eigen::Tensor<OutEigenType, OutRank>;
+ using EigenType = typename GetEigenType<Dtype>::type;
+ using TValue = Eigen::Tensor<EigenType, 3>;
+ using TIndex = Eigen::Tensor<int32_t, 2>;
+ using TOutput = Eigen::Tensor<EigenType, 3>;
protected:
- TosaAxisAttribute* attribute;
- TosaReference::TensorTemplate<TIn>* in;
- TosaReference::TensorTemplate<TIndex>* index;
- TosaReference::TensorTemplate<TOut>* out;
+ int32_t N, W, K, C;
+ TosaReference::TensorTemplate<TValue>* values;
+ TosaReference::TensorTemplate<TIndex>* indices;
+ TosaReference::TensorTemplate<TOutput>* output;
+};
+
+template <DType Dtype>
+class OpScatter : public GraphNode
+{
+public:
+ OpScatter(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpScatter();
+
+ virtual int checkTensorAttributes();
+ virtual int eval();
+
+ using EigenType = typename GetEigenType<Dtype>::type;
+ using TValue = Eigen::Tensor<EigenType, 3>;
+ using TIndex = Eigen::Tensor<int32_t, 2>;
+ using TOutput = Eigen::Tensor<EigenType, 3>;
+
+protected:
+ int32_t N, W, K, C;
+ TosaReference::TensorTemplate<TValue>* values_in;
+ TosaReference::TensorTemplate<TIndex>* indices;
+ TosaReference::TensorTemplate<TValue>* input;
+ TosaReference::TensorTemplate<TOutput>* values_out;
};
}; // namespace TosaReference