MLIR  16.0.0git
TransformInterfaces.h
Go to the documentation of this file.
1 //===- TransformInterfaces.h - Transform Dialect Interfaces -----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMINTERFACES_H
10 #define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMINTERFACES_H
11 
12 #include "mlir/IR/OpDefinition.h"
13 
15 #include "llvm/ADT/ScopeExit.h"
16 
17 namespace mlir {
18 
19 /// The result of a transform IR operation application. This can have one of the
20 /// three states:
21 /// - success;
22 /// - silenceable (recoverable) failure with yet-unreported diagnostic;
23 /// - definite failure.
24 /// Silenceable failure is intended to communicate information about
25 /// transformations that did not apply but in a way that supports recovery,
26 /// for example, they did not modify the payload IR or modified it in some
27 /// predictable way. They are associated with a Diagnostic that provides more
28 /// details on the failure. Silenceable failure can be discarded, turning the
29 /// result into success, or "reported", emitting the diagnostic and turning the
30 /// result into definite failure.
31 /// Transform IR operations containing other operations are allowed to do either
32 /// with the results of the nested transformations, but must propagate definite
33 /// failures as their diagnostics have been already reported to the user.
34 class [[nodiscard]] DiagnosedSilenceableFailure {
35 public:
36  explicit DiagnosedSilenceableFailure(LogicalResult result) : result(result) {}
39  operator=(const DiagnosedSilenceableFailure &) = delete;
42  operator=(DiagnosedSilenceableFailure &&) = default;
43 
44  /// Constructs a DiagnosedSilenceableFailure in the success state.
47  }
48 
49  /// Constructs a DiagnosedSilenceableFailure in the failure state. Typically,
50  /// a diagnostic has been emitted before this.
53  }
54 
55  /// Constructs a DiagnosedSilenceableFailure in the silenceable failure state,
56  /// ready to emit the given diagnostic. This is considered a failure
57  /// regardless of the diagnostic severity.
59  return DiagnosedSilenceableFailure(std::forward<Diagnostic>(diag));
60  }
64  std::forward<SmallVector<Diagnostic>>(diag));
65  }
66 
67  /// Converts all kinds of failure into a LogicalResult failure, emitting the
68  /// diagnostic if necessary. Must not be called more than once.
70 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
71  assert(!reported && "attempting to report a diagnostic more than once");
72  reported = true;
73 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
74  if (!diagnostics.empty()) {
75  for (auto &&diagnostic : diagnostics) {
76  diagnostic.getLocation().getContext()->getDiagEngine().emit(
77  std::move(diagnostic));
78  }
79  diagnostics.clear();
80  result = ::mlir::failure();
81  }
82  return result;
83  }
84 
85  /// Returns `true` if this is a success.
86  bool succeeded() const {
87  return ::mlir::succeeded(result) && diagnostics.empty();
88  }
89 
90  /// Returns `true` if this is a definite failure.
91  bool isDefiniteFailure() const {
92  return ::mlir::failed(result) && diagnostics.empty();
93  }
94 
95  /// Returns `true` if this is a silenceable failure.
96  bool isSilenceableFailure() const { return !diagnostics.empty(); }
97 
98  /// Returns the diagnostic message without emitting it. Expects this object
99  /// to be a silenceable failure.
100  std::string getMessage() const {
101  std::string res;
102  for (auto &diagnostic : diagnostics) {
103  res.append(diagnostic.str());
104  res.append("\n");
105  }
106  return res;
107  }
108 
109  /// Returns a string representation of the failure mode (for error reporting).
110  std::string getStatusString() const {
111  if (succeeded())
112  return "success";
113  if (isSilenceableFailure())
114  return "silenceable failure";
115  return "definite failure";
116  }
117 
118  /// Converts silenceable failure into LogicalResult success without reporting
119  /// the diagnostic, preserves the other states.
121  if (!diagnostics.empty()) {
122  diagnostics.clear();
123  result = ::mlir::success();
124  }
125  return result;
126  }
127 
128  /// Take the diagnostic and silence.
130  assert(!diagnostics.empty() && "expected a diagnostic to be present");
131  auto guard = llvm::make_scope_exit([&]() { diagnostics.clear(); });
132  return std::move(diagnostics);
133  }
134 
135  /// Streams the given values into the last diagnotic.
136  /// Expects this object to be a silenceable failure.
137  template <typename T>
139  assert(isSilenceableFailure() &&
140  "can only append output in silenceable failure state");
141  diagnostics.back() << std::forward<T>(value);
142  return *this;
143  }
144  template <typename T>
146  return std::move(this->operator<<(std::forward<T>(value)));
147  }
148 
149  /// Attaches a note to the last diagnostic.
150  /// Expects this object to be a silenceable failure.
152  assert(isSilenceableFailure() &&
153  "can only attach notes to silenceable failures");
154  return diagnostics.back().attachNote(loc);
155  }
156 
157 private:
158  explicit DiagnosedSilenceableFailure(Diagnostic &&diagnostic)
159  : result(failure()) {
160  diagnostics.emplace_back(std::move(diagnostic));
161  }
163  : diagnostics(std::move(diagnostics)), result(failure()) {}
164 
165  /// The diagnostics associated with this object. If non-empty, the object is
166  /// considered to be in the silenceable failure state regardless of the
167  /// `result` field.
168  SmallVector<Diagnostic, 1> diagnostics;
169 
170  /// The "definite" logical state, either success or failure.
171  /// Ignored if the diagnostics message is present.
172  LogicalResult result;
173 
174 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
175  /// Whether the associated diagnostics have been reported.
176  /// Diagnostics reporting consumes the diagnostics, so we need a mechanism to
177  /// differentiate reported diagnostics from a state where it was never
178  /// created.
179  bool reported = false;
180 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
181 };
182 
183 namespace transform {
184 
185 class TransformOpInterface;
186 
187 /// Options controlling the application of transform operations by the
188 /// TransformState.
190 public:
192 
193  /// Requests computationally expensive checks of the transform and payload IR
194  /// well-formedness to be performed before each transformation. In particular,
195  /// these ensure that the handles still point to valid operations when used.
197  expensiveChecksEnabled = enable;
198  return *this;
199  }
200 
201  /// Returns true if the expensive checks are requested.
202  bool getExpensiveChecksEnabled() const { return expensiveChecksEnabled; }
203 
204 private:
205  bool expensiveChecksEnabled = true;
206 };
207 
208 /// The state maintained across applications of various ops implementing the
209 /// TransformOpInterface. The operations implementing this interface and the
210 /// surrounding structure are referred to as transform IR. The operations to
211 /// which transformations apply are referred to as payload IR. The state thus
212 /// contains the mapping between values defined in the transform IR ops and
213 /// payload IR ops. It assumes that each value in the transform IR can be used
214 /// at most once (since transformations are likely to change the payload IR ops
215 /// the value corresponds to). Checks that transform IR values correspond to
216 /// disjoint sets of payload IR ops throughout the transformation.
217 ///
218 /// A reference to this class is passed as an argument to "apply" methods of the
219 /// transform op interface. Thus the "apply" method can call
220 /// `state.getPayloadOps( getSomeOperand() )` to obtain the list of operations
221 /// associated with its operand and subject to transformation. The method is
222 /// expected to populate the `TransformResults` class instance in order to
223 /// update the mapping. The `applyTransform` method takes care of propagating
224 /// the state of `TransformResults` into the instance of this class.
225 ///
226 /// When applying transform IR operations with regions, the client is expected
227 /// to create a RegionScope RAII object to create a new "stack frame" for
228 /// values defined inside the region. The mappings from and to these values will
229 /// be automatically dropped when the object goes out of scope, typically at the
230 /// end of the "apply" function of the parent operation. If a region contains
231 /// blocks with arguments, the client can map those arguments to payload IR ops
232 /// using "mapBlockArguments".
234  /// Mapping between a Value in the transform IR and the corresponding set of
235  /// operations in the payload IR.
237 
238  /// Mapping between a payload IR operation and the transform IR value it is
239  /// currently associated with.
241 
242  /// Bidirectional mappings between transform IR values and payload IR
243  /// operations.
244  struct Mappings {
245  TransformOpMapping direct;
247  };
248 
249 public:
250  /// Creates a state for transform ops living in the given region. The parent
251  /// operation of the region. The second argument points to the root operation
252  /// in the payload IR beind transformed, which may or may not contain the
253  /// region with transform ops. Additional options can be provided through the
254  /// trailing configuration object.
255  TransformState(Region &region, Operation *root,
257 
258  /// Returns the op at which the transformation state is rooted. This is
259  /// typically helpful for transformations that apply globally.
260  Operation *getTopLevel() const;
261 
262  /// Returns the list of ops that the given transform IR value corresponds to.
263  /// This is helpful for transformations that apply to a particular handle.
264  ArrayRef<Operation *> getPayloadOps(Value value) const;
265 
266  /// Returns the Transform IR handle for the given Payload IR op if it exists
267  /// in the state, null otherwise.
268  Value getHandleForPayloadOp(Operation *op) const;
269 
270  /// Applies the transformation specified by the given transform op and updates
271  /// the state accordingly.
272  DiagnosedSilenceableFailure applyTransform(TransformOpInterface transform);
273 
274  /// Records the mapping between a block argument in the transform IR and a
275  /// list of operations in the payload IR. The arguments must be defined in
276  /// blocks of the currently processed transform IR region, typically after a
277  /// region scope is defined.
279  ArrayRef<Operation *> operations) {
280 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
281  assert(argument.getParentRegion() == regionStack.back() &&
282  "mapping block arguments from a region other than the active one");
283 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
284  return setPayloadOps(argument, operations);
285  }
286 
287  // Forward declarations to support limited visibility.
288  class RegionScope;
289 
290  /// Creates a new region scope for the given region. The region is expected to
291  /// be nested in the currently processed region.
292  // Implementation note: this method is inline but implemented outside of the
293  // class body to comply with visibility and full-declaration requirements.
294  inline RegionScope make_region_scope(Region &region);
295 
296  /// A RAII object maintaining a "stack frame" for a transform IR region. When
297  /// applying a transform IR operation that contains a region, the caller is
298  /// expected to create a RegionScope before applying the ops contained in the
299  /// region. This ensures that the mappings between values defined in the
300  /// transform IR region and payload IR operations are cleared when the region
301  /// processing ends; such values cannot be accessed outside the region.
302  class RegionScope {
303  public:
304  /// Forgets the mapping from or to values defined in the associated
305  /// transform IR region.
307  state.mappings.erase(region);
308 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
309  state.regionStack.pop_back();
310 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
311  }
312 
313  private:
314  /// Creates a new scope for mappings between values defined in the given
315  /// transform IR region and payload IR operations.
316  RegionScope(TransformState &state, Region &region)
317  : state(state), region(&region) {
318  auto res = state.mappings.try_emplace(this->region);
319  assert(res.second && "the region scope is already present");
320  (void)res;
321 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
322  assert(state.regionStack.back()->isProperAncestor(&region) &&
323  "scope started at a non-nested region");
324  state.regionStack.push_back(&region);
325 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
326  }
327 
328  /// Back-reference to the transform state.
329  TransformState &state;
330 
331  /// The region this scope is associated with.
332  Region *region;
333 
334  friend RegionScope TransformState::make_region_scope(Region &);
335  };
336  friend class RegionScope;
337 
338  /// Base class for TransformState extensions that allow TransformState to
339  /// contain user-specified information in the state object. Clients are
340  /// expected to derive this class, add the desired fields, and make the
341  /// derived class compatible with the MLIR TypeID mechanism:
342  ///
343  /// ```mlir
344  /// class MyExtension final : public TransformState::Extension {
345  /// public:
346  /// MyExtension(TranfsormState &state, int myData)
347  /// : Extension(state) {...}
348  /// private:
349  /// int mySupplementaryData;
350  /// };
351  /// ```
352  ///
353  /// Instances of this and derived classes are not expected to be created by
354  /// the user, instead they are directly constructed within a TransformState. A
355  /// TransformState can only contain one extension with the given TypeID.
356  /// Extensions can be obtained from a TransformState instance, and can be
357  /// removed when they are no longer required.
358  ///
359  /// ```mlir
360  /// transformState.addExtension<MyExtension>(/*myData=*/42);
361  /// MyExtension *ext = transformState.getExtension<MyExtension>();
362  /// ext->doSomething();
363  /// ```
364  class Extension {
365  // Allow TransformState to allocate Extensions.
366  friend class TransformState;
367 
368  public:
369  /// Base virtual destructor.
370  // Out-of-line definition ensures symbols are emitted in a single object
371  // file.
372  virtual ~Extension();
373 
374  protected:
375  /// Constructs an extension of the given TransformState object.
376  Extension(TransformState &state) : state(state) {}
377 
378  /// Provides read-only access to the parent TransformState object.
379  const TransformState &getTransformState() const { return state; }
380 
381  /// Replaces the given payload op with another op. If the replacement op is
382  /// null, removes the association of the payload op with its handle.
383  LogicalResult replacePayloadOp(Operation *op, Operation *replacement);
384 
385  private:
386  /// Back-reference to the state that is being extended.
387  TransformState &state;
388  };
389 
390  /// Adds a new Extension of the type specified as template parameter,
391  /// constructing it with the arguments provided. The extension is owned by the
392  /// TransformState. It is expected that the state does not already have an
393  /// extension of the same type. Extension constructors are expected to take
394  /// a reference to TransformState as first argument, automatically supplied
395  /// by this call.
396  template <typename Ty, typename... Args>
397  Ty &addExtension(Args &&...args) {
398  static_assert(
400  "only an class derived from TransformState::Extension is allowed here");
401  auto ptr = std::make_unique<Ty>(*this, std::forward<Args>(args)...);
402  auto result = extensions.try_emplace(TypeID::get<Ty>(), std::move(ptr));
403  assert(result.second && "extension already added");
404  return *static_cast<Ty *>(result.first->second.get());
405  }
406 
407  /// Returns the extension of the specified type.
408  template <typename Ty>
409  Ty *getExtension() {
410  static_assert(
412  "only an class derived from TransformState::Extension is allowed here");
413  auto iter = extensions.find(TypeID::get<Ty>());
414  if (iter == extensions.end())
415  return nullptr;
416  return static_cast<Ty *>(iter->second.get());
417  }
418 
419  /// Removes the extension of the specified type.
420  template <typename Ty>
422  static_assert(
424  "only an class derived from TransformState::Extension is allowed here");
425  extensions.erase(TypeID::get<Ty>());
426  }
427 
428 private:
429  /// Identifier for storing top-level value in the `operations` mapping.
430  static constexpr Value kTopLevelValue = Value();
431 
432  /// Returns the mappings frame for the reigon in which the value is defined.
433  const Mappings &getMapping(Value value) const {
434  return const_cast<TransformState *>(this)->getMapping(value);
435  }
436  Mappings &getMapping(Value value) {
437  auto it = mappings.find(value.getParentRegion());
438  assert(it != mappings.end() &&
439  "trying to find a mapping for a value from an unmapped region");
440  return it->second;
441  }
442 
443  /// Returns the mappings frame for the region in which the operation resides.
444  const Mappings &getMapping(Operation *operation) const {
445  return const_cast<TransformState *>(this)->getMapping(operation);
446  }
447  Mappings &getMapping(Operation *operation) {
448  auto it = mappings.find(operation->getParentRegion());
449  assert(it != mappings.end() &&
450  "trying to find a mapping for an operation from an unmapped region");
451  return it->second;
452  }
453 
454  /// Sets the payload IR ops associated with the given transform IR value.
455  /// Fails if this would result in multiple transform IR values with uses
456  /// corresponding to the same payload IR ops. For example, a hypothetical
457  /// "find function by name" transform op would (indirectly) call this
458  /// function for its result. Having two such calls in a row with for different
459  /// values, e.g. coming from different ops:
460  ///
461  /// %0 = transform.find_func_by_name { name = "myfunc" }
462  /// %1 = transform.find_func_by_name { name = "myfunc" }
463  ///
464  /// would lead to both values pointing to the same operation. The second call
465  /// to setPayloadOps will fail, unless the association with the %0 value is
466  /// removed first by calling update/removePayloadOps.
467  LogicalResult setPayloadOps(Value value, ArrayRef<Operation *> targets);
468 
469  /// Forgets the payload IR ops associated with the given transform IR value.
470  void removePayloadOps(Value value);
471 
472  /// Updates the payload IR ops associated with the given transform IR value.
473  /// The callback function is called once per associated operation and is
474  /// expected to return the modified operation or nullptr. In the latter case,
475  /// the corresponding operation is no longer associated with the transform IR
476  /// value. May fail if the operation produced by the update callback is
477  /// already associated with a different Transform IR handle value.
479  updatePayloadOps(Value value,
480  function_ref<Operation *(Operation *)> callback);
481 
482  /// Attempts to record the mapping between the given Payload IR operation and
483  /// the given Transform IR handle. Fails and reports an error if the operation
484  /// is already tracked by another handle.
485  static LogicalResult tryEmplaceReverseMapping(Mappings &map, Operation *op,
486  Value handle);
487 
488  /// If the operand is a handle consumed by the operation, i.e. has the "free"
489  /// memory effect associated with it, identifies other handles that are
490  /// pointing to payload IR operations nested in the operations pointed to by
491  /// the consumed handle. Marks all such handles as invalidated so trigger
492  /// errors if they are used.
493  void recordHandleInvalidation(OpOperand &handle);
494 
495  /// Checks that the operation does not use invalidated handles as operands.
496  /// Reports errors and returns failure if it does. Otherwise, invalidates the
497  /// handles consumed by the operation as well as any handles pointing to
498  /// payload IR operations nested in the operations associated with the
499  /// consumed handles.
501  checkAndRecordHandleInvalidation(TransformOpInterface transform);
502 
503  /// The mappings between transform IR values and payload IR ops, aggregated by
504  /// the region in which the transform IR values are defined.
505  llvm::SmallDenseMap<Region *, Mappings> mappings;
506 
507  /// Extensions attached to the TransformState, identified by the TypeID of
508  /// their type. Only one extension of any given type is allowed.
510 
511  /// The top-level operation that contains all payload IR, typically a module.
512  Operation *topLevel;
513 
514  /// Additional options controlling the transformation state behavior.
516 
517  /// The mapping from invalidated handles to the error-reporting functions that
518  /// describe when the handles were invalidated. Calling such a function emits
519  /// a user-visible diagnostic.
520  DenseMap<Value, std::function<void()>> invalidatedHandles;
521 
522 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
523  /// A stack of nested regions that are being processed in the transform IR.
524  /// Each region must be an ancestor of the following regions in this list.
525  /// These are also the keys for "mappings".
526  SmallVector<Region *> regionStack;
527 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
528 };
529 
530 /// Local mapping between values defined by a specific op implementing the
531 /// TransformOpInterface and the payload IR ops they correspond to.
533  friend class TransformState;
534 
535 public:
536  /// Indicates that the result of the transform IR op at the given position
537  /// corresponds to the given list of payload IR ops. Each result must be set
538  /// by the transformation exactly once.
539  void set(OpResult value, ArrayRef<Operation *> ops);
540 
541 private:
542  /// Creates an instance of TransformResults that expects mappings for
543  /// `numSegments` values.
544  explicit TransformResults(unsigned numSegments);
545 
546  /// Gets the list of operations associated with the result identified by its
547  /// number in the list of operation results.
548  ArrayRef<Operation *> get(unsigned resultNumber) const;
549 
550  /// Storage for pointers to payload IR ops that are associated with results of
551  /// a transform IR op. `segments` contains as many entries as the transform IR
552  /// op has results. Each entry is a reference to a contiguous segment in
553  /// the `operations` list that contains the pointers to operations. This
554  /// allows for operations to be stored contiguously without nested vectors and
555  /// for different segments to be set in any order.
557  SmallVector<Operation *> operations;
558 };
559 
560 TransformState::RegionScope TransformState::make_region_scope(Region &region) {
561  return RegionScope(*this, region);
562 }
563 
564 namespace detail {
565 /// Maps the only block argument of the op with PossibleTopLevelTransformOpTrait
566 /// to either the list of operations associated with its operand or the root of
567 /// the payload IR, depending on what is available in the context.
570  Operation *op, Region &region);
571 
572 /// Verification hook for PossibleTopLevelTransformOpTrait.
574 } // namespace detail
575 
576 /// This trait is supposed to be attached to Transform dialect operations that
577 /// can be standalone top-level transforms. Such operations typically contain
578 /// other Transform dialect operations that can be executed following some
579 /// control flow logic specific to the current operation. The operations with
580 /// this trait are expected to have at least one single-block region with one
581 /// argument of PDL Operation type. The operations are also expected to be valid
582 /// without operands, in which case they are considered top-level, and with one
583 /// or more arguments, in which case they are considered nested. Top-level
584 /// operations have the block argument of the entry block in the Transform IR
585 /// correspond to the root operation of Payload IR. Nested operations have the
586 /// block argument of the entry block in the Transform IR correspond to a list
587 /// of Payload IR operations mapped to the first operand of the Transform IR
588 /// operation. The operation must implement TransformOpInterface.
589 template <typename OpTy>
591  : public OpTrait::TraitBase<OpTy, PossibleTopLevelTransformOpTrait> {
592 public:
593  /// Verifies that `op` satisfies the invariants of this trait. Not expected to
594  /// be called directly.
597  }
598 
599  /// Returns the single block of the given region.
600  Block *getBodyBlock(unsigned region = 0) {
601  return &this->getOperation()->getRegion(region).front();
602  }
603 
604  /// Sets up the mapping between the entry block of the given region of this op
605  /// and the relevant list of Payload IR operations in the given state. The
606  /// state is expected to be already scoped at the region of this operation.
607  /// Returns failure if the mapping failed, e.g., the value is already mapped.
609  assert(region.getParentOp() == this->getOperation() &&
610  "op comes from the wrong region");
612  state, this->getOperation(), region);
613  }
615  assert(
616  this->getOperation()->getNumRegions() == 1 &&
617  "must indicate the region to map if the operation has more than one");
618  return mapBlockArguments(state, this->getOperation()->getRegion(0));
619  }
620 };
621 
622 /// Trait implementing the TransformOpInterface for operations applying a
623 /// transformation to a single operation handle and producing zero, one or
624 /// multiple operation handles.
625 /// The op must implement a method with the following signature:
626 /// - DiagnosedSilenceableFailure applyToOne(OpTy,
627 /// SmallVector<Operation*> &results, state)
628 /// to perform a transformation that is applied in turn to all payload IR
629 /// operations that correspond to the handle of the transform IR operation.
630 /// In `applyToOne`, OpTy is either Operation* or a concrete payload IR Op class
631 /// that the transformation is applied to (and NOT the class of the transform IR
632 /// op).
633 /// The `applyToOne` method takes an empty `results` vector that it fills with
634 /// zero, one or multiple operations depending on the number of resultd expected
635 /// by the transform op.
636 /// The number of results must match the number of results of the transform op.
637 /// `applyToOne` is allowed to fill the `results` with all null elements to
638 /// signify that the transformation did not apply to the payload IR operations.
639 /// Such null elements are filtered out from results before return.
640 ///
641 /// The transform op having this trait is expected to have a single operand.
642 template <typename OpTy>
644  : public OpTrait::TraitBase<OpTy, TransformEachOpTrait> {
645 public:
646  /// Calls `applyToOne` for every payload operation associated with the operand
647  /// of this transform IR op, the following case disjunction happens:
648  /// 1. If not target payload ops are associated to the operand then fill the
649  /// results vector with the expected number of null elements and return
650  /// success. This is the corner case handling that allows propagating
651  /// the "no-op" case gracefully to improve usability.
652  /// 2. If any `applyToOne` returns definiteFailure, the transformation is
653  /// immediately considered definitely failed and we return.
654  /// 3. All applications of `applyToOne` are checked to return a number of
655  /// results expected by the transform IR op. If not, this is a definite
656  /// failure and we return early.
657  /// 4. If `applyToOne` produces ops, associate them with the result of this
658  /// transform op.
659  /// 5. If any `applyToOne` return silenceableFailure, the transformation is
660  /// considered silenceable.
661  /// 6. Otherwise the transformation is considered successful.
662  DiagnosedSilenceableFailure apply(TransformResults &transformResults,
663  TransformState &state);
664 
665  /// Checks that the op matches the expectations of this trait.
666  static LogicalResult verifyTrait(Operation *op);
667 };
668 
669 /// Side effect resource corresponding to the mapping between Transform IR
670 /// values and Payload IR operations. An Allocate effect from this resource
671 /// means creating a new mapping entry, it is always accompanied by a Write
672 /// effet. A Read effect from this resource means accessing the mapping. A Free
673 /// effect on this resource indicates the removal of the mapping entry,
674 /// typically after a transformation that modifies the Payload IR operations
675 /// associated with one of the Transform IR operation's operands. It is always
676 /// accompanied by a Read effect. Read-after-Free and double-Free are not
677 /// allowed (they would be problematic with "regular" memory effects too) as
678 /// they indicate an attempt to access Payload IR operations that have been
679 /// modified, potentially erased, by the previous tranfsormations.
680 // TODO: consider custom effects if these are not enabling generic passes such
681 // as CSE/DCE to work.
683  : public SideEffects::Resource::Base<TransformMappingResource> {
684  StringRef getName() override { return "transform.mapping"; }
685 };
686 
687 /// Side effect resource corresponding to the Payload IR itself. Only Read and
688 /// Write effects are expected on this resource, with Write always accompanied
689 /// by a Read (short of fully replacing the top-level Payload IR operation, one
690 /// cannot modify the Payload IR without reading it first). This is intended
691 /// to disallow reordering of Transform IR operations that mutate the Payload IR
692 /// while still allowing the reordering of those that only access it.
694  : public SideEffects::Resource::Base<PayloadIRResource> {
695  StringRef getName() override { return "transform.payload_ir"; }
696 };
697 
698 /// Populates `effects` with the memory effects indicating the operation on the
699 /// given handle value:
700 /// - consumes = Read + Free,
701 /// - produces = Allocate + Write,
702 /// - onlyReads = Read.
703 void consumesHandle(ValueRange handles,
705 void producesHandle(ValueRange handles,
707 void onlyReadsHandle(ValueRange handles,
709 
710 /// Checks whether the transform op consumes the given handle.
711 bool isHandleConsumed(Value handle, transform::TransformOpInterface transform);
712 
713 /// Populates `effects` with the memory effects indicating the access to payload
714 /// IR resource.
717 
718 /// Trait implementing the MemoryEffectOpInterface for operations that "consume"
719 /// their operands and produce new results.
720 template <typename OpTy>
722  : public OpTrait::TraitBase<OpTy, FunctionalStyleTransformOpTrait> {
723 public:
724  /// This op "consumes" the operands by reading and freeing then, "produces"
725  /// the results by allocating and writing it and reads/writes the payload IR
726  /// in the process.
728  consumesHandle(this->getOperation()->getOperands(), effects);
729  producesHandle(this->getOperation()->getResults(), effects);
730  modifiesPayload(effects);
731  }
732 
733  /// Checks that the op matches the expectations of this trait.
735  if (!op->getName().getInterface<MemoryEffectOpInterface>()) {
736  op->emitError()
737  << "FunctionalStyleTransformOpTrait should only be attached to ops "
738  "that implement MemoryEffectOpInterface";
739  }
740  return success();
741  }
742 };
743 
744 /// Trait implementing the MemoryEffectOpInterface for single-operand
745 /// single-result operations that use their operand without consuming and
746 /// without modifying the Payload IR to produce a new handle.
747 template <typename OpTy>
749  : public OpTrait::TraitBase<OpTy, NavigationTransformOpTrait> {
750 public:
751  /// This op produces handles to the Payload IR without consuming the original
752  /// handles and without modifying the IR itself.
754  onlyReadsHandle(this->getOperation()->getOperands(), effects);
755  producesHandle(this->getOperation()->getResults(), effects);
756  onlyReadsPayload(effects);
757  }
758 
759  /// Checks that the op matches the expectation of this trait.
761  static_assert(OpTy::template hasTrait<OpTrait::OneOperand>(),
762  "expected single-operand op");
763  static_assert(OpTy::template hasTrait<OpTrait::OneResult>(),
764  "expected single-result op");
765  if (!op->getName().getInterface<MemoryEffectOpInterface>()) {
766  op->emitError() << "NavigationTransformOpTrait should only be attached "
767  "to ops that implement MemoryEffectOpInterface";
768  }
769  return success();
770  }
771 };
772 
773 } // namespace transform
774 } // namespace mlir
775 
776 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h.inc"
777 
778 namespace mlir {
779 namespace transform {
780 namespace detail {
781 /// Applies a one-to-one or a one-to-many transform to each of the given
782 /// targets. Puts the results of transforms, if any, in `results` in the same
783 /// order. Fails if any of the application fails. Individual transforms must be
784 /// callable with the following signature:
785 /// - DiagnosedSilenceableFailure(OpTy,
786 /// SmallVector<Operation*> &results, state)
787 /// where OpTy is either
788 /// - Operation *, in which case the transform is always applied;
789 /// - a concrete Op class, in which case a check is performed whether
790 /// `targets` contains operations of the same class and a silenceable failure
791 /// is reported if it does not.
792 template <typename FnTy>
794  Location loc, int expectedNumResults, ArrayRef<Operation *> targets,
795  SmallVectorImpl<SmallVector<Operation *>> &results, FnTy transform) {
796  SmallVector<Diagnostic> silenceableStack;
797  using OpTy = typename llvm::function_traits<FnTy>::template arg_t<0>;
799  "expected transform function to take an operation");
800  for (Operation *target : targets) {
801  // Emplace back a placeholder for the returned new ops.
802  // This is filled with `expectedNumResults` if the op fails to apply.
803  results.push_back(SmallVector<Operation *>());
804 
805  auto specificOp = dyn_cast<OpTy>(target);
806  if (!specificOp) {
808  diag << "transform applied to the wrong op kind";
809  diag.attachNote(target->getLoc()) << "when applied to this op";
810  // Producing `expectedNumResults` nullptr is a silenceableFailure mode.
811  // TODO: encode this implicit `expectedNumResults` nullptr ==
812  // silenceableFailure with a proper trait.
813  results.back().assign(expectedNumResults, nullptr);
814  silenceableStack.push_back(std::move(diag));
815  continue;
816  }
817 
818  DiagnosedSilenceableFailure result = transform(specificOp, results.back());
819  if (result.isDefiniteFailure())
820  return result;
821  if (result.isSilenceableFailure())
822  for (auto &&diag : result.takeDiagnostics())
823  silenceableStack.push_back(std::move(diag));
824  }
825  if (!silenceableStack.empty()) {
827  std::move(silenceableStack));
828  }
830 }
831 
832 /// Helper function: transpose MxN into NxM; assumes that the input is a valid.
836  if (m.empty())
837  return res;
838  int64_t rows = m.size(), cols = m[0].size();
839  for (int64_t j = 0; j < cols; ++j)
840  res.push_back(SmallVector<Operation *, 1>(rows, nullptr));
841  for (int64_t i = 0; i < rows; ++i) {
842  assert(static_cast<int64_t>(m[i].size()) == cols);
843  for (int64_t j = 0; j < cols; ++j) {
844  res[j][i] = m[i][j];
845  }
846  }
847  return res;
848 }
849 } // namespace detail
850 } // namespace transform
851 } // namespace mlir
852 
853 template <typename OpTy>
856  TransformResults &transformResults, TransformState &state) {
857  using TransformOpType = typename llvm::function_traits<
858  decltype(&OpTy::applyToOne)>::template arg_t<0>;
859  ArrayRef<Operation *> targets =
860  state.getPayloadOps(this->getOperation()->getOperand(0));
861 
862  // Step 1. Handle the corner case where no target is specified.
863  // This is typically the case when the matcher fails to apply and we need to
864  // propagate gracefully.
865  // In this case, we fill all results with an empty vector.
866  if (targets.empty()) {
868  for (auto r : this->getOperation()->getResults())
869  transformResults.set(r.template cast<OpResult>(), empty);
871  }
872 
873  // Step 2. Call applyToOne on each target and record newly produced ops in its
874  // corresponding results entry.
875  int expectedNumResults = this->getOperation()->getNumResults();
878  this->getOperation()->getLoc(), expectedNumResults, targets, results,
879  [&](TransformOpType specificOp, SmallVector<Operation *> &partialResult) {
880  auto res = static_cast<OpTy *>(this)->applyToOne(specificOp,
881  partialResult, state);
882  if (res.isDefiniteFailure())
883  return res;
884 
885  // TODO: encode this implicit must always produce `expectedNumResults`
886  // and nullptr is fine with a proper trait.
887  if (static_cast<int>(partialResult.size()) != expectedNumResults) {
888  auto loc = this->getOperation()->getLoc();
889  auto diag = mlir::emitError(loc, "applications of ")
890  << OpTy::getOperationName() << " expected to produce "
891  << expectedNumResults << " results (actually produced "
892  << partialResult.size() << ").";
893  diag.attachNote(loc)
894  << "If you need variadic results, consider a generic `apply` "
895  << "instead of the specialized `applyToOne`.";
896  diag.attachNote(loc)
897  << "Producing " << expectedNumResults << " null results is "
898  << "allowed if the use case warrants it.";
899  diag.attachNote(specificOp->getLoc()) << "when applied to this op";
901  }
902  // Check that all is null or none is null
903  // TODO: relax this behavior and encode with a proper trait.
904  if (llvm::any_of(partialResult, [](Operation *op) { return op; }) &&
905  llvm::any_of(partialResult, [](Operation *op) { return !op; })) {
906  auto loc = this->getOperation()->getLoc();
907  auto diag = mlir::emitError(loc, "unexpected application of ")
908  << OpTy::getOperationName()
909  << " produces both null and non null results.";
910  diag.attachNote(specificOp->getLoc()) << "when applied to this op";
912  }
913  return res;
914  });
915 
916  // Step 3. Propagate the definite failure if any and bail out.
917  if (result.isDefiniteFailure())
918  return result;
919 
920  // Step 4. If there are no results, return early.
921  if (OpTy::template hasTrait<OpTrait::ZeroResults>())
922  return result;
923 
924  // Step 5. Perform transposition of M applications producing N results each
925  // into N results for each of the M applications.
926  SmallVector<SmallVector<Operation *, 1>> transposedResults =
927  detail::transposeResults(results);
928 
929  // Step 6. Single result applies to M ops produces one single M-result.
930  if (OpTy::template hasTrait<OpTrait::OneResult>()) {
931  assert(transposedResults.size() == 1 && "Expected single result");
932  transformResults.set(
933  this->getOperation()->getResult(0).template cast<OpResult>(),
934  transposedResults[0]);
935  // ApplyToOne may have returned silenceableFailure, propagate it.
936  return result;
937  }
938 
939  // Step 7. Filter out empty results and set the transformResults.
940  for (const auto &it :
941  llvm::zip(this->getOperation()->getResults(), transposedResults)) {
943  llvm::copy_if(std::get<1>(it), std::back_inserter(filtered),
944  [](Operation *op) { return op; });
945  transformResults.set(std::get<0>(it).template cast<OpResult>(), filtered);
946  }
947 
948  // Step 8. ApplyToOne may have returned silenceableFailure, propagate it.
949  return result;
950 }
951 
952 template <typename OpTy>
955  static_assert(OpTy::template hasTrait<OpTrait::OneOperand>(),
956  "expected single-operand op");
957  if (!op->getName().getInterface<TransformOpInterface>()) {
958  return op->emitError() << "TransformEachOpTrait should only be attached to "
959  "ops that implement TransformOpInterface";
960  }
961 
962  return success();
963 }
964 
965 #endif // DIALECT_TRANSFORM_IR_TRANSFORMINTERFACES_H
Diagnostic & attachNote(Optional< Location > loc=llvm::None)
Attaches a note to the last diagnostic.
Include the generated interface declarations.
This class contains a list of basic blocks and a link to the parent operation it is attached to...
Definition: Region.h:26
void set(OpResult value, ArrayRef< Operation *> ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
static std::string diag(llvm::Value &v)
Side effect resource corresponding to the Payload IR itself.
The result of a transform IR operation application.
LogicalResult checkAndReport()
Converts all kinds of failure into a LogicalResult failure, emitting the diagnostic if necessary...
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
This trait is supposed to be attached to Transform dialect operations that can be standalone top-leve...
This is a value defined by a result of an operation.
Definition: Value.h:425
Block represents an ordered list of Operations.
Definition: Block.h:29
Diagnostic & attachNote(Optional< Location > noteLoc=llvm::None)
Attaches a note to this diagnostic.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
Options controlling the application of transform operations by the TransformState.
SmallVector< Diagnostic > && takeDiagnostics()
Take the diagnostic and silence.
bool getExpensiveChecksEnabled() const
Returns true if the expensive checks are requested.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
Definition: LogicalResult.h:68
Ty & addExtension(Args &&...args)
Adds a new Extension of the type specified as template parameter, constructing it with the arguments ...
ArrayRef< Operation * > getPayloadOps(Value value) const
Returns the list of ops that the given transform IR value corresponds to.
Extension(TransformState &state)
Constructs an extension of the given TransformState object.
void removeExtension()
Removes the extension of the specified type.
Operation & front()
Definition: Block.h:144
T::Concept * getInterface() const
Returns an instance of the concept object for the given interface if it was registered to this operat...
StringRef getName() override
Return a string name of the resource.
static LogicalResult verifyTrait(Operation *op)
Checks that the op matches the expectation of this trait.
DiagnosedSilenceableFailure applyTransformToEach(Location loc, int expectedNumResults, ArrayRef< Operation *> targets, SmallVectorImpl< SmallVector< Operation *>> &results, FnTy transform)
Applies a one-to-one or a one-to-many transform to each of the given targets.
static constexpr const bool value
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
LogicalResult silence()
Converts silenceable failure into LogicalResult success without reporting the diagnostic, preserves the other states.
std::string getStatusString() const
Returns a string representation of the failure mode (for error reporting).
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Trait implementing the TransformOpInterface for operations applying a transformation to a single oper...
Trait implementing the MemoryEffectOpInterface for operations that "consume" their operands and produ...
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine...
Definition: Diagnostics.h:157
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
Side effect resource corresponding to the mapping between Transform IR values and Payload IR operatio...
void consumesHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value: ...
bool isSilenceableFailure() const
Returns true if this is a silenceable failure.
DiagnosedSilenceableFailure apply(TransformResults &transformResults, TransformState &state)
Calls applyToOne for every payload operation associated with the operand of this transform IR op...
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition: Value.cpp:41
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
Trait implementing the MemoryEffectOpInterface for single-operand single-result operations that use t...
void onlyReadsHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
DiagnosedSilenceableFailure && operator<<(T &&value) &&
LogicalResult mapBlockArguments(TransformState &state)
static LogicalResult verifyTrait(Operation *op)
Verifies that op satisfies the invariants of this trait.
LogicalResult mapBlockArguments(TransformState &state, Region &region)
Sets up the mapping between the entry block of the given region of this op and the relevant list of P...
Block & back()
Definition: Region.h:64
A RAII object maintaining a "stack frame" for a transform IR region.
~RegionScope()
Forgets the mapping from or to values defined in the associated transform IR region.
DiagnosedSilenceableFailure(LogicalResult result)
The state maintained across applications of various ops implementing the TransformOpInterface.
DiagnosedSilenceableFailure & operator<<(T &&value) &
Streams the given values into the last diagnotic.
This base class is used for derived effects that are non-parametric.
bool succeeded() const
Returns true if this is a success.
This class represents an argument of a Block.
Definition: Value.h:300
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.
static DiagnosedSilenceableFailure silenceableFailure(SmallVector< Diagnostic > &&diag)
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
static llvm::ManagedStatic< PassManagerOptions > options
LogicalResult mapBlockArguments(BlockArgument argument, ArrayRef< Operation *> operations)
Records the mapping between a block argument in the transform IR and a list of operations in the payl...
Helper class for implementing traits.
Definition: OpDefinition.h:316
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
Ty * getExtension()
Returns the extension of the specified type.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
StringRef getName() override
Return a string name of the resource.
LogicalResult mapPossibleTopLevelTransformOpBlockArguments(TransformState &state, Operation *op, Region &region)
Maps the only block argument of the op with PossibleTopLevelTransformOpTrait to either the list of op...
bool isHandleConsumed(Value handle, transform::TransformOpInterface transform)
Checks whether the transform op consumes the given handle.
static LogicalResult verifyTrait(Operation *op)
Checks that the op matches the expectations of this trait.
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void getEffects(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
This op "consumes" the operands by reading and freeing then, "produces" the results by allocating and...
static SmallVector< SmallVector< Operation *, 1 > > transposeResults(const SmallVector< SmallVector< Operation *>, 1 > &m)
Helper function: transpose MxN into NxM; assumes that the input is a valid.
Base class for TransformState extensions that allow TransformState to contain user-specified informat...
void getEffects(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
This op produces handles to the Payload IR without consuming the original handles and without modifyi...
This class represents an operand of an operation.
Definition: Value.h:251
LogicalResult verifyTrait(ConcreteOp op)
This function defines the internal implementation of the verifyTrait method on FunctionOpInterface::T...
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:161
LogicalResult verifyPossibleTopLevelTransformOpTrait(Operation *op)
Verification hook for PossibleTopLevelTransformOpTrait.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:50
const TransformState & getTransformState() const
Provides read-only access to the parent TransformState object.
void producesHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:221
TransformOptions & enableExpensiveChecks(bool enable=true)
Requests computationally expensive checks of the transform and payload IR well-formedness to be perfo...
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:345
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
static LogicalResult verifyTrait(Operation *op)
Checks that the op matches the expectations of this trait.
Block * getBodyBlock(unsigned region=0)
Returns the single block of the given region.
std::string getMessage() const
Returns the diagnostic message without emitting it.