MLIR  20.0.0git
IRDLLoading.cpp
Go to the documentation of this file.
1 //===- IRDLLoading.cpp - IRDL dialect loading --------------------- C++ -*-===//
2 //
3 // This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Manages the loading of MLIR objects from IRDL operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
18 #include "mlir/IR/Attributes.h"
19 #include "mlir/IR/BuiltinOps.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/SmallPtrSet.h"
24 #include "llvm/Support/SMLoc.h"
25 #include <numeric>
26 
27 using namespace mlir;
28 using namespace mlir::irdl;
29 
30 /// Verify that the given list of parameters satisfy the given constraints.
31 /// This encodes the logic of the verification method for attributes and types
32 /// defined with IRDL.
33 static LogicalResult
35  ArrayRef<Attribute> params,
36  ArrayRef<std::unique_ptr<Constraint>> constraints,
37  ArrayRef<size_t> paramConstraints) {
38  if (params.size() != paramConstraints.size()) {
39  emitError() << "expected " << paramConstraints.size()
40  << " type arguments, but had " << params.size();
41  return failure();
42  }
43 
44  ConstraintVerifier verifier(constraints);
45 
46  // Check that each parameter satisfies its constraint.
47  for (auto [i, param] : enumerate(params))
48  if (failed(verifier.verify(emitError, param, paramConstraints[i])))
49  return failure();
50 
51  return success();
52 }
53 
54 /// Get the operand segment sizes from the attribute dictionary.
55 LogicalResult getSegmentSizesFromAttr(Operation *op, StringRef elemName,
56  StringRef attrName, unsigned numElements,
57  ArrayRef<Variadicity> variadicities,
58  SmallVectorImpl<int> &segmentSizes) {
59  // Get the segment sizes attribute, and check that it is of the right type.
60  Attribute segmentSizesAttr = op->getAttr(attrName);
61  if (!segmentSizesAttr) {
62  return op->emitError() << "'" << attrName
63  << "' attribute is expected but not provided";
64  }
65 
66  auto denseSegmentSizes = dyn_cast<DenseI32ArrayAttr>(segmentSizesAttr);
67  if (!denseSegmentSizes) {
68  return op->emitError() << "'" << attrName
69  << "' attribute is expected to be a dense i32 array";
70  }
71 
72  if (denseSegmentSizes.size() != (int64_t)variadicities.size()) {
73  return op->emitError() << "'" << attrName << "' attribute for specifying "
74  << elemName << " segments must have "
75  << variadicities.size() << " elements, but got "
76  << denseSegmentSizes.size();
77  }
78 
79  // Check that the segment sizes are corresponding to the given variadicities,
80  for (auto [i, segmentSize, variadicity] :
81  enumerate(denseSegmentSizes.asArrayRef(), variadicities)) {
82  if (segmentSize < 0)
83  return op->emitError()
84  << "'" << attrName << "' attribute for specifying " << elemName
85  << " segments must have non-negative values";
86  if (variadicity == Variadicity::single && segmentSize != 1)
87  return op->emitError() << "element " << i << " in '" << attrName
88  << "' attribute must be equal to 1";
89 
90  if (variadicity == Variadicity::optional && segmentSize > 1)
91  return op->emitError() << "element " << i << " in '" << attrName
92  << "' attribute must be equal to 0 or 1";
93 
94  segmentSizes.push_back(segmentSize);
95  }
96 
97  // Check that the sum of the segment sizes is equal to the number of elements.
98  int32_t sum = 0;
99  for (int32_t segmentSize : denseSegmentSizes.asArrayRef())
100  sum += segmentSize;
101  if (sum != static_cast<int32_t>(numElements))
102  return op->emitError() << "sum of elements in '" << attrName
103  << "' attribute must be equal to the number of "
104  << elemName << "s";
105 
106  return success();
107 }
108 
109 /// Compute the segment sizes of the given element (operands, results).
110 /// If the operation has more than two non-single elements (optional or
111 /// variadic), then get the segment sizes from the attribute dictionary.
112 /// Otherwise, compute the segment sizes from the number of elements.
113 /// `elemName` should be either `"operand"` or `"result"`.
114 LogicalResult getSegmentSizes(Operation *op, StringRef elemName,
115  StringRef attrName, unsigned numElements,
116  ArrayRef<Variadicity> variadicities,
117  SmallVectorImpl<int> &segmentSizes) {
118  // If we have more than one non-single variadicity, we need to get the
119  // segment sizes from the attribute dictionary.
120  int numberNonSingle = count_if(
121  variadicities, [](Variadicity v) { return v != Variadicity::single; });
122  if (numberNonSingle > 1)
123  return getSegmentSizesFromAttr(op, elemName, attrName, numElements,
124  variadicities, segmentSizes);
125 
126  // If we only have single variadicities, the segments sizes are all 1.
127  if (numberNonSingle == 0) {
128  if (numElements != variadicities.size()) {
129  return op->emitError() << "op expects exactly " << variadicities.size()
130  << " " << elemName << "s, but got " << numElements;
131  }
132  for (size_t i = 0, e = variadicities.size(); i < e; ++i)
133  segmentSizes.push_back(1);
134  return success();
135  }
136 
137  assert(numberNonSingle == 1);
138 
139  // There is exactly one non-single element, so we can
140  // compute its size and check that it is valid.
141  int nonSingleSegmentSize = static_cast<int>(numElements) -
142  static_cast<int>(variadicities.size()) + 1;
143 
144  if (nonSingleSegmentSize < 0) {
145  return op->emitError() << "op expects at least " << variadicities.size() - 1
146  << " " << elemName << "s, but got " << numElements;
147  }
148 
149  // Add the segment sizes.
150  for (Variadicity variadicity : variadicities) {
151  if (variadicity == Variadicity::single) {
152  segmentSizes.push_back(1);
153  continue;
154  }
155 
156  // If we have an optional element, we should check that it represents
157  // zero or one elements.
158  if (nonSingleSegmentSize > 1 && variadicity == Variadicity::optional)
159  return op->emitError() << "op expects at most " << variadicities.size()
160  << " " << elemName << "s, but got " << numElements;
161 
162  segmentSizes.push_back(nonSingleSegmentSize);
163  }
164 
165  return success();
166 }
167 
168 /// Compute the segment sizes of the given operands.
169 /// If the operation has more than two non-single operands (optional or
170 /// variadic), then get the segment sizes from the attribute dictionary.
171 /// Otherwise, compute the segment sizes from the number of operands.
173  ArrayRef<Variadicity> variadicities,
174  SmallVectorImpl<int> &segmentSizes) {
175  return getSegmentSizes(op, "operand", "operand_segment_sizes",
176  op->getNumOperands(), variadicities, segmentSizes);
177 }
178 
179 /// Compute the segment sizes of the given results.
180 /// If the operation has more than two non-single results (optional or
181 /// variadic), then get the segment sizes from the attribute dictionary.
182 /// Otherwise, compute the segment sizes from the number of results.
184  ArrayRef<Variadicity> variadicities,
185  SmallVectorImpl<int> &segmentSizes) {
186  return getSegmentSizes(op, "result", "result_segment_sizes",
187  op->getNumResults(), variadicities, segmentSizes);
188 }
189 
190 /// Verify that the given operation satisfies the given constraints.
191 /// This encodes the logic of the verification method for operations defined
192 /// with IRDL.
193 static LogicalResult irdlOpVerifier(
194  Operation *op, ConstraintVerifier &verifier,
195  ArrayRef<size_t> operandConstrs, ArrayRef<Variadicity> operandVariadicity,
196  ArrayRef<size_t> resultConstrs, ArrayRef<Variadicity> resultVariadicity,
197  const DenseMap<StringAttr, size_t> &attributeConstrs) {
198  // Get the segment sizes for the operands.
199  // This will check that the number of operands is correct.
200  SmallVector<int> operandSegmentSizes;
201  if (failed(
202  getOperandSegmentSizes(op, operandVariadicity, operandSegmentSizes)))
203  return failure();
204 
205  // Get the segment sizes for the results.
206  // This will check that the number of results is correct.
207  SmallVector<int> resultSegmentSizes;
208  if (failed(getResultSegmentSizes(op, resultVariadicity, resultSegmentSizes)))
209  return failure();
210 
211  auto emitError = [op] { return op->emitError(); };
212 
213  /// Сheck that we have all needed attributes passed
214  /// and they satisfy the constraints.
215  DictionaryAttr actualAttrs = op->getAttrDictionary();
216 
217  for (auto [name, constraint] : attributeConstrs) {
218  /// First, check if the attribute actually passed.
219  std::optional<NamedAttribute> actual = actualAttrs.getNamed(name);
220  if (!actual.has_value())
221  return op->emitOpError()
222  << "attribute " << name << " is expected but not provided";
223 
224  /// Then, check if the attribute value satisfies the constraint.
225  if (failed(verifier.verify({emitError}, actual->getValue(), constraint)))
226  return failure();
227  }
228 
229  // Check that all operands satisfy the constraints
230  int operandIdx = 0;
231  for (auto [defIndex, segmentSize] : enumerate(operandSegmentSizes)) {
232  for (int i = 0; i < segmentSize; i++) {
233  if (failed(verifier.verify(
234  {emitError}, TypeAttr::get(op->getOperandTypes()[operandIdx]),
235  operandConstrs[defIndex])))
236  return failure();
237  ++operandIdx;
238  }
239  }
240 
241  // Check that all results satisfy the constraints
242  int resultIdx = 0;
243  for (auto [defIndex, segmentSize] : enumerate(resultSegmentSizes)) {
244  for (int i = 0; i < segmentSize; i++) {
245  if (failed(verifier.verify({emitError},
246  TypeAttr::get(op->getResultTypes()[resultIdx]),
247  resultConstrs[defIndex])))
248  return failure();
249  ++resultIdx;
250  }
251  }
252 
253  return success();
254 }
255 
256 static LogicalResult irdlRegionVerifier(
257  Operation *op, ConstraintVerifier &verifier,
258  ArrayRef<std::unique_ptr<RegionConstraint>> regionsConstraints) {
259  if (op->getNumRegions() != regionsConstraints.size()) {
260  return op->emitOpError()
261  << "unexpected number of regions: expected "
262  << regionsConstraints.size() << " but got " << op->getNumRegions();
263  }
264 
265  for (auto [constraint, region] :
266  llvm::zip(regionsConstraints, op->getRegions()))
267  if (failed(constraint->verify(region, verifier)))
268  return failure();
269 
270  return success();
271 }
272 
273 llvm::unique_function<LogicalResult(Operation *) const>
275  OperationOp op,
276  const DenseMap<irdl::TypeOp, std::unique_ptr<DynamicTypeDefinition>> &types,
277  const DenseMap<irdl::AttributeOp, std::unique_ptr<DynamicAttrDefinition>>
278  &attrs) {
279  // Resolve SSA values to verifier constraint slots
280  SmallVector<Value> constrToValue;
281  SmallVector<Value> regionToValue;
282  for (Operation &op : op->getRegion(0).getOps()) {
283  if (isa<VerifyConstraintInterface>(op)) {
284  if (op.getNumResults() != 1) {
285  op.emitError()
286  << "IRDL constraint operations must have exactly one result";
287  return nullptr;
288  }
289  constrToValue.push_back(op.getResult(0));
290  }
291  if (isa<VerifyRegionInterface>(op)) {
292  if (op.getNumResults() != 1) {
293  op.emitError()
294  << "IRDL constraint operations must have exactly one result";
295  return nullptr;
296  }
297  regionToValue.push_back(op.getResult(0));
298  }
299  }
300 
301  // Build the verifiers for each constraint slot
303  for (Value v : constrToValue) {
304  VerifyConstraintInterface op =
305  cast<VerifyConstraintInterface>(v.getDefiningOp());
306  std::unique_ptr<Constraint> verifier =
307  op.getVerifier(constrToValue, types, attrs);
308  if (!verifier)
309  return nullptr;
310  constraints.push_back(std::move(verifier));
311  }
312 
313  // Build region constraints
315  for (Value v : regionToValue) {
316  VerifyRegionInterface op = cast<VerifyRegionInterface>(v.getDefiningOp());
317  std::unique_ptr<RegionConstraint> verifier =
318  op.getVerifier(constrToValue, types, attrs);
319  regionConstraints.push_back(std::move(verifier));
320  }
321 
322  SmallVector<size_t> operandConstraints;
323  SmallVector<Variadicity> operandVariadicity;
324 
325  // Gather which constraint slots correspond to operand constraints
326  auto operandsOp = op.getOp<OperandsOp>();
327  if (operandsOp.has_value()) {
328  operandConstraints.reserve(operandsOp->getArgs().size());
329  for (Value operand : operandsOp->getArgs()) {
330  for (auto [i, constr] : enumerate(constrToValue)) {
331  if (constr == operand) {
332  operandConstraints.push_back(i);
333  break;
334  }
335  }
336  }
337 
338  // Gather the variadicities of each operand
339  for (VariadicityAttr attr : operandsOp->getVariadicity())
340  operandVariadicity.push_back(attr.getValue());
341  }
342 
343  SmallVector<size_t> resultConstraints;
344  SmallVector<Variadicity> resultVariadicity;
345 
346  // Gather which constraint slots correspond to result constraints
347  auto resultsOp = op.getOp<ResultsOp>();
348  if (resultsOp.has_value()) {
349  resultConstraints.reserve(resultsOp->getArgs().size());
350  for (Value result : resultsOp->getArgs()) {
351  for (auto [i, constr] : enumerate(constrToValue)) {
352  if (constr == result) {
353  resultConstraints.push_back(i);
354  break;
355  }
356  }
357  }
358 
359  // Gather the variadicities of each result
360  for (Attribute attr : resultsOp->getVariadicity())
361  resultVariadicity.push_back(cast<VariadicityAttr>(attr).getValue());
362  }
363 
364  // Gather which constraint slots correspond to attributes constraints
365  DenseMap<StringAttr, size_t> attributeConstraints;
366  auto attributesOp = op.getOp<AttributesOp>();
367  if (attributesOp.has_value()) {
368  const Operation::operand_range values = attributesOp->getAttributeValues();
369  const ArrayAttr names = attributesOp->getAttributeValueNames();
370 
371  for (const auto &[name, value] : llvm::zip(names, values)) {
372  for (auto [i, constr] : enumerate(constrToValue)) {
373  if (constr == value) {
374  attributeConstraints[cast<StringAttr>(name)] = i;
375  break;
376  }
377  }
378  }
379  }
380 
381  return
382  [constraints{std::move(constraints)},
383  regionConstraints{std::move(regionConstraints)},
384  operandConstraints{std::move(operandConstraints)},
385  operandVariadicity{std::move(operandVariadicity)},
386  resultConstraints{std::move(resultConstraints)},
387  resultVariadicity{std::move(resultVariadicity)},
388  attributeConstraints{std::move(attributeConstraints)}](Operation *op) {
389  ConstraintVerifier verifier(constraints);
390  const LogicalResult opVerifierResult = irdlOpVerifier(
391  op, verifier, operandConstraints, operandVariadicity,
392  resultConstraints, resultVariadicity, attributeConstraints);
393  const LogicalResult opRegionVerifierResult =
394  irdlRegionVerifier(op, verifier, regionConstraints);
395  return LogicalResult::success(opVerifierResult.succeeded() &&
396  opRegionVerifierResult.succeeded());
397  };
398 }
399 
400 /// Define and load an operation represented by a `irdl.operation`
401 /// operation.
403  OperationOp op, ExtensibleDialect *dialect,
404  const DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> &types,
405  const DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>>
406  &attrs) {
407 
408  // IRDL does not support defining custom parsers or printers.
409  auto parser = [](OpAsmParser &parser, OperationState &result) {
410  return failure();
411  };
412  auto printer = [](Operation *op, OpAsmPrinter &printer, StringRef) {
413  printer.printGenericOp(op);
414  };
415 
416  auto verifier = createVerifier(op, types, attrs);
417  if (!verifier)
418  return WalkResult::interrupt();
419 
420  // IRDL supports only checking number of blocks and argument constraints
421  // It is done in the main verifier to reuse `ConstraintVerifier` context
422  auto regionVerifier = [](Operation *op) { return LogicalResult::success(); };
423 
424  auto opDef = DynamicOpDefinition::get(
425  op.getName(), dialect, std::move(verifier), std::move(regionVerifier),
426  std::move(parser), std::move(printer));
427  dialect->registerDynamicOp(std::move(opDef));
428 
429  return WalkResult::advance();
430 }
431 
432 /// Get the verifier of a type or attribute definition.
433 /// Return nullptr if the definition is invalid.
435  Operation *attrOrTypeDef, ExtensibleDialect *dialect,
436  DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> &types,
437  DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> &attrs) {
438  assert((isa<AttributeOp>(attrOrTypeDef) || isa<TypeOp>(attrOrTypeDef)) &&
439  "Expected an attribute or type definition");
440 
441  // Resolve SSA values to verifier constraint slots
442  SmallVector<Value> constrToValue;
443  for (Operation &op : attrOrTypeDef->getRegion(0).getOps()) {
444  if (isa<VerifyConstraintInterface>(op)) {
445  assert(op.getNumResults() == 1 &&
446  "IRDL constraint operations must have exactly one result");
447  constrToValue.push_back(op.getResult(0));
448  }
449  }
450 
451  // Build the verifiers for each constraint slot
453  for (Value v : constrToValue) {
454  VerifyConstraintInterface op =
455  cast<VerifyConstraintInterface>(v.getDefiningOp());
456  std::unique_ptr<Constraint> verifier =
457  op.getVerifier(constrToValue, types, attrs);
458  if (!verifier)
459  return {};
460  constraints.push_back(std::move(verifier));
461  }
462 
463  // Get the parameter definitions.
464  std::optional<ParametersOp> params;
465  if (auto attr = dyn_cast<AttributeOp>(attrOrTypeDef))
466  params = attr.getOp<ParametersOp>();
467  else if (auto type = dyn_cast<TypeOp>(attrOrTypeDef))
468  params = type.getOp<ParametersOp>();
469 
470  // Gather which constraint slots correspond to parameter constraints
471  SmallVector<size_t> paramConstraints;
472  if (params.has_value()) {
473  paramConstraints.reserve(params->getArgs().size());
474  for (Value param : params->getArgs()) {
475  for (auto [i, constr] : enumerate(constrToValue)) {
476  if (constr == param) {
477  paramConstraints.push_back(i);
478  break;
479  }
480  }
481  }
482  }
483 
484  auto verifier = [paramConstraints{std::move(paramConstraints)},
485  constraints{std::move(constraints)}](
487  ArrayRef<Attribute> params) {
488  return irdlAttrOrTypeVerifier(emitError, params, constraints,
489  paramConstraints);
490  };
491 
492  // While the `std::move` is not required, not adding it triggers a bug in
493  // clang-10.
494  return std::move(verifier);
495 }
496 
497 /// Get the possible bases of a constraint. Return `true` if all bases can
498 /// potentially be matched.
499 /// A base is a type or an attribute definition. For instance, the base of
500 /// `irdl.parametric "!builtin.complex"(...)` is `builtin.complex`.
501 /// This function returns the following information through arguments:
502 /// - `paramIds`: the set of type or attribute IDs that are used as bases.
503 /// - `paramIrdlOps`: the set of IRDL operations that are used as bases.
504 /// - `isIds`: the set of type or attribute IDs that are used in `irdl.is`
505 /// constraints.
506 static bool getBases(Operation *op, SmallPtrSet<TypeID, 4> &paramIds,
507  SmallPtrSet<Operation *, 4> &paramIrdlOps,
508  SmallPtrSet<TypeID, 4> &isIds) {
509  // For `irdl.any_of`, we get the bases from all its arguments.
510  if (auto anyOf = dyn_cast<AnyOfOp>(op)) {
511  bool hasAny = false;
512  for (Value arg : anyOf.getArgs())
513  hasAny &= getBases(arg.getDefiningOp(), paramIds, paramIrdlOps, isIds);
514  return hasAny;
515  }
516 
517  // For `irdl.all_of`, we get the bases from the first argument.
518  // This is restrictive, but we can relax it later if needed.
519  if (auto allOf = dyn_cast<AllOfOp>(op))
520  return getBases(allOf.getArgs()[0].getDefiningOp(), paramIds, paramIrdlOps,
521  isIds);
522 
523  // For `irdl.parametric`, we get directly the base from the operation.
524  if (auto params = dyn_cast<ParametricOp>(op)) {
525  SymbolRefAttr symRef = params.getBaseType();
526  Operation *defOp = irdl::lookupSymbolNearDialect(op, symRef);
527  assert(defOp && "symbol reference should refer to an existing operation");
528  paramIrdlOps.insert(defOp);
529  return false;
530  }
531 
532  // For `irdl.is`, we get the base TypeID directly.
533  if (auto is = dyn_cast<IsOp>(op)) {
534  Attribute expected = is.getExpected();
535  isIds.insert(expected.getTypeID());
536  return false;
537  }
538 
539  // For `irdl.any`, we return `false` since we can match any type or attribute
540  // base.
541  if (auto isA = dyn_cast<AnyOp>(op))
542  return true;
543 
544  llvm_unreachable("unknown IRDL constraint");
545 }
546 
547 /// Check that an any_of is in the subset IRDL can handle.
548 /// IRDL uses a greedy algorithm to match constraints. This means that if we
549 /// encounter an `any_of` with multiple constraints, we will match the first
550 /// constraint that is satisfied. Thus, the order of constraints matter in
551 /// `any_of` with our current algorithm.
552 /// In order to make the order of constraints irrelevant, we require that
553 /// all `any_of` constraint parameters are disjoint. For this, we check that
554 /// the base parameters are all disjoints between `parametric` operations, and
555 /// that they are disjoint between `parametric` and `is` operations.
556 /// This restriction will be relaxed in the future, when we will change our
557 /// algorithm to be non-greedy.
558 static LogicalResult checkCorrectAnyOf(AnyOfOp anyOf) {
559  SmallPtrSet<TypeID, 4> paramIds;
560  SmallPtrSet<Operation *, 4> paramIrdlOps;
562 
563  for (Value arg : anyOf.getArgs()) {
564  Operation *argOp = arg.getDefiningOp();
565  SmallPtrSet<TypeID, 4> argParamIds;
566  SmallPtrSet<Operation *, 4> argParamIrdlOps;
567  SmallPtrSet<TypeID, 4> argIsIds;
568 
569  // Get the bases of this argument. If it can match any type or attribute,
570  // then our `any_of` should not be allowed.
571  if (getBases(argOp, argParamIds, argParamIrdlOps, argIsIds))
572  return failure();
573 
574  // We check that the base parameters are all disjoints between `parametric`
575  // operations, and that they are disjoint between `parametric` and `is`
576  // operations.
577  for (TypeID id : argParamIds) {
578  if (isIds.count(id))
579  return failure();
580  bool inserted = paramIds.insert(id).second;
581  if (!inserted)
582  return failure();
583  }
584 
585  // We check that the base parameters are all disjoints with `irdl.is`
586  // operations.
587  for (TypeID id : isIds) {
588  if (paramIds.count(id))
589  return failure();
590  isIds.insert(id);
591  }
592 
593  // We check that all `parametric` operations are disjoint. We do not
594  // need to check that they are disjoint with `is` operations, since
595  // `is` operations cannot refer to attributes defined with `irdl.parametric`
596  // operations.
597  for (Operation *op : argParamIrdlOps) {
598  bool inserted = paramIrdlOps.insert(op).second;
599  if (!inserted)
600  return failure();
601  }
602  }
603 
604  return success();
605 }
606 
607 /// Load all dialects in the given module, without loading any operation, type
608 /// or attribute definitions.
611  op.walk([&](DialectOp dialectOp) {
612  MLIRContext *ctx = dialectOp.getContext();
613  StringRef dialectName = dialectOp.getName();
614 
615  DynamicDialect *dialect = ctx->getOrLoadDynamicDialect(
616  dialectName, [](DynamicDialect *dialect) {});
617 
618  dialects.insert({dialectOp, dialect});
619  });
620  return dialects;
621 }
622 
623 /// Preallocate type definitions objects with empty verifiers.
624 /// This in particular allocates a TypeID for each type definition.
629  op.walk([&](TypeOp typeOp) {
630  ExtensibleDialect *dialect = dialects[typeOp.getParentOp()];
631  auto typeDef = DynamicTypeDefinition::get(
632  typeOp.getName(), dialect,
634  return success();
635  });
636  typeDefs.try_emplace(typeOp, std::move(typeDef));
637  });
638  return typeDefs;
639 }
640 
641 /// Preallocate attribute definitions objects with empty verifiers.
642 /// This in particular allocates a TypeID for each attribute definition.
647  op.walk([&](AttributeOp attrOp) {
648  ExtensibleDialect *dialect = dialects[attrOp.getParentOp()];
649  auto attrDef = DynamicAttrDefinition::get(
650  attrOp.getName(), dialect,
652  return success();
653  });
654  attrDefs.try_emplace(attrOp, std::move(attrDef));
655  });
656  return attrDefs;
657 }
658 
659 LogicalResult mlir::irdl::loadDialects(ModuleOp op) {
660  // First, check that all any_of constraints are in a correct form.
661  // This is to ensure we can do the verification correctly.
662  WalkResult anyOfCorrects = op.walk(
663  [](AnyOfOp anyOf) { return (WalkResult)checkCorrectAnyOf(anyOf); });
664  if (anyOfCorrects.wasInterrupted())
665  return op.emitError("any_of constraints are not in the correct form");
666 
667  // Preallocate all dialects, and type and attribute definitions.
668  // In particular, this allocates TypeIDs so type and attributes can have
669  // verifiers that refer to each other.
672  preallocateTypeDefs(op, dialects);
674  preallocateAttrDefs(op, dialects);
675 
676  // Set the verifier for types.
677  WalkResult res = op.walk([&](TypeOp typeOp) {
679  typeOp, dialects[typeOp.getParentOp()], types, attrs);
680  if (!verifier)
681  return WalkResult::interrupt();
682  types[typeOp]->setVerifyFn(std::move(verifier));
683  return WalkResult::advance();
684  });
685  if (res.wasInterrupted())
686  return failure();
687 
688  // Set the verifier for attributes.
689  res = op.walk([&](AttributeOp attrOp) {
691  attrOp, dialects[attrOp.getParentOp()], types, attrs);
692  if (!verifier)
693  return WalkResult::interrupt();
694  attrs[attrOp]->setVerifyFn(std::move(verifier));
695  return WalkResult::advance();
696  });
697  if (res.wasInterrupted())
698  return failure();
699 
700  // Define and load all operations.
701  res = op.walk([&](OperationOp opOp) {
702  return loadOperation(opOp, dialects[opOp.getParentOp()], types, attrs);
703  });
704  if (res.wasInterrupted())
705  return failure();
706 
707  // Load all types in their dialects.
708  for (auto &pair : types) {
709  ExtensibleDialect *dialect = dialects[pair.first.getParentOp()];
710  dialect->registerDynamicType(std::move(pair.second));
711  }
712 
713  // Load all attributes in their dialects.
714  for (auto &pair : attrs) {
715  ExtensibleDialect *dialect = dialects[pair.first.getParentOp()];
716  dialect->registerDynamicAttr(std::move(pair.second));
717  }
718 
719  return success();
720 }
static bool getBases(Operation *op, SmallPtrSet< TypeID, 4 > &paramIds, SmallPtrSet< Operation *, 4 > &paramIrdlOps, SmallPtrSet< TypeID, 4 > &isIds)
Get the possible bases of a constraint.
static LogicalResult checkCorrectAnyOf(AnyOfOp anyOf)
Check that an any_of is in the subset IRDL can handle.
static DenseMap< AttributeOp, std::unique_ptr< DynamicAttrDefinition > > preallocateAttrDefs(ModuleOp op, DenseMap< DialectOp, ExtensibleDialect * > dialects)
Preallocate attribute definitions objects with empty verifiers.
LogicalResult getSegmentSizesFromAttr(Operation *op, StringRef elemName, StringRef attrName, unsigned numElements, ArrayRef< Variadicity > variadicities, SmallVectorImpl< int > &segmentSizes)
Get the operand segment sizes from the attribute dictionary.
Definition: IRDLLoading.cpp:55
static DynamicAttrDefinition::VerifierFn getAttrOrTypeVerifier(Operation *attrOrTypeDef, ExtensibleDialect *dialect, DenseMap< TypeOp, std::unique_ptr< DynamicTypeDefinition >> &types, DenseMap< AttributeOp, std::unique_ptr< DynamicAttrDefinition >> &attrs)
Get the verifier of a type or attribute definition.
static LogicalResult irdlOpVerifier(Operation *op, ConstraintVerifier &verifier, ArrayRef< size_t > operandConstrs, ArrayRef< Variadicity > operandVariadicity, ArrayRef< size_t > resultConstrs, ArrayRef< Variadicity > resultVariadicity, const DenseMap< StringAttr, size_t > &attributeConstrs)
Verify that the given operation satisfies the given constraints.
LogicalResult getOperandSegmentSizes(Operation *op, ArrayRef< Variadicity > variadicities, SmallVectorImpl< int > &segmentSizes)
Compute the segment sizes of the given operands.
static DenseMap< TypeOp, std::unique_ptr< DynamicTypeDefinition > > preallocateTypeDefs(ModuleOp op, DenseMap< DialectOp, ExtensibleDialect * > dialects)
Preallocate type definitions objects with empty verifiers.
LogicalResult getSegmentSizes(Operation *op, StringRef elemName, StringRef attrName, unsigned numElements, ArrayRef< Variadicity > variadicities, SmallVectorImpl< int > &segmentSizes)
Compute the segment sizes of the given element (operands, results).
static WalkResult loadOperation(OperationOp op, ExtensibleDialect *dialect, const DenseMap< TypeOp, std::unique_ptr< DynamicTypeDefinition >> &types, const DenseMap< AttributeOp, std::unique_ptr< DynamicAttrDefinition >> &attrs)
Define and load an operation represented by a irdl.operation operation.
static LogicalResult irdlAttrOrTypeVerifier(function_ref< InFlightDiagnostic()> emitError, ArrayRef< Attribute > params, ArrayRef< std::unique_ptr< Constraint >> constraints, ArrayRef< size_t > paramConstraints)
Verify that the given list of parameters satisfy the given constraints.
Definition: IRDLLoading.cpp:34
static LogicalResult irdlRegionVerifier(Operation *op, ConstraintVerifier &verifier, ArrayRef< std::unique_ptr< RegionConstraint >> regionsConstraints)
static DenseMap< DialectOp, ExtensibleDialect * > loadEmptyDialects(ModuleOp op)
Load all dialects in the given module, without loading any operation, type or attribute definitions.
LogicalResult getResultSegmentSizes(Operation *op, ArrayRef< Variadicity > variadicities, SmallVectorImpl< int > &segmentSizes)
Compute the segment sizes of the given results.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
TypeID getTypeID()
Return a unique identifier for the concrete attribute type.
Definition: Attributes.h:70
static std::unique_ptr< DynamicAttrDefinition > get(StringRef name, ExtensibleDialect *dialect, VerifierFn &&verifier)
Create a new attribute definition at runtime.
llvm::unique_function< LogicalResult(function_ref< InFlightDiagnostic()>, ArrayRef< Attribute >) const > VerifierFn
A dialect that can be defined at runtime.
static std::unique_ptr< DynamicOpDefinition > get(StringRef name, ExtensibleDialect *dialect, OperationName::VerifyInvariantsFn &&verifyFn, OperationName::VerifyRegionInvariantsFn &&verifyRegionFn)
Create a new op at runtime.
static std::unique_ptr< DynamicTypeDefinition > get(StringRef name, ExtensibleDialect *dialect, VerifierFn &&verifier)
Create a new dynamic type definition.
A dialect that can be extended with new operations/types/attributes at runtime.
void registerDynamicOp(std::unique_ptr< DynamicOpDefinition > &&type)
Add a new operation defined at runtime to the dialect.
void registerDynamicType(std::unique_ptr< DynamicTypeDefinition > &&type)
Add a new type defined at runtime to the dialect.
void registerDynamicAttr(std::unique_ptr< DynamicAttrDefinition > &&attr)
Add a new attribute defined at runtime to the dialect.
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
DynamicDialect * getOrLoadDynamicDialect(StringRef dialectNamespace, function_ref< void(DynamicDialect *)> ctor)
Get (or create) a dynamic dialect for the given name.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
DictionaryAttr getAttrDictionary()
Return all of the attributes on this operation as a DictionaryAttr.
Definition: Operation.cpp:296
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:529
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:669
unsigned getNumOperands()
Definition: Operation.h:341
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:672
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_type_range getOperandTypes()
Definition: Operation.h:392
result_type_range getResultTypes()
Definition: Operation.h:423
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
iterator_range< OpIterator > getOps()
Definition: Region.h:172
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:104
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:33
static WalkResult advance()
Definition: Visitors.h:51
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition: Visitors.h:55
static WalkResult interrupt()
Definition: Visitors.h:50
Provides context to the verification of constraints.
Definition: IRDLVerifiers.h:40
LogicalResult verify(function_ref< InFlightDiagnostic()> emitError, Attribute attr, unsigned variable)
Check that a constraint is satisfied by an attribute.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
llvm::LogicalResult loadDialects(ModuleOp op)
Load all the dialects defined in the module.
Operation * lookupSymbolNearDialect(SymbolTableCollection &symbolTable, Operation *source, SymbolRefAttr symbol)
Looks up a symbol from the symbol table containing the source operation's dialect definition operatio...
Definition: IRDLSymbols.cpp:28
llvm::unique_function< LogicalResult(Operation *) const > createVerifier(OperationOp operation, const DenseMap< irdl::TypeOp, std::unique_ptr< DynamicTypeDefinition >> &typeDefs, const DenseMap< irdl::AttributeOp, std::unique_ptr< DynamicAttrDefinition >> &attrDefs)
Generate an op verifier function from the given IRDL operation definition.
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This represents an operation in an abstracted form, suitable for use with the builder APIs.