MLIR 22.0.0git
Fusion.cpp
Go to the documentation of this file.
1//===- Fusion.cpp - Implementation of linalg Fusion -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the linalg dialect Fusion pass.
10//
11//===----------------------------------------------------------------------===//
12
19#include "mlir/IR/AffineExpr.h"
20#include "mlir/IR/AffineMap.h"
21#include "mlir/IR/Dominance.h"
22#include "mlir/Support/LLVM.h"
23#include "llvm/ADT/SmallBitVector.h"
24#include "llvm/Support/Debug.h"
25
26#define DEBUG_TYPE "linalg-fusion"
27
28using namespace mlir;
29using namespace mlir::linalg;
30
31/// Implements a simple high-level fusion pass on linalg structured operations.
32///
33/// In each block, linalg ops are processed in reverse textual order.
34/// Given a linalg op `O`, fusion occurs by:
35/// 1. inspecting the linalg ops that write into the views read by `O`. There
36/// are 2 cases:
37/// a) buffer case: use the SSA value of the views and a simple alias
38/// analysis on subview ops to determine producer-consumer dependences;
39/// b) tensor case: use SSA use-def chains on extract_slice ops;
40/// 2. greedily fuse the linalg ops that produce the subview/extract_slice.
41/// 3. inspect the fused ops and determine whether they have other remaining
42/// LinalgOp uses. If not, then erase the original producing linalg op.
43///
44/// More advanced use cases, analyses as well as profitability heuristics are
45/// left for future work.
46
49 unsigned dimension;
50};
51
52// Given an `op`, returns the first (`shape`, `dimension`) pair that identifies
53// the loop range at `loopDepth`. The semantics of the loopToOperandRangesMaps
54// guarantees at least one such dimension is found. If multiple candidates exist
55// they must agree by construction (i.e. have the same size) and we just return
56// the first one.
57static ShapeDimension
58getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth,
59 bool fromSubViewOpOnly = false) {
60 // Iterate over the inputs and outputs in order.
61 // Extract the subranges from the linearized ranges.
62 for (OpOperand &opOperand : op->getOpOperands()) {
63 // The method `getRangeFromOperandShape` requires using SubViewOp or
64 // ExtractSliceOps. If the value isn't defined from there continue.
65 // todo: The method should be adapted to get the values from
66 // `ViewInterface`. The interface needs a `getOrCreateRanges` method which
67 // currently returns a `linalg.range`. The fix here is to move this op to
68 // `std` dialect and add the method to `ViewInterface`.
69 if (fromSubViewOpOnly &&
70 !isa_and_nonnull<memref::SubViewOp, tensor::ExtractSliceOp>(
71 opOperand.get().getDefiningOp()))
72 continue;
73
74 AffineMap map = op.getMatchingIndexingMap(&opOperand);
75 LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange I/O idx: "
76 << opOperand.getOperandNumber() << "\n");
77 LLVM_DEBUG(llvm::dbgs()
78 << "getShapeDefiningLoopRange map: " << map << "\n");
79 for (const auto &en : llvm::enumerate(map.getResults())) {
80 auto dimExpr = dyn_cast<AffineDimExpr>(en.value());
81 if (!dimExpr)
82 continue;
83 if (loopDepth == cast<AffineDimExpr>(en.value()).getPosition()) {
84 LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange loopDepth: "
85 << loopDepth << "\n");
86 LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange shape: "
87 << opOperand.get() << "\n");
88 return ShapeDimension{opOperand.get(),
89 static_cast<unsigned>(en.index())};
90 }
91 }
92 }
93 llvm_unreachable("Expect to be able to extract a shape defining loop range");
94}
95
96static SmallVector<Value> getTiledOperands(LinalgOp producer) {
97 return producer->getOperands();
98}
99
100/// Fuses the producer by cloning the `producer`. The `fusedLoopsAndRanges`
101/// provides the loop range information for the fused loops. The rest are
102/// obtained from the producer itself, since they are not tiled + fused.
103static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
104 const DenseMap<unsigned, Range> &fusedLoopsAndRanges) {
105 SmallVector<OpFoldResult> ivs, tileSizes, sizeBounds;
106 SmallVector<Range> loopRanges;
107 Location loc = producer.getLoc();
108
109 for (unsigned i = 0, e = producer.getNumLoops(); i < e; ++i) {
110 auto shapeDim = getShapeDefiningLoopRange(producer, i);
111 OpFoldResult dim =
112 createFoldedDimOp(b, loc, shapeDim.shape, shapeDim.dimension);
113 sizeBounds.push_back(dim);
114 auto it = fusedLoopsAndRanges.find(i);
115 if (it != fusedLoopsAndRanges.end()) {
116 ivs.push_back(it->second.offset);
117 tileSizes.push_back(it->second.size);
118 loopRanges.push_back(it->second);
119 LLVM_DEBUG(llvm::dbgs() << "tiled loop#" << i << " with LoopRange "
120 << loopRanges.back() << "\n");
121 } else {
122 tileSizes.push_back(b.getIndexAttr(0));
123 loopRanges.push_back(Range{b.getIndexAttr(0), dim, b.getIndexAttr(1)});
124 LLVM_DEBUG(llvm::dbgs() << "full loop#" << i << " with LoopRange "
125 << loopRanges.back() << "\n");
126 }
127 }
128
129 SmallVector<Value, 8> clonedShapes;
130 clonedShapes.reserve(producer->getNumOperands());
131
132 // Compute subranges for all tensor input/output operands.
133 clonedShapes.append(makeTiledShapes(
134 b, loc, producer, getTiledOperands(producer), ivs, tileSizes, sizeBounds,
135 /**omitPartialTileCheck=*/false));
136
137 // Take result types from the tiled init operands.
138 MutableOperandRange producerDpsInits = producer.getDpsInitsMutable();
139 SmallVector<Type, 4> resultTypes;
140 resultTypes.reserve(producer->getNumResults());
141 int64_t firstInitOperandIdx =
142 producerDpsInits.getAsOperandRange().getBeginOperandIndex();
143 for (int64_t i = 0, e = producer->getNumResults(); i < e; ++i) {
144 resultTypes.push_back(clonedShapes[firstInitOperandIdx + i].getType());
145 }
146
147 // Clone the producer with new operands and result types.
148 LinalgOp clonedOp = clone(b, producer, resultTypes, clonedShapes);
149
150 // Shift all IndexOp results by the tile offset.
151 SmallVector<OpFoldResult> allIvs = llvm::to_vector(
152 llvm::map_range(loopRanges, [&](Range range) { return range.offset; }));
153 offsetIndices(b, clonedOp, allIvs);
154
155 return clonedOp;
156}
157
158/// Get the loop range for a dimension `dim` based on the `shapedOperand`. It is
159/// expected to be defined by a subview op or an extract_slice op.
161 Value shapedOperand, unsigned dim) {
162 Operation *shapeProducingOp = shapedOperand.getDefiningOp();
163 if (auto subViewOp = dyn_cast<memref::SubViewOp>(shapeProducingOp))
164 return subViewOp.getOrCreateRanges(b, loc)[dim];
165 if (auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(shapeProducingOp))
166 return sliceOp.getOrCreateRanges(b, loc)[dim];
167 llvm_unreachable("SubviewOp or ExtractSliceOp expected");
168}
169
170/// Fuses the producer into the loop immediately enclosing the consumer.
171/// This is achieved by "recomputing" the producer at the time it
172/// is needed just before the consumer.
173static LinalgOp fuse(OpBuilder &b, LinalgOp producerOp, AffineMap producerMap,
174 OpOperand &consumerOpOperand) {
175 LLVM_DEBUG(llvm::dbgs() << "Producer map: " << producerMap << "\n");
176 DenseMap<unsigned, Range> fusedLoopsAndRanges;
177 Value shapedOperand = consumerOpOperand.get();
178 for (const auto &en : llvm::enumerate(producerMap.getResults())) {
179 unsigned posInProducerLoop = cast<AffineDimExpr>(en.value()).getPosition();
180 fusedLoopsAndRanges[posInProducerLoop] = getRangeFromOperandShape(
181 b, consumerOpOperand.getOwner()->getLoc(), shapedOperand, en.index());
182 }
183 return fuse(b, producerOp, fusedLoopsAndRanges);
184}
185
186/// Walk back use-def chain through scf::For yields.
187/// Sets `producer` and `outputIndex` if it finds a producer LinalgOp
188
189// TODO(ravishankarm, ntv): This can be moved into the dependence graphs
190// dependence tracking since the dependence tracking is similar to what is done
191// w.r.t to buffers.
192static void getProducerOfTensor(Value tensor, OpResult &opResult) {
193 if (!isa<RankedTensorType>(tensor.getType()))
194 return;
195
196 while (true) {
197 LLVM_DEBUG(llvm::dbgs() << "\ngetProducerOfTensor: " << tensor);
198 if (auto linalgOp = tensor.getDefiningOp<LinalgOp>()) {
199 opResult = cast<OpResult>(tensor);
200 return;
201 }
202 if (auto sliceOp = tensor.getDefiningOp<tensor::ExtractSliceOp>()) {
203 tensor = sliceOp.getSource();
204 continue;
205 }
206 if (auto blockArg = dyn_cast<BlockArgument>(tensor)) {
207 if (auto forOp = blockArg.getDefiningOp<scf::ForOp>()) {
208 tensor = forOp.getInitArgs()[blockArg.getArgNumber()];
209 continue;
210 }
211 }
212 return;
213 }
214}
215
216FailureOr<FusionInfo>
217mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpOperand &consumerOpOperand) {
218 Value inputTensor = consumerOpOperand.get();
219 OpResult producerOpResult;
220 getProducerOfTensor(inputTensor, producerOpResult);
221 if (!producerOpResult) {
222 LLVM_DEBUG(llvm::dbgs() << "\nUnable to find producer");
223 return failure();
224 }
225 return fuseProducerOfTensor(b, producerOpResult, consumerOpOperand);
226}
227
228FailureOr<FusionInfo>
229mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
230 OpOperand &consumerOpOperand) {
231 auto producerOp = dyn_cast<LinalgOp>(producerOpResult.getOwner());
232 if (!producerOp)
233 return failure();
234
235 LinalgOp consumerOp = dyn_cast<LinalgOp>(consumerOpOperand.getOwner());
236 if (!consumerOp)
237 return failure();
238
239 Value inputTensor = consumerOpOperand.get();
240
241 // Must be an extract_slice op to guarantee there are loops we can fuse into.
242 auto sliceOp = inputTensor.getDefiningOp<tensor::ExtractSliceOp>();
243 if (!sliceOp) {
244 LLVM_DEBUG(llvm::dbgs()
245 << "\nNot fusable, not an extract_slice op: " << inputTensor);
246 return failure();
247 }
248
249 // If producer is already in the same block as consumer, we are done.
250 if (consumerOpOperand.get().getParentBlock() ==
251 producerOpResult.getParentBlock())
252 return failure();
253
254 // Insert fused `producer` just before `consumer`.
256 b.setInsertionPoint(consumerOp);
257 LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " << *consumerOp << "\n");
258 OpOperand *opOperand =
259 producerOp.getDpsInitOperand(producerOpResult.getResultNumber());
260 LinalgOp fusedProducer =
261 fuse(b, producerOp, producerOp.getMatchingIndexingMap(opOperand),
262 consumerOpOperand);
263
264 // Replace use.
265 Value def = fusedProducer->getResult(producerOpResult.getResultNumber());
266 Type consumerType = consumerOpOperand.get().getType();
267 // Check if rank-reduction occurred as part of the extract_slice. If yes,
268 // collapse the dropped dimensions.
269 if (cast<ShapedType>(consumerType).getRank() !=
270 cast<ShapedType>(def.getType()).getRank()) {
271 llvm::SmallBitVector droppedDims = sliceOp.getDroppedDims();
272 def =
273 tensor::dropGivenUnitDims(b, fusedProducer.getLoc(), def, droppedDims);
274 }
275 // Canonicalizations are not guaranteed to have happened before constructing
276 // `fusedProducer`. In the tensor case this can result in temporary type
277 // mismatches. Insert a `tensor.cast` op to propagate the transformation
278 // invariant that types are compatible.
279 if (consumerType != def.getType())
280 def = tensor::CastOp::create(b, fusedProducer.getLoc(), consumerType, def);
281 consumerOpOperand.set(def);
282 return FusionInfo{cast<LinalgOp>(producerOpResult.getOwner()), fusedProducer};
283}
static LinalgOp fuse(OpBuilder &b, LinalgOp producer, const DenseMap< unsigned, Range > &fusedLoopsAndRanges)
Fuses the producer by cloning the producer.
Definition Fusion.cpp:103
static void getProducerOfTensor(Value tensor, OpResult &opResult)
Walk back use-def chain through scf::For yields.
Definition Fusion.cpp:192
static SmallVector< Value > getTiledOperands(LinalgOp producer)
Definition Fusion.cpp:96
static ShapeDimension getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth, bool fromSubViewOpOnly=false)
Definition Fusion.cpp:58
static Range getRangeFromOperandShape(OpBuilder &b, Location loc, Value shapedOperand, unsigned dim)
Get the loop range for a dimension dim based on the shapedOperand.
Definition Fusion.cpp:160
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
ArrayRef< AffineExpr > getResults() const
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class provides a mutable adaptor for a range of operands.
Definition ValueRange.h:118
OperandRange getAsOperandRange() const
Explicit conversion to an OperandRange.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:257
This is a value defined by a result of an operation.
Definition Value.h:457
Operation * getOwner() const
Returns the operation that owns this result.
Definition Value.h:466
unsigned getResultNumber() const
Returns the number of this result.
Definition Value.h:469
unsigned getBeginOperandIndex() const
Return the operand index of the first element of this range.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition Value.cpp:46
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
SmallVector< Value > makeTiledShapes(OpBuilder &builder, Location loc, LinalgOp linalgOp, ValueRange valuesToTile, ArrayRef< OpFoldResult > ivs, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > sizeBounds, bool omitPartialTileCheck)
Definition Utils.cpp:860
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
void offsetIndices(OpBuilder &b, LinalgOp linalgOp, ArrayRef< OpFoldResult > offsets)
Definition Utils.cpp:882
CollapseShapeOp dropGivenUnitDims(OpBuilder &b, Location loc, Value src, const llvm::SmallBitVector &dropDims)
Create tensor.collapse_shape to drop unit dimensions in dropDims in tensor src.
Definition Utils.cpp:95
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
Implements a simple high-level fusion pass on linalg structured operations.
Definition Fusion.cpp:47
unsigned dimension
Definition Fusion.cpp:49
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult offset