MLIR 22.0.0git
IRDL.cpp
Go to the documentation of this file.
1//===- IRDL.cpp - IRDL dialect ----------------------------------*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
11#include "mlir/IR/Builders.h"
13#include "mlir/IR/Diagnostics.h"
18#include "mlir/IR/Operation.h"
19#include "mlir/Support/LLVM.h"
20#include "llvm/ADT/STLExtras.h"
21#include "llvm/ADT/SetOperations.h"
22#include "llvm/ADT/StringExtras.h"
23#include "llvm/ADT/TypeSwitch.h"
24#include "llvm/Support/Casting.h"
25
26using namespace mlir;
27using namespace mlir::irdl;
28
29//===----------------------------------------------------------------------===//
30// IRDL dialect.
31//===----------------------------------------------------------------------===//
32
33#include "mlir/Dialect/IRDL/IR/IRDL.cpp.inc"
34
35#include "mlir/Dialect/IRDL/IR/IRDLDialect.cpp.inc"
36
37void IRDLDialect::initialize() {
38 addOperations<
39#define GET_OP_LIST
40#include "mlir/Dialect/IRDL/IR/IRDLOps.cpp.inc"
41 >();
42 addTypes<
43#define GET_TYPEDEF_LIST
44#include "mlir/Dialect/IRDL/IR/IRDLTypesGen.cpp.inc"
45 >();
46 addAttributes<
47#define GET_ATTRDEF_LIST
48#include "mlir/Dialect/IRDL/IR/IRDLAttributes.cpp.inc"
49 >();
50}
51
52//===----------------------------------------------------------------------===//
53// Parsing/Printing/Verifying
54//===----------------------------------------------------------------------===//
55
56/// Parse a region, and add a single block if the region is empty.
57/// If no region is parsed, create a new region with a single empty block.
58static ParseResult parseSingleBlockRegion(OpAsmParser &p, Region &region) {
59 auto regionParseRes = p.parseOptionalRegion(region);
60 if (regionParseRes.has_value() && failed(regionParseRes.value()))
61 return failure();
62
63 // If the region is empty, add a single empty block.
64 if (region.empty())
65 region.push_back(new Block());
66
67 return success();
68}
69
71 Region &region) {
72 if (!region.getBlocks().front().empty())
73 p.printRegion(region);
74}
75static llvm::LogicalResult isValidName(llvm::StringRef in, mlir::Operation *loc,
76 const Twine &label) {
77 if (in.empty())
78 return loc->emitError("name of ") << label << " is empty";
79
80 bool allowUnderscore = false;
81 for (auto &elem : in) {
82 if (elem == '_') {
83 if (!allowUnderscore)
84 return loc->emitError("name of ")
85 << label << " should not contain leading or double underscores";
86 } else {
87 if (!isalnum(elem))
88 return loc->emitError("name of ")
89 << label
90 << " must contain only lowercase letters, digits and "
91 "underscores";
92
93 if (llvm::isUpper(elem))
94 return loc->emitError("name of ")
95 << label << " should not contain uppercase letters";
96 }
97
98 allowUnderscore = elem != '_';
99 }
100
101 return success();
102}
103
104LogicalResult DialectOp::verify() {
105 if (!Dialect::isValidNamespace(getName()))
106 return emitOpError("invalid dialect name");
107 if (failed(isValidName(getSymName(), getOperation(), "dialect")))
108 return failure();
109
110 return success();
111}
112
113LogicalResult OperationOp::verify() {
114 return isValidName(getSymName(), getOperation(), "operation");
115}
116
117LogicalResult TypeOp::verify() {
118 auto symName = getSymName();
119 if (symName.front() == '!')
120 symName = symName.substr(1);
121 return isValidName(symName, getOperation(), "type");
122}
123
124LogicalResult AttributeOp::verify() {
125 auto symName = getSymName();
126 if (symName.front() == '#')
127 symName = symName.substr(1);
128 return isValidName(symName, getOperation(), "attribute");
129}
130
131LogicalResult OperationOp::verifyRegions() {
132 // Stores pairs of value kinds and the list of names of values of this kind in
133 // the operation.
135
136 auto insertNames = [&](StringRef kind, ArrayAttr names) {
137 llvm::SmallDenseSet<StringRef> nameSet;
138 nameSet.reserve(names.size());
139 for (auto name : names)
140 nameSet.insert(llvm::cast<StringAttr>(name).getValue());
141 valueNames.emplace_back(kind, std::move(nameSet));
142 };
143
144 for (Operation &op : getBody().getOps()) {
146 .Case<OperandsOp>(
147 [&](OperandsOp op) { insertNames("operands", op.getNames()); })
148 .Case<ResultsOp>(
149 [&](ResultsOp op) { insertNames("results", op.getNames()); })
150 .Case<RegionsOp>(
151 [&](RegionsOp op) { insertNames("regions", op.getNames()); });
152 }
153
154 // Verify that no two operand, result or region share the same name.
155 // The absence of duplicates within each value kind is checked by the
156 // associated operation's verifier.
157 for (size_t i : llvm::seq(valueNames.size())) {
158 for (size_t j : llvm::seq(i + 1, valueNames.size())) {
159 auto [lhs, lhsSet] = valueNames[i];
160 auto &[rhs, rhsSet] = valueNames[j];
161 llvm::set_intersect(lhsSet, rhsSet);
162 if (!lhsSet.empty())
163 return emitOpError("contains a value named '")
164 << *lhsSet.begin() << "' for both its " << lhs << " and " << rhs;
165 }
166 }
167
168 return success();
169}
170
171static LogicalResult verifyNames(Operation *op, StringRef kindName,
172 ArrayAttr names, size_t numOperands) {
173 if (numOperands != names.size())
174 return op->emitOpError()
175 << "the number of " << kindName
176 << "s and their names must be "
177 "the same, but got "
178 << numOperands << " and " << names.size() << " respectively";
179
181 for (auto [i, name] : llvm::enumerate(names)) {
182 StringRef nameRef = llvm::cast<StringAttr>(name).getValue();
183
184 if (failed(isValidName(nameRef, op, Twine(kindName) + " #" + Twine(i))))
185 return failure();
186
187 if (nameMap.contains(nameRef))
188 return op->emitOpError() << "name of " << kindName << " #" << i
189 << " is a duplicate of the name of " << kindName
190 << " #" << nameMap[nameRef];
191 nameMap.insert({nameRef, i});
192 }
193
194 return success();
195}
196
197LogicalResult ParametersOp::verify() {
198 return verifyNames(*this, "parameter", getNames(), getNumOperands());
199}
200
201template <typename ValueListOp>
202static LogicalResult verifyOperandsResultsCommon(ValueListOp op,
203 StringRef kindName) {
204 size_t numVariadicities = op.getVariadicity().size();
205 size_t numOperands = op.getNumOperands();
206
207 if (numOperands != numVariadicities)
208 return op.emitOpError()
209 << "the number of " << kindName
210 << "s and their variadicities must be "
211 "the same, but got "
212 << numOperands << " and " << numVariadicities << " respectively";
213
214 return verifyNames(op, kindName, op.getNames(), numOperands);
215}
216
217LogicalResult OperandsOp::verify() {
218 return verifyOperandsResultsCommon(*this, "operand");
219}
220
221LogicalResult ResultsOp::verify() {
222 return verifyOperandsResultsCommon(*this, "result");
223}
224
225LogicalResult AttributesOp::verify() {
226 size_t namesSize = getAttributeValueNames().size();
227 size_t valuesSize = getAttributeValues().size();
228
229 if (namesSize != valuesSize)
230 return emitOpError()
231 << "the number of attribute names and their constraints must be "
232 "the same but got "
233 << namesSize << " and " << valuesSize << " respectively";
234
235 return success();
236}
237
238LogicalResult BaseOp::verify() {
239 std::optional<StringRef> baseName = getBaseName();
240 std::optional<SymbolRefAttr> baseRef = getBaseRef();
241 if (baseName.has_value() == baseRef.has_value())
242 return emitOpError() << "the base type or attribute should be specified by "
243 "either a name or a reference";
244
245 if (baseName &&
246 (baseName->empty() || ((*baseName)[0] != '!' && (*baseName)[0] != '#')))
247 return emitOpError() << "the base type or attribute name should start with "
248 "'!' or '#'";
249
250 return success();
251}
252
253/// Finds whether the provided symbol is an IRDL type or attribute definition.
254/// The source operation must be within a DialectOp.
255static LogicalResult
257 Operation *source, SymbolRefAttr symbol) {
258 Operation *targetOp =
259 irdl::lookupSymbolNearDialect(symbolTable, source, symbol);
260
261 if (!targetOp)
262 return source->emitOpError() << "symbol '" << symbol << "' not found";
263
264 if (!isa<TypeOp, AttributeOp>(targetOp))
265 return source->emitOpError() << "symbol '" << symbol
266 << "' does not refer to a type or attribute "
267 "definition (refers to '"
268 << targetOp->getName() << "')";
269
270 return success();
271}
272
273LogicalResult BaseOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
274 std::optional<SymbolRefAttr> baseRef = getBaseRef();
275 if (!baseRef)
276 return success();
277
278 return checkSymbolIsTypeOrAttribute(symbolTable, *this, *baseRef);
279}
280
281LogicalResult
282ParametricOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
283 std::optional<SymbolRefAttr> baseRef = getBaseType();
284 if (!baseRef)
285 return success();
286
287 return checkSymbolIsTypeOrAttribute(symbolTable, *this, *baseRef);
288}
289
290/// Parse a value with its variadicity first. By default, the variadicity is
291/// single.
292///
293/// value-with-variadicity ::= ("single" | "optional" | "variadic")? ssa-value
294static ParseResult
297 VariadicityAttr &variadicityAttr) {
298 MLIRContext *ctx = p.getBuilder().getContext();
299
300 // Parse the variadicity, if present
301 if (p.parseOptionalKeyword("single").succeeded()) {
302 variadicityAttr = VariadicityAttr::get(ctx, Variadicity::single);
303 } else if (p.parseOptionalKeyword("optional").succeeded()) {
304 variadicityAttr = VariadicityAttr::get(ctx, Variadicity::optional);
305 } else if (p.parseOptionalKeyword("variadic").succeeded()) {
306 variadicityAttr = VariadicityAttr::get(ctx, Variadicity::variadic);
307 } else {
308 variadicityAttr = VariadicityAttr::get(ctx, Variadicity::single);
309 }
310
311 // Parse the value
312 if (p.parseOperand(operand))
313 return failure();
314 return success();
315}
316
317static ParseResult parseNamedValueListImpl(
319 ArrayAttr &valueNamesAttr, VariadicityArrayAttr *variadicityAttr) {
320 Builder &builder = p.getBuilder();
321 MLIRContext *ctx = builder.getContext();
322 SmallVector<Attribute> valueNames;
323 SmallVector<VariadicityAttr> variadicities;
324
325 // Parse a single value with its variadicity
326 auto parseOne = [&] {
327 StringRef name;
329 VariadicityAttr variadicity;
330 if (p.parseKeyword(&name) || p.parseColon())
331 return failure();
332
333 if (variadicityAttr) {
334 if (parseValueWithVariadicity(p, operand, variadicity))
335 return failure();
336 variadicities.push_back(variadicity);
337 } else {
338 if (p.parseOperand(operand))
339 return failure();
340 }
341
342 valueNames.push_back(StringAttr::get(ctx, name));
343 operands.push_back(operand);
344 return success();
345 };
346
348 return failure();
349 valueNamesAttr = ArrayAttr::get(ctx, valueNames);
350 if (variadicityAttr)
351 *variadicityAttr = VariadicityArrayAttr::get(ctx, variadicities);
352 return success();
353}
354
355/// Parse a list of named values.
356///
357/// values ::=
358/// `(` (named-value (`,` named-value)*)? `)`
359/// named-value := bare-id `:` ssa-value
360static ParseResult
363 ArrayAttr &valueNamesAttr) {
364 return parseNamedValueListImpl(p, operands, valueNamesAttr, nullptr);
365}
366
367/// Parse a list of named values with their variadicities first. By default, the
368/// variadicity is single.
369///
370/// values-with-variadicity ::=
371/// `(` (value-with-variadicity (`,` value-with-variadicity)*)? `)`
372/// value-with-variadicity
373/// ::= bare-id `:` ("single" | "optional" | "variadic")? ssa-value
376 ArrayAttr &valueNamesAttr, VariadicityArrayAttr &variadicityAttr) {
377 return parseNamedValueListImpl(p, operands, valueNamesAttr, &variadicityAttr);
378}
379
381 OperandRange operands,
382 ArrayAttr valueNamesAttr,
383 VariadicityArrayAttr variadicityAttr) {
384 p << "(";
385 interleaveComma(llvm::seq<int>(0, operands.size()), p, [&](int i) {
386 p << llvm::cast<StringAttr>(valueNamesAttr[i]).getValue() << ": ";
387 if (variadicityAttr) {
388 Variadicity variadicity = variadicityAttr[i].getValue();
389 if (variadicity != Variadicity::single) {
390 p << stringifyVariadicity(variadicity) << " ";
391 }
392 }
393 p << operands[i];
394 });
395 p << ")";
396}
397
398/// Print a list of named values.
399///
400/// values ::=
401/// `(` (named-value (`,` named-value)*)? `)`
402/// named-value := bare-id `:` ssa-value
404 OperandRange operands,
405 ArrayAttr valueNamesAttr) {
406 printNamedValueListImpl(p, op, operands, valueNamesAttr, nullptr);
407}
408
409/// Print a list of named values with their variadicities first. By default, the
410/// variadicity is single.
411///
412/// values-with-variadicity ::=
413/// `(` (value-with-variadicity (`,` value-with-variadicity)*)? `)`
414/// value-with-variadicity ::=
415/// bare-id `:` ("single" | "optional" | "variadic")? ssa-value
417 OpAsmPrinter &p, Operation *op, OperandRange operands,
418 ArrayAttr valueNamesAttr, VariadicityArrayAttr variadicityAttr) {
419 printNamedValueListImpl(p, op, operands, valueNamesAttr, variadicityAttr);
420}
421
422static ParseResult
425 ArrayAttr &attrNamesAttr) {
426 Builder &builder = p.getBuilder();
427 SmallVector<Attribute> attrNames;
428 if (succeeded(p.parseOptionalLBrace())) {
429 auto parseOperands = [&]() {
430 if (p.parseAttribute(attrNames.emplace_back()) || p.parseEqual() ||
431 p.parseOperand(attrOperands.emplace_back()))
432 return failure();
433 return success();
434 };
435 if (p.parseCommaSeparatedList(parseOperands) || p.parseRBrace())
436 return failure();
437 }
438 attrNamesAttr = builder.getArrayAttr(attrNames);
439 return success();
440}
441
442static void printAttributesOp(OpAsmPrinter &p, AttributesOp op,
443 OperandRange attrArgs, ArrayAttr attrNames) {
444 if (attrNames.empty())
445 return;
446 p << "{";
447 interleaveComma(llvm::seq<int>(0, attrNames.size()), p,
448 [&](int i) { p << attrNames[i] << " = " << attrArgs[i]; });
449 p << '}';
450}
451
452LogicalResult RegionOp::verify() {
453 if (IntegerAttr numberOfBlocks = getNumberOfBlocksAttr())
454 if (int64_t number = numberOfBlocks.getInt(); number <= 0) {
455 return emitOpError("the number of blocks is expected to be >= 1 but got ")
456 << number;
457 }
458 return success();
459}
460
461LogicalResult RegionsOp::verify() {
462 return verifyNames(*this, "region", getNames(), getNumOperands());
463}
464
465#include "mlir/Dialect/IRDL/IR/IRDLInterfaces.cpp.inc"
466
467#define GET_TYPEDEF_CLASSES
468#include "mlir/Dialect/IRDL/IR/IRDLTypesGen.cpp.inc"
469
470#include "mlir/Dialect/IRDL/IR/IRDLEnums.cpp.inc"
471
472#define GET_ATTRDEF_CLASSES
473#include "mlir/Dialect/IRDL/IR/IRDLAttributes.cpp.inc"
474
475#define GET_OP_CLASSES
476#include "mlir/Dialect/IRDL/IR/IRDLOps.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static LogicalResult verifyNames(Operation *op, StringRef kindName, ArrayAttr names, size_t numOperands)
Definition IRDL.cpp:171
static ParseResult parseValueWithVariadicity(OpAsmParser &p, OpAsmParser::UnresolvedOperand &operand, VariadicityAttr &variadicityAttr)
Parse a value with its variadicity first.
Definition IRDL.cpp:295
static void printAttributesOp(OpAsmPrinter &p, AttributesOp op, OperandRange attrArgs, ArrayAttr attrNames)
Definition IRDL.cpp:442
static llvm::LogicalResult isValidName(llvm::StringRef in, mlir::Operation *loc, const Twine &label)
Definition IRDL.cpp:75
static LogicalResult checkSymbolIsTypeOrAttribute(SymbolTableCollection &symbolTable, Operation *source, SymbolRefAttr symbol)
Finds whether the provided symbol is an IRDL type or attribute definition.
Definition IRDL.cpp:256
static LogicalResult verifyOperandsResultsCommon(ValueListOp op, StringRef kindName)
Definition IRDL.cpp:202
static ParseResult parseSingleBlockRegion(OpAsmParser &p, Region &region)
Parse a region, and add a single block if the region is empty.
Definition IRDL.cpp:58
static ParseResult parseAttributesOp(OpAsmParser &p, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &attrOperands, ArrayAttr &attrNamesAttr)
Definition IRDL.cpp:423
static void printNamedValueList(OpAsmPrinter &p, Operation *op, OperandRange operands, ArrayAttr valueNamesAttr)
Print a list of named values.
Definition IRDL.cpp:403
static void printNamedValueListWithVariadicity(OpAsmPrinter &p, Operation *op, OperandRange operands, ArrayAttr valueNamesAttr, VariadicityArrayAttr variadicityAttr)
Print a list of named values with their variadicities first.
Definition IRDL.cpp:416
static ParseResult parseNamedValueList(OpAsmParser &p, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, ArrayAttr &valueNamesAttr)
Parse a list of named values.
Definition IRDL.cpp:361
static ParseResult parseNamedValueListImpl(OpAsmParser &p, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, ArrayAttr &valueNamesAttr, VariadicityArrayAttr *variadicityAttr)
Definition IRDL.cpp:317
static void printNamedValueListImpl(OpAsmPrinter &p, Operation *op, OperandRange operands, ArrayAttr valueNamesAttr, VariadicityArrayAttr variadicityAttr)
Definition IRDL.cpp:380
static ParseResult parseNamedValueListWithVariadicity(OpAsmParser &p, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, ArrayAttr &valueNamesAttr, VariadicityArrayAttr &variadicityAttr)
Parse a list of named values with their variadicities first.
Definition IRDL.cpp:374
static void printSingleBlockRegion(OpAsmPrinter &p, Operation *op, Region &region)
Definition IRDL.cpp:70
lhs
ArrayAttr()
@ Paren
Parens surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColon()=0
Parse a : token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:266
MLIRContext * getContext() const
Definition Builders.h:56
static bool isValidNamespace(StringRef str)
Utility function that returns if the given string is a valid dialect namespace.
Definition Dialect.cpp:95
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region if present.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
void push_back(Block *block)
Definition Region.h:61
bool empty()
Definition Region.h:60
BlockListType & getBlocks()
Definition Region.h:45
This class represents a collection of SymbolTables.
Operation * lookupSymbolNearDialect(SymbolTableCollection &symbolTable, Operation *source, SymbolRefAttr symbol)
Looks up a symbol from the symbol table containing the source operation's dialect definition operatio...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:144
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
This is the representation of an operand reference.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.