MLIR 23.0.0git
TransformInterfaces.h
Go to the documentation of this file.
1//===- TransformInterfaces.h - Transform Dialect Interfaces -----*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_DIALECT_TRANSFORM_INTERFACES_TRANSFORMINTERFACES_H
10#define MLIR_DIALECT_TRANSFORM_INTERFACES_TRANSFORMINTERFACES_H
11
18
19#include "mlir/Dialect/Transform/Interfaces/TransformAttrInterfaces.h.inc"
20#include "mlir/Dialect/Transform/Interfaces/TransformTypeInterfaces.h.inc"
21
22namespace mlir {
23namespace transform {
24
25class TransformOpInterface;
28class TransformState;
29
32
33namespace detail {
34/// Maps the only block argument of the op with PossibleTopLevelTransformOpTrait
35/// to either the list of operations associated with its operand or the root of
36/// the payload IR, depending on what is available in the context.
37LogicalResult
39 Operation *op, Region &region);
40
41/// Verification hook for PossibleTopLevelTransformOpTrait.
43
44/// Populates `effects` with side effects implied by
45/// PossibleTopLevelTransformOpTrait for the given operation. The operation may
46/// have an optional `root` operand, indicating it is not in fact top-level. It
47/// is also expected to have a single-block body.
49 Operation *operation, Value root, Block &body,
51
52/// Verification hook for TransformOpInterface.
53LogicalResult verifyTransformOpInterface(Operation *op);
54
55/// Appends the entities associated with the given transform values in `state`
56/// to the pre-existing list of mappings. The array of mappings must have as
57/// many elements as values. If `flatten` is set, multiple values may be
58/// associated with each transform value, and this always succeeds. Otherwise,
59/// checks that each value has exactly one mapping associated and return failure
60/// otherwise.
61LogicalResult appendValueMappings(
63 ValueRange values, const transform::TransformState &state,
64 bool flatten = true);
65
66/// Populates `mappings` with mapped values associated with the given transform
67/// IR values in the given `state`.
70 ValueRange values, const transform::TransformState &state);
71
72/// Populates `results` with payload associations that match exactly those of
73/// the operands to `block`'s terminator.
76
77/// Make a dummy transform state for testing purposes. This MUST NOT be used
78/// outside of test cases.
80 Operation *payloadRoot);
81
82/// Returns all operands that are handles and being consumed by the given op.
84getConsumedHandleOpOperands(transform::TransformOpInterface transformOp);
85
86/// Checks that the given payload operations satisfy the normal form
87/// constraints, reports the first encountered definite failure or the first
88/// encountered silenceable failure if there were no definite failures. Normal
89/// forms are checked in order, so trailing normal forms may assume earlier
90/// normal forms did not produce definite failures.
94
95/// Checks that the given list does not contain duplicate normal forms of the
96/// same class.
97llvm::LogicalResult verifyNormalFormList(
100} // namespace detail
101} // namespace transform
102} // namespace mlir
103
104#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h.inc"
105
106namespace mlir {
107namespace transform {
108
109/// Options controlling the application of transform operations by the
110/// TransformState.
112public:
113 TransformOptions() = default;
116
117 /// Requests computationally expensive checks of the transform and payload IR
118 /// well-formedness to be performed before each transformation. In particular,
119 /// these ensure that the handles still point to valid operations when used.
121 expensiveChecksEnabled = enable;
122 return *this;
123 }
124
125 // Ensures that only a single top-level transform op is present in the IR.
127 enforceSingleToplevelTransformOp = enable;
128 return *this;
129 }
130
131 /// Returns true if the expensive checks are requested.
132 bool getExpensiveChecksEnabled() const { return expensiveChecksEnabled; }
133
134 // Returns true if enforcing a single top-level transform op is requested.
136 return enforceSingleToplevelTransformOp;
137 }
138
139private:
140 bool expensiveChecksEnabled = true;
141 bool enforceSingleToplevelTransformOp = true;
142};
143
144/// Entry point to the Transform dialect infrastructure. Applies the
145/// transformation specified by `transform` to payload IR contained in
146/// `payloadRoot`. The `transform` operation may contain other operations that
147/// will be executed following the internal logic of the operation. It must
148/// have the `PossibleTopLevelTransformOp` trait and not have any operands.
149/// This function internally keeps track of the transformation state.
150LogicalResult applyTransforms(
151 Operation *payloadRoot, TransformOpInterface transform,
152 const RaggedArray<MappedValue> &extraMapping = {},
153 const TransformOptions &options = TransformOptions(),
154 bool enforceToplevelTransformOp = true,
155 function_ref<void(TransformState &)> stateInitializer = nullptr,
156 function_ref<LogicalResult(TransformState &)> stateExporter = nullptr);
157
158/// The state maintained across applications of various ops implementing the
159/// TransformOpInterface. The operations implementing this interface and the
160/// surrounding structure are referred to as transform IR. The operations to
161/// which transformations apply are referred to as payload IR. Transform IR
162/// operates on values that can be associated either with a list of payload IR
163/// operations (such values are referred to as handles) or with a list of
164/// parameters represented as attributes. The state thus contains the mapping
165/// between values defined in the transform IR ops and either payload IR ops or
166/// parameters. For payload ops, the mapping is many-to-many and the reverse
167/// mapping is also stored. The "expensive-checks" option can be passed to the
168/// constructor at transformation execution time that transform IR values used
169/// as operands by a transform IR operation are not associated with dangling
170/// pointers to payload IR operations that are known to have been erased by
171/// previous transformation through the same or a different transform IR value.
172///
173/// A reference to this class is passed as an argument to "apply" methods of the
174/// transform op interface. Thus the "apply" method can call either
175/// `state.getPayloadOps( getSomeOperand() )` to obtain the list of operations
176/// or `state.getParams( getSomeOperand() )` to obtain the list of parameters
177/// associated with its operand. The method is expected to populate the
178/// `TransformResults` class instance in order to update the mapping. The
179/// `applyTransform` method takes care of propagating the state of
180/// `TransformResults` into the instance of this class.
181///
182/// When applying transform IR operations with regions, the client is expected
183/// to create a `RegionScope` RAII object to create a new "stack frame" for
184/// values defined inside the region. The mappings from and to these values will
185/// be automatically dropped when the object goes out of scope, typically at the
186/// end of the `apply` function of the parent operation. If a region contains
187/// blocks with arguments, the client can map those arguments to payload IR ops
188/// using `mapBlockArguments`.
189class TransformState {
190public:
192
193private:
194 /// Mapping between a Value in the transform IR and the corresponding set of
195 /// operations in the payload IR.
196 using TransformOpMapping = DenseMap<Value, SmallVector<Operation *, 2>>;
197
198 /// Mapping between a payload IR operation and the transform IR values it is
199 /// associated with.
200 using TransformOpReverseMapping =
202
203 /// Mapping between a Value in the transform IR and the corresponding list of
204 /// parameters.
205 using ParamMapping = DenseMap<Value, SmallVector<Param>>;
206
207 /// Mapping between a Value in the transform IR and the corrsponding list of
208 /// values in the payload IR. Also works for reverse mappings.
209 using ValueMapping = DenseMap<Value, SmallVector<Value>>;
210
211 /// Mapping between a Value in the transform IR and an error message that
212 /// should be emitted when the value is used.
213 using InvalidatedHandleMap = DenseMap<Value, std::function<void(Location)>>;
214
215#if LLVM_ENABLE_ABI_BREAKING_CHECKS
216 /// Debug only: A timestamp is associated with each transform IR value, so
217 /// that invalid iterator usage can be detected more reliably.
218 using TransformIRTimestampMapping = DenseMap<Value, int64_t>;
219#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
220
221 /// The bidirectional mappings between transform IR values and payload IR
222 /// operations, and the mapping between transform IR values and parameters.
223 struct Mappings {
224 TransformOpMapping direct;
225 TransformOpReverseMapping reverse;
226 ParamMapping params;
227 ValueMapping values;
228 ValueMapping reverseValues;
229
230#if LLVM_ENABLE_ABI_BREAKING_CHECKS
231 TransformIRTimestampMapping timestamps;
232 void incrementTimestamp(Value value) { ++timestamps[value]; }
233#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
234 };
235
236 friend LogicalResult
237 applyTransforms(Operation *, TransformOpInterface,
239 bool, function_ref<void(TransformState &)>,
240 function_ref<LogicalResult(TransformState &)>);
241
242 friend TransformState
244
245public:
246 const TransformOptions &getOptions() const { return options; }
247
248 /// Returns the op at which the transformation state is rooted. This is
249 /// typically helpful for transformations that apply globally.
250 Operation *getTopLevel() const;
251
252 /// Returns the number of extra mappings for the top-level operation.
253 size_t getNumTopLevelMappings() const { return topLevelMappedValues.size(); }
254
255 /// Returns the position-th extra mapping for the top-level operation.
257 return topLevelMappedValues[position];
258 }
259
260 /// Returns an iterator that enumerates all ops that the given transform IR
261 /// value corresponds to. Ops may be erased while iterating; erased ops are
262 /// not enumerated. This function is helpful for transformations that apply to
263 /// a particular handle.
264 auto getPayloadOps(Value value) const {
265 ArrayRef<Operation *> view = getPayloadOpsView(value);
266
267#if LLVM_ENABLE_ABI_BREAKING_CHECKS
268 // Memorize the current timestamp and make sure that it has not changed
269 // when incrementing or dereferencing the iterator returned by this
270 // function. The timestamp is incremented when the "direct" mapping is
271 // resized; this would invalidate the iterator returned by this function.
272 int64_t currentTimestamp = getMapping(value).timestamps.lookup(value);
273#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
274
275 // When ops are replaced/erased, they are replaced with nullptr (until
276 // the data structure is compacted). Do not enumerate these ops.
277 return llvm::make_filter_range(view, [=](Operation *op) {
278#if LLVM_ENABLE_ABI_BREAKING_CHECKS
279 [[maybe_unused]] bool sameTimestamp =
280 currentTimestamp == this->getMapping(value).timestamps.lookup(value);
281 assert(sameTimestamp && "iterator was invalidated during iteration");
282#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
283 return op != nullptr;
284 });
285 }
286
287 /// Returns the list of parameters that the given transform IR value
288 /// corresponds to.
290
291 /// Returns an iterator that enumerates all payload IR values that the given
292 /// transform IR value corresponds to.
293 auto getPayloadValues(Value handleValue) const {
294 ArrayRef<Value> view = getPayloadValuesView(handleValue);
295
296#if LLVM_ENABLE_ABI_BREAKING_CHECKS
297 // Memorize the current timestamp and make sure that it has not changed
298 // when incrementing or dereferencing the iterator returned by this
299 // function. The timestamp is incremented when the "values" mapping is
300 // resized; this would invalidate the iterator returned by this function.
301 int64_t currentTimestamp =
302 getMapping(handleValue).timestamps.lookup(handleValue);
303 return llvm::make_filter_range(view, [=](Value v) {
304 [[maybe_unused]] bool sameTimestamp =
305 currentTimestamp ==
306 this->getMapping(handleValue).timestamps.lookup(handleValue);
307 assert(sameTimestamp && "iterator was invalidated during iteration");
308 return true;
309 });
310#else
311 return llvm::make_range(view.begin(), view.end());
312#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
313 }
314
315 /// Populates `handles` with all handles pointing to the given Payload IR op.
316 /// Returns success if such handles exist, failure otherwise.
317 /// If `includeOutOfScope` is set to "true", handles that are defined in
318 /// regions beyond the most recent isolated from above region are included.
319 LogicalResult getHandlesForPayloadOp(Operation *op,
320 SmallVectorImpl<Value> &handles,
321 bool includeOutOfScope = false) const;
322
323 /// Populates `handles` with all handles pointing to the given payload IR
324 /// value. Returns success if such handles exist, failure otherwise.
325 /// If `includeOutOfScope` is set to "true", handles that are defined in
326 /// regions beyond the most recent isolated from above region are included.
327 LogicalResult getHandlesForPayloadValue(Value payloadValue,
328 SmallVectorImpl<Value> &handles,
329 bool includeOutOfScope = false) const;
330
331 /// Applies the transformation specified by the given transform op and updates
332 /// the state accordingly.
334
335 /// Records the mapping between a block argument in the transform IR and a
336 /// list of operations in the payload IR. The arguments must be defined in
337 /// blocks of the currently processed transform IR region, typically after a
338 /// region scope is defined.
339 ///
340 /// Returns failure if the payload does not satisfy the conditions associated
341 /// with the type of the handle value.
342 LogicalResult mapBlockArguments(BlockArgument argument,
343 ArrayRef<Operation *> operations) {
344 assert(argument.getParentRegion() == regionStack.back()->region &&
345 "mapping block arguments from a region other than the active one");
346 return setPayloadOps(argument, operations);
347 }
348 LogicalResult mapBlockArgument(BlockArgument argument,
349 ArrayRef<MappedValue> values);
350 LogicalResult mapBlockArguments(Block::BlockArgListType arguments,
352
353 // Forward declarations to support limited visibility.
354 class RegionScope;
355
356 /// Creates a new region scope for the given region. The region is expected to
357 /// be nested in the currently processed region.
358 // Implementation note: this method is inline but implemented outside of the
359 // class body to comply with visibility and full-declaration requirements.
360 inline RegionScope make_region_scope(Region &region);
361
362 /// A RAII object maintaining a "stack frame" for a transform IR region. When
363 /// applying a transform IR operation that contains a region, the caller is
364 /// expected to create a RegionScope before applying the ops contained in the
365 /// region. This ensures that the mappings between values defined in the
366 /// transform IR region and payload IR operations are cleared when the region
367 /// processing ends; such values cannot be accessed outside the region.
368 class RegionScope {
369 public:
370 /// Forgets the mapping from or to values defined in the associated
371 /// transform IR region, and restores the mapping that existed before
372 /// entering this scope.
373 ~RegionScope();
374
375 private:
376 /// Creates a new scope for mappings between values defined in the given
377 /// transform IR region and payload IR objects.
378 RegionScope(TransformState &state, Region &region)
379 : state(state), region(&region) {
380 auto res = state.mappings.insert(
381 std::make_pair(&region, std::make_unique<Mappings>()));
382 assert(res.second && "the region scope is already present");
383 (void)res;
384 state.regionStack.push_back(this);
385 }
386
387 /// Back-reference to the transform state.
388 TransformState &state;
389
390 /// The region this scope is associated with.
391 Region *region;
392
393 /// The transform op within this region that is currently being applied.
394 TransformOpInterface currentTransform;
395
397 };
398 friend class RegionScope;
399
400 /// Base class for TransformState extensions that allow TransformState to
401 /// contain user-specified information in the state object. Clients are
402 /// expected to derive this class, add the desired fields, and make the
403 /// derived class compatible with the MLIR TypeID mechanism:
404 ///
405 /// ```mlir
406 /// class MyExtension final : public TransformState::Extension {
407 /// public:
408 /// MyExtension(TranfsormState &state, int myData)
409 /// : Extension(state) {...}
410 /// private:
411 /// int mySupplementaryData;
412 /// };
413 /// ```
414 ///
415 /// Instances of this and derived classes are not expected to be created by
416 /// the user, instead they are directly constructed within a TransformState. A
417 /// TransformState can only contain one extension with the given TypeID.
418 /// Extensions can be obtained from a TransformState instance, and can be
419 /// removed when they are no longer required.
420 ///
421 /// ```mlir
422 /// transformState.addExtension<MyExtension>(/*myData=*/42);
423 /// MyExtension *ext = transformState.getExtension<MyExtension>();
424 /// ext->doSomething();
425 /// ```
426 class Extension {
427 // Allow TransformState to allocate Extensions.
428 friend class TransformState;
429
430 public:
431 /// Base virtual destructor.
432 // Out-of-line definition ensures symbols are emitted in a single object
433 // file.
434 virtual ~Extension();
435
436 protected:
437 /// Constructs an extension of the given TransformState object.
438 Extension(TransformState &state) : state(state) {}
439
440 /// Provides read-only access to the parent TransformState object.
441 const TransformState &getTransformState() const { return state; }
442
443 /// Replaces the given payload op with another op. If the replacement op is
444 /// null, removes the association of the payload op with its handle. Returns
445 /// failure if the op is not associated with any handle.
446 ///
447 /// Note: This function does not update value handles. None of the original
448 /// op's results are allowed to be mapped to any value handle.
449 LogicalResult replacePayloadOp(Operation *op, Operation *replacement);
450
451 /// Replaces the given payload value with another value. If the replacement
452 /// value is null, removes the association of the payload value with its
453 /// handle. Returns failure if the value is not associated with any handle.
454 LogicalResult replacePayloadValue(Value value, Value replacement);
455
456 private:
457 /// Back-reference to the state that is being extended.
458 TransformState &state;
459 };
460
461 /// Adds a new Extension of the type specified as template parameter,
462 /// constructing it with the arguments provided. The extension is owned by the
463 /// TransformState. It is expected that the state does not already have an
464 /// extension of the same type. Extension constructors are expected to take
465 /// a reference to TransformState as first argument, automatically supplied
466 /// by this call.
467 template <typename Ty, typename... Args>
468 Ty &addExtension(Args &&...args) {
469 static_assert(
470 std::is_base_of<Extension, Ty>::value,
471 "only an class derived from TransformState::Extension is allowed here");
472 auto ptr = std::make_unique<Ty>(*this, std::forward<Args>(args)...);
473 auto result = extensions.try_emplace(TypeID::get<Ty>(), std::move(ptr));
474 assert(result.second && "extension already added");
475 return *static_cast<Ty *>(result.first->second.get());
476 }
477
478 /// Returns the extension of the specified type.
479 template <typename Ty>
481 static_assert(
482 std::is_base_of<Extension, Ty>::value,
483 "only an class derived from TransformState::Extension is allowed here");
484 auto iter = extensions.find(TypeID::get<Ty>());
485 if (iter == extensions.end())
486 return nullptr;
487 return static_cast<Ty *>(iter->second.get());
488 }
489
490 /// Removes the extension of the specified type.
491 template <typename Ty>
493 static_assert(
494 std::is_base_of<Extension, Ty>::value,
495 "only an class derived from TransformState::Extension is allowed here");
496 extensions.erase(TypeID::get<Ty>());
497 }
498
499private:
500 /// Identifier for storing top-level value in the `operations` mapping.
501 static constexpr Value kTopLevelValue = Value();
502
503 /// Creates a state for transform ops living in the given region. The second
504 /// argument points to the root operation in the payload IR being transformed,
505 /// which may or may not contain the region with transform ops. Additional
506 /// options can be provided through the trailing configuration object.
507 TransformState(Region *region, Operation *payloadRoot,
508 const RaggedArray<MappedValue> &extraMappings = {},
509 const TransformOptions &options = TransformOptions());
510
511 /// Returns the mappings frame for the region in which the value is defined.
512 /// If `allowOutOfScope` is set to "false", asserts that the value is in
513 /// scope, based on the current stack of frames.
514 const Mappings &getMapping(Value value, bool allowOutOfScope = false) const {
515 return const_cast<TransformState *>(this)->getMapping(value,
516 allowOutOfScope);
517 }
518 Mappings &getMapping(Value value, bool allowOutOfScope = false) {
519 Region *region = value.getParentRegion();
520 auto it = mappings.find(region);
521 assert(it != mappings.end() &&
522 "trying to find a mapping for a value from an unmapped region");
523#ifndef NDEBUG
524 if (!allowOutOfScope) {
525 for (Region *r : llvm::reverse(llvm::make_first_range(mappings))) {
526 if (r == region)
527 break;
528 if (r->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
529 llvm_unreachable("trying to get mapping beyond region that is "
530 "isolated from above");
531 }
532 }
533#endif // NDEBUG
534 return *it->second;
535 }
536
537 /// Returns the mappings frame for the region in which the operation resides.
538 /// If `allowOutOfScope` is set to "false", asserts that the operation is in
539 /// scope, based on the current stack of frames.
540 const Mappings &getMapping(Operation *operation,
541 bool allowOutOfScope = false) const {
542 return const_cast<TransformState *>(this)->getMapping(operation,
543 allowOutOfScope);
544 }
545 Mappings &getMapping(Operation *operation, bool allowOutOfScope = false) {
546 Region *region = operation->getParentRegion();
547 auto it = mappings.find(region);
548 assert(it != mappings.end() &&
549 "trying to find a mapping for an operation from an unmapped region");
550#ifndef NDEBUG
551 if (!allowOutOfScope) {
552 for (Region *r : llvm::reverse(llvm::make_first_range(mappings))) {
553 if (r == region)
554 break;
555 if (r->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
556 llvm_unreachable("trying to get mapping beyond region that is "
557 "isolated from above");
558 }
559 }
560#endif // NDEBUG
561 return *it->second;
562 }
563
564 /// Updates the state to include the associations between op results and the
565 /// provided result of applying a transform op.
566 LogicalResult updateStateFromResults(const TransformResults &results,
567 ResultRange opResults);
568
569 /// Returns a list of all ops that the given transform IR value corresponds
570 /// to. In case an op was erased, the returned list contains nullptr. This
571 /// function is helpful for transformations that apply to a particular handle.
572 ArrayRef<Operation *> getPayloadOpsView(Value value) const;
573
574 /// Returns a list of payload IR values that the given transform IR value
575 /// corresponds to.
576 ArrayRef<Value> getPayloadValuesView(Value handleValue) const;
577
578 /// Sets the payload IR ops associated with the given transform IR value
579 /// (handle). A payload op may be associated multiple handles as long as
580 /// at most one of them gets consumed by further transformations.
581 /// For example, a hypothetical "find function by name" may be called twice in
582 /// a row to produce two handles pointing to the same function:
583 ///
584 /// %0 = transform.find_func_by_name { name = "myfunc" }
585 /// %1 = transform.find_func_by_name { name = "myfunc" }
586 ///
587 /// which is valid by itself. However, calling a hypothetical "rewrite and
588 /// rename function" transform on both handles:
589 ///
590 /// transform.rewrite_and_rename %0 { new_name = "func" }
591 /// transform.rewrite_and_rename %1 { new_name = "func" }
592 ///
593 /// is invalid given the transformation "consumes" the handle as expressed
594 /// by side effects. Practically, a transformation consuming a handle means
595 /// that the associated payload operation may no longer exist.
596 ///
597 /// Similarly, operation handles may be invalidate and should not be used
598 /// after a transform that consumed a value handle pointing to a payload value
599 /// defined by the operation as either block argument or op result. For
600 /// example, in the following sequence, the last transform operation rewrites
601 /// the callee to not return a specified result:
602 ///
603 /// %0 = transform.find_call "myfunc"
604 /// %1 = transform.find_results_of_calling "myfunc"
605 /// transform.drop_call_result_from_signature %1[0]
606 ///
607 /// which requires the call operations to be recreated. Therefore, the handle
608 /// %0 becomes associated with a dangling pointer and should not be used.
609 ///
610 /// Returns failure if the payload does not satisfy the conditions associated
611 /// with the type of the handle value. The value is expected to have a type
612 /// implementing TransformHandleTypeInterface.
613 LogicalResult setPayloadOps(Value value, ArrayRef<Operation *> targets);
614
615 /// Sets the payload IR values association with the given transform IR value
616 /// (handle). A payload value may be associated with multiple handles as long
617 /// as at most one of them is consumed by further transformations. For
618 /// example, a hypothetical "get results of calls to function with the given
619 /// name" transform may be performed twice in a row producing handles pointing
620 /// to the same values:
621 ///
622 /// %0 = transform.find_results_of_calling "myfunc"
623 /// %1 = transform.find_results_of_calling "myfunc"
624 ///
625 /// which is valid by itself. However, calling a hypothetical "erase value
626 /// producer" transform on both handles:
627 ///
628 /// transform.erase_value_produce %0
629 /// transform.erase_value_produce %1
630 ///
631 /// is invalid provided the transformation "consumes" the handle as expressed
632 /// by side effects (which themselves reflect the semantics of the transform
633 /// erasing the producer and making the handle dangling). Practically, a
634 /// transformation consuming a handle means the associated payload value may
635 /// no longer exist.
636 ///
637 /// Similarly, value handles are invalidated and should not be used after a
638 /// transform that consumed an operation handle pointing to the payload IR
639 /// operation defining the values associated the value handle, as either block
640 /// arguments or op results, or any ancestor operation. For example,
641 ///
642 /// %0 = transform.find_call "myfunc"
643 /// %1 = transform.find_results_of_calling "myfunc"
644 /// transform.rewrite_and_rename %0 { new_name = "func" }
645 ///
646 /// makes %1 unusable after the last transformation if it consumes %0. When an
647 /// operation handle is consumed, it usually indicates that the operation was
648 /// destroyed or heavily modified, meaning that the values it defines may no
649 /// longer exist.
650 ///
651 /// Returns failure if the payload values do not satisfy the conditions
652 /// associated with the type of the handle value. The value is expected to
653 /// have a type implementing TransformValueHandleTypeInterface.
654 LogicalResult setPayloadValues(Value handle, ValueRange payloadValues);
655
656 /// Sets the parameters associated with the given transform IR value. Returns
657 /// failure if the parameters do not satisfy the conditions associated with
658 /// the type of the value. The value is expected to have a type implementing
659 /// TransformParamTypeInterface.
660 LogicalResult setParams(Value value, ArrayRef<Param> params);
661
662 /// Forgets the payload IR ops associated with the given transform IR value,
663 /// as well as any association between value handles and the results of said
664 /// payload IR op.
665 ///
666 /// If `allowOutOfScope` is set to "false", asserts that the handle is in
667 /// scope, based on the current stack of frames.
668 void forgetMapping(Value opHandle, ValueRange origOpFlatResults,
669 bool allowOutOfScope = false);
670
671 void forgetValueMapping(Value valueHandle,
672 ArrayRef<Operation *> payloadOperations);
673
674 /// Replaces the given payload op with another op. If the replacement op is
675 /// null, removes the association of the payload op with its handle. Returns
676 /// failure if the op is not associated with any handle.
677 ///
678 /// Note: This function does not update value handles. None of the original
679 /// op's results are allowed to be mapped to any value handle.
680 LogicalResult replacePayloadOp(Operation *op, Operation *replacement);
681
682 /// Replaces the given payload value with another value. If the replacement
683 /// value is null, removes the association of the payload value with its
684 /// handle. Returns failure if the value is not associated with any handle.
685 LogicalResult replacePayloadValue(Value value, Value replacement);
686
687 /// Records handle invalidation reporters into `newlyInvalidated`.
688 /// Specifically,
689 /// - `handle` is the op operand that consumes the handle,
690 /// - `potentialAncestors` is a list of ancestors of the payload operation
691 /// that the consumed handle is associated with, including itself,
692 /// - `throughValue` is the payload value the handle to which is consumed,
693 /// when it is the case, null when the operation handle is consumed
694 /// directly.
695 /// Iterates over all known operation and value handles and records reporters
696 /// for any potential future use of `handle` or any other handle that is
697 /// invalidated by its consumption, i.e., any handle pointing to any payload
698 /// IR entity (operation or value) associated with the same payload IR entity
699 /// as the consumed handle, or any nested payload IR entity. If
700 /// `potentialAncestors` is empty, records the reporter anyway. Does not
701 /// override existing reporters. This must remain a const method so it doesn't
702 /// inadvertently mutate `invalidatedHandles` too early.
703 void recordOpHandleInvalidation(OpOperand &consumingHandle,
704 ArrayRef<Operation *> potentialAncestors,
705 Value throughValue,
706 InvalidatedHandleMap &newlyInvalidated) const;
707
708 /// Records handle invalidation reporters into `newlyInvalidated`.
709 /// Specifically,
710 /// - `consumingHandle` is the op operand that consumes the handle,
711 /// - `potentialAncestors` is a list of ancestors of the payload operation
712 /// that the consumed handle is associated with, including itself,
713 /// - `payloadOp` is the operation itself,
714 /// - `otherHandle` is another that may be associated with the affected
715 /// payload operations
716 /// - `throughValue` is the payload value the handle to which is consumed,
717 /// when it is the case, null when the operation handle is consumed
718 /// directly.
719 /// Looks at the payload operations associated with `otherHandle` and if any
720 /// of these operations has an ancestor (or is itself) listed in
721 /// `potentialAncestors`, records the error message describing the use of the
722 /// invalidated handle. Does nothing if `otherHandle` already has a reporter
723 /// associated with it. This must remain a const method so it doesn't
724 /// inadvertently mutate `invalidatedHandles` too early.
725 void recordOpHandleInvalidationOne(
726 OpOperand &consumingHandle, ArrayRef<Operation *> potentialAncestors,
727 Operation *payloadOp, Value otherHandle, Value throughValue,
728 InvalidatedHandleMap &newlyInvalidated) const;
729
730 /// Records handle invalidation reporters into `newlyInvalidated`.
731 /// Specifically,
732 /// - `opHandle` is the op operand that consumes the handle;
733 /// - `potentialAncestors` is a list of ancestors of the payload operation
734 /// that the consumed handle is associated with, including itself;
735 /// - `payloadValue` is the value defined by the operation associated with
736 /// the consuming handle as either op result or block argument;
737 /// - `valueHandle` is another that may be associated with the payload value.
738 /// Looks at the payload values associated with `valueHandle` and if any of
739 /// these values is defined, as op result or block argument, by an operation
740 /// whose ancestor (or the operation itself) is listed in
741 /// `potentialAncestors`, records the error message describing the use of the
742 /// invalidated handle. Does nothing if `valueHandle` already has a reporter
743 /// associated with it. This must remain a const method so it doesn't
744 /// inadvertently mutate `invalidatedHandles` too early.
745 void recordValueHandleInvalidationByOpHandleOne(
746 OpOperand &opHandle, ArrayRef<Operation *> potentialAncestors,
747 Value payloadValue, Value valueHandle,
748 InvalidatedHandleMap &newlyInvalidated) const;
749
750 /// Records handle invalidation reporters into `newlyInvalidated`.
751 /// Specifically,
752 /// - `valueHandle` is the op operand that consumes the handle,
753 /// - `throughValue` is the payload value the handle to which is consumed,
754 /// when it is the case, null when the operation handle is consumed
755 /// directly.
756 /// Iterates over all known operation and value handles and records reporters
757 /// for any potential future use of `handle` or any other handle that is
758 /// invalidated by its consumption, i.e., any handle pointing to any payload
759 /// IR entity (operation or value) associated with the same payload IR entity
760 /// as the consumed handle, or any nested payload IR entity. Does not override
761 /// existing reporters. This must remain a const method so it doesn't
762 /// inadvertently mutate `invalidatedHandles` too early.
763 void
764 recordValueHandleInvalidation(OpOperand &valueHandle,
765 InvalidatedHandleMap &newlyInvalidated) const;
766
767 /// Checks that the operation does not use invalidated handles as operands.
768 /// Reports errors and returns failure if it does. Otherwise, invalidates the
769 /// handles consumed by the operation as well as any handles pointing to
770 /// payload IR operations nested in the operations associated with the
771 /// consumed handles.
772 LogicalResult
773 checkAndRecordHandleInvalidation(TransformOpInterface transform);
774
775 /// Implementation of the checkAndRecordHandleInvalidation. This must remain a
776 /// const method so it doesn't inadvertently mutate `invalidatedHandles` too
777 /// early.
778 LogicalResult checkAndRecordHandleInvalidationImpl(
779 transform::TransformOpInterface transform,
780 transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const;
781
782 /// Remove all nullptrs from op handles that were added by `replacePayloadOp`.
783 void compactOpHandles();
784
785 /// A stack of mappings between transform IR values and payload IR ops,
786 /// aggregated by the region in which the transform IR values are defined.
787 /// We use a pointer to the Mappings struct so that reallocations inside
788 /// MapVector don't invalidate iterators when we apply nested transform ops
789 /// while also iterating over the mappings.
790 llvm::MapVector<Region *, std::unique_ptr<Mappings>> mappings;
791
792 /// Op handles may be temporarily mapped to nullptr to avoid invalidating
793 /// payload op iterators. This set contains all op handles with nullptrs.
794 /// These handles are "compacted" (i.e., nullptrs removed) at the end of each
795 /// transform.
796 DenseSet<Value> opHandlesToCompact;
797
798 /// Extensions attached to the TransformState, identified by the TypeID of
799 /// their type. Only one extension of any given type is allowed.
801
802 /// The top-level operation that contains all payload IR, typically a module.
803 Operation *topLevel;
804
805 /// Extra mapped values (payload operations, values or parameters) to be
806 /// associated with additional entry block arguments of the top-level
807 /// transform operation.
808 RaggedArray<MappedValue> topLevelMappedValues;
809
810 /// Additional options controlling the transformation state behavior.
811 TransformOptions options;
812
813 /// The mapping from invalidated handles to the error-reporting functions that
814 /// describe when the handles were invalidated. Calling such a function emits
815 /// a user-visible diagnostic with an additional note pointing to the given
816 /// location.
817 InvalidatedHandleMap invalidatedHandles;
818
819 /// A stack of nested regions that are being processed in the transform IR.
820 /// Each region must be an ancestor of the following regions in this list.
821 /// These are also the keys for "mappings".
822 SmallVector<RegionScope *> regionStack;
823
824 /// The top-level region scope. The first (bottom) element of `regionStack`
825 /// is the top-level region scope object.
826 std::unique_ptr<RegionScope> topLevelRegionScope;
827};
828
829/// Local mapping between values defined by a specific op implementing the
830/// TransformOpInterface and the payload IR ops they correspond to.
831class TransformResults {
832 friend class TransformState;
833
834public:
835 /// Indicates that the result of the transform IR op at the given position
836 /// corresponds to the given list of payload IR ops. Each result must be set
837 /// by the transformation exactly once in case of transformation succeeding.
838 /// The value must have a type implementing TransformHandleTypeInterface.
839 template <typename Range>
840 void set(OpResult value, Range &&ops) {
841 int64_t position = value.getResultNumber();
842 assert(position < static_cast<int64_t>(operations.size()) &&
843 "setting results for a non-existent handle");
844 assert(operations[position].data() == nullptr && "results already set");
845 assert(params[position].data() == nullptr &&
846 "another kind of results already set");
847 assert(values[position].data() == nullptr &&
848 "another kind of results already set");
849 operations.replace(position, std::forward<Range>(ops));
850 }
851
852 /// Indicates that the result of the transform IR op at the given position
853 /// corresponds to the given list of payload IR ops. Each result must be set
854 /// by the transformation exactly once in case of transformation succeeding.
855 /// The value must have a type implementing TransformHandleTypeInterface.
856 void set(OpResult value, std::initializer_list<Operation *> ops) {
857 set(value, ArrayRef<Operation *>(ops));
858 }
859
860 /// Indicates that the result of the transform IR op at the given position
861 /// corresponds to the given list of parameters. Each result must be set by
862 /// the transformation exactly once in case of transformation succeeding. The
863 /// value must have a type implementing TransformParamTypeInterface.
865
866 /// Indicates that the result of the transform IR op at the given position
867 /// corresponds to the given range of payload IR values. Each result must be
868 /// set by the transformation exactly once in case of transformation
869 /// succeeding. The value must have a type implementing
870 /// TransformValueHandleTypeInterface.
871 template <typename Range>
872 void setValues(OpResult handle, Range &&values) {
873 int64_t position = handle.getResultNumber();
874 assert(position < static_cast<int64_t>(this->values.size()) &&
875 "setting values for a non-existent handle");
876 assert(this->values[position].data() == nullptr && "values already set");
877 assert(operations[position].data() == nullptr &&
878 "another kind of results already set");
879 assert(params[position].data() == nullptr &&
880 "another kind of results already set");
881 this->values.replace(position, std::forward<Range>(values));
882 }
883
884 /// Indicates that the result of the transform IR op at the given position
885 /// corresponds to the given range of payload IR values. Each result must be
886 /// set by the transformation exactly once in case of transformation
887 /// succeeding. The value must have a type implementing
888 /// TransformValueHandleTypeInterface.
889 void setValues(OpResult handle, std::initializer_list<Value> values) {
890 setValues(handle, ArrayRef<Value>(values));
891 }
892
893 /// Indicates that the result of the transform IR op at the given position
894 /// corresponds to the given range of mapped values. All mapped values are
895 /// expected to be compatible with the type of the result, e.g., if the result
896 /// is an operation handle, all mapped values are expected to be payload
897 /// operations.
899
900 /// Sets the currently unset results to empty lists of the kind expected by
901 /// the corresponding results of the given `transform` op.
902 void setRemainingToEmpty(TransformOpInterface transform);
903
904private:
905 /// Creates an instance of TransformResults that expects mappings for
906 /// `numSegments` values, which may be associated with payload operations or
907 /// parameters.
908 explicit TransformResults(unsigned numSegments);
909
910 /// Gets the list of operations associated with the result identified by its
911 /// number in the list of operation results. The result must have been set to
912 /// be associated with payload IR operations.
913 ArrayRef<Operation *> get(unsigned resultNumber) const;
914
915 /// Gets the list of parameters associated with the result identified by its
916 /// number in the list of operation results. The result must have been set to
917 /// be associated with parameters.
918 ArrayRef<TransformState::Param> getParams(unsigned resultNumber) const;
919
920 /// Gets the list of payload IR values associated with the result identified
921 /// by its number in the list of operation results. The result must have been
922 /// set to be associated with payload IR values.
923 ArrayRef<Value> getValues(unsigned resultNumber) const;
924
925 /// Returns `true` if the result identified by its number in the list of
926 /// operation results is associated with a list of parameters, `false`
927 /// otherwise.
928 bool isParam(unsigned resultNumber) const;
929
930 /// Returns `true` if the result identified by its number in the list of
931 /// operation results is associated with a list of payload IR value, `false`
932 /// otherwise.
933 bool isValue(unsigned resultNumber) const;
934
935 /// Returns `true` if the result identified by its number in the list of
936 /// operation results is associated with something.
937 bool isSet(unsigned resultNumber) const;
938
939 /// Pointers to payload IR ops that are associated with results of a transform
940 /// IR op.
941 RaggedArray<Operation *> operations;
942
943 /// Parameters that are associated with results of the transform IR op.
944 RaggedArray<Param> params;
945
946 /// Payload IR values that are associated with results of a transform IR op.
947 RaggedArray<Value> values;
948};
949
950/// Creates a RAII object the lifetime of which corresponds to the new mapping
951/// for transform IR values defined in the given region. Values defined in
952/// surrounding regions remain accessible.
956
957/// A configuration object for customizing a `TrackingListener`.
959 using SkipHandleFn = std::function<bool(Value)>;
960
961 /// An optional function that returns "true" for handles that do not have to
962 /// be updated. These are typically dead or consumed handles.
964
965 /// If set to "true", the name of a replacement op must match the name of the
966 /// original op. If set to "false", the names of the payload ops tracked in a
967 /// handle may change as the tracking listener updates the transform state.
969
970 /// If set to "true", cast ops (that implement the CastOpInterface) are
971 /// skipped and the replacement op search continues with the operands of the
972 /// cast op.
973 bool skipCastOps = true;
974};
975
976/// A listener that updates a TransformState based on IR modifications. This
977/// listener can be used during a greedy pattern rewrite to keep the transform
978/// state up-to-date.
981public:
982 /// Create a new TrackingListener for usage in the specified transform op.
983 /// Optionally, a function can be specified to identify handles that should
984 /// do not have to be updated.
985 TrackingListener(TransformState &state, TransformOpInterface op,
987
988protected:
989 /// Return a replacement payload op for the given op, which is going to be
990 /// replaced with the given values. By default, if all values are defined by
991 /// the same op, which also has the same type as the given op, that defining
992 /// op is used as a replacement.
993 ///
994 /// A "failure" return value indicates that no replacement operation could be
995 /// found. A "nullptr" return value indicates that no replacement op is needed
996 /// (e.g., handle is dead or was consumed) and that the payload op should
997 /// be dropped from the mapping.
998 ///
999 /// Example: A tracked "linalg.generic" with two results is replaced with two
1000 /// values defined by (another) "linalg.generic". It is reasonable to assume
1001 /// that the replacement "linalg.generic" represents the same "computation".
1002 /// Therefore, the payload op mapping is updated to the defining op of the
1003 /// replacement values.
1004 ///
1005 /// Counter Example: A "linalg.generic" is replaced with values defined by an
1006 /// "scf.for". Without further investigation, the relationship between the
1007 /// "linalg.generic" and the "scf.for" is unclear. They may not represent the
1008 /// same computation; e.g., there may be tiled "linalg.generic" inside the
1009 /// loop body that represents the original computation. Therefore, the
1010 /// TrackingListener is conservative by default: it drops the mapping and
1011 /// triggers the "payload replacement not found" notification. This default
1012 /// behavior can be customized in `TrackingListenerConfig`.
1013 ///
1014 /// If no replacement op could be found according to the rules mentioned
1015 /// above, this function tries to skip over cast-like ops that implement
1016 /// `CastOpInterface`.
1017 ///
1018 /// Example: A tracked "linalg.generic" is replaced with "linalg.generic",
1019 /// wrapped in a "tensor.cast". A cast is a metadata-only operation and it is
1020 /// reasonable to assume that the wrapped "linalg.generic" represents the same
1021 /// computation as the original "linalg.generic". The mapping is updated
1022 /// accordingly.
1023 ///
1024 /// Certain ops (typically also metadata-only ops) are not considered casts,
1025 /// but should be skipped nonetheless. Such ops should implement
1026 /// `FindPayloadReplacementOpInterface` to specify with which operands the
1027 /// lookup should continue.
1028 ///
1029 /// Example: A tracked "linalg.generic" is replaced with "linalg.generic",
1030 /// wrapped in a "tensor.reshape". A reshape is a metadata-only operation but
1031 /// not cast. (Implementing `CastOpInterface` would be incorrect and cause
1032 /// invalid foldings.) However, due to its `FindPayloadReplacementOpInterface`
1033 /// implementation, the replacement op lookup continues with the wrapped
1034 /// "linalg.generic" and the mapping is updated accordingly.
1035 ///
1036 /// Derived classes may override `findReplacementOp` to specify custom
1037 /// replacement rules.
1040 ValueRange newValues) const;
1041
1042 /// Notify the listener that the pattern failed to match the given operation,
1043 /// and provide a callback to populate a diagnostic with the reason why the
1044 /// failure occurred.
1045 void
1047 function_ref<void(Diagnostic &)> reasonCallback) override;
1048
1049 /// This function is called when a tracked payload op is dropped because no
1050 /// replacement op was found. Derived classes can implement this function for
1051 /// custom error handling.
1052 virtual void
1055
1056 /// Return the single op that defines all given values (if any).
1058
1059 /// Return the transform op in which this TrackingListener is used.
1060 TransformOpInterface getTransformOp() const { return transformOp; }
1061
1062private:
1063 friend class TransformRewriter;
1064
1065 void notifyOperationErased(Operation *op) override;
1066
1067 void notifyOperationReplaced(Operation *op, ValueRange newValues) override;
1068 using Listener::notifyOperationReplaced;
1069
1070 /// The transform op in which this TrackingListener is used.
1071 TransformOpInterface transformOp;
1072
1073 /// The handles that are consumed by the transform op.
1074 DenseSet<Value> consumedHandles;
1075
1076 /// Tracking listener configuration.
1078};
1079
1080/// A specialized listener that keeps track of cases in which no replacement
1081/// payload could be found. The error state of this listener must be checked
1082/// before the end of its lifetime.
1084public:
1086
1088
1089 /// Check and return the current error state of this listener. Afterwards,
1090 /// resets the error state to "success".
1092
1093 /// Return the latest match notification message. Returns an empty string
1094 /// when no error message was captured.
1095 std::string getLatestMatchFailureMessage();
1096
1097 /// Return "true" if this tracking listener had a failure.
1098 bool failed() const;
1099
1100protected:
1101 void
1103 function_ref<void(Diagnostic &)> reasonCallback) override;
1104
1105 void
1108
1109private:
1110 /// The error state of this listener. "Success" indicates that no error
1111 /// happened so far.
1113
1114 /// The number of errors that have been encountered.
1115 int64_t errorCounter = 0;
1116
1117 /// Latest message from match failure notification.
1118 std::optional<Diagnostic> matchFailure;
1119};
1120
1121/// This is a special rewriter to be used in transform op implementations,
1122/// providing additional helper functions to update the transform state, etc.
1123// TODO: Helper functions will be added in a subsequent change.
1125protected:
1126 friend class TransformState;
1127
1128 /// Create a new TransformRewriter.
1129 explicit TransformRewriter(MLIRContext *ctx,
1131
1132public:
1133 /// Return "true" if the tracking listener had failures.
1134 bool hasTrackingFailures() const;
1135
1136 /// Silence all tracking failures that have been encountered so far.
1138
1139 /// Notify the transform dialect interpreter that the given op has been
1140 /// replaced with another op and that the mapping between handles and payload
1141 /// ops/values should be updated. This function should be called before the
1142 /// original op is erased. It fails if the operation could not be replaced,
1143 /// e.g., because the original operation is not tracked.
1144 ///
1145 /// Note: As long as IR modifications are performed through this rewriter,
1146 /// the transform state is usually updated automatically. This function should
1147 /// be used when unsupported rewriter API is used; e.g., updating all uses of
1148 /// a tracked operation one-by-one instead of using `RewriterBase::replaceOp`.
1149 LogicalResult notifyPayloadOperationReplaced(Operation *op,
1151
1152private:
1153 ErrorCheckingTrackingListener *const listener;
1154};
1155
1156/// This trait is supposed to be attached to Transform dialect operations that
1157/// can be standalone top-level transforms. Such operations typically contain
1158/// other Transform dialect operations that can be executed following some
1159/// control flow logic specific to the current operation. The operations with
1160/// this trait are expected to have at least one single-block region with at
1161/// least one argument of type implementing TransformHandleTypeInterface. The
1162/// operations are also expected to be valid without operands, in which case
1163/// they are considered top-level, and with one or more arguments, in which case
1164/// they are considered nested. Top-level operations have the block argument of
1165/// the entry block in the Transform IR correspond to the root operation of
1166/// Payload IR. Nested operations have the block argument of the entry block in
1167/// the Transform IR correspond to a list of Payload IR operations mapped to the
1168/// first operand of the Transform IR operation. The operation must implement
1169/// TransformOpInterface.
1170template <typename OpTy>
1172 : public OpTrait::TraitBase<OpTy, PossibleTopLevelTransformOpTrait> {
1173public:
1174 /// Verifies that `op` satisfies the invariants of this trait. Not expected to
1175 /// be called directly.
1176 static LogicalResult verifyTrait(Operation *op) {
1178 }
1179
1180 /// Returns the single block of the given region.
1181 Block *getBodyBlock(unsigned region = 0) {
1182 return &this->getOperation()->getRegion(region).front();
1183 }
1184
1185 /// Populates `effects` with side effects implied by this trait.
1188 Region &region = this->getOperation()->getRegion(0);
1189 if (region.empty())
1190 return;
1192 this->getOperation(), cast<OpTy>(this->getOperation()).getRoot(),
1193 region.front(), effects);
1194 }
1195
1196 /// Sets up the mapping between the entry block of the given region of this op
1197 /// and the relevant list of Payload IR operations in the given state. The
1198 /// state is expected to be already scoped at the region of this operation.
1199 LogicalResult mapBlockArguments(TransformState &state, Region &region) {
1200 assert(region.getParentOp() == this->getOperation() &&
1201 "op comes from the wrong region");
1203 state, this->getOperation(), region);
1204 }
1205 LogicalResult mapBlockArguments(TransformState &state) {
1206 assert(
1207 this->getOperation()->getNumRegions() == 1 &&
1208 "must indicate the region to map if the operation has more than one");
1209 return mapBlockArguments(state, this->getOperation()->getRegion(0));
1210 }
1211};
1212
1213class ApplyToEachResultList;
1214
1215/// Trait implementing the TransformOpInterface for operations applying a
1216/// transformation to a single operation handle and producing an arbitrary
1217/// number of handles and parameter values.
1218/// The op must implement a method with the following signature:
1219/// - DiagnosedSilenceableFailure applyToOne(OpTy,
1220/// ApplyToEachResultList &results, TransformState &state)
1221/// to perform a transformation that is applied in turn to all payload IR
1222/// operations that correspond to the handle of the transform IR operation.
1223/// In `applyToOne`, OpTy is either Operation* or a concrete payload IR Op class
1224/// that the transformation is applied to (and NOT the class of the transform IR
1225/// op).
1226/// The `applyToOne` method takes an empty `results` vector that it fills with
1227/// zero, one or multiple operations depending on the number of results expected
1228/// by the transform op.
1229/// The number of results must match the number of results of the transform op.
1230/// `applyToOne` is allowed to fill the `results` with all null elements to
1231/// signify that the transformation did not apply to the payload IR operations.
1232/// Such null elements are filtered out from results before return.
1233///
1234/// The transform op having this trait is expected to have a single operand.
1235template <typename OpTy>
1237 : public OpTrait::TraitBase<OpTy, TransformEachOpTrait> {
1238public:
1239 /// Calls `applyToOne` for every payload operation associated with the operand
1240 /// of this transform IR op, the following case disjunction happens:
1241 /// 1. If not target payload ops are associated to the operand then fill the
1242 /// results vector with the expected number of null elements and return
1243 /// success. This is the corner case handling that allows propagating
1244 /// the "no-op" case gracefully to improve usability.
1245 /// 2. If any `applyToOne` returns definiteFailure, the transformation is
1246 /// immediately considered definitely failed and we return.
1247 /// 3. All applications of `applyToOne` are checked to return a number of
1248 /// results expected by the transform IR op. If not, this is a definite
1249 /// failure and we return early.
1250 /// 4. If `applyToOne` produces ops, associate them with the result of this
1251 /// transform op.
1252 /// 5. If any `applyToOne` return silenceableFailure, the transformation is
1253 /// considered silenceable.
1254 /// 6. Otherwise the transformation is considered successful.
1256 TransformResults &transformResults,
1257 TransformState &state);
1258
1259 /// Checks that the op matches the expectations of this trait.
1260 static LogicalResult verifyTrait(Operation *op);
1261};
1262
1263/// Side effect resource corresponding to the mapping between Transform IR
1264/// values and Payload IR operations. An Allocate effect from this resource
1265/// means creating a new mapping entry, it is always accompanied by a Write
1266/// effect. A Read effect from this resource means accessing the mapping. A Free
1267/// effect on this resource indicates the removal of the mapping entry,
1268/// typically after a transformation that modifies the Payload IR operations
1269/// associated with one of the Transform IR operation's operands. It is always
1270/// accompanied by a Read effect. Read-after-Free and double-Free are not
1271/// allowed (they would be problematic with "regular" memory effects too) as
1272/// they indicate an attempt to access Payload IR operations that have been
1273/// modified, potentially erased, by the previous transformations.
1274// TODO: consider custom effects if these are not enabling generic passes such
1275// as CSE/DCE to work.
1277 : public SideEffects::Resource::Base<TransformMappingResource> {
1278 StringRef getName() const override { return "transform.mapping"; }
1279};
1280
1281/// Side effect resource corresponding to the Payload IR itself. Only Read and
1282/// Write effects are expected on this resource, with Write always accompanied
1283/// by a Read (short of fully replacing the top-level Payload IR operation, one
1284/// cannot modify the Payload IR without reading it first). This is intended
1285/// to disallow reordering of Transform IR operations that mutate the Payload IR
1286/// while still allowing the reordering of those that only access it.
1288 : public SideEffects::Resource::Base<PayloadIRResource> {
1289 StringRef getName() const override { return "transform.payload_ir"; }
1290};
1291
1292/// Populates `effects` with the memory effects indicating the operation on the
1293/// given handle value:
1294/// - consumes = Read + Free,
1295/// - produces = Allocate + Write,
1296/// - onlyReads = Read.
1299void producesHandle(ResultRange handles,
1305
1306/// Checks whether the transform op consumes the given handle.
1307bool isHandleConsumed(Value handle, transform::TransformOpInterface transform);
1308
1309/// Populates `effects` with the memory effects indicating the access to payload
1310/// IR resource.
1313
1314/// Checks whether the transform op modifies the payload.
1315bool doesModifyPayload(transform::TransformOpInterface transform);
1316/// Checks whether the transform op reads the payload.
1317bool doesReadPayload(transform::TransformOpInterface transform);
1318
1319/// Populates `consumedArguments` with positions of `block` arguments that are
1320/// consumed by the operations in the `block`.
1322 Block &block, llvm::SmallDenseSet<unsigned> &consumedArguments);
1323
1324/// Trait implementing the MemoryEffectOpInterface for operations that "consume"
1325/// their operands and produce new results.
1326template <typename OpTy>
1328 : public OpTrait::TraitBase<OpTy, FunctionalStyleTransformOpTrait> {
1329public:
1330 /// This op "consumes" the operands by reading and freeing then, "produces"
1331 /// the results by allocating and writing it and reads/writes the payload IR
1332 /// in the process.
1334 consumesHandle(this->getOperation()->getOpOperands(), effects);
1335 producesHandle(this->getOperation()->getOpResults(), effects);
1336 modifiesPayload(effects);
1337 }
1338
1339 /// Checks that the op matches the expectations of this trait.
1340 static LogicalResult verifyTrait(Operation *op) {
1341 if (!op->getName().getInterface<MemoryEffectOpInterface>()) {
1342 op->emitError()
1343 << "FunctionalStyleTransformOpTrait should only be attached to ops "
1344 "that implement MemoryEffectOpInterface";
1345 }
1346 return success();
1347 }
1348};
1349
1350/// Trait implementing the MemoryEffectOpInterface for operations that use their
1351/// operands without consuming and without modifying the Payload IR to
1352/// potentially produce new handles.
1353template <typename OpTy>
1355 : public OpTrait::TraitBase<OpTy, NavigationTransformOpTrait> {
1356public:
1357 /// This op produces handles to the Payload IR without consuming the original
1358 /// handles and without modifying the IR itself.
1360 onlyReadsHandle(this->getOperation()->getOpOperands(), effects);
1361 producesHandle(this->getOperation()->getOpResults(), effects);
1362 if (llvm::any_of(this->getOperation()->getOperandTypes(), [](Type t) {
1363 return isa<TransformHandleTypeInterface,
1364 TransformValueHandleTypeInterface>(t);
1365 })) {
1366 onlyReadsPayload(effects);
1367 }
1368 }
1369
1370 /// Checks that the op matches the expectation of this trait.
1371 static LogicalResult verifyTrait(Operation *op) {
1372 if (!op->getName().getInterface<MemoryEffectOpInterface>()) {
1373 op->emitError() << "NavigationTransformOpTrait should only be attached "
1374 "to ops that implement MemoryEffectOpInterface";
1375 }
1376 return success();
1377 }
1378};
1379
1380namespace detail {
1381/// Non-template implementation of ParamProducerTransformOpTrait::getEffects().
1384/// Non-template implementation of ParamProducerTransformOpTrait::verify().
1386} // namespace detail
1387
1388/// Trait implementing the MemoryEffectsOpInterface for operations that produce
1389/// transform dialect parameters. It marks all op results of
1390/// TransformHandleTypeInterface as produced by the op, all operands as only
1391/// read by the op and, if at least one of the operand is a handle to payload
1392/// ops, the entire payload as potentially read. The op must only produce
1393/// parameter-typed results.
1394template <typename OpTy>
1396 : public OpTrait::TraitBase<OpTy, ParamProducerTransformOpTrait> {
1397public:
1398 /// Populates `effects` with effect instances described in the trait
1399 /// documentation.
1404
1405 /// Checks that the op matches the expectation of this trait, i.e., that it
1406 /// implements the MemoryEffectsOpInterface and only produces parameter-typed
1407 /// results.
1408 static LogicalResult verifyTrait(Operation *op) {
1410 }
1411};
1412
1413/// `TrackingListener` failures are reported only for ops that have this trait.
1414/// The purpose of this trait is to give users more time to update their custom
1415/// transform ops to use the provided `TransformRewriter` for all IR
1416/// modifications. This trait will eventually be removed, and failures will be
1417/// reported for all transform ops.
1418template <typename OpTy>
1420 : public OpTrait::TraitBase<OpTy, ReportTrackingListenerFailuresOpTrait> {};
1421
1422/// A single result of applying a transform op with `ApplyEachOpTrait` to a
1423/// single payload operation.
1425
1426/// A list of results of applying a transform op with `ApplyEachOpTrait` to a
1427/// single payload operation, co-indexed with the results of the transform op.
1429public:
1431 explicit ApplyToEachResultList(unsigned size) : results(size) {}
1432
1433 /// Sets the list of results to `size` null pointers.
1434 void assign(unsigned size, std::nullptr_t) { results.assign(size, nullptr); }
1435
1436 /// Sets the list of results to the given range of values.
1437 template <typename Range>
1438 void assign(Range &&range) {
1439 // This is roughly the implementation of SmallVectorImpl::assign.
1440 // Dispatching to it with map_range and template type inference would result
1441 // in more complex code here.
1442 results.clear();
1443 results.reserve(llvm::size(range));
1444 for (auto element : range) {
1445 if constexpr (std::is_convertible_v<decltype(*std::begin(range)),
1446 Operation *>) {
1447 results.push_back(static_cast<Operation *>(element));
1448 } else if constexpr (std::is_convertible_v<decltype(*std::begin(range)),
1449 Value>) {
1450 results.push_back(element.template get<Value>());
1451 } else {
1452 results.push_back(static_cast<Attribute>(element));
1453 }
1454 }
1455 }
1456
1457 /// Appends an element to the list.
1458 // Using ApplyToEachResult that can be implicitly constructed from a Value but
1459 // not from a concrete Op that is implicitly convertible to a Value to avoid
1460 // ambiguity.
1461 void push_back(Operation *op) { results.push_back(op); }
1462 void push_back(Attribute attr) { results.push_back(attr); }
1463 void push_back(ApplyToEachResult r) { results.push_back(r); }
1464
1465 /// Reserves space for `size` elements in the list.
1466 void reserve(unsigned size) { results.reserve(size); }
1467
1468 /// Iterators over the list.
1469 auto begin() { return results.begin(); }
1470 auto end() { return results.end(); }
1471 auto begin() const { return results.begin(); }
1472 auto end() const { return results.end(); }
1473
1474 /// Returns the number of elements in the list.
1475 size_t size() const { return results.size(); }
1476
1477 /// Element access. Expects the index to be in bounds.
1478 ApplyToEachResult &operator[](size_t index) { return results[index]; }
1479 const ApplyToEachResult &operator[](size_t index) const {
1480 return results[index];
1481 }
1482
1483private:
1484 /// Underlying storage.
1486};
1487
1488namespace detail {
1489
1490/// Check that the contents of `partialResult` matches the number, kind (payload
1491/// op or parameter) and nullity (either all or none) requirements of
1492/// `transformOp`. Report errors and return failure otherwise.
1493LogicalResult checkApplyToOne(Operation *transformOp, Location payloadOpLoc,
1494 const ApplyToEachResultList &partialResult);
1495
1496/// "Transpose" the results produced by individual applications, arranging them
1497/// per result value of the transform op, and populate `transformResults` with
1498/// that. The number, kind and nullity of per-application results are assumed to
1499/// have been verified.
1500void setApplyToOneResults(Operation *transformOp,
1501 TransformResults &transformResults,
1503
1504/// Applies a one-to-one or a one-to-many transform to each of the given
1505/// targets. Puts the results of transforms, if any, in `results` in the same
1506/// order. Fails if any of the application fails. Individual transforms must be
1507/// callable with the following signature:
1508/// - DiagnosedSilenceableFailure(OpTy,
1509/// SmallVector<Operation*> &results, state)
1510/// where OpTy is either
1511/// - Operation *, in which case the transform is always applied;
1512/// - a concrete Op class, in which case a check is performed whether
1513/// `targets` contains operations of the same class and a silenceable failure
1514/// is reported if it does not.
1515template <typename TransformOpTy, typename Range>
1517 TransformOpTy transformOp, TransformRewriter &rewriter, Range &&targets,
1519 using OpTy = typename llvm::function_traits<
1520 decltype(&TransformOpTy::applyToOne)>::template arg_t<1>;
1521 static_assert(std::is_convertible<OpTy, Operation *>::value,
1522 "expected transform function to take an operation");
1523 OpBuilder::InsertionGuard g(rewriter);
1524
1525 SmallVector<Diagnostic> silenceableStack;
1526 unsigned expectedNumResults = transformOp->getNumResults();
1527 for (Operation *target : targets) {
1528 auto specificOp = dyn_cast<OpTy>(target);
1529 if (!specificOp) {
1530 Diagnostic diag(transformOp->getLoc(), DiagnosticSeverity::Error);
1531 diag << "transform applied to the wrong op kind";
1532 diag.attachNote(target->getLoc()) << "when applied to this op";
1533 silenceableStack.push_back(std::move(diag));
1534 continue;
1535 }
1536
1537 ApplyToEachResultList partialResults;
1538 partialResults.reserve(expectedNumResults);
1539 Location specificOpLoc = specificOp->getLoc();
1540 rewriter.setInsertionPoint(specificOp);
1542 transformOp.applyToOne(rewriter, specificOp, partialResults, state);
1543 if (res.isDefiniteFailure())
1545
1546 if (res.isSilenceableFailure()) {
1547 res.takeDiagnostics(silenceableStack);
1548 continue;
1549 }
1550
1551 if (failed(detail::checkApplyToOne(transformOp, specificOpLoc,
1552 partialResults))) {
1554 }
1555 results.push_back(std::move(partialResults));
1556 }
1557 if (!silenceableStack.empty()) {
1559 std::move(silenceableStack));
1560 }
1562}
1563
1564/// Reports an error and returns failure if `targets` contains an ancestor
1565/// operation before its descendant (or a copy of itself). Implementation detail
1566/// for expensive checks during `TransformEachOpTrait::apply`.
1567LogicalResult checkNestedConsumption(Location loc,
1568 ArrayRef<Operation *> targets);
1569
1570} // namespace detail
1571} // namespace transform
1572} // namespace mlir
1573
1574template <typename OpTy>
1577 TransformRewriter &rewriter, TransformResults &transformResults,
1578 TransformState &state) {
1579 Value handle = this->getOperation()->getOperand(0);
1580 auto targets = state.getPayloadOps(handle);
1581
1582 // If the operand is consumed, check if it is associated with operations that
1583 // may be erased before their nested operations are.
1584 if (state.getOptions().getExpensiveChecksEnabled() &&
1585 isHandleConsumed(handle, cast<transform::TransformOpInterface>(
1586 this->getOperation())) &&
1587 failed(detail::checkNestedConsumption(this->getOperation()->getLoc(),
1588 llvm::to_vector(targets)))) {
1590 }
1591
1592 // Step 1. Handle the corner case where no target is specified.
1593 // This is typically the case when the matcher fails to apply and we need to
1594 // propagate gracefully.
1595 // In this case, we fill all results with an empty vector.
1596 if (std::empty(targets)) {
1597 SmallVector<Operation *> emptyPayload;
1598 SmallVector<Attribute> emptyParams;
1599 for (OpResult r : this->getOperation()->getResults()) {
1600 if (isa<TransformParamTypeInterface>(r.getType()))
1601 transformResults.setParams(r, emptyParams);
1602 else if (isa<TransformValueHandleTypeInterface>(r.getType()))
1603 transformResults.setValues(r, ValueRange());
1604 else
1605 transformResults.set(r, emptyPayload);
1606 }
1608 }
1609
1610 // Step 2. Call applyToOne on each target and record newly produced ops in its
1611 // corresponding results entry.
1614 cast<OpTy>(this->getOperation()), rewriter, targets, results, state);
1615
1616 // Step 3. Propagate the definite failure if any and bail out.
1617 if (result.isDefiniteFailure())
1618 return result;
1619
1620 // Step 4. "Transpose" the results produced by individual applications,
1621 // arranging them per result value of the transform op. The number, kind and
1622 // nullity of per-application results have been verified by the callback
1623 // above.
1624 detail::setApplyToOneResults(this->getOperation(), transformResults, results);
1625
1626 // Step 5. ApplyToOne may have returned silenceableFailure, propagate it.
1627 return result;
1628}
1629
1630template <typename OpTy>
1631llvm::LogicalResult
1633 static_assert(OpTy::template hasTrait<OpTrait::OneOperand>(),
1634 "expected single-operand op");
1635 if (!op->getName().getInterface<TransformOpInterface>()) {
1636 return op->emitError() << "TransformEachOpTrait should only be attached to "
1637 "ops that implement TransformOpInterface";
1638 }
1639
1640 return success();
1641}
1642
1643namespace llvm {
1644template <>
1645struct PointerLikeTypeTraits<mlir::transform::NormalFormAttrInterface>
1646 : public PointerLikeTypeTraits<mlir::Attribute> {
1647 static inline mlir::transform::NormalFormAttrInterface
1649 return cast<mlir::transform::NormalFormAttrInterface>(
1651 }
1652};
1653} // namespace llvm
1654
1655#endif // DIALECT_TRANSFORM_INTERFACES_TRANSFORMINTERFACES_H
return success()
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Definition Attributes.h:25
static Attribute getFromOpaquePointer(const void *ptr)
Construct an attribute from the opaque pointer representation.
Definition Attributes.h:75
This class represents an argument of a Block.
Definition Value.h:306
Block represents an ordered list of Operations.
Definition Block.h:33
MutableArrayRef< BlockArgument > BlockArgListType
Definition Block.h:95
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
void takeDiagnostics(SmallVectorImpl< Diagnostic > &diags)
Take the diagnostics and silence.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
bool isSilenceableFailure() const
Returns true if this is a silenceable failure.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
This is a value defined by a result of an operation.
Definition Value.h:454
unsigned getResultNumber() const
Returns the number of this result.
Definition Value.h:466
Helper class for implementing traits.
T::Concept * getInterface() const
Returns an instance of the concept object for the given interface if it was registered to this operat...
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:712
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:116
A 2D array where each row may have different length.
Definition RaggedArray.h:18
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Definition Region.cpp:45
Block & back()
Definition Region.h:64
bool empty()
Definition Region.h:60
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition Region.h:200
This class implements the result iterators for the Operation class.
Definition ValueRange.h:248
RewriterBase(MLIRContext *ctx, OpBuilder::Listener *listener=nullptr)
Initialize the builder.
This base class is used for derived effects that are non-parametric.
static TypeID get()
Construct a type info object for the given type T.
Definition TypeID.h:245
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition Value.cpp:39
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
ApplyToEachResult & operator[](size_t index)
Element access. Expects the index to be in bounds.
void assign(unsigned size, std::nullptr_t)
Sets the list of results to size null pointers.
void reserve(unsigned size)
Reserves space for size elements in the list.
size_t size() const
Returns the number of elements in the list.
void push_back(Operation *op)
Appends an element to the list.
const ApplyToEachResult & operator[](size_t index) const
void assign(Range &&range)
Sets the list of results to the given range of values.
A specialized listener that keeps track of cases in which no replacement payload could be found.
bool failed() const
Return "true" if this tracking listener had a failure.
TrackingListener(TransformState &state, TransformOpInterface op, TrackingListenerConfig config=TrackingListenerConfig())
Create a new TrackingListener for usage in the specified transform op.
std::string getLatestMatchFailureMessage()
Return the latest match notification message.
void notifyPayloadReplacementNotFound(Operation *op, ValueRange values, DiagnosedSilenceableFailure &&diag) override
This function is called when a tracked payload op is dropped because no replacement op was found.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notify the listener that the pattern failed to match, and provide a callback to populate a diagnostic...
DiagnosedSilenceableFailure checkAndResetError()
Check and return the current error state of this listener.
Trait implementing the MemoryEffectOpInterface for operations that "consume" their operands and produ...
void getEffects(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
This op "consumes" the operands by reading and freeing then, "produces" the results by allocating and...
static LogicalResult verifyTrait(Operation *op)
Checks that the op matches the expectations of this trait.
Trait implementing the MemoryEffectOpInterface for operations that use their operands without consumi...
void getEffects(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
This op produces handles to the Payload IR without consuming the original handles and without modifyi...
static LogicalResult verifyTrait(Operation *op)
Checks that the op matches the expectation of this trait.
Trait implementing the MemoryEffectsOpInterface for operations that produce transform dialect paramet...
static LogicalResult verifyTrait(Operation *op)
Checks that the op matches the expectation of this trait, i.e., that it implements the MemoryEffectsO...
void getEffects(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with effect instances described in the trait documentation.
This trait is supposed to be attached to Transform dialect operations that can be standalone top-leve...
LogicalResult mapBlockArguments(TransformState &state)
Block * getBodyBlock(unsigned region=0)
Returns the single block of the given region.
static LogicalResult verifyTrait(Operation *op)
Verifies that op satisfies the invariants of this trait.
void getPotentialTopLevelEffects(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with side effects implied by this trait.
LogicalResult mapBlockArguments(TransformState &state, Region &region)
Sets up the mapping between the entry block of the given region of this op and the relevant list of P...
TrackingListener failures are reported only for ops that have this trait.
void notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
Notify the listener that the pattern failed to match the given operation, and provide a callback to p...
TransformOpInterface getTransformOp() const
Return the transform op in which this TrackingListener is used.
TrackingListener(TransformState &state, TransformOpInterface op, TrackingListenerConfig config=TrackingListenerConfig())
Create a new TrackingListener for usage in the specified transform op.
static Operation * getCommonDefiningOp(ValueRange values)
Return the single op that defines all given values (if any).
virtual void notifyPayloadReplacementNotFound(Operation *op, ValueRange values, DiagnosedSilenceableFailure &&diag)
This function is called when a tracked payload op is dropped because no replacement op was found.
virtual DiagnosedSilenceableFailure findReplacementOp(Operation *&result, Operation *op, ValueRange newValues) const
Return a replacement payload op for the given op, which is going to be replaced with the given values...
Trait implementing the TransformOpInterface for operations applying a transformation to a single oper...
static LogicalResult verifyTrait(Operation *op)
Checks that the op matches the expectations of this trait.
DiagnosedSilenceableFailure apply(transform::TransformRewriter &rewriter, TransformResults &transformResults, TransformState &state)
Calls applyToOne for every payload operation associated with the operand of this transform IR op,...
Options controlling the application of transform operations by the TransformState.
TransformOptions & operator=(const TransformOptions &)=default
TransformOptions & enableExpensiveChecks(bool enable=true)
Requests computationally expensive checks of the transform and payload IR well-formedness to be perfo...
TransformOptions & enableEnforceSingleToplevelTransformOp(bool enable=true)
TransformOptions(const TransformOptions &)=default
bool getExpensiveChecksEnabled() const
Returns true if the expensive checks are requested.
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void set(OpResult value, std::initializer_list< Operation * > ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void setValues(OpResult handle, std::initializer_list< Value > values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
void setValues(OpResult handle, Range &&values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
void setParams(OpResult value, ArrayRef< TransformState::Param > params)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void setRemainingToEmpty(TransformOpInterface transform)
Sets the currently unset results to empty lists of the kind expected by the corresponding results of ...
void setMappedValues(OpResult handle, ArrayRef< MappedValue > values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
TransformRewriter(MLIRContext *ctx, ErrorCheckingTrackingListener *listener)
Create a new TransformRewriter.
bool hasTrackingFailures() const
Return "true" if the tracking listener had failures.
LogicalResult notifyPayloadOperationReplaced(Operation *op, Operation *replacement)
Notify the transform dialect interpreter that the given op has been replaced with another op and that...
void silenceTrackingFailure()
Silence all tracking failures that have been encountered so far.
Base class for TransformState extensions that allow TransformState to contain user-specified informat...
Extension(TransformState &state)
Constructs an extension of the given TransformState object.
LogicalResult replacePayloadOp(Operation *op, Operation *replacement)
Replaces the given payload op with another op.
virtual ~Extension()
Base virtual destructor.
LogicalResult replacePayloadValue(Value value, Value replacement)
Replaces the given payload value with another value.
const TransformState & getTransformState() const
Provides read-only access to the parent TransformState object.
A RAII object maintaining a "stack frame" for a transform IR region.
~RegionScope()
Forgets the mapping from or to values defined in the associated transform IR region,...
The state maintained across applications of various ops implementing the TransformOpInterface.
Ty & addExtension(Args &&...args)
Adds a new Extension of the type specified as template parameter, constructing it with the arguments ...
LogicalResult getHandlesForPayloadValue(Value payloadValue, SmallVectorImpl< Value > &handles, bool includeOutOfScope=false) const
Populates handles with all handles pointing to the given payload IR value.
Ty * getExtension()
Returns the extension of the specified type.
auto getPayloadOps(Value value) const
Returns an iterator that enumerates all ops that the given transform IR value corresponds to.
auto getPayloadValues(Value handleValue) const
Returns an iterator that enumerates all payload IR values that the given transform IR value correspon...
LogicalResult mapBlockArguments(BlockArgument argument, ArrayRef< Operation * > operations)
Records the mapping between a block argument in the transform IR and a list of operations in the payl...
DiagnosedSilenceableFailure applyTransform(TransformOpInterface transform)
Applies the transformation specified by the given transform op and updates the state accordingly.
RegionScope make_region_scope(Region &region)
Creates a new region scope for the given region.
ArrayRef< Attribute > getParams(Value value) const
Returns the list of parameters that the given transform IR value corresponds to.
LogicalResult mapBlockArgument(BlockArgument argument, ArrayRef< MappedValue > values)
friend LogicalResult applyTransforms(Operation *, TransformOpInterface, const RaggedArray< MappedValue > &, const TransformOptions &, bool, function_ref< void(TransformState &)>, function_ref< LogicalResult(TransformState &)>)
Entry point to the Transform dialect infrastructure.
size_t getNumTopLevelMappings() const
Returns the number of extra mappings for the top-level operation.
const TransformOptions & getOptions() const
LogicalResult getHandlesForPayloadOp(Operation *op, SmallVectorImpl< Value > &handles, bool includeOutOfScope=false) const
Populates handles with all handles pointing to the given Payload IR op.
Operation * getTopLevel() const
Returns the op at which the transformation state is rooted.
void removeExtension()
Removes the extension of the specified type.
ArrayRef< MappedValue > getTopLevelMapping(size_t position) const
Returns the position-th extra mapping for the top-level operation.
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Definition CallGraph.h:229
AttrTypeReplacer.
LogicalResult verifyTransformOpInterface(Operation *op)
Verification hook for TransformOpInterface.
SmallVector< OpOperand * > getConsumedHandleOpOperands(transform::TransformOpInterface transformOp)
Returns all operands that are handles and being consumed by the given op.
mlir::DiagnosedSilenceableFailure checkNormalForms(llvm::ArrayRef< mlir::transform::NormalFormAttrInterface > normalForms, llvm::ArrayRef< mlir::Operation * > payload)
Checks that the given payload operations satisfy the normal form constraints, reports the first encou...
LogicalResult appendValueMappings(MutableArrayRef< SmallVector< transform::MappedValue > > mappings, ValueRange values, const transform::TransformState &state, bool flatten=true)
Appends the entities associated with the given transform values in state to the pre-existing list of ...
llvm::LogicalResult verifyNormalFormList(llvm::function_ref< mlir::InFlightDiagnostic()> emitError, llvm::ArrayRef< mlir::transform::NormalFormAttrInterface > normalForms)
Checks that the given list does not contain duplicate normal forms of the same class.
void setApplyToOneResults(Operation *transformOp, TransformResults &transformResults, ArrayRef< ApplyToEachResultList > results)
"Transpose" the results produced by individual applications, arranging them per result value of the t...
void forwardTerminatorOperands(Block *block, transform::TransformState &state, transform::TransformResults &results)
Populates results with payload associations that match exactly those of the operands to block's termi...
void getParamProducerTransformOpTraitEffects(Operation *op, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Non-template implementation of ParamProducerTransformOpTrait::getEffects().
LogicalResult checkNestedConsumption(Location loc, ArrayRef< Operation * > targets)
Reports an error and returns failure if targets contains an ancestor operation before its descendant ...
LogicalResult mapPossibleTopLevelTransformOpBlockArguments(TransformState &state, Operation *op, Region &region)
Maps the only block argument of the op with PossibleTopLevelTransformOpTrait to either the list of op...
TransformState makeTransformStateForTesting(Region *region, Operation *payloadRoot)
Make a dummy transform state for testing purposes.
LogicalResult checkApplyToOne(Operation *transformOp, Location payloadOpLoc, const ApplyToEachResultList &partialResult)
Check that the contents of partialResult matches the number, kind (payload op or parameter) and nulli...
LogicalResult verifyParamProducerTransformOpTrait(Operation *op)
Non-template implementation of ParamProducerTransformOpTrait::verify().
void prepareValueMappings(SmallVectorImpl< SmallVector< transform::MappedValue > > &mappings, ValueRange values, const transform::TransformState &state)
Populates mappings with mapped values associated with the given transform IR values in the given stat...
LogicalResult verifyPossibleTopLevelTransformOpTrait(Operation *op)
Verification hook for PossibleTopLevelTransformOpTrait.
DiagnosedSilenceableFailure applyTransformToEach(TransformOpTy transformOp, TransformRewriter &rewriter, Range &&targets, SmallVectorImpl< ApplyToEachResultList > &results, TransformState &state)
Applies a one-to-one or a one-to-many transform to each of the given targets.
void getPotentialTopLevelEffects(Operation *operation, Value root, Block &body, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with side effects implied by PossibleTopLevelTransformOpTrait for the given operati...
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
LogicalResult applyTransforms(Operation *payloadRoot, TransformOpInterface transform, const RaggedArray< MappedValue > &extraMapping={}, const TransformOptions &options=TransformOptions(), bool enforceToplevelTransformOp=true, function_ref< void(TransformState &)> stateInitializer=nullptr, function_ref< LogicalResult(TransformState &)> stateExporter=nullptr)
Entry point to the Transform dialect infrastructure.
bool isHandleConsumed(Value handle, transform::TransformOpInterface transform)
Checks whether the transform op consumes the given handle.
llvm::PointerUnion< Operation *, Param, Value > MappedValue
void getConsumedBlockArguments(Block &block, llvm::SmallDenseSet< unsigned > &consumedArguments)
Populates consumedArguments with positions of block arguments that are consumed by the operations in ...
MappedValue ApplyToEachResult
A single result of applying a transform op with ApplyEachOpTrait to a single payload operation.
bool doesModifyPayload(transform::TransformOpInterface transform)
Checks whether the transform op modifies the payload.
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
bool doesReadPayload(transform::TransformOpInterface transform)
Checks whether the transform op reads the payload.
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
Include the generated interface declarations.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:122
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:120
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
static mlir::transform::NormalFormAttrInterface getFromVoidPointer(void *p)
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
Side effect resource corresponding to the Payload IR itself.
StringRef getName() const override
Return a string name of the resource.
A configuration object for customizing a TrackingListener.
bool requireMatchingReplacementOpName
If set to "true", the name of a replacement op must match the name of the original op.
bool skipCastOps
If set to "true", cast ops (that implement the CastOpInterface) are skipped and the replacement op se...
SkipHandleFn skipHandleFn
An optional function that returns "true" for handles that do not have to be updated.
Side effect resource corresponding to the mapping between Transform IR values and Payload IR operatio...
StringRef getName() const override
Return a string name of the resource.