MLIR  19.0.0git
DimLvlMapParser.cpp
Go to the documentation of this file.
1 //===- DimLvlMapParser.cpp - `DimLvlMap` parser implementation ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "DimLvlMapParser.h"
10 
11 using namespace mlir;
12 using namespace mlir::sparse_tensor;
13 using namespace mlir::sparse_tensor::ir_detail;
14 
15 #define FAILURE_IF_FAILED(RES) \
16  if (failed(RES)) { \
17  return failure(); \
18  }
19 
20 /// Helper function for `FAILURE_IF_NULLOPT_OR_FAILED` to avoid duplicating
21 /// its `RES` parameter.
22 static inline bool didntSucceed(OptionalParseResult res) {
23  return !res.has_value() || failed(*res);
24 }
25 
26 #define FAILURE_IF_NULLOPT_OR_FAILED(RES) \
27  if (didntSucceed(RES)) { \
28  return failure(); \
29  }
30 
31 // NOTE: this macro assumes `AsmParser parser` and `SMLoc loc` are in scope.
32 #define ERROR_IF(COND, MSG) \
33  if (COND) { \
34  return parser.emitError(loc, MSG); \
35  }
36 
37 //===----------------------------------------------------------------------===//
38 // `DimLvlMapParser` implementation for variable parsing.
39 //===----------------------------------------------------------------------===//
40 
41 // Our variation on `AffineParser::{parseBareIdExpr,parseIdentifierDefinition}`
42 OptionalParseResult DimLvlMapParser::parseVar(VarKind vk, bool isOptional,
43  Policy creationPolicy,
44  VarInfo::ID &varID,
45  bool &didCreate) {
46  // Save the current location so that we can have error messages point to
47  // the right place.
48  const auto loc = parser.getCurrentLocation();
49  StringRef name;
50  if (failed(parser.parseOptionalKeyword(&name))) {
51  ERROR_IF(!isOptional, "expected bare identifier")
52  return std::nullopt;
53  }
54 
55  if (const auto res = env.lookupOrCreate(creationPolicy, name, loc, vk)) {
56  varID = res->first;
57  didCreate = res->second;
58  return success();
59  }
60 
61  switch (creationPolicy) {
62  case Policy::MustNot:
63  return parser.emitError(loc, "use of undeclared identifier '" + name + "'");
64  case Policy::May:
65  llvm_unreachable("got nullopt for Policy::May");
66  case Policy::Must:
67  return parser.emitError(loc, "redefinition of identifier '" + name + "'");
68  }
69  llvm_unreachable("unknown Policy");
70 }
71 
72 FailureOr<VarInfo::ID> DimLvlMapParser::parseVarUsage(VarKind vk,
73  bool requireKnown) {
74  VarInfo::ID id;
75  bool didCreate;
76  const bool isOptional = false;
77  const auto creationPolicy = requireKnown ? Policy::MustNot : Policy::May;
78  const auto res = parseVar(vk, isOptional, creationPolicy, id, didCreate);
80  assert(requireKnown ? !didCreate : true);
81  return id;
82 }
83 
84 FailureOr<VarInfo::ID> DimLvlMapParser::parseVarBinding(VarKind vk,
85  bool requireKnown) {
86  const auto loc = parser.getCurrentLocation();
87  VarInfo::ID id;
88  bool didCreate;
89  const bool isOptional = false;
90  const auto creationPolicy = requireKnown ? Policy::MustNot : Policy::Must;
91  const auto res = parseVar(vk, isOptional, creationPolicy, id, didCreate);
93  assert(requireKnown ? !didCreate : didCreate);
94  bindVar(loc, id);
95  return id;
96 }
97 
99 DimLvlMapParser::parseOptionalVarBinding(VarKind vk, bool requireKnown) {
100  const auto loc = parser.getCurrentLocation();
101  VarInfo::ID id;
102  bool didCreate;
103  const bool isOptional = true;
104  const auto creationPolicy = requireKnown ? Policy::MustNot : Policy::Must;
105  const auto res = parseVar(vk, isOptional, creationPolicy, id, didCreate);
106  if (res.has_value()) {
107  FAILURE_IF_FAILED(*res)
108  assert(didCreate);
109  return std::make_pair(bindVar(loc, id), true);
110  }
111  assert(!didCreate);
112  return std::make_pair(env.bindUnusedVar(vk), false);
113 }
114 
115 Var DimLvlMapParser::bindVar(llvm::SMLoc loc, VarInfo::ID id) {
116  MLIRContext *context = parser.getContext();
117  const auto var = env.bindVar(id);
118  const auto &info = std::as_const(env).access(id);
119  const auto name = info.getName();
120  const auto num = *info.getNum();
121  switch (info.getKind()) {
122  case VarKind::Symbol: {
123  const auto affine = getAffineSymbolExpr(num, context);
124  dimsAndSymbols.emplace_back(name, affine);
125  lvlsAndSymbols.emplace_back(name, affine);
126  return var;
127  }
128  case VarKind::Dimension:
129  dimsAndSymbols.emplace_back(name, getAffineDimExpr(num, context));
130  return var;
131  case VarKind::Level:
132  lvlsAndSymbols.emplace_back(name, getAffineDimExpr(num, context));
133  return var;
134  }
135  llvm_unreachable("unknown VarKind");
136 }
137 
138 //===----------------------------------------------------------------------===//
139 // `DimLvlMapParser` implementation for `DimLvlMap` per se.
140 //===----------------------------------------------------------------------===//
141 
143  FAILURE_IF_FAILED(parseSymbolBindingList())
144  FAILURE_IF_FAILED(parseLvlVarBindingList())
145  FAILURE_IF_FAILED(parseDimSpecList())
146  FAILURE_IF_FAILED(parser.parseArrow())
147  FAILURE_IF_FAILED(parseLvlSpecList())
148  InFlightDiagnostic ifd = env.emitErrorIfAnyUnbound(parser);
149  if (failed(ifd))
150  return ifd;
151  return DimLvlMap(env.getRanks().getSymRank(), dimSpecs, lvlSpecs);
152 }
153 
154 ParseResult DimLvlMapParser::parseSymbolBindingList() {
155  return parser.parseCommaSeparatedList(
157  [this]() { return ParseResult(parseVarBinding(VarKind::Symbol)); },
158  " in symbol binding list");
159 }
160 
161 ParseResult DimLvlMapParser::parseLvlVarBindingList() {
162  return parser.parseCommaSeparatedList(
164  [this]() { return ParseResult(parseVarBinding(VarKind::Level)); },
165  " in level declaration list");
166 }
167 
168 //===----------------------------------------------------------------------===//
169 // `DimLvlMapParser` implementation for `DimSpec`.
170 //===----------------------------------------------------------------------===//
171 
172 ParseResult DimLvlMapParser::parseDimSpecList() {
173  return parser.parseCommaSeparatedList(
175  [this]() -> ParseResult { return parseDimSpec(); },
176  " in dimension-specifier list");
177 }
178 
179 ParseResult DimLvlMapParser::parseDimSpec() {
180  // Parse the requisite dim-var binding.
181  const auto varID = parseVarBinding(VarKind::Dimension);
182  FAILURE_IF_FAILED(varID)
183  const DimVar var = env.getVar(*varID).cast<DimVar>();
184 
185  // Parse an optional dimension expression.
186  AffineExpr affine;
187  if (succeeded(parser.parseOptionalEqual())) {
188  // Parse the dim affine expr, with only any lvl-vars in scope.
189  FAILURE_IF_FAILED(parser.parseAffineExpr(lvlsAndSymbols, affine))
190  }
191  DimExpr expr{affine};
192 
193  // Parse an optional slice.
194  SparseTensorDimSliceAttr slice;
195  if (succeeded(parser.parseOptionalColon())) {
196  const auto loc = parser.getCurrentLocation();
197  Attribute attr;
198  FAILURE_IF_FAILED(parser.parseAttribute(attr))
199  slice = llvm::dyn_cast<SparseTensorDimSliceAttr>(attr);
200  ERROR_IF(!slice, "expected SparseTensorDimSliceAttr")
201  }
202 
203  dimSpecs.emplace_back(var, expr, slice);
204  return success();
205 }
206 
207 //===----------------------------------------------------------------------===//
208 // `DimLvlMapParser` implementation for `LvlSpec`.
209 //===----------------------------------------------------------------------===//
210 
211 ParseResult DimLvlMapParser::parseLvlSpecList() {
212  // This method currently only supports two syntaxes:
213  //
214  // (1) There are no forward-declarations, and no lvl-var bindings:
215  // (d0, d1) -> (d0 : dense, d1 : compressed)
216  // Therefore `parseLvlVarBindingList` didn't bind any lvl-vars, and thus
217  // `parseLvlSpec` will need to use `VarEnv::bindUnusedVar` to ensure that
218  // the level-rank is correct at the end of parsing.
219  //
220  // (2) There are forward-declarations, and every lvl-spec must have
221  // a lvl-var binding:
222  // {l0, l1} (d0 = l0, d1 = l1) -> (l0 = d0 : dense, l1 = d1 : compressed)
223  // However, this introduces duplicate information since the order of
224  // the lvl-vars in `parseLvlVarBindingList` must agree with their order
225  // in the list of lvl-specs. Therefore, `parseLvlSpec` will not call
226  // `VarEnv::bindVar` (since `parseLvlVarBindingList` already did so),
227  // and must also validate the consistency between the two lvl-var orders.
228  const auto declaredLvlRank = env.getRanks().getLvlRank();
229  const bool requireLvlVarBinding = declaredLvlRank != 0;
230  // Have `ERROR_IF` point to the start of the list.
231  const auto loc = parser.getCurrentLocation();
232  const auto res = parser.parseCommaSeparatedList(
234  [=]() -> ParseResult { return parseLvlSpec(requireLvlVarBinding); },
235  " in level-specifier list");
236  FAILURE_IF_FAILED(res)
237  const auto specLvlRank = lvlSpecs.size();
238  ERROR_IF(requireLvlVarBinding && specLvlRank != declaredLvlRank,
239  "Level-rank mismatch between forward-declarations and specifiers. "
240  "Declared " +
241  Twine(declaredLvlRank) + " level-variables; but got " +
242  Twine(specLvlRank) + " level-specifiers.")
243  return success();
244 }
245 
246 static inline Twine nth(Var::Num n) {
247  switch (n) {
248  case 1:
249  return "1st";
250  case 2:
251  return "2nd";
252  default:
253  return Twine(n) + "th";
254  }
255 }
256 
258 DimLvlMapParser::parseLvlVarBinding(bool requireLvlVarBinding) {
259  // Nothing to parse, just bind an unnamed variable.
260  if (!requireLvlVarBinding)
261  return env.bindUnusedVar(VarKind::Level).cast<LvlVar>();
262 
263  const auto loc = parser.getCurrentLocation();
264  // NOTE: Calling `parseVarUsage` here is semantically inappropriate,
265  // since the thing we're parsing is supposed to be a variable *binding*
266  // rather than a variable *use*. However, the call to `VarEnv::bindVar`
267  // (and its corresponding call to `DimLvlMapParser::recordVarBinding`)
268  // already occured in `parseLvlVarBindingList`, and therefore we must
269  // use `parseVarUsage` here in order to operationally do the right thing.
270  const auto varID = parseVarUsage(VarKind::Level, /*requireKnown=*/true);
271  FAILURE_IF_FAILED(varID)
272  const auto &info = std::as_const(env).access(*varID);
273  const auto var = info.getVar().cast<LvlVar>();
274  const auto forwardNum = var.getNum();
275  const auto specNum = lvlSpecs.size();
276  ERROR_IF(forwardNum != specNum,
277  "Level-variable ordering mismatch. The variable '" + info.getName() +
278  "' was forward-declared as the " + nth(forwardNum) +
279  " level; but is bound by the " + nth(specNum) +
280  " specification.")
281  FAILURE_IF_FAILED(parser.parseEqual())
282  return var;
283 }
284 
285 ParseResult DimLvlMapParser::parseLvlSpec(bool requireLvlVarBinding) {
286  // Parse the optional lvl-var binding. `requireLvlVarBinding`
287  // specifies whether that "optional" is actually Must or MustNot.
288  const auto varRes = parseLvlVarBinding(requireLvlVarBinding);
289  FAILURE_IF_FAILED(varRes)
290  const LvlVar var = *varRes;
291 
292  // Parse the lvl affine expr, with only the dim-vars in scope.
293  AffineExpr affine;
294  FAILURE_IF_FAILED(parser.parseAffineExpr(dimsAndSymbols, affine))
295  LvlExpr expr{affine};
296 
297  FAILURE_IF_FAILED(parser.parseColon())
298  const auto type = lvlTypeParser.parseLvlType(parser);
299  FAILURE_IF_FAILED(type)
300 
301  lvlSpecs.emplace_back(var, expr, static_cast<LevelType>(*type));
302  return success();
303 }
304 
305 //===----------------------------------------------------------------------===//
static bool didntSucceed(OptionalParseResult res)
Helper function for FAILURE_IF_NULLOPT_OR_FAILED to avoid duplicating its RES parameter.
#define FAILURE_IF_FAILED(RES)
#define ERROR_IF(COND, MSG)
static Twine nth(Var::Num n)
#define FAILURE_IF_NULLOPT_OR_FAILED(RES)
Base type for affine expression.
Definition: AffineExpr.h:69
@ Paren
Parens surrounding zero or more operands.
@ OptionalBraces
{} brackets surrounding zero or more operands, or nothing.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseAffineExpr(ArrayRef< std::pair< StringRef, AffineExpr >> symbolSet, AffineExpr &expr)=0
Parse an affine expr instance into 'expr' using the already computed mapping from symbols to affine e...
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:308
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
This class represents success/failure for parsing-like operations that find it important to chain tog...
Parses the Sparse Tensor Encoding Attribute (STEA).
constexpr unsigned getLvlRank() const
Definition: Var.h:231
constexpr unsigned getSymRank() const
Definition: Var.h:229
Var bindUnusedVar(VarKind vk)
Creates a new variable of the given kind and immediately binds it.
Definition: Var.cpp:221
Ranks getRanks() const
Returns the current ranks of bound variables.
Definition: Var.h:390
InFlightDiagnostic emitErrorIfAnyUnbound(AsmParser &parser) const
Definition: Var.cpp:229
std::optional< std::pair< VarInfo::ID, bool > > lookupOrCreate(Policy creationPolicy, StringRef name, llvm::SMLoc loc, VarKind vk)
Looks up or creates a variable according to the given Policy.
Definition: Var.cpp:197
Var getVar(VarInfo::ID id) const
Gets the Var identified by the VarInfo::ID, raising an assertion failure if the variable is not bound...
Definition: Var.h:394
Var bindVar(VarInfo::ID id)
Binds the given variable to the next free Var::Num for its VarKind.
Definition: Var.cpp:222
ID
Newtype for unique identifiers of VarInfo records, to ensure they aren't confused with Var::Num.
Definition: Var.h:279
A concrete variable, to be used in our variant of AffineExpr.
Definition: Var.h:61
constexpr Num getNum() const
Definition: Var.h:125
constexpr U cast() const
Definition: Var.h:188
VarKind
The three kinds of variables that Var can be.
Definition: Var.h:32
Include the generated interface declarations.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:603
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Definition: AffineExpr.cpp:613
This enum defines all the sparse representations supportable by the SparseTensor dialect.
Definition: Enums.h:238