MLIR  21.0.0git
FoldIntoElementwise.cpp
Go to the documentation of this file.
1 //===- FoldIntoElementwise.cpp - Fold Ops into elementwise if possible ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements folding ops such as transpose and broadcast into the
10 // affine maps of the elementwise op.
11 //
12 //===----------------------------------------------------------------------===//
13 
17 #include "mlir/IR/PatternMatch.h"
19 #include "llvm/ADT/SmallVector.h"
20 #include "llvm/ADT/TypeSwitch.h"
21 
22 namespace mlir {
23 #define GEN_PASS_DEF_LINALGFOLDINTOELEMENTWISEPASS
24 #include "mlir/Dialect/Linalg/Passes.h.inc"
25 } // namespace mlir
26 
27 using namespace mlir;
28 using namespace mlir::linalg;
29 
30 #define DEBUG_TYPE "linalg-fold-into-elementwise"
31 
32 namespace {
33 struct FoldTransposePattern : public OpRewritePattern<ElementwiseOp> {
35 
36  LogicalResult matchAndRewrite(ElementwiseOp op,
37  PatternRewriter &rewriter) const override {
38  bool changed = false;
39  SmallVector<Value> newIns;
40  SmallVector<AffineMap> newMaps;
41  for (OpOperand *operand : op.getDpsInputOperands()) {
42  AffineMap map = op.getMatchingIndexingMap(operand);
43  auto transposeOp = operand->get().getDefiningOp<TransposeOp>();
44 
45  if (!map.isIdentity() || !transposeOp) {
46  // push in original operand and its map.
47  newIns.push_back(operand->get());
48  newMaps.push_back(map);
49  continue;
50  }
51  newIns.push_back(transposeOp.getInput());
52  // push in transposeOp's inverse permutation map.
53  newMaps.push_back(transposeOp.getMatchingIndexingMap(
54  transposeOp.getDpsInputOperand(0)));
55  changed = true;
56  }
57  if (!changed)
58  return failure();
59  newMaps.push_back(op.getIndexingMapsArray().back());
60 
61  rewriter.replaceOpWithNewOp<ElementwiseOp>(
62  op, newIns, op.getDpsInits()[0], op.getKindAttr(),
63  rewriter.getAffineMapArrayAttr(newMaps));
64  return success();
65  }
66 };
67 
68 struct LinalgFoldIntoElementwisePass
69  : public impl::LinalgFoldIntoElementwisePassBase<
70  LinalgFoldIntoElementwisePass> {
71  using impl::LinalgFoldIntoElementwisePassBase<
72  LinalgFoldIntoElementwisePass>::LinalgFoldIntoElementwisePassBase;
73 
74  void runOnOperation() override {
75  llvm::outs() << "Hellow from fold into elemenwise \n";
76  Operation *op = getOperation();
79 
80  if (failed(applyPatternsGreedily(op, std::move(patterns))))
81  return signalPassFailure();
82  }
83 };
84 } // namespace
85 
88  patterns.add<FoldTransposePattern>(patterns.getContext());
89 }
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isIdentity() const
Returns true if this affine map is an identity affine map.
Definition: AffineMap.cpp:345
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:314
This class represents an operand of an operation.
Definition: Value.h:267
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:803
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:554
void populateLinalgFoldIntoElementwisePatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold operations like linalg.transform into elementwise op map.
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358