MLIR  22.0.0git
IRDLLoading.cpp
Go to the documentation of this file.
1 //===- IRDLLoading.cpp - IRDL dialect loading --------------------- C++ -*-===//
2 //
3 // This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Manages the loading of MLIR objects from IRDL operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
18 #include "mlir/IR/Attributes.h"
19 #include "mlir/IR/BuiltinOps.h"
22 #include "llvm/ADT/STLExtras.h"
23 
24 using namespace mlir;
25 using namespace mlir::irdl;
26 
27 /// Verify that the given list of parameters satisfy the given constraints.
28 /// This encodes the logic of the verification method for attributes and types
29 /// defined with IRDL.
30 static LogicalResult
32  ArrayRef<Attribute> params,
33  ArrayRef<std::unique_ptr<Constraint>> constraints,
34  ArrayRef<size_t> paramConstraints) {
35  if (params.size() != paramConstraints.size()) {
36  emitError() << "expected " << paramConstraints.size()
37  << " type arguments, but had " << params.size();
38  return failure();
39  }
40 
41  ConstraintVerifier verifier(constraints);
42 
43  // Check that each parameter satisfies its constraint.
44  for (auto [i, param] : enumerate(params))
45  if (failed(verifier.verify(emitError, param, paramConstraints[i])))
46  return failure();
47 
48  return success();
49 }
50 
51 /// Get the operand segment sizes from the attribute dictionary.
52 LogicalResult getSegmentSizesFromAttr(Operation *op, StringRef elemName,
53  StringRef attrName, unsigned numElements,
54  ArrayRef<Variadicity> variadicities,
55  SmallVectorImpl<int> &segmentSizes) {
56  // Get the segment sizes attribute, and check that it is of the right type.
57  Attribute segmentSizesAttr = op->getAttr(attrName);
58  if (!segmentSizesAttr) {
59  return op->emitError() << "'" << attrName
60  << "' attribute is expected but not provided";
61  }
62 
63  auto denseSegmentSizes = dyn_cast<DenseI32ArrayAttr>(segmentSizesAttr);
64  if (!denseSegmentSizes) {
65  return op->emitError() << "'" << attrName
66  << "' attribute is expected to be a dense i32 array";
67  }
68 
69  if (denseSegmentSizes.size() != (int64_t)variadicities.size()) {
70  return op->emitError() << "'" << attrName << "' attribute for specifying "
71  << elemName << " segments must have "
72  << variadicities.size() << " elements, but got "
73  << denseSegmentSizes.size();
74  }
75 
76  // Check that the segment sizes are corresponding to the given variadicities,
77  for (auto [i, segmentSize, variadicity] :
78  enumerate(denseSegmentSizes.asArrayRef(), variadicities)) {
79  if (segmentSize < 0)
80  return op->emitError()
81  << "'" << attrName << "' attribute for specifying " << elemName
82  << " segments must have non-negative values";
83  if (variadicity == Variadicity::single && segmentSize != 1)
84  return op->emitError() << "element " << i << " in '" << attrName
85  << "' attribute must be equal to 1";
86 
87  if (variadicity == Variadicity::optional && segmentSize > 1)
88  return op->emitError() << "element " << i << " in '" << attrName
89  << "' attribute must be equal to 0 or 1";
90 
91  segmentSizes.push_back(segmentSize);
92  }
93 
94  // Check that the sum of the segment sizes is equal to the number of elements.
95  int32_t sum = 0;
96  for (int32_t segmentSize : denseSegmentSizes.asArrayRef())
97  sum += segmentSize;
98  if (sum != static_cast<int32_t>(numElements))
99  return op->emitError() << "sum of elements in '" << attrName
100  << "' attribute must be equal to the number of "
101  << elemName << "s";
102 
103  return success();
104 }
105 
106 /// Compute the segment sizes of the given element (operands, results).
107 /// If the operation has more than two non-single elements (optional or
108 /// variadic), then get the segment sizes from the attribute dictionary.
109 /// Otherwise, compute the segment sizes from the number of elements.
110 /// `elemName` should be either `"operand"` or `"result"`.
111 LogicalResult getSegmentSizes(Operation *op, StringRef elemName,
112  StringRef attrName, unsigned numElements,
113  ArrayRef<Variadicity> variadicities,
114  SmallVectorImpl<int> &segmentSizes) {
115  // If we have more than one non-single variadicity, we need to get the
116  // segment sizes from the attribute dictionary.
117  int numberNonSingle = count_if(
118  variadicities, [](Variadicity v) { return v != Variadicity::single; });
119  if (numberNonSingle > 1)
120  return getSegmentSizesFromAttr(op, elemName, attrName, numElements,
121  variadicities, segmentSizes);
122 
123  // If we only have single variadicities, the segments sizes are all 1.
124  if (numberNonSingle == 0) {
125  if (numElements != variadicities.size()) {
126  return op->emitError() << "op expects exactly " << variadicities.size()
127  << " " << elemName << "s, but got " << numElements;
128  }
129  for (size_t i = 0, e = variadicities.size(); i < e; ++i)
130  segmentSizes.push_back(1);
131  return success();
132  }
133 
134  assert(numberNonSingle == 1);
135 
136  // There is exactly one non-single element, so we can
137  // compute its size and check that it is valid.
138  int nonSingleSegmentSize = static_cast<int>(numElements) -
139  static_cast<int>(variadicities.size()) + 1;
140 
141  if (nonSingleSegmentSize < 0) {
142  return op->emitError() << "op expects at least " << variadicities.size() - 1
143  << " " << elemName << "s, but got " << numElements;
144  }
145 
146  // Add the segment sizes.
147  for (Variadicity variadicity : variadicities) {
148  if (variadicity == Variadicity::single) {
149  segmentSizes.push_back(1);
150  continue;
151  }
152 
153  // If we have an optional element, we should check that it represents
154  // zero or one elements.
155  if (nonSingleSegmentSize > 1 && variadicity == Variadicity::optional)
156  return op->emitError() << "op expects at most " << variadicities.size()
157  << " " << elemName << "s, but got " << numElements;
158 
159  segmentSizes.push_back(nonSingleSegmentSize);
160  }
161 
162  return success();
163 }
164 
165 /// Compute the segment sizes of the given operands.
166 /// If the operation has more than two non-single operands (optional or
167 /// variadic), then get the segment sizes from the attribute dictionary.
168 /// Otherwise, compute the segment sizes from the number of operands.
170  ArrayRef<Variadicity> variadicities,
171  SmallVectorImpl<int> &segmentSizes) {
172  return getSegmentSizes(op, "operand", "operand_segment_sizes",
173  op->getNumOperands(), variadicities, segmentSizes);
174 }
175 
176 /// Compute the segment sizes of the given results.
177 /// If the operation has more than two non-single results (optional or
178 /// variadic), then get the segment sizes from the attribute dictionary.
179 /// Otherwise, compute the segment sizes from the number of results.
181  ArrayRef<Variadicity> variadicities,
182  SmallVectorImpl<int> &segmentSizes) {
183  return getSegmentSizes(op, "result", "result_segment_sizes",
184  op->getNumResults(), variadicities, segmentSizes);
185 }
186 
187 /// Verify that the given operation satisfies the given constraints.
188 /// This encodes the logic of the verification method for operations defined
189 /// with IRDL.
190 static LogicalResult irdlOpVerifier(
191  Operation *op, ConstraintVerifier &verifier,
192  ArrayRef<size_t> operandConstrs, ArrayRef<Variadicity> operandVariadicity,
193  ArrayRef<size_t> resultConstrs, ArrayRef<Variadicity> resultVariadicity,
194  const DenseMap<StringAttr, size_t> &attributeConstrs) {
195  // Get the segment sizes for the operands.
196  // This will check that the number of operands is correct.
197  SmallVector<int> operandSegmentSizes;
198  if (failed(
199  getOperandSegmentSizes(op, operandVariadicity, operandSegmentSizes)))
200  return failure();
201 
202  // Get the segment sizes for the results.
203  // This will check that the number of results is correct.
204  SmallVector<int> resultSegmentSizes;
205  if (failed(getResultSegmentSizes(op, resultVariadicity, resultSegmentSizes)))
206  return failure();
207 
208  auto emitError = [op] { return op->emitError(); };
209 
210  /// Сheck that we have all needed attributes passed
211  /// and they satisfy the constraints.
212  DictionaryAttr actualAttrs = op->getAttrDictionary();
213 
214  for (auto [name, constraint] : attributeConstrs) {
215  /// First, check if the attribute actually passed.
216  std::optional<NamedAttribute> actual = actualAttrs.getNamed(name);
217  if (!actual.has_value())
218  return op->emitOpError()
219  << "attribute " << name << " is expected but not provided";
220 
221  /// Then, check if the attribute value satisfies the constraint.
222  if (failed(verifier.verify({emitError}, actual->getValue(), constraint)))
223  return failure();
224  }
225 
226  // Check that all operands satisfy the constraints
227  int operandIdx = 0;
228  for (auto [defIndex, segmentSize] : enumerate(operandSegmentSizes)) {
229  for (int i = 0; i < segmentSize; i++) {
230  if (failed(verifier.verify(
231  {emitError}, TypeAttr::get(op->getOperandTypes()[operandIdx]),
232  operandConstrs[defIndex])))
233  return failure();
234  ++operandIdx;
235  }
236  }
237 
238  // Check that all results satisfy the constraints
239  int resultIdx = 0;
240  for (auto [defIndex, segmentSize] : enumerate(resultSegmentSizes)) {
241  for (int i = 0; i < segmentSize; i++) {
242  if (failed(verifier.verify({emitError},
243  TypeAttr::get(op->getResultTypes()[resultIdx]),
244  resultConstrs[defIndex])))
245  return failure();
246  ++resultIdx;
247  }
248  }
249 
250  return success();
251 }
252 
253 static LogicalResult irdlRegionVerifier(
254  Operation *op, ConstraintVerifier &verifier,
255  ArrayRef<std::unique_ptr<RegionConstraint>> regionsConstraints) {
256  if (op->getNumRegions() != regionsConstraints.size()) {
257  return op->emitOpError()
258  << "unexpected number of regions: expected "
259  << regionsConstraints.size() << " but got " << op->getNumRegions();
260  }
261 
262  for (auto [constraint, region] :
263  llvm::zip(regionsConstraints, op->getRegions()))
264  if (failed(constraint->verify(region, verifier)))
265  return failure();
266 
267  return success();
268 }
269 
270 llvm::unique_function<LogicalResult(Operation *) const>
272  OperationOp op,
273  const DenseMap<irdl::TypeOp, std::unique_ptr<DynamicTypeDefinition>> &types,
274  const DenseMap<irdl::AttributeOp, std::unique_ptr<DynamicAttrDefinition>>
275  &attrs) {
276  // Resolve SSA values to verifier constraint slots
277  SmallVector<Value> constrToValue;
278  SmallVector<Value> regionToValue;
279  for (Operation &op : op->getRegion(0).getOps()) {
280  if (isa<VerifyConstraintInterface>(op)) {
281  if (op.getNumResults() != 1) {
282  op.emitError()
283  << "IRDL constraint operations must have exactly one result";
284  return nullptr;
285  }
286  constrToValue.push_back(op.getResult(0));
287  }
288  if (isa<VerifyRegionInterface>(op)) {
289  if (op.getNumResults() != 1) {
290  op.emitError()
291  << "IRDL constraint operations must have exactly one result";
292  return nullptr;
293  }
294  regionToValue.push_back(op.getResult(0));
295  }
296  }
297 
298  // Build the verifiers for each constraint slot
300  for (Value v : constrToValue) {
301  VerifyConstraintInterface op =
302  cast<VerifyConstraintInterface>(v.getDefiningOp());
303  std::unique_ptr<Constraint> verifier =
304  op.getVerifier(constrToValue, types, attrs);
305  if (!verifier)
306  return nullptr;
307  constraints.push_back(std::move(verifier));
308  }
309 
310  // Build region constraints
312  for (Value v : regionToValue) {
313  VerifyRegionInterface op = cast<VerifyRegionInterface>(v.getDefiningOp());
314  std::unique_ptr<RegionConstraint> verifier =
315  op.getVerifier(constrToValue, types, attrs);
316  regionConstraints.push_back(std::move(verifier));
317  }
318 
319  SmallVector<size_t> operandConstraints;
320  SmallVector<Variadicity> operandVariadicity;
321 
322  // Gather which constraint slots correspond to operand constraints
323  auto operandsOp = op.getOp<OperandsOp>();
324  if (operandsOp.has_value()) {
325  operandConstraints.reserve(operandsOp->getArgs().size());
326  for (Value operand : operandsOp->getArgs()) {
327  for (auto [i, constr] : enumerate(constrToValue)) {
328  if (constr == operand) {
329  operandConstraints.push_back(i);
330  break;
331  }
332  }
333  }
334 
335  // Gather the variadicities of each operand
336  for (VariadicityAttr attr : operandsOp->getVariadicity())
337  operandVariadicity.push_back(attr.getValue());
338  }
339 
340  SmallVector<size_t> resultConstraints;
341  SmallVector<Variadicity> resultVariadicity;
342 
343  // Gather which constraint slots correspond to result constraints
344  auto resultsOp = op.getOp<ResultsOp>();
345  if (resultsOp.has_value()) {
346  resultConstraints.reserve(resultsOp->getArgs().size());
347  for (Value result : resultsOp->getArgs()) {
348  for (auto [i, constr] : enumerate(constrToValue)) {
349  if (constr == result) {
350  resultConstraints.push_back(i);
351  break;
352  }
353  }
354  }
355 
356  // Gather the variadicities of each result
357  for (Attribute attr : resultsOp->getVariadicity())
358  resultVariadicity.push_back(cast<VariadicityAttr>(attr).getValue());
359  }
360 
361  // Gather which constraint slots correspond to attributes constraints
362  DenseMap<StringAttr, size_t> attributeConstraints;
363  auto attributesOp = op.getOp<AttributesOp>();
364  if (attributesOp.has_value()) {
365  const Operation::operand_range values = attributesOp->getAttributeValues();
366  const ArrayAttr names = attributesOp->getAttributeValueNames();
367 
368  for (const auto &[name, value] : llvm::zip(names, values)) {
369  for (auto [i, constr] : enumerate(constrToValue)) {
370  if (constr == value) {
371  attributeConstraints[cast<StringAttr>(name)] = i;
372  break;
373  }
374  }
375  }
376  }
377 
378  return
379  [constraints{std::move(constraints)},
380  regionConstraints{std::move(regionConstraints)},
381  operandConstraints{std::move(operandConstraints)},
382  operandVariadicity{std::move(operandVariadicity)},
383  resultConstraints{std::move(resultConstraints)},
384  resultVariadicity{std::move(resultVariadicity)},
385  attributeConstraints{std::move(attributeConstraints)}](Operation *op) {
386  ConstraintVerifier verifier(constraints);
387  const LogicalResult opVerifierResult = irdlOpVerifier(
388  op, verifier, operandConstraints, operandVariadicity,
389  resultConstraints, resultVariadicity, attributeConstraints);
390  const LogicalResult opRegionVerifierResult =
391  irdlRegionVerifier(op, verifier, regionConstraints);
392  return LogicalResult::success(opVerifierResult.succeeded() &&
393  opRegionVerifierResult.succeeded());
394  };
395 }
396 
397 /// Define and load an operation represented by a `irdl.operation`
398 /// operation.
400  OperationOp op, ExtensibleDialect *dialect,
401  const DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> &types,
402  const DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>>
403  &attrs) {
404 
405  // IRDL does not support defining custom parsers or printers.
406  auto parser = [](OpAsmParser &parser, OperationState &result) {
407  return failure();
408  };
409  auto printer = [](Operation *op, OpAsmPrinter &printer, StringRef) {
410  printer.printGenericOp(op);
411  };
412 
413  auto verifier = createVerifier(op, types, attrs);
414  if (!verifier)
415  return WalkResult::interrupt();
416 
417  // IRDL supports only checking number of blocks and argument constraints
418  // It is done in the main verifier to reuse `ConstraintVerifier` context
419  auto regionVerifier = [](Operation *op) { return LogicalResult::success(); };
420 
421  auto opDef = DynamicOpDefinition::get(
422  op.getName(), dialect, std::move(verifier), std::move(regionVerifier),
423  std::move(parser), std::move(printer));
424  dialect->registerDynamicOp(std::move(opDef));
425 
426  return WalkResult::advance();
427 }
428 
429 /// Get the verifier of a type or attribute definition.
430 /// Return nullptr if the definition is invalid.
432  Operation *attrOrTypeDef, ExtensibleDialect *dialect,
433  DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> &types,
434  DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> &attrs) {
435  assert((isa<AttributeOp>(attrOrTypeDef) || isa<TypeOp>(attrOrTypeDef)) &&
436  "Expected an attribute or type definition");
437 
438  // Resolve SSA values to verifier constraint slots
439  SmallVector<Value> constrToValue;
440  for (Operation &op : attrOrTypeDef->getRegion(0).getOps()) {
441  if (isa<VerifyConstraintInterface>(op)) {
442  assert(op.getNumResults() == 1 &&
443  "IRDL constraint operations must have exactly one result");
444  constrToValue.push_back(op.getResult(0));
445  }
446  }
447 
448  // Build the verifiers for each constraint slot
450  for (Value v : constrToValue) {
451  VerifyConstraintInterface op =
452  cast<VerifyConstraintInterface>(v.getDefiningOp());
453  std::unique_ptr<Constraint> verifier =
454  op.getVerifier(constrToValue, types, attrs);
455  if (!verifier)
456  return {};
457  constraints.push_back(std::move(verifier));
458  }
459 
460  // Get the parameter definitions.
461  std::optional<ParametersOp> params;
462  if (auto attr = dyn_cast<AttributeOp>(attrOrTypeDef))
463  params = attr.getOp<ParametersOp>();
464  else if (auto type = dyn_cast<TypeOp>(attrOrTypeDef))
465  params = type.getOp<ParametersOp>();
466 
467  // Gather which constraint slots correspond to parameter constraints
468  SmallVector<size_t> paramConstraints;
469  if (params.has_value()) {
470  paramConstraints.reserve(params->getArgs().size());
471  for (Value param : params->getArgs()) {
472  for (auto [i, constr] : enumerate(constrToValue)) {
473  if (constr == param) {
474  paramConstraints.push_back(i);
475  break;
476  }
477  }
478  }
479  }
480 
481  auto verifier = [paramConstraints{std::move(paramConstraints)},
482  constraints{std::move(constraints)}](
484  ArrayRef<Attribute> params) {
485  return irdlAttrOrTypeVerifier(emitError, params, constraints,
486  paramConstraints);
487  };
488 
489  // While the `std::move` is not required, not adding it triggers a bug in
490  // clang-10.
491  return std::move(verifier);
492 }
493 
494 /// Get the possible bases of a constraint. Return `true` if all bases can
495 /// potentially be matched.
496 /// A base is a type or an attribute definition. For instance, the base of
497 /// `irdl.parametric "!builtin.complex"(...)` is `builtin.complex`.
498 /// This function returns the following information through arguments:
499 /// - `paramIds`: the set of type or attribute IDs that are used as bases.
500 /// - `paramIrdlOps`: the set of IRDL operations that are used as bases.
501 /// - `isIds`: the set of type or attribute IDs that are used in `irdl.is`
502 /// constraints.
503 static bool getBases(Operation *op, SmallPtrSet<TypeID, 4> &paramIds,
504  SmallPtrSet<Operation *, 4> &paramIrdlOps,
505  SmallPtrSet<TypeID, 4> &isIds) {
506  // For `irdl.any_of`, we get the bases from all its arguments.
507  if (auto anyOf = dyn_cast<AnyOfOp>(op)) {
508  bool hasAny = false;
509  for (Value arg : anyOf.getArgs())
510  hasAny &= getBases(arg.getDefiningOp(), paramIds, paramIrdlOps, isIds);
511  return hasAny;
512  }
513 
514  // For `irdl.all_of`, we get the bases from the first argument.
515  // This is restrictive, but we can relax it later if needed.
516  if (auto allOf = dyn_cast<AllOfOp>(op))
517  return getBases(allOf.getArgs()[0].getDefiningOp(), paramIds, paramIrdlOps,
518  isIds);
519 
520  // For `irdl.parametric`, we get directly the base from the operation.
521  if (auto params = dyn_cast<ParametricOp>(op)) {
522  SymbolRefAttr symRef = params.getBaseType();
523  Operation *defOp = irdl::lookupSymbolNearDialect(op, symRef);
524  assert(defOp && "symbol reference should refer to an existing operation");
525  paramIrdlOps.insert(defOp);
526  return false;
527  }
528 
529  // For `irdl.is`, we get the base TypeID directly.
530  if (auto is = dyn_cast<IsOp>(op)) {
531  Attribute expected = is.getExpected();
532  isIds.insert(expected.getTypeID());
533  return false;
534  }
535 
536  // For `irdl.any`, we return `false` since we can match any type or attribute
537  // base.
538  if (auto isA = dyn_cast<AnyOp>(op))
539  return true;
540 
541  llvm_unreachable("unknown IRDL constraint");
542 }
543 
544 /// Check that an any_of is in the subset IRDL can handle.
545 /// IRDL uses a greedy algorithm to match constraints. This means that if we
546 /// encounter an `any_of` with multiple constraints, we will match the first
547 /// constraint that is satisfied. Thus, the order of constraints matter in
548 /// `any_of` with our current algorithm.
549 /// In order to make the order of constraints irrelevant, we require that
550 /// all `any_of` constraint parameters are disjoint. For this, we check that
551 /// the base parameters are all disjoints between `parametric` operations, and
552 /// that they are disjoint between `parametric` and `is` operations.
553 /// This restriction will be relaxed in the future, when we will change our
554 /// algorithm to be non-greedy.
555 static LogicalResult checkCorrectAnyOf(AnyOfOp anyOf) {
556  SmallPtrSet<TypeID, 4> paramIds;
557  SmallPtrSet<Operation *, 4> paramIrdlOps;
559 
560  for (Value arg : anyOf.getArgs()) {
561  Operation *argOp = arg.getDefiningOp();
562  SmallPtrSet<TypeID, 4> argParamIds;
563  SmallPtrSet<Operation *, 4> argParamIrdlOps;
564  SmallPtrSet<TypeID, 4> argIsIds;
565 
566  // Get the bases of this argument. If it can match any type or attribute,
567  // then our `any_of` should not be allowed.
568  if (getBases(argOp, argParamIds, argParamIrdlOps, argIsIds))
569  return failure();
570 
571  // We check that the base parameters are all disjoints between `parametric`
572  // operations, and that they are disjoint between `parametric` and `is`
573  // operations.
574  for (TypeID id : argParamIds) {
575  if (isIds.count(id))
576  return failure();
577  bool inserted = paramIds.insert(id).second;
578  if (!inserted)
579  return failure();
580  }
581 
582  // We check that the base parameters are all disjoints with `irdl.is`
583  // operations.
584  for (TypeID id : isIds) {
585  if (paramIds.count(id))
586  return failure();
587  isIds.insert(id);
588  }
589 
590  // We check that all `parametric` operations are disjoint. We do not
591  // need to check that they are disjoint with `is` operations, since
592  // `is` operations cannot refer to attributes defined with `irdl.parametric`
593  // operations.
594  for (Operation *op : argParamIrdlOps) {
595  bool inserted = paramIrdlOps.insert(op).second;
596  if (!inserted)
597  return failure();
598  }
599  }
600 
601  return success();
602 }
603 
604 /// Load all dialects in the given module, without loading any operation, type
605 /// or attribute definitions.
608  op.walk([&](DialectOp dialectOp) {
609  MLIRContext *ctx = dialectOp.getContext();
610  StringRef dialectName = dialectOp.getName();
611 
612  DynamicDialect *dialect = ctx->getOrLoadDynamicDialect(
613  dialectName, [](DynamicDialect *dialect) {});
614 
615  dialects.insert({dialectOp, dialect});
616  });
617  return dialects;
618 }
619 
620 /// Preallocate type definitions objects with empty verifiers.
621 /// This in particular allocates a TypeID for each type definition.
626  op.walk([&](TypeOp typeOp) {
627  ExtensibleDialect *dialect = dialects[typeOp.getParentOp()];
628  auto typeDef = DynamicTypeDefinition::get(
629  typeOp.getName(), dialect,
631  return success();
632  });
633  typeDefs.try_emplace(typeOp, std::move(typeDef));
634  });
635  return typeDefs;
636 }
637 
638 /// Preallocate attribute definitions objects with empty verifiers.
639 /// This in particular allocates a TypeID for each attribute definition.
644  op.walk([&](AttributeOp attrOp) {
645  ExtensibleDialect *dialect = dialects[attrOp.getParentOp()];
646  auto attrDef = DynamicAttrDefinition::get(
647  attrOp.getName(), dialect,
649  return success();
650  });
651  attrDefs.try_emplace(attrOp, std::move(attrDef));
652  });
653  return attrDefs;
654 }
655 
656 LogicalResult mlir::irdl::loadDialects(ModuleOp op) {
657  // First, check that all any_of constraints are in a correct form.
658  // This is to ensure we can do the verification correctly.
659  WalkResult anyOfCorrects = op.walk(
660  [](AnyOfOp anyOf) { return (WalkResult)checkCorrectAnyOf(anyOf); });
661  if (anyOfCorrects.wasInterrupted())
662  return op.emitError("any_of constraints are not in the correct form");
663 
664  // Preallocate all dialects, and type and attribute definitions.
665  // In particular, this allocates TypeIDs so type and attributes can have
666  // verifiers that refer to each other.
669  preallocateTypeDefs(op, dialects);
671  preallocateAttrDefs(op, dialects);
672 
673  // Set the verifier for types.
674  WalkResult res = op.walk([&](TypeOp typeOp) {
676  typeOp, dialects[typeOp.getParentOp()], types, attrs);
677  if (!verifier)
678  return WalkResult::interrupt();
679  types[typeOp]->setVerifyFn(std::move(verifier));
680  return WalkResult::advance();
681  });
682  if (res.wasInterrupted())
683  return failure();
684 
685  // Set the verifier for attributes.
686  res = op.walk([&](AttributeOp attrOp) {
688  attrOp, dialects[attrOp.getParentOp()], types, attrs);
689  if (!verifier)
690  return WalkResult::interrupt();
691  attrs[attrOp]->setVerifyFn(std::move(verifier));
692  return WalkResult::advance();
693  });
694  if (res.wasInterrupted())
695  return failure();
696 
697  // Define and load all operations.
698  res = op.walk([&](OperationOp opOp) {
699  return loadOperation(opOp, dialects[opOp.getParentOp()], types, attrs);
700  });
701  if (res.wasInterrupted())
702  return failure();
703 
704  // Load all types in their dialects.
705  for (auto &pair : types) {
706  ExtensibleDialect *dialect = dialects[pair.first.getParentOp()];
707  dialect->registerDynamicType(std::move(pair.second));
708  }
709 
710  // Load all attributes in their dialects.
711  for (auto &pair : attrs) {
712  ExtensibleDialect *dialect = dialects[pair.first.getParentOp()];
713  dialect->registerDynamicAttr(std::move(pair.second));
714  }
715 
716  return success();
717 }
static bool getBases(Operation *op, SmallPtrSet< TypeID, 4 > &paramIds, SmallPtrSet< Operation *, 4 > &paramIrdlOps, SmallPtrSet< TypeID, 4 > &isIds)
Get the possible bases of a constraint.
static LogicalResult checkCorrectAnyOf(AnyOfOp anyOf)
Check that an any_of is in the subset IRDL can handle.
static DenseMap< AttributeOp, std::unique_ptr< DynamicAttrDefinition > > preallocateAttrDefs(ModuleOp op, DenseMap< DialectOp, ExtensibleDialect * > dialects)
Preallocate attribute definitions objects with empty verifiers.
LogicalResult getSegmentSizesFromAttr(Operation *op, StringRef elemName, StringRef attrName, unsigned numElements, ArrayRef< Variadicity > variadicities, SmallVectorImpl< int > &segmentSizes)
Get the operand segment sizes from the attribute dictionary.
Definition: IRDLLoading.cpp:52
static DynamicAttrDefinition::VerifierFn getAttrOrTypeVerifier(Operation *attrOrTypeDef, ExtensibleDialect *dialect, DenseMap< TypeOp, std::unique_ptr< DynamicTypeDefinition >> &types, DenseMap< AttributeOp, std::unique_ptr< DynamicAttrDefinition >> &attrs)
Get the verifier of a type or attribute definition.
static LogicalResult irdlOpVerifier(Operation *op, ConstraintVerifier &verifier, ArrayRef< size_t > operandConstrs, ArrayRef< Variadicity > operandVariadicity, ArrayRef< size_t > resultConstrs, ArrayRef< Variadicity > resultVariadicity, const DenseMap< StringAttr, size_t > &attributeConstrs)
Verify that the given operation satisfies the given constraints.
LogicalResult getOperandSegmentSizes(Operation *op, ArrayRef< Variadicity > variadicities, SmallVectorImpl< int > &segmentSizes)
Compute the segment sizes of the given operands.
static DenseMap< TypeOp, std::unique_ptr< DynamicTypeDefinition > > preallocateTypeDefs(ModuleOp op, DenseMap< DialectOp, ExtensibleDialect * > dialects)
Preallocate type definitions objects with empty verifiers.
LogicalResult getSegmentSizes(Operation *op, StringRef elemName, StringRef attrName, unsigned numElements, ArrayRef< Variadicity > variadicities, SmallVectorImpl< int > &segmentSizes)
Compute the segment sizes of the given element (operands, results).
static WalkResult loadOperation(OperationOp op, ExtensibleDialect *dialect, const DenseMap< TypeOp, std::unique_ptr< DynamicTypeDefinition >> &types, const DenseMap< AttributeOp, std::unique_ptr< DynamicAttrDefinition >> &attrs)
Define and load an operation represented by a irdl.operation operation.
static LogicalResult irdlAttrOrTypeVerifier(function_ref< InFlightDiagnostic()> emitError, ArrayRef< Attribute > params, ArrayRef< std::unique_ptr< Constraint >> constraints, ArrayRef< size_t > paramConstraints)
Verify that the given list of parameters satisfy the given constraints.
Definition: IRDLLoading.cpp:31
static LogicalResult irdlRegionVerifier(Operation *op, ConstraintVerifier &verifier, ArrayRef< std::unique_ptr< RegionConstraint >> regionsConstraints)
static DenseMap< DialectOp, ExtensibleDialect * > loadEmptyDialects(ModuleOp op)
Load all dialects in the given module, without loading any operation, type or attribute definitions.
LogicalResult getResultSegmentSizes(Operation *op, ArrayRef< Variadicity > variadicities, SmallVectorImpl< int > &segmentSizes)
Compute the segment sizes of the given results.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
TypeID getTypeID()
Return a unique identifier for the concrete attribute type.
Definition: Attributes.h:52
static std::unique_ptr< DynamicAttrDefinition > get(StringRef name, ExtensibleDialect *dialect, VerifierFn &&verifier)
Create a new attribute definition at runtime.
llvm::unique_function< LogicalResult(function_ref< InFlightDiagnostic()>, ArrayRef< Attribute >) const > VerifierFn
A dialect that can be defined at runtime.
static std::unique_ptr< DynamicOpDefinition > get(StringRef name, ExtensibleDialect *dialect, OperationName::VerifyInvariantsFn &&verifyFn, OperationName::VerifyRegionInvariantsFn &&verifyRegionFn)
Create a new op at runtime.
static std::unique_ptr< DynamicTypeDefinition > get(StringRef name, ExtensibleDialect *dialect, VerifierFn &&verifier)
Create a new dynamic type definition.
A dialect that can be extended with new operations/types/attributes at runtime.
void registerDynamicOp(std::unique_ptr< DynamicOpDefinition > &&type)
Add a new operation defined at runtime to the dialect.
void registerDynamicType(std::unique_ptr< DynamicTypeDefinition > &&type)
Add a new type defined at runtime to the dialect.
void registerDynamicAttr(std::unique_ptr< DynamicAttrDefinition > &&attr)
Add a new attribute defined at runtime to the dialect.
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
DynamicDialect * getOrLoadDynamicDialect(StringRef dialectNamespace, function_ref< void(DynamicDialect *)> ctor)
Get (or create) a dynamic dialect for the given name.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
DictionaryAttr getAttrDictionary()
Return all of the attributes on this operation as a DictionaryAttr.
Definition: Operation.cpp:295
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:534
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
unsigned getNumOperands()
Definition: Operation.h:346
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:686
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:672
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
iterator_range< OpIterator > getOps()
Definition: Region.h:172
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:107
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: WalkResult.h:29
static WalkResult advance()
Definition: WalkResult.h:47
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition: WalkResult.h:51
static WalkResult interrupt()
Definition: WalkResult.h:46
Provides context to the verification of constraints.
Definition: IRDLVerifiers.h:40
LogicalResult verify(function_ref< InFlightDiagnostic()> emitError, Attribute attr, unsigned variable)
Check that a constraint is satisfied by an attribute.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
llvm::LogicalResult loadDialects(ModuleOp op)
Load all the dialects defined in the module.
Operation * lookupSymbolNearDialect(SymbolTableCollection &symbolTable, Operation *source, SymbolRefAttr symbol)
Looks up a symbol from the symbol table containing the source operation's dialect definition operatio...
Definition: IRDLSymbols.cpp:28
llvm::unique_function< LogicalResult(Operation *) const > createVerifier(OperationOp operation, const DenseMap< irdl::TypeOp, std::unique_ptr< DynamicTypeDefinition >> &typeDefs, const DenseMap< irdl::AttributeOp, std::unique_ptr< DynamicAttrDefinition >> &attrDefs)
Generate an op verifier function from the given IRDL operation definition.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This represents an operation in an abstracted form, suitable for use with the builder APIs.