MLIR  22.0.0git
IRDL.cpp
Go to the documentation of this file.
1 //===- IRDL.cpp - IRDL dialect ----------------------------------*- C++ -*-===//
2 //
3 // This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/Diagnostics.h"
16 #include "mlir/IR/OpDefinition.h"
18 #include "mlir/IR/Operation.h"
19 #include "mlir/Support/LLVM.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SetOperations.h"
22 #include "llvm/ADT/StringExtras.h"
23 #include "llvm/ADT/TypeSwitch.h"
24 #include "llvm/Support/Casting.h"
25 
26 using namespace mlir;
27 using namespace mlir::irdl;
28 
29 //===----------------------------------------------------------------------===//
30 // IRDL dialect.
31 //===----------------------------------------------------------------------===//
32 
33 #include "mlir/Dialect/IRDL/IR/IRDL.cpp.inc"
34 
35 #include "mlir/Dialect/IRDL/IR/IRDLDialect.cpp.inc"
36 
37 void IRDLDialect::initialize() {
38  addOperations<
39 #define GET_OP_LIST
40 #include "mlir/Dialect/IRDL/IR/IRDLOps.cpp.inc"
41  >();
42  addTypes<
43 #define GET_TYPEDEF_LIST
44 #include "mlir/Dialect/IRDL/IR/IRDLTypesGen.cpp.inc"
45  >();
46  addAttributes<
47 #define GET_ATTRDEF_LIST
48 #include "mlir/Dialect/IRDL/IR/IRDLAttributes.cpp.inc"
49  >();
50 }
51 
52 //===----------------------------------------------------------------------===//
53 // Parsing/Printing/Verifying
54 //===----------------------------------------------------------------------===//
55 
56 /// Parse a region, and add a single block if the region is empty.
57 /// If no region is parsed, create a new region with a single empty block.
58 static ParseResult parseSingleBlockRegion(OpAsmParser &p, Region &region) {
59  auto regionParseRes = p.parseOptionalRegion(region);
60  if (regionParseRes.has_value() && failed(regionParseRes.value()))
61  return failure();
62 
63  // If the region is empty, add a single empty block.
64  if (region.empty())
65  region.push_back(new Block());
66 
67  return success();
68 }
69 
71  Region &region) {
72  if (!region.getBlocks().front().empty())
73  p.printRegion(region);
74 }
75 static llvm::LogicalResult isValidName(llvm::StringRef in, mlir::Operation *loc,
76  const Twine &label) {
77  if (in.empty())
78  return loc->emitError("name of ") << label << " is empty";
79 
80  bool allowUnderscore = false;
81  for (auto &elem : in) {
82  if (elem == '_') {
83  if (!allowUnderscore)
84  return loc->emitError("name of ")
85  << label << " should not contain leading or double underscores";
86  } else {
87  if (!isalnum(elem))
88  return loc->emitError("name of ")
89  << label
90  << " must contain only lowercase letters, digits and "
91  "underscores";
92 
93  if (llvm::isUpper(elem))
94  return loc->emitError("name of ")
95  << label << " should not contain uppercase letters";
96  }
97 
98  allowUnderscore = elem != '_';
99  }
100 
101  return success();
102 }
103 
104 LogicalResult DialectOp::verify() {
105  if (!Dialect::isValidNamespace(getName()))
106  return emitOpError("invalid dialect name");
107  if (failed(isValidName(getSymName(), getOperation(), "dialect")))
108  return failure();
109 
110  return success();
111 }
112 
113 LogicalResult OperationOp::verify() {
114  return isValidName(getSymName(), getOperation(), "operation");
115 }
116 
117 LogicalResult TypeOp::verify() {
118  auto symName = getSymName();
119  if (symName.front() == '!')
120  symName = symName.substr(1);
121  return isValidName(symName, getOperation(), "type");
122 }
123 
124 LogicalResult AttributeOp::verify() {
125  auto symName = getSymName();
126  if (symName.front() == '#')
127  symName = symName.substr(1);
128  return isValidName(symName, getOperation(), "attribute");
129 }
130 
131 LogicalResult OperationOp::verifyRegions() {
132  // Stores pairs of value kinds and the list of names of values of this kind in
133  // the operation.
135 
136  auto insertNames = [&](StringRef kind, ArrayAttr names) {
137  llvm::SmallDenseSet<StringRef> nameSet;
138  nameSet.reserve(names.size());
139  for (auto name : names)
140  nameSet.insert(llvm::cast<StringAttr>(name).getValue());
141  valueNames.emplace_back(kind, std::move(nameSet));
142  };
143 
144  for (Operation &op : getBody().getOps()) {
146  .Case<OperandsOp>(
147  [&](OperandsOp op) { insertNames("operands", op.getNames()); })
148  .Case<ResultsOp>(
149  [&](ResultsOp op) { insertNames("results", op.getNames()); })
150  .Case<RegionsOp>(
151  [&](RegionsOp op) { insertNames("regions", op.getNames()); });
152  }
153 
154  // Verify that no two operand, result or region share the same name.
155  // The absence of duplicates within each value kind is checked by the
156  // associated operation's verifier.
157  for (size_t i : llvm::seq(valueNames.size())) {
158  for (size_t j : llvm::seq(i + 1, valueNames.size())) {
159  auto [lhs, lhsSet] = valueNames[i];
160  auto &[rhs, rhsSet] = valueNames[j];
161  llvm::set_intersect(lhsSet, rhsSet);
162  if (!lhsSet.empty())
163  return emitOpError("contains a value named '")
164  << *lhsSet.begin() << "' for both its " << lhs << " and " << rhs;
165  }
166  }
167 
168  return success();
169 }
170 
171 static LogicalResult verifyNames(Operation *op, StringRef kindName,
172  ArrayAttr names, size_t numOperands) {
173  if (numOperands != names.size())
174  return op->emitOpError()
175  << "the number of " << kindName
176  << "s and their names must be "
177  "the same, but got "
178  << numOperands << " and " << names.size() << " respectively";
179 
181  for (auto [i, name] : llvm::enumerate(names)) {
182  StringRef nameRef = llvm::cast<StringAttr>(name).getValue();
183 
184  if (failed(isValidName(nameRef, op, Twine(kindName) + " #" + Twine(i))))
185  return failure();
186 
187  if (nameMap.contains(nameRef))
188  return op->emitOpError() << "name of " << kindName << " #" << i
189  << " is a duplicate of the name of " << kindName
190  << " #" << nameMap[nameRef];
191  nameMap.insert({nameRef, i});
192  }
193 
194  return success();
195 }
196 
197 LogicalResult ParametersOp::verify() {
198  return verifyNames(*this, "parameter", getNames(), getNumOperands());
199 }
200 
201 template <typename ValueListOp>
202 static LogicalResult verifyOperandsResultsCommon(ValueListOp op,
203  StringRef kindName) {
204  size_t numVariadicities = op.getVariadicity().size();
205  size_t numOperands = op.getNumOperands();
206 
207  if (numOperands != numVariadicities)
208  return op.emitOpError()
209  << "the number of " << kindName
210  << "s and their variadicities must be "
211  "the same, but got "
212  << numOperands << " and " << numVariadicities << " respectively";
213 
214  return verifyNames(op, kindName, op.getNames(), numOperands);
215 }
216 
217 LogicalResult OperandsOp::verify() {
218  return verifyOperandsResultsCommon(*this, "operand");
219 }
220 
221 LogicalResult ResultsOp::verify() {
222  return verifyOperandsResultsCommon(*this, "result");
223 }
224 
225 LogicalResult AttributesOp::verify() {
226  size_t namesSize = getAttributeValueNames().size();
227  size_t valuesSize = getAttributeValues().size();
228 
229  if (namesSize != valuesSize)
230  return emitOpError()
231  << "the number of attribute names and their constraints must be "
232  "the same but got "
233  << namesSize << " and " << valuesSize << " respectively";
234 
235  return success();
236 }
237 
238 LogicalResult BaseOp::verify() {
239  std::optional<StringRef> baseName = getBaseName();
240  std::optional<SymbolRefAttr> baseRef = getBaseRef();
241  if (baseName.has_value() == baseRef.has_value())
242  return emitOpError() << "the base type or attribute should be specified by "
243  "either a name or a reference";
244 
245  if (baseName &&
246  (baseName->empty() || ((*baseName)[0] != '!' && (*baseName)[0] != '#')))
247  return emitOpError() << "the base type or attribute name should start with "
248  "'!' or '#'";
249 
250  return success();
251 }
252 
253 /// Finds whether the provided symbol is an IRDL type or attribute definition.
254 /// The source operation must be within a DialectOp.
255 static LogicalResult
257  Operation *source, SymbolRefAttr symbol) {
258  Operation *targetOp =
259  irdl::lookupSymbolNearDialect(symbolTable, source, symbol);
260 
261  if (!targetOp)
262  return source->emitOpError() << "symbol '" << symbol << "' not found";
263 
264  if (!isa<TypeOp, AttributeOp>(targetOp))
265  return source->emitOpError() << "symbol '" << symbol
266  << "' does not refer to a type or attribute "
267  "definition (refers to '"
268  << targetOp->getName() << "')";
269 
270  return success();
271 }
272 
273 LogicalResult BaseOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
274  std::optional<SymbolRefAttr> baseRef = getBaseRef();
275  if (!baseRef)
276  return success();
277 
278  return checkSymbolIsTypeOrAttribute(symbolTable, *this, *baseRef);
279 }
280 
281 LogicalResult
282 ParametricOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
283  std::optional<SymbolRefAttr> baseRef = getBaseType();
284  if (!baseRef)
285  return success();
286 
287  return checkSymbolIsTypeOrAttribute(symbolTable, *this, *baseRef);
288 }
289 
290 /// Parse a value with its variadicity first. By default, the variadicity is
291 /// single.
292 ///
293 /// value-with-variadicity ::= ("single" | "optional" | "variadic")? ssa-value
294 static ParseResult
297  VariadicityAttr &variadicityAttr) {
298  MLIRContext *ctx = p.getBuilder().getContext();
299 
300  // Parse the variadicity, if present
301  if (p.parseOptionalKeyword("single").succeeded()) {
302  variadicityAttr = VariadicityAttr::get(ctx, Variadicity::single);
303  } else if (p.parseOptionalKeyword("optional").succeeded()) {
304  variadicityAttr = VariadicityAttr::get(ctx, Variadicity::optional);
305  } else if (p.parseOptionalKeyword("variadic").succeeded()) {
306  variadicityAttr = VariadicityAttr::get(ctx, Variadicity::variadic);
307  } else {
308  variadicityAttr = VariadicityAttr::get(ctx, Variadicity::single);
309  }
310 
311  // Parse the value
312  if (p.parseOperand(operand))
313  return failure();
314  return success();
315 }
316 
317 static ParseResult parseNamedValueListImpl(
319  ArrayAttr &valueNamesAttr, VariadicityArrayAttr *variadicityAttr) {
320  Builder &builder = p.getBuilder();
321  MLIRContext *ctx = builder.getContext();
322  SmallVector<Attribute> valueNames;
323  SmallVector<VariadicityAttr> variadicities;
324 
325  // Parse a single value with its variadicity
326  auto parseOne = [&] {
327  StringRef name;
329  VariadicityAttr variadicity;
330  if (p.parseKeyword(&name) || p.parseColon())
331  return failure();
332 
333  if (variadicityAttr) {
334  if (parseValueWithVariadicity(p, operand, variadicity))
335  return failure();
336  variadicities.push_back(variadicity);
337  } else {
338  if (p.parseOperand(operand))
339  return failure();
340  }
341 
342  valueNames.push_back(StringAttr::get(ctx, name));
343  operands.push_back(operand);
344  return success();
345  };
346 
348  return failure();
349  valueNamesAttr = ArrayAttr::get(ctx, valueNames);
350  if (variadicityAttr)
351  *variadicityAttr = VariadicityArrayAttr::get(ctx, variadicities);
352  return success();
353 }
354 
355 /// Parse a list of named values.
356 ///
357 /// values ::=
358 /// `(` (named-value (`,` named-value)*)? `)`
359 /// named-value := bare-id `:` ssa-value
360 static ParseResult
363  ArrayAttr &valueNamesAttr) {
364  return parseNamedValueListImpl(p, operands, valueNamesAttr, nullptr);
365 }
366 
367 /// Parse a list of named values with their variadicities first. By default, the
368 /// variadicity is single.
369 ///
370 /// values-with-variadicity ::=
371 /// `(` (value-with-variadicity (`,` value-with-variadicity)*)? `)`
372 /// value-with-variadicity
373 /// ::= bare-id `:` ("single" | "optional" | "variadic")? ssa-value
376  ArrayAttr &valueNamesAttr, VariadicityArrayAttr &variadicityAttr) {
377  return parseNamedValueListImpl(p, operands, valueNamesAttr, &variadicityAttr);
378 }
379 
381  OperandRange operands,
382  ArrayAttr valueNamesAttr,
383  VariadicityArrayAttr variadicityAttr) {
384  p << "(";
385  interleaveComma(llvm::seq<int>(0, operands.size()), p, [&](int i) {
386  p << llvm::cast<StringAttr>(valueNamesAttr[i]).getValue() << ": ";
387  if (variadicityAttr) {
388  Variadicity variadicity = variadicityAttr[i].getValue();
389  if (variadicity != Variadicity::single) {
390  p << stringifyVariadicity(variadicity) << " ";
391  }
392  }
393  p << operands[i];
394  });
395  p << ")";
396 }
397 
398 /// Print a list of named values.
399 ///
400 /// values ::=
401 /// `(` (named-value (`,` named-value)*)? `)`
402 /// named-value := bare-id `:` ssa-value
404  OperandRange operands,
405  ArrayAttr valueNamesAttr) {
406  printNamedValueListImpl(p, op, operands, valueNamesAttr, nullptr);
407 }
408 
409 /// Print a list of named values with their variadicities first. By default, the
410 /// variadicity is single.
411 ///
412 /// values-with-variadicity ::=
413 /// `(` (value-with-variadicity (`,` value-with-variadicity)*)? `)`
414 /// value-with-variadicity ::=
415 /// bare-id `:` ("single" | "optional" | "variadic")? ssa-value
417  OpAsmPrinter &p, Operation *op, OperandRange operands,
418  ArrayAttr valueNamesAttr, VariadicityArrayAttr variadicityAttr) {
419  printNamedValueListImpl(p, op, operands, valueNamesAttr, variadicityAttr);
420 }
421 
422 static ParseResult
425  ArrayAttr &attrNamesAttr) {
426  Builder &builder = p.getBuilder();
427  SmallVector<Attribute> attrNames;
428  if (succeeded(p.parseOptionalLBrace())) {
429  auto parseOperands = [&]() {
430  if (p.parseAttribute(attrNames.emplace_back()) || p.parseEqual() ||
431  p.parseOperand(attrOperands.emplace_back()))
432  return failure();
433  return success();
434  };
435  if (p.parseCommaSeparatedList(parseOperands) || p.parseRBrace())
436  return failure();
437  }
438  attrNamesAttr = builder.getArrayAttr(attrNames);
439  return success();
440 }
441 
442 static void printAttributesOp(OpAsmPrinter &p, AttributesOp op,
443  OperandRange attrArgs, ArrayAttr attrNames) {
444  if (attrNames.empty())
445  return;
446  p << "{";
447  interleaveComma(llvm::seq<int>(0, attrNames.size()), p,
448  [&](int i) { p << attrNames[i] << " = " << attrArgs[i]; });
449  p << '}';
450 }
451 
452 LogicalResult RegionOp::verify() {
453  if (IntegerAttr numberOfBlocks = getNumberOfBlocksAttr())
454  if (int64_t number = numberOfBlocks.getInt(); number <= 0) {
455  return emitOpError("the number of blocks is expected to be >= 1 but got ")
456  << number;
457  }
458  return success();
459 }
460 
461 LogicalResult RegionsOp::verify() {
462  return verifyNames(*this, "region", getNames(), getNumOperands());
463 }
464 
465 #include "mlir/Dialect/IRDL/IR/IRDLInterfaces.cpp.inc"
466 
467 #define GET_TYPEDEF_CLASSES
468 #include "mlir/Dialect/IRDL/IR/IRDLTypesGen.cpp.inc"
469 
470 #include "mlir/Dialect/IRDL/IR/IRDLEnums.cpp.inc"
471 
472 #define GET_ATTRDEF_CLASSES
473 #include "mlir/Dialect/IRDL/IR/IRDLAttributes.cpp.inc"
474 
475 #define GET_OP_CLASSES
476 #include "mlir/Dialect/IRDL/IR/IRDLOps.cpp.inc"
static LogicalResult verifyNames(Operation *op, StringRef kindName, ArrayAttr names, size_t numOperands)
Definition: IRDL.cpp:171
static ParseResult parseValueWithVariadicity(OpAsmParser &p, OpAsmParser::UnresolvedOperand &operand, VariadicityAttr &variadicityAttr)
Parse a value with its variadicity first.
Definition: IRDL.cpp:295
static void printAttributesOp(OpAsmPrinter &p, AttributesOp op, OperandRange attrArgs, ArrayAttr attrNames)
Definition: IRDL.cpp:442
static llvm::LogicalResult isValidName(llvm::StringRef in, mlir::Operation *loc, const Twine &label)
Definition: IRDL.cpp:75
static LogicalResult checkSymbolIsTypeOrAttribute(SymbolTableCollection &symbolTable, Operation *source, SymbolRefAttr symbol)
Finds whether the provided symbol is an IRDL type or attribute definition.
Definition: IRDL.cpp:256
static LogicalResult verifyOperandsResultsCommon(ValueListOp op, StringRef kindName)
Definition: IRDL.cpp:202
static ParseResult parseSingleBlockRegion(OpAsmParser &p, Region &region)
Parse a region, and add a single block if the region is empty.
Definition: IRDL.cpp:58
static ParseResult parseAttributesOp(OpAsmParser &p, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &attrOperands, ArrayAttr &attrNamesAttr)
Definition: IRDL.cpp:423
static void printNamedValueList(OpAsmPrinter &p, Operation *op, OperandRange operands, ArrayAttr valueNamesAttr)
Print a list of named values.
Definition: IRDL.cpp:403
static void printNamedValueListWithVariadicity(OpAsmPrinter &p, Operation *op, OperandRange operands, ArrayAttr valueNamesAttr, VariadicityArrayAttr variadicityAttr)
Print a list of named values with their variadicities first.
Definition: IRDL.cpp:416
static ParseResult parseNamedValueList(OpAsmParser &p, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, ArrayAttr &valueNamesAttr)
Parse a list of named values.
Definition: IRDL.cpp:361
static ParseResult parseNamedValueListImpl(OpAsmParser &p, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, ArrayAttr &valueNamesAttr, VariadicityArrayAttr *variadicityAttr)
Definition: IRDL.cpp:317
static void printNamedValueListImpl(OpAsmPrinter &p, Operation *op, OperandRange operands, ArrayAttr valueNamesAttr, VariadicityArrayAttr variadicityAttr)
Definition: IRDL.cpp:380
static ParseResult parseNamedValueListWithVariadicity(OpAsmParser &p, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, ArrayAttr &valueNamesAttr, VariadicityArrayAttr &variadicityAttr)
Parse a list of named values with their variadicities first.
Definition: IRDL.cpp:374
static void printSingleBlockRegion(OpAsmPrinter &p, Operation *op, Region &region)
Definition: IRDL.cpp:70
union mlir::linalg::@1244::ArityGroupAndKind::Kind kind
@ Paren
Parens surrounding zero or more operands.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColon()=0
Parse a : token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
Block represents an ordered list of Operations.
Definition: Block.h:33
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
MLIRContext * getContext() const
Definition: Builders.h:56
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:265
static bool isValidNamespace(StringRef str)
Utility function that returns if the given string is a valid dialect namespace.
Definition: Dialect.cpp:95
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region if present.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:672
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
void push_back(Block *block)
Definition: Region.h:61
bool empty()
Definition: Region.h:60
BlockListType & getBlocks()
Definition: Region.h:45
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Operation * lookupSymbolNearDialect(SymbolTableCollection &symbolTable, Operation *source, SymbolRefAttr symbol)
Looks up a symbol from the symbol table containing the source operation's dialect definition operatio...
Definition: IRDLSymbols.cpp:28
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
This is the representation of an operand reference.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.