22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/SmallPtrSet.h"
24 #include "llvm/Support/SMLoc.h"
36 ArrayRef<std::unique_ptr<Constraint>> constraints,
38 if (params.size() != paramConstraints.size()) {
39 emitError() <<
"expected " << paramConstraints.size()
40 <<
" type arguments, but had " << params.size();
56 StringRef attrName,
unsigned numElements,
61 if (!segmentSizesAttr) {
63 <<
"' attribute is expected but not provided";
66 auto denseSegmentSizes = dyn_cast<DenseI32ArrayAttr>(segmentSizesAttr);
67 if (!denseSegmentSizes) {
69 <<
"' attribute is expected to be a dense i32 array";
72 if (denseSegmentSizes.size() != (int64_t)variadicities.size()) {
73 return op->
emitError() <<
"'" << attrName <<
"' attribute for specifying "
74 << elemName <<
" segments must have "
75 << variadicities.size() <<
" elements, but got "
76 << denseSegmentSizes.size();
80 for (
auto [i, segmentSize, variadicity] :
81 enumerate(denseSegmentSizes.asArrayRef(), variadicities)) {
84 <<
"'" << attrName <<
"' attribute for specifying " << elemName
85 <<
" segments must have non-negative values";
86 if (variadicity == Variadicity::single && segmentSize != 1)
87 return op->
emitError() <<
"element " << i <<
" in '" << attrName
88 <<
"' attribute must be equal to 1";
90 if (variadicity == Variadicity::optional && segmentSize > 1)
91 return op->
emitError() <<
"element " << i <<
" in '" << attrName
92 <<
"' attribute must be equal to 0 or 1";
94 segmentSizes.push_back(segmentSize);
99 for (int32_t segmentSize : denseSegmentSizes.asArrayRef())
101 if (sum !=
static_cast<int32_t
>(numElements))
102 return op->
emitError() <<
"sum of elements in '" << attrName
103 <<
"' attribute must be equal to the number of "
115 StringRef attrName,
unsigned numElements,
120 int numberNonSingle = count_if(
121 variadicities, [](Variadicity v) {
return v != Variadicity::single; });
122 if (numberNonSingle > 1)
124 variadicities, segmentSizes);
127 if (numberNonSingle == 0) {
128 if (numElements != variadicities.size()) {
129 return op->
emitError() <<
"op expects exactly " << variadicities.size()
130 <<
" " << elemName <<
"s, but got " << numElements;
132 for (
size_t i = 0, e = variadicities.size(); i < e; ++i)
133 segmentSizes.push_back(1);
137 assert(numberNonSingle == 1);
141 int nonSingleSegmentSize =
static_cast<int>(numElements) -
142 static_cast<int>(variadicities.size()) + 1;
144 if (nonSingleSegmentSize < 0) {
145 return op->
emitError() <<
"op expects at least " << variadicities.size() - 1
146 <<
" " << elemName <<
"s, but got " << numElements;
150 for (Variadicity variadicity : variadicities) {
151 if (variadicity == Variadicity::single) {
152 segmentSizes.push_back(1);
158 if (nonSingleSegmentSize > 1 && variadicity == Variadicity::optional)
159 return op->
emitError() <<
"op expects at most " << variadicities.size()
160 <<
" " << elemName <<
"s, but got " << numElements;
162 segmentSizes.push_back(nonSingleSegmentSize);
217 for (
auto [name, constraint] : attributeConstrs) {
219 std::optional<NamedAttribute> actual = actualAttrs.getNamed(name);
220 if (!actual.has_value())
222 <<
"attribute " << name <<
" is expected but not provided";
225 if (
failed(verifier.
verify({emitError}, actual->getValue(), constraint)))
231 for (
auto [defIndex, segmentSize] :
enumerate(operandSegmentSizes)) {
232 for (
int i = 0; i < segmentSize; i++) {
235 operandConstrs[defIndex])))
243 for (
auto [defIndex, segmentSize] :
enumerate(resultSegmentSizes)) {
244 for (
int i = 0; i < segmentSize; i++) {
247 resultConstrs[defIndex])))
258 ArrayRef<std::unique_ptr<RegionConstraint>> regionsConstraints) {
261 <<
"unexpected number of regions: expected "
262 << regionsConstraints.size() <<
" but got " << op->
getNumRegions();
265 for (
auto [constraint, region] :
266 llvm::zip(regionsConstraints, op->
getRegions()))
267 if (
failed(constraint->verify(region, verifier)))
277 DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> &types,
278 DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> &attrs) {
283 if (isa<VerifyConstraintInterface>(op)) {
286 <<
"IRDL constraint operations must have exactly one result";
287 constrToValue.push_back(op.
getResult(0));
289 if (isa<VerifyRegionInterface>(op)) {
292 <<
"IRDL constraint operations must have exactly one result";
293 regionToValue.push_back(op.
getResult(0));
299 for (
Value v : constrToValue) {
300 VerifyConstraintInterface op =
301 cast<VerifyConstraintInterface>(v.getDefiningOp());
302 std::unique_ptr<Constraint> verifier =
303 op.getVerifier(constrToValue, types, attrs);
306 constraints.push_back(std::move(verifier));
311 for (
Value v : regionToValue) {
312 VerifyRegionInterface op = cast<VerifyRegionInterface>(v.getDefiningOp());
313 std::unique_ptr<RegionConstraint> verifier =
314 op.getVerifier(constrToValue, types, attrs);
315 regionConstraints.push_back(std::move(verifier));
322 auto operandsOp = op.getOp<OperandsOp>();
323 if (operandsOp.has_value()) {
324 operandConstraints.reserve(operandsOp->getArgs().size());
325 for (
Value operand : operandsOp->getArgs()) {
326 for (
auto [i, constr] :
enumerate(constrToValue)) {
327 if (constr == operand) {
328 operandConstraints.push_back(i);
335 for (VariadicityAttr attr : operandsOp->getVariadicity())
336 operandVariadicity.push_back(attr.getValue());
343 auto resultsOp = op.getOp<ResultsOp>();
344 if (resultsOp.has_value()) {
345 resultConstraints.reserve(resultsOp->getArgs().size());
346 for (
Value result : resultsOp->getArgs()) {
347 for (
auto [i, constr] :
enumerate(constrToValue)) {
348 if (constr == result) {
349 resultConstraints.push_back(i);
356 for (
Attribute attr : resultsOp->getVariadicity())
357 resultVariadicity.push_back(attr.cast<VariadicityAttr>().getValue());
362 auto attributesOp = op.getOp<AttributesOp>();
363 if (attributesOp.has_value()) {
365 const ArrayAttr names = attributesOp->getAttributeValueNames();
367 for (
const auto &[name, value] : llvm::zip(names, values)) {
368 for (
auto [i, constr] :
enumerate(constrToValue)) {
369 if (constr == value) {
370 attributesContraints[name.cast<StringAttr>()] = i;
382 printer.printGenericOp(op);
386 [constraints{std::move(constraints)},
387 regionConstraints{std::move(regionConstraints)},
388 operandConstraints{std::move(operandConstraints)},
389 operandVariadicity{std::move(operandVariadicity)},
390 resultConstraints{std::move(resultConstraints)},
391 resultVariadicity{std::move(resultVariadicity)},
392 attributesContraints{std::move(attributesContraints)}](
Operation *op) {
395 op, verifier, operandConstraints, operandVariadicity,
396 resultConstraints, resultVariadicity, attributesContraints);
408 op.
getName(), dialect, std::move(verifier), std::move(regionVerifier),
409 std::move(parser), std::move(printer));
419 DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> &types,
420 DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> &attrs) {
421 assert((isa<AttributeOp>(attrOrTypeDef) || isa<TypeOp>(attrOrTypeDef)) &&
422 "Expected an attribute or type definition");
427 if (isa<VerifyConstraintInterface>(op)) {
429 "IRDL constraint operations must have exactly one result");
430 constrToValue.push_back(op.
getResult(0));
436 for (
Value v : constrToValue) {
437 VerifyConstraintInterface op =
438 cast<VerifyConstraintInterface>(v.getDefiningOp());
439 std::unique_ptr<Constraint> verifier =
440 op.getVerifier(constrToValue, types, attrs);
443 constraints.push_back(std::move(verifier));
447 std::optional<ParametersOp> params;
448 if (
auto attr = dyn_cast<AttributeOp>(attrOrTypeDef))
449 params = attr.getOp<ParametersOp>();
450 else if (
auto type = dyn_cast<TypeOp>(attrOrTypeDef))
451 params = type.getOp<ParametersOp>();
455 if (params.has_value()) {
456 paramConstraints.reserve(params->getArgs().size());
457 for (
Value param : params->getArgs()) {
458 for (
auto [i, constr] :
enumerate(constrToValue)) {
459 if (constr == param) {
460 paramConstraints.push_back(i);
467 auto verifier = [paramConstraints{std::move(paramConstraints)},
468 constraints{std::move(constraints)}](
477 return std::move(verifier);
493 if (
auto anyOf = dyn_cast<AnyOfOp>(op)) {
494 bool has_any =
false;
495 for (
Value arg : anyOf.getArgs())
496 has_any &=
getBases(arg.getDefiningOp(), paramIds, paramIrdlOps, isIds);
502 if (
auto allOf = dyn_cast<AllOfOp>(op))
503 return getBases(allOf.getArgs()[0].getDefiningOp(), paramIds, paramIrdlOps,
507 if (
auto params = dyn_cast<ParametricOp>(op)) {
508 SymbolRefAttr symRef = params.getBaseType();
510 assert(defOp &&
"symbol reference should refer to an existing operation");
511 paramIrdlOps.insert(defOp);
516 if (
auto is = dyn_cast<IsOp>(op)) {
524 if (
auto isA = dyn_cast<AnyOp>(op))
527 llvm_unreachable(
"unknown IRDL constraint");
546 for (
Value arg : anyOf.getArgs()) {
554 if (
getBases(argOp, argParamIds, argParamIrdlOps, argIsIds))
560 for (
TypeID id : argParamIds) {
563 bool inserted = paramIds.insert(
id).second;
571 if (paramIds.count(
id))
581 bool inserted = paramIrdlOps.insert(op).second;
594 op.
walk([&](DialectOp dialectOp) {
596 StringRef dialectName = dialectOp.getName();
601 dialects.insert({dialectOp, dialect});
612 op.
walk([&](TypeOp typeOp) {
615 typeOp.getName(), dialect,
619 typeDefs.try_emplace(typeOp, std::move(typeDef));
630 op.
walk([&](AttributeOp attrOp) {
633 attrOp.getName(), dialect,
637 attrDefs.try_emplace(attrOp, std::move(attrDef));
648 return op.
emitError(
"any_of constraints are not in the correct form");
662 typeOp, dialects[typeOp.getParentOp()], types, attrs);
665 types[typeOp]->setVerifyFn(std::move(verifier));
672 res = op.
walk([&](AttributeOp attrOp) {
674 attrOp, dialects[attrOp.getParentOp()], types, attrs);
677 attrs[attrOp]->setVerifyFn(std::move(verifier));
684 res = op.
walk([&](OperationOp opOp) {
685 return loadOperation(opOp, dialects[opOp.getParentOp()], types, attrs);
691 for (
auto &pair : types) {
697 for (
auto &pair : attrs) {
static bool getBases(Operation *op, SmallPtrSet< TypeID, 4 > ¶mIds, SmallPtrSet< Operation *, 4 > ¶mIrdlOps, SmallPtrSet< TypeID, 4 > &isIds)
Get the possible bases of a constraint.
static LogicalResult checkCorrectAnyOf(AnyOfOp anyOf)
Check that an any_of is in the subset IRDL can handle.
static DenseMap< AttributeOp, std::unique_ptr< DynamicAttrDefinition > > preallocateAttrDefs(ModuleOp op, DenseMap< DialectOp, ExtensibleDialect * > dialects)
Preallocate attribute definitions objects with empty verifiers.
LogicalResult getSegmentSizesFromAttr(Operation *op, StringRef elemName, StringRef attrName, unsigned numElements, ArrayRef< Variadicity > variadicities, SmallVectorImpl< int > &segmentSizes)
Get the operand segment sizes from the attribute dictionary.
static DynamicAttrDefinition::VerifierFn getAttrOrTypeVerifier(Operation *attrOrTypeDef, ExtensibleDialect *dialect, DenseMap< TypeOp, std::unique_ptr< DynamicTypeDefinition >> &types, DenseMap< AttributeOp, std::unique_ptr< DynamicAttrDefinition >> &attrs)
Get the verifier of a type or attribute definition.
static WalkResult loadOperation(OperationOp op, ExtensibleDialect *dialect, DenseMap< TypeOp, std::unique_ptr< DynamicTypeDefinition >> &types, DenseMap< AttributeOp, std::unique_ptr< DynamicAttrDefinition >> &attrs)
Define and load an operation represented by a irdl.operation operation.
static LogicalResult irdlOpVerifier(Operation *op, ConstraintVerifier &verifier, ArrayRef< size_t > operandConstrs, ArrayRef< Variadicity > operandVariadicity, ArrayRef< size_t > resultConstrs, ArrayRef< Variadicity > resultVariadicity, const DenseMap< StringAttr, size_t > &attributeConstrs)
Verify that the given operation satisfies the given constraints.
LogicalResult getOperandSegmentSizes(Operation *op, ArrayRef< Variadicity > variadicities, SmallVectorImpl< int > &segmentSizes)
Compute the segment sizes of the given operands.
static DenseMap< TypeOp, std::unique_ptr< DynamicTypeDefinition > > preallocateTypeDefs(ModuleOp op, DenseMap< DialectOp, ExtensibleDialect * > dialects)
Preallocate type definitions objects with empty verifiers.
LogicalResult getSegmentSizes(Operation *op, StringRef elemName, StringRef attrName, unsigned numElements, ArrayRef< Variadicity > variadicities, SmallVectorImpl< int > &segmentSizes)
Compute the segment sizes of the given element (operands, results).
static LogicalResult irdlAttrOrTypeVerifier(function_ref< InFlightDiagnostic()> emitError, ArrayRef< Attribute > params, ArrayRef< std::unique_ptr< Constraint >> constraints, ArrayRef< size_t > paramConstraints)
Verify that the given list of parameters satisfy the given constraints.
static LogicalResult irdlRegionVerifier(Operation *op, ConstraintVerifier &verifier, ArrayRef< std::unique_ptr< RegionConstraint >> regionsConstraints)
static DenseMap< DialectOp, ExtensibleDialect * > loadEmptyDialects(ModuleOp op)
Load all dialects in the given module, without loading any operation, type or attribute definitions.
LogicalResult getResultSegmentSizes(Operation *op, ArrayRef< Variadicity > variadicities, SmallVectorImpl< int > &segmentSizes)
Compute the segment sizes of the given results.
Attributes are known-constant values of operations.
TypeID getTypeID()
Return a unique identifier for the concrete attribute type.
static std::unique_ptr< DynamicAttrDefinition > get(StringRef name, ExtensibleDialect *dialect, VerifierFn &&verifier)
Create a new attribute definition at runtime.
llvm::unique_function< LogicalResult(function_ref< InFlightDiagnostic()>, ArrayRef< Attribute >) const > VerifierFn
A dialect that can be defined at runtime.
static std::unique_ptr< DynamicOpDefinition > get(StringRef name, ExtensibleDialect *dialect, OperationName::VerifyInvariantsFn &&verifyFn, OperationName::VerifyRegionInvariantsFn &&verifyRegionFn)
Create a new op at runtime.
static std::unique_ptr< DynamicTypeDefinition > get(StringRef name, ExtensibleDialect *dialect, VerifierFn &&verifier)
Create a new dynamic type definition.
A dialect that can be extended with new operations/types/attributes at runtime.
void registerDynamicOp(std::unique_ptr< DynamicOpDefinition > &&type)
Add a new operation defined at runtime to the dialect.
void registerDynamicType(std::unique_ptr< DynamicTypeDefinition > &&type)
Add a new type defined at runtime to the dialect.
void registerDynamicAttr(std::unique_ptr< DynamicAttrDefinition > &&attr)
Add a new attribute defined at runtime to the dialect.
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
DynamicDialect * getOrLoadDynamicDialect(StringRef dialectNamespace, function_ref< void(DynamicDialect *)> ctor)
Get (or create) a dynamic dialect for the given name.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
DictionaryAttr getAttrDictionary()
Return all of the attributes on this operation as a DictionaryAttr.
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
unsigned getNumRegions()
Returns the number of regions held by this operation.
unsigned getNumOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
iterator_range< OpIterator > getOps()
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an efficient unique identifier for a specific C++ type.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
Provides context to the verification of constraints.
LogicalResult verify(function_ref< InFlightDiagnostic()> emitError, Attribute attr, unsigned variable)
Check that a constraint is satisfied by an attribute.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
LogicalResult loadDialects(ModuleOp op)
Load all the dialects defined in the module.
This header declares functions that assist transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
bool succeeded() const
Returns true if the provided LogicalResult corresponds to a success value.
static LogicalResult success(bool isSuccess=true)
If isSuccess is true a success result is generated, otherwise a 'failure' result is generated.
This represents an operation in an abstracted form, suitable for use with the builder APIs.