MLIR  19.0.0git
DecomposeAffineOps.cpp
Go to the documentation of this file.
1 //===- DecomposeAffineOps.cpp - Decompose affine ops into finer-grained ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements functionality to progressively decompose coarse-grained
10 // affine ops into finer-grained ops.
11 //
12 //===----------------------------------------------------------------------===//
13 
16 #include "mlir/IR/PatternMatch.h"
18 #include "llvm/Support/Debug.h"
19 
20 using namespace mlir;
21 using namespace mlir::affine;
22 
23 #define DEBUG_TYPE "decompose-affine-ops"
24 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
25 #define DBGSNL() (llvm::dbgs() << "\n")
26 
27 /// Count the number of loops surrounding `operand` such that operand could be
28 /// hoisted above.
29 /// Stop counting at the first loop over which the operand cannot be hoisted.
30 static int64_t numEnclosingInvariantLoops(OpOperand &operand) {
31  int64_t count = 0;
32  Operation *currentOp = operand.getOwner();
33  while (auto loopOp = currentOp->getParentOfType<LoopLikeOpInterface>()) {
34  if (!loopOp.isDefinedOutsideOfLoop(operand.get()))
35  break;
36  currentOp = loopOp;
37  count++;
38  }
39  return count;
40 }
41 
43  AffineApplyOp op) {
44  SmallVector<int64_t> numInvariant = llvm::to_vector(
45  llvm::map_range(op->getOpOperands(), [&](OpOperand &operand) {
46  return numEnclosingInvariantLoops(operand);
47  }));
48 
49  int64_t numOperands = op.getNumOperands();
50  SmallVector<int64_t> operandPositions =
51  llvm::to_vector(llvm::seq<int64_t>(0, numOperands));
52  llvm::stable_sort(operandPositions, [&numInvariant](size_t i1, size_t i2) {
53  return numInvariant[i1] > numInvariant[i2];
54  });
55 
56  SmallVector<AffineExpr> replacements(numOperands);
57  SmallVector<Value> operands(numOperands);
58  for (int64_t i = 0; i < numOperands; ++i) {
59  operands[i] = op.getOperand(operandPositions[i]);
60  replacements[operandPositions[i]] = getAffineSymbolExpr(i, op.getContext());
61  }
62 
63  AffineMap map = op.getAffineMap();
64  ArrayRef<AffineExpr> repls{replacements};
65  map = map.replaceDimsAndSymbols(repls.take_front(map.getNumDims()),
66  repls.drop_front(map.getNumDims()),
67  /*numResultDims=*/0,
68  /*numResultSyms=*/numOperands);
69  map = AffineMap::get(0, numOperands,
70  simplifyAffineExpr(map.getResult(0), 0, numOperands),
71  op->getContext());
72  canonicalizeMapAndOperands(&map, &operands);
73 
74  rewriter.startOpModification(op);
75  op.setMap(map);
76  op->setOperands(operands);
77  rewriter.finalizeOpModification(op);
78 }
79 
80 /// Build an affine.apply that is a subexpression `expr` of `originalOp`s affine
81 /// map and with the same operands.
82 /// Canonicalize the map and operands to deduplicate and drop dead operands
83 /// before returning but do not perform maximal composition of AffineApplyOp
84 /// which would defeat the purpose.
85 static AffineApplyOp createSubApply(RewriterBase &rewriter,
86  AffineApplyOp originalOp, AffineExpr expr) {
87  MLIRContext *ctx = originalOp->getContext();
88  AffineMap m = originalOp.getAffineMap();
89  auto rhsMap = AffineMap::get(m.getNumDims(), m.getNumSymbols(), expr, ctx);
90  SmallVector<Value> rhsOperands = originalOp->getOperands();
91  canonicalizeMapAndOperands(&rhsMap, &rhsOperands);
92  return rewriter.create<AffineApplyOp>(originalOp.getLoc(), rhsMap,
93  rhsOperands);
94 }
95 
97  AffineApplyOp op) {
98  // 1. Preconditions: only handle dimensionless AffineApplyOp maps with a
99  // top-level binary expression that we can reassociate (i.e. add or mul).
100  AffineMap m = op.getAffineMap();
101  if (m.getNumDims() > 0)
102  return rewriter.notifyMatchFailure(op, "expected no dims");
103 
104  AffineExpr remainingExp = m.getResult(0);
105  auto binExpr = dyn_cast<AffineBinaryOpExpr>(remainingExp);
106  if (!binExpr)
107  return rewriter.notifyMatchFailure(op, "terminal affine.apply");
108 
109  if (!isa<AffineBinaryOpExpr>(binExpr.getLHS()) &&
110  !isa<AffineBinaryOpExpr>(binExpr.getRHS()))
111  return rewriter.notifyMatchFailure(op, "terminal affine.apply");
112 
113  bool supportedKind = ((binExpr.getKind() == AffineExprKind::Add) ||
114  (binExpr.getKind() == AffineExprKind::Mul));
115  if (!supportedKind)
116  return rewriter.notifyMatchFailure(
117  op, "only add or mul binary expr can be reassociated");
118 
119  LLVM_DEBUG(DBGS() << "Start decomposeIntoFinerGrainedOps: " << op << "\n");
120 
121  // 2. Iteratively extract the RHS subexpressions while the top-level binary
122  // expr kind remains the same.
123  MLIRContext *ctx = op->getContext();
124  SmallVector<AffineExpr> subExpressions;
125  while (true) {
126  auto currentBinExpr = dyn_cast<AffineBinaryOpExpr>(remainingExp);
127  if (!currentBinExpr || currentBinExpr.getKind() != binExpr.getKind()) {
128  subExpressions.push_back(remainingExp);
129  LLVM_DEBUG(DBGS() << "--terminal: " << subExpressions.back() << "\n");
130  break;
131  }
132  subExpressions.push_back(currentBinExpr.getRHS());
133  LLVM_DEBUG(DBGS() << "--subExpr: " << subExpressions.back() << "\n");
134  remainingExp = currentBinExpr.getLHS();
135  }
136 
137  // 3. Reorder subExpressions by the min symbol they are a function of.
138  // This also takes care of properly reordering local variables.
139  // This however won't be able to split expression that cannot be reassociated
140  // such as ones that involve divs and multiple symbols.
141  auto getMaxSymbol = [&](AffineExpr e) -> int64_t {
142  for (int64_t i = m.getNumSymbols(); i >= 0; --i)
143  if (e.isFunctionOfSymbol(i))
144  return i;
145  return -1;
146  };
147  llvm::stable_sort(subExpressions, [&](AffineExpr e1, AffineExpr e2) {
148  return getMaxSymbol(e1) < getMaxSymbol(e2);
149  });
150  LLVM_DEBUG(
151  llvm::interleaveComma(subExpressions, DBGS() << "--sorted subexprs: ");
152  llvm::dbgs() << "\n");
153 
154  // 4. Merge sorted subExpressions iteratively, thus achieving reassociation.
155  auto s0 = getAffineSymbolExpr(0, ctx);
156  auto s1 = getAffineSymbolExpr(1, ctx);
157  AffineMap binMap = AffineMap::get(
158  /*dimCount=*/0, /*symbolCount=*/2,
159  getAffineBinaryOpExpr(binExpr.getKind(), s0, s1), ctx);
160 
161  auto current = createSubApply(rewriter, op, subExpressions[0]);
162  for (int64_t i = 1, e = subExpressions.size(); i < e; ++i) {
163  Value tmp = createSubApply(rewriter, op, subExpressions[i]);
164  current = rewriter.create<AffineApplyOp>(op.getLoc(), binMap,
165  ValueRange{current, tmp});
166  LLVM_DEBUG(DBGS() << "--reassociate into: " << current << "\n");
167  }
168 
169  // 5. Replace original op.
170  rewriter.replaceOp(op, current.getResult());
171  return current;
172 }
static int64_t numEnclosingInvariantLoops(OpOperand &operand)
Count the number of loops surrounding operand such that operand could be hoisted above.
static AffineApplyOp createSubApply(RewriterBase &rewriter, AffineApplyOp originalOp, AffineExpr expr)
Build an affine.apply that is a subexpression expr of originalOps affine map and with the same operan...
#define DBGS()
Base type for affine expression.
Definition: AffineExpr.h:69
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
Definition: AffineMap.cpp:384
unsigned getNumDims() const
Definition: AffineMap.cpp:380
AffineMap replaceDimsAndSymbols(ArrayRef< AffineExpr > dimReplacements, ArrayRef< AffineExpr > symReplacements, unsigned numResultDims, unsigned numResultSyms) const
This method substitutes any uses of dimensions and symbols (e.g.
Definition: AffineMap.cpp:486
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:397
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
This class represents an operand of an operation.
Definition: Value.h:263
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:341
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:378
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
Definition: Operation.cpp:237
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:708
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:612
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
void canonicalizeMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Modifies both map and operands in-place so as to:
Definition: AffineOps.cpp:1429
FailureOr< AffineApplyOp > decompose(RewriterBase &rewriter, AffineApplyOp op)
Split an "affine.apply" operation into smaller ops.
void reorderOperandsByHoistability(RewriterBase &rewriter, AffineApplyOp op)
Helper function to rewrite op's affine map and reorder its operands such that they are in increasing ...
Include the generated interface declarations.
@ Mul
RHS of mul is always a constant or a symbolic expression.
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
Definition: AffineExpr.cpp:62
AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols)
Simplify an affine expression by flattening and some amount of simple analysis.
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Definition: AffineExpr.cpp:609