MLIR  22.0.0git
Fusion.cpp
Go to the documentation of this file.
1 //===- Fusion.cpp - Implementation of linalg Fusion -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Fusion pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
19 #include "mlir/IR/AffineExpr.h"
20 #include "mlir/IR/AffineMap.h"
21 #include "mlir/IR/Dominance.h"
22 #include "mlir/Support/LLVM.h"
23 #include "llvm/ADT/SmallBitVector.h"
24 #include "llvm/Support/Debug.h"
25 
26 #define DEBUG_TYPE "linalg-fusion"
27 
28 using namespace mlir;
29 using namespace mlir::linalg;
30 
31 /// Implements a simple high-level fusion pass on linalg structured operations.
32 ///
33 /// In each block, linalg ops are processed in reverse textual order.
34 /// Given a linalg op `O`, fusion occurs by:
35 /// 1. inspecting the linalg ops that write into the views read by `O`. There
36 /// are 2 cases:
37 /// a) buffer case: use the SSA value of the views and a simple alias
38 /// analysis on subview ops to determine producer-consumer dependences;
39 /// b) tensor case: use SSA use-def chains on extract_slice ops;
40 /// 2. greedily fuse the linalg ops that produce the subview/extract_slice.
41 /// 3. inspect the fused ops and determine whether they have other remaining
42 /// LinalgOp uses. If not, then erase the original producing linalg op.
43 ///
44 /// More advanced use cases, analyses as well as profitability heuristics are
45 /// left for future work.
46 
49  unsigned dimension;
50 };
51 
52 // Given an `op`, returns the first (`shape`, `dimension`) pair that identifies
53 // the loop range at `loopDepth`. The semantics of the loopToOperandRangesMaps
54 // guarantees at least one such dimension is found. If multiple candidates exist
55 // they must agree by construction (i.e. have the same size) and we just return
56 // the first one.
57 static ShapeDimension
58 getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth,
59  bool fromSubViewOpOnly = false) {
60  // Iterate over the inputs and outputs in order.
61  // Extract the subranges from the linearized ranges.
62  for (OpOperand &opOperand : op->getOpOperands()) {
63  // The method `getRangeFromOperandShape` requires using SubViewOp or
64  // ExtractSliceOps. If the value isn't defined from there continue.
65  // todo: The method should be adapted to get the values from
66  // `ViewInterface`. The interface needs a `getOrCreateRanges` method which
67  // currently returns a `linalg.range`. The fix here is to move this op to
68  // `std` dialect and add the method to `ViewInterface`.
69  if (fromSubViewOpOnly &&
70  !isa_and_nonnull<memref::SubViewOp, tensor::ExtractSliceOp>(
71  opOperand.get().getDefiningOp()))
72  continue;
73 
74  AffineMap map = op.getMatchingIndexingMap(&opOperand);
75  LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange I/O idx: "
76  << opOperand.getOperandNumber() << "\n");
77  LLVM_DEBUG(llvm::dbgs()
78  << "getShapeDefiningLoopRange map: " << map << "\n");
79  for (const auto &en : llvm::enumerate(map.getResults())) {
80  auto dimExpr = dyn_cast<AffineDimExpr>(en.value());
81  if (!dimExpr)
82  continue;
83  if (loopDepth == cast<AffineDimExpr>(en.value()).getPosition()) {
84  LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange loopDepth: "
85  << loopDepth << "\n");
86  LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange shape: "
87  << opOperand.get() << "\n");
88  return ShapeDimension{opOperand.get(),
89  static_cast<unsigned>(en.index())};
90  }
91  }
92  }
93  llvm_unreachable("Expect to be able to extract a shape defining loop range");
94 }
95 
96 static SmallVector<Value> getTiledOperands(LinalgOp producer) {
97  return producer->getOperands();
98 }
99 
100 /// Fuses the producer by cloning the `producer`. The `fusedLoopsAndRanges`
101 /// provides the loop range information for the fused loops. The rest are
102 /// obtained from the producer itself, since they are not tiled + fused.
103 static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
104  const DenseMap<unsigned, Range> &fusedLoopsAndRanges) {
105  SmallVector<OpFoldResult> ivs, tileSizes, sizeBounds;
106  SmallVector<Range> loopRanges;
107  Location loc = producer.getLoc();
108 
109  for (unsigned i = 0, e = producer.getNumLoops(); i < e; ++i) {
110  auto shapeDim = getShapeDefiningLoopRange(producer, i);
111  OpFoldResult dim =
112  createFoldedDimOp(b, loc, shapeDim.shape, shapeDim.dimension);
113  sizeBounds.push_back(dim);
114  auto it = fusedLoopsAndRanges.find(i);
115  if (it != fusedLoopsAndRanges.end()) {
116  ivs.push_back(it->second.offset);
117  tileSizes.push_back(it->second.size);
118  loopRanges.push_back(it->second);
119  LLVM_DEBUG(llvm::dbgs() << "tiled loop#" << i << " with LoopRange "
120  << loopRanges.back() << "\n");
121  } else {
122  tileSizes.push_back(b.getIndexAttr(0));
123  loopRanges.push_back(Range{b.getIndexAttr(0), dim, b.getIndexAttr(1)});
124  LLVM_DEBUG(llvm::dbgs() << "full loop#" << i << " with LoopRange "
125  << loopRanges.back() << "\n");
126  }
127  }
128 
129  SmallVector<Value, 8> clonedShapes;
130  clonedShapes.reserve(producer->getNumOperands());
131 
132  // Compute subranges for all tensor input/output operands.
133  clonedShapes.append(makeTiledShapes(
134  b, loc, producer, getTiledOperands(producer), ivs, tileSizes, sizeBounds,
135  /**omitPartialTileCheck=*/false));
136 
137  // Take result types from the tiled init operands.
138  MutableOperandRange producerDpsInits = producer.getDpsInitsMutable();
139  SmallVector<Type, 4> resultTypes;
140  resultTypes.reserve(producer->getNumResults());
141  int64_t firstInitOperandIdx =
142  producerDpsInits.getAsOperandRange().getBeginOperandIndex();
143  for (int64_t i = 0, e = producer->getNumResults(); i < e; ++i) {
144  resultTypes.push_back(clonedShapes[firstInitOperandIdx + i].getType());
145  }
146 
147  // Clone the producer with new operands and result types.
148  LinalgOp clonedOp = clone(b, producer, resultTypes, clonedShapes);
149 
150  // Shift all IndexOp results by the tile offset.
151  SmallVector<OpFoldResult> allIvs = llvm::to_vector(
152  llvm::map_range(loopRanges, [&](Range range) { return range.offset; }));
153  offsetIndices(b, clonedOp, allIvs);
154 
155  return clonedOp;
156 }
157 
158 /// Get the loop range for a dimension `dim` based on the `shapedOperand`. It is
159 /// expected to be defined by a subview op or an extract_slice op.
161  Value shapedOperand, unsigned dim) {
162  Operation *shapeProducingOp = shapedOperand.getDefiningOp();
163  if (auto subViewOp = dyn_cast<memref::SubViewOp>(shapeProducingOp))
164  return subViewOp.getOrCreateRanges(b, loc)[dim];
165  if (auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(shapeProducingOp))
166  return sliceOp.getOrCreateRanges(b, loc)[dim];
167  llvm_unreachable("SubviewOp or ExtractSliceOp expected");
168 }
169 
170 /// Fuses the producer into the loop immediately enclosing the consumer.
171 /// This is achieved by "recomputing" the producer at the time it
172 /// is needed just before the consumer.
173 static LinalgOp fuse(OpBuilder &b, LinalgOp producerOp, AffineMap producerMap,
174  OpOperand &consumerOpOperand) {
175  LLVM_DEBUG(llvm::dbgs() << "Producer map: " << producerMap << "\n");
176  DenseMap<unsigned, Range> fusedLoopsAndRanges;
177  Value shapedOperand = consumerOpOperand.get();
178  for (const auto &en : llvm::enumerate(producerMap.getResults())) {
179  unsigned posInProducerLoop = cast<AffineDimExpr>(en.value()).getPosition();
180  fusedLoopsAndRanges[posInProducerLoop] = getRangeFromOperandShape(
181  b, consumerOpOperand.getOwner()->getLoc(), shapedOperand, en.index());
182  }
183  return fuse(b, producerOp, fusedLoopsAndRanges);
184 }
185 
186 /// Walk back use-def chain through scf::For yields.
187 /// Sets `producer` and `outputIndex` if it finds a producer LinalgOp
188 
189 // TODO(ravishankarm, ntv): This can be moved into the dependence graphs
190 // dependence tracking since the dependence tracking is similar to what is done
191 // w.r.t to buffers.
192 static void getProducerOfTensor(Value tensor, OpResult &opResult) {
193  if (!isa<RankedTensorType>(tensor.getType()))
194  return;
195 
196  while (true) {
197  LLVM_DEBUG(llvm::dbgs() << "\ngetProducerOfTensor: " << tensor);
198  if (auto linalgOp = tensor.getDefiningOp<LinalgOp>()) {
199  opResult = cast<OpResult>(tensor);
200  return;
201  }
202  if (auto sliceOp = tensor.getDefiningOp<tensor::ExtractSliceOp>()) {
203  tensor = sliceOp.getSource();
204  continue;
205  }
206  if (auto blockArg = dyn_cast<BlockArgument>(tensor)) {
207  if (auto forOp = blockArg.getDefiningOp<scf::ForOp>()) {
208  tensor = forOp.getInitArgs()[blockArg.getArgNumber()];
209  continue;
210  }
211  }
212  return;
213  }
214 }
215 
216 FailureOr<FusionInfo>
218  Value inputTensor = consumerOpOperand.get();
219  OpResult producerOpResult;
220  getProducerOfTensor(inputTensor, producerOpResult);
221  if (!producerOpResult) {
222  LLVM_DEBUG(llvm::dbgs() << "\nUnable to find producer");
223  return failure();
224  }
225  return fuseProducerOfTensor(b, producerOpResult, consumerOpOperand);
226 }
227 
228 FailureOr<FusionInfo>
230  OpOperand &consumerOpOperand) {
231  auto producerOp = dyn_cast<LinalgOp>(producerOpResult.getOwner());
232  if (!producerOp)
233  return failure();
234 
235  LinalgOp consumerOp = dyn_cast<LinalgOp>(consumerOpOperand.getOwner());
236  if (!consumerOp)
237  return failure();
238 
239  Value inputTensor = consumerOpOperand.get();
240 
241  // Must be an extract_slice op to guarantee there are loops we can fuse into.
242  auto sliceOp = inputTensor.getDefiningOp<tensor::ExtractSliceOp>();
243  if (!sliceOp) {
244  LLVM_DEBUG(llvm::dbgs()
245  << "\nNot fusable, not an extract_slice op: " << inputTensor);
246  return failure();
247  }
248 
249  // If producer is already in the same block as consumer, we are done.
250  if (consumerOpOperand.get().getParentBlock() ==
251  producerOpResult.getParentBlock())
252  return failure();
253 
254  // Insert fused `producer` just before `consumer`.
256  b.setInsertionPoint(consumerOp);
257  LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " << *consumerOp << "\n");
258  OpOperand *opOperand =
259  producerOp.getDpsInitOperand(producerOpResult.getResultNumber());
260  LinalgOp fusedProducer =
261  fuse(b, producerOp, producerOp.getMatchingIndexingMap(opOperand),
262  consumerOpOperand);
263 
264  // Replace use.
265  Value def = fusedProducer->getResult(producerOpResult.getResultNumber());
266  Type consumerType = consumerOpOperand.get().getType();
267  // Check if rank-reduction occurred as part of the extract_slice. If yes,
268  // collapse the dropped dimensions.
269  if (cast<ShapedType>(consumerType).getRank() !=
270  cast<ShapedType>(def.getType()).getRank()) {
271  llvm::SmallBitVector droppedDims = sliceOp.getDroppedDims();
272  def =
273  tensor::dropGivenUnitDims(b, fusedProducer.getLoc(), def, droppedDims);
274  }
275  // Canonicalizations are not guaranteed to have happened before constructing
276  // `fusedProducer`. In the tensor case this can result in temporary type
277  // mismatches. Insert a `tensor.cast` op to propagate the transformation
278  // invariant that types are compatible.
279  if (consumerType != def.getType())
280  def = tensor::CastOp::create(b, fusedProducer.getLoc(), consumerType, def);
281  consumerOpOperand.set(def);
282  return FusionInfo{cast<LinalgOp>(producerOpResult.getOwner()), fusedProducer};
283 }
static LinalgOp fuse(OpBuilder &b, LinalgOp producer, const DenseMap< unsigned, Range > &fusedLoopsAndRanges)
Fuses the producer by cloning the producer.
Definition: Fusion.cpp:103
static void getProducerOfTensor(Value tensor, OpResult &opResult)
Walk back use-def chain through scf::For yields.
Definition: Fusion.cpp:192
static SmallVector< Value > getTiledOperands(LinalgOp producer)
Definition: Fusion.cpp:96
static ShapeDimension getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth, bool fromSubViewOpOnly=false)
Definition: Fusion.cpp:58
static Range getRangeFromOperandShape(OpBuilder &b, Location loc, Value shapedOperand, unsigned dim)
Get the loop range for a dimension dim based on the shapedOperand.
Definition: Fusion.cpp:160
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:403
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:103
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
void set(IRValueT newValue)
Set the current value being used by this operand.
Definition: UseDefLists.h:163
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:118
OperandRange getAsOperandRange() const
Explicit conversion to an OperandRange.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
This class represents an operand of an operation.
Definition: Value.h:257
This is a value defined by a result of an operation.
Definition: Value.h:447
Operation * getOwner() const
Returns the operation that owns this result.
Definition: Value.h:456
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:459
unsigned getBeginOperandIndex() const
Return the operand index of the first element of this range.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:46
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
FailureOr< FusionInfo > fuseProducerOfTensor(OpBuilder &b, OpOperand &consumerOpOperand)
This implements the fusion part of the "tileAndFuse on tensors" transformation and thus requires the ...
Definition: Fusion.cpp:217
SmallVector< Value > makeTiledShapes(OpBuilder &builder, Location loc, LinalgOp linalgOp, ValueRange valuesToTile, ArrayRef< OpFoldResult > ivs, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > sizeBounds, bool omitPartialTileCheck)
Creates extract_slice/subview ops for all valuesToTile of the given linalgOp with builder,...
Definition: Utils.cpp:862
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition: LinalgOps.cpp:103
void offsetIndices(OpBuilder &b, LinalgOp linalgOp, ArrayRef< OpFoldResult > offests)
Add the specified offsets to any linalg.index ops contained in the given linalgOp.
Definition: Utils.cpp:884
CollapseShapeOp dropGivenUnitDims(OpBuilder &b, Location loc, Value src, const llvm::SmallBitVector &dropDims)
Create tensor.collapse_shape to drop unit dimensions in dropDims in tensor src.
Definition: Utils.cpp:95
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
Implements a simple high-level fusion pass on linalg structured operations.
Definition: Fusion.cpp:47
unsigned dimension
Definition: Fusion.cpp:49
Value shape
Definition: Fusion.cpp:48
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult offset
A struct containing the Linalg producer before and after fusion.
Definition: Utils.h:233