MLIR
20.0.0git
|
A listener that updates a TransformState based on IR modifications. More...
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
Public Member Functions | |
TrackingListener (TransformState &state, TransformOpInterface op, TrackingListenerConfig config=TrackingListenerConfig()) | |
Create a new TrackingListener for usage in the specified transform op. More... | |
Public Member Functions inherited from mlir::RewriterBase::Listener | |
Listener () | |
virtual void | notifyBlockErased (Block *block) |
Notify the listener that the specified block is about to be erased. More... | |
virtual void | notifyOperationModified (Operation *op) |
Notify the listener that the specified operation was modified in-place. More... | |
virtual void | notifyOperationReplaced (Operation *op, Operation *replacement) |
Notify the listener that all uses of the specified operation's results are about to be replaced with the results of another operation. More... | |
virtual void | notifyPatternBegin (const Pattern &pattern, Operation *op) |
Notify the listener that the specified pattern is about to be applied at the specified root operation. More... | |
virtual void | notifyPatternEnd (const Pattern &pattern, LogicalResult status) |
Notify the listener that a pattern application finished with the specified status. More... | |
Public Member Functions inherited from mlir::OpBuilder::Listener | |
Listener () | |
virtual | ~Listener ()=default |
virtual void | notifyOperationInserted (Operation *op, InsertPoint previous) |
Notify the listener that the specified operation was inserted. More... | |
virtual void | notifyBlockInserted (Block *block, Region *previous, Region::iterator previousIt) |
Notify the listener that the specified block was inserted. More... | |
Public Member Functions inherited from mlir::OpBuilder::ListenerBase | |
Kind | getKind () const |
Public Member Functions inherited from mlir::transform::TransformState::Extension | |
virtual | ~Extension () |
Base virtual destructor. More... | |
Protected Member Functions | |
virtual DiagnosedSilenceableFailure | findReplacementOp (Operation *&result, Operation *op, ValueRange newValues) const |
Return a replacement payload op for the given op, which is going to be replaced with the given values. More... | |
void | notifyMatchFailure (Location loc, function_ref< void(Diagnostic &)> reasonCallback) override |
Notify the listener that the pattern failed to match the given operation, and provide a callback to populate a diagnostic with the reason why the failure occurred. More... | |
virtual void | notifyPayloadReplacementNotFound (Operation *op, ValueRange values, DiagnosedSilenceableFailure &&diag) |
This function is called when a tracked payload op is dropped because no replacement op was found. More... | |
TransformOpInterface | getTransformOp () const |
Return the transform op in which this TrackingListener is used. More... | |
Protected Member Functions inherited from mlir::OpBuilder::Listener | |
Listener (Kind kind) | |
Protected Member Functions inherited from mlir::OpBuilder::ListenerBase | |
ListenerBase (Kind kind) | |
Protected Member Functions inherited from mlir::transform::TransformState::Extension | |
Extension (TransformState &state) | |
Constructs an extension of the given TransformState object. More... | |
const TransformState & | getTransformState () const |
Provides read-only access to the parent TransformState object. More... | |
LogicalResult | replacePayloadOp (Operation *op, Operation *replacement) |
Replaces the given payload op with another op. More... | |
LogicalResult | replacePayloadValue (Value value, Value replacement) |
Replaces the given payload value with another value. More... | |
Static Protected Member Functions | |
static Operation * | getCommonDefiningOp (ValueRange values) |
Return the single op that defines all given values (if any). More... | |
Friends | |
class | TransformRewriter |
Additional Inherited Members | |
Public Types inherited from mlir::OpBuilder::ListenerBase | |
enum class | Kind { OpBuilderListener = 0 , RewriterBaseListener = 1 } |
The kind of listener. More... | |
Static Public Member Functions inherited from mlir::RewriterBase::Listener | |
static bool | classof (const OpBuilder::Listener *base) |
A listener that updates a TransformState based on IR modifications.
This listener can be used during a greedy pattern rewrite to keep the transform state up-to-date.
Definition at line 963 of file TransformInterfaces.h.
transform::TrackingListener::TrackingListener | ( | TransformState & | state, |
TransformOpInterface | op, | ||
TrackingListenerConfig | config = TrackingListenerConfig() |
||
) |
Create a new TrackingListener for usage in the specified transform op.
Optionally, a function can be specified to identify handles that should do not have to be updated.
Definition at line 1190 of file TransformInterfaces.cpp.
|
protectedvirtual |
Return a replacement payload op for the given op, which is going to be replaced with the given values.
By default, if all values are defined by the same op, which also has the same type as the given op, that defining op is used as a replacement.
A "failure" return value indicates that no replacement operation could be found. A "nullptr" return value indicates that no replacement op is needed (e.g., handle is dead or was consumed) and that the payload op should be dropped from the mapping.
Example: A tracked "linalg.generic" with two results is replaced with two values defined by (another) "linalg.generic". It is reasonable to assume that the replacement "linalg.generic" represents the same "computation". Therefore, the payload op mapping is updated to the defining op of the replacement values.
Counter Example: A "linalg.generic" is replaced with values defined by an "scf.for". Without further investigation, the relationship between the "linalg.generic" and the "scf.for" is unclear. They may not represent the same computation; e.g., there may be tiled "linalg.generic" inside the loop body that represents the original computation. Therefore, the TrackingListener is conservative by default: it drops the mapping and triggers the "payload replacement not found" notification. This default behavior can be customized in TrackingListenerConfig
.
If no replacement op could be found according to the rules mentioned above, this function tries to skip over cast-like ops that implement CastOpInterface
.
Example: A tracked "linalg.generic" is replaced with "linalg.generic", wrapped in a "tensor.cast". A cast is a metadata-only operation and it is reasonable to assume that the wrapped "linalg.generic" represents the same computation as the original "linalg.generic". The mapping is updated accordingly.
Certain ops (typically also metadata-only ops) are not considered casts, but should be skipped nonetheless. Such ops should implement FindPayloadReplacementOpInterface
to specify with which operands the lookup should continue.
Example: A tracked "linalg.generic" is replaced with "linalg.generic", wrapped in a "tensor.reshape". A reshape is a metadata-only operation but not cast. (Implementing CastOpInterface
would be incorrect and cause invalid foldings.) However, due to its FindPayloadReplacementOpInterface
implementation, the replacement op lookup continues with the wrapped "linalg.generic" and the mapping is updated accordingly.
Derived classes may override findReplacementOp
to specify custom replacement rules.
Definition at line 1217 of file TransformInterfaces.cpp.
References mlir::config, diag(), mlir::emitSilenceableFailure(), mlir::Operation::getLoc(), mlir::Operation::getName(), mlir::Operation::getNumResults(), mlir::Operation::getOperands(), mlir::Operation::hasTrait(), and mlir::DiagnosedSilenceableFailure::success().
|
staticprotected |
Return the single op that defines all given values (if any).
Definition at line 1201 of file TransformInterfaces.cpp.
|
inlineprotected |
Return the transform op in which this TrackingListener is used.
Definition at line 1044 of file TransformInterfaces.h.
|
overrideprotectedvirtual |
Notify the listener that the pattern failed to match the given operation, and provide a callback to populate a diagnostic with the reason why the failure occurred.
Reimplemented from mlir::RewriterBase::Listener.
Definition at line 1275 of file TransformInterfaces.cpp.
References DBGS, diag(), and mlir::Remark.
|
inlineprotectedvirtual |
This function is called when a tracked payload op is dropped because no replacement op was found.
Derived classes can implement this function for custom error handling.
Reimplemented in mlir::transform::ErrorCheckingTrackingListener.
Definition at line 1037 of file TransformInterfaces.h.
|
friend |
Definition at line 1047 of file TransformInterfaces.h.