MLIR  20.0.0git
Public Member Functions | Protected Member Functions | Static Protected Member Functions | Friends | List of all members
mlir::transform::TrackingListener Class Reference

A listener that updates a TransformState based on IR modifications. More...

#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"

+ Inheritance diagram for mlir::transform::TrackingListener:

Public Member Functions

 TrackingListener (TransformState &state, TransformOpInterface op, TrackingListenerConfig config=TrackingListenerConfig())
 Create a new TrackingListener for usage in the specified transform op. More...
 
- Public Member Functions inherited from mlir::RewriterBase::Listener
 Listener ()
 
virtual void notifyBlockErased (Block *block)
 Notify the listener that the specified block is about to be erased. More...
 
virtual void notifyOperationModified (Operation *op)
 Notify the listener that the specified operation was modified in-place. More...
 
virtual void notifyOperationReplaced (Operation *op, Operation *replacement)
 Notify the listener that all uses of the specified operation's results are about to be replaced with the results of another operation. More...
 
virtual void notifyPatternBegin (const Pattern &pattern, Operation *op)
 Notify the listener that the specified pattern is about to be applied at the specified root operation. More...
 
virtual void notifyPatternEnd (const Pattern &pattern, LogicalResult status)
 Notify the listener that a pattern application finished with the specified status. More...
 
- Public Member Functions inherited from mlir::OpBuilder::Listener
 Listener ()
 
virtual ~Listener ()=default
 
virtual void notifyOperationInserted (Operation *op, InsertPoint previous)
 Notify the listener that the specified operation was inserted. More...
 
virtual void notifyBlockInserted (Block *block, Region *previous, Region::iterator previousIt)
 Notify the listener that the specified block was inserted. More...
 
- Public Member Functions inherited from mlir::OpBuilder::ListenerBase
Kind getKind () const
 
- Public Member Functions inherited from mlir::transform::TransformState::Extension
virtual ~Extension ()
 Base virtual destructor. More...
 

Protected Member Functions

virtual DiagnosedSilenceableFailure findReplacementOp (Operation *&result, Operation *op, ValueRange newValues) const
 Return a replacement payload op for the given op, which is going to be replaced with the given values. More...
 
void notifyMatchFailure (Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
 Notify the listener that the pattern failed to match the given operation, and provide a callback to populate a diagnostic with the reason why the failure occurred. More...
 
virtual void notifyPayloadReplacementNotFound (Operation *op, ValueRange values, DiagnosedSilenceableFailure &&diag)
 This function is called when a tracked payload op is dropped because no replacement op was found. More...
 
TransformOpInterface getTransformOp () const
 Return the transform op in which this TrackingListener is used. More...
 
- Protected Member Functions inherited from mlir::OpBuilder::Listener
 Listener (Kind kind)
 
- Protected Member Functions inherited from mlir::OpBuilder::ListenerBase
 ListenerBase (Kind kind)
 
- Protected Member Functions inherited from mlir::transform::TransformState::Extension
 Extension (TransformState &state)
 Constructs an extension of the given TransformState object. More...
 
const TransformStategetTransformState () const
 Provides read-only access to the parent TransformState object. More...
 
LogicalResult replacePayloadOp (Operation *op, Operation *replacement)
 Replaces the given payload op with another op. More...
 
LogicalResult replacePayloadValue (Value value, Value replacement)
 Replaces the given payload value with another value. More...
 

Static Protected Member Functions

static OperationgetCommonDefiningOp (ValueRange values)
 Return the single op that defines all given values (if any). More...
 

Friends

class TransformRewriter
 

Additional Inherited Members

- Public Types inherited from mlir::OpBuilder::ListenerBase
enum class  Kind { OpBuilderListener = 0 , RewriterBaseListener = 1 }
 The kind of listener. More...
 
- Static Public Member Functions inherited from mlir::RewriterBase::Listener
static bool classof (const OpBuilder::Listener *base)
 

Detailed Description

A listener that updates a TransformState based on IR modifications.

This listener can be used during a greedy pattern rewrite to keep the transform state up-to-date.

Definition at line 963 of file TransformInterfaces.h.

Constructor & Destructor Documentation

◆ TrackingListener()

transform::TrackingListener::TrackingListener ( TransformState state,
TransformOpInterface  op,
TrackingListenerConfig  config = TrackingListenerConfig() 
)

Create a new TrackingListener for usage in the specified transform op.

Optionally, a function can be specified to identify handles that should do not have to be updated.

Definition at line 1190 of file TransformInterfaces.cpp.

Member Function Documentation

◆ findReplacementOp()

DiagnosedSilenceableFailure transform::TrackingListener::findReplacementOp ( Operation *&  result,
Operation op,
ValueRange  newValues 
) const
protectedvirtual

Return a replacement payload op for the given op, which is going to be replaced with the given values.

By default, if all values are defined by the same op, which also has the same type as the given op, that defining op is used as a replacement.

A "failure" return value indicates that no replacement operation could be found. A "nullptr" return value indicates that no replacement op is needed (e.g., handle is dead or was consumed) and that the payload op should be dropped from the mapping.

Example: A tracked "linalg.generic" with two results is replaced with two values defined by (another) "linalg.generic". It is reasonable to assume that the replacement "linalg.generic" represents the same "computation". Therefore, the payload op mapping is updated to the defining op of the replacement values.

Counter Example: A "linalg.generic" is replaced with values defined by an "scf.for". Without further investigation, the relationship between the "linalg.generic" and the "scf.for" is unclear. They may not represent the same computation; e.g., there may be tiled "linalg.generic" inside the loop body that represents the original computation. Therefore, the TrackingListener is conservative by default: it drops the mapping and triggers the "payload replacement not found" notification. This default behavior can be customized in TrackingListenerConfig.

If no replacement op could be found according to the rules mentioned above, this function tries to skip over cast-like ops that implement CastOpInterface.

Example: A tracked "linalg.generic" is replaced with "linalg.generic", wrapped in a "tensor.cast". A cast is a metadata-only operation and it is reasonable to assume that the wrapped "linalg.generic" represents the same computation as the original "linalg.generic". The mapping is updated accordingly.

Certain ops (typically also metadata-only ops) are not considered casts, but should be skipped nonetheless. Such ops should implement FindPayloadReplacementOpInterface to specify with which operands the lookup should continue.

Example: A tracked "linalg.generic" is replaced with "linalg.generic", wrapped in a "tensor.reshape". A reshape is a metadata-only operation but not cast. (Implementing CastOpInterface would be incorrect and cause invalid foldings.) However, due to its FindPayloadReplacementOpInterface implementation, the replacement op lookup continues with the wrapped "linalg.generic" and the mapping is updated accordingly.

Derived classes may override findReplacementOp to specify custom replacement rules.

Definition at line 1217 of file TransformInterfaces.cpp.

References diag(), mlir::emitSilenceableFailure(), mlir::Operation::getLoc(), mlir::Operation::getName(), mlir::Operation::getNumResults(), mlir::Operation::getOperands(), mlir::Operation::hasTrait(), and mlir::DiagnosedSilenceableFailure::success().

◆ getCommonDefiningOp()

Operation * transform::TrackingListener::getCommonDefiningOp ( ValueRange  values)
staticprotected

Return the single op that defines all given values (if any).

Definition at line 1201 of file TransformInterfaces.cpp.

◆ getTransformOp()

TransformOpInterface mlir::transform::TrackingListener::getTransformOp ( ) const
inlineprotected

Return the transform op in which this TrackingListener is used.

Definition at line 1044 of file TransformInterfaces.h.

◆ notifyMatchFailure()

void transform::TrackingListener::notifyMatchFailure ( Location  loc,
function_ref< void(Diagnostic &)>  reasonCallback 
)
overrideprotectedvirtual

Notify the listener that the pattern failed to match the given operation, and provide a callback to populate a diagnostic with the reason why the failure occurred.

Reimplemented from mlir::RewriterBase::Listener.

Definition at line 1275 of file TransformInterfaces.cpp.

References DBGS, diag(), and mlir::Remark.

◆ notifyPayloadReplacementNotFound()

virtual void mlir::transform::TrackingListener::notifyPayloadReplacementNotFound ( Operation op,
ValueRange  values,
DiagnosedSilenceableFailure &&  diag 
)
inlineprotectedvirtual

This function is called when a tracked payload op is dropped because no replacement op was found.

Derived classes can implement this function for custom error handling.

Reimplemented in mlir::transform::ErrorCheckingTrackingListener.

Definition at line 1037 of file TransformInterfaces.h.

Friends And Related Function Documentation

◆ TransformRewriter

friend class TransformRewriter
friend

Definition at line 1047 of file TransformInterfaces.h.


The documentation for this class was generated from the following files: