MLIR 22.0.0git
FoldIntoElementwise.cpp
Go to the documentation of this file.
1//===- FoldIntoElementwise.cpp - Fold Ops into elementwise if possible ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements folding ops such as transpose and broadcast into the
10// affine maps of the elementwise op.
11//
12//===----------------------------------------------------------------------===//
13
19#include "llvm/ADT/SmallVector.h"
20
21namespace mlir {
22#define GEN_PASS_DEF_LINALGFOLDINTOELEMENTWISEPASS
23#include "mlir/Dialect/Linalg/Passes.h.inc"
24} // namespace mlir
25
26using namespace mlir;
27using namespace mlir::linalg;
28
29#define DEBUG_TYPE "linalg-fold-into-elementwise"
30
31namespace {
32struct FoldTransposePattern : public OpRewritePattern<ElementwiseOp> {
33 using OpRewritePattern<ElementwiseOp>::OpRewritePattern;
34
35 LogicalResult matchAndRewrite(ElementwiseOp op,
36 PatternRewriter &rewriter) const override {
37 bool changed = false;
38 SmallVector<Value> newIns;
40 for (OpOperand *operand : op.getDpsInputOperands()) {
41 AffineMap map = op.getMatchingIndexingMap(operand);
42 auto transposeOp = operand->get().getDefiningOp<TransposeOp>();
43
44 if (!map.isIdentity() || !transposeOp) {
45 // push in original operand and its map.
46 newIns.push_back(operand->get());
47 newMaps.push_back(map);
48 continue;
49 }
50 newIns.push_back(transposeOp.getInput());
51 // push in transposeOp's inverse permutation map.
52 newMaps.push_back(transposeOp.getMatchingIndexingMap(
53 transposeOp.getDpsInputOperand(0)));
54 changed = true;
55 }
56 if (!changed)
57 return failure();
58 newMaps.push_back(op.getIndexingMapsArray().back());
59
60 rewriter.replaceOpWithNewOp<ElementwiseOp>(
61 op, newIns, op.getDpsInits()[0], op.getKindAttr(),
62 rewriter.getAffineMapArrayAttr(newMaps));
63 return success();
64 }
65};
66
67struct LinalgFoldIntoElementwisePass
69 LinalgFoldIntoElementwisePass> {
71 LinalgFoldIntoElementwisePass>::LinalgFoldIntoElementwisePassBase;
72
73 void runOnOperation() override {
74 Operation *op = getOperation();
77
78 if (failed(applyPatternsGreedily(op, std::move(patterns))))
79 return signalPassFailure();
80 }
81};
82} // namespace
83
86 patterns.add<FoldTransposePattern>(patterns.getContext());
87}
return success()
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isIdentity() const
Returns true if this affine map is an identity affine map.
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition Builders.cpp:318
This class represents an operand of an operation.
Definition Value.h:257
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
void populateLinalgFoldIntoElementwisePatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold operations like linalg.transform into elementwise op map.
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...