MLIR 23.0.0git
FoldIntoElementwise.cpp
Go to the documentation of this file.
1//===- FoldIntoElementwise.cpp - Fold Ops into elementwise if possible ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements folding ops such as transpose and broadcast into the
10// affine maps of the elementwise op.
11//
12//===----------------------------------------------------------------------===//
13
19#include "llvm/ADT/SmallVector.h"
20
21namespace mlir {
22#define GEN_PASS_DEF_LINALGFOLDINTOELEMENTWISEPASS
23#include "mlir/Dialect/Linalg/Passes.h.inc"
24} // namespace mlir
25
26using namespace mlir;
27using namespace mlir::linalg;
28
29#define DEBUG_TYPE "linalg-fold-into-elementwise"
30
31namespace {
32template <typename ProducerOpTy>
33struct ElementwiseOpFolder {
34 // Helper function to fold broadcast etc into elementwise op.
35 // Producer in this context is `broadcast op` etc, consumer is elwise operand.
36 static bool fold(OpOperand *elwiseOperand, AffineMap elwiseMap,
37 SmallVector<Value> &newIns,
38 SmallVector<AffineMap> &newMaps) {
39 auto producerOp = elwiseOperand->get().getDefiningOp<ProducerOpTy>();
40 if (!producerOp || !elwiseMap.isProjectedPermutation())
41 return false;
42 newIns.push_back(producerOp.getInput());
43 // push in the new composed affine map
44 newMaps.push_back(
45 producerOp.getMatchingIndexingMap(producerOp.getDpsInputOperand(0))
46 .compose(elwiseMap));
47 return true;
48 }
49};
50
51template <typename... ProducerOps>
52struct FoldIntoElementwisePattern : public OpRewritePattern<ElementwiseOp> {
53 using OpRewritePattern<ElementwiseOp>::OpRewritePattern;
54
55 LogicalResult matchAndRewrite(ElementwiseOp op,
56 PatternRewriter &rewriter) const override {
57 bool changed = false;
58 SmallVector<Value> newIns;
60 for (OpOperand *operand : op.getDpsInputOperands()) {
61 AffineMap consumerMap = op.getMatchingIndexingMap(operand);
62 const bool folded = (ElementwiseOpFolder<ProducerOps>::fold(
63 operand, consumerMap, newIns, newMaps) ||
64 ...);
65 if (folded) {
66 changed = true;
67 } else {
68 // push in original operand and its map.
69 newIns.push_back(operand->get());
70 newMaps.push_back(consumerMap);
71 }
72 }
73 if (!changed)
74 return failure();
75 newMaps.push_back(op.getIndexingMapsArray().back());
76
77 rewriter.replaceOpWithNewOp<ElementwiseOp>(
78 op, newIns, op.getDpsInits()[0], op.getKindAttr(),
79 rewriter.getAffineMapArrayAttr(newMaps));
80 return success();
81 }
82};
83
84struct LinalgFoldIntoElementwisePass
85 : public impl::LinalgFoldIntoElementwisePassBase<
86 LinalgFoldIntoElementwisePass> {
87 using impl::LinalgFoldIntoElementwisePassBase<
88 LinalgFoldIntoElementwisePass>::LinalgFoldIntoElementwisePassBase;
89
90 void runOnOperation() override {
91 Operation *op = getOperation();
92 RewritePatternSet patterns(op->getContext());
94
95 if (failed(applyPatternsGreedily(op, std::move(patterns))))
96 return signalPassFailure();
97 }
98};
99} // namespace
100
102 RewritePatternSet &patterns) {
103 patterns.add<FoldIntoElementwisePattern<TransposeOp, BroadcastOp>>(
104 patterns.getContext());
105}
return success()
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition Builders.cpp:322
IRValueT get() const
Return the current value being used by this operand.
This class represents an operand of an operation.
Definition Value.h:254
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:234
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
void populateLinalgFoldIntoElementwisePatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold operations like linalg.transform into elementwise op map.
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...