MLIR  19.0.0git
TransformInterfaces.h
Go to the documentation of this file.
1 //===- TransformInterfaces.h - Transform Dialect Interfaces -----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_TRANSFORM_INTERFACES_TRANSFORMINTERFACES_H
10 #define MLIR_DIALECT_TRANSFORM_INTERFACES_TRANSFORMINTERFACES_H
11 
14 #include "mlir/IR/OpDefinition.h"
15 #include "mlir/IR/PatternMatch.h"
19 
20 #include "mlir/Dialect/Transform/Interfaces/TransformTypeInterfaces.h.inc"
21 
22 namespace mlir {
23 namespace transform {
24 
25 class TransformOpInterface;
26 class TransformResults;
27 class TransformRewriter;
28 class TransformState;
29 
30 using Param = Attribute;
32 
33 namespace detail {
34 /// Maps the only block argument of the op with PossibleTopLevelTransformOpTrait
35 /// to either the list of operations associated with its operand or the root of
36 /// the payload IR, depending on what is available in the context.
39  Operation *op, Region &region);
40 
41 /// Verification hook for PossibleTopLevelTransformOpTrait.
43 
44 /// Populates `effects` with side effects implied by
45 /// PossibleTopLevelTransformOpTrait for the given operation. The operation may
46 /// have an optional `root` operand, indicating it is not in fact top-level. It
47 /// is also expected to have a single-block body.
49  Operation *operation, Value root, Block &body,
51 
52 /// Verification hook for TransformOpInterface.
54 
55 /// Populates `mappings` with mapped values associated with the given transform
56 /// IR values in the given `state`.
59  ValueRange values, const transform::TransformState &state);
60 
61 /// Populates `results` with payload associations that match exactly those of
62 /// the operands to `block`'s terminator.
65 
66 /// Make a dummy transform state for testing purposes. This MUST NOT be used
67 /// outside of test cases.
69  Operation *payloadRoot);
70 
71 /// Returns all operands that are handles and being consumed by the given op.
73 getConsumedHandleOpOperands(transform::TransformOpInterface transformOp);
74 } // namespace detail
75 } // namespace transform
76 } // namespace mlir
77 
78 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h.inc"
79 
80 namespace mlir {
81 namespace transform {
82 
83 /// Options controlling the application of transform operations by the
84 /// TransformState.
86 public:
87  TransformOptions() = default;
88  TransformOptions(const TransformOptions &) = default;
90 
91  /// Requests computationally expensive checks of the transform and payload IR
92  /// well-formedness to be performed before each transformation. In particular,
93  /// these ensure that the handles still point to valid operations when used.
94  TransformOptions &enableExpensiveChecks(bool enable = true) {
95  expensiveChecksEnabled = enable;
96  return *this;
97  }
98 
99  // Ensures that only a single top-level transform op is present in the IR.
101  enforceSingleToplevelTransformOp = enable;
102  return *this;
103  }
104 
105  /// Returns true if the expensive checks are requested.
106  bool getExpensiveChecksEnabled() const { return expensiveChecksEnabled; }
107 
108  // Returns true if enforcing a single top-level transform op is requested.
110  return enforceSingleToplevelTransformOp;
111  }
112 
113 private:
114  bool expensiveChecksEnabled = true;
115  bool enforceSingleToplevelTransformOp = true;
116 };
117 
118 /// Entry point to the Transform dialect infrastructure. Applies the
119 /// transformation specified by `transform` to payload IR contained in
120 /// `payloadRoot`. The `transform` operation may contain other operations that
121 /// will be executed following the internal logic of the operation. It must
122 /// have the `PossibleTopLevelTransformOp` trait and not have any operands.
123 /// This function internally keeps track of the transformation state.
125 applyTransforms(Operation *payloadRoot, TransformOpInterface transform,
126  const RaggedArray<MappedValue> &extraMapping = {},
127  const TransformOptions &options = TransformOptions(),
128  bool enforceToplevelTransformOp = true);
129 
130 /// The state maintained across applications of various ops implementing the
131 /// TransformOpInterface. The operations implementing this interface and the
132 /// surrounding structure are referred to as transform IR. The operations to
133 /// which transformations apply are referred to as payload IR. Transform IR
134 /// operates on values that can be associated either with a list of payload IR
135 /// operations (such values are referred to as handles) or with a list of
136 /// parameters represented as attributes. The state thus contains the mapping
137 /// between values defined in the transform IR ops and either payload IR ops or
138 /// parameters. For payload ops, the mapping is many-to-many and the reverse
139 /// mapping is also stored. The "expensive-checks" option can be passed to the
140 /// constructor at transformation execution time that transform IR values used
141 /// as operands by a transform IR operation are not associated with dangling
142 /// pointers to payload IR operations that are known to have been erased by
143 /// previous transformation through the same or a different transform IR value.
144 ///
145 /// A reference to this class is passed as an argument to "apply" methods of the
146 /// transform op interface. Thus the "apply" method can call either
147 /// `state.getPayloadOps( getSomeOperand() )` to obtain the list of operations
148 /// or `state.getParams( getSomeOperand() )` to obtain the list of parameters
149 /// associated with its operand. The method is expected to populate the
150 /// `TransformResults` class instance in order to update the mapping. The
151 /// `applyTransform` method takes care of propagating the state of
152 /// `TransformResults` into the instance of this class.
153 ///
154 /// When applying transform IR operations with regions, the client is expected
155 /// to create a `RegionScope` RAII object to create a new "stack frame" for
156 /// values defined inside the region. The mappings from and to these values will
157 /// be automatically dropped when the object goes out of scope, typically at the
158 /// end of the `apply` function of the parent operation. If a region contains
159 /// blocks with arguments, the client can map those arguments to payload IR ops
160 /// using `mapBlockArguments`.
162 public:
164 
165 private:
166  /// Mapping between a Value in the transform IR and the corresponding set of
167  /// operations in the payload IR.
169 
170  /// Mapping between a payload IR operation and the transform IR values it is
171  /// associated with.
174 
175  /// Mapping between a Value in the transform IR and the corresponding list of
176  /// parameters.
178 
179  /// Mapping between a Value in the transform IR and the corrsponding list of
180  /// values in the payload IR. Also works for reverse mappings.
182 
183  /// Mapping between a Value in the transform IR and an error message that
184  /// should be emitted when the value is used.
185  using InvalidatedHandleMap = DenseMap<Value, std::function<void(Location)>>;
186 
187 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
188  /// Debug only: A timestamp is associated with each transform IR value, so
189  /// that invalid iterator usage can be detected more reliably.
190  using TransformIRTimestampMapping = DenseMap<Value, int64_t>;
191 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
192 
193  /// The bidirectional mappings between transform IR values and payload IR
194  /// operations, and the mapping between transform IR values and parameters.
195  struct Mappings {
196  TransformOpMapping direct;
198  ParamMapping params;
199  ValueMapping values;
200  ValueMapping reverseValues;
201 
202 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
203  TransformIRTimestampMapping timestamps;
204  void incrementTimestamp(Value value) { ++timestamps[value]; }
205 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
206  };
207 
208  friend LogicalResult applyTransforms(Operation *, TransformOpInterface,
209  const RaggedArray<MappedValue> &,
210  const TransformOptions &, bool);
211 
212  friend TransformState
214 
215 public:
216  const TransformOptions &getOptions() const { return options; }
217 
218  /// Returns the op at which the transformation state is rooted. This is
219  /// typically helpful for transformations that apply globally.
220  Operation *getTopLevel() const;
221 
222  /// Returns the number of extra mappings for the top-level operation.
223  size_t getNumTopLevelMappings() const { return topLevelMappedValues.size(); }
224 
225  /// Returns the position-th extra mapping for the top-level operation.
226  ArrayRef<MappedValue> getTopLevelMapping(size_t position) const {
227  return topLevelMappedValues[position];
228  }
229 
230  /// Returns an iterator that enumerates all ops that the given transform IR
231  /// value corresponds to. Ops may be erased while iterating; erased ops are
232  /// not enumerated. This function is helpful for transformations that apply to
233  /// a particular handle.
234  auto getPayloadOps(Value value) const {
235  ArrayRef<Operation *> view = getPayloadOpsView(value);
236 
237 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
238  // Memorize the current timestamp and make sure that it has not changed
239  // when incrementing or dereferencing the iterator returned by this
240  // function. The timestamp is incremented when the "direct" mapping is
241  // resized; this would invalidate the iterator returned by this function.
242  int64_t currentTimestamp = getMapping(value).timestamps.lookup(value);
243 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
244 
245  // When ops are replaced/erased, they are replaced with nullptr (until
246  // the data structure is compacted). Do not enumerate these ops.
247  return llvm::make_filter_range(view, [=](Operation *op) {
248 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
249  [[maybe_unused]] bool sameTimestamp =
250  currentTimestamp == this->getMapping(value).timestamps.lookup(value);
251  assert(sameTimestamp && "iterator was invalidated during iteration");
252 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
253  return op != nullptr;
254  });
255  }
256 
257  /// Returns the list of parameters that the given transform IR value
258  /// corresponds to.
259  ArrayRef<Attribute> getParams(Value value) const;
260 
261  /// Returns an iterator that enumerates all payload IR values that the given
262  /// transform IR value corresponds to.
263  auto getPayloadValues(Value handleValue) const {
264  ArrayRef<Value> view = getPayloadValuesView(handleValue);
265 
266 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
267  // Memorize the current timestamp and make sure that it has not changed
268  // when incrementing or dereferencing the iterator returned by this
269  // function. The timestamp is incremented when the "values" mapping is
270  // resized; this would invalidate the iterator returned by this function.
271  int64_t currentTimestamp =
272  getMapping(handleValue).timestamps.lookup(handleValue);
273  return llvm::make_filter_range(view, [=](Value v) {
274  [[maybe_unused]] bool sameTimestamp =
275  currentTimestamp ==
276  this->getMapping(handleValue).timestamps.lookup(handleValue);
277  assert(sameTimestamp && "iterator was invalidated during iteration");
278  return true;
279  });
280 #else
281  return llvm::make_range(view.begin(), view.end());
282 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
283  }
284 
285  /// Populates `handles` with all handles pointing to the given Payload IR op.
286  /// Returns success if such handles exist, failure otherwise.
287  /// If `includeOutOfScope` is set to "true", handles that are defined in
288  /// regions beyond the most recent isolated from above region are included.
290  SmallVectorImpl<Value> &handles,
291  bool includeOutOfScope = false) const;
292 
293  /// Populates `handles` with all handles pointing to the given payload IR
294  /// value. Returns success if such handles exist, failure otherwise.
295  /// If `includeOutOfScope` is set to "true", handles that are defined in
296  /// regions beyond the most recent isolated from above region are included.
298  SmallVectorImpl<Value> &handles,
299  bool includeOutOfScope = false) const;
300 
301  /// Applies the transformation specified by the given transform op and updates
302  /// the state accordingly.
303  DiagnosedSilenceableFailure applyTransform(TransformOpInterface transform);
304 
305  /// Records the mapping between a block argument in the transform IR and a
306  /// list of operations in the payload IR. The arguments must be defined in
307  /// blocks of the currently processed transform IR region, typically after a
308  /// region scope is defined.
309  ///
310  /// Returns failure if the payload does not satisfy the conditions associated
311  /// with the type of the handle value.
313  ArrayRef<Operation *> operations) {
314  assert(argument.getParentRegion() == regionStack.back()->region &&
315  "mapping block arguments from a region other than the active one");
316  return setPayloadOps(argument, operations);
317  }
319  ArrayRef<MappedValue> values);
320 
321  // Forward declarations to support limited visibility.
322  class RegionScope;
323 
324  /// Creates a new region scope for the given region. The region is expected to
325  /// be nested in the currently processed region.
326  // Implementation note: this method is inline but implemented outside of the
327  // class body to comply with visibility and full-declaration requirements.
328  inline RegionScope make_region_scope(Region &region);
329 
330  /// A RAII object maintaining a "stack frame" for a transform IR region. When
331  /// applying a transform IR operation that contains a region, the caller is
332  /// expected to create a RegionScope before applying the ops contained in the
333  /// region. This ensures that the mappings between values defined in the
334  /// transform IR region and payload IR operations are cleared when the region
335  /// processing ends; such values cannot be accessed outside the region.
336  class RegionScope {
337  public:
338  /// Forgets the mapping from or to values defined in the associated
339  /// transform IR region, and restores the mapping that existed before
340  /// entering this scope.
341  ~RegionScope();
342 
343  private:
344  /// Creates a new scope for mappings between values defined in the given
345  /// transform IR region and payload IR objects.
346  RegionScope(TransformState &state, Region &region)
347  : state(state), region(&region) {
348  auto res = state.mappings.insert(
349  std::make_pair(&region, std::make_unique<Mappings>()));
350  assert(res.second && "the region scope is already present");
351  (void)res;
352  state.regionStack.push_back(this);
353  }
354 
355  /// Back-reference to the transform state.
356  TransformState &state;
357 
358  /// The region this scope is associated with.
359  Region *region;
360 
361  /// The transform op within this region that is currently being applied.
362  TransformOpInterface currentTransform;
363 
365  };
366  friend class RegionScope;
367 
368  /// Base class for TransformState extensions that allow TransformState to
369  /// contain user-specified information in the state object. Clients are
370  /// expected to derive this class, add the desired fields, and make the
371  /// derived class compatible with the MLIR TypeID mechanism:
372  ///
373  /// ```mlir
374  /// class MyExtension final : public TransformState::Extension {
375  /// public:
376  /// MyExtension(TranfsormState &state, int myData)
377  /// : Extension(state) {...}
378  /// private:
379  /// int mySupplementaryData;
380  /// };
381  /// ```
382  ///
383  /// Instances of this and derived classes are not expected to be created by
384  /// the user, instead they are directly constructed within a TransformState. A
385  /// TransformState can only contain one extension with the given TypeID.
386  /// Extensions can be obtained from a TransformState instance, and can be
387  /// removed when they are no longer required.
388  ///
389  /// ```mlir
390  /// transformState.addExtension<MyExtension>(/*myData=*/42);
391  /// MyExtension *ext = transformState.getExtension<MyExtension>();
392  /// ext->doSomething();
393  /// ```
394  class Extension {
395  // Allow TransformState to allocate Extensions.
396  friend class TransformState;
397 
398  public:
399  /// Base virtual destructor.
400  // Out-of-line definition ensures symbols are emitted in a single object
401  // file.
402  virtual ~Extension();
403 
404  protected:
405  /// Constructs an extension of the given TransformState object.
406  Extension(TransformState &state) : state(state) {}
407 
408  /// Provides read-only access to the parent TransformState object.
409  const TransformState &getTransformState() const { return state; }
410 
411  /// Replaces the given payload op with another op. If the replacement op is
412  /// null, removes the association of the payload op with its handle. Returns
413  /// failure if the op is not associated with any handle.
414  ///
415  /// Note: This function does not update value handles. None of the original
416  /// op's results are allowed to be mapped to any value handle.
418 
419  /// Replaces the given payload value with another value. If the replacement
420  /// value is null, removes the association of the payload value with its
421  /// handle. Returns failure if the value is not associated with any handle.
422  LogicalResult replacePayloadValue(Value value, Value replacement);
423 
424  private:
425  /// Back-reference to the state that is being extended.
426  TransformState &state;
427  };
428 
429  /// Adds a new Extension of the type specified as template parameter,
430  /// constructing it with the arguments provided. The extension is owned by the
431  /// TransformState. It is expected that the state does not already have an
432  /// extension of the same type. Extension constructors are expected to take
433  /// a reference to TransformState as first argument, automatically supplied
434  /// by this call.
435  template <typename Ty, typename... Args>
436  Ty &addExtension(Args &&...args) {
437  static_assert(
438  std::is_base_of<Extension, Ty>::value,
439  "only an class derived from TransformState::Extension is allowed here");
440  auto ptr = std::make_unique<Ty>(*this, std::forward<Args>(args)...);
441  auto result = extensions.try_emplace(TypeID::get<Ty>(), std::move(ptr));
442  assert(result.second && "extension already added");
443  return *static_cast<Ty *>(result.first->second.get());
444  }
445 
446  /// Returns the extension of the specified type.
447  template <typename Ty>
448  Ty *getExtension() {
449  static_assert(
450  std::is_base_of<Extension, Ty>::value,
451  "only an class derived from TransformState::Extension is allowed here");
452  auto iter = extensions.find(TypeID::get<Ty>());
453  if (iter == extensions.end())
454  return nullptr;
455  return static_cast<Ty *>(iter->second.get());
456  }
457 
458  /// Removes the extension of the specified type.
459  template <typename Ty>
461  static_assert(
462  std::is_base_of<Extension, Ty>::value,
463  "only an class derived from TransformState::Extension is allowed here");
464  extensions.erase(TypeID::get<Ty>());
465  }
466 
467 private:
468  /// Identifier for storing top-level value in the `operations` mapping.
469  static constexpr Value kTopLevelValue = Value();
470 
471  /// Creates a state for transform ops living in the given region. The second
472  /// argument points to the root operation in the payload IR being transformed,
473  /// which may or may not contain the region with transform ops. Additional
474  /// options can be provided through the trailing configuration object.
475  TransformState(Region *region, Operation *payloadRoot,
476  const RaggedArray<MappedValue> &extraMappings = {},
477  const TransformOptions &options = TransformOptions());
478 
479  /// Returns the mappings frame for the region in which the value is defined.
480  /// If `allowOutOfScope` is set to "false", asserts that the value is in
481  /// scope, based on the current stack of frames.
482  const Mappings &getMapping(Value value, bool allowOutOfScope = false) const {
483  return const_cast<TransformState *>(this)->getMapping(value,
484  allowOutOfScope);
485  }
486  Mappings &getMapping(Value value, bool allowOutOfScope = false) {
487  Region *region = value.getParentRegion();
488  auto it = mappings.find(region);
489  assert(it != mappings.end() &&
490  "trying to find a mapping for a value from an unmapped region");
491 #ifndef NDEBUG
492  if (!allowOutOfScope) {
493  for (Region *r : llvm::reverse(llvm::make_first_range(mappings))) {
494  if (r == region)
495  break;
496  if (r->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
497  llvm_unreachable("trying to get mapping beyond region that is "
498  "isolated from above");
499  }
500  }
501 #endif // NDEBUG
502  return *it->second;
503  }
504 
505  /// Returns the mappings frame for the region in which the operation resides.
506  /// If `allowOutOfScope` is set to "false", asserts that the operation is in
507  /// scope, based on the current stack of frames.
508  const Mappings &getMapping(Operation *operation,
509  bool allowOutOfScope = false) const {
510  return const_cast<TransformState *>(this)->getMapping(operation,
511  allowOutOfScope);
512  }
513  Mappings &getMapping(Operation *operation, bool allowOutOfScope = false) {
514  Region *region = operation->getParentRegion();
515  auto it = mappings.find(region);
516  assert(it != mappings.end() &&
517  "trying to find a mapping for an operation from an unmapped region");
518 #ifndef NDEBUG
519  if (!allowOutOfScope) {
520  for (Region *r : llvm::reverse(llvm::make_first_range(mappings))) {
521  if (r == region)
522  break;
523  if (r->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
524  llvm_unreachable("trying to get mapping beyond region that is "
525  "isolated from above");
526  }
527  }
528 #endif // NDEBUG
529  return *it->second;
530  }
531 
532  /// Updates the state to include the associations between op results and the
533  /// provided result of applying a transform op.
534  LogicalResult updateStateFromResults(const TransformResults &results,
535  ResultRange opResults);
536 
537  /// Returns a list of all ops that the given transform IR value corresponds
538  /// to. In case an op was erased, the returned list contains nullptr. This
539  /// function is helpful for transformations that apply to a particular handle.
540  ArrayRef<Operation *> getPayloadOpsView(Value value) const;
541 
542  /// Returns a list of payload IR values that the given transform IR value
543  /// corresponds to.
544  ArrayRef<Value> getPayloadValuesView(Value handleValue) const;
545 
546  /// Sets the payload IR ops associated with the given transform IR value
547  /// (handle). A payload op may be associated multiple handles as long as
548  /// at most one of them gets consumed by further transformations.
549  /// For example, a hypothetical "find function by name" may be called twice in
550  /// a row to produce two handles pointing to the same function:
551  ///
552  /// %0 = transform.find_func_by_name { name = "myfunc" }
553  /// %1 = transform.find_func_by_name { name = "myfunc" }
554  ///
555  /// which is valid by itself. However, calling a hypothetical "rewrite and
556  /// rename function" transform on both handles:
557  ///
558  /// transform.rewrite_and_rename %0 { new_name = "func" }
559  /// transform.rewrite_and_rename %1 { new_name = "func" }
560  ///
561  /// is invalid given the transformation "consumes" the handle as expressed
562  /// by side effects. Practically, a transformation consuming a handle means
563  /// that the associated payload operation may no longer exist.
564  ///
565  /// Similarly, operation handles may be invalidate and should not be used
566  /// after a transform that consumed a value handle pointing to a payload value
567  /// defined by the operation as either block argument or op result. For
568  /// example, in the following sequence, the last transform operation rewrites
569  /// the callee to not return a specified result:
570  ///
571  /// %0 = transform.find_call "myfunc"
572  /// %1 = transform.find_results_of_calling "myfunc"
573  /// transform.drop_call_result_from_signature %1[0]
574  ///
575  /// which requires the call operations to be recreated. Therefore, the handle
576  /// %0 becomes associated with a dangling pointer and should not be used.
577  ///
578  /// Returns failure if the payload does not satisfy the conditions associated
579  /// with the type of the handle value. The value is expected to have a type
580  /// implementing TransformHandleTypeInterface.
581  LogicalResult setPayloadOps(Value value, ArrayRef<Operation *> targets);
582 
583  /// Sets the payload IR values association with the given transform IR value
584  /// (handle). A payload value may be associated with multiple handles as long
585  /// as at most one of them is consumed by further transformations. For
586  /// example, a hypothetical "get results of calls to function with the given
587  /// name" transform may be performed twice in a row producing handles pointing
588  /// to the same values:
589  ///
590  /// %0 = transform.find_results_of_calling "myfunc"
591  /// %1 = transform.find_results_of_calling "myfunc"
592  ///
593  /// which is valid by itself. However, calling a hypothetical "erase value
594  /// producer" transform on both handles:
595  ///
596  /// transform.erase_value_produce %0
597  /// transform.erase_value_produce %1
598  ///
599  /// is invalid provided the transformation "consumes" the handle as expressed
600  /// by side effects (which themselves reflect the semantics of the transform
601  /// erasing the producer and making the handle dangling). Practically, a
602  /// transformation consuming a handle means the associated payload value may
603  /// no longer exist.
604  ///
605  /// Similarly, value handles are invalidated and should not be used after a
606  /// transform that consumed an operation handle pointing to the payload IR
607  /// operation defining the values associated the value handle, as either block
608  /// arguments or op results, or any ancestor operation. For example,
609  ///
610  /// %0 = transform.find_call "myfunc"
611  /// %1 = transform.find_results_of_calling "myfunc"
612  /// transform.rewrite_and_rename %0 { new_name = "func" }
613  ///
614  /// makes %1 unusable after the last transformation if it consumes %0. When an
615  /// operation handle is consumed, it usually indicates that the operation was
616  /// destroyed or heavily modified, meaning that the values it defines may no
617  /// longer exist.
618  ///
619  /// Returns failure if the payload values do not satisfy the conditions
620  /// associated with the type of the handle value. The value is expected to
621  /// have a type implementing TransformValueHandleTypeInterface.
622  LogicalResult setPayloadValues(Value handle, ValueRange payloadValues);
623 
624  /// Sets the parameters associated with the given transform IR value. Returns
625  /// failure if the parameters do not satisfy the conditions associated with
626  /// the type of the value. The value is expected to have a type implementing
627  /// TransformParamTypeInterface.
628  LogicalResult setParams(Value value, ArrayRef<Param> params);
629 
630  /// Forgets the payload IR ops associated with the given transform IR value,
631  /// as well as any association between value handles and the results of said
632  /// payload IR op.
633  ///
634  /// If `allowOutOfScope` is set to "false", asserts that the handle is in
635  /// scope, based on the current stack of frames.
636  void forgetMapping(Value opHandle, ValueRange origOpFlatResults,
637  bool allowOutOfScope = false);
638 
639  void forgetValueMapping(Value valueHandle,
640  ArrayRef<Operation *> payloadOperations);
641 
642  /// Replaces the given payload op with another op. If the replacement op is
643  /// null, removes the association of the payload op with its handle. Returns
644  /// failure if the op is not associated with any handle.
645  ///
646  /// Note: This function does not update value handles. None of the original
647  /// op's results are allowed to be mapped to any value handle.
648  LogicalResult replacePayloadOp(Operation *op, Operation *replacement);
649 
650  /// Replaces the given payload value with another value. If the replacement
651  /// value is null, removes the association of the payload value with its
652  /// handle. Returns failure if the value is not associated with any handle.
653  LogicalResult replacePayloadValue(Value value, Value replacement);
654 
655  /// Records handle invalidation reporters into `newlyInvalidated`.
656  /// Specifically,
657  /// - `handle` is the op operand that consumes the handle,
658  /// - `potentialAncestors` is a list of ancestors of the payload operation
659  /// that the consumed handle is associated with, including itself,
660  /// - `throughValue` is the payload value the handle to which is consumed,
661  /// when it is the case, null when the operation handle is consumed
662  /// directly.
663  /// Iterates over all known operation and value handles and records reporters
664  /// for any potential future use of `handle` or any other handle that is
665  /// invalidated by its consumption, i.e., any handle pointing to any payload
666  /// IR entity (operation or value) associated with the same payload IR entity
667  /// as the consumed handle, or any nested payload IR entity. If
668  /// `potentialAncestors` is empty, records the reporter anyway. Does not
669  /// override existing reporters. This must remain a const method so it doesn't
670  /// inadvertently mutate `invalidatedHandles` too early.
671  void recordOpHandleInvalidation(OpOperand &consumingHandle,
672  ArrayRef<Operation *> potentialAncestors,
673  Value throughValue,
674  InvalidatedHandleMap &newlyInvalidated) const;
675 
676  /// Records handle invalidation reporters into `newlyInvalidated`.
677  /// Specifically,
678  /// - `consumingHandle` is the op operand that consumes the handle,
679  /// - `potentialAncestors` is a list of ancestors of the payload operation
680  /// that the consumed handle is associated with, including itself,
681  /// - `payloadOp` is the operation itself,
682  /// - `otherHandle` is another that may be associated with the affected
683  /// payload operations
684  /// - `throughValue` is the payload value the handle to which is consumed,
685  /// when it is the case, null when the operation handle is consumed
686  /// directly.
687  /// Looks at the payload opreations associated with `otherHandle` and if any
688  /// of these operations has an ancestor (or is itself) listed in
689  /// `potentialAncestors`, records the error message describing the use of the
690  /// invalidated handle. Does nothing if `otherHandle` already has a reporter
691  /// associated with it. This must remain a const method so it doesn't
692  /// inadvertently mutate `invalidatedHandles` too early.
693  void recordOpHandleInvalidationOne(
694  OpOperand &consumingHandle, ArrayRef<Operation *> potentialAncestors,
695  Operation *payloadOp, Value otherHandle, Value throughValue,
696  InvalidatedHandleMap &newlyInvalidated) const;
697 
698  /// Records handle invalidation reporters into `newlyInvalidated`.
699  /// Specifically,
700  /// - `opHandle` is the op operand that consumes the handle;
701  /// - `potentialAncestors` is a list of ancestors of the payload operation
702  /// that the consumed handle is associated with, including itself;
703  /// - `payloadValue` is the value defined by the operation associated with
704  /// the consuming handle as either op result or block argument;
705  /// - `valueHandle` is another that may be associated with the payload value.
706  /// Looks at the payload values associated with `valueHandle` and if any of
707  /// these values is defined, as op result or block argument, by an operation
708  /// whose ancestor (or the operation itself) is listed in
709  /// `potentialAncestors`, records the error message describing the use of the
710  /// invalidated handle. Does nothing if `valueHandle` already has a reporter
711  /// associated with it. This must remain a const method so it doesn't
712  /// inadvertently mutate `invalidatedHandles` too early.
713  void recordValueHandleInvalidationByOpHandleOne(
714  OpOperand &opHandle, ArrayRef<Operation *> potentialAncestors,
715  Value payloadValue, Value valueHandle,
716  InvalidatedHandleMap &newlyInvalidated) const;
717 
718  /// Records handle invalidation reporters into `newlyInvalidated`.
719  /// Specifically,
720  /// - `valueHandle` is the op operand that consumes the handle,
721  /// - `throughValue` is the payload value the handle to which is consumed,
722  /// when it is the case, null when the operation handle is consumed
723  /// directly.
724  /// Iterates over all known operation and value handles and records reporters
725  /// for any potential future use of `handle` or any other handle that is
726  /// invalidated by its consumption, i.e., any handle pointing to any payload
727  /// IR entity (operation or value) associated with the same payload IR entity
728  /// as the consumed handle, or any nested payload IR entity. Does not override
729  /// existing reporters. This must remain a const method so it doesn't
730  /// inadvertently mutate `invalidatedHandles` too early.
731  void
732  recordValueHandleInvalidation(OpOperand &valueHandle,
733  InvalidatedHandleMap &newlyInvalidated) const;
734 
735  /// Checks that the operation does not use invalidated handles as operands.
736  /// Reports errors and returns failure if it does. Otherwise, invalidates the
737  /// handles consumed by the operation as well as any handles pointing to
738  /// payload IR operations nested in the operations associated with the
739  /// consumed handles.
740  LogicalResult
741  checkAndRecordHandleInvalidation(TransformOpInterface transform);
742 
743  /// Implementation of the checkAndRecordHandleInvalidation. This must remain a
744  /// const method so it doesn't inadvertently mutate `invalidatedHandles` too
745  /// early.
746  LogicalResult checkAndRecordHandleInvalidationImpl(
747  transform::TransformOpInterface transform,
748  transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const;
749 
750  /// Remove all nullptrs from op handles that were added by `replacePayloadOp`.
751  void compactOpHandles();
752 
753  /// A stack of mappings between transform IR values and payload IR ops,
754  /// aggregated by the region in which the transform IR values are defined.
755  /// We use a pointer to the Mappings struct so that reallocations inside
756  /// MapVector don't invalidate iterators when we apply nested transform ops
757  /// while also iterating over the mappings.
758  llvm::MapVector<Region *, std::unique_ptr<Mappings>> mappings;
759 
760  /// Op handles may be temporarily mapped to nullptr to avoid invalidating
761  /// payload op iterators. This set contains all op handles with nullptrs.
762  /// These handles are "compacted" (i.e., nullptrs removed) at the end of each
763  /// transform.
764  DenseSet<Value> opHandlesToCompact;
765 
766  /// Extensions attached to the TransformState, identified by the TypeID of
767  /// their type. Only one extension of any given type is allowed.
768  DenseMap<TypeID, std::unique_ptr<Extension>> extensions;
769 
770  /// The top-level operation that contains all payload IR, typically a module.
771  Operation *topLevel;
772 
773  /// Extra mapped values (payload operations, values or parameters) to be
774  /// associated with additional entry block arguments of the top-level
775  /// transform operation.
776  RaggedArray<MappedValue> topLevelMappedValues;
777 
778  /// Additional options controlling the transformation state behavior.
779  TransformOptions options;
780 
781  /// The mapping from invalidated handles to the error-reporting functions that
782  /// describe when the handles were invalidated. Calling such a function emits
783  /// a user-visible diagnostic with an additional note pointing to the given
784  /// location.
785  InvalidatedHandleMap invalidatedHandles;
786 
787  /// A stack of nested regions that are being processed in the transform IR.
788  /// Each region must be an ancestor of the following regions in this list.
789  /// These are also the keys for "mappings".
790  SmallVector<RegionScope *> regionStack;
791 
792  /// The top-level region scope. The first (bottom) element of `regionStack`
793  /// is the top-level region scope object.
794  std::unique_ptr<RegionScope> topLevelRegionScope;
795 };
796 
797 /// Local mapping between values defined by a specific op implementing the
798 /// TransformOpInterface and the payload IR ops they correspond to.
800  friend class TransformState;
801 
802 public:
803  /// Indicates that the result of the transform IR op at the given position
804  /// corresponds to the given list of payload IR ops. Each result must be set
805  /// by the transformation exactly once in case of transformation succeeding.
806  /// The value must have a type implementing TransformHandleTypeInterface.
807  template <typename Range>
808  void set(OpResult value, Range &&ops) {
809  int64_t position = value.getResultNumber();
810  assert(position < static_cast<int64_t>(operations.size()) &&
811  "setting results for a non-existent handle");
812  assert(operations[position].data() == nullptr && "results already set");
813  assert(params[position].data() == nullptr &&
814  "another kind of results already set");
815  assert(values[position].data() == nullptr &&
816  "another kind of results already set");
817  operations.replace(position, std::forward<Range>(ops));
818  }
819 
820  /// Indicates that the result of the transform IR op at the given position
821  /// corresponds to the given list of payload IR ops. Each result must be set
822  /// by the transformation exactly once in case of transformation succeeding.
823  /// The value must have a type implementing TransformHandleTypeInterface.
824  void set(OpResult value, std::initializer_list<Operation *> ops) {
825  set(value, ArrayRef<Operation *>(ops));
826  }
827 
828  /// Indicates that the result of the transform IR op at the given position
829  /// corresponds to the given list of parameters. Each result must be set by
830  /// the transformation exactly once in case of transformation succeeding. The
831  /// value must have a type implementing TransformParamTypeInterface.
833 
834  /// Indicates that the result of the transform IR op at the given position
835  /// corresponds to the given range of payload IR values. Each result must be
836  /// set by the transformation exactly once in case of transformation
837  /// succeeding. The value must have a type implementing
838  /// TransformValueHandleTypeInterface.
839  template <typename Range>
840  void setValues(OpResult handle, Range &&values) {
841  int64_t position = handle.getResultNumber();
842  assert(position < static_cast<int64_t>(this->values.size()) &&
843  "setting values for a non-existent handle");
844  assert(this->values[position].data() == nullptr && "values already set");
845  assert(operations[position].data() == nullptr &&
846  "another kind of results already set");
847  assert(params[position].data() == nullptr &&
848  "another kind of results already set");
849  this->values.replace(position, std::forward<Range>(values));
850  }
851 
852  /// Indicates that the result of the transform IR op at the given position
853  /// corresponds to the given range of payload IR values. Each result must be
854  /// set by the transformation exactly once in case of transformation
855  /// succeeding. The value must have a type implementing
856  /// TransformValueHandleTypeInterface.
857  void setValues(OpResult handle, std::initializer_list<Value> values) {
858  setValues(handle, ArrayRef<Value>(values));
859  }
860 
861  /// Indicates that the result of the transform IR op at the given position
862  /// corresponds to the given range of mapped values. All mapped values are
863  /// expected to be compatible with the type of the result, e.g., if the result
864  /// is an operation handle, all mapped values are expected to be payload
865  /// operations.
866  void setMappedValues(OpResult handle, ArrayRef<MappedValue> values);
867 
868  /// Sets the currently unset results to empty lists of the kind expected by
869  /// the corresponding results of the given `transform` op.
870  void setRemainingToEmpty(TransformOpInterface transform);
871 
872 private:
873  /// Creates an instance of TransformResults that expects mappings for
874  /// `numSegments` values, which may be associated with payload operations or
875  /// parameters.
876  explicit TransformResults(unsigned numSegments);
877 
878  /// Gets the list of operations associated with the result identified by its
879  /// number in the list of operation results. The result must have been set to
880  /// be associated with payload IR operations.
881  ArrayRef<Operation *> get(unsigned resultNumber) const;
882 
883  /// Gets the list of parameters associated with the result identified by its
884  /// number in the list of operation results. The result must have been set to
885  /// be associated with parameters.
886  ArrayRef<TransformState::Param> getParams(unsigned resultNumber) const;
887 
888  /// Gets the list of payload IR values associated with the result identified
889  /// by its number in the list of operation results. The result must have been
890  /// set to be associated with payload IR values.
891  ArrayRef<Value> getValues(unsigned resultNumber) const;
892 
893  /// Returns `true` if the result identified by its number in the list of
894  /// operation results is associated with a list of parameters, `false`
895  /// otherwise.
896  bool isParam(unsigned resultNumber) const;
897 
898  /// Returns `true` if the result identified by its number in the list of
899  /// operation results is associated with a list of payload IR value, `false`
900  /// otherwise.
901  bool isValue(unsigned resultNumber) const;
902 
903  /// Returns `true` if the result identified by its number in the list of
904  /// operation results is associated with something.
905  bool isSet(unsigned resultNumber) const;
906 
907  /// Pointers to payload IR ops that are associated with results of a transform
908  /// IR op.
909  RaggedArray<Operation *> operations;
910 
911  /// Parameters that are associated with results of the transform IR op.
912  RaggedArray<Param> params;
913 
914  /// Payload IR values that are associated with results of a transform IR op.
915  RaggedArray<Value> values;
916 };
917 
918 /// Creates a RAII object the lifetime of which corresponds to the new mapping
919 /// for transform IR values defined in the given region. Values defined in
920 /// surrounding regions remain accessible.
922  return RegionScope(*this, region);
923 }
924 
925 /// A configuration object for customizing a `TrackingListener`.
927  using SkipHandleFn = std::function<bool(Value)>;
928 
929  /// An optional function that returns "true" for handles that do not have to
930  /// be updated. These are typically dead or consumed handles.
932 
933  /// If set to "true", the name of a replacement op must match the name of the
934  /// original op. If set to "false", the names of the payload ops tracked in a
935  /// handle may change as the tracking listener updates the transform state.
937 
938  /// If set to "true", cast ops (that implement the CastOpInterface) are
939  /// skipped and the replacement op search continues with the operands of the
940  /// cast op.
941  bool skipCastOps = true;
942 };
943 
944 /// A listener that updates a TransformState based on IR modifications. This
945 /// listener can be used during a greedy pattern rewrite to keep the transform
946 /// state up-to-date.
949 public:
950  /// Create a new TrackingListener for usage in the specified transform op.
951  /// Optionally, a function can be specified to identify handles that should
952  /// do not have to be updated.
953  TrackingListener(TransformState &state, TransformOpInterface op,
955 
956 protected:
957  /// Return a replacement payload op for the given op, which is going to be
958  /// replaced with the given values. By default, if all values are defined by
959  /// the same op, which also has the same type as the given op, that defining
960  /// op is used as a replacement.
961  ///
962  /// A "failure" return value indicates that no replacement operation could be
963  /// found. A "nullptr" return value indicates that no replacement op is needed
964  /// (e.g., handle is dead or was consumed) and that the payload op should
965  /// be dropped from the mapping.
966  ///
967  /// Example: A tracked "linalg.generic" with two results is replaced with two
968  /// values defined by (another) "linalg.generic". It is reasonable to assume
969  /// that the replacement "linalg.generic" represents the same "computation".
970  /// Therefore, the payload op mapping is updated to the defining op of the
971  /// replacement values.
972  ///
973  /// Counter Example: A "linalg.generic" is replaced with values defined by an
974  /// "scf.for". Without further investigation, the relationship between the
975  /// "linalg.generic" and the "scf.for" is unclear. They may not represent the
976  /// same computation; e.g., there may be tiled "linalg.generic" inside the
977  /// loop body that represents the original computation. Therefore, the
978  /// TrackingListener is conservative by default: it drops the mapping and
979  /// triggers the "payload replacement not found" notification. This default
980  /// behavior can be customized in `TrackingListenerConfig`.
981  ///
982  /// If no replacement op could be found according to the rules mentioned
983  /// above, this function tries to skip over cast-like ops that implement
984  /// `CastOpInterface`.
985  ///
986  /// Example: A tracked "linalg.generic" is replaced with "linalg.generic",
987  /// wrapped in a "tensor.cast". A cast is a metadata-only operation and it is
988  /// reasonable to assume that the wrapped "linalg.generic" represents the same
989  /// computation as the original "linalg.generic". The mapping is updated
990  /// accordingly.
991  ///
992  /// Certain ops (typically also metadata-only ops) are not considered casts,
993  /// but should be skipped nonetheless. Such ops should implement
994  /// `FindPayloadReplacementOpInterface` to specify with which operands the
995  /// lookup should continue.
996  ///
997  /// Example: A tracked "linalg.generic" is replaced with "linalg.generic",
998  /// wrapped in a "tensor.reshape". A reshape is a metadata-only operation but
999  /// not cast. (Implementing `CastOpInterface` would be incorrect and cause
1000  /// invalid foldings.) However, due to its `FindPayloadReplacementOpInterface`
1001  /// implementation, the replacement op lookup continues with the wrapped
1002  /// "linalg.generic" and the mapping is updated accordingly.
1003  ///
1004  /// Derived classes may override `findReplacementOp` to specify custom
1005  /// replacement rules.
1007  findReplacementOp(Operation *&result, Operation *op,
1008  ValueRange newValues) const;
1009 
1010  /// Notify the listener that the pattern failed to match the given operation,
1011  /// and provide a callback to populate a diagnostic with the reason why the
1012  /// failure occurred.
1013  void
1015  function_ref<void(Diagnostic &)> reasonCallback) override;
1016 
1017  /// This function is called when a tracked payload op is dropped because no
1018  /// replacement op was found. Derived classes can implement this function for
1019  /// custom error handling.
1020  virtual void
1023 
1024  /// Return the single op that defines all given values (if any).
1025  static Operation *getCommonDefiningOp(ValueRange values);
1026 
1027  /// Return the transform op in which this TrackingListener is used.
1028  TransformOpInterface getTransformOp() const { return transformOp; }
1029 
1030 private:
1031  friend class TransformRewriter;
1032 
1033  void notifyOperationErased(Operation *op) override;
1034 
1035  void notifyOperationReplaced(Operation *op, ValueRange newValues) override;
1036  using Listener::notifyOperationReplaced;
1037 
1038  /// The transform op in which this TrackingListener is used.
1039  TransformOpInterface transformOp;
1040 
1041  /// The handles that are consumed by the transform op.
1042  DenseSet<Value> consumedHandles;
1043 
1044  /// Tracking listener configuration.
1045  TrackingListenerConfig config;
1046 };
1047 
1048 /// A specialized listener that keeps track of cases in which no replacement
1049 /// payload could be found. The error state of this listener must be checked
1050 /// before the end of its lifetime.
1052 public:
1054 
1055  ~ErrorCheckingTrackingListener() override;
1056 
1057  /// Check and return the current error state of this listener. Afterwards,
1058  /// resets the error state to "success".
1060 
1061  /// Return "true" if this tracking listener had a failure.
1062  bool failed() const;
1063 
1064 protected:
1065  void
1067  DiagnosedSilenceableFailure &&diag) override;
1068 
1069 private:
1070  /// The error state of this listener. "Success" indicates that no error
1071  /// happened so far.
1073 
1074  /// The number of errors that have been encountered.
1075  int64_t errorCounter = 0;
1076 };
1077 
1078 /// This is a special rewriter to be used in transform op implementations,
1079 /// providing additional helper functions to update the transform state, etc.
1080 // TODO: Helper functions will be added in a subsequent change.
1082 protected:
1083  friend class TransformState;
1084 
1085  /// Create a new TransformRewriter.
1086  explicit TransformRewriter(MLIRContext *ctx,
1087  ErrorCheckingTrackingListener *listener);
1088 
1089 public:
1090  /// Return "true" if the tracking listener had failures.
1091  bool hasTrackingFailures() const;
1092 
1093  /// Silence all tracking failures that have been encountered so far.
1094  void silenceTrackingFailure();
1095 
1096  /// Notify the transform dialect interpreter that the given op has been
1097  /// replaced with another op and that the mapping between handles and payload
1098  /// ops/values should be updated. This function should be called before the
1099  /// original op is erased. It fails if the operation could not be replaced,
1100  /// e.g., because the original operation is not tracked.
1101  ///
1102  /// Note: As long as IR modifications are performed through this rewriter,
1103  /// the transform state is usually updated automatically. This function should
1104  /// be used when unsupported rewriter API is used; e.g., updating all uses of
1105  /// a tracked operation one-by-one instead of using `RewriterBase::replaceOp`.
1107  Operation *replacement);
1108 
1109 private:
1110  ErrorCheckingTrackingListener *const listener;
1111 };
1112 
1113 /// This trait is supposed to be attached to Transform dialect operations that
1114 /// can be standalone top-level transforms. Such operations typically contain
1115 /// other Transform dialect operations that can be executed following some
1116 /// control flow logic specific to the current operation. The operations with
1117 /// this trait are expected to have at least one single-block region with at
1118 /// least one argument of type implementing TransformHandleTypeInterface. The
1119 /// operations are also expected to be valid without operands, in which case
1120 /// they are considered top-level, and with one or more arguments, in which case
1121 /// they are considered nested. Top-level operations have the block argument of
1122 /// the entry block in the Transform IR correspond to the root operation of
1123 /// Payload IR. Nested operations have the block argument of the entry block in
1124 /// the Transform IR correspond to a list of Payload IR operations mapped to the
1125 /// first operand of the Transform IR operation. The operation must implement
1126 /// TransformOpInterface.
1127 template <typename OpTy>
1129  : public OpTrait::TraitBase<OpTy, PossibleTopLevelTransformOpTrait> {
1130 public:
1131  /// Verifies that `op` satisfies the invariants of this trait. Not expected to
1132  /// be called directly.
1135  }
1136 
1137  /// Returns the single block of the given region.
1138  Block *getBodyBlock(unsigned region = 0) {
1139  return &this->getOperation()->getRegion(region).front();
1140  }
1141 
1142  /// Populates `effects` with side effects implied by this trait.
1146  this->getOperation(), cast<OpTy>(this->getOperation()).getRoot(),
1147  *getBodyBlock(), effects);
1148  }
1149 
1150  /// Sets up the mapping between the entry block of the given region of this op
1151  /// and the relevant list of Payload IR operations in the given state. The
1152  /// state is expected to be already scoped at the region of this operation.
1154  assert(region.getParentOp() == this->getOperation() &&
1155  "op comes from the wrong region");
1157  state, this->getOperation(), region);
1158  }
1160  assert(
1161  this->getOperation()->getNumRegions() == 1 &&
1162  "must indicate the region to map if the operation has more than one");
1163  return mapBlockArguments(state, this->getOperation()->getRegion(0));
1164  }
1165 };
1166 
1167 class ApplyToEachResultList;
1168 
1169 /// Trait implementing the TransformOpInterface for operations applying a
1170 /// transformation to a single operation handle and producing an arbitrary
1171 /// number of handles and parameter values.
1172 /// The op must implement a method with the following signature:
1173 /// - DiagnosedSilenceableFailure applyToOne(OpTy,
1174 /// ApplyToEachResultList &results, TransformState &state)
1175 /// to perform a transformation that is applied in turn to all payload IR
1176 /// operations that correspond to the handle of the transform IR operation.
1177 /// In `applyToOne`, OpTy is either Operation* or a concrete payload IR Op class
1178 /// that the transformation is applied to (and NOT the class of the transform IR
1179 /// op).
1180 /// The `applyToOne` method takes an empty `results` vector that it fills with
1181 /// zero, one or multiple operations depending on the number of results expected
1182 /// by the transform op.
1183 /// The number of results must match the number of results of the transform op.
1184 /// `applyToOne` is allowed to fill the `results` with all null elements to
1185 /// signify that the transformation did not apply to the payload IR operations.
1186 /// Such null elements are filtered out from results before return.
1187 ///
1188 /// The transform op having this trait is expected to have a single operand.
1189 template <typename OpTy>
1191  : public OpTrait::TraitBase<OpTy, TransformEachOpTrait> {
1192 public:
1193  /// Calls `applyToOne` for every payload operation associated with the operand
1194  /// of this transform IR op, the following case disjunction happens:
1195  /// 1. If not target payload ops are associated to the operand then fill the
1196  /// results vector with the expected number of null elements and return
1197  /// success. This is the corner case handling that allows propagating
1198  /// the "no-op" case gracefully to improve usability.
1199  /// 2. If any `applyToOne` returns definiteFailure, the transformation is
1200  /// immediately considered definitely failed and we return.
1201  /// 3. All applications of `applyToOne` are checked to return a number of
1202  /// results expected by the transform IR op. If not, this is a definite
1203  /// failure and we return early.
1204  /// 4. If `applyToOne` produces ops, associate them with the result of this
1205  /// transform op.
1206  /// 5. If any `applyToOne` return silenceableFailure, the transformation is
1207  /// considered silenceable.
1208  /// 6. Otherwise the transformation is considered successful.
1210  TransformResults &transformResults,
1211  TransformState &state);
1212 
1213  /// Checks that the op matches the expectations of this trait.
1214  static LogicalResult verifyTrait(Operation *op);
1215 };
1216 
1217 /// Side effect resource corresponding to the mapping between Transform IR
1218 /// values and Payload IR operations. An Allocate effect from this resource
1219 /// means creating a new mapping entry, it is always accompanied by a Write
1220 /// effect. A Read effect from this resource means accessing the mapping. A Free
1221 /// effect on this resource indicates the removal of the mapping entry,
1222 /// typically after a transformation that modifies the Payload IR operations
1223 /// associated with one of the Transform IR operation's operands. It is always
1224 /// accompanied by a Read effect. Read-after-Free and double-Free are not
1225 /// allowed (they would be problematic with "regular" memory effects too) as
1226 /// they indicate an attempt to access Payload IR operations that have been
1227 /// modified, potentially erased, by the previous transformations.
1228 // TODO: consider custom effects if these are not enabling generic passes such
1229 // as CSE/DCE to work.
1231  : public SideEffects::Resource::Base<TransformMappingResource> {
1232  StringRef getName() override { return "transform.mapping"; }
1233 };
1234 
1235 /// Side effect resource corresponding to the Payload IR itself. Only Read and
1236 /// Write effects are expected on this resource, with Write always accompanied
1237 /// by a Read (short of fully replacing the top-level Payload IR operation, one
1238 /// cannot modify the Payload IR without reading it first). This is intended
1239 /// to disallow reordering of Transform IR operations that mutate the Payload IR
1240 /// while still allowing the reordering of those that only access it.
1242  : public SideEffects::Resource::Base<PayloadIRResource> {
1243  StringRef getName() override { return "transform.payload_ir"; }
1244 };
1245 
1246 /// Populates `effects` with the memory effects indicating the operation on the
1247 /// given handle value:
1248 /// - consumes = Read + Free,
1249 /// - produces = Allocate + Write,
1250 /// - onlyReads = Read.
1251 void consumesHandle(ValueRange handles,
1253 void producesHandle(ValueRange handles,
1255 void onlyReadsHandle(ValueRange handles,
1257 
1258 /// Checks whether the transform op consumes the given handle.
1259 bool isHandleConsumed(Value handle, transform::TransformOpInterface transform);
1260 
1261 /// Populates `effects` with the memory effects indicating the access to payload
1262 /// IR resource.
1265 
1266 /// Checks whether the transform op modifies the payload.
1267 bool doesModifyPayload(transform::TransformOpInterface transform);
1268 /// Checks whether the transform op reads the payload.
1269 bool doesReadPayload(transform::TransformOpInterface transform);
1270 
1271 /// Populates `consumedArguments` with positions of `block` arguments that are
1272 /// consumed by the operations in the `block`.
1274  Block &block, llvm::SmallDenseSet<unsigned> &consumedArguments);
1275 
1276 /// Trait implementing the MemoryEffectOpInterface for operations that "consume"
1277 /// their operands and produce new results.
1278 template <typename OpTy>
1280  : public OpTrait::TraitBase<OpTy, FunctionalStyleTransformOpTrait> {
1281 public:
1282  /// This op "consumes" the operands by reading and freeing then, "produces"
1283  /// the results by allocating and writing it and reads/writes the payload IR
1284  /// in the process.
1286  consumesHandle(this->getOperation()->getOperands(), effects);
1287  producesHandle(this->getOperation()->getResults(), effects);
1288  modifiesPayload(effects);
1289  }
1290 
1291  /// Checks that the op matches the expectations of this trait.
1293  if (!op->getName().getInterface<MemoryEffectOpInterface>()) {
1294  op->emitError()
1295  << "FunctionalStyleTransformOpTrait should only be attached to ops "
1296  "that implement MemoryEffectOpInterface";
1297  }
1298  return success();
1299  }
1300 };
1301 
1302 /// Trait implementing the MemoryEffectOpInterface for operations that use their
1303 /// operands without consuming and without modifying the Payload IR to
1304 /// potentially produce new handles.
1305 template <typename OpTy>
1307  : public OpTrait::TraitBase<OpTy, NavigationTransformOpTrait> {
1308 public:
1309  /// This op produces handles to the Payload IR without consuming the original
1310  /// handles and without modifying the IR itself.
1312  onlyReadsHandle(this->getOperation()->getOperands(), effects);
1313  producesHandle(this->getOperation()->getResults(), effects);
1314  if (llvm::any_of(this->getOperation()->getOperandTypes(), [](Type t) {
1315  return isa<TransformHandleTypeInterface,
1316  TransformValueHandleTypeInterface>(t);
1317  })) {
1318  onlyReadsPayload(effects);
1319  }
1320  }
1321 
1322  /// Checks that the op matches the expectation of this trait.
1324  if (!op->getName().getInterface<MemoryEffectOpInterface>()) {
1325  op->emitError() << "NavigationTransformOpTrait should only be attached "
1326  "to ops that implement MemoryEffectOpInterface";
1327  }
1328  return success();
1329  }
1330 };
1331 
1332 namespace detail {
1333 /// Non-template implementation of ParamProducerTransformOpTrait::getEffects().
1336 /// Non-template implementation of ParamProducerTransformOpTrait::verify().
1338 } // namespace detail
1339 
1340 /// Trait implementing the MemoryEffectsOpInterface for operations that produce
1341 /// transform dialect parameters. It marks all op results of
1342 /// TransformHandleTypeInterface as produced by the op, all operands as only
1343 /// read by the op and, if at least one of the operand is a handle to payload
1344 /// ops, the entire payload as potentially read. The op must only produce
1345 /// parameter-typed results.
1346 template <typename OpTy>
1348  : public OpTrait::TraitBase<OpTy, ParamProducerTransformOpTrait> {
1349 public:
1350  /// Populates `effects` with effect instances described in the trait
1351  /// documentation.
1354  effects);
1355  }
1356 
1357  /// Checks that the op matches the expectation of this trait, i.e., that it
1358  /// implements the MemoryEffectsOpInterface and only produces parameter-typed
1359  /// results.
1362  }
1363 };
1364 
1365 /// `TrackingListener` failures are reported only for ops that have this trait.
1366 /// The purpose of this trait is to give users more time to update their custom
1367 /// transform ops to use the provided `TransformRewriter` for all IR
1368 /// modifications. This trait will eventually be removed, and failures will be
1369 /// reported for all transform ops.
1370 template <typename OpTy>
1372  : public OpTrait::TraitBase<OpTy, ReportTrackingListenerFailuresOpTrait> {};
1373 
1374 /// A single result of applying a transform op with `ApplyEachOpTrait` to a
1375 /// single payload operation.
1377 
1378 /// A list of results of applying a transform op with `ApplyEachOpTrait` to a
1379 /// single payload operation, co-indexed with the results of the transform op.
1381 public:
1383  explicit ApplyToEachResultList(unsigned size) : results(size) {}
1384 
1385  /// Sets the list of results to `size` null pointers.
1386  void assign(unsigned size, std::nullptr_t) { results.assign(size, nullptr); }
1387 
1388  /// Sets the list of results to the given range of values.
1389  template <typename Range>
1390  void assign(Range &&range) {
1391  // This is roughly the implementation of SmallVectorImpl::assign.
1392  // Dispatching to it with map_range and template type inference would result
1393  // in more complex code here.
1394  results.clear();
1395  results.reserve(llvm::size(range));
1396  for (auto element : range) {
1397  if constexpr (std::is_convertible_v<decltype(*std::begin(range)),
1398  Operation *>) {
1399  results.push_back(static_cast<Operation *>(element));
1400  } else if constexpr (std::is_convertible_v<decltype(*std::begin(range)),
1401  Value>) {
1402  results.push_back(element.template get<Value>());
1403  } else {
1404  results.push_back(static_cast<Attribute>(element));
1405  }
1406  }
1407  }
1408 
1409  /// Appends an element to the list.
1410  // Using ApplyToEachResult that can be implicitly constructed from a Value but
1411  // not from a concrete Op that is implicitly convertible to a Value to avoid
1412  // ambiguity.
1413  void push_back(Operation *op) { results.push_back(op); }
1414  void push_back(Attribute attr) { results.push_back(attr); }
1415  void push_back(ApplyToEachResult r) { results.push_back(r); }
1416 
1417  /// Reserves space for `size` elements in the list.
1418  void reserve(unsigned size) { results.reserve(size); }
1419 
1420  /// Iterators over the list.
1421  auto begin() { return results.begin(); }
1422  auto end() { return results.end(); }
1423  auto begin() const { return results.begin(); }
1424  auto end() const { return results.end(); }
1425 
1426  /// Returns the number of elements in the list.
1427  size_t size() const { return results.size(); }
1428 
1429  /// Element access. Expects the index to be in bounds.
1430  ApplyToEachResult &operator[](size_t index) { return results[index]; }
1431  const ApplyToEachResult &operator[](size_t index) const {
1432  return results[index];
1433  }
1434 
1435 private:
1436  /// Underlying storage.
1438 };
1439 
1440 namespace detail {
1441 
1442 /// Check that the contents of `partialResult` matches the number, kind (payload
1443 /// op or parameter) and nullity (either all or none) requirements of
1444 /// `transformOp`. Report errors and return failure otherwise.
1445 LogicalResult checkApplyToOne(Operation *transformOp, Location payloadOpLoc,
1446  const ApplyToEachResultList &partialResult);
1447 
1448 /// "Transpose" the results produced by individual applications, arranging them
1449 /// per result value of the transform op, and populate `transformResults` with
1450 /// that. The number, kind and nullity of per-application results are assumed to
1451 /// have been verified.
1452 void setApplyToOneResults(Operation *transformOp,
1453  TransformResults &transformResults,
1455 
1456 /// Applies a one-to-one or a one-to-many transform to each of the given
1457 /// targets. Puts the results of transforms, if any, in `results` in the same
1458 /// order. Fails if any of the application fails. Individual transforms must be
1459 /// callable with the following signature:
1460 /// - DiagnosedSilenceableFailure(OpTy,
1461 /// SmallVector<Operation*> &results, state)
1462 /// where OpTy is either
1463 /// - Operation *, in which case the transform is always applied;
1464 /// - a concrete Op class, in which case a check is performed whether
1465 /// `targets` contains operations of the same class and a silenceable failure
1466 /// is reported if it does not.
1467 template <typename TransformOpTy, typename Range>
1469  TransformOpTy transformOp, TransformRewriter &rewriter, Range &&targets,
1471  using OpTy = typename llvm::function_traits<
1472  decltype(&TransformOpTy::applyToOne)>::template arg_t<1>;
1473  static_assert(std::is_convertible<OpTy, Operation *>::value,
1474  "expected transform function to take an operation");
1475  OpBuilder::InsertionGuard g(rewriter);
1476 
1477  SmallVector<Diagnostic> silenceableStack;
1478  unsigned expectedNumResults = transformOp->getNumResults();
1479  for (Operation *target : targets) {
1480  auto specificOp = dyn_cast<OpTy>(target);
1481  if (!specificOp) {
1482  Diagnostic diag(transformOp->getLoc(), DiagnosticSeverity::Error);
1483  diag << "transform applied to the wrong op kind";
1484  diag.attachNote(target->getLoc()) << "when applied to this op";
1485  silenceableStack.push_back(std::move(diag));
1486  continue;
1487  }
1488 
1489  ApplyToEachResultList partialResults;
1490  partialResults.reserve(expectedNumResults);
1491  Location specificOpLoc = specificOp->getLoc();
1492  rewriter.setInsertionPoint(specificOp);
1494  transformOp.applyToOne(rewriter, specificOp, partialResults, state);
1495  if (res.isDefiniteFailure())
1497 
1498  if (res.isSilenceableFailure()) {
1499  res.takeDiagnostics(silenceableStack);
1500  continue;
1501  }
1502 
1503  if (failed(detail::checkApplyToOne(transformOp, specificOpLoc,
1504  partialResults))) {
1506  }
1507  results.push_back(std::move(partialResults));
1508  }
1509  if (!silenceableStack.empty()) {
1511  std::move(silenceableStack));
1512  }
1514 }
1515 
1516 /// Reports an error and returns failure if `targets` contains an ancestor
1517 /// operation before its descendant (or a copy of itself). Implementation detail
1518 /// for expensive checks during `TransformEachOpTrait::apply`.
1520  ArrayRef<Operation *> targets);
1521 
1522 } // namespace detail
1523 } // namespace transform
1524 } // namespace mlir
1525 
1526 template <typename OpTy>
1529  TransformRewriter &rewriter, TransformResults &transformResults,
1530  TransformState &state) {
1531  Value handle = this->getOperation()->getOperand(0);
1532  auto targets = state.getPayloadOps(handle);
1533 
1534  // If the operand is consumed, check if it is associated with operations that
1535  // may be erased before their nested operations are.
1536  if (state.getOptions().getExpensiveChecksEnabled() &&
1537  isHandleConsumed(handle, cast<transform::TransformOpInterface>(
1538  this->getOperation())) &&
1539  failed(detail::checkNestedConsumption(this->getOperation()->getLoc(),
1540  llvm::to_vector(targets)))) {
1541  return DiagnosedSilenceableFailure::definiteFailure();
1542  }
1543 
1544  // Step 1. Handle the corner case where no target is specified.
1545  // This is typically the case when the matcher fails to apply and we need to
1546  // propagate gracefully.
1547  // In this case, we fill all results with an empty vector.
1548  if (std::empty(targets)) {
1549  SmallVector<Operation *> emptyPayload;
1550  SmallVector<Attribute> emptyParams;
1551  for (OpResult r : this->getOperation()->getResults()) {
1552  if (isa<TransformParamTypeInterface>(r.getType()))
1553  transformResults.setParams(r, emptyParams);
1554  else if (isa<TransformValueHandleTypeInterface>(r.getType()))
1555  transformResults.setValues(r, ValueRange());
1556  else
1557  transformResults.set(r, emptyPayload);
1558  }
1560  }
1561 
1562  // Step 2. Call applyToOne on each target and record newly produced ops in its
1563  // corresponding results entry.
1566  cast<OpTy>(this->getOperation()), rewriter, targets, results, state);
1567 
1568  // Step 3. Propagate the definite failure if any and bail out.
1569  if (result.isDefiniteFailure())
1570  return result;
1571 
1572  // Step 4. "Transpose" the results produced by individual applications,
1573  // arranging them per result value of the transform op. The number, kind and
1574  // nullity of per-application results have been verified by the callback
1575  // above.
1576  detail::setApplyToOneResults(this->getOperation(), transformResults, results);
1577 
1578  // Step 5. ApplyToOne may have returned silenceableFailure, propagate it.
1579  return result;
1580 }
1581 
1582 template <typename OpTy>
1585  static_assert(OpTy::template hasTrait<OpTrait::OneOperand>(),
1586  "expected single-operand op");
1587  if (!op->getName().getInterface<TransformOpInterface>()) {
1588  return op->emitError() << "TransformEachOpTrait should only be attached to "
1589  "ops that implement TransformOpInterface";
1590  }
1591 
1592  return success();
1593 }
1594 
1595 #endif // DIALECT_TRANSFORM_INTERFACES_TRANSFORMINTERFACES_H
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
Block represents an ordered list of Operations.
Definition: Block.h:30
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
void takeDiagnostics(SmallVectorImpl< Diagnostic > &diags)
Take the diagnostics and silence.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
bool isSilenceableFailure() const
Returns true if this is a silenceable failure.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:156
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
This is a value defined by a result of an operation.
Definition: Value.h:457
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:469
Helper class for implementing traits.
Definition: OpDefinition.h:373
Operation * getOperation()
Return the ultimate Operation being worked on.
Definition: OpDefinition.h:376
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
A 2D array where each row may have different length.
Definition: RaggedArray.h:18
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Definition: Region.cpp:45
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
Block & front()
Definition: Region.h:65
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
This base class is used for derived effects that are non-parametric.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition: Value.cpp:41
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
auto begin()
Iterators over the list.
const ApplyToEachResult & operator[](size_t index) const
void assign(unsigned size, std::nullptr_t)
Sets the list of results to size null pointers.
void reserve(unsigned size)
Reserves space for size elements in the list.
size_t size() const
Returns the number of elements in the list.
ApplyToEachResult & operator[](size_t index)
Element access. Expects the index to be in bounds.
void push_back(Operation *op)
Appends an element to the list.
void assign(Range &&range)
Sets the list of results to the given range of values.
A specialized listener that keeps track of cases in which no replacement payload could be found.
bool failed() const
Return "true" if this tracking listener had a failure.
void notifyPayloadReplacementNotFound(Operation *op, ValueRange values, DiagnosedSilenceableFailure &&diag) override
This function is called when a tracked payload op is dropped because no replacement op was found.
DiagnosedSilenceableFailure checkAndResetError()
Check and return the current error state of this listener.
Trait implementing the MemoryEffectOpInterface for operations that "consume" their operands and produ...
void getEffects(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
This op "consumes" the operands by reading and freeing then, "produces" the results by allocating and...
static LogicalResult verifyTrait(Operation *op)
Checks that the op matches the expectations of this trait.
Trait implementing the MemoryEffectOpInterface for operations that use their operands without consumi...
void getEffects(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
This op produces handles to the Payload IR without consuming the original handles and without modifyi...
static LogicalResult verifyTrait(Operation *op)
Checks that the op matches the expectation of this trait.
Trait implementing the MemoryEffectsOpInterface for operations that produce transform dialect paramet...
static LogicalResult verifyTrait(Operation *op)
Checks that the op matches the expectation of this trait, i.e., that it implements the MemoryEffectsO...
void getEffects(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with effect instances described in the trait documentation.
This trait is supposed to be attached to Transform dialect operations that can be standalone top-leve...
LogicalResult mapBlockArguments(TransformState &state)
static LogicalResult verifyTrait(Operation *op)
Verifies that op satisfies the invariants of this trait.
void getPotentialTopLevelEffects(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with side effects implied by this trait.
LogicalResult mapBlockArguments(TransformState &state, Region &region)
Sets up the mapping between the entry block of the given region of this op and the relevant list of P...
Block * getBodyBlock(unsigned region=0)
Returns the single block of the given region.
TrackingListener failures are reported only for ops that have this trait.
A listener that updates a TransformState based on IR modifications.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notify the listener that the pattern failed to match the given operation, and provide a callback to p...
TransformOpInterface getTransformOp() const
Return the transform op in which this TrackingListener is used.
TrackingListener(TransformState &state, TransformOpInterface op, TrackingListenerConfig config=TrackingListenerConfig())
Create a new TrackingListener for usage in the specified transform op.
static Operation * getCommonDefiningOp(ValueRange values)
Return the single op that defines all given values (if any).
virtual void notifyPayloadReplacementNotFound(Operation *op, ValueRange values, DiagnosedSilenceableFailure &&diag)
This function is called when a tracked payload op is dropped because no replacement op was found.
virtual DiagnosedSilenceableFailure findReplacementOp(Operation *&result, Operation *op, ValueRange newValues) const
Return a replacement payload op for the given op, which is going to be replaced with the given values...
Trait implementing the TransformOpInterface for operations applying a transformation to a single oper...
static LogicalResult verifyTrait(Operation *op)
Checks that the op matches the expectations of this trait.
DiagnosedSilenceableFailure apply(transform::TransformRewriter &rewriter, TransformResults &transformResults, TransformState &state)
Calls applyToOne for every payload operation associated with the operand of this transform IR op,...
Options controlling the application of transform operations by the TransformState.
TransformOptions & enableExpensiveChecks(bool enable=true)
Requests computationally expensive checks of the transform and payload IR well-formedness to be perfo...
TransformOptions & enableEnforceSingleToplevelTransformOp(bool enable=true)
TransformOptions & operator=(const TransformOptions &)=default
TransformOptions(const TransformOptions &)=default
bool getExpensiveChecksEnabled() const
Returns true if the expensive checks are requested.
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void set(OpResult value, std::initializer_list< Operation * > ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void setValues(OpResult handle, std::initializer_list< Value > values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
void setValues(OpResult handle, Range &&values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
void setParams(OpResult value, ArrayRef< TransformState::Param > params)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void setRemainingToEmpty(TransformOpInterface transform)
Sets the currently unset results to empty lists of the kind expected by the corresponding results of ...
void setMappedValues(OpResult handle, ArrayRef< MappedValue > values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
TransformRewriter(MLIRContext *ctx, ErrorCheckingTrackingListener *listener)
Create a new TransformRewriter.
bool hasTrackingFailures() const
Return "true" if the tracking listener had failures.
LogicalResult notifyPayloadOperationReplaced(Operation *op, Operation *replacement)
Notify the transform dialect interpreter that the given op has been replaced with another op and that...
void silenceTrackingFailure()
Silence all tracking failures that have been encountered so far.
Base class for TransformState extensions that allow TransformState to contain user-specified informat...
Extension(TransformState &state)
Constructs an extension of the given TransformState object.
const TransformState & getTransformState() const
Provides read-only access to the parent TransformState object.
LogicalResult replacePayloadOp(Operation *op, Operation *replacement)
Replaces the given payload op with another op.
virtual ~Extension()
Base virtual destructor.
LogicalResult replacePayloadValue(Value value, Value replacement)
Replaces the given payload value with another value.
A RAII object maintaining a "stack frame" for a transform IR region.
~RegionScope()
Forgets the mapping from or to values defined in the associated transform IR region,...
The state maintained across applications of various ops implementing the TransformOpInterface.
const TransformOptions & getOptions() const
friend LogicalResult applyTransforms(Operation *, TransformOpInterface, const RaggedArray< MappedValue > &, const TransformOptions &, bool)
Entry point to the Transform dialect infrastructure.
LogicalResult getHandlesForPayloadValue(Value payloadValue, SmallVectorImpl< Value > &handles, bool includeOutOfScope=false) const
Populates handles with all handles pointing to the given payload IR value.
auto getPayloadOps(Value value) const
Returns an iterator that enumerates all ops that the given transform IR value corresponds to.
auto getPayloadValues(Value handleValue) const
Returns an iterator that enumerates all payload IR values that the given transform IR value correspon...
LogicalResult mapBlockArguments(BlockArgument argument, ArrayRef< Operation * > operations)
Records the mapping between a block argument in the transform IR and a list of operations in the payl...
DiagnosedSilenceableFailure applyTransform(TransformOpInterface transform)
Applies the transformation specified by the given transform op and updates the state accordingly.
RegionScope make_region_scope(Region &region)
Creates a new region scope for the given region.
ArrayRef< Attribute > getParams(Value value) const
Returns the list of parameters that the given transform IR value corresponds to.
LogicalResult mapBlockArgument(BlockArgument argument, ArrayRef< MappedValue > values)
size_t getNumTopLevelMappings() const
Returns the number of extra mappings for the top-level operation.
Ty & addExtension(Args &&...args)
Adds a new Extension of the type specified as template parameter, constructing it with the arguments ...
LogicalResult getHandlesForPayloadOp(Operation *op, SmallVectorImpl< Value > &handles, bool includeOutOfScope=false) const
Populates handles with all handles pointing to the given Payload IR op.
Ty * getExtension()
Returns the extension of the specified type.
Operation * getTopLevel() const
Returns the op at which the transformation state is rooted.
void removeExtension()
Removes the extension of the specified type.
ArrayRef< MappedValue > getTopLevelMapping(size_t position) const
Returns the position-th extra mapping for the top-level operation.
LogicalResult verifyTransformOpInterface(Operation *op)
Verification hook for TransformOpInterface.
void setApplyToOneResults(Operation *transformOp, TransformResults &transformResults, ArrayRef< ApplyToEachResultList > results)
"Transpose" the results produced by individual applications, arranging them per result value of the t...
void forwardTerminatorOperands(Block *block, transform::TransformState &state, transform::TransformResults &results)
Populates results with payload associations that match exactly those of the operands to block's termi...
void getParamProducerTransformOpTraitEffects(Operation *op, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Non-template implementation of ParamProducerTransformOpTrait::getEffects().
LogicalResult checkNestedConsumption(Location loc, ArrayRef< Operation * > targets)
Reports an error and returns failure if targets contains an ancestor operation before its descendant ...
LogicalResult mapPossibleTopLevelTransformOpBlockArguments(TransformState &state, Operation *op, Region &region)
Maps the only block argument of the op with PossibleTopLevelTransformOpTrait to either the list of op...
TransformState makeTransformStateForTesting(Region *region, Operation *payloadRoot)
Make a dummy transform state for testing purposes.
LogicalResult checkApplyToOne(Operation *transformOp, Location payloadOpLoc, const ApplyToEachResultList &partialResult)
Check that the contents of partialResult matches the number, kind (payload op or parameter) and nulli...
SmallVector< OpOperand * > getConsumedHandleOpOperands(transform::TransformOpInterface transformOp)
Returns all operands that are handles and being consumed by the given op.
LogicalResult verifyParamProducerTransformOpTrait(Operation *op)
Non-template implementation of ParamProducerTransformOpTrait::verify().
void prepareValueMappings(SmallVectorImpl< SmallVector< transform::MappedValue >> &mappings, ValueRange values, const transform::TransformState &state)
Populates mappings with mapped values associated with the given transform IR values in the given stat...
LogicalResult verifyPossibleTopLevelTransformOpTrait(Operation *op)
Verification hook for PossibleTopLevelTransformOpTrait.
DiagnosedSilenceableFailure applyTransformToEach(TransformOpTy transformOp, TransformRewriter &rewriter, Range &&targets, SmallVectorImpl< ApplyToEachResultList > &results, TransformState &state)
Applies a one-to-one or a one-to-many transform to each of the given targets.
void getPotentialTopLevelEffects(Operation *operation, Value root, Block &body, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with side effects implied by PossibleTopLevelTransformOpTrait for the given operati...
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
bool isHandleConsumed(Value handle, transform::TransformOpInterface transform)
Checks whether the transform op consumes the given handle.
void onlyReadsHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void getConsumedBlockArguments(Block &block, llvm::SmallDenseSet< unsigned > &consumedArguments)
Populates consumedArguments with positions of block arguments that are consumed by the operations in ...
LogicalResult applyTransforms(Operation *payloadRoot, TransformOpInterface transform, const RaggedArray< MappedValue > &extraMapping={}, const TransformOptions &options=TransformOptions(), bool enforceToplevelTransformOp=true)
Entry point to the Transform dialect infrastructure.
bool doesModifyPayload(transform::TransformOpInterface transform)
Checks whether the transform op modifies the payload.
void consumesHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
bool doesReadPayload(transform::TransformOpInterface transform)
Checks whether the transform op reads the payload.
void producesHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
llvm::PointerUnion< Operation *, Param, Value > MappedValue
Include the generated interface declarations.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
Side effect resource corresponding to the Payload IR itself.
StringRef getName() override
Return a string name of the resource.
A configuration object for customizing a TrackingListener.
bool requireMatchingReplacementOpName
If set to "true", the name of a replacement op must match the name of the original op.
bool skipCastOps
If set to "true", cast ops (that implement the CastOpInterface) are skipped and the replacement op se...
SkipHandleFn skipHandleFn
An optional function that returns "true" for handles that do not have to be updated.
std::function< bool(Value)> SkipHandleFn
Side effect resource corresponding to the mapping between Transform IR values and Payload IR operatio...
StringRef getName() override
Return a string name of the resource.