MLIR  20.0.0git
TransformInterfaces.h
Go to the documentation of this file.
1 //===- TransformInterfaces.h - Transform Dialect Interfaces -----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_TRANSFORM_INTERFACES_TRANSFORMINTERFACES_H
10 #define MLIR_DIALECT_TRANSFORM_INTERFACES_TRANSFORMINTERFACES_H
11 
14 #include "mlir/IR/OpDefinition.h"
15 #include "mlir/IR/PatternMatch.h"
18 
19 #include "mlir/Dialect/Transform/Interfaces/TransformTypeInterfaces.h.inc"
20 
21 namespace mlir {
22 namespace transform {
23 
24 class TransformOpInterface;
25 class TransformResults;
26 class TransformRewriter;
27 class TransformState;
28 
29 using Param = Attribute;
31 
32 namespace detail {
33 /// Maps the only block argument of the op with PossibleTopLevelTransformOpTrait
34 /// to either the list of operations associated with its operand or the root of
35 /// the payload IR, depending on what is available in the context.
36 LogicalResult
38  Operation *op, Region &region);
39 
40 /// Verification hook for PossibleTopLevelTransformOpTrait.
42 
43 /// Populates `effects` with side effects implied by
44 /// PossibleTopLevelTransformOpTrait for the given operation. The operation may
45 /// have an optional `root` operand, indicating it is not in fact top-level. It
46 /// is also expected to have a single-block body.
48  Operation *operation, Value root, Block &body,
50 
51 /// Verification hook for TransformOpInterface.
52 LogicalResult verifyTransformOpInterface(Operation *op);
53 
54 /// Appends the entities associated with the given transform values in `state`
55 /// to the pre-existing list of mappings. The array of mappings must have as
56 /// many elements as values. If `flatten` is set, multiple values may be
57 /// associated with each transform value, and this always succeeds. Otherwise,
58 /// checks that each value has exactly one mapping associated and return failure
59 /// otherwise.
60 LogicalResult appendValueMappings(
62  ValueRange values, const transform::TransformState &state,
63  bool flatten = true);
64 
65 /// Populates `mappings` with mapped values associated with the given transform
66 /// IR values in the given `state`.
69  ValueRange values, const transform::TransformState &state);
70 
71 /// Populates `results` with payload associations that match exactly those of
72 /// the operands to `block`'s terminator.
75 
76 /// Make a dummy transform state for testing purposes. This MUST NOT be used
77 /// outside of test cases.
79  Operation *payloadRoot);
80 
81 /// Returns all operands that are handles and being consumed by the given op.
83 getConsumedHandleOpOperands(transform::TransformOpInterface transformOp);
84 } // namespace detail
85 } // namespace transform
86 } // namespace mlir
87 
88 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h.inc"
89 
90 namespace mlir {
91 namespace transform {
92 
93 /// Options controlling the application of transform operations by the
94 /// TransformState.
96 public:
97  TransformOptions() = default;
98  TransformOptions(const TransformOptions &) = default;
100 
101  /// Requests computationally expensive checks of the transform and payload IR
102  /// well-formedness to be performed before each transformation. In particular,
103  /// these ensure that the handles still point to valid operations when used.
105  expensiveChecksEnabled = enable;
106  return *this;
107  }
108 
109  // Ensures that only a single top-level transform op is present in the IR.
111  enforceSingleToplevelTransformOp = enable;
112  return *this;
113  }
114 
115  /// Returns true if the expensive checks are requested.
116  bool getExpensiveChecksEnabled() const { return expensiveChecksEnabled; }
117 
118  // Returns true if enforcing a single top-level transform op is requested.
120  return enforceSingleToplevelTransformOp;
121  }
122 
123 private:
124  bool expensiveChecksEnabled = true;
125  bool enforceSingleToplevelTransformOp = true;
126 };
127 
128 /// Entry point to the Transform dialect infrastructure. Applies the
129 /// transformation specified by `transform` to payload IR contained in
130 /// `payloadRoot`. The `transform` operation may contain other operations that
131 /// will be executed following the internal logic of the operation. It must
132 /// have the `PossibleTopLevelTransformOp` trait and not have any operands.
133 /// This function internally keeps track of the transformation state.
134 LogicalResult applyTransforms(
135  Operation *payloadRoot, TransformOpInterface transform,
136  const RaggedArray<MappedValue> &extraMapping = {},
137  const TransformOptions &options = TransformOptions(),
138  bool enforceToplevelTransformOp = true,
139  function_ref<void(TransformState &)> stateInitializer = nullptr,
140  function_ref<LogicalResult(TransformState &)> stateExporter = nullptr);
141 
142 /// The state maintained across applications of various ops implementing the
143 /// TransformOpInterface. The operations implementing this interface and the
144 /// surrounding structure are referred to as transform IR. The operations to
145 /// which transformations apply are referred to as payload IR. Transform IR
146 /// operates on values that can be associated either with a list of payload IR
147 /// operations (such values are referred to as handles) or with a list of
148 /// parameters represented as attributes. The state thus contains the mapping
149 /// between values defined in the transform IR ops and either payload IR ops or
150 /// parameters. For payload ops, the mapping is many-to-many and the reverse
151 /// mapping is also stored. The "expensive-checks" option can be passed to the
152 /// constructor at transformation execution time that transform IR values used
153 /// as operands by a transform IR operation are not associated with dangling
154 /// pointers to payload IR operations that are known to have been erased by
155 /// previous transformation through the same or a different transform IR value.
156 ///
157 /// A reference to this class is passed as an argument to "apply" methods of the
158 /// transform op interface. Thus the "apply" method can call either
159 /// `state.getPayloadOps( getSomeOperand() )` to obtain the list of operations
160 /// or `state.getParams( getSomeOperand() )` to obtain the list of parameters
161 /// associated with its operand. The method is expected to populate the
162 /// `TransformResults` class instance in order to update the mapping. The
163 /// `applyTransform` method takes care of propagating the state of
164 /// `TransformResults` into the instance of this class.
165 ///
166 /// When applying transform IR operations with regions, the client is expected
167 /// to create a `RegionScope` RAII object to create a new "stack frame" for
168 /// values defined inside the region. The mappings from and to these values will
169 /// be automatically dropped when the object goes out of scope, typically at the
170 /// end of the `apply` function of the parent operation. If a region contains
171 /// blocks with arguments, the client can map those arguments to payload IR ops
172 /// using `mapBlockArguments`.
174 public:
176 
177 private:
178  /// Mapping between a Value in the transform IR and the corresponding set of
179  /// operations in the payload IR.
181 
182  /// Mapping between a payload IR operation and the transform IR values it is
183  /// associated with.
186 
187  /// Mapping between a Value in the transform IR and the corresponding list of
188  /// parameters.
190 
191  /// Mapping between a Value in the transform IR and the corrsponding list of
192  /// values in the payload IR. Also works for reverse mappings.
194 
195  /// Mapping between a Value in the transform IR and an error message that
196  /// should be emitted when the value is used.
197  using InvalidatedHandleMap = DenseMap<Value, std::function<void(Location)>>;
198 
199 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
200  /// Debug only: A timestamp is associated with each transform IR value, so
201  /// that invalid iterator usage can be detected more reliably.
202  using TransformIRTimestampMapping = DenseMap<Value, int64_t>;
203 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
204 
205  /// The bidirectional mappings between transform IR values and payload IR
206  /// operations, and the mapping between transform IR values and parameters.
207  struct Mappings {
208  TransformOpMapping direct;
210  ParamMapping params;
211  ValueMapping values;
212  ValueMapping reverseValues;
213 
214 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
215  TransformIRTimestampMapping timestamps;
216  void incrementTimestamp(Value value) { ++timestamps[value]; }
217 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
218  };
219 
220  friend LogicalResult
221  applyTransforms(Operation *, TransformOpInterface,
223  bool, function_ref<void(TransformState &)>,
224  function_ref<LogicalResult(TransformState &)>);
225 
226  friend TransformState
228 
229 public:
230  const TransformOptions &getOptions() const { return options; }
231 
232  /// Returns the op at which the transformation state is rooted. This is
233  /// typically helpful for transformations that apply globally.
234  Operation *getTopLevel() const;
235 
236  /// Returns the number of extra mappings for the top-level operation.
237  size_t getNumTopLevelMappings() const { return topLevelMappedValues.size(); }
238 
239  /// Returns the position-th extra mapping for the top-level operation.
240  ArrayRef<MappedValue> getTopLevelMapping(size_t position) const {
241  return topLevelMappedValues[position];
242  }
243 
244  /// Returns an iterator that enumerates all ops that the given transform IR
245  /// value corresponds to. Ops may be erased while iterating; erased ops are
246  /// not enumerated. This function is helpful for transformations that apply to
247  /// a particular handle.
248  auto getPayloadOps(Value value) const {
249  ArrayRef<Operation *> view = getPayloadOpsView(value);
250 
251 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
252  // Memorize the current timestamp and make sure that it has not changed
253  // when incrementing or dereferencing the iterator returned by this
254  // function. The timestamp is incremented when the "direct" mapping is
255  // resized; this would invalidate the iterator returned by this function.
256  int64_t currentTimestamp = getMapping(value).timestamps.lookup(value);
257 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
258 
259  // When ops are replaced/erased, they are replaced with nullptr (until
260  // the data structure is compacted). Do not enumerate these ops.
261  return llvm::make_filter_range(view, [=](Operation *op) {
262 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
263  [[maybe_unused]] bool sameTimestamp =
264  currentTimestamp == this->getMapping(value).timestamps.lookup(value);
265  assert(sameTimestamp && "iterator was invalidated during iteration");
266 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
267  return op != nullptr;
268  });
269  }
270 
271  /// Returns the list of parameters that the given transform IR value
272  /// corresponds to.
273  ArrayRef<Attribute> getParams(Value value) const;
274 
275  /// Returns an iterator that enumerates all payload IR values that the given
276  /// transform IR value corresponds to.
277  auto getPayloadValues(Value handleValue) const {
278  ArrayRef<Value> view = getPayloadValuesView(handleValue);
279 
280 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
281  // Memorize the current timestamp and make sure that it has not changed
282  // when incrementing or dereferencing the iterator returned by this
283  // function. The timestamp is incremented when the "values" mapping is
284  // resized; this would invalidate the iterator returned by this function.
285  int64_t currentTimestamp =
286  getMapping(handleValue).timestamps.lookup(handleValue);
287  return llvm::make_filter_range(view, [=](Value v) {
288  [[maybe_unused]] bool sameTimestamp =
289  currentTimestamp ==
290  this->getMapping(handleValue).timestamps.lookup(handleValue);
291  assert(sameTimestamp && "iterator was invalidated during iteration");
292  return true;
293  });
294 #else
295  return llvm::make_range(view.begin(), view.end());
296 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
297  }
298 
299  /// Populates `handles` with all handles pointing to the given Payload IR op.
300  /// Returns success if such handles exist, failure otherwise.
301  /// If `includeOutOfScope` is set to "true", handles that are defined in
302  /// regions beyond the most recent isolated from above region are included.
303  LogicalResult getHandlesForPayloadOp(Operation *op,
304  SmallVectorImpl<Value> &handles,
305  bool includeOutOfScope = false) const;
306 
307  /// Populates `handles` with all handles pointing to the given payload IR
308  /// value. Returns success if such handles exist, failure otherwise.
309  /// If `includeOutOfScope` is set to "true", handles that are defined in
310  /// regions beyond the most recent isolated from above region are included.
311  LogicalResult getHandlesForPayloadValue(Value payloadValue,
312  SmallVectorImpl<Value> &handles,
313  bool includeOutOfScope = false) const;
314 
315  /// Applies the transformation specified by the given transform op and updates
316  /// the state accordingly.
317  DiagnosedSilenceableFailure applyTransform(TransformOpInterface transform);
318 
319  /// Records the mapping between a block argument in the transform IR and a
320  /// list of operations in the payload IR. The arguments must be defined in
321  /// blocks of the currently processed transform IR region, typically after a
322  /// region scope is defined.
323  ///
324  /// Returns failure if the payload does not satisfy the conditions associated
325  /// with the type of the handle value.
326  LogicalResult mapBlockArguments(BlockArgument argument,
327  ArrayRef<Operation *> operations) {
328  assert(argument.getParentRegion() == regionStack.back()->region &&
329  "mapping block arguments from a region other than the active one");
330  return setPayloadOps(argument, operations);
331  }
332  LogicalResult mapBlockArgument(BlockArgument argument,
333  ArrayRef<MappedValue> values);
334  LogicalResult mapBlockArguments(Block::BlockArgListType arguments,
336 
337  // Forward declarations to support limited visibility.
338  class RegionScope;
339 
340  /// Creates a new region scope for the given region. The region is expected to
341  /// be nested in the currently processed region.
342  // Implementation note: this method is inline but implemented outside of the
343  // class body to comply with visibility and full-declaration requirements.
344  inline RegionScope make_region_scope(Region &region);
345 
346  /// A RAII object maintaining a "stack frame" for a transform IR region. When
347  /// applying a transform IR operation that contains a region, the caller is
348  /// expected to create a RegionScope before applying the ops contained in the
349  /// region. This ensures that the mappings between values defined in the
350  /// transform IR region and payload IR operations are cleared when the region
351  /// processing ends; such values cannot be accessed outside the region.
352  class RegionScope {
353  public:
354  /// Forgets the mapping from or to values defined in the associated
355  /// transform IR region, and restores the mapping that existed before
356  /// entering this scope.
357  ~RegionScope();
358 
359  private:
360  /// Creates a new scope for mappings between values defined in the given
361  /// transform IR region and payload IR objects.
362  RegionScope(TransformState &state, Region &region)
363  : state(state), region(&region) {
364  auto res = state.mappings.insert(
365  std::make_pair(&region, std::make_unique<Mappings>()));
366  assert(res.second && "the region scope is already present");
367  (void)res;
368  state.regionStack.push_back(this);
369  }
370 
371  /// Back-reference to the transform state.
372  TransformState &state;
373 
374  /// The region this scope is associated with.
375  Region *region;
376 
377  /// The transform op within this region that is currently being applied.
378  TransformOpInterface currentTransform;
379 
381  };
382  friend class RegionScope;
383 
384  /// Base class for TransformState extensions that allow TransformState to
385  /// contain user-specified information in the state object. Clients are
386  /// expected to derive this class, add the desired fields, and make the
387  /// derived class compatible with the MLIR TypeID mechanism:
388  ///
389  /// ```mlir
390  /// class MyExtension final : public TransformState::Extension {
391  /// public:
392  /// MyExtension(TranfsormState &state, int myData)
393  /// : Extension(state) {...}
394  /// private:
395  /// int mySupplementaryData;
396  /// };
397  /// ```
398  ///
399  /// Instances of this and derived classes are not expected to be created by
400  /// the user, instead they are directly constructed within a TransformState. A
401  /// TransformState can only contain one extension with the given TypeID.
402  /// Extensions can be obtained from a TransformState instance, and can be
403  /// removed when they are no longer required.
404  ///
405  /// ```mlir
406  /// transformState.addExtension<MyExtension>(/*myData=*/42);
407  /// MyExtension *ext = transformState.getExtension<MyExtension>();
408  /// ext->doSomething();
409  /// ```
410  class Extension {
411  // Allow TransformState to allocate Extensions.
412  friend class TransformState;
413 
414  public:
415  /// Base virtual destructor.
416  // Out-of-line definition ensures symbols are emitted in a single object
417  // file.
418  virtual ~Extension();
419 
420  protected:
421  /// Constructs an extension of the given TransformState object.
422  Extension(TransformState &state) : state(state) {}
423 
424  /// Provides read-only access to the parent TransformState object.
425  const TransformState &getTransformState() const { return state; }
426 
427  /// Replaces the given payload op with another op. If the replacement op is
428  /// null, removes the association of the payload op with its handle. Returns
429  /// failure if the op is not associated with any handle.
430  ///
431  /// Note: This function does not update value handles. None of the original
432  /// op's results are allowed to be mapped to any value handle.
433  LogicalResult replacePayloadOp(Operation *op, Operation *replacement);
434 
435  /// Replaces the given payload value with another value. If the replacement
436  /// value is null, removes the association of the payload value with its
437  /// handle. Returns failure if the value is not associated with any handle.
438  LogicalResult replacePayloadValue(Value value, Value replacement);
439 
440  private:
441  /// Back-reference to the state that is being extended.
442  TransformState &state;
443  };
444 
445  /// Adds a new Extension of the type specified as template parameter,
446  /// constructing it with the arguments provided. The extension is owned by the
447  /// TransformState. It is expected that the state does not already have an
448  /// extension of the same type. Extension constructors are expected to take
449  /// a reference to TransformState as first argument, automatically supplied
450  /// by this call.
451  template <typename Ty, typename... Args>
452  Ty &addExtension(Args &&...args) {
453  static_assert(
454  std::is_base_of<Extension, Ty>::value,
455  "only an class derived from TransformState::Extension is allowed here");
456  auto ptr = std::make_unique<Ty>(*this, std::forward<Args>(args)...);
457  auto result = extensions.try_emplace(TypeID::get<Ty>(), std::move(ptr));
458  assert(result.second && "extension already added");
459  return *static_cast<Ty *>(result.first->second.get());
460  }
461 
462  /// Returns the extension of the specified type.
463  template <typename Ty>
464  Ty *getExtension() {
465  static_assert(
466  std::is_base_of<Extension, Ty>::value,
467  "only an class derived from TransformState::Extension is allowed here");
468  auto iter = extensions.find(TypeID::get<Ty>());
469  if (iter == extensions.end())
470  return nullptr;
471  return static_cast<Ty *>(iter->second.get());
472  }
473 
474  /// Removes the extension of the specified type.
475  template <typename Ty>
477  static_assert(
478  std::is_base_of<Extension, Ty>::value,
479  "only an class derived from TransformState::Extension is allowed here");
480  extensions.erase(TypeID::get<Ty>());
481  }
482 
483 private:
484  /// Identifier for storing top-level value in the `operations` mapping.
485  static constexpr Value kTopLevelValue = Value();
486 
487  /// Creates a state for transform ops living in the given region. The second
488  /// argument points to the root operation in the payload IR being transformed,
489  /// which may or may not contain the region with transform ops. Additional
490  /// options can be provided through the trailing configuration object.
491  TransformState(Region *region, Operation *payloadRoot,
492  const RaggedArray<MappedValue> &extraMappings = {},
493  const TransformOptions &options = TransformOptions());
494 
495  /// Returns the mappings frame for the region in which the value is defined.
496  /// If `allowOutOfScope` is set to "false", asserts that the value is in
497  /// scope, based on the current stack of frames.
498  const Mappings &getMapping(Value value, bool allowOutOfScope = false) const {
499  return const_cast<TransformState *>(this)->getMapping(value,
500  allowOutOfScope);
501  }
502  Mappings &getMapping(Value value, bool allowOutOfScope = false) {
503  Region *region = value.getParentRegion();
504  auto it = mappings.find(region);
505  assert(it != mappings.end() &&
506  "trying to find a mapping for a value from an unmapped region");
507 #ifndef NDEBUG
508  if (!allowOutOfScope) {
509  for (Region *r : llvm::reverse(llvm::make_first_range(mappings))) {
510  if (r == region)
511  break;
512  if (r->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
513  llvm_unreachable("trying to get mapping beyond region that is "
514  "isolated from above");
515  }
516  }
517 #endif // NDEBUG
518  return *it->second;
519  }
520 
521  /// Returns the mappings frame for the region in which the operation resides.
522  /// If `allowOutOfScope` is set to "false", asserts that the operation is in
523  /// scope, based on the current stack of frames.
524  const Mappings &getMapping(Operation *operation,
525  bool allowOutOfScope = false) const {
526  return const_cast<TransformState *>(this)->getMapping(operation,
527  allowOutOfScope);
528  }
529  Mappings &getMapping(Operation *operation, bool allowOutOfScope = false) {
530  Region *region = operation->getParentRegion();
531  auto it = mappings.find(region);
532  assert(it != mappings.end() &&
533  "trying to find a mapping for an operation from an unmapped region");
534 #ifndef NDEBUG
535  if (!allowOutOfScope) {
536  for (Region *r : llvm::reverse(llvm::make_first_range(mappings))) {
537  if (r == region)
538  break;
539  if (r->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
540  llvm_unreachable("trying to get mapping beyond region that is "
541  "isolated from above");
542  }
543  }
544 #endif // NDEBUG
545  return *it->second;
546  }
547 
548  /// Updates the state to include the associations between op results and the
549  /// provided result of applying a transform op.
550  LogicalResult updateStateFromResults(const TransformResults &results,
551  ResultRange opResults);
552 
553  /// Returns a list of all ops that the given transform IR value corresponds
554  /// to. In case an op was erased, the returned list contains nullptr. This
555  /// function is helpful for transformations that apply to a particular handle.
556  ArrayRef<Operation *> getPayloadOpsView(Value value) const;
557 
558  /// Returns a list of payload IR values that the given transform IR value
559  /// corresponds to.
560  ArrayRef<Value> getPayloadValuesView(Value handleValue) const;
561 
562  /// Sets the payload IR ops associated with the given transform IR value
563  /// (handle). A payload op may be associated multiple handles as long as
564  /// at most one of them gets consumed by further transformations.
565  /// For example, a hypothetical "find function by name" may be called twice in
566  /// a row to produce two handles pointing to the same function:
567  ///
568  /// %0 = transform.find_func_by_name { name = "myfunc" }
569  /// %1 = transform.find_func_by_name { name = "myfunc" }
570  ///
571  /// which is valid by itself. However, calling a hypothetical "rewrite and
572  /// rename function" transform on both handles:
573  ///
574  /// transform.rewrite_and_rename %0 { new_name = "func" }
575  /// transform.rewrite_and_rename %1 { new_name = "func" }
576  ///
577  /// is invalid given the transformation "consumes" the handle as expressed
578  /// by side effects. Practically, a transformation consuming a handle means
579  /// that the associated payload operation may no longer exist.
580  ///
581  /// Similarly, operation handles may be invalidate and should not be used
582  /// after a transform that consumed a value handle pointing to a payload value
583  /// defined by the operation as either block argument or op result. For
584  /// example, in the following sequence, the last transform operation rewrites
585  /// the callee to not return a specified result:
586  ///
587  /// %0 = transform.find_call "myfunc"
588  /// %1 = transform.find_results_of_calling "myfunc"
589  /// transform.drop_call_result_from_signature %1[0]
590  ///
591  /// which requires the call operations to be recreated. Therefore, the handle
592  /// %0 becomes associated with a dangling pointer and should not be used.
593  ///
594  /// Returns failure if the payload does not satisfy the conditions associated
595  /// with the type of the handle value. The value is expected to have a type
596  /// implementing TransformHandleTypeInterface.
597  LogicalResult setPayloadOps(Value value, ArrayRef<Operation *> targets);
598 
599  /// Sets the payload IR values association with the given transform IR value
600  /// (handle). A payload value may be associated with multiple handles as long
601  /// as at most one of them is consumed by further transformations. For
602  /// example, a hypothetical "get results of calls to function with the given
603  /// name" transform may be performed twice in a row producing handles pointing
604  /// to the same values:
605  ///
606  /// %0 = transform.find_results_of_calling "myfunc"
607  /// %1 = transform.find_results_of_calling "myfunc"
608  ///
609  /// which is valid by itself. However, calling a hypothetical "erase value
610  /// producer" transform on both handles:
611  ///
612  /// transform.erase_value_produce %0
613  /// transform.erase_value_produce %1
614  ///
615  /// is invalid provided the transformation "consumes" the handle as expressed
616  /// by side effects (which themselves reflect the semantics of the transform
617  /// erasing the producer and making the handle dangling). Practically, a
618  /// transformation consuming a handle means the associated payload value may
619  /// no longer exist.
620  ///
621  /// Similarly, value handles are invalidated and should not be used after a
622  /// transform that consumed an operation handle pointing to the payload IR
623  /// operation defining the values associated the value handle, as either block
624  /// arguments or op results, or any ancestor operation. For example,
625  ///
626  /// %0 = transform.find_call "myfunc"
627  /// %1 = transform.find_results_of_calling "myfunc"
628  /// transform.rewrite_and_rename %0 { new_name = "func" }
629  ///
630  /// makes %1 unusable after the last transformation if it consumes %0. When an
631  /// operation handle is consumed, it usually indicates that the operation was
632  /// destroyed or heavily modified, meaning that the values it defines may no
633  /// longer exist.
634  ///
635  /// Returns failure if the payload values do not satisfy the conditions
636  /// associated with the type of the handle value. The value is expected to
637  /// have a type implementing TransformValueHandleTypeInterface.
638  LogicalResult setPayloadValues(Value handle, ValueRange payloadValues);
639 
640  /// Sets the parameters associated with the given transform IR value. Returns
641  /// failure if the parameters do not satisfy the conditions associated with
642  /// the type of the value. The value is expected to have a type implementing
643  /// TransformParamTypeInterface.
644  LogicalResult setParams(Value value, ArrayRef<Param> params);
645 
646  /// Forgets the payload IR ops associated with the given transform IR value,
647  /// as well as any association between value handles and the results of said
648  /// payload IR op.
649  ///
650  /// If `allowOutOfScope` is set to "false", asserts that the handle is in
651  /// scope, based on the current stack of frames.
652  void forgetMapping(Value opHandle, ValueRange origOpFlatResults,
653  bool allowOutOfScope = false);
654 
655  void forgetValueMapping(Value valueHandle,
656  ArrayRef<Operation *> payloadOperations);
657 
658  /// Replaces the given payload op with another op. If the replacement op is
659  /// null, removes the association of the payload op with its handle. Returns
660  /// failure if the op is not associated with any handle.
661  ///
662  /// Note: This function does not update value handles. None of the original
663  /// op's results are allowed to be mapped to any value handle.
664  LogicalResult replacePayloadOp(Operation *op, Operation *replacement);
665 
666  /// Replaces the given payload value with another value. If the replacement
667  /// value is null, removes the association of the payload value with its
668  /// handle. Returns failure if the value is not associated with any handle.
669  LogicalResult replacePayloadValue(Value value, Value replacement);
670 
671  /// Records handle invalidation reporters into `newlyInvalidated`.
672  /// Specifically,
673  /// - `handle` is the op operand that consumes the handle,
674  /// - `potentialAncestors` is a list of ancestors of the payload operation
675  /// that the consumed handle is associated with, including itself,
676  /// - `throughValue` is the payload value the handle to which is consumed,
677  /// when it is the case, null when the operation handle is consumed
678  /// directly.
679  /// Iterates over all known operation and value handles and records reporters
680  /// for any potential future use of `handle` or any other handle that is
681  /// invalidated by its consumption, i.e., any handle pointing to any payload
682  /// IR entity (operation or value) associated with the same payload IR entity
683  /// as the consumed handle, or any nested payload IR entity. If
684  /// `potentialAncestors` is empty, records the reporter anyway. Does not
685  /// override existing reporters. This must remain a const method so it doesn't
686  /// inadvertently mutate `invalidatedHandles` too early.
687  void recordOpHandleInvalidation(OpOperand &consumingHandle,
688  ArrayRef<Operation *> potentialAncestors,
689  Value throughValue,
690  InvalidatedHandleMap &newlyInvalidated) const;
691 
692  /// Records handle invalidation reporters into `newlyInvalidated`.
693  /// Specifically,
694  /// - `consumingHandle` is the op operand that consumes the handle,
695  /// - `potentialAncestors` is a list of ancestors of the payload operation
696  /// that the consumed handle is associated with, including itself,
697  /// - `payloadOp` is the operation itself,
698  /// - `otherHandle` is another that may be associated with the affected
699  /// payload operations
700  /// - `throughValue` is the payload value the handle to which is consumed,
701  /// when it is the case, null when the operation handle is consumed
702  /// directly.
703  /// Looks at the payload opreations associated with `otherHandle` and if any
704  /// of these operations has an ancestor (or is itself) listed in
705  /// `potentialAncestors`, records the error message describing the use of the
706  /// invalidated handle. Does nothing if `otherHandle` already has a reporter
707  /// associated with it. This must remain a const method so it doesn't
708  /// inadvertently mutate `invalidatedHandles` too early.
709  void recordOpHandleInvalidationOne(
710  OpOperand &consumingHandle, ArrayRef<Operation *> potentialAncestors,
711  Operation *payloadOp, Value otherHandle, Value throughValue,
712  InvalidatedHandleMap &newlyInvalidated) const;
713 
714  /// Records handle invalidation reporters into `newlyInvalidated`.
715  /// Specifically,
716  /// - `opHandle` is the op operand that consumes the handle;
717  /// - `potentialAncestors` is a list of ancestors of the payload operation
718  /// that the consumed handle is associated with, including itself;
719  /// - `payloadValue` is the value defined by the operation associated with
720  /// the consuming handle as either op result or block argument;
721  /// - `valueHandle` is another that may be associated with the payload value.
722  /// Looks at the payload values associated with `valueHandle` and if any of
723  /// these values is defined, as op result or block argument, by an operation
724  /// whose ancestor (or the operation itself) is listed in
725  /// `potentialAncestors`, records the error message describing the use of the
726  /// invalidated handle. Does nothing if `valueHandle` already has a reporter
727  /// associated with it. This must remain a const method so it doesn't
728  /// inadvertently mutate `invalidatedHandles` too early.
729  void recordValueHandleInvalidationByOpHandleOne(
730  OpOperand &opHandle, ArrayRef<Operation *> potentialAncestors,
731  Value payloadValue, Value valueHandle,
732  InvalidatedHandleMap &newlyInvalidated) const;
733 
734  /// Records handle invalidation reporters into `newlyInvalidated`.
735  /// Specifically,
736  /// - `valueHandle` is the op operand that consumes the handle,
737  /// - `throughValue` is the payload value the handle to which is consumed,
738  /// when it is the case, null when the operation handle is consumed
739  /// directly.
740  /// Iterates over all known operation and value handles and records reporters
741  /// for any potential future use of `handle` or any other handle that is
742  /// invalidated by its consumption, i.e., any handle pointing to any payload
743  /// IR entity (operation or value) associated with the same payload IR entity
744  /// as the consumed handle, or any nested payload IR entity. Does not override
745  /// existing reporters. This must remain a const method so it doesn't
746  /// inadvertently mutate `invalidatedHandles` too early.
747  void
748  recordValueHandleInvalidation(OpOperand &valueHandle,
749  InvalidatedHandleMap &newlyInvalidated) const;
750 
751  /// Checks that the operation does not use invalidated handles as operands.
752  /// Reports errors and returns failure if it does. Otherwise, invalidates the
753  /// handles consumed by the operation as well as any handles pointing to
754  /// payload IR operations nested in the operations associated with the
755  /// consumed handles.
756  LogicalResult
757  checkAndRecordHandleInvalidation(TransformOpInterface transform);
758 
759  /// Implementation of the checkAndRecordHandleInvalidation. This must remain a
760  /// const method so it doesn't inadvertently mutate `invalidatedHandles` too
761  /// early.
762  LogicalResult checkAndRecordHandleInvalidationImpl(
763  transform::TransformOpInterface transform,
764  transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const;
765 
766  /// Remove all nullptrs from op handles that were added by `replacePayloadOp`.
767  void compactOpHandles();
768 
769  /// A stack of mappings between transform IR values and payload IR ops,
770  /// aggregated by the region in which the transform IR values are defined.
771  /// We use a pointer to the Mappings struct so that reallocations inside
772  /// MapVector don't invalidate iterators when we apply nested transform ops
773  /// while also iterating over the mappings.
774  llvm::MapVector<Region *, std::unique_ptr<Mappings>> mappings;
775 
776  /// Op handles may be temporarily mapped to nullptr to avoid invalidating
777  /// payload op iterators. This set contains all op handles with nullptrs.
778  /// These handles are "compacted" (i.e., nullptrs removed) at the end of each
779  /// transform.
780  DenseSet<Value> opHandlesToCompact;
781 
782  /// Extensions attached to the TransformState, identified by the TypeID of
783  /// their type. Only one extension of any given type is allowed.
784  DenseMap<TypeID, std::unique_ptr<Extension>> extensions;
785 
786  /// The top-level operation that contains all payload IR, typically a module.
787  Operation *topLevel;
788 
789  /// Extra mapped values (payload operations, values or parameters) to be
790  /// associated with additional entry block arguments of the top-level
791  /// transform operation.
792  RaggedArray<MappedValue> topLevelMappedValues;
793 
794  /// Additional options controlling the transformation state behavior.
795  TransformOptions options;
796 
797  /// The mapping from invalidated handles to the error-reporting functions that
798  /// describe when the handles were invalidated. Calling such a function emits
799  /// a user-visible diagnostic with an additional note pointing to the given
800  /// location.
801  InvalidatedHandleMap invalidatedHandles;
802 
803  /// A stack of nested regions that are being processed in the transform IR.
804  /// Each region must be an ancestor of the following regions in this list.
805  /// These are also the keys for "mappings".
806  SmallVector<RegionScope *> regionStack;
807 
808  /// The top-level region scope. The first (bottom) element of `regionStack`
809  /// is the top-level region scope object.
810  std::unique_ptr<RegionScope> topLevelRegionScope;
811 };
812 
813 /// Local mapping between values defined by a specific op implementing the
814 /// TransformOpInterface and the payload IR ops they correspond to.
816  friend class TransformState;
817 
818 public:
819  /// Indicates that the result of the transform IR op at the given position
820  /// corresponds to the given list of payload IR ops. Each result must be set
821  /// by the transformation exactly once in case of transformation succeeding.
822  /// The value must have a type implementing TransformHandleTypeInterface.
823  template <typename Range>
824  void set(OpResult value, Range &&ops) {
825  int64_t position = value.getResultNumber();
826  assert(position < static_cast<int64_t>(operations.size()) &&
827  "setting results for a non-existent handle");
828  assert(operations[position].data() == nullptr && "results already set");
829  assert(params[position].data() == nullptr &&
830  "another kind of results already set");
831  assert(values[position].data() == nullptr &&
832  "another kind of results already set");
833  operations.replace(position, std::forward<Range>(ops));
834  }
835 
836  /// Indicates that the result of the transform IR op at the given position
837  /// corresponds to the given list of payload IR ops. Each result must be set
838  /// by the transformation exactly once in case of transformation succeeding.
839  /// The value must have a type implementing TransformHandleTypeInterface.
840  void set(OpResult value, std::initializer_list<Operation *> ops) {
841  set(value, ArrayRef<Operation *>(ops));
842  }
843 
844  /// Indicates that the result of the transform IR op at the given position
845  /// corresponds to the given list of parameters. Each result must be set by
846  /// the transformation exactly once in case of transformation succeeding. The
847  /// value must have a type implementing TransformParamTypeInterface.
849 
850  /// Indicates that the result of the transform IR op at the given position
851  /// corresponds to the given range of payload IR values. Each result must be
852  /// set by the transformation exactly once in case of transformation
853  /// succeeding. The value must have a type implementing
854  /// TransformValueHandleTypeInterface.
855  template <typename Range>
856  void setValues(OpResult handle, Range &&values) {
857  int64_t position = handle.getResultNumber();
858  assert(position < static_cast<int64_t>(this->values.size()) &&
859  "setting values for a non-existent handle");
860  assert(this->values[position].data() == nullptr && "values already set");
861  assert(operations[position].data() == nullptr &&
862  "another kind of results already set");
863  assert(params[position].data() == nullptr &&
864  "another kind of results already set");
865  this->values.replace(position, std::forward<Range>(values));
866  }
867 
868  /// Indicates that the result of the transform IR op at the given position
869  /// corresponds to the given range of payload IR values. Each result must be
870  /// set by the transformation exactly once in case of transformation
871  /// succeeding. The value must have a type implementing
872  /// TransformValueHandleTypeInterface.
873  void setValues(OpResult handle, std::initializer_list<Value> values) {
874  setValues(handle, ArrayRef<Value>(values));
875  }
876 
877  /// Indicates that the result of the transform IR op at the given position
878  /// corresponds to the given range of mapped values. All mapped values are
879  /// expected to be compatible with the type of the result, e.g., if the result
880  /// is an operation handle, all mapped values are expected to be payload
881  /// operations.
882  void setMappedValues(OpResult handle, ArrayRef<MappedValue> values);
883 
884  /// Sets the currently unset results to empty lists of the kind expected by
885  /// the corresponding results of the given `transform` op.
886  void setRemainingToEmpty(TransformOpInterface transform);
887 
888 private:
889  /// Creates an instance of TransformResults that expects mappings for
890  /// `numSegments` values, which may be associated with payload operations or
891  /// parameters.
892  explicit TransformResults(unsigned numSegments);
893 
894  /// Gets the list of operations associated with the result identified by its
895  /// number in the list of operation results. The result must have been set to
896  /// be associated with payload IR operations.
897  ArrayRef<Operation *> get(unsigned resultNumber) const;
898 
899  /// Gets the list of parameters associated with the result identified by its
900  /// number in the list of operation results. The result must have been set to
901  /// be associated with parameters.
902  ArrayRef<TransformState::Param> getParams(unsigned resultNumber) const;
903 
904  /// Gets the list of payload IR values associated with the result identified
905  /// by its number in the list of operation results. The result must have been
906  /// set to be associated with payload IR values.
907  ArrayRef<Value> getValues(unsigned resultNumber) const;
908 
909  /// Returns `true` if the result identified by its number in the list of
910  /// operation results is associated with a list of parameters, `false`
911  /// otherwise.
912  bool isParam(unsigned resultNumber) const;
913 
914  /// Returns `true` if the result identified by its number in the list of
915  /// operation results is associated with a list of payload IR value, `false`
916  /// otherwise.
917  bool isValue(unsigned resultNumber) const;
918 
919  /// Returns `true` if the result identified by its number in the list of
920  /// operation results is associated with something.
921  bool isSet(unsigned resultNumber) const;
922 
923  /// Pointers to payload IR ops that are associated with results of a transform
924  /// IR op.
925  RaggedArray<Operation *> operations;
926 
927  /// Parameters that are associated with results of the transform IR op.
928  RaggedArray<Param> params;
929 
930  /// Payload IR values that are associated with results of a transform IR op.
931  RaggedArray<Value> values;
932 };
933 
934 /// Creates a RAII object the lifetime of which corresponds to the new mapping
935 /// for transform IR values defined in the given region. Values defined in
936 /// surrounding regions remain accessible.
938  return RegionScope(*this, region);
939 }
940 
941 /// A configuration object for customizing a `TrackingListener`.
943  using SkipHandleFn = std::function<bool(Value)>;
944 
945  /// An optional function that returns "true" for handles that do not have to
946  /// be updated. These are typically dead or consumed handles.
948 
949  /// If set to "true", the name of a replacement op must match the name of the
950  /// original op. If set to "false", the names of the payload ops tracked in a
951  /// handle may change as the tracking listener updates the transform state.
953 
954  /// If set to "true", cast ops (that implement the CastOpInterface) are
955  /// skipped and the replacement op search continues with the operands of the
956  /// cast op.
957  bool skipCastOps = true;
958 };
959 
960 /// A listener that updates a TransformState based on IR modifications. This
961 /// listener can be used during a greedy pattern rewrite to keep the transform
962 /// state up-to-date.
965 public:
966  /// Create a new TrackingListener for usage in the specified transform op.
967  /// Optionally, a function can be specified to identify handles that should
968  /// do not have to be updated.
969  TrackingListener(TransformState &state, TransformOpInterface op,
971 
972 protected:
973  /// Return a replacement payload op for the given op, which is going to be
974  /// replaced with the given values. By default, if all values are defined by
975  /// the same op, which also has the same type as the given op, that defining
976  /// op is used as a replacement.
977  ///
978  /// A "failure" return value indicates that no replacement operation could be
979  /// found. A "nullptr" return value indicates that no replacement op is needed
980  /// (e.g., handle is dead or was consumed) and that the payload op should
981  /// be dropped from the mapping.
982  ///
983  /// Example: A tracked "linalg.generic" with two results is replaced with two
984  /// values defined by (another) "linalg.generic". It is reasonable to assume
985  /// that the replacement "linalg.generic" represents the same "computation".
986  /// Therefore, the payload op mapping is updated to the defining op of the
987  /// replacement values.
988  ///
989  /// Counter Example: A "linalg.generic" is replaced with values defined by an
990  /// "scf.for". Without further investigation, the relationship between the
991  /// "linalg.generic" and the "scf.for" is unclear. They may not represent the
992  /// same computation; e.g., there may be tiled "linalg.generic" inside the
993  /// loop body that represents the original computation. Therefore, the
994  /// TrackingListener is conservative by default: it drops the mapping and
995  /// triggers the "payload replacement not found" notification. This default
996  /// behavior can be customized in `TrackingListenerConfig`.
997  ///
998  /// If no replacement op could be found according to the rules mentioned
999  /// above, this function tries to skip over cast-like ops that implement
1000  /// `CastOpInterface`.
1001  ///
1002  /// Example: A tracked "linalg.generic" is replaced with "linalg.generic",
1003  /// wrapped in a "tensor.cast". A cast is a metadata-only operation and it is
1004  /// reasonable to assume that the wrapped "linalg.generic" represents the same
1005  /// computation as the original "linalg.generic". The mapping is updated
1006  /// accordingly.
1007  ///
1008  /// Certain ops (typically also metadata-only ops) are not considered casts,
1009  /// but should be skipped nonetheless. Such ops should implement
1010  /// `FindPayloadReplacementOpInterface` to specify with which operands the
1011  /// lookup should continue.
1012  ///
1013  /// Example: A tracked "linalg.generic" is replaced with "linalg.generic",
1014  /// wrapped in a "tensor.reshape". A reshape is a metadata-only operation but
1015  /// not cast. (Implementing `CastOpInterface` would be incorrect and cause
1016  /// invalid foldings.) However, due to its `FindPayloadReplacementOpInterface`
1017  /// implementation, the replacement op lookup continues with the wrapped
1018  /// "linalg.generic" and the mapping is updated accordingly.
1019  ///
1020  /// Derived classes may override `findReplacementOp` to specify custom
1021  /// replacement rules.
1023  findReplacementOp(Operation *&result, Operation *op,
1024  ValueRange newValues) const;
1025 
1026  /// Notify the listener that the pattern failed to match the given operation,
1027  /// and provide a callback to populate a diagnostic with the reason why the
1028  /// failure occurred.
1029  void
1031  function_ref<void(Diagnostic &)> reasonCallback) override;
1032 
1033  /// This function is called when a tracked payload op is dropped because no
1034  /// replacement op was found. Derived classes can implement this function for
1035  /// custom error handling.
1036  virtual void
1039 
1040  /// Return the single op that defines all given values (if any).
1041  static Operation *getCommonDefiningOp(ValueRange values);
1042 
1043  /// Return the transform op in which this TrackingListener is used.
1044  TransformOpInterface getTransformOp() const { return transformOp; }
1045 
1046 private:
1047  friend class TransformRewriter;
1048 
1049  void notifyOperationErased(Operation *op) override;
1050 
1051  void notifyOperationReplaced(Operation *op, ValueRange newValues) override;
1052  using Listener::notifyOperationReplaced;
1053 
1054  /// The transform op in which this TrackingListener is used.
1055  TransformOpInterface transformOp;
1056 
1057  /// The handles that are consumed by the transform op.
1058  DenseSet<Value> consumedHandles;
1059 
1060  /// Tracking listener configuration.
1062 };
1063 
1064 /// A specialized listener that keeps track of cases in which no replacement
1065 /// payload could be found. The error state of this listener must be checked
1066 /// before the end of its lifetime.
1068 public:
1070 
1071  ~ErrorCheckingTrackingListener() override;
1072 
1073  /// Check and return the current error state of this listener. Afterwards,
1074  /// resets the error state to "success".
1076 
1077  /// Return "true" if this tracking listener had a failure.
1078  bool failed() const;
1079 
1080 protected:
1081  void
1083  DiagnosedSilenceableFailure &&diag) override;
1084 
1085 private:
1086  /// The error state of this listener. "Success" indicates that no error
1087  /// happened so far.
1089 
1090  /// The number of errors that have been encountered.
1091  int64_t errorCounter = 0;
1092 };
1093 
1094 /// This is a special rewriter to be used in transform op implementations,
1095 /// providing additional helper functions to update the transform state, etc.
1096 // TODO: Helper functions will be added in a subsequent change.
1098 protected:
1099  friend class TransformState;
1100 
1101  /// Create a new TransformRewriter.
1102  explicit TransformRewriter(MLIRContext *ctx,
1103  ErrorCheckingTrackingListener *listener);
1104 
1105 public:
1106  /// Return "true" if the tracking listener had failures.
1107  bool hasTrackingFailures() const;
1108 
1109  /// Silence all tracking failures that have been encountered so far.
1110  void silenceTrackingFailure();
1111 
1112  /// Notify the transform dialect interpreter that the given op has been
1113  /// replaced with another op and that the mapping between handles and payload
1114  /// ops/values should be updated. This function should be called before the
1115  /// original op is erased. It fails if the operation could not be replaced,
1116  /// e.g., because the original operation is not tracked.
1117  ///
1118  /// Note: As long as IR modifications are performed through this rewriter,
1119  /// the transform state is usually updated automatically. This function should
1120  /// be used when unsupported rewriter API is used; e.g., updating all uses of
1121  /// a tracked operation one-by-one instead of using `RewriterBase::replaceOp`.
1122  LogicalResult notifyPayloadOperationReplaced(Operation *op,
1123  Operation *replacement);
1124 
1125 private:
1126  ErrorCheckingTrackingListener *const listener;
1127 };
1128 
1129 /// This trait is supposed to be attached to Transform dialect operations that
1130 /// can be standalone top-level transforms. Such operations typically contain
1131 /// other Transform dialect operations that can be executed following some
1132 /// control flow logic specific to the current operation. The operations with
1133 /// this trait are expected to have at least one single-block region with at
1134 /// least one argument of type implementing TransformHandleTypeInterface. The
1135 /// operations are also expected to be valid without operands, in which case
1136 /// they are considered top-level, and with one or more arguments, in which case
1137 /// they are considered nested. Top-level operations have the block argument of
1138 /// the entry block in the Transform IR correspond to the root operation of
1139 /// Payload IR. Nested operations have the block argument of the entry block in
1140 /// the Transform IR correspond to a list of Payload IR operations mapped to the
1141 /// first operand of the Transform IR operation. The operation must implement
1142 /// TransformOpInterface.
1143 template <typename OpTy>
1145  : public OpTrait::TraitBase<OpTy, PossibleTopLevelTransformOpTrait> {
1146 public:
1147  /// Verifies that `op` satisfies the invariants of this trait. Not expected to
1148  /// be called directly.
1149  static LogicalResult verifyTrait(Operation *op) {
1151  }
1152 
1153  /// Returns the single block of the given region.
1154  Block *getBodyBlock(unsigned region = 0) {
1155  return &this->getOperation()->getRegion(region).front();
1156  }
1157 
1158  /// Populates `effects` with side effects implied by this trait.
1162  this->getOperation(), cast<OpTy>(this->getOperation()).getRoot(),
1163  *getBodyBlock(), effects);
1164  }
1165 
1166  /// Sets up the mapping between the entry block of the given region of this op
1167  /// and the relevant list of Payload IR operations in the given state. The
1168  /// state is expected to be already scoped at the region of this operation.
1169  LogicalResult mapBlockArguments(TransformState &state, Region &region) {
1170  assert(region.getParentOp() == this->getOperation() &&
1171  "op comes from the wrong region");
1173  state, this->getOperation(), region);
1174  }
1175  LogicalResult mapBlockArguments(TransformState &state) {
1176  assert(
1177  this->getOperation()->getNumRegions() == 1 &&
1178  "must indicate the region to map if the operation has more than one");
1179  return mapBlockArguments(state, this->getOperation()->getRegion(0));
1180  }
1181 };
1182 
1183 class ApplyToEachResultList;
1184 
1185 /// Trait implementing the TransformOpInterface for operations applying a
1186 /// transformation to a single operation handle and producing an arbitrary
1187 /// number of handles and parameter values.
1188 /// The op must implement a method with the following signature:
1189 /// - DiagnosedSilenceableFailure applyToOne(OpTy,
1190 /// ApplyToEachResultList &results, TransformState &state)
1191 /// to perform a transformation that is applied in turn to all payload IR
1192 /// operations that correspond to the handle of the transform IR operation.
1193 /// In `applyToOne`, OpTy is either Operation* or a concrete payload IR Op class
1194 /// that the transformation is applied to (and NOT the class of the transform IR
1195 /// op).
1196 /// The `applyToOne` method takes an empty `results` vector that it fills with
1197 /// zero, one or multiple operations depending on the number of results expected
1198 /// by the transform op.
1199 /// The number of results must match the number of results of the transform op.
1200 /// `applyToOne` is allowed to fill the `results` with all null elements to
1201 /// signify that the transformation did not apply to the payload IR operations.
1202 /// Such null elements are filtered out from results before return.
1203 ///
1204 /// The transform op having this trait is expected to have a single operand.
1205 template <typename OpTy>
1207  : public OpTrait::TraitBase<OpTy, TransformEachOpTrait> {
1208 public:
1209  /// Calls `applyToOne` for every payload operation associated with the operand
1210  /// of this transform IR op, the following case disjunction happens:
1211  /// 1. If not target payload ops are associated to the operand then fill the
1212  /// results vector with the expected number of null elements and return
1213  /// success. This is the corner case handling that allows propagating
1214  /// the "no-op" case gracefully to improve usability.
1215  /// 2. If any `applyToOne` returns definiteFailure, the transformation is
1216  /// immediately considered definitely failed and we return.
1217  /// 3. All applications of `applyToOne` are checked to return a number of
1218  /// results expected by the transform IR op. If not, this is a definite
1219  /// failure and we return early.
1220  /// 4. If `applyToOne` produces ops, associate them with the result of this
1221  /// transform op.
1222  /// 5. If any `applyToOne` return silenceableFailure, the transformation is
1223  /// considered silenceable.
1224  /// 6. Otherwise the transformation is considered successful.
1226  TransformResults &transformResults,
1227  TransformState &state);
1228 
1229  /// Checks that the op matches the expectations of this trait.
1230  static LogicalResult verifyTrait(Operation *op);
1231 };
1232 
1233 /// Side effect resource corresponding to the mapping between Transform IR
1234 /// values and Payload IR operations. An Allocate effect from this resource
1235 /// means creating a new mapping entry, it is always accompanied by a Write
1236 /// effect. A Read effect from this resource means accessing the mapping. A Free
1237 /// effect on this resource indicates the removal of the mapping entry,
1238 /// typically after a transformation that modifies the Payload IR operations
1239 /// associated with one of the Transform IR operation's operands. It is always
1240 /// accompanied by a Read effect. Read-after-Free and double-Free are not
1241 /// allowed (they would be problematic with "regular" memory effects too) as
1242 /// they indicate an attempt to access Payload IR operations that have been
1243 /// modified, potentially erased, by the previous transformations.
1244 // TODO: consider custom effects if these are not enabling generic passes such
1245 // as CSE/DCE to work.
1247  : public SideEffects::Resource::Base<TransformMappingResource> {
1248  StringRef getName() override { return "transform.mapping"; }
1249 };
1250 
1251 /// Side effect resource corresponding to the Payload IR itself. Only Read and
1252 /// Write effects are expected on this resource, with Write always accompanied
1253 /// by a Read (short of fully replacing the top-level Payload IR operation, one
1254 /// cannot modify the Payload IR without reading it first). This is intended
1255 /// to disallow reordering of Transform IR operations that mutate the Payload IR
1256 /// while still allowing the reordering of those that only access it.
1258  : public SideEffects::Resource::Base<PayloadIRResource> {
1259  StringRef getName() override { return "transform.payload_ir"; }
1260 };
1261 
1262 /// Populates `effects` with the memory effects indicating the operation on the
1263 /// given handle value:
1264 /// - consumes = Read + Free,
1265 /// - produces = Allocate + Write,
1266 /// - onlyReads = Read.
1269 void producesHandle(ResultRange handles,
1275 
1276 /// Checks whether the transform op consumes the given handle.
1277 bool isHandleConsumed(Value handle, transform::TransformOpInterface transform);
1278 
1279 /// Populates `effects` with the memory effects indicating the access to payload
1280 /// IR resource.
1283 
1284 /// Checks whether the transform op modifies the payload.
1285 bool doesModifyPayload(transform::TransformOpInterface transform);
1286 /// Checks whether the transform op reads the payload.
1287 bool doesReadPayload(transform::TransformOpInterface transform);
1288 
1289 /// Populates `consumedArguments` with positions of `block` arguments that are
1290 /// consumed by the operations in the `block`.
1292  Block &block, llvm::SmallDenseSet<unsigned> &consumedArguments);
1293 
1294 /// Trait implementing the MemoryEffectOpInterface for operations that "consume"
1295 /// their operands and produce new results.
1296 template <typename OpTy>
1298  : public OpTrait::TraitBase<OpTy, FunctionalStyleTransformOpTrait> {
1299 public:
1300  /// This op "consumes" the operands by reading and freeing then, "produces"
1301  /// the results by allocating and writing it and reads/writes the payload IR
1302  /// in the process.
1304  consumesHandle(this->getOperation()->getOpOperands(), effects);
1305  producesHandle(this->getOperation()->getOpResults(), effects);
1306  modifiesPayload(effects);
1307  }
1308 
1309  /// Checks that the op matches the expectations of this trait.
1310  static LogicalResult verifyTrait(Operation *op) {
1311  if (!op->getName().getInterface<MemoryEffectOpInterface>()) {
1312  op->emitError()
1313  << "FunctionalStyleTransformOpTrait should only be attached to ops "
1314  "that implement MemoryEffectOpInterface";
1315  }
1316  return success();
1317  }
1318 };
1319 
1320 /// Trait implementing the MemoryEffectOpInterface for operations that use their
1321 /// operands without consuming and without modifying the Payload IR to
1322 /// potentially produce new handles.
1323 template <typename OpTy>
1325  : public OpTrait::TraitBase<OpTy, NavigationTransformOpTrait> {
1326 public:
1327  /// This op produces handles to the Payload IR without consuming the original
1328  /// handles and without modifying the IR itself.
1330  onlyReadsHandle(this->getOperation()->getOpOperands(), effects);
1331  producesHandle(this->getOperation()->getOpResults(), effects);
1332  if (llvm::any_of(this->getOperation()->getOperandTypes(), [](Type t) {
1333  return isa<TransformHandleTypeInterface,
1334  TransformValueHandleTypeInterface>(t);
1335  })) {
1336  onlyReadsPayload(effects);
1337  }
1338  }
1339 
1340  /// Checks that the op matches the expectation of this trait.
1341  static LogicalResult verifyTrait(Operation *op) {
1342  if (!op->getName().getInterface<MemoryEffectOpInterface>()) {
1343  op->emitError() << "NavigationTransformOpTrait should only be attached "
1344  "to ops that implement MemoryEffectOpInterface";
1345  }
1346  return success();
1347  }
1348 };
1349 
1350 namespace detail {
1351 /// Non-template implementation of ParamProducerTransformOpTrait::getEffects().
1354 /// Non-template implementation of ParamProducerTransformOpTrait::verify().
1356 } // namespace detail
1357 
1358 /// Trait implementing the MemoryEffectsOpInterface for operations that produce
1359 /// transform dialect parameters. It marks all op results of
1360 /// TransformHandleTypeInterface as produced by the op, all operands as only
1361 /// read by the op and, if at least one of the operand is a handle to payload
1362 /// ops, the entire payload as potentially read. The op must only produce
1363 /// parameter-typed results.
1364 template <typename OpTy>
1366  : public OpTrait::TraitBase<OpTy, ParamProducerTransformOpTrait> {
1367 public:
1368  /// Populates `effects` with effect instances described in the trait
1369  /// documentation.
1372  effects);
1373  }
1374 
1375  /// Checks that the op matches the expectation of this trait, i.e., that it
1376  /// implements the MemoryEffectsOpInterface and only produces parameter-typed
1377  /// results.
1378  static LogicalResult verifyTrait(Operation *op) {
1380  }
1381 };
1382 
1383 /// `TrackingListener` failures are reported only for ops that have this trait.
1384 /// The purpose of this trait is to give users more time to update their custom
1385 /// transform ops to use the provided `TransformRewriter` for all IR
1386 /// modifications. This trait will eventually be removed, and failures will be
1387 /// reported for all transform ops.
1388 template <typename OpTy>
1390  : public OpTrait::TraitBase<OpTy, ReportTrackingListenerFailuresOpTrait> {};
1391 
1392 /// A single result of applying a transform op with `ApplyEachOpTrait` to a
1393 /// single payload operation.
1395 
1396 /// A list of results of applying a transform op with `ApplyEachOpTrait` to a
1397 /// single payload operation, co-indexed with the results of the transform op.
1399 public:
1401  explicit ApplyToEachResultList(unsigned size) : results(size) {}
1402 
1403  /// Sets the list of results to `size` null pointers.
1404  void assign(unsigned size, std::nullptr_t) { results.assign(size, nullptr); }
1405 
1406  /// Sets the list of results to the given range of values.
1407  template <typename Range>
1408  void assign(Range &&range) {
1409  // This is roughly the implementation of SmallVectorImpl::assign.
1410  // Dispatching to it with map_range and template type inference would result
1411  // in more complex code here.
1412  results.clear();
1413  results.reserve(llvm::size(range));
1414  for (auto element : range) {
1415  if constexpr (std::is_convertible_v<decltype(*std::begin(range)),
1416  Operation *>) {
1417  results.push_back(static_cast<Operation *>(element));
1418  } else if constexpr (std::is_convertible_v<decltype(*std::begin(range)),
1419  Value>) {
1420  results.push_back(element.template get<Value>());
1421  } else {
1422  results.push_back(static_cast<Attribute>(element));
1423  }
1424  }
1425  }
1426 
1427  /// Appends an element to the list.
1428  // Using ApplyToEachResult that can be implicitly constructed from a Value but
1429  // not from a concrete Op that is implicitly convertible to a Value to avoid
1430  // ambiguity.
1431  void push_back(Operation *op) { results.push_back(op); }
1432  void push_back(Attribute attr) { results.push_back(attr); }
1433  void push_back(ApplyToEachResult r) { results.push_back(r); }
1434 
1435  /// Reserves space for `size` elements in the list.
1436  void reserve(unsigned size) { results.reserve(size); }
1437 
1438  /// Iterators over the list.
1439  auto begin() { return results.begin(); }
1440  auto end() { return results.end(); }
1441  auto begin() const { return results.begin(); }
1442  auto end() const { return results.end(); }
1443 
1444  /// Returns the number of elements in the list.
1445  size_t size() const { return results.size(); }
1446 
1447  /// Element access. Expects the index to be in bounds.
1448  ApplyToEachResult &operator[](size_t index) { return results[index]; }
1449  const ApplyToEachResult &operator[](size_t index) const {
1450  return results[index];
1451  }
1452 
1453 private:
1454  /// Underlying storage.
1456 };
1457 
1458 namespace detail {
1459 
1460 /// Check that the contents of `partialResult` matches the number, kind (payload
1461 /// op or parameter) and nullity (either all or none) requirements of
1462 /// `transformOp`. Report errors and return failure otherwise.
1463 LogicalResult checkApplyToOne(Operation *transformOp, Location payloadOpLoc,
1464  const ApplyToEachResultList &partialResult);
1465 
1466 /// "Transpose" the results produced by individual applications, arranging them
1467 /// per result value of the transform op, and populate `transformResults` with
1468 /// that. The number, kind and nullity of per-application results are assumed to
1469 /// have been verified.
1470 void setApplyToOneResults(Operation *transformOp,
1471  TransformResults &transformResults,
1473 
1474 /// Applies a one-to-one or a one-to-many transform to each of the given
1475 /// targets. Puts the results of transforms, if any, in `results` in the same
1476 /// order. Fails if any of the application fails. Individual transforms must be
1477 /// callable with the following signature:
1478 /// - DiagnosedSilenceableFailure(OpTy,
1479 /// SmallVector<Operation*> &results, state)
1480 /// where OpTy is either
1481 /// - Operation *, in which case the transform is always applied;
1482 /// - a concrete Op class, in which case a check is performed whether
1483 /// `targets` contains operations of the same class and a silenceable failure
1484 /// is reported if it does not.
1485 template <typename TransformOpTy, typename Range>
1487  TransformOpTy transformOp, TransformRewriter &rewriter, Range &&targets,
1489  using OpTy = typename llvm::function_traits<
1490  decltype(&TransformOpTy::applyToOne)>::template arg_t<1>;
1491  static_assert(std::is_convertible<OpTy, Operation *>::value,
1492  "expected transform function to take an operation");
1493  OpBuilder::InsertionGuard g(rewriter);
1494 
1495  SmallVector<Diagnostic> silenceableStack;
1496  unsigned expectedNumResults = transformOp->getNumResults();
1497  for (Operation *target : targets) {
1498  auto specificOp = dyn_cast<OpTy>(target);
1499  if (!specificOp) {
1500  Diagnostic diag(transformOp->getLoc(), DiagnosticSeverity::Error);
1501  diag << "transform applied to the wrong op kind";
1502  diag.attachNote(target->getLoc()) << "when applied to this op";
1503  silenceableStack.push_back(std::move(diag));
1504  continue;
1505  }
1506 
1507  ApplyToEachResultList partialResults;
1508  partialResults.reserve(expectedNumResults);
1509  Location specificOpLoc = specificOp->getLoc();
1510  rewriter.setInsertionPoint(specificOp);
1512  transformOp.applyToOne(rewriter, specificOp, partialResults, state);
1513  if (res.isDefiniteFailure())
1515 
1516  if (res.isSilenceableFailure()) {
1517  res.takeDiagnostics(silenceableStack);
1518  continue;
1519  }
1520 
1521  if (failed(detail::checkApplyToOne(transformOp, specificOpLoc,
1522  partialResults))) {
1524  }
1525  results.push_back(std::move(partialResults));
1526  }
1527  if (!silenceableStack.empty()) {
1529  std::move(silenceableStack));
1530  }
1532 }
1533 
1534 /// Reports an error and returns failure if `targets` contains an ancestor
1535 /// operation before its descendant (or a copy of itself). Implementation detail
1536 /// for expensive checks during `TransformEachOpTrait::apply`.
1537 LogicalResult checkNestedConsumption(Location loc,
1538  ArrayRef<Operation *> targets);
1539 
1540 } // namespace detail
1541 } // namespace transform
1542 } // namespace mlir
1543 
1544 template <typename OpTy>
1547  TransformRewriter &rewriter, TransformResults &transformResults,
1548  TransformState &state) {
1549  Value handle = this->getOperation()->getOperand(0);
1550  auto targets = state.getPayloadOps(handle);
1551 
1552  // If the operand is consumed, check if it is associated with operations that
1553  // may be erased before their nested operations are.
1554  if (state.getOptions().getExpensiveChecksEnabled() &&
1555  isHandleConsumed(handle, cast<transform::TransformOpInterface>(
1556  this->getOperation())) &&
1557  failed(detail::checkNestedConsumption(this->getOperation()->getLoc(),
1558  llvm::to_vector(targets)))) {
1559  return DiagnosedSilenceableFailure::definiteFailure();
1560  }
1561 
1562  // Step 1. Handle the corner case where no target is specified.
1563  // This is typically the case when the matcher fails to apply and we need to
1564  // propagate gracefully.
1565  // In this case, we fill all results with an empty vector.
1566  if (std::empty(targets)) {
1567  SmallVector<Operation *> emptyPayload;
1568  SmallVector<Attribute> emptyParams;
1569  for (OpResult r : this->getOperation()->getResults()) {
1570  if (isa<TransformParamTypeInterface>(r.getType()))
1571  transformResults.setParams(r, emptyParams);
1572  else if (isa<TransformValueHandleTypeInterface>(r.getType()))
1573  transformResults.setValues(r, ValueRange());
1574  else
1575  transformResults.set(r, emptyPayload);
1576  }
1577  return DiagnosedSilenceableFailure::success();
1578  }
1579 
1580  // Step 2. Call applyToOne on each target and record newly produced ops in its
1581  // corresponding results entry.
1584  cast<OpTy>(this->getOperation()), rewriter, targets, results, state);
1585 
1586  // Step 3. Propagate the definite failure if any and bail out.
1587  if (result.isDefiniteFailure())
1588  return result;
1589 
1590  // Step 4. "Transpose" the results produced by individual applications,
1591  // arranging them per result value of the transform op. The number, kind and
1592  // nullity of per-application results have been verified by the callback
1593  // above.
1594  detail::setApplyToOneResults(this->getOperation(), transformResults, results);
1595 
1596  // Step 5. ApplyToOne may have returned silenceableFailure, propagate it.
1597  return result;
1598 }
1599 
1600 template <typename OpTy>
1601 llvm::LogicalResult
1603  static_assert(OpTy::template hasTrait<OpTrait::OneOperand>(),
1604  "expected single-operand op");
1605  if (!op->getName().getInterface<TransformOpInterface>()) {
1606  return op->emitError() << "TransformEachOpTrait should only be attached to "
1607  "ops that implement TransformOpInterface";
1608  }
1609 
1610  return success();
1611 }
1612 
1613 #endif // DIALECT_TRANSFORM_INTERFACES_TRANSFORMINTERFACES_H
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
Block represents an ordered list of Operations.
Definition: Block.h:33
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
void takeDiagnostics(SmallVectorImpl< Diagnostic > &diags)
Take the diagnostics and silence.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
bool isSilenceableFailure() const
Returns true if this is a silenceable failure.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:357
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:407
This is a value defined by a result of an operation.
Definition: Value.h:457
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:469
Helper class for implementing traits.
Definition: OpDefinition.h:373
Operation * getOperation()
Return the ultimate Operation being worked on.
Definition: OpDefinition.h:376
T::Concept * getInterface() const
Returns an instance of the concept object for the given interface if it was registered to this operat...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:687
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
A 2D array where each row may have different length.
Definition: RaggedArray.h:18
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Definition: Region.cpp:45
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
Block & front()
Definition: Region.h:65
This class implements the result iterators for the Operation class.
Definition: ValueRange.h:242
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
This base class is used for derived effects that are non-parametric.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition: Value.cpp:41
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
auto begin()
Iterators over the list.
const ApplyToEachResult & operator[](size_t index) const
void assign(unsigned size, std::nullptr_t)
Sets the list of results to size null pointers.
void reserve(unsigned size)
Reserves space for size elements in the list.
size_t size() const
Returns the number of elements in the list.
ApplyToEachResult & operator[](size_t index)
Element access. Expects the index to be in bounds.
void push_back(Operation *op)
Appends an element to the list.
void assign(Range &&range)
Sets the list of results to the given range of values.
A specialized listener that keeps track of cases in which no replacement payload could be found.
bool failed() const
Return "true" if this tracking listener had a failure.
void notifyPayloadReplacementNotFound(Operation *op, ValueRange values, DiagnosedSilenceableFailure &&diag) override
This function is called when a tracked payload op is dropped because no replacement op was found.
DiagnosedSilenceableFailure checkAndResetError()
Check and return the current error state of this listener.
Trait implementing the MemoryEffectOpInterface for operations that "consume" their operands and produ...
void getEffects(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
This op "consumes" the operands by reading and freeing then, "produces" the results by allocating and...
static LogicalResult verifyTrait(Operation *op)
Checks that the op matches the expectations of this trait.
Trait implementing the MemoryEffectOpInterface for operations that use their operands without consumi...
void getEffects(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
This op produces handles to the Payload IR without consuming the original handles and without modifyi...
static LogicalResult verifyTrait(Operation *op)
Checks that the op matches the expectation of this trait.
Trait implementing the MemoryEffectsOpInterface for operations that produce transform dialect paramet...
static LogicalResult verifyTrait(Operation *op)
Checks that the op matches the expectation of this trait, i.e., that it implements the MemoryEffectsO...
void getEffects(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with effect instances described in the trait documentation.
This trait is supposed to be attached to Transform dialect operations that can be standalone top-leve...
LogicalResult mapBlockArguments(TransformState &state)
static LogicalResult verifyTrait(Operation *op)
Verifies that op satisfies the invariants of this trait.
void getPotentialTopLevelEffects(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with side effects implied by this trait.
LogicalResult mapBlockArguments(TransformState &state, Region &region)
Sets up the mapping between the entry block of the given region of this op and the relevant list of P...
Block * getBodyBlock(unsigned region=0)
Returns the single block of the given region.
TrackingListener failures are reported only for ops that have this trait.
A listener that updates a TransformState based on IR modifications.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notify the listener that the pattern failed to match the given operation, and provide a callback to p...
TransformOpInterface getTransformOp() const
Return the transform op in which this TrackingListener is used.
TrackingListener(TransformState &state, TransformOpInterface op, TrackingListenerConfig config=TrackingListenerConfig())
Create a new TrackingListener for usage in the specified transform op.
static Operation * getCommonDefiningOp(ValueRange values)
Return the single op that defines all given values (if any).
virtual void notifyPayloadReplacementNotFound(Operation *op, ValueRange values, DiagnosedSilenceableFailure &&diag)
This function is called when a tracked payload op is dropped because no replacement op was found.
virtual DiagnosedSilenceableFailure findReplacementOp(Operation *&result, Operation *op, ValueRange newValues) const
Return a replacement payload op for the given op, which is going to be replaced with the given values...
Trait implementing the TransformOpInterface for operations applying a transformation to a single oper...
static LogicalResult verifyTrait(Operation *op)
Checks that the op matches the expectations of this trait.
DiagnosedSilenceableFailure apply(transform::TransformRewriter &rewriter, TransformResults &transformResults, TransformState &state)
Calls applyToOne for every payload operation associated with the operand of this transform IR op,...
Options controlling the application of transform operations by the TransformState.
TransformOptions & enableExpensiveChecks(bool enable=true)
Requests computationally expensive checks of the transform and payload IR well-formedness to be perfo...
TransformOptions & enableEnforceSingleToplevelTransformOp(bool enable=true)
TransformOptions & operator=(const TransformOptions &)=default
TransformOptions(const TransformOptions &)=default
bool getExpensiveChecksEnabled() const
Returns true if the expensive checks are requested.
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void set(OpResult value, std::initializer_list< Operation * > ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void setValues(OpResult handle, std::initializer_list< Value > values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
void setValues(OpResult handle, Range &&values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
void setParams(OpResult value, ArrayRef< TransformState::Param > params)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void setRemainingToEmpty(TransformOpInterface transform)
Sets the currently unset results to empty lists of the kind expected by the corresponding results of ...
void setMappedValues(OpResult handle, ArrayRef< MappedValue > values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
TransformRewriter(MLIRContext *ctx, ErrorCheckingTrackingListener *listener)
Create a new TransformRewriter.
bool hasTrackingFailures() const
Return "true" if the tracking listener had failures.
LogicalResult notifyPayloadOperationReplaced(Operation *op, Operation *replacement)
Notify the transform dialect interpreter that the given op has been replaced with another op and that...
void silenceTrackingFailure()
Silence all tracking failures that have been encountered so far.
Base class for TransformState extensions that allow TransformState to contain user-specified informat...
Extension(TransformState &state)
Constructs an extension of the given TransformState object.
const TransformState & getTransformState() const
Provides read-only access to the parent TransformState object.
LogicalResult replacePayloadOp(Operation *op, Operation *replacement)
Replaces the given payload op with another op.
virtual ~Extension()
Base virtual destructor.
LogicalResult replacePayloadValue(Value value, Value replacement)
Replaces the given payload value with another value.
A RAII object maintaining a "stack frame" for a transform IR region.
~RegionScope()
Forgets the mapping from or to values defined in the associated transform IR region,...
The state maintained across applications of various ops implementing the TransformOpInterface.
const TransformOptions & getOptions() const
friend LogicalResult applyTransforms(Operation *, TransformOpInterface, const RaggedArray< MappedValue > &, const TransformOptions &, bool, function_ref< void(TransformState &)>, function_ref< LogicalResult(TransformState &)>)
Entry point to the Transform dialect infrastructure.
LogicalResult getHandlesForPayloadValue(Value payloadValue, SmallVectorImpl< Value > &handles, bool includeOutOfScope=false) const
Populates handles with all handles pointing to the given payload IR value.
auto getPayloadOps(Value value) const
Returns an iterator that enumerates all ops that the given transform IR value corresponds to.
auto getPayloadValues(Value handleValue) const
Returns an iterator that enumerates all payload IR values that the given transform IR value correspon...
LogicalResult mapBlockArguments(BlockArgument argument, ArrayRef< Operation * > operations)
Records the mapping between a block argument in the transform IR and a list of operations in the payl...
DiagnosedSilenceableFailure applyTransform(TransformOpInterface transform)
Applies the transformation specified by the given transform op and updates the state accordingly.
RegionScope make_region_scope(Region &region)
Creates a new region scope for the given region.
ArrayRef< Attribute > getParams(Value value) const
Returns the list of parameters that the given transform IR value corresponds to.
LogicalResult mapBlockArgument(BlockArgument argument, ArrayRef< MappedValue > values)
size_t getNumTopLevelMappings() const
Returns the number of extra mappings for the top-level operation.
Ty & addExtension(Args &&...args)
Adds a new Extension of the type specified as template parameter, constructing it with the arguments ...
LogicalResult getHandlesForPayloadOp(Operation *op, SmallVectorImpl< Value > &handles, bool includeOutOfScope=false) const
Populates handles with all handles pointing to the given Payload IR op.
Ty * getExtension()
Returns the extension of the specified type.
Operation * getTopLevel() const
Returns the op at which the transformation state is rooted.
void removeExtension()
Removes the extension of the specified type.
ArrayRef< MappedValue > getTopLevelMapping(size_t position) const
Returns the position-th extra mapping for the top-level operation.
LogicalResult verifyTransformOpInterface(Operation *op)
Verification hook for TransformOpInterface.
LogicalResult appendValueMappings(MutableArrayRef< SmallVector< transform::MappedValue >> mappings, ValueRange values, const transform::TransformState &state, bool flatten=true)
Appends the entities associated with the given transform values in state to the pre-existing list of ...
void setApplyToOneResults(Operation *transformOp, TransformResults &transformResults, ArrayRef< ApplyToEachResultList > results)
"Transpose" the results produced by individual applications, arranging them per result value of the t...
void forwardTerminatorOperands(Block *block, transform::TransformState &state, transform::TransformResults &results)
Populates results with payload associations that match exactly those of the operands to block's termi...
void getParamProducerTransformOpTraitEffects(Operation *op, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Non-template implementation of ParamProducerTransformOpTrait::getEffects().
LogicalResult checkNestedConsumption(Location loc, ArrayRef< Operation * > targets)
Reports an error and returns failure if targets contains an ancestor operation before its descendant ...
LogicalResult mapPossibleTopLevelTransformOpBlockArguments(TransformState &state, Operation *op, Region &region)
Maps the only block argument of the op with PossibleTopLevelTransformOpTrait to either the list of op...
TransformState makeTransformStateForTesting(Region *region, Operation *payloadRoot)
Make a dummy transform state for testing purposes.
LogicalResult checkApplyToOne(Operation *transformOp, Location payloadOpLoc, const ApplyToEachResultList &partialResult)
Check that the contents of partialResult matches the number, kind (payload op or parameter) and nulli...
SmallVector< OpOperand * > getConsumedHandleOpOperands(transform::TransformOpInterface transformOp)
Returns all operands that are handles and being consumed by the given op.
LogicalResult verifyParamProducerTransformOpTrait(Operation *op)
Non-template implementation of ParamProducerTransformOpTrait::verify().
void prepareValueMappings(SmallVectorImpl< SmallVector< transform::MappedValue >> &mappings, ValueRange values, const transform::TransformState &state)
Populates mappings with mapped values associated with the given transform IR values in the given stat...
LogicalResult verifyPossibleTopLevelTransformOpTrait(Operation *op)
Verification hook for PossibleTopLevelTransformOpTrait.
DiagnosedSilenceableFailure applyTransformToEach(TransformOpTy transformOp, TransformRewriter &rewriter, Range &&targets, SmallVectorImpl< ApplyToEachResultList > &results, TransformState &state)
Applies a one-to-one or a one-to-many transform to each of the given targets.
void getPotentialTopLevelEffects(Operation *operation, Value root, Block &body, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with side effects implied by PossibleTopLevelTransformOpTrait for the given operati...
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
LogicalResult applyTransforms(Operation *payloadRoot, TransformOpInterface transform, const RaggedArray< MappedValue > &extraMapping={}, const TransformOptions &options=TransformOptions(), bool enforceToplevelTransformOp=true, function_ref< void(TransformState &)> stateInitializer=nullptr, function_ref< LogicalResult(TransformState &)> stateExporter=nullptr)
Entry point to the Transform dialect infrastructure.
bool isHandleConsumed(Value handle, transform::TransformOpInterface transform)
Checks whether the transform op consumes the given handle.
void getConsumedBlockArguments(Block &block, llvm::SmallDenseSet< unsigned > &consumedArguments)
Populates consumedArguments with positions of block arguments that are consumed by the operations in ...
bool doesModifyPayload(transform::TransformOpInterface transform)
Checks whether the transform op modifies the payload.
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
bool doesReadPayload(transform::TransformOpInterface transform)
Checks whether the transform op reads the payload.
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
llvm::PointerUnion< Operation *, Param, Value > MappedValue
Include the generated interface declarations.
llvm::function_ref< Fn > function_ref
Definition: LLVM.h:152
const FrozenRewritePatternSet GreedyRewriteConfig config
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
Side effect resource corresponding to the Payload IR itself.
StringRef getName() override
Return a string name of the resource.
A configuration object for customizing a TrackingListener.
bool requireMatchingReplacementOpName
If set to "true", the name of a replacement op must match the name of the original op.
bool skipCastOps
If set to "true", cast ops (that implement the CastOpInterface) are skipped and the replacement op se...
SkipHandleFn skipHandleFn
An optional function that returns "true" for handles that do not have to be updated.
std::function< bool(Value)> SkipHandleFn
Side effect resource corresponding to the mapping between Transform IR values and Payload IR operatio...
StringRef getName() override
Return a string name of the resource.