20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/SetOperations.h"
22 #include "llvm/ADT/SmallString.h"
23 #include "llvm/ADT/StringExtras.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 #include "llvm/IR/Metadata.h"
26 #include "llvm/Support/Casting.h"
35 #include "mlir/Dialect/IRDL/IR/IRDL.cpp.inc"
37 #include "mlir/Dialect/IRDL/IR/IRDLDialect.cpp.inc"
39 void IRDLDialect::initialize() {
42 #include "mlir/Dialect/IRDL/IR/IRDLOps.cpp.inc"
45 #define GET_TYPEDEF_LIST
46 #include "mlir/Dialect/IRDL/IR/IRDLTypesGen.cpp.inc"
49 #define GET_ATTRDEF_LIST
50 #include "mlir/Dialect/IRDL/IR/IRDLAttributes.cpp.inc"
62 if (regionParseRes.has_value() && failed(regionParseRes.value()))
80 return loc->
emitError(
"name of ") << label <<
" is empty";
82 bool allowUnderscore =
false;
83 for (
auto &elem : in) {
87 << label <<
" should not contain leading or double underscores";
92 <<
" must contain only lowercase letters, digits and "
95 if (llvm::isUpper(elem))
97 << label <<
" should not contain uppercase letters";
100 allowUnderscore = elem !=
'_';
108 return emitOpError(
"invalid dialect name");
109 if (failed(
isValidName(getSymName(), getOperation(),
"dialect")))
116 return isValidName(getSymName(), getOperation(),
"operation");
120 auto symName = getSymName();
121 if (symName.front() ==
'!')
122 symName = symName.substr(1);
123 return isValidName(symName, getOperation(),
"type");
127 auto symName = getSymName();
128 if (symName.front() ==
'#')
129 symName = symName.substr(1);
130 return isValidName(symName, getOperation(),
"attribute");
133 LogicalResult OperationOp::verifyRegions() {
138 auto insertNames = [&](StringRef
kind, ArrayAttr names) {
139 llvm::SmallDenseSet<StringRef> nameSet;
140 nameSet.reserve(names.size());
141 for (
auto name : names)
142 nameSet.insert(llvm::cast<StringAttr>(name).getValue());
143 valueNames.emplace_back(
kind, std::move(nameSet));
146 for (
Operation &op : getBody().getOps()) {
149 [&](OperandsOp op) { insertNames(
"operands", op.getNames()); })
151 [&](ResultsOp op) { insertNames(
"results", op.getNames()); })
153 [&](RegionsOp op) { insertNames(
"regions", op.getNames()); });
159 for (
size_t i : llvm::seq(valueNames.size())) {
160 for (
size_t j : llvm::seq(i + 1, valueNames.size())) {
161 auto [lhs, lhsSet] = valueNames[i];
162 auto &[rhs, rhsSet] = valueNames[
j];
163 llvm::set_intersect(lhsSet, rhsSet);
165 return emitOpError(
"contains a value named '")
166 << *lhsSet.begin() <<
"' for both its " << lhs <<
" and " << rhs;
174 ArrayAttr names,
size_t numOperands) {
175 if (numOperands != names.size())
177 <<
"the number of " << kindName
178 <<
"s and their names must be "
180 << numOperands <<
" and " << names.size() <<
" respectively";
184 StringRef nameRef = llvm::cast<StringAttr>(name).getValue();
186 if (failed(
isValidName(nameRef, op, Twine(kindName) +
" #" + Twine(i))))
189 if (nameMap.contains(nameRef))
190 return op->
emitOpError() <<
"name of " << kindName <<
" #" << i
191 <<
" is a duplicate of the name of " << kindName
192 <<
" #" << nameMap[nameRef];
193 nameMap.insert({nameRef, i});
200 return verifyNames(*
this,
"parameter", getNames(), getNumOperands());
203 template <
typename ValueListOp>
205 StringRef kindName) {
206 size_t numVariadicities = op.getVariadicity().size();
207 size_t numOperands = op.getNumOperands();
209 if (numOperands != numVariadicities)
210 return op.emitOpError()
211 <<
"the number of " << kindName
212 <<
"s and their variadicities must be "
214 << numOperands <<
" and " << numVariadicities <<
" respectively";
216 return verifyNames(op, kindName, op.getNames(), numOperands);
228 size_t namesSize = getAttributeValueNames().size();
229 size_t valuesSize = getAttributeValues().size();
231 if (namesSize != valuesSize)
233 <<
"the number of attribute names and their constraints must be "
235 << namesSize <<
" and " << valuesSize <<
" respectively";
241 std::optional<StringRef> baseName = getBaseName();
242 std::optional<SymbolRefAttr> baseRef = getBaseRef();
243 if (baseName.has_value() == baseRef.has_value())
244 return emitOpError() <<
"the base type or attribute should be specified by "
245 "either a name or a reference";
248 (baseName->empty() || ((*baseName)[0] !=
'!' && (*baseName)[0] !=
'#')))
249 return emitOpError() <<
"the base type or attribute name should start with "
259 Operation *source, SymbolRefAttr symbol) {
264 return source->
emitOpError() <<
"symbol '" << symbol <<
"' not found";
266 if (!isa<TypeOp, AttributeOp>(targetOp))
267 return source->
emitOpError() <<
"symbol '" << symbol
268 <<
"' does not refer to a type or attribute "
269 "definition (refers to '"
270 << targetOp->
getName() <<
"')";
276 std::optional<SymbolRefAttr> baseRef = getBaseRef();
285 std::optional<SymbolRefAttr> baseRef = getBaseType();
299 VariadicityAttr &variadicityAttr) {
321 ArrayAttr &valueNamesAttr, VariadicityArrayAttr *variadicityAttr) {
328 auto parseOne = [&] {
331 VariadicityAttr variadicity;
335 if (variadicityAttr) {
338 variadicities.push_back(variadicity);
345 operands.push_back(operand);
365 ArrayAttr &valueNamesAttr) {
378 ArrayAttr &valueNamesAttr, VariadicityArrayAttr &variadicityAttr) {
384 ArrayAttr valueNamesAttr,
385 VariadicityArrayAttr variadicityAttr) {
387 interleaveComma(llvm::seq<int>(0, operands.size()), p, [&](
int i) {
388 p << llvm::cast<StringAttr>(valueNamesAttr[i]).getValue() <<
": ";
389 if (variadicityAttr) {
390 Variadicity variadicity = variadicityAttr[i].getValue();
391 if (variadicity != Variadicity::single) {
392 p << stringifyVariadicity(variadicity) <<
" ";
407 ArrayAttr valueNamesAttr) {
420 ArrayAttr valueNamesAttr, VariadicityArrayAttr variadicityAttr) {
427 ArrayAttr &attrNamesAttr) {
431 auto parseOperands = [&]() {
446 if (attrNames.empty())
449 interleaveComma(llvm::seq<int>(0, attrNames.size()), p,
450 [&](
int i) { p << attrNames[i] <<
" = " << attrArgs[i]; });
455 if (IntegerAttr numberOfBlocks = getNumberOfBlocksAttr())
456 if (int64_t number = numberOfBlocks.getInt(); number <= 0) {
457 return emitOpError(
"the number of blocks is expected to be >= 1 but got ")
464 return verifyNames(*
this,
"region", getNames(), getNumOperands());
467 #include "mlir/Dialect/IRDL/IR/IRDLInterfaces.cpp.inc"
469 #define GET_TYPEDEF_CLASSES
470 #include "mlir/Dialect/IRDL/IR/IRDLTypesGen.cpp.inc"
472 #include "mlir/Dialect/IRDL/IR/IRDLEnums.cpp.inc"
474 #define GET_ATTRDEF_CLASSES
475 #include "mlir/Dialect/IRDL/IR/IRDLAttributes.cpp.inc"
477 #define GET_OP_CLASSES
478 #include "mlir/Dialect/IRDL/IR/IRDLOps.cpp.inc"
static LogicalResult verifyNames(Operation *op, StringRef kindName, ArrayAttr names, size_t numOperands)
static ParseResult parseValueWithVariadicity(OpAsmParser &p, OpAsmParser::UnresolvedOperand &operand, VariadicityAttr &variadicityAttr)
Parse a value with its variadicity first.
static void printAttributesOp(OpAsmPrinter &p, AttributesOp op, OperandRange attrArgs, ArrayAttr attrNames)
static llvm::LogicalResult isValidName(llvm::StringRef in, mlir::Operation *loc, const Twine &label)
static LogicalResult checkSymbolIsTypeOrAttribute(SymbolTableCollection &symbolTable, Operation *source, SymbolRefAttr symbol)
Finds whether the provided symbol is an IRDL type or attribute definition.
static LogicalResult verifyOperandsResultsCommon(ValueListOp op, StringRef kindName)
static ParseResult parseSingleBlockRegion(OpAsmParser &p, Region ®ion)
Parse a region, and add a single block if the region is empty.
static ParseResult parseAttributesOp(OpAsmParser &p, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &attrOperands, ArrayAttr &attrNamesAttr)
static void printNamedValueList(OpAsmPrinter &p, Operation *op, OperandRange operands, ArrayAttr valueNamesAttr)
Print a list of named values.
static void printNamedValueListWithVariadicity(OpAsmPrinter &p, Operation *op, OperandRange operands, ArrayAttr valueNamesAttr, VariadicityArrayAttr variadicityAttr)
Print a list of named values with their variadicities first.
static ParseResult parseNamedValueList(OpAsmParser &p, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, ArrayAttr &valueNamesAttr)
Parse a list of named values.
static ParseResult parseNamedValueListImpl(OpAsmParser &p, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, ArrayAttr &valueNamesAttr, VariadicityArrayAttr *variadicityAttr)
static void printNamedValueListImpl(OpAsmPrinter &p, Operation *op, OperandRange operands, ArrayAttr valueNamesAttr, VariadicityArrayAttr variadicityAttr)
static ParseResult parseNamedValueListWithVariadicity(OpAsmParser &p, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, ArrayAttr &valueNamesAttr, VariadicityArrayAttr &variadicityAttr)
Parse a list of named values with their variadicities first.
static void printSingleBlockRegion(OpAsmPrinter &p, Operation *op, Region ®ion)
union mlir::linalg::@1205::ArityGroupAndKind::Kind kind
@ Paren
Parens surrounding zero or more operands.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColon()=0
Parse a : token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
Block represents an ordered list of Operations.
This class is a general helper class for creating context-global objects like types,...
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
static bool isValidNamespace(StringRef str)
Utility function that returns if the given string is a valid dialect namespace.
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region if present.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
void push_back(Block *block)
BlockListType & getBlocks()
This class represents a collection of SymbolTables.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Operation * lookupSymbolNearDialect(SymbolTableCollection &symbolTable, Operation *source, SymbolRefAttr symbol)
Looks up a symbol from the symbol table containing the source operation's dialect definition operatio...
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This is the representation of an operand reference.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.