MLIR 23.0.0git
IRDL.cpp
Go to the documentation of this file.
1//===- IRDL.cpp - IRDL dialect ----------------------------------*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
11#include "mlir/IR/Builders.h"
13#include "mlir/IR/Diagnostics.h"
18#include "mlir/IR/Operation.h"
19#include "mlir/Support/LLVM.h"
20#include "llvm/ADT/STLExtras.h"
21#include "llvm/ADT/SetOperations.h"
22#include "llvm/ADT/StringExtras.h"
23#include "llvm/ADT/TypeSwitch.h"
24#include "llvm/Support/Casting.h"
25
26using namespace mlir;
27using namespace mlir::irdl;
28
29//===----------------------------------------------------------------------===//
30// IRDL dialect.
31//===----------------------------------------------------------------------===//
32
33#include "mlir/Dialect/IRDL/IR/IRDL.cpp.inc"
34
35#include "mlir/Dialect/IRDL/IR/IRDLDialect.cpp.inc"
36
37void IRDLDialect::initialize() {
38 addOperations<
39#define GET_OP_LIST
40#include "mlir/Dialect/IRDL/IR/IRDLOps.cpp.inc"
41 >();
42 addTypes<
43#define GET_TYPEDEF_LIST
44#include "mlir/Dialect/IRDL/IR/IRDLTypesGen.cpp.inc"
45 >();
46 addAttributes<
47#define GET_ATTRDEF_LIST
48#include "mlir/Dialect/IRDL/IR/IRDLAttributes.cpp.inc"
49 >();
50}
51
52//===----------------------------------------------------------------------===//
53// Parsing/Printing/Verifying
54//===----------------------------------------------------------------------===//
55
56/// Parse a region, and add a single block if the region is empty.
57/// If no region is parsed, create a new region with a single empty block.
58static ParseResult parseSingleBlockRegion(OpAsmParser &p, Region &region) {
59 auto regionParseRes = p.parseOptionalRegion(region);
60 if (regionParseRes.has_value() && failed(regionParseRes.value()))
61 return failure();
62
63 // If the region is empty, add a single empty block.
64 if (region.empty())
65 region.push_back(new Block());
66
67 return success();
68}
69
71 Region &region) {
72 if (!region.getBlocks().front().empty())
73 p.printRegion(region);
74}
75static llvm::LogicalResult isValidName(llvm::StringRef in, mlir::Operation *loc,
76 const Twine &label) {
77 if (in.empty())
78 return loc->emitError("name of ") << label << " is empty";
79
80 bool allowUnderscore = false;
81 for (auto &elem : in) {
82 if (elem == '_') {
83 if (!allowUnderscore)
84 return loc->emitError("name of ")
85 << label << " should not contain leading or double underscores";
86 } else {
87 if (!isalnum(elem))
88 return loc->emitError("name of ")
89 << label
90 << " must contain only lowercase letters, digits and "
91 "underscores";
92
93 if (llvm::isUpper(elem))
94 return loc->emitError("name of ")
95 << label << " should not contain uppercase letters";
96 }
97
98 allowUnderscore = elem != '_';
99 }
100
101 return success();
102}
103
104LogicalResult DialectOp::verify() {
105 if (!Dialect::isValidNamespace(getName()))
106 return emitOpError("invalid dialect name");
107 if (failed(isValidName(getSymName(), getOperation(), "dialect")))
108 return failure();
109
110 return success();
111}
112
113LogicalResult OperationOp::verify() {
114 return isValidName(getSymName(), getOperation(), "operation");
115}
116
117LogicalResult TypeOp::verify() {
118 auto symName = getSymName();
119 if (symName.front() == '!')
120 symName = symName.substr(1);
121 return isValidName(symName, getOperation(), "type");
122}
123
124LogicalResult AttributeOp::verify() {
125 auto symName = getSymName();
126 if (symName.front() == '#')
127 symName = symName.substr(1);
128 return isValidName(symName, getOperation(), "attribute");
129}
130
131LogicalResult OperationOp::verifyRegions() {
132 // Stores pairs of value kinds and the list of names of values of this kind in
133 // the operation.
135
136 auto insertNames = [&](StringRef kind, ArrayAttr names) {
137 llvm::SmallDenseSet<StringRef> nameSet;
138 nameSet.reserve(names.size());
139 for (auto name : names)
140 nameSet.insert(llvm::cast<StringAttr>(name).getValue());
141 valueNames.emplace_back(kind, std::move(nameSet));
142 };
143
144 for (Operation &op : getBody().getOps()) {
146 .Case([&](OperandsOp op) { insertNames("operands", op.getNames()); })
147 .Case([&](ResultsOp op) { insertNames("results", op.getNames()); })
148 .Case([&](RegionsOp op) { insertNames("regions", op.getNames()); });
149 }
150
151 // Verify that no two operand, result or region share the same name.
152 // The absence of duplicates within each value kind is checked by the
153 // associated operation's verifier.
154 for (size_t i : llvm::seq(valueNames.size())) {
155 for (size_t j : llvm::seq(i + 1, valueNames.size())) {
156 auto [lhs, lhsSet] = valueNames[i];
157 auto &[rhs, rhsSet] = valueNames[j];
158 llvm::set_intersect(lhsSet, rhsSet);
159 if (!lhsSet.empty())
160 return emitOpError("contains a value named '")
161 << *lhsSet.begin() << "' for both its " << lhs << " and " << rhs;
162 }
163 }
164
165 return success();
166}
167
168static LogicalResult verifyNames(Operation *op, StringRef kindName,
169 ArrayAttr names, size_t numOperands) {
170 if (numOperands != names.size())
171 return op->emitOpError()
172 << "the number of " << kindName
173 << "s and their names must be "
174 "the same, but got "
175 << numOperands << " and " << names.size() << " respectively";
176
178 for (auto [i, name] : llvm::enumerate(names)) {
179 StringRef nameRef = llvm::cast<StringAttr>(name).getValue();
180
181 if (failed(isValidName(nameRef, op, Twine(kindName) + " #" + Twine(i))))
182 return failure();
183
184 if (nameMap.contains(nameRef))
185 return op->emitOpError() << "name of " << kindName << " #" << i
186 << " is a duplicate of the name of " << kindName
187 << " #" << nameMap[nameRef];
188 nameMap.insert({nameRef, i});
189 }
190
191 return success();
192}
193
194LogicalResult ParametersOp::verify() {
195 return verifyNames(*this, "parameter", getNames(), getNumOperands());
196}
197
198template <typename ValueListOp>
199static LogicalResult verifyOperandsResultsCommon(ValueListOp op,
200 StringRef kindName) {
201 size_t numVariadicities = op.getVariadicity().size();
202 size_t numOperands = op.getNumOperands();
203
204 if (numOperands != numVariadicities)
205 return op.emitOpError()
206 << "the number of " << kindName
207 << "s and their variadicities must be "
208 "the same, but got "
209 << numOperands << " and " << numVariadicities << " respectively";
210
211 return verifyNames(op, kindName, op.getNames(), numOperands);
212}
213
214LogicalResult OperandsOp::verify() {
215 return verifyOperandsResultsCommon(*this, "operand");
216}
217
218LogicalResult ResultsOp::verify() {
219 return verifyOperandsResultsCommon(*this, "result");
220}
221
222LogicalResult AttributesOp::verify() {
223 size_t namesSize = getAttributeValueNames().size();
224 size_t valuesSize = getAttributeValues().size();
225
226 if (namesSize != valuesSize)
227 return emitOpError()
228 << "the number of attribute names and their constraints must be "
229 "the same but got "
230 << namesSize << " and " << valuesSize << " respectively";
231
232 return success();
233}
234
235LogicalResult BaseOp::verify() {
236 std::optional<StringRef> baseName = getBaseName();
237 std::optional<SymbolRefAttr> baseRef = getBaseRef();
238 if (baseName.has_value() == baseRef.has_value())
239 return emitOpError() << "the base type or attribute should be specified by "
240 "either a name or a reference";
241
242 if (baseName &&
243 (baseName->empty() || ((*baseName)[0] != '!' && (*baseName)[0] != '#')))
244 return emitOpError() << "the base type or attribute name should start with "
245 "'!' or '#'";
246
247 return success();
248}
249
250/// Finds whether the provided symbol is an IRDL type or attribute definition.
251/// The source operation must be within a DialectOp.
252static LogicalResult
254 Operation *source, SymbolRefAttr symbol) {
255 Operation *targetOp =
256 irdl::lookupSymbolNearDialect(symbolTable, source, symbol);
257
258 if (!targetOp)
259 return source->emitOpError() << "symbol '" << symbol << "' not found";
260
261 if (!isa<TypeOp, AttributeOp>(targetOp))
262 return source->emitOpError() << "symbol '" << symbol
263 << "' does not refer to a type or attribute "
264 "definition (refers to '"
265 << targetOp->getName() << "')";
266
267 return success();
268}
269
270LogicalResult BaseOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
271 std::optional<SymbolRefAttr> baseRef = getBaseRef();
272 if (!baseRef)
273 return success();
274
275 return checkSymbolIsTypeOrAttribute(symbolTable, *this, *baseRef);
276}
277
278LogicalResult
279ParametricOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
280 std::optional<SymbolRefAttr> baseRef = getBaseType();
281 if (!baseRef)
282 return success();
283
284 return checkSymbolIsTypeOrAttribute(symbolTable, *this, *baseRef);
285}
286
287/// Parse a value with its variadicity first. By default, the variadicity is
288/// single.
289///
290/// value-with-variadicity ::= ("single" | "optional" | "variadic")? ssa-value
291static ParseResult
294 VariadicityAttr &variadicityAttr) {
295 MLIRContext *ctx = p.getBuilder().getContext();
296
297 // Parse the variadicity, if present
298 if (p.parseOptionalKeyword("single").succeeded()) {
299 variadicityAttr = VariadicityAttr::get(ctx, Variadicity::single);
300 } else if (p.parseOptionalKeyword("optional").succeeded()) {
301 variadicityAttr = VariadicityAttr::get(ctx, Variadicity::optional);
302 } else if (p.parseOptionalKeyword("variadic").succeeded()) {
303 variadicityAttr = VariadicityAttr::get(ctx, Variadicity::variadic);
304 } else {
305 variadicityAttr = VariadicityAttr::get(ctx, Variadicity::single);
306 }
307
308 // Parse the value
309 if (p.parseOperand(operand))
310 return failure();
311 return success();
312}
313
314static ParseResult parseNamedValueListImpl(
316 ArrayAttr &valueNamesAttr, VariadicityArrayAttr *variadicityAttr) {
317 Builder &builder = p.getBuilder();
318 MLIRContext *ctx = builder.getContext();
319 SmallVector<Attribute> valueNames;
320 SmallVector<VariadicityAttr> variadicities;
321
322 // Parse a single value with its variadicity
323 auto parseOne = [&] {
324 StringRef name;
326 VariadicityAttr variadicity;
327 if (p.parseKeyword(&name) || p.parseColon())
328 return failure();
329
330 if (variadicityAttr) {
331 if (parseValueWithVariadicity(p, operand, variadicity))
332 return failure();
333 variadicities.push_back(variadicity);
334 } else {
335 if (p.parseOperand(operand))
336 return failure();
337 }
338
339 valueNames.push_back(StringAttr::get(ctx, name));
340 operands.push_back(operand);
341 return success();
342 };
343
345 return failure();
346 valueNamesAttr = ArrayAttr::get(ctx, valueNames);
347 if (variadicityAttr)
348 *variadicityAttr = VariadicityArrayAttr::get(ctx, variadicities);
349 return success();
350}
351
352/// Parse a list of named values.
353///
354/// values ::=
355/// `(` (named-value (`,` named-value)*)? `)`
356/// named-value := bare-id `:` ssa-value
357static ParseResult
360 ArrayAttr &valueNamesAttr) {
361 return parseNamedValueListImpl(p, operands, valueNamesAttr, nullptr);
362}
363
364/// Parse a list of named values with their variadicities first. By default, the
365/// variadicity is single.
366///
367/// values-with-variadicity ::=
368/// `(` (value-with-variadicity (`,` value-with-variadicity)*)? `)`
369/// value-with-variadicity
370/// ::= bare-id `:` ("single" | "optional" | "variadic")? ssa-value
373 ArrayAttr &valueNamesAttr, VariadicityArrayAttr &variadicityAttr) {
374 return parseNamedValueListImpl(p, operands, valueNamesAttr, &variadicityAttr);
375}
376
378 OperandRange operands,
379 ArrayAttr valueNamesAttr,
380 VariadicityArrayAttr variadicityAttr) {
381 p << "(";
382 interleaveComma(llvm::seq<int>(0, operands.size()), p, [&](int i) {
383 p << llvm::cast<StringAttr>(valueNamesAttr[i]).getValue() << ": ";
384 if (variadicityAttr) {
385 Variadicity variadicity = variadicityAttr[i].getValue();
386 if (variadicity != Variadicity::single) {
387 p << stringifyVariadicity(variadicity) << " ";
388 }
389 }
390 p << operands[i];
391 });
392 p << ")";
393}
394
395/// Print a list of named values.
396///
397/// values ::=
398/// `(` (named-value (`,` named-value)*)? `)`
399/// named-value := bare-id `:` ssa-value
401 OperandRange operands,
402 ArrayAttr valueNamesAttr) {
403 printNamedValueListImpl(p, op, operands, valueNamesAttr, nullptr);
404}
405
406/// Print a list of named values with their variadicities first. By default, the
407/// variadicity is single.
408///
409/// values-with-variadicity ::=
410/// `(` (value-with-variadicity (`,` value-with-variadicity)*)? `)`
411/// value-with-variadicity ::=
412/// bare-id `:` ("single" | "optional" | "variadic")? ssa-value
414 OpAsmPrinter &p, Operation *op, OperandRange operands,
415 ArrayAttr valueNamesAttr, VariadicityArrayAttr variadicityAttr) {
416 printNamedValueListImpl(p, op, operands, valueNamesAttr, variadicityAttr);
417}
418
419static ParseResult
422 ArrayAttr &attrNamesAttr) {
423 Builder &builder = p.getBuilder();
424 SmallVector<Attribute> attrNames;
425 if (succeeded(p.parseOptionalLBrace())) {
426 auto parseOperands = [&]() {
427 if (p.parseAttribute(attrNames.emplace_back()) || p.parseEqual() ||
428 p.parseOperand(attrOperands.emplace_back()))
429 return failure();
430 return success();
431 };
432 if (p.parseCommaSeparatedList(parseOperands) || p.parseRBrace())
433 return failure();
434 }
435 attrNamesAttr = builder.getArrayAttr(attrNames);
436 return success();
437}
438
439static void printAttributesOp(OpAsmPrinter &p, AttributesOp op,
440 OperandRange attrArgs, ArrayAttr attrNames) {
441 if (attrNames.empty())
442 return;
443 p << "{";
444 interleaveComma(llvm::seq<int>(0, attrNames.size()), p,
445 [&](int i) { p << attrNames[i] << " = " << attrArgs[i]; });
446 p << '}';
447}
448
449LogicalResult RegionOp::verify() {
450 if (IntegerAttr numberOfBlocks = getNumberOfBlocksAttr())
451 if (int64_t number = numberOfBlocks.getInt(); number <= 0) {
452 return emitOpError("the number of blocks is expected to be >= 1 but got ")
453 << number;
454 }
455 return success();
456}
457
458LogicalResult RegionsOp::verify() {
459 return verifyNames(*this, "region", getNames(), getNumOperands());
460}
461
462#include "mlir/Dialect/IRDL/IR/IRDLInterfaces.cpp.inc"
463
464#define GET_TYPEDEF_CLASSES
465#include "mlir/Dialect/IRDL/IR/IRDLTypesGen.cpp.inc"
466
467#include "mlir/Dialect/IRDL/IR/IRDLEnums.cpp.inc"
468
469#define GET_ATTRDEF_CLASSES
470#include "mlir/Dialect/IRDL/IR/IRDLAttributes.cpp.inc"
471
472#define GET_OP_CLASSES
473#include "mlir/Dialect/IRDL/IR/IRDLOps.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static LogicalResult verifyNames(Operation *op, StringRef kindName, ArrayAttr names, size_t numOperands)
Definition IRDL.cpp:168
static ParseResult parseValueWithVariadicity(OpAsmParser &p, OpAsmParser::UnresolvedOperand &operand, VariadicityAttr &variadicityAttr)
Parse a value with its variadicity first.
Definition IRDL.cpp:292
static void printAttributesOp(OpAsmPrinter &p, AttributesOp op, OperandRange attrArgs, ArrayAttr attrNames)
Definition IRDL.cpp:439
static llvm::LogicalResult isValidName(llvm::StringRef in, mlir::Operation *loc, const Twine &label)
Definition IRDL.cpp:75
static LogicalResult checkSymbolIsTypeOrAttribute(SymbolTableCollection &symbolTable, Operation *source, SymbolRefAttr symbol)
Finds whether the provided symbol is an IRDL type or attribute definition.
Definition IRDL.cpp:253
static LogicalResult verifyOperandsResultsCommon(ValueListOp op, StringRef kindName)
Definition IRDL.cpp:199
static ParseResult parseSingleBlockRegion(OpAsmParser &p, Region &region)
Parse a region, and add a single block if the region is empty.
Definition IRDL.cpp:58
static ParseResult parseAttributesOp(OpAsmParser &p, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &attrOperands, ArrayAttr &attrNamesAttr)
Definition IRDL.cpp:420
static void printNamedValueList(OpAsmPrinter &p, Operation *op, OperandRange operands, ArrayAttr valueNamesAttr)
Print a list of named values.
Definition IRDL.cpp:400
static void printNamedValueListWithVariadicity(OpAsmPrinter &p, Operation *op, OperandRange operands, ArrayAttr valueNamesAttr, VariadicityArrayAttr variadicityAttr)
Print a list of named values with their variadicities first.
Definition IRDL.cpp:413
static ParseResult parseNamedValueList(OpAsmParser &p, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, ArrayAttr &valueNamesAttr)
Parse a list of named values.
Definition IRDL.cpp:358
static ParseResult parseNamedValueListImpl(OpAsmParser &p, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, ArrayAttr &valueNamesAttr, VariadicityArrayAttr *variadicityAttr)
Definition IRDL.cpp:314
static void printNamedValueListImpl(OpAsmPrinter &p, Operation *op, OperandRange operands, ArrayAttr valueNamesAttr, VariadicityArrayAttr variadicityAttr)
Definition IRDL.cpp:377
static ParseResult parseNamedValueListWithVariadicity(OpAsmParser &p, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, ArrayAttr &valueNamesAttr, VariadicityArrayAttr &variadicityAttr)
Parse a list of named values with their variadicities first.
Definition IRDL.cpp:371
static void printSingleBlockRegion(OpAsmPrinter &p, Operation *op, Region &region)
Definition IRDL.cpp:70
lhs
ArrayAttr()
@ Paren
Parens surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColon()=0
Parse a : token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:266
MLIRContext * getContext() const
Definition Builders.h:56
static bool isValidNamespace(StringRef str)
Utility function that returns if the given string is a valid dialect namespace.
Definition Dialect.cpp:95
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region if present.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
void push_back(Block *block)
Definition Region.h:61
bool empty()
Definition Region.h:60
BlockListType & getBlocks()
Definition Region.h:45
This class represents a collection of SymbolTables.
Operation * lookupSymbolNearDialect(SymbolTableCollection &symbolTable, Operation *source, SymbolRefAttr symbol)
Looks up a symbol from the symbol table containing the source operation's dialect definition operatio...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
Include the generated interface declarations.
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:136
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:118
This is the representation of an operand reference.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.