MLIR  21.0.0git
IRDL.cpp
Go to the documentation of this file.
1 //===- IRDL.cpp - IRDL dialect ----------------------------------*- C++ -*-===//
2 //
3 // This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/Diagnostics.h"
16 #include "mlir/IR/OpDefinition.h"
18 #include "mlir/IR/Operation.h"
19 #include "mlir/Support/LLVM.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SetOperations.h"
22 #include "llvm/ADT/SmallString.h"
23 #include "llvm/ADT/StringExtras.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 #include "llvm/IR/Metadata.h"
26 #include "llvm/Support/Casting.h"
27 
28 using namespace mlir;
29 using namespace mlir::irdl;
30 
31 //===----------------------------------------------------------------------===//
32 // IRDL dialect.
33 //===----------------------------------------------------------------------===//
34 
35 #include "mlir/Dialect/IRDL/IR/IRDL.cpp.inc"
36 
37 #include "mlir/Dialect/IRDL/IR/IRDLDialect.cpp.inc"
38 
39 void IRDLDialect::initialize() {
40  addOperations<
41 #define GET_OP_LIST
42 #include "mlir/Dialect/IRDL/IR/IRDLOps.cpp.inc"
43  >();
44  addTypes<
45 #define GET_TYPEDEF_LIST
46 #include "mlir/Dialect/IRDL/IR/IRDLTypesGen.cpp.inc"
47  >();
48  addAttributes<
49 #define GET_ATTRDEF_LIST
50 #include "mlir/Dialect/IRDL/IR/IRDLAttributes.cpp.inc"
51  >();
52 }
53 
54 //===----------------------------------------------------------------------===//
55 // Parsing/Printing/Verifying
56 //===----------------------------------------------------------------------===//
57 
58 /// Parse a region, and add a single block if the region is empty.
59 /// If no region is parsed, create a new region with a single empty block.
60 static ParseResult parseSingleBlockRegion(OpAsmParser &p, Region &region) {
61  auto regionParseRes = p.parseOptionalRegion(region);
62  if (regionParseRes.has_value() && failed(regionParseRes.value()))
63  return failure();
64 
65  // If the region is empty, add a single empty block.
66  if (region.empty())
67  region.push_back(new Block());
68 
69  return success();
70 }
71 
73  Region &region) {
74  if (!region.getBlocks().front().empty())
75  p.printRegion(region);
76 }
77 
78 LogicalResult DialectOp::verify() {
79  if (!Dialect::isValidNamespace(getName()))
80  return emitOpError("invalid dialect name");
81  return success();
82 }
83 
84 LogicalResult OperationOp::verifyRegions() {
85  // Stores pairs of value kinds and the list of names of values of this kind in
86  // the operation.
88 
89  auto insertNames = [&](StringRef kind, ArrayAttr names) {
90  llvm::SmallDenseSet<StringRef> nameSet;
91  nameSet.reserve(names.size());
92  for (auto name : names)
93  nameSet.insert(llvm::cast<StringAttr>(name).getValue());
94  valueNames.emplace_back(kind, std::move(nameSet));
95  };
96 
97  for (Operation &op : getBody().getOps()) {
99  .Case<OperandsOp>(
100  [&](OperandsOp op) { insertNames("operands", op.getNames()); })
101  .Case<ResultsOp>(
102  [&](ResultsOp op) { insertNames("results", op.getNames()); })
103  .Case<RegionsOp>(
104  [&](RegionsOp op) { insertNames("regions", op.getNames()); });
105  }
106 
107  // Verify that no two operand, result or region share the same name.
108  // The absence of duplicates within each value kind is checked by the
109  // associated operation's verifier.
110  for (size_t i : llvm::seq(valueNames.size())) {
111  for (size_t j : llvm::seq(i + 1, valueNames.size())) {
112  auto [lhs, lhsSet] = valueNames[i];
113  auto &[rhs, rhsSet] = valueNames[j];
114  llvm::set_intersect(lhsSet, rhsSet);
115  if (!lhsSet.empty())
116  return emitOpError("contains a value named '")
117  << *lhsSet.begin() << "' for both its " << lhs << " and " << rhs;
118  }
119  }
120 
121  return success();
122 }
123 
124 static LogicalResult verifyNames(Operation *op, StringRef kindName,
125  ArrayAttr names, size_t numOperands) {
126  if (numOperands != names.size())
127  return op->emitOpError()
128  << "the number of " << kindName
129  << "s and their names must be "
130  "the same, but got "
131  << numOperands << " and " << names.size() << " respectively";
132 
134  for (auto [i, name] : llvm::enumerate(names)) {
135  StringRef nameRef = llvm::cast<StringAttr>(name).getValue();
136  if (nameRef.empty())
137  return op->emitOpError()
138  << "name of " << kindName << " #" << i << " is empty";
139  if (!llvm::isAlpha(nameRef[0]) && nameRef[0] != '_')
140  return op->emitOpError()
141  << "name of " << kindName << " #" << i
142  << " must start with either a letter or an underscore";
143  if (llvm::any_of(nameRef,
144  [](char c) { return !llvm::isAlnum(c) && c != '_'; }))
145  return op->emitOpError()
146  << "name of " << kindName << " #" << i
147  << " must contain only letters, digits and underscores";
148  if (nameMap.contains(nameRef))
149  return op->emitOpError() << "name of " << kindName << " #" << i
150  << " is a duplicate of the name of " << kindName
151  << " #" << nameMap[nameRef];
152  nameMap.insert({nameRef, i});
153  }
154 
155  return success();
156 }
157 
158 LogicalResult ParametersOp::verify() {
159  return verifyNames(*this, "parameter", getNames(), getNumOperands());
160 }
161 
162 template <typename ValueListOp>
163 static LogicalResult verifyOperandsResultsCommon(ValueListOp op,
164  StringRef kindName) {
165  size_t numVariadicities = op.getVariadicity().size();
166  size_t numOperands = op.getNumOperands();
167 
168  if (numOperands != numVariadicities)
169  return op.emitOpError()
170  << "the number of " << kindName
171  << "s and their variadicities must be "
172  "the same, but got "
173  << numOperands << " and " << numVariadicities << " respectively";
174 
175  return verifyNames(op, kindName, op.getNames(), numOperands);
176 }
177 
178 LogicalResult OperandsOp::verify() {
179  return verifyOperandsResultsCommon(*this, "operand");
180 }
181 
182 LogicalResult ResultsOp::verify() {
183  return verifyOperandsResultsCommon(*this, "result");
184 }
185 
186 LogicalResult AttributesOp::verify() {
187  size_t namesSize = getAttributeValueNames().size();
188  size_t valuesSize = getAttributeValues().size();
189 
190  if (namesSize != valuesSize)
191  return emitOpError()
192  << "the number of attribute names and their constraints must be "
193  "the same but got "
194  << namesSize << " and " << valuesSize << " respectively";
195 
196  return success();
197 }
198 
199 LogicalResult BaseOp::verify() {
200  std::optional<StringRef> baseName = getBaseName();
201  std::optional<SymbolRefAttr> baseRef = getBaseRef();
202  if (baseName.has_value() == baseRef.has_value())
203  return emitOpError() << "the base type or attribute should be specified by "
204  "either a name or a reference";
205 
206  if (baseName &&
207  (baseName->empty() || ((*baseName)[0] != '!' && (*baseName)[0] != '#')))
208  return emitOpError() << "the base type or attribute name should start with "
209  "'!' or '#'";
210 
211  return success();
212 }
213 
214 /// Finds whether the provided symbol is an IRDL type or attribute definition.
215 /// The source operation must be within a DialectOp.
216 static LogicalResult
218  Operation *source, SymbolRefAttr symbol) {
219  Operation *targetOp =
220  irdl::lookupSymbolNearDialect(symbolTable, source, symbol);
221 
222  if (!targetOp)
223  return source->emitOpError() << "symbol '" << symbol << "' not found";
224 
225  if (!isa<TypeOp, AttributeOp>(targetOp))
226  return source->emitOpError() << "symbol '" << symbol
227  << "' does not refer to a type or attribute "
228  "definition (refers to '"
229  << targetOp->getName() << "')";
230 
231  return success();
232 }
233 
234 LogicalResult BaseOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
235  std::optional<SymbolRefAttr> baseRef = getBaseRef();
236  if (!baseRef)
237  return success();
238 
239  return checkSymbolIsTypeOrAttribute(symbolTable, *this, *baseRef);
240 }
241 
242 LogicalResult
243 ParametricOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
244  std::optional<SymbolRefAttr> baseRef = getBaseType();
245  if (!baseRef)
246  return success();
247 
248  return checkSymbolIsTypeOrAttribute(symbolTable, *this, *baseRef);
249 }
250 
251 /// Parse a value with its variadicity first. By default, the variadicity is
252 /// single.
253 ///
254 /// value-with-variadicity ::= ("single" | "optional" | "variadic")? ssa-value
255 static ParseResult
258  VariadicityAttr &variadicityAttr) {
259  MLIRContext *ctx = p.getBuilder().getContext();
260 
261  // Parse the variadicity, if present
262  if (p.parseOptionalKeyword("single").succeeded()) {
263  variadicityAttr = VariadicityAttr::get(ctx, Variadicity::single);
264  } else if (p.parseOptionalKeyword("optional").succeeded()) {
265  variadicityAttr = VariadicityAttr::get(ctx, Variadicity::optional);
266  } else if (p.parseOptionalKeyword("variadic").succeeded()) {
267  variadicityAttr = VariadicityAttr::get(ctx, Variadicity::variadic);
268  } else {
269  variadicityAttr = VariadicityAttr::get(ctx, Variadicity::single);
270  }
271 
272  // Parse the value
273  if (p.parseOperand(operand))
274  return failure();
275  return success();
276 }
277 
278 static ParseResult parseNamedValueListImpl(
280  ArrayAttr &valueNamesAttr, VariadicityArrayAttr *variadicityAttr) {
281  Builder &builder = p.getBuilder();
282  MLIRContext *ctx = builder.getContext();
283  SmallVector<Attribute> valueNames;
284  SmallVector<VariadicityAttr> variadicities;
285 
286  // Parse a single value with its variadicity
287  auto parseOne = [&] {
288  StringRef name;
290  VariadicityAttr variadicity;
291  if (p.parseKeyword(&name) || p.parseColon())
292  return failure();
293 
294  if (variadicityAttr) {
295  if (parseValueWithVariadicity(p, operand, variadicity))
296  return failure();
297  variadicities.push_back(variadicity);
298  } else {
299  if (p.parseOperand(operand))
300  return failure();
301  }
302 
303  valueNames.push_back(StringAttr::get(ctx, name));
304  operands.push_back(operand);
305  return success();
306  };
307 
309  return failure();
310  valueNamesAttr = ArrayAttr::get(ctx, valueNames);
311  if (variadicityAttr)
312  *variadicityAttr = VariadicityArrayAttr::get(ctx, variadicities);
313  return success();
314 }
315 
316 /// Parse a list of named values.
317 ///
318 /// values ::=
319 /// `(` (named-value (`,` named-value)*)? `)`
320 /// named-value := bare-id `:` ssa-value
321 static ParseResult
324  ArrayAttr &valueNamesAttr) {
325  return parseNamedValueListImpl(p, operands, valueNamesAttr, nullptr);
326 }
327 
328 /// Parse a list of named values with their variadicities first. By default, the
329 /// variadicity is single.
330 ///
331 /// values-with-variadicity ::=
332 /// `(` (value-with-variadicity (`,` value-with-variadicity)*)? `)`
333 /// value-with-variadicity
334 /// ::= bare-id `:` ("single" | "optional" | "variadic")? ssa-value
337  ArrayAttr &valueNamesAttr, VariadicityArrayAttr &variadicityAttr) {
338  return parseNamedValueListImpl(p, operands, valueNamesAttr, &variadicityAttr);
339 }
340 
342  OperandRange operands,
343  ArrayAttr valueNamesAttr,
344  VariadicityArrayAttr variadicityAttr) {
345  p << "(";
346  interleaveComma(llvm::seq<int>(0, operands.size()), p, [&](int i) {
347  p << llvm::cast<StringAttr>(valueNamesAttr[i]).getValue() << ": ";
348  if (variadicityAttr) {
349  Variadicity variadicity = variadicityAttr[i].getValue();
350  if (variadicity != Variadicity::single) {
351  p << stringifyVariadicity(variadicity) << " ";
352  }
353  }
354  p << operands[i];
355  });
356  p << ")";
357 }
358 
359 /// Print a list of named values.
360 ///
361 /// values ::=
362 /// `(` (named-value (`,` named-value)*)? `)`
363 /// named-value := bare-id `:` ssa-value
365  OperandRange operands,
366  ArrayAttr valueNamesAttr) {
367  printNamedValueListImpl(p, op, operands, valueNamesAttr, nullptr);
368 }
369 
370 /// Print a list of named values with their variadicities first. By default, the
371 /// variadicity is single.
372 ///
373 /// values-with-variadicity ::=
374 /// `(` (value-with-variadicity (`,` value-with-variadicity)*)? `)`
375 /// value-with-variadicity ::=
376 /// bare-id `:` ("single" | "optional" | "variadic")? ssa-value
378  OpAsmPrinter &p, Operation *op, OperandRange operands,
379  ArrayAttr valueNamesAttr, VariadicityArrayAttr variadicityAttr) {
380  printNamedValueListImpl(p, op, operands, valueNamesAttr, variadicityAttr);
381 }
382 
383 static ParseResult
386  ArrayAttr &attrNamesAttr) {
387  Builder &builder = p.getBuilder();
388  SmallVector<Attribute> attrNames;
389  if (succeeded(p.parseOptionalLBrace())) {
390  auto parseOperands = [&]() {
391  if (p.parseAttribute(attrNames.emplace_back()) || p.parseEqual() ||
392  p.parseOperand(attrOperands.emplace_back()))
393  return failure();
394  return success();
395  };
396  if (p.parseCommaSeparatedList(parseOperands) || p.parseRBrace())
397  return failure();
398  }
399  attrNamesAttr = builder.getArrayAttr(attrNames);
400  return success();
401 }
402 
403 static void printAttributesOp(OpAsmPrinter &p, AttributesOp op,
404  OperandRange attrArgs, ArrayAttr attrNames) {
405  if (attrNames.empty())
406  return;
407  p << "{";
408  interleaveComma(llvm::seq<int>(0, attrNames.size()), p,
409  [&](int i) { p << attrNames[i] << " = " << attrArgs[i]; });
410  p << '}';
411 }
412 
413 LogicalResult RegionOp::verify() {
414  if (IntegerAttr numberOfBlocks = getNumberOfBlocksAttr())
415  if (int64_t number = numberOfBlocks.getInt(); number <= 0) {
416  return emitOpError("the number of blocks is expected to be >= 1 but got ")
417  << number;
418  }
419  return success();
420 }
421 
422 LogicalResult RegionsOp::verify() {
423  return verifyNames(*this, "region", getNames(), getNumOperands());
424 }
425 
426 #include "mlir/Dialect/IRDL/IR/IRDLInterfaces.cpp.inc"
427 
428 #define GET_TYPEDEF_CLASSES
429 #include "mlir/Dialect/IRDL/IR/IRDLTypesGen.cpp.inc"
430 
431 #include "mlir/Dialect/IRDL/IR/IRDLEnums.cpp.inc"
432 
433 #define GET_ATTRDEF_CLASSES
434 #include "mlir/Dialect/IRDL/IR/IRDLAttributes.cpp.inc"
435 
436 #define GET_OP_CLASSES
437 #include "mlir/Dialect/IRDL/IR/IRDLOps.cpp.inc"
static LogicalResult verifyNames(Operation *op, StringRef kindName, ArrayAttr names, size_t numOperands)
Definition: IRDL.cpp:124
static ParseResult parseValueWithVariadicity(OpAsmParser &p, OpAsmParser::UnresolvedOperand &operand, VariadicityAttr &variadicityAttr)
Parse a value with its variadicity first.
Definition: IRDL.cpp:256
static void printAttributesOp(OpAsmPrinter &p, AttributesOp op, OperandRange attrArgs, ArrayAttr attrNames)
Definition: IRDL.cpp:403
static LogicalResult checkSymbolIsTypeOrAttribute(SymbolTableCollection &symbolTable, Operation *source, SymbolRefAttr symbol)
Finds whether the provided symbol is an IRDL type or attribute definition.
Definition: IRDL.cpp:217
static LogicalResult verifyOperandsResultsCommon(ValueListOp op, StringRef kindName)
Definition: IRDL.cpp:163
static ParseResult parseSingleBlockRegion(OpAsmParser &p, Region &region)
Parse a region, and add a single block if the region is empty.
Definition: IRDL.cpp:60
static ParseResult parseAttributesOp(OpAsmParser &p, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &attrOperands, ArrayAttr &attrNamesAttr)
Definition: IRDL.cpp:384
static void printNamedValueList(OpAsmPrinter &p, Operation *op, OperandRange operands, ArrayAttr valueNamesAttr)
Print a list of named values.
Definition: IRDL.cpp:364
static void printNamedValueListWithVariadicity(OpAsmPrinter &p, Operation *op, OperandRange operands, ArrayAttr valueNamesAttr, VariadicityArrayAttr variadicityAttr)
Print a list of named values with their variadicities first.
Definition: IRDL.cpp:377
static ParseResult parseNamedValueList(OpAsmParser &p, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, ArrayAttr &valueNamesAttr)
Parse a list of named values.
Definition: IRDL.cpp:322
static ParseResult parseNamedValueListImpl(OpAsmParser &p, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, ArrayAttr &valueNamesAttr, VariadicityArrayAttr *variadicityAttr)
Definition: IRDL.cpp:278
static void printNamedValueListImpl(OpAsmPrinter &p, Operation *op, OperandRange operands, ArrayAttr valueNamesAttr, VariadicityArrayAttr variadicityAttr)
Definition: IRDL.cpp:341
static ParseResult parseNamedValueListWithVariadicity(OpAsmParser &p, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, ArrayAttr &valueNamesAttr, VariadicityArrayAttr &variadicityAttr)
Parse a list of named values with their variadicities first.
Definition: IRDL.cpp:335
static void printSingleBlockRegion(OpAsmPrinter &p, Operation *op, Region &region)
Definition: IRDL.cpp:72
union mlir::linalg::@1195::ArityGroupAndKind::Kind kind
@ Paren
Parens surrounding zero or more operands.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColon()=0
Parse a : token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
Block represents an ordered list of Operations.
Definition: Block.h:33
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
MLIRContext * getContext() const
Definition: Builders.h:56
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:262
static bool isValidNamespace(StringRef str)
Utility function that returns if the given string is a valid dialect namespace.
Definition: Dialect.cpp:98
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region if present.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
void push_back(Block *block)
Definition: Region.h:61
bool empty()
Definition: Region.h:60
BlockListType & getBlocks()
Definition: Region.h:45
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Operation * lookupSymbolNearDialect(SymbolTableCollection &symbolTable, Operation *source, SymbolRefAttr symbol)
Looks up a symbol from the symbol table containing the source operation's dialect definition operatio...
Definition: IRDLSymbols.cpp:28
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:424
This is the representation of an operand reference.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.