|
MLIR 22.0.0git
|
For vectors with at least one unit dim, replaces: elementwise(a, b) with: sc_a = shape_cast(a) sc_b = shape_cast(b) res = elementwise(sc_a, sc_b) return shape_cast(res) The newly inserted shape_cast Ops fold (before elementwise Op) and then restore (after elementwise Op) the unit dim. Vectors a and b are required to be rank > 1. More...
Public Member Functions | |
| LogicalResult | matchAndRewrite (Operation *op, PatternRewriter &rewriter) const override |
| Attempt to match against code rooted at the specified operation, which is the same operation code as getRootKind(). | |
| OpTraitRewritePattern (MLIRContext *context, PatternBenefit benefit=1) | |
| Public Member Functions inherited from mlir::OpTraitRewritePattern< OpTrait::Elementwise > | |
| OpTraitRewritePattern (MLIRContext *context, PatternBenefit benefit=1) | |
| Public Member Functions inherited from mlir::RewritePattern | |
| virtual | ~RewritePattern ()=default |
| Public Member Functions inherited from mlir::Pattern | |
| ArrayRef< OperationName > | getGeneratedOps () const |
| Return a list of operations that may be generated when rewriting an operation instance with this pattern. | |
| std::optional< OperationName > | getRootKind () const |
| Return the root node that this pattern matches. | |
| std::optional< TypeID > | getRootInterfaceID () const |
| Return the interface ID used to match the root operation of this pattern. | |
| std::optional< TypeID > | getRootTraitID () const |
| Return the trait ID used to match the root operation of this pattern. | |
| PatternBenefit | getBenefit () const |
| Return the benefit (the inverse of "cost") of matching this pattern. | |
| bool | hasBoundedRewriteRecursion () const |
| Returns true if this pattern is known to result in recursive application, i.e. | |
| MLIRContext * | getContext () const |
| Return the MLIRContext used to create this pattern. | |
| StringRef | getDebugName () const |
| Return a readable name for this pattern. | |
| void | setDebugName (StringRef name) |
| Set the human readable debug name used for this pattern. | |
| ArrayRef< StringRef > | getDebugLabels () const |
| Return the set of debug labels attached to this pattern. | |
| void | addDebugLabels (ArrayRef< StringRef > labels) |
| Add the provided debug labels to this pattern. | |
| void | addDebugLabels (StringRef label) |
Additional Inherited Members | |
| Public Types inherited from mlir::OpTraitRewritePattern< OpTrait::Elementwise > | |
| using | Base |
| Type alias to allow derived classes to inherit constructors with using Base::Base;. | |
| Static Public Member Functions inherited from mlir::RewritePattern | |
| template<typename T, typename... Args> | |
| static std::unique_ptr< T > | create (Args &&...args) |
| This method provides a convenient interface for creating and initializing derived rewrite patterns of the given type T. | |
| Protected Member Functions inherited from mlir::RewritePattern | |
| Pattern (StringRef rootName, PatternBenefit benefit, MLIRContext *context, ArrayRef< StringRef > generatedNames={}) | |
| Inherit the base constructors from Pattern. | |
| Pattern (MatchAnyOpTypeTag tag, PatternBenefit benefit, MLIRContext *context, ArrayRef< StringRef > generatedNames={}) | |
| Inherit the base constructors from Pattern. | |
| Pattern (MatchInterfaceOpTypeTag tag, TypeID interfaceID, PatternBenefit benefit, MLIRContext *context, ArrayRef< StringRef > generatedNames={}) | |
| Inherit the base constructors from Pattern. | |
| Pattern (MatchTraitOpTypeTag tag, TypeID traitID, PatternBenefit benefit, MLIRContext *context, ArrayRef< StringRef > generatedNames={}) | |
| Inherit the base constructors from Pattern. | |
| Protected Member Functions inherited from mlir::Pattern | |
| Pattern (StringRef rootName, PatternBenefit benefit, MLIRContext *context, ArrayRef< StringRef > generatedNames={}) | |
| Construct a pattern with a certain benefit that matches the operation with the given root name. | |
| Pattern (MatchAnyOpTypeTag tag, PatternBenefit benefit, MLIRContext *context, ArrayRef< StringRef > generatedNames={}) | |
| Construct a pattern that may match any operation type. | |
| Pattern (MatchInterfaceOpTypeTag tag, TypeID interfaceID, PatternBenefit benefit, MLIRContext *context, ArrayRef< StringRef > generatedNames={}) | |
| Construct a pattern that may match any operation that implements the interface defined by the provided interfaceID. | |
| Pattern (MatchTraitOpTypeTag tag, TypeID traitID, PatternBenefit benefit, MLIRContext *context, ArrayRef< StringRef > generatedNames={}) | |
| Construct a pattern that may match any operation that implements the trait defined by the provided traitID. | |
| void | setHasBoundedRewriteRecursion (bool hasBoundedRecursionArg=true) |
| Set the flag detailing if this pattern has bounded rewrite recursion or not. | |
For vectors with at least one unit dim, replaces: elementwise(a, b) with: sc_a = shape_cast(a) sc_b = shape_cast(b) res = elementwise(sc_a, sc_b) return shape_cast(res) The newly inserted shape_cast Ops fold (before elementwise Op) and then restore (after elementwise Op) the unit dim. Vectors a and b are required to be rank > 1.
Ex: mul = arith.mulf B_row, A_row : vector<1x[4]xf32> cast = vector.shape_cast mul : vector<1x[4]xf32> to vector<[4]xf32>
gets converted to:
B_row_sc = vector.shape_cast B_row : vector<1x[4]xf32> to vector<[4]xf32> A_row_sc = vector.shape_cast A_row : vector<1x[4]xf32> to vector<[4]xf32> mul = arith.mulf B_row_sc, A_row_sc : vector<[4]xf32> cast_new = vector.shape_cast mul : vector<[4]xf32> to vector<1x[4]xf32> cast = vector.shape_cast cast_new : vector<1x[4]xf32> to vector<[4]xf32>
Patterns for folding shape_casts should instantly eliminate cast_new and cast.
Definition at line 1960 of file VectorTransforms.cpp.
|
inlineoverridevirtual |
Attempt to match against code rooted at the specified operation, which is the same operation code as getRootKind().
If successful, perform the rewrite.
Note: Implementations must modify the IR if and only if the function returns "success".
Implements mlir::RewritePattern.
Definition at line 1963 of file VectorTransforms.cpp.
References mlir::OpBuilder::create(), dropNonScalableUnitDimFromType(), mlir::Operation::getAttrs(), mlir::OperationName::getIdentifier(), mlir::Operation::getLoc(), mlir::Operation::getName(), mlir::Operation::getNumRegions(), mlir::Operation::getNumResults(), mlir::Operation::getOperand(), mlir::Operation::getOperands(), mlir::Operation::getResult(), mlir::Value::getType(), mlir::RewriterBase::notifyMatchFailure(), mlir::RewriterBase::replaceOpWithNewOp(), and success().
|
inline |
Definition at line 354 of file PatternMatch.h.