23 fprintf(stderr,
"Already opened file %s\n", filename);
26 file = fopen(filename,
"r");
28 fprintf(stderr,
"Cannot find file %s\n", filename);
42 void SparseTensorReader::readLine() {
43 if (!fgets(line, kColWidth, file)) {
44 fprintf(stderr,
"Cannot read next line of %s\n", filename);
51 assert(file &&
"Attempt to readHeader() before openFile()");
52 if (strstr(filename,
".mtx")) {
54 }
else if (strstr(filename,
".tns")) {
55 readExtFROSTTHeader();
57 fprintf(stderr,
"Unknown format %s\n", filename);
60 assert(
isValid() &&
"Failed to read the header");
66 const uint64_t *shape)
const {
67 assert(rank ==
getRank() &&
"Rank mismatch");
68 for (uint64_t r = 0; r < rank; r++)
69 assert((shape[r] == 0 || shape[r] == idata[2 + r]) &&
70 "Dimension size mismatch");
76 assert(
false &&
"Must readHeader() before calling canReadAs()");
98 fprintf(stderr,
"Unknown ValueKind: %d\n",
static_cast<uint8_t
>(valueKind_));
104 for (
char *c = token; *c; c++)
109 static inline bool streq(
const char *lhs,
const char *rhs) {
110 return strcmp(lhs, rhs) == 0;
114 static inline bool strne(
const char *lhs,
const char *rhs) {
115 return strcmp(lhs, rhs);
119 void SparseTensorReader::readMMEHeader() {
126 if (fscanf(file,
"%63s %63s %63s %63s %63s\n", header,
object, format, field,
128 fprintf(stderr,
"Corrupt header in %s\n", filename);
138 if (
streq(field,
"pattern")) {
140 }
else if (
streq(field,
"real")) {
142 }
else if (
streq(field,
"integer")) {
144 }
else if (
streq(field,
"complex")) {
147 fprintf(stderr,
"Unexpected header field value in %s\n", filename);
151 isSymmetric_ =
streq(symmetry,
"symmetric");
153 if (
strne(header,
"%%matrixmarket") ||
strne(
object,
"matrix") ||
154 strne(format,
"coordinate") ||
155 (
strne(symmetry,
"general") && !isSymmetric_)) {
156 fprintf(stderr,
"Cannot find a general sparse matrix in %s\n", filename);
167 if (sscanf(line,
"%" PRIu64
"%" PRIu64
"%" PRIu64
"\n", idata + 2, idata + 3,
169 fprintf(stderr,
"Cannot find size in %s\n", filename);
178 void SparseTensorReader::readExtFROSTTHeader() {
186 if (sscanf(line,
"%" PRIu64
"%" PRIu64
"\n", idata, idata + 1) != 2) {
187 fprintf(stderr,
"Cannot find metadata in %s\n", filename);
191 for (uint64_t r = 0; r < idata[0]; r++) {
192 if (fscanf(file,
"%" PRIu64, idata + 2 + r) != 1) {
193 fprintf(stderr,
"Cannot find dimension size %s\n", filename);
static bool streq(const char *lhs, const char *rhs)
Idiomatic name for checking string equality.
static bool strne(const char *lhs, const char *rhs)
Idiomatic name for checking string inequality.
static void toLower(char *token)
Helper to convert C-style strings (i.e., '\0' terminated) to lower case.
void assertMatchesShape(uint64_t rank, const uint64_t *shape) const
Asserts the shape subsumes the actual dimension sizes.
void closeFile()
Closes the file.
void readHeader()
Reads and parses the file's header.
bool canReadAs(PrimaryType valTy) const
Checks if the file's ValueKind can be converted into the given tensor PrimaryType.
bool isValid() const
Checks if a header has been successfully read.
uint64_t getRank() const
Gets the dimension-rank of the tensor.
void openFile()
Opens the file for reading.
PrimaryType
Encoding of the elemental type, for "overloading" @newSparseTensor.
constexpr bool isRealPrimaryType(PrimaryType valTy)
constexpr bool isComplexPrimaryType(PrimaryType valTy)
constexpr bool isFloatingPrimaryType(PrimaryType valTy)