MLIR  20.0.0git
File.cpp
Go to the documentation of this file.
1 //===- File.cpp - Reading/writing sparse tensors from/to files ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements reading and writing sparse tensor files.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
15 #include <cctype>
16 #include <cstring>
17 
18 using namespace mlir::sparse_tensor;
19 
20 /// Opens the file for reading.
22  if (file) {
23  fprintf(stderr, "Already opened file %s\n", filename);
24  exit(1);
25  }
26  file = fopen(filename, "r");
27  if (!file) {
28  fprintf(stderr, "Cannot find file %s\n", filename);
29  exit(1);
30  }
31 }
32 
33 /// Closes the file.
35  if (file) {
36  fclose(file);
37  file = nullptr;
38  }
39 }
40 
41 /// Attempts to read a line from the file.
42 void SparseTensorReader::readLine() {
43  if (!fgets(line, kColWidth, file)) {
44  fprintf(stderr, "Cannot read next line of %s\n", filename);
45  exit(1);
46  }
47 }
48 
49 /// Reads and parses the file's header.
51  assert(file && "Attempt to readHeader() before openFile()");
52  if (strstr(filename, ".mtx")) {
53  readMMEHeader();
54  } else if (strstr(filename, ".tns")) {
55  readExtFROSTTHeader();
56  } else {
57  fprintf(stderr, "Unknown format %s\n", filename);
58  exit(1);
59  }
60  assert(isValid() && "Failed to read the header");
61 }
62 
63 /// Asserts the shape subsumes the actual dimension sizes. Is only
64 /// valid after parsing the header.
66  const uint64_t *shape) const {
67  assert(rank == getRank() && "Rank mismatch");
68  for (uint64_t r = 0; r < rank; r++)
69  assert((shape[r] == 0 || shape[r] == idata[2 + r]) &&
70  "Dimension size mismatch");
71 }
72 
74  switch (valueKind_) {
76  assert(false && "Must readHeader() before calling canReadAs()");
77  return false; // In case assertions are disabled.
79  return true;
81  // When the file is specified to store integer values, we still
82  // allow implicitly converting those to floating primary-types.
83  return isRealPrimaryType(valTy);
84  case ValueKind::kReal:
85  // When the file is specified to store real/floating values, then
86  // we disallow implicit conversion to integer primary-types.
87  return isFloatingPrimaryType(valTy);
89  // When the file is specified to store complex values, then we
90  // require a complex primary-type.
91  return isComplexPrimaryType(valTy);
93  // The "extended" FROSTT format doesn't specify a ValueKind.
94  // So we allow implicitly converting the stored values to both
95  // integer and floating primary-types.
96  return isRealPrimaryType(valTy);
97  }
98  fprintf(stderr, "Unknown ValueKind: %d\n", static_cast<uint8_t>(valueKind_));
99  return false;
100 }
101 
102 /// Helper to convert C-style strings (i.e., '\0' terminated) to lower case.
103 static inline void toLower(char *token) {
104  for (char *c = token; *c; c++)
105  *c = tolower(*c);
106 }
107 
108 /// Idiomatic name for checking string equality.
109 static inline bool streq(const char *lhs, const char *rhs) {
110  return strcmp(lhs, rhs) == 0;
111 }
112 
113 /// Idiomatic name for checking string inequality.
114 static inline bool strne(const char *lhs, const char *rhs) {
115  return strcmp(lhs, rhs); // aka `!= 0`
116 }
117 
118 /// Read the MME header of a general sparse matrix of type real.
119 void SparseTensorReader::readMMEHeader() {
120  char header[64];
121  char object[64];
122  char format[64];
123  char field[64];
124  char symmetry[64];
125  // Read header line.
126  if (fscanf(file, "%63s %63s %63s %63s %63s\n", header, object, format, field,
127  symmetry) != 5) {
128  fprintf(stderr, "Corrupt header in %s\n", filename);
129  exit(1);
130  }
131  // Convert all to lowercase up front (to avoid accidental redundancy).
132  toLower(header);
133  toLower(object);
134  toLower(format);
135  toLower(field);
136  toLower(symmetry);
137  // Process `field`, which specify pattern or the data type of the values.
138  if (streq(field, "pattern")) {
139  valueKind_ = ValueKind::kPattern;
140  } else if (streq(field, "real")) {
141  valueKind_ = ValueKind::kReal;
142  } else if (streq(field, "integer")) {
143  valueKind_ = ValueKind::kInteger;
144  } else if (streq(field, "complex")) {
145  valueKind_ = ValueKind::kComplex;
146  } else {
147  fprintf(stderr, "Unexpected header field value in %s\n", filename);
148  exit(1);
149  }
150  // Set properties.
151  isSymmetric_ = streq(symmetry, "symmetric");
152  // Make sure this is a general sparse matrix.
153  if (strne(header, "%%matrixmarket") || strne(object, "matrix") ||
154  strne(format, "coordinate") ||
155  (strne(symmetry, "general") && !isSymmetric_)) {
156  fprintf(stderr, "Cannot find a general sparse matrix in %s\n", filename);
157  exit(1);
158  }
159  // Skip comments.
160  while (true) {
161  readLine();
162  if (line[0] != '%')
163  break;
164  }
165  // Next line contains M N NNZ.
166  idata[0] = 2; // rank
167  if (sscanf(line, "%" PRIu64 "%" PRIu64 "%" PRIu64 "\n", idata + 2, idata + 3,
168  idata + 1) != 3) {
169  fprintf(stderr, "Cannot find size in %s\n", filename);
170  exit(1);
171  }
172 }
173 
174 /// Read the "extended" FROSTT header. Although not part of the documented
175 /// format, we assume that the file starts with optional comments followed
176 /// by two lines that define the rank, the number of nonzeros, and the
177 /// dimensions sizes (one per rank) of the sparse tensor.
178 void SparseTensorReader::readExtFROSTTHeader() {
179  // Skip comments.
180  while (true) {
181  readLine();
182  if (line[0] != '#')
183  break;
184  }
185  // Next line contains RANK and NNZ.
186  if (sscanf(line, "%" PRIu64 "%" PRIu64 "\n", idata, idata + 1) != 2) {
187  fprintf(stderr, "Cannot find metadata in %s\n", filename);
188  exit(1);
189  }
190  // Followed by a line with the dimension sizes (one per rank).
191  for (uint64_t r = 0; r < idata[0]; r++) {
192  if (fscanf(file, "%" PRIu64, idata + 2 + r) != 1) {
193  fprintf(stderr, "Cannot find dimension size %s\n", filename);
194  exit(1);
195  }
196  }
197  readLine(); // end of line
198  // The FROSTT format does not define the data type of the nonzero elements.
199  valueKind_ = ValueKind::kUndefined;
200 }
static bool streq(const char *lhs, const char *rhs)
Idiomatic name for checking string equality.
Definition: File.cpp:109
static bool strne(const char *lhs, const char *rhs)
Idiomatic name for checking string inequality.
Definition: File.cpp:114
static void toLower(char *token)
Helper to convert C-style strings (i.e., '\0' terminated) to lower case.
Definition: File.cpp:103
void assertMatchesShape(uint64_t rank, const uint64_t *shape) const
Asserts the shape subsumes the actual dimension sizes.
Definition: File.cpp:65
void closeFile()
Closes the file.
Definition: File.cpp:34
void readHeader()
Reads and parses the file's header.
Definition: File.cpp:50
bool canReadAs(PrimaryType valTy) const
Checks if the file's ValueKind can be converted into the given tensor PrimaryType.
Definition: File.cpp:73
bool isValid() const
Checks if a header has been successfully read.
Definition: File.h:146
uint64_t getRank() const
Gets the dimension-rank of the tensor.
Definition: File.h:168
void openFile()
Opens the file for reading.
Definition: File.cpp:21
PrimaryType
Encoding of the elemental type, for "overloading" @newSparseTensor.
Definition: Enums.h:82
constexpr bool isRealPrimaryType(PrimaryType valTy)
Definition: Enums.h:137
constexpr bool isComplexPrimaryType(PrimaryType valTy)
Definition: Enums.h:141
constexpr bool isFloatingPrimaryType(PrimaryType valTy)
Definition: Enums.h:129