MLIR  21.0.0git
IRDL.cpp
Go to the documentation of this file.
1 //===- IRDL.cpp - IRDL dialect ----------------------------------*- C++ -*-===//
2 //
3 // This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/Diagnostics.h"
16 #include "mlir/IR/OpDefinition.h"
18 #include "mlir/IR/Operation.h"
19 #include "mlir/Support/LLVM.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SetOperations.h"
22 #include "llvm/ADT/SmallString.h"
23 #include "llvm/ADT/StringExtras.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 #include "llvm/IR/Metadata.h"
26 #include "llvm/Support/Casting.h"
27 
28 using namespace mlir;
29 using namespace mlir::irdl;
30 
31 //===----------------------------------------------------------------------===//
32 // IRDL dialect.
33 //===----------------------------------------------------------------------===//
34 
35 #include "mlir/Dialect/IRDL/IR/IRDL.cpp.inc"
36 
37 #include "mlir/Dialect/IRDL/IR/IRDLDialect.cpp.inc"
38 
39 void IRDLDialect::initialize() {
40  addOperations<
41 #define GET_OP_LIST
42 #include "mlir/Dialect/IRDL/IR/IRDLOps.cpp.inc"
43  >();
44  addTypes<
45 #define GET_TYPEDEF_LIST
46 #include "mlir/Dialect/IRDL/IR/IRDLTypesGen.cpp.inc"
47  >();
48  addAttributes<
49 #define GET_ATTRDEF_LIST
50 #include "mlir/Dialect/IRDL/IR/IRDLAttributes.cpp.inc"
51  >();
52 }
53 
54 //===----------------------------------------------------------------------===//
55 // Parsing/Printing/Verifying
56 //===----------------------------------------------------------------------===//
57 
58 /// Parse a region, and add a single block if the region is empty.
59 /// If no region is parsed, create a new region with a single empty block.
60 static ParseResult parseSingleBlockRegion(OpAsmParser &p, Region &region) {
61  auto regionParseRes = p.parseOptionalRegion(region);
62  if (regionParseRes.has_value() && failed(regionParseRes.value()))
63  return failure();
64 
65  // If the region is empty, add a single empty block.
66  if (region.empty())
67  region.push_back(new Block());
68 
69  return success();
70 }
71 
73  Region &region) {
74  if (!region.getBlocks().front().empty())
75  p.printRegion(region);
76 }
77 static llvm::LogicalResult isValidName(llvm::StringRef in, mlir::Operation *loc,
78  const Twine &label) {
79  if (in.empty())
80  return loc->emitError("name of ") << label << " is empty";
81 
82  bool allowUnderscore = false;
83  for (auto &elem : in) {
84  if (elem == '_') {
85  if (!allowUnderscore)
86  return loc->emitError("name of ")
87  << label << " should not contain leading or double underscores";
88  } else {
89  if (!isalnum(elem))
90  return loc->emitError("name of ")
91  << label
92  << " must contain only lowercase letters, digits and "
93  "underscores";
94 
95  if (llvm::isUpper(elem))
96  return loc->emitError("name of ")
97  << label << " should not contain uppercase letters";
98  }
99 
100  allowUnderscore = elem != '_';
101  }
102 
103  return success();
104 }
105 
106 LogicalResult DialectOp::verify() {
107  if (!Dialect::isValidNamespace(getName()))
108  return emitOpError("invalid dialect name");
109  if (failed(isValidName(getSymName(), getOperation(), "dialect")))
110  return failure();
111 
112  return success();
113 }
114 
115 LogicalResult OperationOp::verify() {
116  return isValidName(getSymName(), getOperation(), "operation");
117 }
118 
119 LogicalResult TypeOp::verify() {
120  auto symName = getSymName();
121  if (symName.front() == '!')
122  symName = symName.substr(1);
123  return isValidName(symName, getOperation(), "type");
124 }
125 
126 LogicalResult AttributeOp::verify() {
127  auto symName = getSymName();
128  if (symName.front() == '#')
129  symName = symName.substr(1);
130  return isValidName(symName, getOperation(), "attribute");
131 }
132 
133 LogicalResult OperationOp::verifyRegions() {
134  // Stores pairs of value kinds and the list of names of values of this kind in
135  // the operation.
137 
138  auto insertNames = [&](StringRef kind, ArrayAttr names) {
139  llvm::SmallDenseSet<StringRef> nameSet;
140  nameSet.reserve(names.size());
141  for (auto name : names)
142  nameSet.insert(llvm::cast<StringAttr>(name).getValue());
143  valueNames.emplace_back(kind, std::move(nameSet));
144  };
145 
146  for (Operation &op : getBody().getOps()) {
148  .Case<OperandsOp>(
149  [&](OperandsOp op) { insertNames("operands", op.getNames()); })
150  .Case<ResultsOp>(
151  [&](ResultsOp op) { insertNames("results", op.getNames()); })
152  .Case<RegionsOp>(
153  [&](RegionsOp op) { insertNames("regions", op.getNames()); });
154  }
155 
156  // Verify that no two operand, result or region share the same name.
157  // The absence of duplicates within each value kind is checked by the
158  // associated operation's verifier.
159  for (size_t i : llvm::seq(valueNames.size())) {
160  for (size_t j : llvm::seq(i + 1, valueNames.size())) {
161  auto [lhs, lhsSet] = valueNames[i];
162  auto &[rhs, rhsSet] = valueNames[j];
163  llvm::set_intersect(lhsSet, rhsSet);
164  if (!lhsSet.empty())
165  return emitOpError("contains a value named '")
166  << *lhsSet.begin() << "' for both its " << lhs << " and " << rhs;
167  }
168  }
169 
170  return success();
171 }
172 
173 static LogicalResult verifyNames(Operation *op, StringRef kindName,
174  ArrayAttr names, size_t numOperands) {
175  if (numOperands != names.size())
176  return op->emitOpError()
177  << "the number of " << kindName
178  << "s and their names must be "
179  "the same, but got "
180  << numOperands << " and " << names.size() << " respectively";
181 
183  for (auto [i, name] : llvm::enumerate(names)) {
184  StringRef nameRef = llvm::cast<StringAttr>(name).getValue();
185 
186  if (failed(isValidName(nameRef, op, Twine(kindName) + " #" + Twine(i))))
187  return failure();
188 
189  if (nameMap.contains(nameRef))
190  return op->emitOpError() << "name of " << kindName << " #" << i
191  << " is a duplicate of the name of " << kindName
192  << " #" << nameMap[nameRef];
193  nameMap.insert({nameRef, i});
194  }
195 
196  return success();
197 }
198 
199 LogicalResult ParametersOp::verify() {
200  return verifyNames(*this, "parameter", getNames(), getNumOperands());
201 }
202 
203 template <typename ValueListOp>
204 static LogicalResult verifyOperandsResultsCommon(ValueListOp op,
205  StringRef kindName) {
206  size_t numVariadicities = op.getVariadicity().size();
207  size_t numOperands = op.getNumOperands();
208 
209  if (numOperands != numVariadicities)
210  return op.emitOpError()
211  << "the number of " << kindName
212  << "s and their variadicities must be "
213  "the same, but got "
214  << numOperands << " and " << numVariadicities << " respectively";
215 
216  return verifyNames(op, kindName, op.getNames(), numOperands);
217 }
218 
219 LogicalResult OperandsOp::verify() {
220  return verifyOperandsResultsCommon(*this, "operand");
221 }
222 
223 LogicalResult ResultsOp::verify() {
224  return verifyOperandsResultsCommon(*this, "result");
225 }
226 
227 LogicalResult AttributesOp::verify() {
228  size_t namesSize = getAttributeValueNames().size();
229  size_t valuesSize = getAttributeValues().size();
230 
231  if (namesSize != valuesSize)
232  return emitOpError()
233  << "the number of attribute names and their constraints must be "
234  "the same but got "
235  << namesSize << " and " << valuesSize << " respectively";
236 
237  return success();
238 }
239 
240 LogicalResult BaseOp::verify() {
241  std::optional<StringRef> baseName = getBaseName();
242  std::optional<SymbolRefAttr> baseRef = getBaseRef();
243  if (baseName.has_value() == baseRef.has_value())
244  return emitOpError() << "the base type or attribute should be specified by "
245  "either a name or a reference";
246 
247  if (baseName &&
248  (baseName->empty() || ((*baseName)[0] != '!' && (*baseName)[0] != '#')))
249  return emitOpError() << "the base type or attribute name should start with "
250  "'!' or '#'";
251 
252  return success();
253 }
254 
255 /// Finds whether the provided symbol is an IRDL type or attribute definition.
256 /// The source operation must be within a DialectOp.
257 static LogicalResult
259  Operation *source, SymbolRefAttr symbol) {
260  Operation *targetOp =
261  irdl::lookupSymbolNearDialect(symbolTable, source, symbol);
262 
263  if (!targetOp)
264  return source->emitOpError() << "symbol '" << symbol << "' not found";
265 
266  if (!isa<TypeOp, AttributeOp>(targetOp))
267  return source->emitOpError() << "symbol '" << symbol
268  << "' does not refer to a type or attribute "
269  "definition (refers to '"
270  << targetOp->getName() << "')";
271 
272  return success();
273 }
274 
275 LogicalResult BaseOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
276  std::optional<SymbolRefAttr> baseRef = getBaseRef();
277  if (!baseRef)
278  return success();
279 
280  return checkSymbolIsTypeOrAttribute(symbolTable, *this, *baseRef);
281 }
282 
283 LogicalResult
284 ParametricOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
285  std::optional<SymbolRefAttr> baseRef = getBaseType();
286  if (!baseRef)
287  return success();
288 
289  return checkSymbolIsTypeOrAttribute(symbolTable, *this, *baseRef);
290 }
291 
292 /// Parse a value with its variadicity first. By default, the variadicity is
293 /// single.
294 ///
295 /// value-with-variadicity ::= ("single" | "optional" | "variadic")? ssa-value
296 static ParseResult
299  VariadicityAttr &variadicityAttr) {
300  MLIRContext *ctx = p.getBuilder().getContext();
301 
302  // Parse the variadicity, if present
303  if (p.parseOptionalKeyword("single").succeeded()) {
304  variadicityAttr = VariadicityAttr::get(ctx, Variadicity::single);
305  } else if (p.parseOptionalKeyword("optional").succeeded()) {
306  variadicityAttr = VariadicityAttr::get(ctx, Variadicity::optional);
307  } else if (p.parseOptionalKeyword("variadic").succeeded()) {
308  variadicityAttr = VariadicityAttr::get(ctx, Variadicity::variadic);
309  } else {
310  variadicityAttr = VariadicityAttr::get(ctx, Variadicity::single);
311  }
312 
313  // Parse the value
314  if (p.parseOperand(operand))
315  return failure();
316  return success();
317 }
318 
319 static ParseResult parseNamedValueListImpl(
321  ArrayAttr &valueNamesAttr, VariadicityArrayAttr *variadicityAttr) {
322  Builder &builder = p.getBuilder();
323  MLIRContext *ctx = builder.getContext();
324  SmallVector<Attribute> valueNames;
325  SmallVector<VariadicityAttr> variadicities;
326 
327  // Parse a single value with its variadicity
328  auto parseOne = [&] {
329  StringRef name;
331  VariadicityAttr variadicity;
332  if (p.parseKeyword(&name) || p.parseColon())
333  return failure();
334 
335  if (variadicityAttr) {
336  if (parseValueWithVariadicity(p, operand, variadicity))
337  return failure();
338  variadicities.push_back(variadicity);
339  } else {
340  if (p.parseOperand(operand))
341  return failure();
342  }
343 
344  valueNames.push_back(StringAttr::get(ctx, name));
345  operands.push_back(operand);
346  return success();
347  };
348 
350  return failure();
351  valueNamesAttr = ArrayAttr::get(ctx, valueNames);
352  if (variadicityAttr)
353  *variadicityAttr = VariadicityArrayAttr::get(ctx, variadicities);
354  return success();
355 }
356 
357 /// Parse a list of named values.
358 ///
359 /// values ::=
360 /// `(` (named-value (`,` named-value)*)? `)`
361 /// named-value := bare-id `:` ssa-value
362 static ParseResult
365  ArrayAttr &valueNamesAttr) {
366  return parseNamedValueListImpl(p, operands, valueNamesAttr, nullptr);
367 }
368 
369 /// Parse a list of named values with their variadicities first. By default, the
370 /// variadicity is single.
371 ///
372 /// values-with-variadicity ::=
373 /// `(` (value-with-variadicity (`,` value-with-variadicity)*)? `)`
374 /// value-with-variadicity
375 /// ::= bare-id `:` ("single" | "optional" | "variadic")? ssa-value
378  ArrayAttr &valueNamesAttr, VariadicityArrayAttr &variadicityAttr) {
379  return parseNamedValueListImpl(p, operands, valueNamesAttr, &variadicityAttr);
380 }
381 
383  OperandRange operands,
384  ArrayAttr valueNamesAttr,
385  VariadicityArrayAttr variadicityAttr) {
386  p << "(";
387  interleaveComma(llvm::seq<int>(0, operands.size()), p, [&](int i) {
388  p << llvm::cast<StringAttr>(valueNamesAttr[i]).getValue() << ": ";
389  if (variadicityAttr) {
390  Variadicity variadicity = variadicityAttr[i].getValue();
391  if (variadicity != Variadicity::single) {
392  p << stringifyVariadicity(variadicity) << " ";
393  }
394  }
395  p << operands[i];
396  });
397  p << ")";
398 }
399 
400 /// Print a list of named values.
401 ///
402 /// values ::=
403 /// `(` (named-value (`,` named-value)*)? `)`
404 /// named-value := bare-id `:` ssa-value
406  OperandRange operands,
407  ArrayAttr valueNamesAttr) {
408  printNamedValueListImpl(p, op, operands, valueNamesAttr, nullptr);
409 }
410 
411 /// Print a list of named values with their variadicities first. By default, the
412 /// variadicity is single.
413 ///
414 /// values-with-variadicity ::=
415 /// `(` (value-with-variadicity (`,` value-with-variadicity)*)? `)`
416 /// value-with-variadicity ::=
417 /// bare-id `:` ("single" | "optional" | "variadic")? ssa-value
419  OpAsmPrinter &p, Operation *op, OperandRange operands,
420  ArrayAttr valueNamesAttr, VariadicityArrayAttr variadicityAttr) {
421  printNamedValueListImpl(p, op, operands, valueNamesAttr, variadicityAttr);
422 }
423 
424 static ParseResult
427  ArrayAttr &attrNamesAttr) {
428  Builder &builder = p.getBuilder();
429  SmallVector<Attribute> attrNames;
430  if (succeeded(p.parseOptionalLBrace())) {
431  auto parseOperands = [&]() {
432  if (p.parseAttribute(attrNames.emplace_back()) || p.parseEqual() ||
433  p.parseOperand(attrOperands.emplace_back()))
434  return failure();
435  return success();
436  };
437  if (p.parseCommaSeparatedList(parseOperands) || p.parseRBrace())
438  return failure();
439  }
440  attrNamesAttr = builder.getArrayAttr(attrNames);
441  return success();
442 }
443 
444 static void printAttributesOp(OpAsmPrinter &p, AttributesOp op,
445  OperandRange attrArgs, ArrayAttr attrNames) {
446  if (attrNames.empty())
447  return;
448  p << "{";
449  interleaveComma(llvm::seq<int>(0, attrNames.size()), p,
450  [&](int i) { p << attrNames[i] << " = " << attrArgs[i]; });
451  p << '}';
452 }
453 
454 LogicalResult RegionOp::verify() {
455  if (IntegerAttr numberOfBlocks = getNumberOfBlocksAttr())
456  if (int64_t number = numberOfBlocks.getInt(); number <= 0) {
457  return emitOpError("the number of blocks is expected to be >= 1 but got ")
458  << number;
459  }
460  return success();
461 }
462 
463 LogicalResult RegionsOp::verify() {
464  return verifyNames(*this, "region", getNames(), getNumOperands());
465 }
466 
467 #include "mlir/Dialect/IRDL/IR/IRDLInterfaces.cpp.inc"
468 
469 #define GET_TYPEDEF_CLASSES
470 #include "mlir/Dialect/IRDL/IR/IRDLTypesGen.cpp.inc"
471 
472 #include "mlir/Dialect/IRDL/IR/IRDLEnums.cpp.inc"
473 
474 #define GET_ATTRDEF_CLASSES
475 #include "mlir/Dialect/IRDL/IR/IRDLAttributes.cpp.inc"
476 
477 #define GET_OP_CLASSES
478 #include "mlir/Dialect/IRDL/IR/IRDLOps.cpp.inc"
static LogicalResult verifyNames(Operation *op, StringRef kindName, ArrayAttr names, size_t numOperands)
Definition: IRDL.cpp:173
static ParseResult parseValueWithVariadicity(OpAsmParser &p, OpAsmParser::UnresolvedOperand &operand, VariadicityAttr &variadicityAttr)
Parse a value with its variadicity first.
Definition: IRDL.cpp:297
static void printAttributesOp(OpAsmPrinter &p, AttributesOp op, OperandRange attrArgs, ArrayAttr attrNames)
Definition: IRDL.cpp:444
static llvm::LogicalResult isValidName(llvm::StringRef in, mlir::Operation *loc, const Twine &label)
Definition: IRDL.cpp:77
static LogicalResult checkSymbolIsTypeOrAttribute(SymbolTableCollection &symbolTable, Operation *source, SymbolRefAttr symbol)
Finds whether the provided symbol is an IRDL type or attribute definition.
Definition: IRDL.cpp:258
static LogicalResult verifyOperandsResultsCommon(ValueListOp op, StringRef kindName)
Definition: IRDL.cpp:204
static ParseResult parseSingleBlockRegion(OpAsmParser &p, Region &region)
Parse a region, and add a single block if the region is empty.
Definition: IRDL.cpp:60
static ParseResult parseAttributesOp(OpAsmParser &p, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &attrOperands, ArrayAttr &attrNamesAttr)
Definition: IRDL.cpp:425
static void printNamedValueList(OpAsmPrinter &p, Operation *op, OperandRange operands, ArrayAttr valueNamesAttr)
Print a list of named values.
Definition: IRDL.cpp:405
static void printNamedValueListWithVariadicity(OpAsmPrinter &p, Operation *op, OperandRange operands, ArrayAttr valueNamesAttr, VariadicityArrayAttr variadicityAttr)
Print a list of named values with their variadicities first.
Definition: IRDL.cpp:418
static ParseResult parseNamedValueList(OpAsmParser &p, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, ArrayAttr &valueNamesAttr)
Parse a list of named values.
Definition: IRDL.cpp:363
static ParseResult parseNamedValueListImpl(OpAsmParser &p, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, ArrayAttr &valueNamesAttr, VariadicityArrayAttr *variadicityAttr)
Definition: IRDL.cpp:319
static void printNamedValueListImpl(OpAsmPrinter &p, Operation *op, OperandRange operands, ArrayAttr valueNamesAttr, VariadicityArrayAttr variadicityAttr)
Definition: IRDL.cpp:382
static ParseResult parseNamedValueListWithVariadicity(OpAsmParser &p, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, ArrayAttr &valueNamesAttr, VariadicityArrayAttr &variadicityAttr)
Parse a list of named values with their variadicities first.
Definition: IRDL.cpp:376
static void printSingleBlockRegion(OpAsmPrinter &p, Operation *op, Region &region)
Definition: IRDL.cpp:72
union mlir::linalg::@1205::ArityGroupAndKind::Kind kind
@ Paren
Parens surrounding zero or more operands.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColon()=0
Parse a : token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
Block represents an ordered list of Operations.
Definition: Block.h:33
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
MLIRContext * getContext() const
Definition: Builders.h:55
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:264
static bool isValidNamespace(StringRef str)
Utility function that returns if the given string is a valid dialect namespace.
Definition: Dialect.cpp:98
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region if present.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:673
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
void push_back(Block *block)
Definition: Region.h:61
bool empty()
Definition: Region.h:60
BlockListType & getBlocks()
Definition: Region.h:45
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Operation * lookupSymbolNearDialect(SymbolTableCollection &symbolTable, Operation *source, SymbolRefAttr symbol)
Looks up a symbol from the symbol table containing the source operation's dialect definition operatio...
Definition: IRDLSymbols.cpp:28
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
This is the representation of an operand reference.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.