MLIR  20.0.0git
FoldAddIntoDest.cpp
Go to the documentation of this file.
1 //===- FoldAddIntoDest.cpp ---------------------------------------*- C++-*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
12 #include "mlir/IR/Dominance.h"
14 
15 using namespace mlir;
16 
17 // Determine whether the value is defined to be zero.
18 static bool isDefinedAsZero(Value val) {
19  if (!val)
20  return false;
21 
22  // Check whether val is a constant scalar / vector splat / tensor splat float
23  // or integer zero.
24  if (matchPattern(val, m_AnyZeroFloat()) || matchPattern(val, m_Zero()))
25  return true;
26 
28  .Case<linalg::FillOp, linalg::CopyOp>([&](auto op) {
29  return op && op.getInputs().size() == 1 &&
30  isDefinedAsZero(op.getInputs()[0]);
31  })
32  .Default([&](auto) { return false; });
33 }
34 
35 /// Replace a linalg.add with one operand the single user of a contraction,
36 /// which has a zero-filled, "identity-mapped" destination and is dominated by
37 /// the `other` operand, by the contraction with `other` as its dest.
38 ///
39 /// As an example, the following pseudo-code will be rewritten
40 /// %cst = arith.constant 0.000000e+00
41 /// %empty = tensor.empty()
42 /// %zeroed = linalg.fill ins(%cst : f32) outs(%empty : !type) -> !type
43 /// %C = linalg.matmul ins(%A, %B) outs(%zeroed)
44 /// %empty2 = tensor.empty()
45 /// %zeroed2 = linalg.fill ins(%cst : f32) outs(%empty2 : !type) -> !type
46 /// %F = linalg.matmul ins(%D, %E) outs(%zeroed2)
47 /// %out = linalg.add ins(%C, %F) outs(%empty)
48 /// to:
49 /// %cst = arith.constant 0.000000e+00
50 /// %empty = tensor.empty()
51 /// %zeroed = linalg.fill ins(%cst : f32) outs(%empty : !type) -> !type
52 /// %C = linalg.matmul ins(%A, %B) outs(%zeroed)
53 /// %out = linalg.matmul ins(%D, %E) outs(%C)
54 ///
55 struct FoldAddIntoDest final : public OpRewritePattern<linalg::AddOp> {
57 
58  LogicalResult matchAndRewrite(linalg::AddOp addOp,
59  PatternRewriter &rewriter) const override {
60  // For now, pattern only applies on tensor types (memref support is TODO).
61  if (!addOp.hasPureTensorSemantics())
62  return failure();
63 
64  Value dominatingOperand = nullptr;
65  linalg::LinalgOp dominatedOp = nullptr;
66  { // We will forget about which operand was left or right after this block.
67  Value lhs = addOp.getInputs()[0];
68  Value rhs = addOp.getInputs()[1];
69 
70  // Can only put one of addOp's operands in the dest/out arg of the other's
71  // defining op based on suitable dominance.
72  // TODO: Can be generalized to move ops around as long as that still
73  // respects use-def chains and doesn't affect side-effects.
74  if (auto rhsOp = rhs.getDefiningOp<linalg::LinalgOp>()) {
75  DominanceInfo domInfo(rhsOp);
76  if (domInfo.properlyDominates(lhs, rhsOp)) {
77  dominatingOperand = lhs;
78  dominatedOp = rhsOp;
79  }
80  }
81  if (auto lhsOp = lhs.getDefiningOp<linalg::LinalgOp>()) {
82  DominanceInfo domInfo(lhsOp);
83  if (domInfo.properlyDominates(rhs, lhsOp)) {
84  dominatingOperand = rhs;
85  dominatedOp = lhsOp;
86  }
87  }
88  if (!dominatingOperand || !dominatedOp)
89  return failure();
90  // NB: As linalg.add's generalisation ignores the out argument in its
91  // region there is no need to perform checks on addOp's out argument.
92  }
93 
94  // When dominated op is a contraction we know it accumulates on its out arg.
95  // E.g., AddOp is not a contraction and hence ignores its out arg's value.
96  // TODO: Generalize check to also pass in case of other LinalgOps that
97  // accumulate on their out arg but are not (binary) contraction ops.
98  auto dominatedDestOp =
99  dyn_cast<DestinationStyleOpInterface>((Operation *)dominatedOp);
100  if (dominatedOp->getNumResults() != 1 ||
101  !linalg::isaContractionOpInterface(dominatedOp) ||
102  (!dominatedDestOp || dominatedDestOp.getNumDpsInits() != 1))
103  return rewriter.notifyMatchFailure(
104  dominatedOp, "expected dominated op to be single-result "
105  "destination-passing contraction");
106 
107  // To change the contraction's result, `addOp` must be its only user.
108  if (!dominatedOp->getResult(0).hasOneUse())
109  return rewriter.notifyMatchFailure(
110  dominatedOp,
111  "expected linalg.add to be single user of contraction's result");
112 
113  // As `dominatedOp` was already accumulating on its out argument, it is only
114  // safe to no longer use its current out arg when it is the additive ident.
115  auto *destOperand = dominatedDestOp.getDpsInitOperand(0);
116  if (!isDefinedAsZero(destOperand->get()))
117  return rewriter.notifyMatchFailure(
118  dominatedOp, "expected dominated op's dest to be additive zero");
119  // TODO: If the other op is a contraction and has additive ident as dest, we
120  // can swap the dests and achieve the proper sum, given suitable dominance.
121 
122  // As an operand to `addOp`, `dominatingOperand` has an identity affine_map.
123  // Hence, we can only substitute `dominatingOperand` for the dest of the
124  // contraction when dest's indexing_map corresponds to an identity map
125  // w.r.t. just the dimensions of dest, i.e. is an ordered projection.
126  SmallVector<AffineMap> indexMaps = dominatedOp.getIndexingMapsArray();
127  int prevDimPos = -1;
128  for (auto expr : indexMaps[destOperand->getOperandNumber()].getResults()) {
129  auto dim = dyn_cast<AffineDimExpr>(expr);
130  if (!dim || prevDimPos > static_cast<int>(dim.getPosition()))
131  return rewriter.notifyMatchFailure(
132  dominatedOp, "expected index_map for contraction's dest to be an "
133  "ordered projection");
134  prevDimPos = dim.getPosition();
135  }
136 
137  // Replace the additive-ident, i.e. zero, out arg of the dominated op by the
138  // dominating summand. This makes the dominated op's result the sum of both
139  // of addOp's arguments - therefore we replace addOp and it uses by it.
140  rewriter.modifyOpInPlace(
141  dominatedOp, [&]() { dominatedOp->setOperand(2, dominatingOperand); });
142  rewriter.replaceAllOpUsesWith(addOp, dominatedOp->getResult(0));
143  return success();
144  }
145 };
146 
148  // Replace linalg.add when destination passing suffices for achieving the sum.
149  patterns.add<FoldAddIntoDest>(patterns.getContext());
150 }
static bool isDefinedAsZero(Value val)
A class for computing basic dominance information.
Definition: Dominance.h:140
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
Definition: Dominance.h:153
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:636
void replaceAllOpUsesWith(Operation *from, ValueRange to)
Find uses of from and replace them with to.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
bool isaContractionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ContractionOpInterface.
void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns)
Pattern to replace linalg.add when destination passing on a contraction op suffices for achieving the...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:442
const FrozenRewritePatternSet & patterns
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.
Definition: Matchers.h:399
Replace a linalg.add with one operand the single user of a contraction, which has a zero-filled,...
LogicalResult matchAndRewrite(linalg::AddOp addOp, PatternRewriter &rewriter) const override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358