MLIR  20.0.0git
TransformInterfaces.h
Go to the documentation of this file.
1 //===- TransformInterfaces.h - Transform Dialect Interfaces -----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_TRANSFORM_INTERFACES_TRANSFORMINTERFACES_H
10 #define MLIR_DIALECT_TRANSFORM_INTERFACES_TRANSFORMINTERFACES_H
11 
14 #include "mlir/IR/OpDefinition.h"
15 #include "mlir/IR/PatternMatch.h"
18 
19 #include "mlir/Dialect/Transform/Interfaces/TransformTypeInterfaces.h.inc"
20 
21 namespace mlir {
22 namespace transform {
23 
24 class TransformOpInterface;
25 class TransformResults;
26 class TransformRewriter;
27 class TransformState;
28 
29 using Param = Attribute;
31 
32 namespace detail {
33 /// Maps the only block argument of the op with PossibleTopLevelTransformOpTrait
34 /// to either the list of operations associated with its operand or the root of
35 /// the payload IR, depending on what is available in the context.
36 LogicalResult
38  Operation *op, Region &region);
39 
40 /// Verification hook for PossibleTopLevelTransformOpTrait.
42 
43 /// Populates `effects` with side effects implied by
44 /// PossibleTopLevelTransformOpTrait for the given operation. The operation may
45 /// have an optional `root` operand, indicating it is not in fact top-level. It
46 /// is also expected to have a single-block body.
48  Operation *operation, Value root, Block &body,
50 
51 /// Verification hook for TransformOpInterface.
52 LogicalResult verifyTransformOpInterface(Operation *op);
53 
54 /// Appends the entities associated with the given transform values in `state`
55 /// to the pre-existing list of mappings. The array of mappings must have as
56 /// many elements as values. If `flatten` is set, multiple values may be
57 /// associated with each transform value, and this always succeeds. Otherwise,
58 /// checks that each value has exactly one mapping associated and return failure
59 /// otherwise.
60 LogicalResult appendValueMappings(
62  ValueRange values, const transform::TransformState &state,
63  bool flatten = true);
64 
65 /// Populates `mappings` with mapped values associated with the given transform
66 /// IR values in the given `state`.
69  ValueRange values, const transform::TransformState &state);
70 
71 /// Populates `results` with payload associations that match exactly those of
72 /// the operands to `block`'s terminator.
75 
76 /// Make a dummy transform state for testing purposes. This MUST NOT be used
77 /// outside of test cases.
79  Operation *payloadRoot);
80 
81 /// Returns all operands that are handles and being consumed by the given op.
83 getConsumedHandleOpOperands(transform::TransformOpInterface transformOp);
84 } // namespace detail
85 } // namespace transform
86 } // namespace mlir
87 
88 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h.inc"
89 
90 namespace mlir {
91 namespace transform {
92 
93 /// Options controlling the application of transform operations by the
94 /// TransformState.
96 public:
97  TransformOptions() = default;
98  TransformOptions(const TransformOptions &) = default;
100 
101  /// Requests computationally expensive checks of the transform and payload IR
102  /// well-formedness to be performed before each transformation. In particular,
103  /// these ensure that the handles still point to valid operations when used.
105  expensiveChecksEnabled = enable;
106  return *this;
107  }
108 
109  // Ensures that only a single top-level transform op is present in the IR.
111  enforceSingleToplevelTransformOp = enable;
112  return *this;
113  }
114 
115  /// Returns true if the expensive checks are requested.
116  bool getExpensiveChecksEnabled() const { return expensiveChecksEnabled; }
117 
118  // Returns true if enforcing a single top-level transform op is requested.
120  return enforceSingleToplevelTransformOp;
121  }
122 
123 private:
124  bool expensiveChecksEnabled = true;
125  bool enforceSingleToplevelTransformOp = true;
126 };
127 
128 /// Entry point to the Transform dialect infrastructure. Applies the
129 /// transformation specified by `transform` to payload IR contained in
130 /// `payloadRoot`. The `transform` operation may contain other operations that
131 /// will be executed following the internal logic of the operation. It must
132 /// have the `PossibleTopLevelTransformOp` trait and not have any operands.
133 /// This function internally keeps track of the transformation state.
134 LogicalResult
135 applyTransforms(Operation *payloadRoot, TransformOpInterface transform,
136  const RaggedArray<MappedValue> &extraMapping = {},
137  const TransformOptions &options = TransformOptions(),
138  bool enforceToplevelTransformOp = true);
139 
140 /// The state maintained across applications of various ops implementing the
141 /// TransformOpInterface. The operations implementing this interface and the
142 /// surrounding structure are referred to as transform IR. The operations to
143 /// which transformations apply are referred to as payload IR. Transform IR
144 /// operates on values that can be associated either with a list of payload IR
145 /// operations (such values are referred to as handles) or with a list of
146 /// parameters represented as attributes. The state thus contains the mapping
147 /// between values defined in the transform IR ops and either payload IR ops or
148 /// parameters. For payload ops, the mapping is many-to-many and the reverse
149 /// mapping is also stored. The "expensive-checks" option can be passed to the
150 /// constructor at transformation execution time that transform IR values used
151 /// as operands by a transform IR operation are not associated with dangling
152 /// pointers to payload IR operations that are known to have been erased by
153 /// previous transformation through the same or a different transform IR value.
154 ///
155 /// A reference to this class is passed as an argument to "apply" methods of the
156 /// transform op interface. Thus the "apply" method can call either
157 /// `state.getPayloadOps( getSomeOperand() )` to obtain the list of operations
158 /// or `state.getParams( getSomeOperand() )` to obtain the list of parameters
159 /// associated with its operand. The method is expected to populate the
160 /// `TransformResults` class instance in order to update the mapping. The
161 /// `applyTransform` method takes care of propagating the state of
162 /// `TransformResults` into the instance of this class.
163 ///
164 /// When applying transform IR operations with regions, the client is expected
165 /// to create a `RegionScope` RAII object to create a new "stack frame" for
166 /// values defined inside the region. The mappings from and to these values will
167 /// be automatically dropped when the object goes out of scope, typically at the
168 /// end of the `apply` function of the parent operation. If a region contains
169 /// blocks with arguments, the client can map those arguments to payload IR ops
170 /// using `mapBlockArguments`.
172 public:
174 
175 private:
176  /// Mapping between a Value in the transform IR and the corresponding set of
177  /// operations in the payload IR.
179 
180  /// Mapping between a payload IR operation and the transform IR values it is
181  /// associated with.
184 
185  /// Mapping between a Value in the transform IR and the corresponding list of
186  /// parameters.
188 
189  /// Mapping between a Value in the transform IR and the corrsponding list of
190  /// values in the payload IR. Also works for reverse mappings.
192 
193  /// Mapping between a Value in the transform IR and an error message that
194  /// should be emitted when the value is used.
195  using InvalidatedHandleMap = DenseMap<Value, std::function<void(Location)>>;
196 
197 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
198  /// Debug only: A timestamp is associated with each transform IR value, so
199  /// that invalid iterator usage can be detected more reliably.
200  using TransformIRTimestampMapping = DenseMap<Value, int64_t>;
201 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
202 
203  /// The bidirectional mappings between transform IR values and payload IR
204  /// operations, and the mapping between transform IR values and parameters.
205  struct Mappings {
206  TransformOpMapping direct;
208  ParamMapping params;
209  ValueMapping values;
210  ValueMapping reverseValues;
211 
212 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
213  TransformIRTimestampMapping timestamps;
214  void incrementTimestamp(Value value) { ++timestamps[value]; }
215 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
216  };
217 
218  friend LogicalResult applyTransforms(Operation *, TransformOpInterface,
219  const RaggedArray<MappedValue> &,
220  const TransformOptions &, bool);
221 
222  friend TransformState
224 
225 public:
226  const TransformOptions &getOptions() const { return options; }
227 
228  /// Returns the op at which the transformation state is rooted. This is
229  /// typically helpful for transformations that apply globally.
230  Operation *getTopLevel() const;
231 
232  /// Returns the number of extra mappings for the top-level operation.
233  size_t getNumTopLevelMappings() const { return topLevelMappedValues.size(); }
234 
235  /// Returns the position-th extra mapping for the top-level operation.
236  ArrayRef<MappedValue> getTopLevelMapping(size_t position) const {
237  return topLevelMappedValues[position];
238  }
239 
240  /// Returns an iterator that enumerates all ops that the given transform IR
241  /// value corresponds to. Ops may be erased while iterating; erased ops are
242  /// not enumerated. This function is helpful for transformations that apply to
243  /// a particular handle.
244  auto getPayloadOps(Value value) const {
245  ArrayRef<Operation *> view = getPayloadOpsView(value);
246 
247 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
248  // Memorize the current timestamp and make sure that it has not changed
249  // when incrementing or dereferencing the iterator returned by this
250  // function. The timestamp is incremented when the "direct" mapping is
251  // resized; this would invalidate the iterator returned by this function.
252  int64_t currentTimestamp = getMapping(value).timestamps.lookup(value);
253 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
254 
255  // When ops are replaced/erased, they are replaced with nullptr (until
256  // the data structure is compacted). Do not enumerate these ops.
257  return llvm::make_filter_range(view, [=](Operation *op) {
258 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
259  [[maybe_unused]] bool sameTimestamp =
260  currentTimestamp == this->getMapping(value).timestamps.lookup(value);
261  assert(sameTimestamp && "iterator was invalidated during iteration");
262 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
263  return op != nullptr;
264  });
265  }
266 
267  /// Returns the list of parameters that the given transform IR value
268  /// corresponds to.
269  ArrayRef<Attribute> getParams(Value value) const;
270 
271  /// Returns an iterator that enumerates all payload IR values that the given
272  /// transform IR value corresponds to.
273  auto getPayloadValues(Value handleValue) const {
274  ArrayRef<Value> view = getPayloadValuesView(handleValue);
275 
276 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
277  // Memorize the current timestamp and make sure that it has not changed
278  // when incrementing or dereferencing the iterator returned by this
279  // function. The timestamp is incremented when the "values" mapping is
280  // resized; this would invalidate the iterator returned by this function.
281  int64_t currentTimestamp =
282  getMapping(handleValue).timestamps.lookup(handleValue);
283  return llvm::make_filter_range(view, [=](Value v) {
284  [[maybe_unused]] bool sameTimestamp =
285  currentTimestamp ==
286  this->getMapping(handleValue).timestamps.lookup(handleValue);
287  assert(sameTimestamp && "iterator was invalidated during iteration");
288  return true;
289  });
290 #else
291  return llvm::make_range(view.begin(), view.end());
292 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
293  }
294 
295  /// Populates `handles` with all handles pointing to the given Payload IR op.
296  /// Returns success if such handles exist, failure otherwise.
297  /// If `includeOutOfScope` is set to "true", handles that are defined in
298  /// regions beyond the most recent isolated from above region are included.
299  LogicalResult getHandlesForPayloadOp(Operation *op,
300  SmallVectorImpl<Value> &handles,
301  bool includeOutOfScope = false) const;
302 
303  /// Populates `handles` with all handles pointing to the given payload IR
304  /// value. Returns success if such handles exist, failure otherwise.
305  /// If `includeOutOfScope` is set to "true", handles that are defined in
306  /// regions beyond the most recent isolated from above region are included.
307  LogicalResult getHandlesForPayloadValue(Value payloadValue,
308  SmallVectorImpl<Value> &handles,
309  bool includeOutOfScope = false) const;
310 
311  /// Applies the transformation specified by the given transform op and updates
312  /// the state accordingly.
313  DiagnosedSilenceableFailure applyTransform(TransformOpInterface transform);
314 
315  /// Records the mapping between a block argument in the transform IR and a
316  /// list of operations in the payload IR. The arguments must be defined in
317  /// blocks of the currently processed transform IR region, typically after a
318  /// region scope is defined.
319  ///
320  /// Returns failure if the payload does not satisfy the conditions associated
321  /// with the type of the handle value.
322  LogicalResult mapBlockArguments(BlockArgument argument,
323  ArrayRef<Operation *> operations) {
324  assert(argument.getParentRegion() == regionStack.back()->region &&
325  "mapping block arguments from a region other than the active one");
326  return setPayloadOps(argument, operations);
327  }
328  LogicalResult mapBlockArgument(BlockArgument argument,
329  ArrayRef<MappedValue> values);
330  LogicalResult mapBlockArguments(Block::BlockArgListType arguments,
332 
333  // Forward declarations to support limited visibility.
334  class RegionScope;
335 
336  /// Creates a new region scope for the given region. The region is expected to
337  /// be nested in the currently processed region.
338  // Implementation note: this method is inline but implemented outside of the
339  // class body to comply with visibility and full-declaration requirements.
340  inline RegionScope make_region_scope(Region &region);
341 
342  /// A RAII object maintaining a "stack frame" for a transform IR region. When
343  /// applying a transform IR operation that contains a region, the caller is
344  /// expected to create a RegionScope before applying the ops contained in the
345  /// region. This ensures that the mappings between values defined in the
346  /// transform IR region and payload IR operations are cleared when the region
347  /// processing ends; such values cannot be accessed outside the region.
348  class RegionScope {
349  public:
350  /// Forgets the mapping from or to values defined in the associated
351  /// transform IR region, and restores the mapping that existed before
352  /// entering this scope.
353  ~RegionScope();
354 
355  private:
356  /// Creates a new scope for mappings between values defined in the given
357  /// transform IR region and payload IR objects.
358  RegionScope(TransformState &state, Region &region)
359  : state(state), region(&region) {
360  auto res = state.mappings.insert(
361  std::make_pair(&region, std::make_unique<Mappings>()));
362  assert(res.second && "the region scope is already present");
363  (void)res;
364  state.regionStack.push_back(this);
365  }
366 
367  /// Back-reference to the transform state.
368  TransformState &state;
369 
370  /// The region this scope is associated with.
371  Region *region;
372 
373  /// The transform op within this region that is currently being applied.
374  TransformOpInterface currentTransform;
375 
377  };
378  friend class RegionScope;
379 
380  /// Base class for TransformState extensions that allow TransformState to
381  /// contain user-specified information in the state object. Clients are
382  /// expected to derive this class, add the desired fields, and make the
383  /// derived class compatible with the MLIR TypeID mechanism:
384  ///
385  /// ```mlir
386  /// class MyExtension final : public TransformState::Extension {
387  /// public:
388  /// MyExtension(TranfsormState &state, int myData)
389  /// : Extension(state) {...}
390  /// private:
391  /// int mySupplementaryData;
392  /// };
393  /// ```
394  ///
395  /// Instances of this and derived classes are not expected to be created by
396  /// the user, instead they are directly constructed within a TransformState. A
397  /// TransformState can only contain one extension with the given TypeID.
398  /// Extensions can be obtained from a TransformState instance, and can be
399  /// removed when they are no longer required.
400  ///
401  /// ```mlir
402  /// transformState.addExtension<MyExtension>(/*myData=*/42);
403  /// MyExtension *ext = transformState.getExtension<MyExtension>();
404  /// ext->doSomething();
405  /// ```
406  class Extension {
407  // Allow TransformState to allocate Extensions.
408  friend class TransformState;
409 
410  public:
411  /// Base virtual destructor.
412  // Out-of-line definition ensures symbols are emitted in a single object
413  // file.
414  virtual ~Extension();
415 
416  protected:
417  /// Constructs an extension of the given TransformState object.
418  Extension(TransformState &state) : state(state) {}
419 
420  /// Provides read-only access to the parent TransformState object.
421  const TransformState &getTransformState() const { return state; }
422 
423  /// Replaces the given payload op with another op. If the replacement op is
424  /// null, removes the association of the payload op with its handle. Returns
425  /// failure if the op is not associated with any handle.
426  ///
427  /// Note: This function does not update value handles. None of the original
428  /// op's results are allowed to be mapped to any value handle.
429  LogicalResult replacePayloadOp(Operation *op, Operation *replacement);
430 
431  /// Replaces the given payload value with another value. If the replacement
432  /// value is null, removes the association of the payload value with its
433  /// handle. Returns failure if the value is not associated with any handle.
434  LogicalResult replacePayloadValue(Value value, Value replacement);
435 
436  private:
437  /// Back-reference to the state that is being extended.
438  TransformState &state;
439  };
440 
441  /// Adds a new Extension of the type specified as template parameter,
442  /// constructing it with the arguments provided. The extension is owned by the
443  /// TransformState. It is expected that the state does not already have an
444  /// extension of the same type. Extension constructors are expected to take
445  /// a reference to TransformState as first argument, automatically supplied
446  /// by this call.
447  template <typename Ty, typename... Args>
448  Ty &addExtension(Args &&...args) {
449  static_assert(
450  std::is_base_of<Extension, Ty>::value,
451  "only an class derived from TransformState::Extension is allowed here");
452  auto ptr = std::make_unique<Ty>(*this, std::forward<Args>(args)...);
453  auto result = extensions.try_emplace(TypeID::get<Ty>(), std::move(ptr));
454  assert(result.second && "extension already added");
455  return *static_cast<Ty *>(result.first->second.get());
456  }
457 
458  /// Returns the extension of the specified type.
459  template <typename Ty>
460  Ty *getExtension() {
461  static_assert(
462  std::is_base_of<Extension, Ty>::value,
463  "only an class derived from TransformState::Extension is allowed here");
464  auto iter = extensions.find(TypeID::get<Ty>());
465  if (iter == extensions.end())
466  return nullptr;
467  return static_cast<Ty *>(iter->second.get());
468  }
469 
470  /// Removes the extension of the specified type.
471  template <typename Ty>
473  static_assert(
474  std::is_base_of<Extension, Ty>::value,
475  "only an class derived from TransformState::Extension is allowed here");
476  extensions.erase(TypeID::get<Ty>());
477  }
478 
479 private:
480  /// Identifier for storing top-level value in the `operations` mapping.
481  static constexpr Value kTopLevelValue = Value();
482 
483  /// Creates a state for transform ops living in the given region. The second
484  /// argument points to the root operation in the payload IR being transformed,
485  /// which may or may not contain the region with transform ops. Additional
486  /// options can be provided through the trailing configuration object.
487  TransformState(Region *region, Operation *payloadRoot,
488  const RaggedArray<MappedValue> &extraMappings = {},
489  const TransformOptions &options = TransformOptions());
490 
491  /// Returns the mappings frame for the region in which the value is defined.
492  /// If `allowOutOfScope` is set to "false", asserts that the value is in
493  /// scope, based on the current stack of frames.
494  const Mappings &getMapping(Value value, bool allowOutOfScope = false) const {
495  return const_cast<TransformState *>(this)->getMapping(value,
496  allowOutOfScope);
497  }
498  Mappings &getMapping(Value value, bool allowOutOfScope = false) {
499  Region *region = value.getParentRegion();
500  auto it = mappings.find(region);
501  assert(it != mappings.end() &&
502  "trying to find a mapping for a value from an unmapped region");
503 #ifndef NDEBUG
504  if (!allowOutOfScope) {
505  for (Region *r : llvm::reverse(llvm::make_first_range(mappings))) {
506  if (r == region)
507  break;
508  if (r->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
509  llvm_unreachable("trying to get mapping beyond region that is "
510  "isolated from above");
511  }
512  }
513 #endif // NDEBUG
514  return *it->second;
515  }
516 
517  /// Returns the mappings frame for the region in which the operation resides.
518  /// If `allowOutOfScope` is set to "false", asserts that the operation is in
519  /// scope, based on the current stack of frames.
520  const Mappings &getMapping(Operation *operation,
521  bool allowOutOfScope = false) const {
522  return const_cast<TransformState *>(this)->getMapping(operation,
523  allowOutOfScope);
524  }
525  Mappings &getMapping(Operation *operation, bool allowOutOfScope = false) {
526  Region *region = operation->getParentRegion();
527  auto it = mappings.find(region);
528  assert(it != mappings.end() &&
529  "trying to find a mapping for an operation from an unmapped region");
530 #ifndef NDEBUG
531  if (!allowOutOfScope) {
532  for (Region *r : llvm::reverse(llvm::make_first_range(mappings))) {
533  if (r == region)
534  break;
535  if (r->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
536  llvm_unreachable("trying to get mapping beyond region that is "
537  "isolated from above");
538  }
539  }
540 #endif // NDEBUG
541  return *it->second;
542  }
543 
544  /// Updates the state to include the associations between op results and the
545  /// provided result of applying a transform op.
546  LogicalResult updateStateFromResults(const TransformResults &results,
547  ResultRange opResults);
548 
549  /// Returns a list of all ops that the given transform IR value corresponds
550  /// to. In case an op was erased, the returned list contains nullptr. This
551  /// function is helpful for transformations that apply to a particular handle.
552  ArrayRef<Operation *> getPayloadOpsView(Value value) const;
553 
554  /// Returns a list of payload IR values that the given transform IR value
555  /// corresponds to.
556  ArrayRef<Value> getPayloadValuesView(Value handleValue) const;
557 
558  /// Sets the payload IR ops associated with the given transform IR value
559  /// (handle). A payload op may be associated multiple handles as long as
560  /// at most one of them gets consumed by further transformations.
561  /// For example, a hypothetical "find function by name" may be called twice in
562  /// a row to produce two handles pointing to the same function:
563  ///
564  /// %0 = transform.find_func_by_name { name = "myfunc" }
565  /// %1 = transform.find_func_by_name { name = "myfunc" }
566  ///
567  /// which is valid by itself. However, calling a hypothetical "rewrite and
568  /// rename function" transform on both handles:
569  ///
570  /// transform.rewrite_and_rename %0 { new_name = "func" }
571  /// transform.rewrite_and_rename %1 { new_name = "func" }
572  ///
573  /// is invalid given the transformation "consumes" the handle as expressed
574  /// by side effects. Practically, a transformation consuming a handle means
575  /// that the associated payload operation may no longer exist.
576  ///
577  /// Similarly, operation handles may be invalidate and should not be used
578  /// after a transform that consumed a value handle pointing to a payload value
579  /// defined by the operation as either block argument or op result. For
580  /// example, in the following sequence, the last transform operation rewrites
581  /// the callee to not return a specified result:
582  ///
583  /// %0 = transform.find_call "myfunc"
584  /// %1 = transform.find_results_of_calling "myfunc"
585  /// transform.drop_call_result_from_signature %1[0]
586  ///
587  /// which requires the call operations to be recreated. Therefore, the handle
588  /// %0 becomes associated with a dangling pointer and should not be used.
589  ///
590  /// Returns failure if the payload does not satisfy the conditions associated
591  /// with the type of the handle value. The value is expected to have a type
592  /// implementing TransformHandleTypeInterface.
593  LogicalResult setPayloadOps(Value value, ArrayRef<Operation *> targets);
594 
595  /// Sets the payload IR values association with the given transform IR value
596  /// (handle). A payload value may be associated with multiple handles as long
597  /// as at most one of them is consumed by further transformations. For
598  /// example, a hypothetical "get results of calls to function with the given
599  /// name" transform may be performed twice in a row producing handles pointing
600  /// to the same values:
601  ///
602  /// %0 = transform.find_results_of_calling "myfunc"
603  /// %1 = transform.find_results_of_calling "myfunc"
604  ///
605  /// which is valid by itself. However, calling a hypothetical "erase value
606  /// producer" transform on both handles:
607  ///
608  /// transform.erase_value_produce %0
609  /// transform.erase_value_produce %1
610  ///
611  /// is invalid provided the transformation "consumes" the handle as expressed
612  /// by side effects (which themselves reflect the semantics of the transform
613  /// erasing the producer and making the handle dangling). Practically, a
614  /// transformation consuming a handle means the associated payload value may
615  /// no longer exist.
616  ///
617  /// Similarly, value handles are invalidated and should not be used after a
618  /// transform that consumed an operation handle pointing to the payload IR
619  /// operation defining the values associated the value handle, as either block
620  /// arguments or op results, or any ancestor operation. For example,
621  ///
622  /// %0 = transform.find_call "myfunc"
623  /// %1 = transform.find_results_of_calling "myfunc"
624  /// transform.rewrite_and_rename %0 { new_name = "func" }
625  ///
626  /// makes %1 unusable after the last transformation if it consumes %0. When an
627  /// operation handle is consumed, it usually indicates that the operation was
628  /// destroyed or heavily modified, meaning that the values it defines may no
629  /// longer exist.
630  ///
631  /// Returns failure if the payload values do not satisfy the conditions
632  /// associated with the type of the handle value. The value is expected to
633  /// have a type implementing TransformValueHandleTypeInterface.
634  LogicalResult setPayloadValues(Value handle, ValueRange payloadValues);
635 
636  /// Sets the parameters associated with the given transform IR value. Returns
637  /// failure if the parameters do not satisfy the conditions associated with
638  /// the type of the value. The value is expected to have a type implementing
639  /// TransformParamTypeInterface.
640  LogicalResult setParams(Value value, ArrayRef<Param> params);
641 
642  /// Forgets the payload IR ops associated with the given transform IR value,
643  /// as well as any association between value handles and the results of said
644  /// payload IR op.
645  ///
646  /// If `allowOutOfScope` is set to "false", asserts that the handle is in
647  /// scope, based on the current stack of frames.
648  void forgetMapping(Value opHandle, ValueRange origOpFlatResults,
649  bool allowOutOfScope = false);
650 
651  void forgetValueMapping(Value valueHandle,
652  ArrayRef<Operation *> payloadOperations);
653 
654  /// Replaces the given payload op with another op. If the replacement op is
655  /// null, removes the association of the payload op with its handle. Returns
656  /// failure if the op is not associated with any handle.
657  ///
658  /// Note: This function does not update value handles. None of the original
659  /// op's results are allowed to be mapped to any value handle.
660  LogicalResult replacePayloadOp(Operation *op, Operation *replacement);
661 
662  /// Replaces the given payload value with another value. If the replacement
663  /// value is null, removes the association of the payload value with its
664  /// handle. Returns failure if the value is not associated with any handle.
665  LogicalResult replacePayloadValue(Value value, Value replacement);
666 
667  /// Records handle invalidation reporters into `newlyInvalidated`.
668  /// Specifically,
669  /// - `handle` is the op operand that consumes the handle,
670  /// - `potentialAncestors` is a list of ancestors of the payload operation
671  /// that the consumed handle is associated with, including itself,
672  /// - `throughValue` is the payload value the handle to which is consumed,
673  /// when it is the case, null when the operation handle is consumed
674  /// directly.
675  /// Iterates over all known operation and value handles and records reporters
676  /// for any potential future use of `handle` or any other handle that is
677  /// invalidated by its consumption, i.e., any handle pointing to any payload
678  /// IR entity (operation or value) associated with the same payload IR entity
679  /// as the consumed handle, or any nested payload IR entity. If
680  /// `potentialAncestors` is empty, records the reporter anyway. Does not
681  /// override existing reporters. This must remain a const method so it doesn't
682  /// inadvertently mutate `invalidatedHandles` too early.
683  void recordOpHandleInvalidation(OpOperand &consumingHandle,
684  ArrayRef<Operation *> potentialAncestors,
685  Value throughValue,
686  InvalidatedHandleMap &newlyInvalidated) const;
687 
688  /// Records handle invalidation reporters into `newlyInvalidated`.
689  /// Specifically,
690  /// - `consumingHandle` is the op operand that consumes the handle,
691  /// - `potentialAncestors` is a list of ancestors of the payload operation
692  /// that the consumed handle is associated with, including itself,
693  /// - `payloadOp` is the operation itself,
694  /// - `otherHandle` is another that may be associated with the affected
695  /// payload operations
696  /// - `throughValue` is the payload value the handle to which is consumed,
697  /// when it is the case, null when the operation handle is consumed
698  /// directly.
699  /// Looks at the payload opreations associated with `otherHandle` and if any
700  /// of these operations has an ancestor (or is itself) listed in
701  /// `potentialAncestors`, records the error message describing the use of the
702  /// invalidated handle. Does nothing if `otherHandle` already has a reporter
703  /// associated with it. This must remain a const method so it doesn't
704  /// inadvertently mutate `invalidatedHandles` too early.
705  void recordOpHandleInvalidationOne(
706  OpOperand &consumingHandle, ArrayRef<Operation *> potentialAncestors,
707  Operation *payloadOp, Value otherHandle, Value throughValue,
708  InvalidatedHandleMap &newlyInvalidated) const;
709 
710  /// Records handle invalidation reporters into `newlyInvalidated`.
711  /// Specifically,
712  /// - `opHandle` is the op operand that consumes the handle;
713  /// - `potentialAncestors` is a list of ancestors of the payload operation
714  /// that the consumed handle is associated with, including itself;
715  /// - `payloadValue` is the value defined by the operation associated with
716  /// the consuming handle as either op result or block argument;
717  /// - `valueHandle` is another that may be associated with the payload value.
718  /// Looks at the payload values associated with `valueHandle` and if any of
719  /// these values is defined, as op result or block argument, by an operation
720  /// whose ancestor (or the operation itself) is listed in
721  /// `potentialAncestors`, records the error message describing the use of the
722  /// invalidated handle. Does nothing if `valueHandle` already has a reporter
723  /// associated with it. This must remain a const method so it doesn't
724  /// inadvertently mutate `invalidatedHandles` too early.
725  void recordValueHandleInvalidationByOpHandleOne(
726  OpOperand &opHandle, ArrayRef<Operation *> potentialAncestors,
727  Value payloadValue, Value valueHandle,
728  InvalidatedHandleMap &newlyInvalidated) const;
729 
730  /// Records handle invalidation reporters into `newlyInvalidated`.
731  /// Specifically,
732  /// - `valueHandle` is the op operand that consumes the handle,
733  /// - `throughValue` is the payload value the handle to which is consumed,
734  /// when it is the case, null when the operation handle is consumed
735  /// directly.
736  /// Iterates over all known operation and value handles and records reporters
737  /// for any potential future use of `handle` or any other handle that is
738  /// invalidated by its consumption, i.e., any handle pointing to any payload
739  /// IR entity (operation or value) associated with the same payload IR entity
740  /// as the consumed handle, or any nested payload IR entity. Does not override
741  /// existing reporters. This must remain a const method so it doesn't
742  /// inadvertently mutate `invalidatedHandles` too early.
743  void
744  recordValueHandleInvalidation(OpOperand &valueHandle,
745  InvalidatedHandleMap &newlyInvalidated) const;
746 
747  /// Checks that the operation does not use invalidated handles as operands.
748  /// Reports errors and returns failure if it does. Otherwise, invalidates the
749  /// handles consumed by the operation as well as any handles pointing to
750  /// payload IR operations nested in the operations associated with the
751  /// consumed handles.
752  LogicalResult
753  checkAndRecordHandleInvalidation(TransformOpInterface transform);
754 
755  /// Implementation of the checkAndRecordHandleInvalidation. This must remain a
756  /// const method so it doesn't inadvertently mutate `invalidatedHandles` too
757  /// early.
758  LogicalResult checkAndRecordHandleInvalidationImpl(
759  transform::TransformOpInterface transform,
760  transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const;
761 
762  /// Remove all nullptrs from op handles that were added by `replacePayloadOp`.
763  void compactOpHandles();
764 
765  /// A stack of mappings between transform IR values and payload IR ops,
766  /// aggregated by the region in which the transform IR values are defined.
767  /// We use a pointer to the Mappings struct so that reallocations inside
768  /// MapVector don't invalidate iterators when we apply nested transform ops
769  /// while also iterating over the mappings.
770  llvm::MapVector<Region *, std::unique_ptr<Mappings>> mappings;
771 
772  /// Op handles may be temporarily mapped to nullptr to avoid invalidating
773  /// payload op iterators. This set contains all op handles with nullptrs.
774  /// These handles are "compacted" (i.e., nullptrs removed) at the end of each
775  /// transform.
776  DenseSet<Value> opHandlesToCompact;
777 
778  /// Extensions attached to the TransformState, identified by the TypeID of
779  /// their type. Only one extension of any given type is allowed.
780  DenseMap<TypeID, std::unique_ptr<Extension>> extensions;
781 
782  /// The top-level operation that contains all payload IR, typically a module.
783  Operation *topLevel;
784 
785  /// Extra mapped values (payload operations, values or parameters) to be
786  /// associated with additional entry block arguments of the top-level
787  /// transform operation.
788  RaggedArray<MappedValue> topLevelMappedValues;
789 
790  /// Additional options controlling the transformation state behavior.
791  TransformOptions options;
792 
793  /// The mapping from invalidated handles to the error-reporting functions that
794  /// describe when the handles were invalidated. Calling such a function emits
795  /// a user-visible diagnostic with an additional note pointing to the given
796  /// location.
797  InvalidatedHandleMap invalidatedHandles;
798 
799  /// A stack of nested regions that are being processed in the transform IR.
800  /// Each region must be an ancestor of the following regions in this list.
801  /// These are also the keys for "mappings".
802  SmallVector<RegionScope *> regionStack;
803 
804  /// The top-level region scope. The first (bottom) element of `regionStack`
805  /// is the top-level region scope object.
806  std::unique_ptr<RegionScope> topLevelRegionScope;
807 };
808 
809 /// Local mapping between values defined by a specific op implementing the
810 /// TransformOpInterface and the payload IR ops they correspond to.
812  friend class TransformState;
813 
814 public:
815  /// Indicates that the result of the transform IR op at the given position
816  /// corresponds to the given list of payload IR ops. Each result must be set
817  /// by the transformation exactly once in case of transformation succeeding.
818  /// The value must have a type implementing TransformHandleTypeInterface.
819  template <typename Range>
820  void set(OpResult value, Range &&ops) {
821  int64_t position = value.getResultNumber();
822  assert(position < static_cast<int64_t>(operations.size()) &&
823  "setting results for a non-existent handle");
824  assert(operations[position].data() == nullptr && "results already set");
825  assert(params[position].data() == nullptr &&
826  "another kind of results already set");
827  assert(values[position].data() == nullptr &&
828  "another kind of results already set");
829  operations.replace(position, std::forward<Range>(ops));
830  }
831 
832  /// Indicates that the result of the transform IR op at the given position
833  /// corresponds to the given list of payload IR ops. Each result must be set
834  /// by the transformation exactly once in case of transformation succeeding.
835  /// The value must have a type implementing TransformHandleTypeInterface.
836  void set(OpResult value, std::initializer_list<Operation *> ops) {
837  set(value, ArrayRef<Operation *>(ops));
838  }
839 
840  /// Indicates that the result of the transform IR op at the given position
841  /// corresponds to the given list of parameters. Each result must be set by
842  /// the transformation exactly once in case of transformation succeeding. The
843  /// value must have a type implementing TransformParamTypeInterface.
845 
846  /// Indicates that the result of the transform IR op at the given position
847  /// corresponds to the given range of payload IR values. Each result must be
848  /// set by the transformation exactly once in case of transformation
849  /// succeeding. The value must have a type implementing
850  /// TransformValueHandleTypeInterface.
851  template <typename Range>
852  void setValues(OpResult handle, Range &&values) {
853  int64_t position = handle.getResultNumber();
854  assert(position < static_cast<int64_t>(this->values.size()) &&
855  "setting values for a non-existent handle");
856  assert(this->values[position].data() == nullptr && "values already set");
857  assert(operations[position].data() == nullptr &&
858  "another kind of results already set");
859  assert(params[position].data() == nullptr &&
860  "another kind of results already set");
861  this->values.replace(position, std::forward<Range>(values));
862  }
863 
864  /// Indicates that the result of the transform IR op at the given position
865  /// corresponds to the given range of payload IR values. Each result must be
866  /// set by the transformation exactly once in case of transformation
867  /// succeeding. The value must have a type implementing
868  /// TransformValueHandleTypeInterface.
869  void setValues(OpResult handle, std::initializer_list<Value> values) {
870  setValues(handle, ArrayRef<Value>(values));
871  }
872 
873  /// Indicates that the result of the transform IR op at the given position
874  /// corresponds to the given range of mapped values. All mapped values are
875  /// expected to be compatible with the type of the result, e.g., if the result
876  /// is an operation handle, all mapped values are expected to be payload
877  /// operations.
878  void setMappedValues(OpResult handle, ArrayRef<MappedValue> values);
879 
880  /// Sets the currently unset results to empty lists of the kind expected by
881  /// the corresponding results of the given `transform` op.
882  void setRemainingToEmpty(TransformOpInterface transform);
883 
884 private:
885  /// Creates an instance of TransformResults that expects mappings for
886  /// `numSegments` values, which may be associated with payload operations or
887  /// parameters.
888  explicit TransformResults(unsigned numSegments);
889 
890  /// Gets the list of operations associated with the result identified by its
891  /// number in the list of operation results. The result must have been set to
892  /// be associated with payload IR operations.
893  ArrayRef<Operation *> get(unsigned resultNumber) const;
894 
895  /// Gets the list of parameters associated with the result identified by its
896  /// number in the list of operation results. The result must have been set to
897  /// be associated with parameters.
898  ArrayRef<TransformState::Param> getParams(unsigned resultNumber) const;
899 
900  /// Gets the list of payload IR values associated with the result identified
901  /// by its number in the list of operation results. The result must have been
902  /// set to be associated with payload IR values.
903  ArrayRef<Value> getValues(unsigned resultNumber) const;
904 
905  /// Returns `true` if the result identified by its number in the list of
906  /// operation results is associated with a list of parameters, `false`
907  /// otherwise.
908  bool isParam(unsigned resultNumber) const;
909 
910  /// Returns `true` if the result identified by its number in the list of
911  /// operation results is associated with a list of payload IR value, `false`
912  /// otherwise.
913  bool isValue(unsigned resultNumber) const;
914 
915  /// Returns `true` if the result identified by its number in the list of
916  /// operation results is associated with something.
917  bool isSet(unsigned resultNumber) const;
918 
919  /// Pointers to payload IR ops that are associated with results of a transform
920  /// IR op.
921  RaggedArray<Operation *> operations;
922 
923  /// Parameters that are associated with results of the transform IR op.
924  RaggedArray<Param> params;
925 
926  /// Payload IR values that are associated with results of a transform IR op.
927  RaggedArray<Value> values;
928 };
929 
930 /// Creates a RAII object the lifetime of which corresponds to the new mapping
931 /// for transform IR values defined in the given region. Values defined in
932 /// surrounding regions remain accessible.
934  return RegionScope(*this, region);
935 }
936 
937 /// A configuration object for customizing a `TrackingListener`.
939  using SkipHandleFn = std::function<bool(Value)>;
940 
941  /// An optional function that returns "true" for handles that do not have to
942  /// be updated. These are typically dead or consumed handles.
944 
945  /// If set to "true", the name of a replacement op must match the name of the
946  /// original op. If set to "false", the names of the payload ops tracked in a
947  /// handle may change as the tracking listener updates the transform state.
949 
950  /// If set to "true", cast ops (that implement the CastOpInterface) are
951  /// skipped and the replacement op search continues with the operands of the
952  /// cast op.
953  bool skipCastOps = true;
954 };
955 
956 /// A listener that updates a TransformState based on IR modifications. This
957 /// listener can be used during a greedy pattern rewrite to keep the transform
958 /// state up-to-date.
961 public:
962  /// Create a new TrackingListener for usage in the specified transform op.
963  /// Optionally, a function can be specified to identify handles that should
964  /// do not have to be updated.
965  TrackingListener(TransformState &state, TransformOpInterface op,
967 
968 protected:
969  /// Return a replacement payload op for the given op, which is going to be
970  /// replaced with the given values. By default, if all values are defined by
971  /// the same op, which also has the same type as the given op, that defining
972  /// op is used as a replacement.
973  ///
974  /// A "failure" return value indicates that no replacement operation could be
975  /// found. A "nullptr" return value indicates that no replacement op is needed
976  /// (e.g., handle is dead or was consumed) and that the payload op should
977  /// be dropped from the mapping.
978  ///
979  /// Example: A tracked "linalg.generic" with two results is replaced with two
980  /// values defined by (another) "linalg.generic". It is reasonable to assume
981  /// that the replacement "linalg.generic" represents the same "computation".
982  /// Therefore, the payload op mapping is updated to the defining op of the
983  /// replacement values.
984  ///
985  /// Counter Example: A "linalg.generic" is replaced with values defined by an
986  /// "scf.for". Without further investigation, the relationship between the
987  /// "linalg.generic" and the "scf.for" is unclear. They may not represent the
988  /// same computation; e.g., there may be tiled "linalg.generic" inside the
989  /// loop body that represents the original computation. Therefore, the
990  /// TrackingListener is conservative by default: it drops the mapping and
991  /// triggers the "payload replacement not found" notification. This default
992  /// behavior can be customized in `TrackingListenerConfig`.
993  ///
994  /// If no replacement op could be found according to the rules mentioned
995  /// above, this function tries to skip over cast-like ops that implement
996  /// `CastOpInterface`.
997  ///
998  /// Example: A tracked "linalg.generic" is replaced with "linalg.generic",
999  /// wrapped in a "tensor.cast". A cast is a metadata-only operation and it is
1000  /// reasonable to assume that the wrapped "linalg.generic" represents the same
1001  /// computation as the original "linalg.generic". The mapping is updated
1002  /// accordingly.
1003  ///
1004  /// Certain ops (typically also metadata-only ops) are not considered casts,
1005  /// but should be skipped nonetheless. Such ops should implement
1006  /// `FindPayloadReplacementOpInterface` to specify with which operands the
1007  /// lookup should continue.
1008  ///
1009  /// Example: A tracked "linalg.generic" is replaced with "linalg.generic",
1010  /// wrapped in a "tensor.reshape". A reshape is a metadata-only operation but
1011  /// not cast. (Implementing `CastOpInterface` would be incorrect and cause
1012  /// invalid foldings.) However, due to its `FindPayloadReplacementOpInterface`
1013  /// implementation, the replacement op lookup continues with the wrapped
1014  /// "linalg.generic" and the mapping is updated accordingly.
1015  ///
1016  /// Derived classes may override `findReplacementOp` to specify custom
1017  /// replacement rules.
1019  findReplacementOp(Operation *&result, Operation *op,
1020  ValueRange newValues) const;
1021 
1022  /// Notify the listener that the pattern failed to match the given operation,
1023  /// and provide a callback to populate a diagnostic with the reason why the
1024  /// failure occurred.
1025  void
1027  function_ref<void(Diagnostic &)> reasonCallback) override;
1028 
1029  /// This function is called when a tracked payload op is dropped because no
1030  /// replacement op was found. Derived classes can implement this function for
1031  /// custom error handling.
1032  virtual void
1035 
1036  /// Return the single op that defines all given values (if any).
1037  static Operation *getCommonDefiningOp(ValueRange values);
1038 
1039  /// Return the transform op in which this TrackingListener is used.
1040  TransformOpInterface getTransformOp() const { return transformOp; }
1041 
1042 private:
1043  friend class TransformRewriter;
1044 
1045  void notifyOperationErased(Operation *op) override;
1046 
1047  void notifyOperationReplaced(Operation *op, ValueRange newValues) override;
1048  using Listener::notifyOperationReplaced;
1049 
1050  /// The transform op in which this TrackingListener is used.
1051  TransformOpInterface transformOp;
1052 
1053  /// The handles that are consumed by the transform op.
1054  DenseSet<Value> consumedHandles;
1055 
1056  /// Tracking listener configuration.
1057  TrackingListenerConfig config;
1058 };
1059 
1060 /// A specialized listener that keeps track of cases in which no replacement
1061 /// payload could be found. The error state of this listener must be checked
1062 /// before the end of its lifetime.
1064 public:
1066 
1067  ~ErrorCheckingTrackingListener() override;
1068 
1069  /// Check and return the current error state of this listener. Afterwards,
1070  /// resets the error state to "success".
1072 
1073  /// Return "true" if this tracking listener had a failure.
1074  bool failed() const;
1075 
1076 protected:
1077  void
1079  DiagnosedSilenceableFailure &&diag) override;
1080 
1081 private:
1082  /// The error state of this listener. "Success" indicates that no error
1083  /// happened so far.
1085 
1086  /// The number of errors that have been encountered.
1087  int64_t errorCounter = 0;
1088 };
1089 
1090 /// This is a special rewriter to be used in transform op implementations,
1091 /// providing additional helper functions to update the transform state, etc.
1092 // TODO: Helper functions will be added in a subsequent change.
1094 protected:
1095  friend class TransformState;
1096 
1097  /// Create a new TransformRewriter.
1098  explicit TransformRewriter(MLIRContext *ctx,
1099  ErrorCheckingTrackingListener *listener);
1100 
1101 public:
1102  /// Return "true" if the tracking listener had failures.
1103  bool hasTrackingFailures() const;
1104 
1105  /// Silence all tracking failures that have been encountered so far.
1106  void silenceTrackingFailure();
1107 
1108  /// Notify the transform dialect interpreter that the given op has been
1109  /// replaced with another op and that the mapping between handles and payload
1110  /// ops/values should be updated. This function should be called before the
1111  /// original op is erased. It fails if the operation could not be replaced,
1112  /// e.g., because the original operation is not tracked.
1113  ///
1114  /// Note: As long as IR modifications are performed through this rewriter,
1115  /// the transform state is usually updated automatically. This function should
1116  /// be used when unsupported rewriter API is used; e.g., updating all uses of
1117  /// a tracked operation one-by-one instead of using `RewriterBase::replaceOp`.
1118  LogicalResult notifyPayloadOperationReplaced(Operation *op,
1119  Operation *replacement);
1120 
1121 private:
1122  ErrorCheckingTrackingListener *const listener;
1123 };
1124 
1125 /// This trait is supposed to be attached to Transform dialect operations that
1126 /// can be standalone top-level transforms. Such operations typically contain
1127 /// other Transform dialect operations that can be executed following some
1128 /// control flow logic specific to the current operation. The operations with
1129 /// this trait are expected to have at least one single-block region with at
1130 /// least one argument of type implementing TransformHandleTypeInterface. The
1131 /// operations are also expected to be valid without operands, in which case
1132 /// they are considered top-level, and with one or more arguments, in which case
1133 /// they are considered nested. Top-level operations have the block argument of
1134 /// the entry block in the Transform IR correspond to the root operation of
1135 /// Payload IR. Nested operations have the block argument of the entry block in
1136 /// the Transform IR correspond to a list of Payload IR operations mapped to the
1137 /// first operand of the Transform IR operation. The operation must implement
1138 /// TransformOpInterface.
1139 template <typename OpTy>
1141  : public OpTrait::TraitBase<OpTy, PossibleTopLevelTransformOpTrait> {
1142 public:
1143  /// Verifies that `op` satisfies the invariants of this trait. Not expected to
1144  /// be called directly.
1145  static LogicalResult verifyTrait(Operation *op) {
1147  }
1148 
1149  /// Returns the single block of the given region.
1150  Block *getBodyBlock(unsigned region = 0) {
1151  return &this->getOperation()->getRegion(region).front();
1152  }
1153 
1154  /// Populates `effects` with side effects implied by this trait.
1158  this->getOperation(), cast<OpTy>(this->getOperation()).getRoot(),
1159  *getBodyBlock(), effects);
1160  }
1161 
1162  /// Sets up the mapping between the entry block of the given region of this op
1163  /// and the relevant list of Payload IR operations in the given state. The
1164  /// state is expected to be already scoped at the region of this operation.
1165  LogicalResult mapBlockArguments(TransformState &state, Region &region) {
1166  assert(region.getParentOp() == this->getOperation() &&
1167  "op comes from the wrong region");
1169  state, this->getOperation(), region);
1170  }
1171  LogicalResult mapBlockArguments(TransformState &state) {
1172  assert(
1173  this->getOperation()->getNumRegions() == 1 &&
1174  "must indicate the region to map if the operation has more than one");
1175  return mapBlockArguments(state, this->getOperation()->getRegion(0));
1176  }
1177 };
1178 
1179 class ApplyToEachResultList;
1180 
1181 /// Trait implementing the TransformOpInterface for operations applying a
1182 /// transformation to a single operation handle and producing an arbitrary
1183 /// number of handles and parameter values.
1184 /// The op must implement a method with the following signature:
1185 /// - DiagnosedSilenceableFailure applyToOne(OpTy,
1186 /// ApplyToEachResultList &results, TransformState &state)
1187 /// to perform a transformation that is applied in turn to all payload IR
1188 /// operations that correspond to the handle of the transform IR operation.
1189 /// In `applyToOne`, OpTy is either Operation* or a concrete payload IR Op class
1190 /// that the transformation is applied to (and NOT the class of the transform IR
1191 /// op).
1192 /// The `applyToOne` method takes an empty `results` vector that it fills with
1193 /// zero, one or multiple operations depending on the number of results expected
1194 /// by the transform op.
1195 /// The number of results must match the number of results of the transform op.
1196 /// `applyToOne` is allowed to fill the `results` with all null elements to
1197 /// signify that the transformation did not apply to the payload IR operations.
1198 /// Such null elements are filtered out from results before return.
1199 ///
1200 /// The transform op having this trait is expected to have a single operand.
1201 template <typename OpTy>
1203  : public OpTrait::TraitBase<OpTy, TransformEachOpTrait> {
1204 public:
1205  /// Calls `applyToOne` for every payload operation associated with the operand
1206  /// of this transform IR op, the following case disjunction happens:
1207  /// 1. If not target payload ops are associated to the operand then fill the
1208  /// results vector with the expected number of null elements and return
1209  /// success. This is the corner case handling that allows propagating
1210  /// the "no-op" case gracefully to improve usability.
1211  /// 2. If any `applyToOne` returns definiteFailure, the transformation is
1212  /// immediately considered definitely failed and we return.
1213  /// 3. All applications of `applyToOne` are checked to return a number of
1214  /// results expected by the transform IR op. If not, this is a definite
1215  /// failure and we return early.
1216  /// 4. If `applyToOne` produces ops, associate them with the result of this
1217  /// transform op.
1218  /// 5. If any `applyToOne` return silenceableFailure, the transformation is
1219  /// considered silenceable.
1220  /// 6. Otherwise the transformation is considered successful.
1222  TransformResults &transformResults,
1223  TransformState &state);
1224 
1225  /// Checks that the op matches the expectations of this trait.
1226  static LogicalResult verifyTrait(Operation *op);
1227 };
1228 
1229 /// Side effect resource corresponding to the mapping between Transform IR
1230 /// values and Payload IR operations. An Allocate effect from this resource
1231 /// means creating a new mapping entry, it is always accompanied by a Write
1232 /// effect. A Read effect from this resource means accessing the mapping. A Free
1233 /// effect on this resource indicates the removal of the mapping entry,
1234 /// typically after a transformation that modifies the Payload IR operations
1235 /// associated with one of the Transform IR operation's operands. It is always
1236 /// accompanied by a Read effect. Read-after-Free and double-Free are not
1237 /// allowed (they would be problematic with "regular" memory effects too) as
1238 /// they indicate an attempt to access Payload IR operations that have been
1239 /// modified, potentially erased, by the previous transformations.
1240 // TODO: consider custom effects if these are not enabling generic passes such
1241 // as CSE/DCE to work.
1243  : public SideEffects::Resource::Base<TransformMappingResource> {
1244  StringRef getName() override { return "transform.mapping"; }
1245 };
1246 
1247 /// Side effect resource corresponding to the Payload IR itself. Only Read and
1248 /// Write effects are expected on this resource, with Write always accompanied
1249 /// by a Read (short of fully replacing the top-level Payload IR operation, one
1250 /// cannot modify the Payload IR without reading it first). This is intended
1251 /// to disallow reordering of Transform IR operations that mutate the Payload IR
1252 /// while still allowing the reordering of those that only access it.
1254  : public SideEffects::Resource::Base<PayloadIRResource> {
1255  StringRef getName() override { return "transform.payload_ir"; }
1256 };
1257 
1258 /// Populates `effects` with the memory effects indicating the operation on the
1259 /// given handle value:
1260 /// - consumes = Read + Free,
1261 /// - produces = Allocate + Write,
1262 /// - onlyReads = Read.
1265 void producesHandle(ResultRange handles,
1271 
1272 /// Checks whether the transform op consumes the given handle.
1273 bool isHandleConsumed(Value handle, transform::TransformOpInterface transform);
1274 
1275 /// Populates `effects` with the memory effects indicating the access to payload
1276 /// IR resource.
1279 
1280 /// Checks whether the transform op modifies the payload.
1281 bool doesModifyPayload(transform::TransformOpInterface transform);
1282 /// Checks whether the transform op reads the payload.
1283 bool doesReadPayload(transform::TransformOpInterface transform);
1284 
1285 /// Populates `consumedArguments` with positions of `block` arguments that are
1286 /// consumed by the operations in the `block`.
1288  Block &block, llvm::SmallDenseSet<unsigned> &consumedArguments);
1289 
1290 /// Trait implementing the MemoryEffectOpInterface for operations that "consume"
1291 /// their operands and produce new results.
1292 template <typename OpTy>
1294  : public OpTrait::TraitBase<OpTy, FunctionalStyleTransformOpTrait> {
1295 public:
1296  /// This op "consumes" the operands by reading and freeing then, "produces"
1297  /// the results by allocating and writing it and reads/writes the payload IR
1298  /// in the process.
1300  consumesHandle(this->getOperation()->getOpOperands(), effects);
1301  producesHandle(this->getOperation()->getOpResults(), effects);
1302  modifiesPayload(effects);
1303  }
1304 
1305  /// Checks that the op matches the expectations of this trait.
1306  static LogicalResult verifyTrait(Operation *op) {
1307  if (!op->getName().getInterface<MemoryEffectOpInterface>()) {
1308  op->emitError()
1309  << "FunctionalStyleTransformOpTrait should only be attached to ops "
1310  "that implement MemoryEffectOpInterface";
1311  }
1312  return success();
1313  }
1314 };
1315 
1316 /// Trait implementing the MemoryEffectOpInterface for operations that use their
1317 /// operands without consuming and without modifying the Payload IR to
1318 /// potentially produce new handles.
1319 template <typename OpTy>
1321  : public OpTrait::TraitBase<OpTy, NavigationTransformOpTrait> {
1322 public:
1323  /// This op produces handles to the Payload IR without consuming the original
1324  /// handles and without modifying the IR itself.
1326  onlyReadsHandle(this->getOperation()->getOpOperands(), effects);
1327  producesHandle(this->getOperation()->getOpResults(), effects);
1328  if (llvm::any_of(this->getOperation()->getOperandTypes(), [](Type t) {
1329  return isa<TransformHandleTypeInterface,
1330  TransformValueHandleTypeInterface>(t);
1331  })) {
1332  onlyReadsPayload(effects);
1333  }
1334  }
1335 
1336  /// Checks that the op matches the expectation of this trait.
1337  static LogicalResult verifyTrait(Operation *op) {
1338  if (!op->getName().getInterface<MemoryEffectOpInterface>()) {
1339  op->emitError() << "NavigationTransformOpTrait should only be attached "
1340  "to ops that implement MemoryEffectOpInterface";
1341  }
1342  return success();
1343  }
1344 };
1345 
1346 namespace detail {
1347 /// Non-template implementation of ParamProducerTransformOpTrait::getEffects().
1350 /// Non-template implementation of ParamProducerTransformOpTrait::verify().
1352 } // namespace detail
1353 
1354 /// Trait implementing the MemoryEffectsOpInterface for operations that produce
1355 /// transform dialect parameters. It marks all op results of
1356 /// TransformHandleTypeInterface as produced by the op, all operands as only
1357 /// read by the op and, if at least one of the operand is a handle to payload
1358 /// ops, the entire payload as potentially read. The op must only produce
1359 /// parameter-typed results.
1360 template <typename OpTy>
1362  : public OpTrait::TraitBase<OpTy, ParamProducerTransformOpTrait> {
1363 public:
1364  /// Populates `effects` with effect instances described in the trait
1365  /// documentation.
1368  effects);
1369  }
1370 
1371  /// Checks that the op matches the expectation of this trait, i.e., that it
1372  /// implements the MemoryEffectsOpInterface and only produces parameter-typed
1373  /// results.
1374  static LogicalResult verifyTrait(Operation *op) {
1376  }
1377 };
1378 
1379 /// `TrackingListener` failures are reported only for ops that have this trait.
1380 /// The purpose of this trait is to give users more time to update their custom
1381 /// transform ops to use the provided `TransformRewriter` for all IR
1382 /// modifications. This trait will eventually be removed, and failures will be
1383 /// reported for all transform ops.
1384 template <typename OpTy>
1386  : public OpTrait::TraitBase<OpTy, ReportTrackingListenerFailuresOpTrait> {};
1387 
1388 /// A single result of applying a transform op with `ApplyEachOpTrait` to a
1389 /// single payload operation.
1391 
1392 /// A list of results of applying a transform op with `ApplyEachOpTrait` to a
1393 /// single payload operation, co-indexed with the results of the transform op.
1395 public:
1397  explicit ApplyToEachResultList(unsigned size) : results(size) {}
1398 
1399  /// Sets the list of results to `size` null pointers.
1400  void assign(unsigned size, std::nullptr_t) { results.assign(size, nullptr); }
1401 
1402  /// Sets the list of results to the given range of values.
1403  template <typename Range>
1404  void assign(Range &&range) {
1405  // This is roughly the implementation of SmallVectorImpl::assign.
1406  // Dispatching to it with map_range and template type inference would result
1407  // in more complex code here.
1408  results.clear();
1409  results.reserve(llvm::size(range));
1410  for (auto element : range) {
1411  if constexpr (std::is_convertible_v<decltype(*std::begin(range)),
1412  Operation *>) {
1413  results.push_back(static_cast<Operation *>(element));
1414  } else if constexpr (std::is_convertible_v<decltype(*std::begin(range)),
1415  Value>) {
1416  results.push_back(element.template get<Value>());
1417  } else {
1418  results.push_back(static_cast<Attribute>(element));
1419  }
1420  }
1421  }
1422 
1423  /// Appends an element to the list.
1424  // Using ApplyToEachResult that can be implicitly constructed from a Value but
1425  // not from a concrete Op that is implicitly convertible to a Value to avoid
1426  // ambiguity.
1427  void push_back(Operation *op) { results.push_back(op); }
1428  void push_back(Attribute attr) { results.push_back(attr); }
1429  void push_back(ApplyToEachResult r) { results.push_back(r); }
1430 
1431  /// Reserves space for `size` elements in the list.
1432  void reserve(unsigned size) { results.reserve(size); }
1433 
1434  /// Iterators over the list.
1435  auto begin() { return results.begin(); }
1436  auto end() { return results.end(); }
1437  auto begin() const { return results.begin(); }
1438  auto end() const { return results.end(); }
1439 
1440  /// Returns the number of elements in the list.
1441  size_t size() const { return results.size(); }
1442 
1443  /// Element access. Expects the index to be in bounds.
1444  ApplyToEachResult &operator[](size_t index) { return results[index]; }
1445  const ApplyToEachResult &operator[](size_t index) const {
1446  return results[index];
1447  }
1448 
1449 private:
1450  /// Underlying storage.
1452 };
1453 
1454 namespace detail {
1455 
1456 /// Check that the contents of `partialResult` matches the number, kind (payload
1457 /// op or parameter) and nullity (either all or none) requirements of
1458 /// `transformOp`. Report errors and return failure otherwise.
1459 LogicalResult checkApplyToOne(Operation *transformOp, Location payloadOpLoc,
1460  const ApplyToEachResultList &partialResult);
1461 
1462 /// "Transpose" the results produced by individual applications, arranging them
1463 /// per result value of the transform op, and populate `transformResults` with
1464 /// that. The number, kind and nullity of per-application results are assumed to
1465 /// have been verified.
1466 void setApplyToOneResults(Operation *transformOp,
1467  TransformResults &transformResults,
1469 
1470 /// Applies a one-to-one or a one-to-many transform to each of the given
1471 /// targets. Puts the results of transforms, if any, in `results` in the same
1472 /// order. Fails if any of the application fails. Individual transforms must be
1473 /// callable with the following signature:
1474 /// - DiagnosedSilenceableFailure(OpTy,
1475 /// SmallVector<Operation*> &results, state)
1476 /// where OpTy is either
1477 /// - Operation *, in which case the transform is always applied;
1478 /// - a concrete Op class, in which case a check is performed whether
1479 /// `targets` contains operations of the same class and a silenceable failure
1480 /// is reported if it does not.
1481 template <typename TransformOpTy, typename Range>
1483  TransformOpTy transformOp, TransformRewriter &rewriter, Range &&targets,
1485  using OpTy = typename llvm::function_traits<
1486  decltype(&TransformOpTy::applyToOne)>::template arg_t<1>;
1487  static_assert(std::is_convertible<OpTy, Operation *>::value,
1488  "expected transform function to take an operation");
1489  OpBuilder::InsertionGuard g(rewriter);
1490 
1491  SmallVector<Diagnostic> silenceableStack;
1492  unsigned expectedNumResults = transformOp->getNumResults();
1493  for (Operation *target : targets) {
1494  auto specificOp = dyn_cast<OpTy>(target);
1495  if (!specificOp) {
1496  Diagnostic diag(transformOp->getLoc(), DiagnosticSeverity::Error);
1497  diag << "transform applied to the wrong op kind";
1498  diag.attachNote(target->getLoc()) << "when applied to this op";
1499  silenceableStack.push_back(std::move(diag));
1500  continue;
1501  }
1502 
1503  ApplyToEachResultList partialResults;
1504  partialResults.reserve(expectedNumResults);
1505  Location specificOpLoc = specificOp->getLoc();
1506  rewriter.setInsertionPoint(specificOp);
1508  transformOp.applyToOne(rewriter, specificOp, partialResults, state);
1509  if (res.isDefiniteFailure())
1511 
1512  if (res.isSilenceableFailure()) {
1513  res.takeDiagnostics(silenceableStack);
1514  continue;
1515  }
1516 
1517  if (failed(detail::checkApplyToOne(transformOp, specificOpLoc,
1518  partialResults))) {
1520  }
1521  results.push_back(std::move(partialResults));
1522  }
1523  if (!silenceableStack.empty()) {
1525  std::move(silenceableStack));
1526  }
1528 }
1529 
1530 /// Reports an error and returns failure if `targets` contains an ancestor
1531 /// operation before its descendant (or a copy of itself). Implementation detail
1532 /// for expensive checks during `TransformEachOpTrait::apply`.
1533 LogicalResult checkNestedConsumption(Location loc,
1534  ArrayRef<Operation *> targets);
1535 
1536 } // namespace detail
1537 } // namespace transform
1538 } // namespace mlir
1539 
1540 template <typename OpTy>
1543  TransformRewriter &rewriter, TransformResults &transformResults,
1544  TransformState &state) {
1545  Value handle = this->getOperation()->getOperand(0);
1546  auto targets = state.getPayloadOps(handle);
1547 
1548  // If the operand is consumed, check if it is associated with operations that
1549  // may be erased before their nested operations are.
1550  if (state.getOptions().getExpensiveChecksEnabled() &&
1551  isHandleConsumed(handle, cast<transform::TransformOpInterface>(
1552  this->getOperation())) &&
1553  failed(detail::checkNestedConsumption(this->getOperation()->getLoc(),
1554  llvm::to_vector(targets)))) {
1555  return DiagnosedSilenceableFailure::definiteFailure();
1556  }
1557 
1558  // Step 1. Handle the corner case where no target is specified.
1559  // This is typically the case when the matcher fails to apply and we need to
1560  // propagate gracefully.
1561  // In this case, we fill all results with an empty vector.
1562  if (std::empty(targets)) {
1563  SmallVector<Operation *> emptyPayload;
1564  SmallVector<Attribute> emptyParams;
1565  for (OpResult r : this->getOperation()->getResults()) {
1566  if (isa<TransformParamTypeInterface>(r.getType()))
1567  transformResults.setParams(r, emptyParams);
1568  else if (isa<TransformValueHandleTypeInterface>(r.getType()))
1569  transformResults.setValues(r, ValueRange());
1570  else
1571  transformResults.set(r, emptyPayload);
1572  }
1573  return DiagnosedSilenceableFailure::success();
1574  }
1575 
1576  // Step 2. Call applyToOne on each target and record newly produced ops in its
1577  // corresponding results entry.
1580  cast<OpTy>(this->getOperation()), rewriter, targets, results, state);
1581 
1582  // Step 3. Propagate the definite failure if any and bail out.
1583  if (result.isDefiniteFailure())
1584  return result;
1585 
1586  // Step 4. "Transpose" the results produced by individual applications,
1587  // arranging them per result value of the transform op. The number, kind and
1588  // nullity of per-application results have been verified by the callback
1589  // above.
1590  detail::setApplyToOneResults(this->getOperation(), transformResults, results);
1591 
1592  // Step 5. ApplyToOne may have returned silenceableFailure, propagate it.
1593  return result;
1594 }
1595 
1596 template <typename OpTy>
1597 llvm::LogicalResult
1599  static_assert(OpTy::template hasTrait<OpTrait::OneOperand>(),
1600  "expected single-operand op");
1601  if (!op->getName().getInterface<TransformOpInterface>()) {
1602  return op->emitError() << "TransformEachOpTrait should only be attached to "
1603  "ops that implement TransformOpInterface";
1604  }
1605 
1606  return success();
1607 }
1608 
1609 #endif // DIALECT_TRANSFORM_INTERFACES_TRANSFORMINTERFACES_H
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
Block represents an ordered list of Operations.
Definition: Block.h:31
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
void takeDiagnostics(SmallVectorImpl< Diagnostic > &diags)
Take the diagnostics and silence.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
bool isSilenceableFailure() const
Returns true if this is a silenceable failure.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:351
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:401
This is a value defined by a result of an operation.
Definition: Value.h:457
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:469
Helper class for implementing traits.
Definition: OpDefinition.h:373
Operation * getOperation()
Return the ultimate Operation being worked on.
Definition: OpDefinition.h:376
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
A 2D array where each row may have different length.
Definition: RaggedArray.h:18
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Definition: Region.cpp:45
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
Block & front()
Definition: Region.h:65
This class implements the result iterators for the Operation class.
Definition: ValueRange.h:242
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
This base class is used for derived effects that are non-parametric.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition: Value.cpp:41
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
auto begin()
Iterators over the list.
const ApplyToEachResult & operator[](size_t index) const
void assign(unsigned size, std::nullptr_t)
Sets the list of results to size null pointers.
void reserve(unsigned size)
Reserves space for size elements in the list.
size_t size() const
Returns the number of elements in the list.
ApplyToEachResult & operator[](size_t index)
Element access. Expects the index to be in bounds.
void push_back(Operation *op)
Appends an element to the list.
void assign(Range &&range)
Sets the list of results to the given range of values.
A specialized listener that keeps track of cases in which no replacement payload could be found.
bool failed() const
Return "true" if this tracking listener had a failure.
void notifyPayloadReplacementNotFound(Operation *op, ValueRange values, DiagnosedSilenceableFailure &&diag) override
This function is called when a tracked payload op is dropped because no replacement op was found.
DiagnosedSilenceableFailure checkAndResetError()
Check and return the current error state of this listener.
Trait implementing the MemoryEffectOpInterface for operations that "consume" their operands and produ...
void getEffects(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
This op "consumes" the operands by reading and freeing then, "produces" the results by allocating and...
static LogicalResult verifyTrait(Operation *op)
Checks that the op matches the expectations of this trait.
Trait implementing the MemoryEffectOpInterface for operations that use their operands without consumi...
void getEffects(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
This op produces handles to the Payload IR without consuming the original handles and without modifyi...
static LogicalResult verifyTrait(Operation *op)
Checks that the op matches the expectation of this trait.
Trait implementing the MemoryEffectsOpInterface for operations that produce transform dialect paramet...
static LogicalResult verifyTrait(Operation *op)
Checks that the op matches the expectation of this trait, i.e., that it implements the MemoryEffectsO...
void getEffects(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with effect instances described in the trait documentation.
This trait is supposed to be attached to Transform dialect operations that can be standalone top-leve...
LogicalResult mapBlockArguments(TransformState &state)
static LogicalResult verifyTrait(Operation *op)
Verifies that op satisfies the invariants of this trait.
void getPotentialTopLevelEffects(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with side effects implied by this trait.
LogicalResult mapBlockArguments(TransformState &state, Region &region)
Sets up the mapping between the entry block of the given region of this op and the relevant list of P...
Block * getBodyBlock(unsigned region=0)
Returns the single block of the given region.
TrackingListener failures are reported only for ops that have this trait.
A listener that updates a TransformState based on IR modifications.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notify the listener that the pattern failed to match the given operation, and provide a callback to p...
TransformOpInterface getTransformOp() const
Return the transform op in which this TrackingListener is used.
TrackingListener(TransformState &state, TransformOpInterface op, TrackingListenerConfig config=TrackingListenerConfig())
Create a new TrackingListener for usage in the specified transform op.
static Operation * getCommonDefiningOp(ValueRange values)
Return the single op that defines all given values (if any).
virtual void notifyPayloadReplacementNotFound(Operation *op, ValueRange values, DiagnosedSilenceableFailure &&diag)
This function is called when a tracked payload op is dropped because no replacement op was found.
virtual DiagnosedSilenceableFailure findReplacementOp(Operation *&result, Operation *op, ValueRange newValues) const
Return a replacement payload op for the given op, which is going to be replaced with the given values...
Trait implementing the TransformOpInterface for operations applying a transformation to a single oper...
static LogicalResult verifyTrait(Operation *op)
Checks that the op matches the expectations of this trait.
DiagnosedSilenceableFailure apply(transform::TransformRewriter &rewriter, TransformResults &transformResults, TransformState &state)
Calls applyToOne for every payload operation associated with the operand of this transform IR op,...
Options controlling the application of transform operations by the TransformState.
TransformOptions & enableExpensiveChecks(bool enable=true)
Requests computationally expensive checks of the transform and payload IR well-formedness to be perfo...
TransformOptions & enableEnforceSingleToplevelTransformOp(bool enable=true)
TransformOptions & operator=(const TransformOptions &)=default
TransformOptions(const TransformOptions &)=default
bool getExpensiveChecksEnabled() const
Returns true if the expensive checks are requested.
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void set(OpResult value, std::initializer_list< Operation * > ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void setValues(OpResult handle, std::initializer_list< Value > values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
void setValues(OpResult handle, Range &&values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
void setParams(OpResult value, ArrayRef< TransformState::Param > params)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void setRemainingToEmpty(TransformOpInterface transform)
Sets the currently unset results to empty lists of the kind expected by the corresponding results of ...
void setMappedValues(OpResult handle, ArrayRef< MappedValue > values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
TransformRewriter(MLIRContext *ctx, ErrorCheckingTrackingListener *listener)
Create a new TransformRewriter.
bool hasTrackingFailures() const
Return "true" if the tracking listener had failures.
LogicalResult notifyPayloadOperationReplaced(Operation *op, Operation *replacement)
Notify the transform dialect interpreter that the given op has been replaced with another op and that...
void silenceTrackingFailure()
Silence all tracking failures that have been encountered so far.
Base class for TransformState extensions that allow TransformState to contain user-specified informat...
Extension(TransformState &state)
Constructs an extension of the given TransformState object.
const TransformState & getTransformState() const
Provides read-only access to the parent TransformState object.
LogicalResult replacePayloadOp(Operation *op, Operation *replacement)
Replaces the given payload op with another op.
virtual ~Extension()
Base virtual destructor.
LogicalResult replacePayloadValue(Value value, Value replacement)
Replaces the given payload value with another value.
A RAII object maintaining a "stack frame" for a transform IR region.
~RegionScope()
Forgets the mapping from or to values defined in the associated transform IR region,...
The state maintained across applications of various ops implementing the TransformOpInterface.
const TransformOptions & getOptions() const
friend LogicalResult applyTransforms(Operation *, TransformOpInterface, const RaggedArray< MappedValue > &, const TransformOptions &, bool)
Entry point to the Transform dialect infrastructure.
LogicalResult getHandlesForPayloadValue(Value payloadValue, SmallVectorImpl< Value > &handles, bool includeOutOfScope=false) const
Populates handles with all handles pointing to the given payload IR value.
auto getPayloadOps(Value value) const
Returns an iterator that enumerates all ops that the given transform IR value corresponds to.
auto getPayloadValues(Value handleValue) const
Returns an iterator that enumerates all payload IR values that the given transform IR value correspon...
LogicalResult mapBlockArguments(BlockArgument argument, ArrayRef< Operation * > operations)
Records the mapping between a block argument in the transform IR and a list of operations in the payl...
DiagnosedSilenceableFailure applyTransform(TransformOpInterface transform)
Applies the transformation specified by the given transform op and updates the state accordingly.
RegionScope make_region_scope(Region &region)
Creates a new region scope for the given region.
ArrayRef< Attribute > getParams(Value value) const
Returns the list of parameters that the given transform IR value corresponds to.
LogicalResult mapBlockArgument(BlockArgument argument, ArrayRef< MappedValue > values)
size_t getNumTopLevelMappings() const
Returns the number of extra mappings for the top-level operation.
Ty & addExtension(Args &&...args)
Adds a new Extension of the type specified as template parameter, constructing it with the arguments ...
LogicalResult getHandlesForPayloadOp(Operation *op, SmallVectorImpl< Value > &handles, bool includeOutOfScope=false) const
Populates handles with all handles pointing to the given Payload IR op.
Ty * getExtension()
Returns the extension of the specified type.
Operation * getTopLevel() const
Returns the op at which the transformation state is rooted.
void removeExtension()
Removes the extension of the specified type.
ArrayRef< MappedValue > getTopLevelMapping(size_t position) const
Returns the position-th extra mapping for the top-level operation.
LogicalResult verifyTransformOpInterface(Operation *op)
Verification hook for TransformOpInterface.
LogicalResult appendValueMappings(MutableArrayRef< SmallVector< transform::MappedValue >> mappings, ValueRange values, const transform::TransformState &state, bool flatten=true)
Appends the entities associated with the given transform values in state to the pre-existing list of ...
void setApplyToOneResults(Operation *transformOp, TransformResults &transformResults, ArrayRef< ApplyToEachResultList > results)
"Transpose" the results produced by individual applications, arranging them per result value of the t...
void forwardTerminatorOperands(Block *block, transform::TransformState &state, transform::TransformResults &results)
Populates results with payload associations that match exactly those of the operands to block's termi...
void getParamProducerTransformOpTraitEffects(Operation *op, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Non-template implementation of ParamProducerTransformOpTrait::getEffects().
LogicalResult checkNestedConsumption(Location loc, ArrayRef< Operation * > targets)
Reports an error and returns failure if targets contains an ancestor operation before its descendant ...
LogicalResult mapPossibleTopLevelTransformOpBlockArguments(TransformState &state, Operation *op, Region &region)
Maps the only block argument of the op with PossibleTopLevelTransformOpTrait to either the list of op...
TransformState makeTransformStateForTesting(Region *region, Operation *payloadRoot)
Make a dummy transform state for testing purposes.
LogicalResult checkApplyToOne(Operation *transformOp, Location payloadOpLoc, const ApplyToEachResultList &partialResult)
Check that the contents of partialResult matches the number, kind (payload op or parameter) and nulli...
SmallVector< OpOperand * > getConsumedHandleOpOperands(transform::TransformOpInterface transformOp)
Returns all operands that are handles and being consumed by the given op.
LogicalResult verifyParamProducerTransformOpTrait(Operation *op)
Non-template implementation of ParamProducerTransformOpTrait::verify().
void prepareValueMappings(SmallVectorImpl< SmallVector< transform::MappedValue >> &mappings, ValueRange values, const transform::TransformState &state)
Populates mappings with mapped values associated with the given transform IR values in the given stat...
LogicalResult verifyPossibleTopLevelTransformOpTrait(Operation *op)
Verification hook for PossibleTopLevelTransformOpTrait.
DiagnosedSilenceableFailure applyTransformToEach(TransformOpTy transformOp, TransformRewriter &rewriter, Range &&targets, SmallVectorImpl< ApplyToEachResultList > &results, TransformState &state)
Applies a one-to-one or a one-to-many transform to each of the given targets.
void getPotentialTopLevelEffects(Operation *operation, Value root, Block &body, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with side effects implied by PossibleTopLevelTransformOpTrait for the given operati...
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
bool isHandleConsumed(Value handle, transform::TransformOpInterface transform)
Checks whether the transform op consumes the given handle.
void getConsumedBlockArguments(Block &block, llvm::SmallDenseSet< unsigned > &consumedArguments)
Populates consumedArguments with positions of block arguments that are consumed by the operations in ...
LogicalResult applyTransforms(Operation *payloadRoot, TransformOpInterface transform, const RaggedArray< MappedValue > &extraMapping={}, const TransformOptions &options=TransformOptions(), bool enforceToplevelTransformOp=true)
Entry point to the Transform dialect infrastructure.
bool doesModifyPayload(transform::TransformOpInterface transform)
Checks whether the transform op modifies the payload.
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
bool doesReadPayload(transform::TransformOpInterface transform)
Checks whether the transform op reads the payload.
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
llvm::PointerUnion< Operation *, Param, Value > MappedValue
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
Side effect resource corresponding to the Payload IR itself.
StringRef getName() override
Return a string name of the resource.
A configuration object for customizing a TrackingListener.
bool requireMatchingReplacementOpName
If set to "true", the name of a replacement op must match the name of the original op.
bool skipCastOps
If set to "true", cast ops (that implement the CastOpInterface) are skipped and the replacement op se...
SkipHandleFn skipHandleFn
An optional function that returns "true" for handles that do not have to be updated.
std::function< bool(Value)> SkipHandleFn
Side effect resource corresponding to the mapping between Transform IR values and Payload IR operatio...
StringRef getName() override
Return a string name of the resource.