22 #include "llvm/ADT/STLExtras.h"
33 ArrayRef<std::unique_ptr<Constraint>> constraints,
35 if (params.size() != paramConstraints.size()) {
36 emitError() <<
"expected " << paramConstraints.size()
37 <<
" type arguments, but had " << params.size();
53 StringRef attrName,
unsigned numElements,
58 if (!segmentSizesAttr) {
60 <<
"' attribute is expected but not provided";
63 auto denseSegmentSizes = dyn_cast<DenseI32ArrayAttr>(segmentSizesAttr);
64 if (!denseSegmentSizes) {
66 <<
"' attribute is expected to be a dense i32 array";
69 if (denseSegmentSizes.size() != (int64_t)variadicities.size()) {
70 return op->
emitError() <<
"'" << attrName <<
"' attribute for specifying "
71 << elemName <<
" segments must have "
72 << variadicities.size() <<
" elements, but got "
73 << denseSegmentSizes.size();
77 for (
auto [i, segmentSize, variadicity] :
78 enumerate(denseSegmentSizes.asArrayRef(), variadicities)) {
81 <<
"'" << attrName <<
"' attribute for specifying " << elemName
82 <<
" segments must have non-negative values";
83 if (variadicity == Variadicity::single && segmentSize != 1)
84 return op->
emitError() <<
"element " << i <<
" in '" << attrName
85 <<
"' attribute must be equal to 1";
87 if (variadicity == Variadicity::optional && segmentSize > 1)
88 return op->
emitError() <<
"element " << i <<
" in '" << attrName
89 <<
"' attribute must be equal to 0 or 1";
91 segmentSizes.push_back(segmentSize);
96 for (int32_t segmentSize : denseSegmentSizes.asArrayRef())
98 if (sum !=
static_cast<int32_t
>(numElements))
99 return op->
emitError() <<
"sum of elements in '" << attrName
100 <<
"' attribute must be equal to the number of "
112 StringRef attrName,
unsigned numElements,
117 int numberNonSingle = count_if(
118 variadicities, [](Variadicity v) {
return v != Variadicity::single; });
119 if (numberNonSingle > 1)
121 variadicities, segmentSizes);
124 if (numberNonSingle == 0) {
125 if (numElements != variadicities.size()) {
126 return op->
emitError() <<
"op expects exactly " << variadicities.size()
127 <<
" " << elemName <<
"s, but got " << numElements;
129 for (
size_t i = 0, e = variadicities.size(); i < e; ++i)
130 segmentSizes.push_back(1);
134 assert(numberNonSingle == 1);
138 int nonSingleSegmentSize =
static_cast<int>(numElements) -
139 static_cast<int>(variadicities.size()) + 1;
141 if (nonSingleSegmentSize < 0) {
142 return op->
emitError() <<
"op expects at least " << variadicities.size() - 1
143 <<
" " << elemName <<
"s, but got " << numElements;
147 for (Variadicity variadicity : variadicities) {
148 if (variadicity == Variadicity::single) {
149 segmentSizes.push_back(1);
155 if (nonSingleSegmentSize > 1 && variadicity == Variadicity::optional)
156 return op->
emitError() <<
"op expects at most " << variadicities.size()
157 <<
" " << elemName <<
"s, but got " << numElements;
159 segmentSizes.push_back(nonSingleSegmentSize);
214 for (
auto [name, constraint] : attributeConstrs) {
216 std::optional<NamedAttribute> actual = actualAttrs.getNamed(name);
217 if (!actual.has_value())
219 <<
"attribute " << name <<
" is expected but not provided";
222 if (
failed(verifier.
verify({emitError}, actual->getValue(), constraint)))
228 for (
auto [defIndex, segmentSize] :
enumerate(operandSegmentSizes)) {
229 for (
int i = 0; i < segmentSize; i++) {
232 operandConstrs[defIndex])))
240 for (
auto [defIndex, segmentSize] :
enumerate(resultSegmentSizes)) {
241 for (
int i = 0; i < segmentSize; i++) {
244 resultConstrs[defIndex])))
255 ArrayRef<std::unique_ptr<RegionConstraint>> regionsConstraints) {
258 <<
"unexpected number of regions: expected "
259 << regionsConstraints.size() <<
" but got " << op->
getNumRegions();
262 for (
auto [constraint, region] :
263 llvm::zip(regionsConstraints, op->
getRegions()))
264 if (
failed(constraint->verify(region, verifier)))
270 llvm::unique_function<LogicalResult(
Operation *)
const>
273 const DenseMap<irdl::TypeOp, std::unique_ptr<DynamicTypeDefinition>> &types,
274 const DenseMap<irdl::AttributeOp, std::unique_ptr<DynamicAttrDefinition>>
280 if (isa<VerifyConstraintInterface>(op)) {
281 if (op.getNumResults() != 1) {
283 <<
"IRDL constraint operations must have exactly one result";
286 constrToValue.push_back(op.getResult(0));
288 if (isa<VerifyRegionInterface>(op)) {
289 if (op.getNumResults() != 1) {
291 <<
"IRDL constraint operations must have exactly one result";
294 regionToValue.push_back(op.getResult(0));
300 for (
Value v : constrToValue) {
301 VerifyConstraintInterface op =
302 cast<VerifyConstraintInterface>(v.getDefiningOp());
303 std::unique_ptr<Constraint> verifier =
304 op.getVerifier(constrToValue, types, attrs);
307 constraints.push_back(std::move(verifier));
312 for (
Value v : regionToValue) {
313 VerifyRegionInterface op = cast<VerifyRegionInterface>(v.getDefiningOp());
314 std::unique_ptr<RegionConstraint> verifier =
315 op.getVerifier(constrToValue, types, attrs);
316 regionConstraints.push_back(std::move(verifier));
323 auto operandsOp = op.getOp<OperandsOp>();
324 if (operandsOp.has_value()) {
325 operandConstraints.reserve(operandsOp->getArgs().size());
326 for (
Value operand : operandsOp->getArgs()) {
327 for (
auto [i, constr] :
enumerate(constrToValue)) {
328 if (constr == operand) {
329 operandConstraints.push_back(i);
336 for (VariadicityAttr attr : operandsOp->getVariadicity())
337 operandVariadicity.push_back(attr.getValue());
344 auto resultsOp = op.getOp<ResultsOp>();
345 if (resultsOp.has_value()) {
346 resultConstraints.reserve(resultsOp->getArgs().size());
347 for (
Value result : resultsOp->getArgs()) {
348 for (
auto [i, constr] :
enumerate(constrToValue)) {
349 if (constr == result) {
350 resultConstraints.push_back(i);
357 for (
Attribute attr : resultsOp->getVariadicity())
358 resultVariadicity.push_back(cast<VariadicityAttr>(attr).getValue());
363 auto attributesOp = op.getOp<AttributesOp>();
364 if (attributesOp.has_value()) {
366 const ArrayAttr names = attributesOp->getAttributeValueNames();
368 for (
const auto &[name, value] : llvm::zip(names, values)) {
369 for (
auto [i, constr] :
enumerate(constrToValue)) {
370 if (constr == value) {
371 attributeConstraints[cast<StringAttr>(name)] = i;
379 [constraints{std::move(constraints)},
380 regionConstraints{std::move(regionConstraints)},
381 operandConstraints{std::move(operandConstraints)},
382 operandVariadicity{std::move(operandVariadicity)},
383 resultConstraints{std::move(resultConstraints)},
384 resultVariadicity{std::move(resultVariadicity)},
385 attributeConstraints{std::move(attributeConstraints)}](
Operation *op) {
388 op, verifier, operandConstraints, operandVariadicity,
389 resultConstraints, resultVariadicity, attributeConstraints);
390 const LogicalResult opRegionVerifierResult =
392 return LogicalResult::success(opVerifierResult.succeeded() &&
393 opRegionVerifierResult.succeeded());
401 const DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> &types,
402 const DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>>
410 printer.printGenericOp(op);
419 auto regionVerifier = [](
Operation *op) {
return LogicalResult::success(); };
422 op.getName(), dialect, std::move(verifier), std::move(regionVerifier),
423 std::move(parser), std::move(printer));
433 DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> &types,
434 DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> &attrs) {
435 assert((isa<AttributeOp>(attrOrTypeDef) || isa<TypeOp>(attrOrTypeDef)) &&
436 "Expected an attribute or type definition");
441 if (isa<VerifyConstraintInterface>(op)) {
442 assert(op.getNumResults() == 1 &&
443 "IRDL constraint operations must have exactly one result");
444 constrToValue.push_back(op.getResult(0));
450 for (
Value v : constrToValue) {
451 VerifyConstraintInterface op =
452 cast<VerifyConstraintInterface>(v.getDefiningOp());
453 std::unique_ptr<Constraint> verifier =
454 op.getVerifier(constrToValue, types, attrs);
457 constraints.push_back(std::move(verifier));
461 std::optional<ParametersOp> params;
462 if (
auto attr = dyn_cast<AttributeOp>(attrOrTypeDef))
463 params = attr.getOp<ParametersOp>();
464 else if (
auto type = dyn_cast<TypeOp>(attrOrTypeDef))
465 params = type.getOp<ParametersOp>();
469 if (params.has_value()) {
470 paramConstraints.reserve(params->getArgs().size());
471 for (
Value param : params->getArgs()) {
472 for (
auto [i, constr] :
enumerate(constrToValue)) {
473 if (constr == param) {
474 paramConstraints.push_back(i);
481 auto verifier = [paramConstraints{std::move(paramConstraints)},
482 constraints{std::move(constraints)}](
491 return std::move(verifier);
507 if (
auto anyOf = dyn_cast<AnyOfOp>(op)) {
510 hasAny &=
getBases(arg.getDefiningOp(), paramIds, paramIrdlOps, isIds);
516 if (
auto allOf = dyn_cast<AllOfOp>(op))
517 return getBases(
allOf.getArgs()[0].getDefiningOp(), paramIds, paramIrdlOps,
521 if (
auto params = dyn_cast<ParametricOp>(op)) {
522 SymbolRefAttr symRef = params.getBaseType();
524 assert(defOp &&
"symbol reference should refer to an existing operation");
525 paramIrdlOps.insert(defOp);
530 if (
auto is = dyn_cast<IsOp>(op)) {
538 if (
auto isA = dyn_cast<AnyOp>(op))
541 llvm_unreachable(
"unknown IRDL constraint");
568 if (
getBases(argOp, argParamIds, argParamIrdlOps, argIsIds))
574 for (
TypeID id : argParamIds) {
577 bool inserted = paramIds.insert(
id).second;
585 if (paramIds.count(
id))
595 bool inserted = paramIrdlOps.insert(op).second;
608 op.walk([&](DialectOp dialectOp) {
610 StringRef dialectName = dialectOp.getName();
615 dialects.insert({dialectOp, dialect});
626 op.walk([&](TypeOp typeOp) {
629 typeOp.getName(), dialect,
633 typeDefs.try_emplace(typeOp, std::move(typeDef));
644 op.walk([&](AttributeOp attrOp) {
647 attrOp.getName(), dialect,
651 attrDefs.try_emplace(attrOp, std::move(attrDef));
662 return op.emitError(
"any_of constraints are not in the correct form");
676 typeOp, dialects[typeOp.getParentOp()], types, attrs);
679 types[typeOp]->setVerifyFn(std::move(verifier));
686 res = op.walk([&](AttributeOp attrOp) {
688 attrOp, dialects[attrOp.getParentOp()], types, attrs);
691 attrs[attrOp]->setVerifyFn(std::move(verifier));
698 res = op.walk([&](OperationOp opOp) {
699 return loadOperation(opOp, dialects[opOp.getParentOp()], types, attrs);
705 for (
auto &pair : types) {
711 for (
auto &pair : attrs) {
static bool getBases(Operation *op, SmallPtrSet< TypeID, 4 > ¶mIds, SmallPtrSet< Operation *, 4 > ¶mIrdlOps, SmallPtrSet< TypeID, 4 > &isIds)
Get the possible bases of a constraint.
static LogicalResult checkCorrectAnyOf(AnyOfOp anyOf)
Check that an any_of is in the subset IRDL can handle.
static DenseMap< AttributeOp, std::unique_ptr< DynamicAttrDefinition > > preallocateAttrDefs(ModuleOp op, DenseMap< DialectOp, ExtensibleDialect * > dialects)
Preallocate attribute definitions objects with empty verifiers.
LogicalResult getSegmentSizesFromAttr(Operation *op, StringRef elemName, StringRef attrName, unsigned numElements, ArrayRef< Variadicity > variadicities, SmallVectorImpl< int > &segmentSizes)
Get the operand segment sizes from the attribute dictionary.
static DynamicAttrDefinition::VerifierFn getAttrOrTypeVerifier(Operation *attrOrTypeDef, ExtensibleDialect *dialect, DenseMap< TypeOp, std::unique_ptr< DynamicTypeDefinition >> &types, DenseMap< AttributeOp, std::unique_ptr< DynamicAttrDefinition >> &attrs)
Get the verifier of a type or attribute definition.
static LogicalResult irdlOpVerifier(Operation *op, ConstraintVerifier &verifier, ArrayRef< size_t > operandConstrs, ArrayRef< Variadicity > operandVariadicity, ArrayRef< size_t > resultConstrs, ArrayRef< Variadicity > resultVariadicity, const DenseMap< StringAttr, size_t > &attributeConstrs)
Verify that the given operation satisfies the given constraints.
LogicalResult getOperandSegmentSizes(Operation *op, ArrayRef< Variadicity > variadicities, SmallVectorImpl< int > &segmentSizes)
Compute the segment sizes of the given operands.
static DenseMap< TypeOp, std::unique_ptr< DynamicTypeDefinition > > preallocateTypeDefs(ModuleOp op, DenseMap< DialectOp, ExtensibleDialect * > dialects)
Preallocate type definitions objects with empty verifiers.
LogicalResult getSegmentSizes(Operation *op, StringRef elemName, StringRef attrName, unsigned numElements, ArrayRef< Variadicity > variadicities, SmallVectorImpl< int > &segmentSizes)
Compute the segment sizes of the given element (operands, results).
static WalkResult loadOperation(OperationOp op, ExtensibleDialect *dialect, const DenseMap< TypeOp, std::unique_ptr< DynamicTypeDefinition >> &types, const DenseMap< AttributeOp, std::unique_ptr< DynamicAttrDefinition >> &attrs)
Define and load an operation represented by a irdl.operation operation.
static LogicalResult irdlAttrOrTypeVerifier(function_ref< InFlightDiagnostic()> emitError, ArrayRef< Attribute > params, ArrayRef< std::unique_ptr< Constraint >> constraints, ArrayRef< size_t > paramConstraints)
Verify that the given list of parameters satisfy the given constraints.
static LogicalResult irdlRegionVerifier(Operation *op, ConstraintVerifier &verifier, ArrayRef< std::unique_ptr< RegionConstraint >> regionsConstraints)
static DenseMap< DialectOp, ExtensibleDialect * > loadEmptyDialects(ModuleOp op)
Load all dialects in the given module, without loading any operation, type or attribute definitions.
LogicalResult getResultSegmentSizes(Operation *op, ArrayRef< Variadicity > variadicities, SmallVectorImpl< int > &segmentSizes)
Compute the segment sizes of the given results.
Attributes are known-constant values of operations.
TypeID getTypeID()
Return a unique identifier for the concrete attribute type.
static std::unique_ptr< DynamicAttrDefinition > get(StringRef name, ExtensibleDialect *dialect, VerifierFn &&verifier)
Create a new attribute definition at runtime.
llvm::unique_function< LogicalResult(function_ref< InFlightDiagnostic()>, ArrayRef< Attribute >) const > VerifierFn
A dialect that can be defined at runtime.
static std::unique_ptr< DynamicOpDefinition > get(StringRef name, ExtensibleDialect *dialect, OperationName::VerifyInvariantsFn &&verifyFn, OperationName::VerifyRegionInvariantsFn &&verifyRegionFn)
Create a new op at runtime.
static std::unique_ptr< DynamicTypeDefinition > get(StringRef name, ExtensibleDialect *dialect, VerifierFn &&verifier)
Create a new dynamic type definition.
A dialect that can be extended with new operations/types/attributes at runtime.
void registerDynamicOp(std::unique_ptr< DynamicOpDefinition > &&type)
Add a new operation defined at runtime to the dialect.
void registerDynamicType(std::unique_ptr< DynamicTypeDefinition > &&type)
Add a new type defined at runtime to the dialect.
void registerDynamicAttr(std::unique_ptr< DynamicAttrDefinition > &&attr)
Add a new attribute defined at runtime to the dialect.
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
DynamicDialect * getOrLoadDynamicDialect(StringRef dialectNamespace, function_ref< void(DynamicDialect *)> ctor)
Get (or create) a dynamic dialect for the given name.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
DictionaryAttr getAttrDictionary()
Return all of the attributes on this operation as a DictionaryAttr.
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
unsigned getNumRegions()
Returns the number of regions held by this operation.
unsigned getNumOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
iterator_range< OpIterator > getOps()
This class provides an efficient unique identifier for a specific C++ type.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
Provides context to the verification of constraints.
LogicalResult verify(function_ref< InFlightDiagnostic()> emitError, Attribute attr, unsigned variable)
Check that a constraint is satisfied by an attribute.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
llvm::LogicalResult loadDialects(ModuleOp op)
Load all the dialects defined in the module.
Operation * lookupSymbolNearDialect(SymbolTableCollection &symbolTable, Operation *source, SymbolRefAttr symbol)
Looks up a symbol from the symbol table containing the source operation's dialect definition operatio...
llvm::unique_function< LogicalResult(Operation *) const > createVerifier(OperationOp operation, const DenseMap< irdl::TypeOp, std::unique_ptr< DynamicTypeDefinition >> &typeDefs, const DenseMap< irdl::AttributeOp, std::unique_ptr< DynamicAttrDefinition >> &attrDefs)
Generate an op verifier function from the given IRDL operation definition.
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This represents an operation in an abstracted form, suitable for use with the builder APIs.