MLIR 22.0.0git
File.cpp
Go to the documentation of this file.
1//===- File.cpp - Reading/writing sparse tensors from/to files ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements reading and writing sparse tensor files.
10//
11//===----------------------------------------------------------------------===//
12
14
15#include <cctype>
16#include <cstring>
17
18using namespace mlir::sparse_tensor;
19
20/// Opens the file for reading.
22 if (file) {
23 fprintf(stderr, "Already opened file %s\n", filename);
24 exit(1);
25 }
26 file = fopen(filename, "r");
27 if (!file) {
28 fprintf(stderr, "Cannot find file %s\n", filename);
29 exit(1);
30 }
31}
32
33/// Closes the file.
35 if (file) {
36 fclose(file);
37 file = nullptr;
38 }
39}
40
41/// Attempts to read a line from the file.
42void SparseTensorReader::readLine() {
43 if (!fgets(line, kColWidth, file)) {
44 fprintf(stderr, "Cannot read next line of %s\n", filename);
45 exit(1);
46 }
47}
48
49/// Reads and parses the file's header.
51 assert(file && "Attempt to readHeader() before openFile()");
52 if (strstr(filename, ".mtx")) {
53 readMMEHeader();
54 } else if (strstr(filename, ".tns")) {
55 readExtFROSTTHeader();
56 } else {
57 fprintf(stderr, "Unknown format %s\n", filename);
58 exit(1);
59 }
60 assert(isValid() && "Failed to read the header");
61}
62
63/// Asserts the shape subsumes the actual dimension sizes. Is only
64/// valid after parsing the header.
66 const uint64_t *shape) const {
67 assert(rank == getRank() && "Rank mismatch");
68 for (uint64_t r = 0; r < rank; r++)
69 assert((shape[r] == 0 || shape[r] == idata[2 + r]) &&
70 "Dimension size mismatch");
71}
72
74 switch (valueKind_) {
76 assert(false && "Must readHeader() before calling canReadAs()");
77 return false; // In case assertions are disabled.
79 return true;
81 // When the file is specified to store integer values, we still
82 // allow implicitly converting those to floating primary-types.
83 return isRealPrimaryType(valTy);
85 // When the file is specified to store real/floating values, then
86 // we disallow implicit conversion to integer primary-types.
87 return isFloatingPrimaryType(valTy);
89 // When the file is specified to store complex values, then we
90 // require a complex primary-type.
91 return isComplexPrimaryType(valTy);
93 // The "extended" FROSTT format doesn't specify a ValueKind.
94 // So we allow implicitly converting the stored values to both
95 // integer and floating primary-types.
96 return isRealPrimaryType(valTy);
97 }
98 fprintf(stderr, "Unknown ValueKind: %d\n", static_cast<uint8_t>(valueKind_));
99 return false;
100}
101
102/// Helper to convert C-style strings (i.e., '\0' terminated) to lower case.
103static inline void toLower(char *token) {
104 for (char *c = token; *c; c++)
105 *c = tolower(*c);
106}
107
108/// Idiomatic name for checking string equality.
109static inline bool streq(const char *lhs, const char *rhs) {
110 return strcmp(lhs, rhs) == 0;
111}
112
113/// Idiomatic name for checking string inequality.
114static inline bool strne(const char *lhs, const char *rhs) {
115 return strcmp(lhs, rhs); // aka `!= 0`
116}
117
118/// Read the MME header of a general sparse matrix of type real.
119void SparseTensorReader::readMMEHeader() {
120 char header[64];
121 char object[64];
122 char format[64];
123 char field[64];
124 char symmetry[64];
125 // Read header line.
126 if (fscanf(file, "%63s %63s %63s %63s %63s\n", header, object, format, field,
127 symmetry) != 5) {
128 fprintf(stderr, "Corrupt header in %s\n", filename);
129 exit(1);
130 }
131 // Convert all to lowercase up front (to avoid accidental redundancy).
132 toLower(header);
133 toLower(object);
134 toLower(format);
135 toLower(field);
136 toLower(symmetry);
137 // Process `field`, which specify pattern or the data type of the values.
138 if (streq(field, "pattern")) {
139 valueKind_ = ValueKind::kPattern;
140 } else if (streq(field, "real")) {
141 valueKind_ = ValueKind::kReal;
142 } else if (streq(field, "integer")) {
143 valueKind_ = ValueKind::kInteger;
144 } else if (streq(field, "complex")) {
145 valueKind_ = ValueKind::kComplex;
146 } else {
147 fprintf(stderr, "Unexpected header field value in %s\n", filename);
148 exit(1);
149 }
150 // Set properties.
151 isSymmetric_ = streq(symmetry, "symmetric");
152 // Make sure this is a general sparse matrix.
153 if (strne(header, "%%matrixmarket") || strne(object, "matrix") ||
154 strne(format, "coordinate") ||
155 (strne(symmetry, "general") && !isSymmetric_)) {
156 fprintf(stderr, "Cannot find a general sparse matrix in %s\n", filename);
157 exit(1);
158 }
159 // Skip comments.
160 while (true) {
161 readLine();
162 if (line[0] != '%')
163 break;
164 }
165 // Next line contains M N NNZ.
166 idata[0] = 2; // rank
167 if (sscanf(line, "%" PRIu64 "%" PRIu64 "%" PRIu64 "\n", idata + 2, idata + 3,
168 idata + 1) != 3) {
169 fprintf(stderr, "Cannot find size in %s\n", filename);
170 exit(1);
171 }
172}
173
174/// Read the "extended" FROSTT header. Although not part of the documented
175/// format, we assume that the file starts with optional comments followed
176/// by two lines that define the rank, the number of nonzeros, and the
177/// dimensions sizes (one per rank) of the sparse tensor.
178void SparseTensorReader::readExtFROSTTHeader() {
179 // Skip comments.
180 while (true) {
181 readLine();
182 if (line[0] != '#')
183 break;
184 }
185 // Next line contains RANK and NNZ.
186 if (sscanf(line, "%" PRIu64 "%" PRIu64 "\n", idata, idata + 1) != 2) {
187 fprintf(stderr, "Cannot find metadata in %s\n", filename);
188 exit(1);
189 }
190 // Followed by a line with the dimension sizes (one per rank).
191 for (uint64_t r = 0; r < idata[0]; r++) {
192 if (fscanf(file, "%" PRIu64, idata + 2 + r) != 1) {
193 fprintf(stderr, "Cannot find dimension size %s\n", filename);
194 exit(1);
195 }
196 }
197 readLine(); // end of line
198 // The FROSTT format does not define the data type of the nonzero elements.
199 valueKind_ = ValueKind::kUndefined;
200}
static bool streq(const char *lhs, const char *rhs)
Idiomatic name for checking string equality.
Definition File.cpp:109
static bool strne(const char *lhs, const char *rhs)
Idiomatic name for checking string inequality.
Definition File.cpp:114
static void toLower(char *token)
Helper to convert C-style strings (i.e., '\0' terminated) to lower case.
Definition File.cpp:103
lhs
void assertMatchesShape(uint64_t rank, const uint64_t *shape) const
Asserts the shape subsumes the actual dimension sizes.
Definition File.cpp:65
void closeFile()
Closes the file.
Definition File.cpp:34
void readHeader()
Reads and parses the file's header.
Definition File.cpp:50
bool canReadAs(PrimaryType valTy) const
Checks if the file's ValueKind can be converted into the given tensor PrimaryType.
Definition File.cpp:73
bool isValid() const
Checks if a header has been successfully read.
Definition File.h:146
uint64_t getRank() const
Gets the dimension-rank of the tensor.
Definition File.h:168
void openFile()
Opens the file for reading.
Definition File.cpp:21
PrimaryType
Encoding of the elemental type, for "overloading" @newSparseTensor.
Definition Enums.h:82
constexpr bool isRealPrimaryType(PrimaryType valTy)
Definition Enums.h:137
constexpr bool isComplexPrimaryType(PrimaryType valTy)
Definition Enums.h:141
constexpr bool isFloatingPrimaryType(PrimaryType valTy)
Definition Enums.h:129