MLIR 23.0.0git
FoldAddIntoDest.cpp
Go to the documentation of this file.
1//===- FoldAddIntoDest.cpp ---------------------------------------*- C++-*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
12#include "mlir/IR/Dominance.h"
14
15using namespace mlir;
16
17// Determine whether the value is defined to be zero.
18static bool isDefinedAsZero(Value val) {
19 if (!val)
20 return false;
21
22 // Check whether val is a constant scalar / vector splat / tensor splat float
23 // or integer zero.
24 if (isZeroIntegerOrFloat(val))
25 return true;
26
27 auto *defOp = val.getDefiningOp();
28 if (!defOp)
29 return false;
30
32 .Case<linalg::FillOp, linalg::CopyOp>([&](auto op) {
33 return op.getInputs().size() == 1 && isDefinedAsZero(op.getInputs()[0]);
34 })
35 .Default([&](auto) { return false; });
36}
37
38/// Replace a linalg.add with one operand the single user of a contraction,
39/// which has a zero-filled, "identity-mapped" destination and is dominated by
40/// the `other` operand, by the contraction with `other` as its dest.
41///
42/// As an example, the following pseudo-code will be rewritten
43/// %cst = arith.constant 0.000000e+00
44/// %empty = tensor.empty()
45/// %zeroed = linalg.fill ins(%cst : f32) outs(%empty : !type) -> !type
46/// %C = linalg.matmul ins(%A, %B) outs(%zeroed)
47/// %empty2 = tensor.empty()
48/// %zeroed2 = linalg.fill ins(%cst : f32) outs(%empty2 : !type) -> !type
49/// %F = linalg.matmul ins(%D, %E) outs(%zeroed2)
50/// %out = linalg.add ins(%C, %F) outs(%empty)
51/// to:
52/// %cst = arith.constant 0.000000e+00
53/// %empty = tensor.empty()
54/// %zeroed = linalg.fill ins(%cst : f32) outs(%empty : !type) -> !type
55/// %C = linalg.matmul ins(%A, %B) outs(%zeroed)
56/// %out = linalg.matmul ins(%D, %E) outs(%C)
57///
58struct FoldAddIntoDest final : public OpRewritePattern<linalg::AddOp> {
59 using OpRewritePattern<linalg::AddOp>::OpRewritePattern;
60
61 LogicalResult matchAndRewrite(linalg::AddOp addOp,
62 PatternRewriter &rewriter) const override {
63 // For now, pattern only applies on tensor types (memref support is TODO).
64 if (!addOp.hasPureTensorSemantics())
65 return failure();
66
67 Value dominatingOperand = nullptr;
68 linalg::LinalgOp dominatedOp = nullptr;
69 { // We will forget about which operand was left or right after this block.
70 Value lhs = addOp.getInputs()[0];
71 Value rhs = addOp.getInputs()[1];
72
73 // Can only put one of addOp's operands in the dest/out arg of the other's
74 // defining op based on suitable dominance.
75 // TODO: Can be generalized to move ops around as long as that still
76 // respects use-def chains and doesn't affect side-effects.
77 if (auto rhsOp = rhs.getDefiningOp<linalg::LinalgOp>()) {
78 DominanceInfo domInfo(rhsOp);
79 if (domInfo.properlyDominates(lhs, rhsOp)) {
80 dominatingOperand = lhs;
81 dominatedOp = rhsOp;
82 }
83 }
84 if (auto lhsOp = lhs.getDefiningOp<linalg::LinalgOp>()) {
85 DominanceInfo domInfo(lhsOp);
86 if (domInfo.properlyDominates(rhs, lhsOp)) {
87 dominatingOperand = rhs;
88 dominatedOp = lhsOp;
89 }
90 }
91 if (!dominatingOperand || !dominatedOp)
92 return failure();
93 // NB: As linalg.add's generalisation ignores the out argument in its
94 // region there is no need to perform checks on addOp's out argument.
95 }
96
97 // When dominated op is a contraction we know it accumulates on its out arg.
98 // E.g., AddOp is not a contraction and hence ignores its out arg's value.
99 // TODO: Generalize check to also pass in case of other LinalgOps that
100 // accumulate on their out arg but are not (binary) contraction ops.
101 auto dominatedDestOp =
102 dyn_cast<DestinationStyleOpInterface>((Operation *)dominatedOp);
103 if (dominatedOp->getNumResults() != 1 ||
104 !linalg::isaContractionOpInterface(dominatedOp) ||
105 (!dominatedDestOp || dominatedDestOp.getNumDpsInits() != 1))
106 return rewriter.notifyMatchFailure(
107 dominatedOp, "expected dominated op to be single-result "
108 "destination-passing contraction");
109
110 // To change the contraction's result, `addOp` must be its only user.
111 if (!dominatedOp->getResult(0).hasOneUse())
112 return rewriter.notifyMatchFailure(
113 dominatedOp,
114 "expected linalg.add to be single user of contraction's result");
115
116 // As `dominatedOp` was already accumulating on its out argument, it is only
117 // safe to no longer use its current out arg when it is the additive ident.
118 auto *destOperand = dominatedDestOp.getDpsInitOperand(0);
119 if (!isDefinedAsZero(destOperand->get()))
120 return rewriter.notifyMatchFailure(
121 dominatedOp, "expected dominated op's dest to be additive zero");
122 // TODO: If the other op is a contraction and has additive ident as dest, we
123 // can swap the dests and achieve the proper sum, given suitable dominance.
124
125 // As an operand to `addOp`, `dominatingOperand` has an identity affine_map.
126 // Hence, we can only substitute `dominatingOperand` for the dest of the
127 // contraction when dest's indexing_map corresponds to an identity map
128 // w.r.t. just the dimensions of dest, i.e. is an ordered projection.
129 SmallVector<AffineMap> indexMaps = dominatedOp.getIndexingMapsArray();
130 int prevDimPos = -1;
131 for (auto expr : indexMaps[destOperand->getOperandNumber()].getResults()) {
132 auto dim = dyn_cast<AffineDimExpr>(expr);
133 if (!dim || prevDimPos > static_cast<int>(dim.getPosition()))
134 return rewriter.notifyMatchFailure(
135 dominatedOp, "expected index_map for contraction's dest to be an "
136 "ordered projection");
137 prevDimPos = dim.getPosition();
138 }
139
140 // Replace the additive-ident, i.e. zero, out arg of the dominated op by the
141 // dominating summand. This makes the dominated op's result the sum of both
142 // of addOp's arguments - therefore we replace addOp and it uses by it.
143 rewriter.modifyOpInPlace(
144 dominatedOp, [&]() { dominatedOp->setOperand(2, dominatingOperand); });
145 rewriter.replaceAllOpUsesWith(addOp, dominatedOp->getResult(0));
146 return success();
147 }
148};
149
151 // Replace linalg.add when destination passing suffices for achieving the sum.
152 patterns.add<FoldAddIntoDest>(patterns.getContext());
153}
return success()
static bool isDefinedAsZero(Value val)
lhs
A class for computing basic dominance information.
Definition Dominance.h:143
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void replaceAllOpUsesWith(Operation *from, ValueRange to)
Find uses of from and replace them with to.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
bool isaContractionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ContractionOpInterface.
void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns)
Pattern to replace linalg.add when destination passing on a contraction op suffices for achieving the...
Include the generated interface declarations.
bool isZeroIntegerOrFloat(OpFoldResult v)
Return "true" if v is an integer/float value/attribute with constant value zero.
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:139
Replace a linalg.add with one operand the single user of a contraction, which has a zero-filled,...
LogicalResult matchAndRewrite(linalg::AddOp addOp, PatternRewriter &rewriter) const override
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})