MLIR  22.0.0git
TileUsingInterface.h
Go to the documentation of this file.
1 //===- TileUsingInterface.h - Tiling ops using TilingInterface --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_SCF_TRANSFORMS_TILEUSINGINTERFACE_H
10 #define MLIR_DIALECT_SCF_TRANSFORMS_TILEUSINGINTERFACE_H
11 
14 #include "mlir/IR/PatternMatch.h"
19 
20 #include <deque>
21 
22 namespace mlir {
23 class Operation;
24 class RewriterBase;
25 class TilingInterface;
26 } // namespace mlir
27 
28 namespace mlir {
29 namespace scf {
30 
32  std::function<SmallVector<OpFoldResult>(OpBuilder &, Operation *)>;
33 
34 /// Options to use to control tiling.
36  /// Computation function that returns the tile sizes to use for each loop.
37  /// Returning a tile size of zero implies no tiling for that loop. If the
38  /// size of the returned vector is smaller than the number of loops, the inner
39  /// loops are not tiled. If the size of the returned vector is larger, then
40  /// the vector is truncated to number of loops.
42 
45  tileSizeComputationFunction = std::move(fun);
46  return *this;
47  }
48  /// Convenience function to set the `tileSizeComputationFunction` to a
49  /// function that computes tile sizes at the point they are needed. Allows
50  /// proper interaction with folding.
52 
53  /// Computation function that returns the number of threads to use for
54  /// each loop. Returning a num threads of zero implies no tiling for that
55  /// loop. If the size of the returned vector is smaller than the number of
56  /// loops, the inner loops are not tiled. If the size of the returned vector
57  /// is larger, then the vector is truncated to number of loops. Note: This
58  /// option is only supported with loopType set to `LoopType::ForallOp`. If the
59  /// tile size function is not specified while the num threads computation is,
60  /// then the tile size is determined automatically to map at most one tile per
61  /// thread.
63 
66  numThreadsComputationFunction = std::move(fun);
67  return *this;
68  }
69  /// Convenience function to set the `numThreadsComputationFunction` to a
70  /// function that computes num threads at the point they are needed.
72 
73  /// The interchange vector to reorder the tiled loops.
76  interchangeVector = llvm::to_vector(interchange);
77  return *this;
78  }
79 
80  /// Specify which loop construct to use for tile and fuse.
81  enum class LoopType { ForOp, ForallOp };
84  loopType = type;
85  return *this;
86  }
87 
88  /// Specify mapping of loops to devices. This is only respected when the loop
89  /// constructs support such a mapping (like `scf.forall`). Will be ignored
90  /// when using loop constructs that dont support such a mapping (like
91  /// `scf.for`)
94  mappingVector = llvm::to_vector(mapping);
95  return *this;
96  }
97 
98  //-------------------------------------------------------------------------//
99  // Options related reduction tiling
100  //-------------------------------------------------------------------------//
101 
102  /// Specify how reduction dimensions should be tiled.
107  reductionStrategy = strategy;
108  return *this;
109  }
110 
111  /// Specify the reduction dimensions to be tiled. Note that this needs to be
112  /// specified. If left unspecified, then none of the reduction dimensions are
113  /// tiled.
116  reductionDims.clear();
117  reductionDims.insert(dims.begin(), dims.end());
118  return *this;
119  }
120 };
121 
122 /// Transformation information returned after tiling.
124  /// Tiled operations that are generated during tiling. The order does not
125  /// matter except the last op. The replacements are expected to be the results
126  /// of the last op.
128  /// The initial destination values passed to the tiled operations.
130  /// The `scf.for` operations that iterate over the tiles.
132  /// Values to use as replacements for the untiled op. Is the same size as the
133  /// number of results of the untiled op.
135  /// Slices generated after tiling that can be used for fusing with the tiled
136  /// producer.
138  /// In cases where there as an additional merge step after tiling
139  /// return the merged ops after tiling. This list is empty when reduction
140  /// tiling strategy is
141  /// `scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction.
143 };
144 
145 /// Method to tile an op that implements the `TilingInterface` using
146 /// `scf.for` for iterating over the tiles.
147 FailureOr<SCFTilingResult> tileUsingSCF(RewriterBase &rewriter,
148  TilingInterface op,
149  const SCFTilingOptions &options);
150 
151 /// Options used to control tile + fuse.
153  /// The tiling options used to control the tiling of the consumer.
157  return *this;
158  }
159 
160  /// Control function to check if a slice needs to be fused or not,
161  /// The control function receives
162  /// 1) the slice along which fusion is to be done,
163  /// 2) the producer value that is to be fused
164  /// 3) a boolean value set to `true` if the fusion is from
165  /// a destination operand.
166  /// The control function returns an `std::optiona<ControlFnResult>`.
167  /// If the return value is `std::nullopt`, that implies no fusion
168  /// is to be performed along that slice.
170  /// Set to true if the loop nest has to return a replacement value
171  /// for the fused producer.
173  };
174  using ControlFnTy = std::function<std::optional<ControlFnResult>(
175  tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer,
176  bool isDestinationOperand)>;
177  /// The default control function implements greedy fusion without yielding
178  /// a replacement for any of the fused results.
179  ControlFnTy fusionControlFn = [](tensor::ExtractSliceOp, OpResult,
180  bool) -> std::optional<ControlFnResult> {
181  return ControlFnResult{};
182  };
184  fusionControlFn = controlFn;
185  return *this;
186  }
187 
188  /// An optional set of rewrite patterns to apply to the results of tiling
189  /// before fusion. This will track deleted and newly inserted
190  /// `tensor.extract_slice` ops and update the worklist.
191  std::optional<FrozenRewritePatternSet> cleanupPatterns = std::nullopt;
192 };
193 
194 /// Fuse the producer of the source of `candidateSliceOp` by computing the
195 /// required slice of the producer in-place. Note that the method
196 /// replaces the uses of `candidateSliceOp` with the tiled and fused producer
197 /// value but does not delete the slice operation.
199  OpResult origProducer; // Original untiled producer.
200  Value tiledAndFusedProducer; // Tile and fused producer value.
203 };
204 std::optional<SCFFuseProducerOfSliceResult>
206  tensor::ExtractSliceOp candidateSliceOp,
208 
209 /// Reconstruct the fused producer from within the tiled-and-fused code. Based
210 /// on the slice of the producer computed in place it is possible that within
211 /// the loop nest same slice of the producer is computed multiple times. It is
212 /// in general not possible to recompute the value of the fused producer from
213 /// the tiled loop code in such cases. For the cases where no slice of the
214 /// producer is computed in a redundant fashion it is possible to reconstruct
215 /// the value of the original producer from within the tiled loop. It is upto
216 /// the caller to ensure that the producer is not computed redundantly within
217 /// the tiled loop nest. For example, consider
218 ///
219 /// ```mlir
220 /// %0 = linalg.matmul ins(...) outs(...) -> tensor<?x?xf32>
221 /// %1 = linalg.matmul ins(%0, ..) outs(...) -> tensor<?x?x?f32>
222 /// ```
223 ///
224 /// If `%1` is tiled in a 2D fashion and `%0` is fused with it, the resulting IR
225 /// is,
226 ///
227 /// ```mlir
228 /// %t1_0 = scf.for .... iter_args(%arg0 = ...) {
229 /// %t1_1 = scf.for ... iter_args(%arg1 = %arg0) {
230 /// ...
231 /// %t1_2 = linalg.matmul ins(...) outs(...) -> tensor<?x?xf32>
232 /// %t1_3 = linalg.matmul ins(%t1_2, ...)
233 /// %t1_4 = tensor.insert_slice %t1_3 into %arg1 ...
234 /// scf.yield %t1_4
235 /// }
236 /// scf.yield %t1_1
237 /// }
238 /// ```
239 ///
240 /// Here `%t1_2` is the same for all iterations of the inner `scf.for`. Instead
241 /// if `%1` were tiled only along the rows, the resultant code would be
242 ///
243 /// ```mlir
244 /// %t2_0 = scf.for .... iter_args(%arg0 = ...) {
245 /// ...
246 /// %t2_1 = linalg.matmul ins(...) outs(...) -> tensor<?x?xf32>
247 /// %t2_2 = linalg.matmul ins(%t2_1, ...)
248 /// %t2_3 = tensor.insert_slice %t2_2 into %arg0 ...
249 /// scf.yield %t2_3
250 /// }
251 /// ```
252 ///
253 /// Here there is no intersection in the different slices of `%t2_1` computed
254 /// across iterations of the `scf.for`. In such cases, the value of the original
255 /// `%0` can be reconstructed from within the loop body. This is useful in cases
256 /// where `%0` had other uses as well. If not reconstructed from within the loop
257 /// body, uses of `%0` could not be replaced, making it still live and the
258 /// fusion immaterial.
259 ///
260 /// The @param `yieldResultNumber` decides which result would be yield. If not
261 /// given, yield all `opResult` of fused producer.
262 ///
263 /// The method returns the list of new slices added during the process (which
264 /// can be used to fuse along).
265 FailureOr<SmallVector<Operation *>> yieldReplacementForFusedProducer(
266  RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
267  scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
269  ArrayRef<unsigned> yieldResultNumber = ArrayRef<unsigned>{});
270 
271 /// Transformation information returned after tile and fuse.
273  /// List of untiled operations that were fused with the tiled consumer.
275  /// List of tiled and fused operations generated. The first element is always
276  /// the tiled version of the original consumer operation processed by
277  /// `tileConsumerAndFuseProducersUsingSCF`, followed by any operations that
278  /// were fused with it.
280  /// The `scf.for` operations that iterate over the tiles.
282  /// The replacement values to use for the tiled and fused operations.
284 };
285 
286 /// Method to tile and fuse a sequence of operations, by tiling the consumer
287 /// and fusing its producers. Note that this assumes that it is valid to
288 /// tile+fuse the producer into the innermost tiled loop. Its up to the caller
289 /// to ensure that the tile sizes provided make this fusion valid.
290 ///
291 /// For example, for the following sequence
292 ///
293 /// ```mlir
294 /// %0 =
295 /// %1 = linalg.fill ... outs(%0 : ... )
296 /// %2 = linalg.matmul ... outs(%1 : ...) ...
297 /// ```
298 ///
299 /// it is legal to fuse the fill with the matmul only if the matmul is tiled
300 /// along the parallel dimensions and not the reduction dimension, i.e. the tile
301 /// size for the reduction dimension should be 0. The resulting fused
302 /// transformation is
303 ///
304 /// ```mlir
305 /// %1 = scf.for ... iter_args(%arg0 = %0)
306 /// %2 = tensor.extract_slice %arg0
307 /// %3 = linalg.fill .. outs(%2 : ... )
308 /// %4 = linalg.matmul .. outs(%3 : ...)
309 /// }
310 /// ```
311 FailureOr<SCFTileAndFuseResult>
313  TilingInterface consumer,
315 
316 /// Fuse the consumer `candidateSlices` by computing the required slice of the
317 /// consumer in-place. All the entries of `candidateSlices` are expected to map
318 /// to the same consumer. The method returns an error if the consumer cannot be
319 /// tiled in a manner that is consistent for all the passed slices. Note that
320 /// the method replaces the uses of `candidateSlices` with the tiled and fused
321 /// consumer value but does not delete the slice operations.
323  // Original untiled consumer operands.
325  // Tiled and fused consumer operands.
328 };
329 FailureOr<scf::SCFFuseConsumerOfSliceResult>
331  ArrayRef<Operation *> candidateSlices,
333 
334 /// Method to lower an `op` that implements the `TilingInterface` to
335 /// loops/scalars.
336 FailureOr<SmallVector<scf::ForOp>>
337 lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op);
338 
339 /// Method to tile a reduction and generate a parallel op within a serial loop.
340 /// Each of the partial reductions are calculated in parallel. Then after the
341 /// loop all the partial reduction are merged into a final reduction.
342 /// For example for the following sequence
343 ///
344 /// ```mlir
345 /// %0 = linalg.generic %in ["parallel", "reduction"]
346 /// : tensor<7x9xf32> -> tensor<7xf32>
347 /// ```
348 ///
349 /// into:
350 ///
351 /// ```mlir
352 /// %0 = linalg.fill ... : tensor<7x4xf32>
353 /// %1 = scf.for ... iter_args(%arg0 = %0)
354 /// %2 = tensor.extract_slice %arg0 : tensor<7x4xf32> -> tensor<7x?xf32>
355 /// %3 = tensor.extract_slice %in : tensor<7x9xf32> -> tensor<7x?xf32>
356 /// %4 = linalg.generic %2, %3 ["parallel", "parallel"]
357 /// : tensor<7x?xf32> -> tensor<7x?xf32>
358 /// %5 = tensor.insert_slice %3, %0[0, 0] : tensor<7x4xf32>
359 /// }
360 /// %6 = linalg.generic %1 ["parallel", "reduction"]
361 /// : tensor<7x4xf32> -> tensor<7xf32>
362 /// ```
363 FailureOr<scf::SCFTilingResult>
364 tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op,
365  ArrayRef<OpFoldResult> tileSizes);
366 
367 } // namespace scf
368 } // namespace mlir
369 
370 #endif // MLIR_DIALECT_SCF_TRANSFORMS_TILEUSINGINTERFACE_H
static llvm::ManagedStatic< PassManagerOptions > options
This class helps build Operations.
Definition: Builders.h:205
This is a value defined by a result of an operation.
Definition: Value.h:447
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
FailureOr< scf::SCFTilingResult > tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > tileSizes)
Method to tile a reduction and generate a parallel op within a serial loop.
FailureOr< SCFTilingResult > tileUsingSCF(RewriterBase &rewriter, TilingInterface op, const SCFTilingOptions &options)
Method to tile an op that implements the TilingInterface using scf.for for iterating over the tiles.
FailureOr< scf::SCFFuseConsumerOfSliceResult > tileAndFuseConsumerOfSlices(RewriterBase &rewriter, ArrayRef< Operation * > candidateSlices, MutableArrayRef< LoopLikeOpInterface > loops)
Implementation of fusing consumer of a single slice by computing the slice of the consumer in-place f...
FailureOr< SmallVector< scf::ForOp > > lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op)
Method to lower an op that implements the TilingInterface to loops/scalars.
FailureOr< SmallVector< Operation * > > yieldReplacementForFusedProducer(RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp, scf::SCFFuseProducerOfSliceResult fusedProducerInfo, MutableArrayRef< LoopLikeOpInterface > loops, ArrayRef< unsigned > yieldResultNumber=ArrayRef< unsigned >{})
Reconstruct the fused producer from within the tiled-and-fused code.
FailureOr< SCFTileAndFuseResult > tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter, TilingInterface consumer, const SCFTileAndFuseOptions &options)
Method to tile and fuse a sequence of operations, by tiling the consumer and fusing its producers.
std::function< SmallVector< OpFoldResult >(OpBuilder &, Operation *)> SCFTileSizeComputationFunction
std::optional< SCFFuseProducerOfSliceResult > tileAndFuseProducerOfSlice(RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp, MutableArrayRef< LoopLikeOpInterface > loops)
Implementation of fusing producer of a single slice by computing the slice of the producer in-place.
Include the generated interface declarations.
ReductionTilingStrategy
Tiling can be thought of as splitting a dimension into 2 and materializing the outer dimension as a l...
Fuse the consumer candidateSlices by computing the required slice of the consumer in-place.
SmallVector< OpOperand * > tiledAndFusedConsumerOperands
SmallVector< OpOperand * > origConsumerOperands
Fuse the producer of the source of candidateSliceOp by computing the required slice of the producer i...
SmallVector< Operation * > generatedSlices
Control function to check if a slice needs to be fused or not, The control function receives 1) the s...
bool yieldProducerReplacement
Set to true if the loop nest has to return a replacement value for the fused producer.
Options used to control tile + fuse.
SCFTilingOptions tilingOptions
The tiling options used to control the tiling of the consumer.
std::optional< FrozenRewritePatternSet > cleanupPatterns
An optional set of rewrite patterns to apply to the results of tiling before fusion.
SCFTileAndFuseOptions & setTilingOptions(SCFTilingOptions options)
ControlFnTy fusionControlFn
The default control function implements greedy fusion without yielding a replacement for any of the f...
std::function< std::optional< ControlFnResult >(tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer, bool isDestinationOperand)> ControlFnTy
SCFTileAndFuseOptions & setFusionControlFn(ControlFnTy controlFn)
Transformation information returned after tile and fuse.
SmallVector< LoopLikeOpInterface > loops
The scf.for operations that iterate over the tiles.
llvm::SetVector< Operation * > fusedProducers
List of untiled operations that were fused with the tiled consumer.
llvm::DenseMap< Value, Value > replacements
The replacement values to use for the tiled and fused operations.
llvm::SetVector< Operation * > tiledAndFusedOps
List of tiled and fused operations generated.
Options to use to control tiling.
SCFTileSizeComputationFunction tileSizeComputationFunction
Computation function that returns the tile sizes to use for each loop.
SCFTilingOptions & setTileSizeComputationFunction(SCFTileSizeComputationFunction fun)
SCFTilingOptions & setReductionTilingStrategy(ReductionTilingStrategy strategy)
SCFTilingOptions & setNumThreadsComputationFunction(SCFTileSizeComputationFunction fun)
SCFTilingOptions & setNumThreads(ArrayRef< OpFoldResult > numThreads)
Convenience function to set the numThreadsComputationFunction to a function that computes num threads...
SCFTilingOptions & setInterchange(ArrayRef< int64_t > interchange)
SCFTilingOptions & setTileSizes(ArrayRef< OpFoldResult > tileSizes)
Convenience function to set the tileSizeComputationFunction to a function that computes tile sizes at...
SCFTilingOptions & setReductionDims(ArrayRef< unsigned > dims)
SetVector< unsigned > reductionDims
Specify the reduction dimensions to be tiled.
SCFTileSizeComputationFunction numThreadsComputationFunction
Computation function that returns the number of threads to use for each loop.
SmallVector< int64_t > interchangeVector
The interchange vector to reorder the tiled loops.
LoopType
Specify which loop construct to use for tile and fuse.
SmallVector< Attribute > mappingVector
Specify mapping of loops to devices.
SCFTilingOptions & setLoopType(LoopType type)
SCFTilingOptions & setMapping(ArrayRef< Attribute > mapping)
ReductionTilingStrategy reductionStrategy
Specify how reduction dimensions should be tiled.
Transformation information returned after tiling.
SmallVector< Operation * > tiledOps
Tiled operations that are generated during tiling.
SmallVector< Value > initialValues
The initial destination values passed to the tiled operations.
SmallVector< LoopLikeOpInterface > loops
The scf.for operations that iterate over the tiles.
SmallVector< Operation * > generatedSlices
Slices generated after tiling that can be used for fusing with the tiled producer.
SmallVector< Value > replacements
Values to use as replacements for the untiled op.
SmallVector< Operation * > mergeOps
In cases where there as an additional merge step after tiling return the merged ops after tiling.