MLIR  19.0.0git
IRDLLoading.cpp
Go to the documentation of this file.
1 //===- IRDLLoading.cpp - IRDL dialect loading --------------------- C++ -*-===//
2 //
3 // This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Manages the loading of MLIR objects from IRDL operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
17 #include "mlir/IR/Attributes.h"
18 #include "mlir/IR/BuiltinOps.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/SmallPtrSet.h"
24 #include "llvm/Support/SMLoc.h"
25 #include <numeric>
26 
27 using namespace mlir;
28 using namespace mlir::irdl;
29 
30 /// Verify that the given list of parameters satisfy the given constraints.
31 /// This encodes the logic of the verification method for attributes and types
32 /// defined with IRDL.
33 static LogicalResult
35  ArrayRef<Attribute> params,
36  ArrayRef<std::unique_ptr<Constraint>> constraints,
37  ArrayRef<size_t> paramConstraints) {
38  if (params.size() != paramConstraints.size()) {
39  emitError() << "expected " << paramConstraints.size()
40  << " type arguments, but had " << params.size();
41  return failure();
42  }
43 
44  ConstraintVerifier verifier(constraints);
45 
46  // Check that each parameter satisfies its constraint.
47  for (auto [i, param] : enumerate(params))
48  if (failed(verifier.verify(emitError, param, paramConstraints[i])))
49  return failure();
50 
51  return success();
52 }
53 
54 /// Get the operand segment sizes from the attribute dictionary.
56  StringRef attrName, unsigned numElements,
57  ArrayRef<Variadicity> variadicities,
58  SmallVectorImpl<int> &segmentSizes) {
59  // Get the segment sizes attribute, and check that it is of the right type.
60  Attribute segmentSizesAttr = op->getAttr(attrName);
61  if (!segmentSizesAttr) {
62  return op->emitError() << "'" << attrName
63  << "' attribute is expected but not provided";
64  }
65 
66  auto denseSegmentSizes = dyn_cast<DenseI32ArrayAttr>(segmentSizesAttr);
67  if (!denseSegmentSizes) {
68  return op->emitError() << "'" << attrName
69  << "' attribute is expected to be a dense i32 array";
70  }
71 
72  if (denseSegmentSizes.size() != (int64_t)variadicities.size()) {
73  return op->emitError() << "'" << attrName << "' attribute for specifying "
74  << elemName << " segments must have "
75  << variadicities.size() << " elements, but got "
76  << denseSegmentSizes.size();
77  }
78 
79  // Check that the segment sizes are corresponding to the given variadicities,
80  for (auto [i, segmentSize, variadicity] :
81  enumerate(denseSegmentSizes.asArrayRef(), variadicities)) {
82  if (segmentSize < 0)
83  return op->emitError()
84  << "'" << attrName << "' attribute for specifying " << elemName
85  << " segments must have non-negative values";
86  if (variadicity == Variadicity::single && segmentSize != 1)
87  return op->emitError() << "element " << i << " in '" << attrName
88  << "' attribute must be equal to 1";
89 
90  if (variadicity == Variadicity::optional && segmentSize > 1)
91  return op->emitError() << "element " << i << " in '" << attrName
92  << "' attribute must be equal to 0 or 1";
93 
94  segmentSizes.push_back(segmentSize);
95  }
96 
97  // Check that the sum of the segment sizes is equal to the number of elements.
98  int32_t sum = 0;
99  for (int32_t segmentSize : denseSegmentSizes.asArrayRef())
100  sum += segmentSize;
101  if (sum != static_cast<int32_t>(numElements))
102  return op->emitError() << "sum of elements in '" << attrName
103  << "' attribute must be equal to the number of "
104  << elemName << "s";
105 
106  return success();
107 }
108 
109 /// Compute the segment sizes of the given element (operands, results).
110 /// If the operation has more than two non-single elements (optional or
111 /// variadic), then get the segment sizes from the attribute dictionary.
112 /// Otherwise, compute the segment sizes from the number of elements.
113 /// `elemName` should be either `"operand"` or `"result"`.
114 LogicalResult getSegmentSizes(Operation *op, StringRef elemName,
115  StringRef attrName, unsigned numElements,
116  ArrayRef<Variadicity> variadicities,
117  SmallVectorImpl<int> &segmentSizes) {
118  // If we have more than one non-single variadicity, we need to get the
119  // segment sizes from the attribute dictionary.
120  int numberNonSingle = count_if(
121  variadicities, [](Variadicity v) { return v != Variadicity::single; });
122  if (numberNonSingle > 1)
123  return getSegmentSizesFromAttr(op, elemName, attrName, numElements,
124  variadicities, segmentSizes);
125 
126  // If we only have single variadicities, the segments sizes are all 1.
127  if (numberNonSingle == 0) {
128  if (numElements != variadicities.size()) {
129  return op->emitError() << "op expects exactly " << variadicities.size()
130  << " " << elemName << "s, but got " << numElements;
131  }
132  for (size_t i = 0, e = variadicities.size(); i < e; ++i)
133  segmentSizes.push_back(1);
134  return success();
135  }
136 
137  assert(numberNonSingle == 1);
138 
139  // There is exactly one non-single element, so we can
140  // compute its size and check that it is valid.
141  int nonSingleSegmentSize = static_cast<int>(numElements) -
142  static_cast<int>(variadicities.size()) + 1;
143 
144  if (nonSingleSegmentSize < 0) {
145  return op->emitError() << "op expects at least " << variadicities.size() - 1
146  << " " << elemName << "s, but got " << numElements;
147  }
148 
149  // Add the segment sizes.
150  for (Variadicity variadicity : variadicities) {
151  if (variadicity == Variadicity::single) {
152  segmentSizes.push_back(1);
153  continue;
154  }
155 
156  // If we have an optional element, we should check that it represents
157  // zero or one elements.
158  if (nonSingleSegmentSize > 1 && variadicity == Variadicity::optional)
159  return op->emitError() << "op expects at most " << variadicities.size()
160  << " " << elemName << "s, but got " << numElements;
161 
162  segmentSizes.push_back(nonSingleSegmentSize);
163  }
164 
165  return success();
166 }
167 
168 /// Compute the segment sizes of the given operands.
169 /// If the operation has more than two non-single operands (optional or
170 /// variadic), then get the segment sizes from the attribute dictionary.
171 /// Otherwise, compute the segment sizes from the number of operands.
173  ArrayRef<Variadicity> variadicities,
174  SmallVectorImpl<int> &segmentSizes) {
175  return getSegmentSizes(op, "operand", "operand_segment_sizes",
176  op->getNumOperands(), variadicities, segmentSizes);
177 }
178 
179 /// Compute the segment sizes of the given results.
180 /// If the operation has more than two non-single results (optional or
181 /// variadic), then get the segment sizes from the attribute dictionary.
182 /// Otherwise, compute the segment sizes from the number of results.
184  ArrayRef<Variadicity> variadicities,
185  SmallVectorImpl<int> &segmentSizes) {
186  return getSegmentSizes(op, "result", "result_segment_sizes",
187  op->getNumResults(), variadicities, segmentSizes);
188 }
189 
190 /// Verify that the given operation satisfies the given constraints.
191 /// This encodes the logic of the verification method for operations defined
192 /// with IRDL.
194  Operation *op, ConstraintVerifier &verifier,
195  ArrayRef<size_t> operandConstrs, ArrayRef<Variadicity> operandVariadicity,
196  ArrayRef<size_t> resultConstrs, ArrayRef<Variadicity> resultVariadicity,
197  const DenseMap<StringAttr, size_t> &attributeConstrs) {
198  // Get the segment sizes for the operands.
199  // This will check that the number of operands is correct.
200  SmallVector<int> operandSegmentSizes;
201  if (failed(
202  getOperandSegmentSizes(op, operandVariadicity, operandSegmentSizes)))
203  return failure();
204 
205  // Get the segment sizes for the results.
206  // This will check that the number of results is correct.
207  SmallVector<int> resultSegmentSizes;
208  if (failed(getResultSegmentSizes(op, resultVariadicity, resultSegmentSizes)))
209  return failure();
210 
211  auto emitError = [op] { return op->emitError(); };
212 
213  /// Сheck that we have all needed attributes passed
214  /// and they satisfy the constraints.
215  DictionaryAttr actualAttrs = op->getAttrDictionary();
216 
217  for (auto [name, constraint] : attributeConstrs) {
218  /// First, check if the attribute actually passed.
219  std::optional<NamedAttribute> actual = actualAttrs.getNamed(name);
220  if (!actual.has_value())
221  return op->emitOpError()
222  << "attribute " << name << " is expected but not provided";
223 
224  /// Then, check if the attribute value satisfies the constraint.
225  if (failed(verifier.verify({emitError}, actual->getValue(), constraint)))
226  return failure();
227  }
228 
229  // Check that all operands satisfy the constraints
230  int operandIdx = 0;
231  for (auto [defIndex, segmentSize] : enumerate(operandSegmentSizes)) {
232  for (int i = 0; i < segmentSize; i++) {
233  if (failed(verifier.verify(
234  {emitError}, TypeAttr::get(op->getOperandTypes()[operandIdx]),
235  operandConstrs[defIndex])))
236  return failure();
237  ++operandIdx;
238  }
239  }
240 
241  // Check that all results satisfy the constraints
242  int resultIdx = 0;
243  for (auto [defIndex, segmentSize] : enumerate(resultSegmentSizes)) {
244  for (int i = 0; i < segmentSize; i++) {
245  if (failed(verifier.verify({emitError},
246  TypeAttr::get(op->getResultTypes()[resultIdx]),
247  resultConstrs[defIndex])))
248  return failure();
249  ++resultIdx;
250  }
251  }
252 
253  return success();
254 }
255 
257  Operation *op, ConstraintVerifier &verifier,
258  ArrayRef<std::unique_ptr<RegionConstraint>> regionsConstraints) {
259  if (op->getNumRegions() != regionsConstraints.size()) {
260  return op->emitOpError()
261  << "unexpected number of regions: expected "
262  << regionsConstraints.size() << " but got " << op->getNumRegions();
263  }
264 
265  for (auto [constraint, region] :
266  llvm::zip(regionsConstraints, op->getRegions()))
267  if (failed(constraint->verify(region, verifier)))
268  return failure();
269 
270  return success();
271 }
272 
273 /// Define and load an operation represented by a `irdl.operation`
274 /// operation.
276  OperationOp op, ExtensibleDialect *dialect,
277  DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> &types,
278  DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> &attrs) {
279  // Resolve SSA values to verifier constraint slots
280  SmallVector<Value> constrToValue;
281  SmallVector<Value> regionToValue;
282  for (Operation &op : op->getRegion(0).getOps()) {
283  if (isa<VerifyConstraintInterface>(op)) {
284  if (op.getNumResults() != 1)
285  return op.emitError()
286  << "IRDL constraint operations must have exactly one result";
287  constrToValue.push_back(op.getResult(0));
288  }
289  if (isa<VerifyRegionInterface>(op)) {
290  if (op.getNumResults() != 1)
291  return op.emitError()
292  << "IRDL constraint operations must have exactly one result";
293  regionToValue.push_back(op.getResult(0));
294  }
295  }
296 
297  // Build the verifiers for each constraint slot
299  for (Value v : constrToValue) {
300  VerifyConstraintInterface op =
301  cast<VerifyConstraintInterface>(v.getDefiningOp());
302  std::unique_ptr<Constraint> verifier =
303  op.getVerifier(constrToValue, types, attrs);
304  if (!verifier)
305  return WalkResult::interrupt();
306  constraints.push_back(std::move(verifier));
307  }
308 
309  // Build region constraints
311  for (Value v : regionToValue) {
312  VerifyRegionInterface op = cast<VerifyRegionInterface>(v.getDefiningOp());
313  std::unique_ptr<RegionConstraint> verifier =
314  op.getVerifier(constrToValue, types, attrs);
315  regionConstraints.push_back(std::move(verifier));
316  }
317 
318  SmallVector<size_t> operandConstraints;
319  SmallVector<Variadicity> operandVariadicity;
320 
321  // Gather which constraint slots correspond to operand constraints
322  auto operandsOp = op.getOp<OperandsOp>();
323  if (operandsOp.has_value()) {
324  operandConstraints.reserve(operandsOp->getArgs().size());
325  for (Value operand : operandsOp->getArgs()) {
326  for (auto [i, constr] : enumerate(constrToValue)) {
327  if (constr == operand) {
328  operandConstraints.push_back(i);
329  break;
330  }
331  }
332  }
333 
334  // Gather the variadicities of each operand
335  for (VariadicityAttr attr : operandsOp->getVariadicity())
336  operandVariadicity.push_back(attr.getValue());
337  }
338 
339  SmallVector<size_t> resultConstraints;
340  SmallVector<Variadicity> resultVariadicity;
341 
342  // Gather which constraint slots correspond to result constraints
343  auto resultsOp = op.getOp<ResultsOp>();
344  if (resultsOp.has_value()) {
345  resultConstraints.reserve(resultsOp->getArgs().size());
346  for (Value result : resultsOp->getArgs()) {
347  for (auto [i, constr] : enumerate(constrToValue)) {
348  if (constr == result) {
349  resultConstraints.push_back(i);
350  break;
351  }
352  }
353  }
354 
355  // Gather the variadicities of each result
356  for (Attribute attr : resultsOp->getVariadicity())
357  resultVariadicity.push_back(cast<VariadicityAttr>(attr).getValue());
358  }
359 
360  // Gather which constraint slots correspond to attributes constraints
361  DenseMap<StringAttr, size_t> attributesContraints;
362  auto attributesOp = op.getOp<AttributesOp>();
363  if (attributesOp.has_value()) {
364  const Operation::operand_range values = attributesOp->getAttributeValues();
365  const ArrayAttr names = attributesOp->getAttributeValueNames();
366 
367  for (const auto &[name, value] : llvm::zip(names, values)) {
368  for (auto [i, constr] : enumerate(constrToValue)) {
369  if (constr == value) {
370  attributesContraints[cast<StringAttr>(name)] = i;
371  break;
372  }
373  }
374  }
375  }
376 
377  // IRDL does not support defining custom parsers or printers.
378  auto parser = [](OpAsmParser &parser, OperationState &result) {
379  return failure();
380  };
381  auto printer = [](Operation *op, OpAsmPrinter &printer, StringRef) {
382  printer.printGenericOp(op);
383  };
384 
385  auto verifier =
386  [constraints{std::move(constraints)},
387  regionConstraints{std::move(regionConstraints)},
388  operandConstraints{std::move(operandConstraints)},
389  operandVariadicity{std::move(operandVariadicity)},
390  resultConstraints{std::move(resultConstraints)},
391  resultVariadicity{std::move(resultVariadicity)},
392  attributesContraints{std::move(attributesContraints)}](Operation *op) {
393  ConstraintVerifier verifier(constraints);
394  const LogicalResult opVerifierResult = irdlOpVerifier(
395  op, verifier, operandConstraints, operandVariadicity,
396  resultConstraints, resultVariadicity, attributesContraints);
397  const LogicalResult opRegionVerifierResult =
398  irdlRegionVerifier(op, verifier, regionConstraints);
399  return LogicalResult::success(opVerifierResult.succeeded() &&
400  opRegionVerifierResult.succeeded());
401  };
402 
403  // IRDL supports only checking number of blocks and argument contraints
404  // It is done in the main verifier to reuse `ConstraintVerifier` context
405  auto regionVerifier = [](Operation *op) { return LogicalResult::success(); };
406 
407  auto opDef = DynamicOpDefinition::get(
408  op.getName(), dialect, std::move(verifier), std::move(regionVerifier),
409  std::move(parser), std::move(printer));
410  dialect->registerDynamicOp(std::move(opDef));
411 
412  return WalkResult::advance();
413 }
414 
415 /// Get the verifier of a type or attribute definition.
416 /// Return nullptr if the definition is invalid.
418  Operation *attrOrTypeDef, ExtensibleDialect *dialect,
419  DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> &types,
420  DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> &attrs) {
421  assert((isa<AttributeOp>(attrOrTypeDef) || isa<TypeOp>(attrOrTypeDef)) &&
422  "Expected an attribute or type definition");
423 
424  // Resolve SSA values to verifier constraint slots
425  SmallVector<Value> constrToValue;
426  for (Operation &op : attrOrTypeDef->getRegion(0).getOps()) {
427  if (isa<VerifyConstraintInterface>(op)) {
428  assert(op.getNumResults() == 1 &&
429  "IRDL constraint operations must have exactly one result");
430  constrToValue.push_back(op.getResult(0));
431  }
432  }
433 
434  // Build the verifiers for each constraint slot
436  for (Value v : constrToValue) {
437  VerifyConstraintInterface op =
438  cast<VerifyConstraintInterface>(v.getDefiningOp());
439  std::unique_ptr<Constraint> verifier =
440  op.getVerifier(constrToValue, types, attrs);
441  if (!verifier)
442  return {};
443  constraints.push_back(std::move(verifier));
444  }
445 
446  // Get the parameter definitions.
447  std::optional<ParametersOp> params;
448  if (auto attr = dyn_cast<AttributeOp>(attrOrTypeDef))
449  params = attr.getOp<ParametersOp>();
450  else if (auto type = dyn_cast<TypeOp>(attrOrTypeDef))
451  params = type.getOp<ParametersOp>();
452 
453  // Gather which constraint slots correspond to parameter constraints
454  SmallVector<size_t> paramConstraints;
455  if (params.has_value()) {
456  paramConstraints.reserve(params->getArgs().size());
457  for (Value param : params->getArgs()) {
458  for (auto [i, constr] : enumerate(constrToValue)) {
459  if (constr == param) {
460  paramConstraints.push_back(i);
461  break;
462  }
463  }
464  }
465  }
466 
467  auto verifier = [paramConstraints{std::move(paramConstraints)},
468  constraints{std::move(constraints)}](
470  ArrayRef<Attribute> params) {
471  return irdlAttrOrTypeVerifier(emitError, params, constraints,
472  paramConstraints);
473  };
474 
475  // While the `std::move` is not required, not adding it triggers a bug in
476  // clang-10.
477  return std::move(verifier);
478 }
479 
480 /// Get the possible bases of a constraint. Return `true` if all bases can
481 /// potentially be matched.
482 /// A base is a type or an attribute definition. For instance, the base of
483 /// `irdl.parametric "!builtin.complex"(...)` is `builtin.complex`.
484 /// This function returns the following information through arguments:
485 /// - `paramIds`: the set of type or attribute IDs that are used as bases.
486 /// - `paramIrdlOps`: the set of IRDL operations that are used as bases.
487 /// - `isIds`: the set of type or attribute IDs that are used in `irdl.is`
488 /// constraints.
489 static bool getBases(Operation *op, SmallPtrSet<TypeID, 4> &paramIds,
490  SmallPtrSet<Operation *, 4> &paramIrdlOps,
491  SmallPtrSet<TypeID, 4> &isIds) {
492  // For `irdl.any_of`, we get the bases from all its arguments.
493  if (auto anyOf = dyn_cast<AnyOfOp>(op)) {
494  bool hasAny = false;
495  for (Value arg : anyOf.getArgs())
496  hasAny &= getBases(arg.getDefiningOp(), paramIds, paramIrdlOps, isIds);
497  return hasAny;
498  }
499 
500  // For `irdl.all_of`, we get the bases from the first argument.
501  // This is restrictive, but we can relax it later if needed.
502  if (auto allOf = dyn_cast<AllOfOp>(op))
503  return getBases(allOf.getArgs()[0].getDefiningOp(), paramIds, paramIrdlOps,
504  isIds);
505 
506  // For `irdl.parametric`, we get directly the base from the operation.
507  if (auto params = dyn_cast<ParametricOp>(op)) {
508  SymbolRefAttr symRef = params.getBaseType();
509  Operation *defOp = SymbolTable::lookupNearestSymbolFrom(op, symRef);
510  assert(defOp && "symbol reference should refer to an existing operation");
511  paramIrdlOps.insert(defOp);
512  return false;
513  }
514 
515  // For `irdl.is`, we get the base TypeID directly.
516  if (auto is = dyn_cast<IsOp>(op)) {
517  Attribute expected = is.getExpected();
518  isIds.insert(expected.getTypeID());
519  return false;
520  }
521 
522  // For `irdl.any`, we return `false` since we can match any type or attribute
523  // base.
524  if (auto isA = dyn_cast<AnyOp>(op))
525  return true;
526 
527  llvm_unreachable("unknown IRDL constraint");
528 }
529 
530 /// Check that an any_of is in the subset IRDL can handle.
531 /// IRDL uses a greedy algorithm to match constraints. This means that if we
532 /// encounter an `any_of` with multiple constraints, we will match the first
533 /// constraint that is satisfied. Thus, the order of constraints matter in
534 /// `any_of` with our current algorithm.
535 /// In order to make the order of constraints irrelevant, we require that
536 /// all `any_of` constraint parameters are disjoint. For this, we check that
537 /// the base parameters are all disjoints between `parametric` operations, and
538 /// that they are disjoint between `parametric` and `is` operations.
539 /// This restriction will be relaxed in the future, when we will change our
540 /// algorithm to be non-greedy.
541 static LogicalResult checkCorrectAnyOf(AnyOfOp anyOf) {
542  SmallPtrSet<TypeID, 4> paramIds;
543  SmallPtrSet<Operation *, 4> paramIrdlOps;
545 
546  for (Value arg : anyOf.getArgs()) {
547  Operation *argOp = arg.getDefiningOp();
548  SmallPtrSet<TypeID, 4> argParamIds;
549  SmallPtrSet<Operation *, 4> argParamIrdlOps;
550  SmallPtrSet<TypeID, 4> argIsIds;
551 
552  // Get the bases of this argument. If it can match any type or attribute,
553  // then our `any_of` should not be allowed.
554  if (getBases(argOp, argParamIds, argParamIrdlOps, argIsIds))
555  return failure();
556 
557  // We check that the base parameters are all disjoints between `parametric`
558  // operations, and that they are disjoint between `parametric` and `is`
559  // operations.
560  for (TypeID id : argParamIds) {
561  if (isIds.count(id))
562  return failure();
563  bool inserted = paramIds.insert(id).second;
564  if (!inserted)
565  return failure();
566  }
567 
568  // We check that the base parameters are all disjoints with `irdl.is`
569  // operations.
570  for (TypeID id : isIds) {
571  if (paramIds.count(id))
572  return failure();
573  isIds.insert(id);
574  }
575 
576  // We check that all `parametric` operations are disjoint. We do not
577  // need to check that they are disjoint with `is` operations, since
578  // `is` operations cannot refer to attributes defined with `irdl.parametric`
579  // operations.
580  for (Operation *op : argParamIrdlOps) {
581  bool inserted = paramIrdlOps.insert(op).second;
582  if (!inserted)
583  return failure();
584  }
585  }
586 
587  return success();
588 }
589 
590 /// Load all dialects in the given module, without loading any operation, type
591 /// or attribute definitions.
594  op.walk([&](DialectOp dialectOp) {
595  MLIRContext *ctx = dialectOp.getContext();
596  StringRef dialectName = dialectOp.getName();
597 
598  DynamicDialect *dialect = ctx->getOrLoadDynamicDialect(
599  dialectName, [](DynamicDialect *dialect) {});
600 
601  dialects.insert({dialectOp, dialect});
602  });
603  return dialects;
604 }
605 
606 /// Preallocate type definitions objects with empty verifiers.
607 /// This in particular allocates a TypeID for each type definition.
612  op.walk([&](TypeOp typeOp) {
613  ExtensibleDialect *dialect = dialects[typeOp.getParentOp()];
614  auto typeDef = DynamicTypeDefinition::get(
615  typeOp.getName(), dialect,
617  return success();
618  });
619  typeDefs.try_emplace(typeOp, std::move(typeDef));
620  });
621  return typeDefs;
622 }
623 
624 /// Preallocate attribute definitions objects with empty verifiers.
625 /// This in particular allocates a TypeID for each attribute definition.
630  op.walk([&](AttributeOp attrOp) {
631  ExtensibleDialect *dialect = dialects[attrOp.getParentOp()];
632  auto attrDef = DynamicAttrDefinition::get(
633  attrOp.getName(), dialect,
635  return success();
636  });
637  attrDefs.try_emplace(attrOp, std::move(attrDef));
638  });
639  return attrDefs;
640 }
641 
643  // First, check that all any_of constraints are in a correct form.
644  // This is to ensure we can do the verification correctly.
645  WalkResult anyOfCorrects = op.walk(
646  [](AnyOfOp anyOf) { return (WalkResult)checkCorrectAnyOf(anyOf); });
647  if (anyOfCorrects.wasInterrupted())
648  return op.emitError("any_of constraints are not in the correct form");
649 
650  // Preallocate all dialects, and type and attribute definitions.
651  // In particular, this allocates TypeIDs so type and attributes can have
652  // verifiers that refer to each other.
655  preallocateTypeDefs(op, dialects);
657  preallocateAttrDefs(op, dialects);
658 
659  // Set the verifier for types.
660  WalkResult res = op.walk([&](TypeOp typeOp) {
662  typeOp, dialects[typeOp.getParentOp()], types, attrs);
663  if (!verifier)
664  return WalkResult::interrupt();
665  types[typeOp]->setVerifyFn(std::move(verifier));
666  return WalkResult::advance();
667  });
668  if (res.wasInterrupted())
669  return failure();
670 
671  // Set the verifier for attributes.
672  res = op.walk([&](AttributeOp attrOp) {
674  attrOp, dialects[attrOp.getParentOp()], types, attrs);
675  if (!verifier)
676  return WalkResult::interrupt();
677  attrs[attrOp]->setVerifyFn(std::move(verifier));
678  return WalkResult::advance();
679  });
680  if (res.wasInterrupted())
681  return failure();
682 
683  // Define and load all operations.
684  res = op.walk([&](OperationOp opOp) {
685  return loadOperation(opOp, dialects[opOp.getParentOp()], types, attrs);
686  });
687  if (res.wasInterrupted())
688  return failure();
689 
690  // Load all types in their dialects.
691  for (auto &pair : types) {
692  ExtensibleDialect *dialect = dialects[pair.first.getParentOp()];
693  dialect->registerDynamicType(std::move(pair.second));
694  }
695 
696  // Load all attributes in their dialects.
697  for (auto &pair : attrs) {
698  ExtensibleDialect *dialect = dialects[pair.first.getParentOp()];
699  dialect->registerDynamicAttr(std::move(pair.second));
700  }
701 
702  return success();
703 }
static bool getBases(Operation *op, SmallPtrSet< TypeID, 4 > &paramIds, SmallPtrSet< Operation *, 4 > &paramIrdlOps, SmallPtrSet< TypeID, 4 > &isIds)
Get the possible bases of a constraint.
static LogicalResult checkCorrectAnyOf(AnyOfOp anyOf)
Check that an any_of is in the subset IRDL can handle.
static DenseMap< AttributeOp, std::unique_ptr< DynamicAttrDefinition > > preallocateAttrDefs(ModuleOp op, DenseMap< DialectOp, ExtensibleDialect * > dialects)
Preallocate attribute definitions objects with empty verifiers.
LogicalResult getSegmentSizesFromAttr(Operation *op, StringRef elemName, StringRef attrName, unsigned numElements, ArrayRef< Variadicity > variadicities, SmallVectorImpl< int > &segmentSizes)
Get the operand segment sizes from the attribute dictionary.
Definition: IRDLLoading.cpp:55
static DynamicAttrDefinition::VerifierFn getAttrOrTypeVerifier(Operation *attrOrTypeDef, ExtensibleDialect *dialect, DenseMap< TypeOp, std::unique_ptr< DynamicTypeDefinition >> &types, DenseMap< AttributeOp, std::unique_ptr< DynamicAttrDefinition >> &attrs)
Get the verifier of a type or attribute definition.
static WalkResult loadOperation(OperationOp op, ExtensibleDialect *dialect, DenseMap< TypeOp, std::unique_ptr< DynamicTypeDefinition >> &types, DenseMap< AttributeOp, std::unique_ptr< DynamicAttrDefinition >> &attrs)
Define and load an operation represented by a irdl.operation operation.
static LogicalResult irdlOpVerifier(Operation *op, ConstraintVerifier &verifier, ArrayRef< size_t > operandConstrs, ArrayRef< Variadicity > operandVariadicity, ArrayRef< size_t > resultConstrs, ArrayRef< Variadicity > resultVariadicity, const DenseMap< StringAttr, size_t > &attributeConstrs)
Verify that the given operation satisfies the given constraints.
LogicalResult getOperandSegmentSizes(Operation *op, ArrayRef< Variadicity > variadicities, SmallVectorImpl< int > &segmentSizes)
Compute the segment sizes of the given operands.
static DenseMap< TypeOp, std::unique_ptr< DynamicTypeDefinition > > preallocateTypeDefs(ModuleOp op, DenseMap< DialectOp, ExtensibleDialect * > dialects)
Preallocate type definitions objects with empty verifiers.
LogicalResult getSegmentSizes(Operation *op, StringRef elemName, StringRef attrName, unsigned numElements, ArrayRef< Variadicity > variadicities, SmallVectorImpl< int > &segmentSizes)
Compute the segment sizes of the given element (operands, results).
static LogicalResult irdlAttrOrTypeVerifier(function_ref< InFlightDiagnostic()> emitError, ArrayRef< Attribute > params, ArrayRef< std::unique_ptr< Constraint >> constraints, ArrayRef< size_t > paramConstraints)
Verify that the given list of parameters satisfy the given constraints.
Definition: IRDLLoading.cpp:34
static LogicalResult irdlRegionVerifier(Operation *op, ConstraintVerifier &verifier, ArrayRef< std::unique_ptr< RegionConstraint >> regionsConstraints)
static DenseMap< DialectOp, ExtensibleDialect * > loadEmptyDialects(ModuleOp op)
Load all dialects in the given module, without loading any operation, type or attribute definitions.
LogicalResult getResultSegmentSizes(Operation *op, ArrayRef< Variadicity > variadicities, SmallVectorImpl< int > &segmentSizes)
Compute the segment sizes of the given results.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
TypeID getTypeID()
Return a unique identifier for the concrete attribute type.
Definition: Attributes.h:65
static std::unique_ptr< DynamicAttrDefinition > get(StringRef name, ExtensibleDialect *dialect, VerifierFn &&verifier)
Create a new attribute definition at runtime.
llvm::unique_function< LogicalResult(function_ref< InFlightDiagnostic()>, ArrayRef< Attribute >) const > VerifierFn
A dialect that can be defined at runtime.
static std::unique_ptr< DynamicOpDefinition > get(StringRef name, ExtensibleDialect *dialect, OperationName::VerifyInvariantsFn &&verifyFn, OperationName::VerifyRegionInvariantsFn &&verifyRegionFn)
Create a new op at runtime.
static std::unique_ptr< DynamicTypeDefinition > get(StringRef name, ExtensibleDialect *dialect, VerifierFn &&verifier)
Create a new dynamic type definition.
A dialect that can be extended with new operations/types/attributes at runtime.
void registerDynamicOp(std::unique_ptr< DynamicOpDefinition > &&type)
Add a new operation defined at runtime to the dialect.
void registerDynamicType(std::unique_ptr< DynamicTypeDefinition > &&type)
Add a new type defined at runtime to the dialect.
void registerDynamicAttr(std::unique_ptr< DynamicAttrDefinition > &&attr)
Add a new attribute defined at runtime to the dialect.
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:308
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
DynamicDialect * getOrLoadDynamicDialect(StringRef dialectNamespace, function_ref< void(DynamicDialect *)> ctor)
Get (or create) a dynamic dialect for the given name.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
DictionaryAttr getAttrDictionary()
Return all of the attributes on this operation as a DictionaryAttr.
Definition: Operation.cpp:296
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:529
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:669
unsigned getNumOperands()
Definition: Operation.h:341
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:672
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_type_range getOperandTypes()
Definition: Operation.h:392
result_type_range getResultTypes()
Definition: Operation.h:423
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
iterator_range< OpIterator > getOps()
Definition: Region.h:172
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:104
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:34
static WalkResult advance()
Definition: Visitors.h:52
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition: Visitors.h:56
static WalkResult interrupt()
Definition: Visitors.h:51
Provides context to the verification of constraints.
Definition: IRDLVerifiers.h:38
LogicalResult verify(function_ref< InFlightDiagnostic()> emitError, Attribute attr, unsigned variable)
Check that a constraint is satisfied by an attribute.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
LogicalResult loadDialects(ModuleOp op)
Load all the dialects defined in the module.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
bool succeeded() const
Returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:41
static LogicalResult success(bool isSuccess=true)
If isSuccess is true a success result is generated, otherwise a 'failure' result is generated.
Definition: LogicalResult.h:30
This represents an operation in an abstracted form, suitable for use with the builder APIs.