diff options
Diffstat (limited to 'reference_model/src/tensor.h')
-rw-r--r-- | reference_model/src/tensor.h | 45 |
1 files changed, 24 insertions, 21 deletions
diff --git a/reference_model/src/tensor.h b/reference_model/src/tensor.h index d39cc7c..6c0622e 100644 --- a/reference_model/src/tensor.h +++ b/reference_model/src/tensor.h @@ -33,9 +33,7 @@ class GraphNode; class Tensor { public: - Tensor(std::string tensorName_, - DType tensorDtype__, - std::vector<int> shape_); + Tensor(std::string tensorName_, DType tensorDtype__, std::vector<int> shape_); virtual ~Tensor(); @@ -240,9 +238,7 @@ template <class T> class TensorTemplate : public Tensor { public: - TensorTemplate(std::string tensorName_, - DType tensorDtype_, - std::vector<int> shape_) + TensorTemplate(std::string tensorName_, DType tensorDtype_, std::vector<int> shape_) : Tensor(tensorName_, tensorDtype_, shape_) { tensor = nullptr; @@ -606,11 +602,15 @@ int Tensor6<bool>::dumpTensor(FILE* out) const; class TensorFactory { public: - static Tensor* newTensor(std::string tensorName_, - DType tensorDtype_, - std::vector<int> shape_, - const uint32_t rank) + static Tensor* newTensor(std::string tensorName_, DType tensorDtype_, std::vector<int> shape_, const uint32_t rank) { + // Bail out if any dimension is invalid. + for (auto& dim : shape_) + { + if (dim <= 0) + goto done; + } + switch (tensorDtype_) { case DType_FLOAT: @@ -630,9 +630,8 @@ public: return new Tensor5<float>(tensorName_, tensorDtype_, shape_); case 6: return new Tensor6<float>(tensorName_, tensorDtype_, shape_); - default: - goto done; } + break; case DType_INT32: case DType_UINT8: case DType_INT4: @@ -654,9 +653,8 @@ public: return new Tensor5<int32_t>(tensorName_, tensorDtype_, shape_); case 6: return new Tensor6<int32_t>(tensorName_, tensorDtype_, shape_); - default: - goto done; } + break; case DType_INT48: switch (rank) { @@ -674,9 +672,8 @@ public: return new Tensor5<int64_t>(tensorName_, tensorDtype_, shape_); case 6: return new Tensor6<int64_t>(tensorName_, tensorDtype_, shape_); - default: - goto done; } + break; case DType_BOOL: switch (rank) { @@ -694,16 +691,22 @@ public: return new Tensor5<bool>(tensorName_, tensorDtype_, shape_); case 6: return new Tensor6<bool>(tensorName_, tensorDtype_, shape_); - default: - goto done; } + break; default: - goto done; + break; } done: - FATAL_ERROR("Unsupported tensor name=%s, type=%s, rank=%d", tensorName_.c_str(), EnumNamesDType()[tensorDtype_], - rank); + std::string shape_str("["); + for (auto& dim : shape_) + { + shape_str += (std::to_string(dim) + ", "); + } + shape_str.append("]"); + + FATAL_ERROR("Unsupported tensor name=%s, type=%s, rank=%d, shape=%s", tensorName_.c_str(), + EnumNamesDType()[tensorDtype_], rank, shape_str.c_str()); } static Tensor* newTensor(DType type, const std::vector<int> shape); |