20#include "llvm/ADT/STLExtras.h"
21#include "llvm/ADT/SetOperations.h"
22#include "llvm/ADT/StringExtras.h"
23#include "llvm/ADT/TypeSwitch.h"
24#include "llvm/Support/Casting.h"
33#include "mlir/Dialect/IRDL/IR/IRDL.cpp.inc"
35#include "mlir/Dialect/IRDL/IR/IRDLDialect.cpp.inc"
37void IRDLDialect::initialize() {
40#include "mlir/Dialect/IRDL/IR/IRDLOps.cpp.inc"
43#define GET_TYPEDEF_LIST
44#include "mlir/Dialect/IRDL/IR/IRDLTypesGen.cpp.inc"
47#define GET_ATTRDEF_LIST
48#include "mlir/Dialect/IRDL/IR/IRDLAttributes.cpp.inc"
60 if (regionParseRes.has_value() && failed(regionParseRes.value()))
78 return loc->
emitError(
"name of ") << label <<
" is empty";
80 bool allowUnderscore =
false;
81 for (
auto &elem : in) {
85 << label <<
" should not contain leading or double underscores";
90 <<
" must contain only lowercase letters, digits and "
93 if (llvm::isUpper(elem))
95 << label <<
" should not contain uppercase letters";
98 allowUnderscore = elem !=
'_';
104LogicalResult DialectOp::verify() {
113LogicalResult OperationOp::verify() {
114 return isValidName(getSymName(), getOperation(),
"operation");
117LogicalResult TypeOp::verify() {
118 auto symName = getSymName();
119 if (symName.front() ==
'!')
120 symName = symName.substr(1);
121 return isValidName(symName, getOperation(),
"type");
124LogicalResult AttributeOp::verify() {
125 auto symName = getSymName();
126 if (symName.front() ==
'#')
127 symName = symName.substr(1);
128 return isValidName(symName, getOperation(),
"attribute");
131LogicalResult OperationOp::verifyRegions() {
136 auto insertNames = [&](StringRef kind,
ArrayAttr names) {
137 llvm::SmallDenseSet<StringRef> nameSet;
138 nameSet.reserve(names.size());
139 for (
auto name : names)
140 nameSet.insert(llvm::cast<StringAttr>(name).getValue());
141 valueNames.emplace_back(kind, std::move(nameSet));
144 for (
Operation &op : getBody().getOps()) {
146 .Case([&](OperandsOp op) { insertNames(
"operands", op.getNames()); })
147 .Case([&](ResultsOp op) { insertNames(
"results", op.getNames()); })
148 .Case([&](RegionsOp op) { insertNames(
"regions", op.getNames()); });
154 for (
size_t i : llvm::seq(valueNames.size())) {
155 for (
size_t j : llvm::seq(i + 1, valueNames.size())) {
156 auto [
lhs, lhsSet] = valueNames[i];
157 auto &[
rhs, rhsSet] = valueNames[
j];
158 llvm::set_intersect(lhsSet, rhsSet);
161 << *lhsSet.begin() <<
"' for both its " <<
lhs <<
" and " <<
rhs;
170 if (numOperands != names.size())
172 <<
"the number of " << kindName
173 <<
"s and their names must be "
175 << numOperands <<
" and " << names.size() <<
" respectively";
178 for (
auto [i, name] : llvm::enumerate(names)) {
179 StringRef nameRef = llvm::cast<StringAttr>(name).getValue();
181 if (failed(
isValidName(nameRef, op, Twine(kindName) +
" #" + Twine(i))))
184 if (nameMap.contains(nameRef))
185 return op->
emitOpError() <<
"name of " << kindName <<
" #" << i
186 <<
" is a duplicate of the name of " << kindName
187 <<
" #" << nameMap[nameRef];
188 nameMap.insert({nameRef, i});
194LogicalResult ParametersOp::verify() {
195 return verifyNames(*
this,
"parameter", getNames(), getNumOperands());
198template <
typename ValueListOp>
200 StringRef kindName) {
201 size_t numVariadicities = op.getVariadicity().size();
202 size_t numOperands = op.getNumOperands();
204 if (numOperands != numVariadicities)
205 return op.emitOpError()
206 <<
"the number of " << kindName
207 <<
"s and their variadicities must be "
209 << numOperands <<
" and " << numVariadicities <<
" respectively";
211 return verifyNames(op, kindName, op.getNames(), numOperands);
214LogicalResult OperandsOp::verify() {
218LogicalResult ResultsOp::verify() {
222LogicalResult AttributesOp::verify() {
223 size_t namesSize = getAttributeValueNames().size();
224 size_t valuesSize = getAttributeValues().size();
226 if (namesSize != valuesSize)
228 <<
"the number of attribute names and their constraints must be "
230 << namesSize <<
" and " << valuesSize <<
" respectively";
235LogicalResult BaseOp::verify() {
236 std::optional<StringRef> baseName = getBaseName();
237 std::optional<SymbolRefAttr> baseRef = getBaseRef();
238 if (baseName.has_value() == baseRef.has_value())
239 return emitOpError() <<
"the base type or attribute should be specified by "
240 "either a name or a reference";
243 (baseName->empty() || ((*baseName)[0] !=
'!' && (*baseName)[0] !=
'#')))
244 return emitOpError() <<
"the base type or attribute name should start with "
254 Operation *source, SymbolRefAttr symbol) {
259 return source->
emitOpError() <<
"symbol '" << symbol <<
"' not found";
261 if (!isa<TypeOp, AttributeOp>(targetOp))
262 return source->
emitOpError() <<
"symbol '" << symbol
263 <<
"' does not refer to a type or attribute "
264 "definition (refers to '"
265 << targetOp->
getName() <<
"')";
271 std::optional<SymbolRefAttr> baseRef = getBaseRef();
280 std::optional<SymbolRefAttr> baseRef = getBaseType();
294 VariadicityAttr &variadicityAttr) {
299 variadicityAttr = VariadicityAttr::get(ctx, Variadicity::single);
301 variadicityAttr = VariadicityAttr::get(ctx, Variadicity::optional);
303 variadicityAttr = VariadicityAttr::get(ctx, Variadicity::variadic);
305 variadicityAttr = VariadicityAttr::get(ctx, Variadicity::single);
316 ArrayAttr &valueNamesAttr, VariadicityArrayAttr *variadicityAttr) {
323 auto parseOne = [&] {
326 VariadicityAttr variadicity;
330 if (variadicityAttr) {
333 variadicities.push_back(variadicity);
339 valueNames.push_back(StringAttr::get(ctx, name));
340 operands.push_back(operand);
346 valueNamesAttr = ArrayAttr::get(ctx, valueNames);
348 *variadicityAttr = VariadicityArrayAttr::get(ctx, variadicities);
373 ArrayAttr &valueNamesAttr, VariadicityArrayAttr &variadicityAttr) {
380 VariadicityArrayAttr variadicityAttr) {
382 interleaveComma(llvm::seq<int>(0, operands.size()), p, [&](
int i) {
383 p << llvm::cast<StringAttr>(valueNamesAttr[i]).getValue() <<
": ";
384 if (variadicityAttr) {
385 Variadicity variadicity = variadicityAttr[i].getValue();
386 if (variadicity != Variadicity::single) {
387 p << stringifyVariadicity(variadicity) <<
" ";
415 ArrayAttr valueNamesAttr, VariadicityArrayAttr variadicityAttr) {
426 auto parseOperands = [&]() {
441 if (attrNames.empty())
444 interleaveComma(llvm::seq<int>(0, attrNames.size()), p,
445 [&](
int i) { p << attrNames[i] <<
" = " << attrArgs[i]; });
449LogicalResult RegionOp::verify() {
450 if (IntegerAttr numberOfBlocks = getNumberOfBlocksAttr())
451 if (
int64_t number = numberOfBlocks.getInt(); number <= 0) {
452 return emitOpError(
"the number of blocks is expected to be >= 1 but got ")
458LogicalResult RegionsOp::verify() {
459 return verifyNames(*
this,
"region", getNames(), getNumOperands());
462#include "mlir/Dialect/IRDL/IR/IRDLInterfaces.cpp.inc"
464#define GET_TYPEDEF_CLASSES
465#include "mlir/Dialect/IRDL/IR/IRDLTypesGen.cpp.inc"
467#include "mlir/Dialect/IRDL/IR/IRDLEnums.cpp.inc"
469#define GET_ATTRDEF_CLASSES
470#include "mlir/Dialect/IRDL/IR/IRDLAttributes.cpp.inc"
472#define GET_OP_CLASSES
473#include "mlir/Dialect/IRDL/IR/IRDLOps.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static LogicalResult verifyNames(Operation *op, StringRef kindName, ArrayAttr names, size_t numOperands)
static ParseResult parseValueWithVariadicity(OpAsmParser &p, OpAsmParser::UnresolvedOperand &operand, VariadicityAttr &variadicityAttr)
Parse a value with its variadicity first.
static void printAttributesOp(OpAsmPrinter &p, AttributesOp op, OperandRange attrArgs, ArrayAttr attrNames)
static llvm::LogicalResult isValidName(llvm::StringRef in, mlir::Operation *loc, const Twine &label)
static LogicalResult checkSymbolIsTypeOrAttribute(SymbolTableCollection &symbolTable, Operation *source, SymbolRefAttr symbol)
Finds whether the provided symbol is an IRDL type or attribute definition.
static LogicalResult verifyOperandsResultsCommon(ValueListOp op, StringRef kindName)
static ParseResult parseSingleBlockRegion(OpAsmParser &p, Region ®ion)
Parse a region, and add a single block if the region is empty.
static ParseResult parseAttributesOp(OpAsmParser &p, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &attrOperands, ArrayAttr &attrNamesAttr)
static void printNamedValueList(OpAsmPrinter &p, Operation *op, OperandRange operands, ArrayAttr valueNamesAttr)
Print a list of named values.
static void printNamedValueListWithVariadicity(OpAsmPrinter &p, Operation *op, OperandRange operands, ArrayAttr valueNamesAttr, VariadicityArrayAttr variadicityAttr)
Print a list of named values with their variadicities first.
static ParseResult parseNamedValueList(OpAsmParser &p, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, ArrayAttr &valueNamesAttr)
Parse a list of named values.
static ParseResult parseNamedValueListImpl(OpAsmParser &p, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, ArrayAttr &valueNamesAttr, VariadicityArrayAttr *variadicityAttr)
static void printNamedValueListImpl(OpAsmPrinter &p, Operation *op, OperandRange operands, ArrayAttr valueNamesAttr, VariadicityArrayAttr variadicityAttr)
static ParseResult parseNamedValueListWithVariadicity(OpAsmParser &p, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, ArrayAttr &valueNamesAttr, VariadicityArrayAttr &variadicityAttr)
Parse a list of named values with their variadicities first.
static void printSingleBlockRegion(OpAsmPrinter &p, Operation *op, Region ®ion)
@ Paren
Parens surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColon()=0
Parse a : token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
This class is a general helper class for creating context-global objects like types,...
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
MLIRContext * getContext() const
static bool isValidNamespace(StringRef str)
Utility function that returns if the given string is a valid dialect namespace.
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region if present.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
void push_back(Block *block)
BlockListType & getBlocks()
This class represents a collection of SymbolTables.
Operation * lookupSymbolNearDialect(SymbolTableCollection &symbolTable, Operation *source, SymbolRefAttr symbol)
Looks up a symbol from the symbol table containing the source operation's dialect definition operatio...
Include the generated interface declarations.
llvm::TypeSwitch< T, ResultT > TypeSwitch
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
This is the representation of an operand reference.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.