MLIR 23.0.0git
Fusion.cpp
Go to the documentation of this file.
1//===- Fusion.cpp - Implementation of linalg Fusion -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the linalg dialect Fusion pass.
10//
11//===----------------------------------------------------------------------===//
12
19#include "mlir/IR/AffineExpr.h"
20#include "mlir/IR/AffineMap.h"
21#include "mlir/IR/Dominance.h"
22#include "mlir/Support/LLVM.h"
23#include "llvm/ADT/SmallBitVector.h"
24#include "llvm/ADT/SmallVectorExtras.h"
25#include "llvm/Support/Debug.h"
26
27#define DEBUG_TYPE "linalg-fusion"
28
29using namespace mlir;
30using namespace mlir::linalg;
31
32/// Implements a simple high-level fusion pass on linalg structured operations.
33///
34/// In each block, linalg ops are processed in reverse textual order.
35/// Given a linalg op `O`, fusion occurs by:
36/// 1. inspecting the linalg ops that write into the views read by `O`. There
37/// are 2 cases:
38/// a) buffer case: use the SSA value of the views and a simple alias
39/// analysis on subview ops to determine producer-consumer dependences;
40/// b) tensor case: use SSA use-def chains on extract_slice ops;
41/// 2. greedily fuse the linalg ops that produce the subview/extract_slice.
42/// 3. inspect the fused ops and determine whether they have other remaining
43/// LinalgOp uses. If not, then erase the original producing linalg op.
44///
45/// More advanced use cases, analyses as well as profitability heuristics are
46/// left for future work.
47
50 unsigned dimension;
51};
52
53// Given an `op`, returns the first (`shape`, `dimension`) pair that identifies
54// the loop range at `loopDepth`. The semantics of the loopToOperandRangesMaps
55// guarantees at least one such dimension is found. If multiple candidates exist
56// they must agree by construction (i.e. have the same size) and we just return
57// the first one.
58static ShapeDimension
59getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth,
60 bool fromSubViewOpOnly = false) {
61 // Iterate over the inputs and outputs in order.
62 // Extract the subranges from the linearized ranges.
63 for (OpOperand &opOperand : op->getOpOperands()) {
64 // The method `getRangeFromOperandShape` requires using SubViewOp or
65 // ExtractSliceOps. If the value isn't defined from there continue.
66 // todo: The method should be adapted to get the values from
67 // `ViewInterface`. The interface needs a `getOrCreateRanges` method which
68 // currently returns a `linalg.range`. The fix here is to move this op to
69 // `std` dialect and add the method to `ViewInterface`.
70 if (fromSubViewOpOnly &&
71 !isa_and_nonnull<memref::SubViewOp, tensor::ExtractSliceOp>(
72 opOperand.get().getDefiningOp()))
73 continue;
74
75 AffineMap map = op.getMatchingIndexingMap(&opOperand);
76 LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange I/O idx: "
77 << opOperand.getOperandNumber() << "\n");
78 LLVM_DEBUG(llvm::dbgs()
79 << "getShapeDefiningLoopRange map: " << map << "\n");
80 for (const auto &en : llvm::enumerate(map.getResults())) {
81 auto dimExpr = dyn_cast<AffineDimExpr>(en.value());
82 if (!dimExpr)
83 continue;
84 if (loopDepth == cast<AffineDimExpr>(en.value()).getPosition()) {
85 LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange loopDepth: "
86 << loopDepth << "\n");
87 LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange shape: "
88 << opOperand.get() << "\n");
89 return ShapeDimension{opOperand.get(),
90 static_cast<unsigned>(en.index())};
91 }
92 }
93 }
94 llvm_unreachable("Expect to be able to extract a shape defining loop range");
95}
96
97static SmallVector<Value> getTiledOperands(LinalgOp producer) {
98 return producer->getOperands();
99}
100
101/// Fuses the producer by cloning the `producer`. The `fusedLoopsAndRanges`
102/// provides the loop range information for the fused loops. The rest are
103/// obtained from the producer itself, since they are not tiled + fused.
104static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
105 const DenseMap<unsigned, Range> &fusedLoopsAndRanges) {
106 SmallVector<OpFoldResult> ivs, tileSizes, sizeBounds;
107 SmallVector<Range> loopRanges;
108 Location loc = producer.getLoc();
109
110 for (unsigned i = 0, e = producer.getNumLoops(); i < e; ++i) {
111 auto shapeDim = getShapeDefiningLoopRange(producer, i);
112 OpFoldResult dim =
113 createFoldedDimOp(b, loc, shapeDim.shape, shapeDim.dimension);
114 sizeBounds.push_back(dim);
115 auto it = fusedLoopsAndRanges.find(i);
116 if (it != fusedLoopsAndRanges.end()) {
117 ivs.push_back(it->second.offset);
118 tileSizes.push_back(it->second.size);
119 loopRanges.push_back(it->second);
120 LLVM_DEBUG(llvm::dbgs() << "tiled loop#" << i << " with LoopRange "
121 << loopRanges.back() << "\n");
122 } else {
123 tileSizes.push_back(b.getIndexAttr(0));
124 loopRanges.push_back(Range{b.getIndexAttr(0), dim, b.getIndexAttr(1)});
125 LLVM_DEBUG(llvm::dbgs() << "full loop#" << i << " with LoopRange "
126 << loopRanges.back() << "\n");
127 }
128 }
129
130 SmallVector<Value, 8> clonedShapes;
131 clonedShapes.reserve(producer->getNumOperands());
132
133 // Compute subranges for all tensor input/output operands.
134 clonedShapes.append(makeTiledShapes(
135 b, loc, producer, getTiledOperands(producer), ivs, tileSizes, sizeBounds,
136 /**omitPartialTileCheck=*/false));
137
138 // Take result types from the tiled init operands.
139 MutableOperandRange producerDpsInits = producer.getDpsInitsMutable();
140 SmallVector<Type, 4> resultTypes;
141 resultTypes.reserve(producer->getNumResults());
142 int64_t firstInitOperandIdx =
143 producerDpsInits.getAsOperandRange().getBeginOperandIndex();
144 for (int64_t i = 0, e = producer->getNumResults(); i < e; ++i) {
145 resultTypes.push_back(clonedShapes[firstInitOperandIdx + i].getType());
146 }
147
148 // Clone the producer with new operands and result types.
149 LinalgOp clonedOp = clone(b, producer, resultTypes, clonedShapes);
150
151 // Shift all IndexOp results by the tile offset.
152 SmallVector<OpFoldResult> allIvs = llvm::map_to_vector(
153 loopRanges, [&](Range range) { return range.offset; });
154 offsetIndices(b, clonedOp, allIvs);
155
156 return clonedOp;
157}
158
159/// Get the loop range for a dimension `dim` based on the `shapedOperand`. It is
160/// expected to be defined by a subview op or an extract_slice op.
162 Value shapedOperand, unsigned dim) {
163 Operation *shapeProducingOp = shapedOperand.getDefiningOp();
164 if (auto subViewOp = dyn_cast<memref::SubViewOp>(shapeProducingOp))
165 return subViewOp.getOrCreateRanges(b, loc)[dim];
166 if (auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(shapeProducingOp))
167 return sliceOp.getOrCreateRanges(b, loc)[dim];
168 llvm_unreachable("SubviewOp or ExtractSliceOp expected");
169}
170
171/// Fuses the producer into the loop immediately enclosing the consumer.
172/// This is achieved by "recomputing" the producer at the time it
173/// is needed just before the consumer.
174static LinalgOp fuse(OpBuilder &b, LinalgOp producerOp, AffineMap producerMap,
175 OpOperand &consumerOpOperand) {
176 LLVM_DEBUG(llvm::dbgs() << "Producer map: " << producerMap << "\n");
177 DenseMap<unsigned, Range> fusedLoopsAndRanges;
178 Value shapedOperand = consumerOpOperand.get();
179 for (const auto &en : llvm::enumerate(producerMap.getResults())) {
180 unsigned posInProducerLoop = cast<AffineDimExpr>(en.value()).getPosition();
181 fusedLoopsAndRanges[posInProducerLoop] = getRangeFromOperandShape(
182 b, consumerOpOperand.getOwner()->getLoc(), shapedOperand, en.index());
183 }
184 return fuse(b, producerOp, fusedLoopsAndRanges);
185}
186
187/// Walk back use-def chain through scf::For yields.
188/// Sets `producer` and `outputIndex` if it finds a producer LinalgOp
189
190// TODO(ravishankarm, ntv): This can be moved into the dependence graphs
191// dependence tracking since the dependence tracking is similar to what is done
192// w.r.t to buffers.
193static void getProducerOfTensor(Value tensor, OpResult &opResult) {
194 if (!isa<RankedTensorType>(tensor.getType()))
195 return;
196
197 while (true) {
198 LLVM_DEBUG(llvm::dbgs() << "\ngetProducerOfTensor: " << tensor);
199 if (auto linalgOp = tensor.getDefiningOp<LinalgOp>()) {
200 opResult = cast<OpResult>(tensor);
201 return;
202 }
203 if (auto sliceOp = tensor.getDefiningOp<tensor::ExtractSliceOp>()) {
204 tensor = sliceOp.getSource();
205 continue;
206 }
207 if (auto blockArg = dyn_cast<BlockArgument>(tensor)) {
208 if (auto forOp = blockArg.getDefiningOp<scf::ForOp>()) {
209 tensor = forOp.getInitArgs()[blockArg.getArgNumber()];
210 continue;
211 }
212 }
213 return;
214 }
215}
216
217FailureOr<FusionInfo>
219 Value inputTensor = consumerOpOperand.get();
220 OpResult producerOpResult;
221 getProducerOfTensor(inputTensor, producerOpResult);
222 if (!producerOpResult) {
223 LLVM_DEBUG(llvm::dbgs() << "\nUnable to find producer");
224 return failure();
225 }
226 return fuseProducerOfTensor(b, producerOpResult, consumerOpOperand);
227}
228
229FailureOr<FusionInfo>
231 OpOperand &consumerOpOperand) {
232 auto producerOp = dyn_cast<LinalgOp>(producerOpResult.getOwner());
233 if (!producerOp)
234 return failure();
235
236 LinalgOp consumerOp = dyn_cast<LinalgOp>(consumerOpOperand.getOwner());
237 if (!consumerOp)
238 return failure();
239
240 Value inputTensor = consumerOpOperand.get();
241
242 // Must be an extract_slice op to guarantee there are loops we can fuse into.
243 auto sliceOp = inputTensor.getDefiningOp<tensor::ExtractSliceOp>();
244 if (!sliceOp) {
245 LLVM_DEBUG(llvm::dbgs()
246 << "\nNot fusable, not an extract_slice op: " << inputTensor);
247 return failure();
248 }
249
250 // If producer is already in the same block as consumer, we are done.
251 if (consumerOpOperand.get().getParentBlock() ==
252 producerOpResult.getParentBlock())
253 return failure();
254
255 // Insert fused `producer` just before `consumer`.
257 b.setInsertionPoint(consumerOp);
258 LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " << *consumerOp << "\n");
259 OpOperand *opOperand =
260 producerOp.getDpsInitOperand(producerOpResult.getResultNumber());
261 LinalgOp fusedProducer =
262 fuse(b, producerOp, producerOp.getMatchingIndexingMap(opOperand),
263 consumerOpOperand);
264
265 // Replace use.
266 Value def = fusedProducer->getResult(producerOpResult.getResultNumber());
267 Type consumerType = consumerOpOperand.get().getType();
268 // Check if rank-reduction occurred as part of the extract_slice. If yes,
269 // collapse the dropped dimensions.
270 if (cast<ShapedType>(consumerType).getRank() !=
271 cast<ShapedType>(def.getType()).getRank()) {
272 llvm::SmallBitVector droppedDims = sliceOp.getDroppedDims();
273 def =
274 tensor::dropGivenUnitDims(b, fusedProducer.getLoc(), def, droppedDims);
275 }
276 // Canonicalizations are not guaranteed to have happened before constructing
277 // `fusedProducer`. In the tensor case this can result in temporary type
278 // mismatches. Insert a `tensor.cast` op to propagate the transformation
279 // invariant that types are compatible.
280 if (consumerType != def.getType())
281 def = tensor::CastOp::create(b, fusedProducer.getLoc(), consumerType, def);
282 consumerOpOperand.set(def);
283 return FusionInfo{cast<LinalgOp>(producerOpResult.getOwner()), fusedProducer};
284}
static LinalgOp fuse(OpBuilder &b, LinalgOp producer, const DenseMap< unsigned, Range > &fusedLoopsAndRanges)
Fuses the producer by cloning the producer.
Definition Fusion.cpp:104
static void getProducerOfTensor(Value tensor, OpResult &opResult)
Walk back use-def chain through scf::For yields.
Definition Fusion.cpp:193
static SmallVector< Value > getTiledOperands(LinalgOp producer)
Definition Fusion.cpp:97
static ShapeDimension getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth, bool fromSubViewOpOnly=false)
Definition Fusion.cpp:59
static Range getRangeFromOperandShape(OpBuilder &b, Location loc, Value shapedOperand, unsigned dim)
Get the loop range for a dimension dim based on the shapedOperand.
Definition Fusion.cpp:161
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
ArrayRef< AffineExpr > getResults() const
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class provides a mutable adaptor for a range of operands.
Definition ValueRange.h:118
OperandRange getAsOperandRange() const
Explicit conversion to an OperandRange.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:257
This is a value defined by a result of an operation.
Definition Value.h:457
Operation * getOwner() const
Returns the operation that owns this result.
Definition Value.h:466
unsigned getResultNumber() const
Returns the number of this result.
Definition Value.h:469
unsigned getBeginOperandIndex() const
Return the operand index of the first element of this range.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition Value.cpp:46
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
FailureOr< FusionInfo > fuseProducerOfTensor(OpBuilder &b, OpOperand &consumerOpOperand)
This implements the fusion part of the "tileAndFuse on tensors" transformation and thus requires the ...
Definition Fusion.cpp:218
SmallVector< Value > makeTiledShapes(OpBuilder &builder, Location loc, LinalgOp linalgOp, ValueRange valuesToTile, ArrayRef< OpFoldResult > ivs, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > sizeBounds, bool omitPartialTileCheck)
Creates extract_slice/subview ops for all valuesToTile of the given linalgOp with builder,...
Definition Utils.cpp:2850
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
void offsetIndices(OpBuilder &b, LinalgOp linalgOp, ArrayRef< OpFoldResult > offests)
Add the specified offsets to any linalg.index ops contained in the given linalgOp.
Definition Utils.cpp:2872
CollapseShapeOp dropGivenUnitDims(OpBuilder &b, Location loc, Value src, const llvm::SmallBitVector &dropDims)
Create tensor.collapse_shape to drop unit dimensions in dropDims in tensor src.
Definition Utils.cpp:95
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:118
Implements a simple high-level fusion pass on linalg structured operations.
Definition Fusion.cpp:48
unsigned dimension
Definition Fusion.cpp:50
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult offset
A struct containing the Linalg producer before and after fusion.
Definition Utils.h:250