19#include "llvm/ADT/SmallVector.h"
22#define GEN_PASS_DEF_LINALGFOLDINTOELEMENTWISEPASS
23#include "mlir/Dialect/Linalg/Passes.h.inc"
29#define DEBUG_TYPE "linalg-fold-into-elementwise"
32template <
typename ProducerOpTy>
33struct ElementwiseOpFolder {
42 newIns.push_back(producerOp.getInput());
45 producerOp.getMatchingIndexingMap(producerOp.getDpsInputOperand(0))
51template <
typename... ProducerOps>
55 LogicalResult matchAndRewrite(ElementwiseOp op,
60 for (
OpOperand *operand : op.getDpsInputOperands()) {
61 AffineMap consumerMap = op.getMatchingIndexingMap(operand);
62 const bool folded = (ElementwiseOpFolder<ProducerOps>::fold(
63 operand, consumerMap, newIns, newMaps) ||
69 newIns.push_back(operand->get());
70 newMaps.push_back(consumerMap);
75 newMaps.push_back(op.getIndexingMapsArray().back());
78 op, newIns, op.getDpsInits()[0], op.getKindAttr(),
84struct LinalgFoldIntoElementwisePass
85 :
public impl::LinalgFoldIntoElementwisePassBase<
86 LinalgFoldIntoElementwisePass> {
87 using impl::LinalgFoldIntoElementwisePassBase<
88 LinalgFoldIntoElementwisePass>::LinalgFoldIntoElementwisePassBase;
90 void runOnOperation()
override {
96 return signalPassFailure();
103 patterns.
add<FoldIntoElementwisePattern<TransposeOp, BroadcastOp>>(
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
IRValueT get() const
Return the current value being used by this operand.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
MLIRContext * getContext()
Return the context this operation is associated with.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
void populateLinalgFoldIntoElementwisePatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold operations like linalg.transform into elementwise op map.
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...