25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/Support/Debug.h"
29 #define GEN_PASS_DEF_LINALGGENERALIZENAMEDOPSPASS
30 #include "mlir/Dialect/Linalg/Passes.h.inc"
33 #define DEBUG_TYPE "linalg-generalization"
42 if (isa<GenericOp>(linalgOp) || isa<MapOp>(linalgOp))
45 if (linalgOp->getNumRegions() != 1) {
46 assert(linalgOp->getNumRegions() == 0 &&
"op with multiple regions");
67 assert(linalgOp->getNumRegions() == 1 &&
68 "expect named op to have one region attached");
69 GenericOp genericOp = rewriter.
create<GenericOp>(
70 linalgOp.getLoc(), resultTypes, inputs, outputs, indexingMaps, iterators);
72 genericOp.getRegion().begin());
73 rewriter.
replaceOp(linalgOp, genericOp->getResults());
79 struct LinalgGeneralizeNamedOpsPass
80 :
public impl::LinalgGeneralizeNamedOpsPassBase<
81 LinalgGeneralizeNamedOpsPass> {
82 using impl::LinalgGeneralizeNamedOpsPassBase<
83 LinalgGeneralizeNamedOpsPass>::LinalgGeneralizeNamedOpsPassBase;
84 void runOnOperation()
override;
89 void LinalgGeneralizeNamedOpsPass::runOnOperation() {
static LogicalResult generalizeNamedOpPrecondition(LinalgOp linalgOp)
static MLIRContext * getContext(OpFoldResult val)
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
This class provides an abstraction over the various different ranges of value types.
This class provides an abstraction over the different types of ranges over Values.
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp namedOp)
Create a GenericOp from the given named operation namedOp and replace namedOp.
void populateLinalgNamedOpsGeneralizationPatterns(RewritePatternSet &patterns)
Linalg generalization patterns.
Include the generated interface declarations.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Linalg generalization pattern.