MLIR  19.0.0git
IRDL.cpp
Go to the documentation of this file.
1 //===- IRDL.cpp - IRDL dialect ----------------------------------*- C++ -*-===//
2 //
3 // This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/Diagnostics.h"
16 #include "mlir/IR/OpDefinition.h"
18 #include "mlir/IR/Operation.h"
19 #include "mlir/Support/LLVM.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/IR/Metadata.h"
24 #include "llvm/Support/Casting.h"
25 
26 using namespace mlir;
27 using namespace mlir::irdl;
28 
29 //===----------------------------------------------------------------------===//
30 // IRDL dialect.
31 //===----------------------------------------------------------------------===//
32 
33 #include "mlir/Dialect/IRDL/IR/IRDL.cpp.inc"
34 
35 #include "mlir/Dialect/IRDL/IR/IRDLDialect.cpp.inc"
36 
37 void IRDLDialect::initialize() {
38  addOperations<
39 #define GET_OP_LIST
40 #include "mlir/Dialect/IRDL/IR/IRDLOps.cpp.inc"
41  >();
42  addTypes<
43 #define GET_TYPEDEF_LIST
44 #include "mlir/Dialect/IRDL/IR/IRDLTypesGen.cpp.inc"
45  >();
46  addAttributes<
47 #define GET_ATTRDEF_LIST
48 #include "mlir/Dialect/IRDL/IR/IRDLAttributes.cpp.inc"
49  >();
50 }
51 
52 //===----------------------------------------------------------------------===//
53 // Parsing/Printing
54 //===----------------------------------------------------------------------===//
55 
56 /// Parse a region, and add a single block if the region is empty.
57 /// If no region is parsed, create a new region with a single empty block.
59  auto regionParseRes = p.parseOptionalRegion(region);
60  if (regionParseRes.has_value() && failed(regionParseRes.value()))
61  return failure();
62 
63  // If the region is empty, add a single empty block.
64  if (region.empty())
65  region.push_back(new Block());
66 
67  return success();
68 }
69 
71  Region &region) {
72  if (!region.getBlocks().front().empty())
73  p.printRegion(region);
74 }
75 
77  if (!Dialect::isValidNamespace(getName()))
78  return emitOpError("invalid dialect name");
79  return success();
80 }
81 
83  size_t numVariadicities = getVariadicity().size();
84  size_t numOperands = getNumOperands();
85 
86  if (numOperands != numVariadicities)
87  return emitOpError()
88  << "the number of operands and their variadicities must be "
89  "the same, but got "
90  << numOperands << " and " << numVariadicities << " respectively";
91 
92  return success();
93 }
94 
96  size_t numVariadicities = getVariadicity().size();
97  size_t numOperands = this->getNumOperands();
98 
99  if (numOperands != numVariadicities)
100  return emitOpError()
101  << "the number of operands and their variadicities must be "
102  "the same, but got "
103  << numOperands << " and " << numVariadicities << " respectively";
104 
105  return success();
106 }
107 
109  size_t namesSize = getAttributeValueNames().size();
110  size_t valuesSize = getAttributeValues().size();
111 
112  if (namesSize != valuesSize)
113  return emitOpError()
114  << "the number of attribute names and their constraints must be "
115  "the same but got "
116  << namesSize << " and " << valuesSize << " respectively";
117 
118  return success();
119 }
120 
122  std::optional<StringRef> baseName = getBaseName();
123  std::optional<SymbolRefAttr> baseRef = getBaseRef();
124  if (baseName.has_value() == baseRef.has_value())
125  return emitOpError() << "the base type or attribute should be specified by "
126  "either a name or a reference";
127 
128  if (baseName &&
129  (baseName->empty() || ((*baseName)[0] != '!' && (*baseName)[0] != '#')))
130  return emitOpError() << "the base type or attribute name should start with "
131  "'!' or '#'";
132 
133  return success();
134 }
135 
136 /// Finds whether the provided symbol is an IRDL type or attribute definition.
137 /// The source operation must be within a DialectOp.
138 static LogicalResult
140  Operation *source, SymbolRefAttr symbol) {
141  Operation *targetOp =
142  irdl::lookupSymbolNearDialect(symbolTable, source, symbol);
143 
144  if (!targetOp)
145  return source->emitOpError() << "symbol '" << symbol << "' not found";
146 
147  if (!isa<TypeOp, AttributeOp>(targetOp))
148  return source->emitOpError() << "symbol '" << symbol
149  << "' does not refer to a type or attribute "
150  "definition (refers to '"
151  << targetOp->getName() << "')";
152 
153  return success();
154 }
155 
156 LogicalResult BaseOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
157  std::optional<SymbolRefAttr> baseRef = getBaseRef();
158  if (!baseRef)
159  return success();
160 
161  return checkSymbolIsTypeOrAttribute(symbolTable, *this, *baseRef);
162 }
163 
165 ParametricOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
166  std::optional<SymbolRefAttr> baseRef = getBaseType();
167  if (!baseRef)
168  return success();
169 
170  return checkSymbolIsTypeOrAttribute(symbolTable, *this, *baseRef);
171 }
172 
173 /// Parse a value with its variadicity first. By default, the variadicity is
174 /// single.
175 ///
176 /// value-with-variadicity ::= ("single" | "optional" | "variadic")? ssa-value
177 static ParseResult
180  VariadicityAttr &variadicityAttr) {
181  MLIRContext *ctx = p.getBuilder().getContext();
182 
183  // Parse the variadicity, if present
184  if (p.parseOptionalKeyword("single").succeeded()) {
185  variadicityAttr = VariadicityAttr::get(ctx, Variadicity::single);
186  } else if (p.parseOptionalKeyword("optional").succeeded()) {
187  variadicityAttr = VariadicityAttr::get(ctx, Variadicity::optional);
188  } else if (p.parseOptionalKeyword("variadic").succeeded()) {
189  variadicityAttr = VariadicityAttr::get(ctx, Variadicity::variadic);
190  } else {
191  variadicityAttr = VariadicityAttr::get(ctx, Variadicity::single);
192  }
193 
194  // Parse the value
195  if (p.parseOperand(operand))
196  return failure();
197  return success();
198 }
199 
200 /// Parse a list of values with their variadicities first. By default, the
201 /// variadicity is single.
202 ///
203 /// values-with-variadicity ::=
204 /// `(` (value-with-variadicity (`,` value-with-variadicity)*)? `)`
205 /// value-with-variadicity ::= ("single" | "optional" | "variadic")? ssa-value
208  VariadicityArrayAttr &variadicityAttr) {
209  Builder &builder = p.getBuilder();
210  MLIRContext *ctx = builder.getContext();
211  SmallVector<VariadicityAttr> variadicities;
212 
213  // Parse a single value with its variadicity
214  auto parseOne = [&] {
216  VariadicityAttr variadicity;
217  if (parseValueWithVariadicity(p, operand, variadicity))
218  return failure();
219  operands.push_back(operand);
220  variadicities.push_back(variadicity);
221  return success();
222  };
223 
225  return failure();
226  variadicityAttr = VariadicityArrayAttr::get(ctx, variadicities);
227  return success();
228 }
229 
230 /// Print a list of values with their variadicities first. By default, the
231 /// variadicity is single.
232 ///
233 /// values-with-variadicity ::=
234 /// `(` (value-with-variadicity (`,` value-with-variadicity)*)? `)`
235 /// value-with-variadicity ::= ("single" | "optional" | "variadic")? ssa-value
237  OperandRange operands,
238  VariadicityArrayAttr variadicityAttr) {
239  p << "(";
240  interleaveComma(llvm::seq<int>(0, operands.size()), p, [&](int i) {
241  Variadicity variadicity = variadicityAttr[i].getValue();
242  if (variadicity != Variadicity::single) {
243  p << stringifyVariadicity(variadicity) << " ";
244  }
245  p << operands[i];
246  });
247  p << ")";
248 }
249 
250 static ParseResult
253  ArrayAttr &attrNamesAttr) {
254  Builder &builder = p.getBuilder();
255  SmallVector<Attribute> attrNames;
256  if (succeeded(p.parseOptionalLBrace())) {
257  auto parseOperands = [&]() {
258  if (p.parseAttribute(attrNames.emplace_back()) || p.parseEqual() ||
259  p.parseOperand(attrOperands.emplace_back()))
260  return failure();
261  return success();
262  };
263  if (p.parseCommaSeparatedList(parseOperands) || p.parseRBrace())
264  return failure();
265  }
266  attrNamesAttr = builder.getArrayAttr(attrNames);
267  return success();
268 }
269 
270 static void printAttributesOp(OpAsmPrinter &p, AttributesOp op,
271  OperandRange attrArgs, ArrayAttr attrNames) {
272  if (attrNames.empty())
273  return;
274  p << "{";
275  interleaveComma(llvm::seq<int>(0, attrNames.size()), p,
276  [&](int i) { p << attrNames[i] << " = " << attrArgs[i]; });
277  p << '}';
278 }
279 
281  if (IntegerAttr numberOfBlocks = getNumberOfBlocksAttr())
282  if (int64_t number = numberOfBlocks.getInt(); number <= 0) {
283  return emitOpError("the number of blocks is expected to be >= 1 but got ")
284  << number;
285  }
286  return success();
287 }
288 
289 #include "mlir/Dialect/IRDL/IR/IRDLInterfaces.cpp.inc"
290 
291 #define GET_TYPEDEF_CLASSES
292 #include "mlir/Dialect/IRDL/IR/IRDLTypesGen.cpp.inc"
293 
294 #include "mlir/Dialect/IRDL/IR/IRDLEnums.cpp.inc"
295 
296 #define GET_ATTRDEF_CLASSES
297 #include "mlir/Dialect/IRDL/IR/IRDLAttributes.cpp.inc"
298 
299 #define GET_OP_CLASSES
300 #include "mlir/Dialect/IRDL/IR/IRDLOps.cpp.inc"
static ParseResult parseValueWithVariadicity(OpAsmParser &p, OpAsmParser::UnresolvedOperand &operand, VariadicityAttr &variadicityAttr)
Parse a value with its variadicity first.
Definition: IRDL.cpp:178
static void printAttributesOp(OpAsmPrinter &p, AttributesOp op, OperandRange attrArgs, ArrayAttr attrNames)
Definition: IRDL.cpp:270
static LogicalResult checkSymbolIsTypeOrAttribute(SymbolTableCollection &symbolTable, Operation *source, SymbolRefAttr symbol)
Finds whether the provided symbol is an IRDL type or attribute definition.
Definition: IRDL.cpp:139
static ParseResult parseValuesWithVariadicity(OpAsmParser &p, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, VariadicityArrayAttr &variadicityAttr)
Parse a list of values with their variadicities first.
Definition: IRDL.cpp:206
static ParseResult parseSingleBlockRegion(OpAsmParser &p, Region &region)
Parse a region, and add a single block if the region is empty.
Definition: IRDL.cpp:58
static ParseResult parseAttributesOp(OpAsmParser &p, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &attrOperands, ArrayAttr &attrNamesAttr)
Definition: IRDL.cpp:251
static void printSingleBlockRegion(OpAsmPrinter &p, Operation *op, Region &region)
Definition: IRDL.cpp:70
static void printValuesWithVariadicity(OpAsmPrinter &p, Operation *op, OperandRange operands, VariadicityArrayAttr variadicityAttr)
Print a list of values with their variadicities first.
Definition: IRDL.cpp:236
@ Paren
Parens surrounding zero or more operands.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
Block represents an ordered list of Operations.
Definition: Block.h:31
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
MLIRContext * getContext() const
Definition: Builders.h:55
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:273
static bool isValidNamespace(StringRef str)
Utility function that returns if the given string is a valid dialect namespace.
Definition: Dialect.cpp:92
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region if present.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
This class represents success/failure for parsing-like operations that find it important to chain tog...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
void push_back(Block *block)
Definition: Region.h:61
bool empty()
Definition: Region.h:60
BlockListType & getBlocks()
Definition: Region.h:45
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Operation * lookupSymbolNearDialect(SymbolTableCollection &symbolTable, Operation *source, SymbolRefAttr symbol)
Looks up a symbol from the symbol table containing the source operation's dialect definition operatio...
Definition: IRDLSymbols.cpp:28
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
bool succeeded() const
Returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:41
This is the representation of an operand reference.