MLIR 22.0.0git
TileUsingInterface.h
Go to the documentation of this file.
1//===- TileUsingInterface.h - Tiling ops using TilingInterface --*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_DIALECT_SCF_TRANSFORMS_TILEUSINGINTERFACE_H
10#define MLIR_DIALECT_SCF_TRANSFORMS_TILEUSINGINTERFACE_H
11
19
20#include <deque>
21
22namespace mlir {
23class Operation;
24class RewriterBase;
25class TilingInterface;
26} // namespace mlir
27
28namespace mlir {
29namespace scf {
30
31using SCFTileSizeComputationFunction =
32 std::function<SmallVector<OpFoldResult>(OpBuilder &, Operation *)>;
33
34/// Options to use to control tiling.
35struct SCFTilingOptions {
36 /// Specify which loop construct to use for tile and fuse.
37 enum class LoopType { ForOp, ForallOp, CustomOp };
38 LoopType loopType = LoopType::ForOp;
39 SCFTilingOptions &setLoopType(LoopType type) {
40 loopType = type;
41 return *this;
42 }
43
44 /// Computation function that returns the tile sizes to use for each loop.
45 /// Returning a tile size of zero implies no tiling for that loop. If the
46 /// size of the returned vector is smaller than the number of loops, the inner
47 /// loops are not tiled. If the size of the returned vector is larger, then
48 /// the vector is truncated to number of loops.
49 SCFTileSizeComputationFunction tileSizeComputationFunction = nullptr;
50
51 SCFTilingOptions &
52 setTileSizeComputationFunction(SCFTileSizeComputationFunction fun) {
53 tileSizeComputationFunction = std::move(fun);
54 return *this;
55 }
56 /// Convenience function to set the `tileSizeComputationFunction` to a
57 /// function that computes tile sizes at the point they are needed. Allows
58 /// proper interaction with folding.
59 SCFTilingOptions &setTileSizes(ArrayRef<OpFoldResult> tileSizes);
60
61 /// The interchange vector to reorder the tiled loops.
62 SmallVector<int64_t> interchangeVector = {};
63 SCFTilingOptions &setInterchange(ArrayRef<int64_t> interchange) {
64 interchangeVector = llvm::to_vector(interchange);
65 return *this;
66 }
67
68 //-------------------------------------------------------------------------//
69 // Options related to tiling using `scf.forall`.
70 //-------------------------------------------------------------------------//
71
72 /// Computation function that returns the number of threads to use for
73 /// each loop. Returning a num threads of zero implies no tiling for that
74 /// loop. If the size of the returned vector is smaller than the number of
75 /// loops, the inner loops are not tiled. If the size of the returned vector
76 /// is larger, then the vector is truncated to number of loops. Note: This
77 /// option is only supported with loopType set to `LoopType::ForallOp`. If the
78 /// tile size function is not specified while the num threads computation is,
79 /// then the tile size is determined automatically to map at most one tile per
80 /// thread.
81 SCFTileSizeComputationFunction numThreadsComputationFunction = nullptr;
82
83 SCFTilingOptions &
84 setNumThreadsComputationFunction(SCFTileSizeComputationFunction fun) {
85 numThreadsComputationFunction = std::move(fun);
86 return *this;
87 }
88 /// Convenience function to set the `numThreadsComputationFunction` to a
89 /// function that computes num threads at the point they are needed.
90 SCFTilingOptions &setNumThreads(ArrayRef<OpFoldResult> numThreads);
91
92 /// Specify mapping of loops to devices. This is only respected when the loop
93 /// constructs support such a mapping (like `scf.forall`). Will be ignored
94 /// when using loop constructs that dont support such a mapping (like
95 /// `scf.for`)
96 SmallVector<Attribute> mappingVector = {};
97 SCFTilingOptions &setMapping(ArrayRef<Attribute> mapping) {
98 mappingVector = llvm::to_vector(mapping);
99 return *this;
100 }
101
102 //-------------------------------------------------------------------------//
103 // Options related reduction tiling
104 //-------------------------------------------------------------------------//
105
106 /// Specify how reduction dimensions should be tiled.
107 ReductionTilingStrategy reductionStrategy =
108 ReductionTilingStrategy::FullReduction;
109 SCFTilingOptions &
110 setReductionTilingStrategy(ReductionTilingStrategy strategy) {
111 reductionStrategy = strategy;
112 return *this;
113 }
114
115 /// Specify the reduction dimensions to be tiled. Note that this needs to be
116 /// specified. If left unspecified, then none of the reduction dimensions are
117 /// tiled.
118 SetVector<unsigned> reductionDims;
119 SCFTilingOptions &setReductionDims(ArrayRef<unsigned> dims) {
120 reductionDims.clear();
121 reductionDims.insert(dims.begin(), dims.end());
122 return *this;
123 }
124
125 //-------------------------------------------------------------------------//
126 // Options related to tiling using custom loop.
127 //-------------------------------------------------------------------------//
128
129 // For generating the inter-tile loops using a custom loop, two callback
130 // functions are needed
131 // 1. That generates the "loop header", i.e. the loop that iterates over the
132 // different tiles.
133 // 2. That generates the loop terminator
134 //
135 // For `scf.forall` case the call back to generate loop header would generate
136 //
137 // ```mlir
138 // scf.forall (...) = ... {
139 // ..
140 // }
141 // ```
142 //
143 // and the call back to generate the loop terminator would generate the
144 // `scf.in_parallel` region
145 //
146 // ```mlir
147 // scf.forall (...) = ... {
148 // scf.in_parallel {
149 // tensor.parallel_insert_slice ...
150 // }
151 // }
152 // ```
153 //
154
155 // Information that is to be returned by loop header callback needed for the
156 // rest of the tiled codegeneration.
157 // - `loops`: The generated loops
158 // - `tileOffset`: The values that represent the offset of the iteration space
159 // tile.
160 // - `tileSizes` : The values that represent the size of the iteration space
161 // tile.
162 // - `destinationTensors` : The tensors to use as destinations during tiling.
163 struct CustomLoopHeaderInfo {
164 SmallVector<LoopLikeOpInterface> loops;
165 SmallVector<OpFoldResult> tileOffset;
166 SmallVector<OpFoldResult> tileSizes;
167 SmallVector<Value> destinationTensors;
168 };
169
170 // Type of the callback function that generates the loop headers.
171 // - `loopRanges` : Values that represent the full size of the iteration space
172 // being tiled.
173 // - `givenTileSizes` : The tile sizes that are to be used to tile the
174 // iteration space.
175 // - `destinationTensors` : The tensors to use as destinations for the results
176 // of the tiled loop for loops that implement
177 // `DestinationStyleOpInterface`.
178 // Returns the `CustomLoopHeaderInfo` object (described above). it is expected
179 // that this function sets the insertion point of `rewriter` to the program
180 // point where the intra-tile loop computation is to be generated.
181 using GenerateLoopHeaderFn = std::function<FailureOr<CustomLoopHeaderInfo>(
182 RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
183 ArrayRef<OpFoldResult> givenTileSizes, ValueRange destinationTensors)>;
184
185 // Type of the callback function that generates the loop terminator.
186 // - `loops` : generated loops from the GenerateLoopHeaderFn callback
187 // - `tiledResults` : Tiles of the result computed for the iteration space
188 // tile.
189 // - `resultOffsets` : For each of the `tiledResults`, the offset at which
190 // the result tile is to be "inserted" back into the
191 // destination tensor.
192 // - `resultSizes` : For each of the `tiledResults`, the size of the result
193 // tile that is to be "inserted" back into the destination
194 // tensor.
195 // Returns the `CustomLoopHeaderInfo` object (described above)
196 using GenerateLoopTerminatorFn = std::function<LogicalResult(
197 RewriterBase &rewriter, Location loc, ArrayRef<LoopLikeOpInterface> loops,
198 ValueRange tiledResults,
199 ArrayRef<SmallVector<OpFoldResult>> resultOffsets,
200 ArrayRef<SmallVector<OpFoldResult>> resultSizes,
201 ValueRange destinationTensors)>;
202
203 // Callback function to generate the inter-tile loop header.
204 GenerateLoopHeaderFn generateLoopHeaderFn = nullptr;
205 // Callback function to generate the inter-tile loop terminator.
206 GenerateLoopTerminatorFn generateLoopTerminatorFn = nullptr;
207 // Helper function to set the callbacks for inter-tile loop header and
208 // terminator functions when using a custom operation for the loop.
209 SCFTilingOptions &
210 setCustomLoopGenerationFns(GenerateLoopHeaderFn headerFn,
211 GenerateLoopTerminatorFn terminatorFn) {
212 generateLoopHeaderFn = std::move(headerFn);
213 generateLoopTerminatorFn = std::move(terminatorFn);
214 return *this;
215 }
216};
217
218/// Transformation information returned after tiling.
219struct SCFTilingResult {
220 /// Tiled operations that are generated during tiling. The order does not
221 /// matter except the last op. The replacements are expected to be the results
222 /// of the last op.
223 SmallVector<Operation *> tiledOps;
224 /// The initial destination values passed to the tiled operations.
225 SmallVector<Value> initialValues;
226 /// The `scf.for` operations that iterate over the tiles.
227 SmallVector<LoopLikeOpInterface> loops;
228 /// Values to use as replacements for the untiled op. Is the same size as the
229 /// number of results of the untiled op.
230 SmallVector<Value> replacements;
231 /// Slices generated after tiling that can be used for fusing with the tiled
232 /// producer.
233 SmallVector<Operation *> generatedSlices;
234 /// In cases where there as an additional merge step after tiling
235 /// return the merged ops after tiling. This list is empty when reduction
236 /// tiling strategy is
237 /// `scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction.
238 SmallVector<Operation *> mergeOps;
239};
240
241/// Method to tile an op that implements the `TilingInterface` using
242/// `scf.for` for iterating over the tiles.
243FailureOr<SCFTilingResult> tileUsingSCF(RewriterBase &rewriter,
244 TilingInterface op,
245 const SCFTilingOptions &options);
246
247/// Options used to control tile + fuse.
248struct SCFTileAndFuseOptions {
249 /// The tiling options used to control the tiling of the consumer.
250 SCFTilingOptions tilingOptions;
251 SCFTileAndFuseOptions &setTilingOptions(SCFTilingOptions options) {
252 tilingOptions = options;
253 return *this;
254 }
255
256 /// Control function to check if a slice needs to be fused or not,
257 /// The control function receives
258 /// 1) the slice along which fusion is to be done,
259 /// 2) the producer value that is to be fused
260 /// 3) a boolean value set to `true` if the fusion is from
261 /// a destination operand.
262 /// The control function returns an `std::optiona<ControlFnResult>`.
263 /// If the return value is `std::nullopt`, that implies no fusion
264 /// is to be performed along that slice.
265 struct ControlFnResult {
266 /// Set to true if the loop nest has to return a replacement value
267 /// for the fused producer.
268 bool yieldProducerReplacement = false;
269 };
270 using ControlFnTy = std::function<std::optional<ControlFnResult>(
271 tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer,
272 bool isDestinationOperand)>;
273 /// The default control function implements greedy fusion without yielding
274 /// a replacement for any of the fused results.
275 ControlFnTy fusionControlFn = [](tensor::ExtractSliceOp, OpResult,
276 bool) -> std::optional<ControlFnResult> {
277 return ControlFnResult{};
278 };
279 SCFTileAndFuseOptions &setFusionControlFn(ControlFnTy controlFn) {
280 fusionControlFn = controlFn;
281 return *this;
282 }
283
284 /// An optional set of rewrite patterns to apply to the results of tiling
285 /// before fusion. This will track deleted and newly inserted
286 /// `tensor.extract_slice` ops and update the worklist.
287 std::optional<FrozenRewritePatternSet> cleanupPatterns = std::nullopt;
288};
289
290/// Fuse the producer of the source of `candidateSliceOp` by computing the
291/// required slice of the producer in-place. Note that the method
292/// replaces the uses of `candidateSliceOp` with the tiled and fused producer
293/// value but does not delete the slice operation.
294struct SCFFuseProducerOfSliceResult {
295 OpResult origProducer; // Original untiled producer.
296 Value tiledAndFusedProducer; // Tile and fused producer value.
297 SmallVector<Operation *> tiledOps;
298 SmallVector<Operation *> generatedSlices;
299};
300std::optional<SCFFuseProducerOfSliceResult>
301tileAndFuseProducerOfSlice(RewriterBase &rewriter,
302 tensor::ExtractSliceOp candidateSliceOp,
303 MutableArrayRef<LoopLikeOpInterface> loops);
304
305/// Reconstruct the fused producer from within the tiled-and-fused code. Based
306/// on the slice of the producer computed in place it is possible that within
307/// the loop nest same slice of the producer is computed multiple times. It is
308/// in general not possible to recompute the value of the fused producer from
309/// the tiled loop code in such cases. For the cases where no slice of the
310/// producer is computed in a redundant fashion it is possible to reconstruct
311/// the value of the original producer from within the tiled loop. It is upto
312/// the caller to ensure that the producer is not computed redundantly within
313/// the tiled loop nest. For example, consider
314///
315/// ```mlir
316/// %0 = linalg.matmul ins(...) outs(...) -> tensor<?x?xf32>
317/// %1 = linalg.matmul ins(%0, ..) outs(...) -> tensor<?x?x?f32>
318/// ```
319///
320/// If `%1` is tiled in a 2D fashion and `%0` is fused with it, the resulting IR
321/// is,
322///
323/// ```mlir
324/// %t1_0 = scf.for .... iter_args(%arg0 = ...) {
325/// %t1_1 = scf.for ... iter_args(%arg1 = %arg0) {
326/// ...
327/// %t1_2 = linalg.matmul ins(...) outs(...) -> tensor<?x?xf32>
328/// %t1_3 = linalg.matmul ins(%t1_2, ...)
329/// %t1_4 = tensor.insert_slice %t1_3 into %arg1 ...
330/// scf.yield %t1_4
331/// }
332/// scf.yield %t1_1
333/// }
334/// ```
335///
336/// Here `%t1_2` is the same for all iterations of the inner `scf.for`. Instead
337/// if `%1` were tiled only along the rows, the resultant code would be
338///
339/// ```mlir
340/// %t2_0 = scf.for .... iter_args(%arg0 = ...) {
341/// ...
342/// %t2_1 = linalg.matmul ins(...) outs(...) -> tensor<?x?xf32>
343/// %t2_2 = linalg.matmul ins(%t2_1, ...)
344/// %t2_3 = tensor.insert_slice %t2_2 into %arg0 ...
345/// scf.yield %t2_3
346/// }
347/// ```
348///
349/// Here there is no intersection in the different slices of `%t2_1` computed
350/// across iterations of the `scf.for`. In such cases, the value of the original
351/// `%0` can be reconstructed from within the loop body. This is useful in cases
352/// where `%0` had other uses as well. If not reconstructed from within the loop
353/// body, uses of `%0` could not be replaced, making it still live and the
354/// fusion immaterial.
355///
356/// The @param `yieldResultNumber` decides which result would be yield. If not
357/// given, yield all `opResult` of fused producer.
358///
359/// The method returns the list of new slices added during the process (which
360/// can be used to fuse along).
361FailureOr<SmallVector<Operation *>> yieldReplacementForFusedProducer(
362 RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
363 scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
364 MutableArrayRef<LoopLikeOpInterface> loops,
365 ArrayRef<unsigned> yieldResultNumber = ArrayRef<unsigned>{});
366
367/// Transformation information returned after tile and fuse.
368struct SCFTileAndFuseResult {
369 /// List of untiled operations that were fused with the tiled consumer.
370 llvm::SetVector<Operation *> fusedProducers;
371 /// List of tiled and fused operations generated. The first element is always
372 /// the tiled version of the original consumer operation processed by
373 /// `tileConsumerAndFuseProducersUsingSCF`, followed by any operations that
374 /// were fused with it.
375 llvm::SetVector<Operation *> tiledAndFusedOps;
376 /// The `scf.for` operations that iterate over the tiles.
377 SmallVector<LoopLikeOpInterface> loops;
378 /// The replacement values to use for the tiled and fused operations.
379 llvm::DenseMap<Value, Value> replacements;
380};
381
382/// Method to tile and fuse a sequence of operations, by tiling the consumer
383/// and fusing its producers. Note that this assumes that it is valid to
384/// tile+fuse the producer into the innermost tiled loop. Its up to the caller
385/// to ensure that the tile sizes provided make this fusion valid.
386///
387/// For example, for the following sequence
388///
389/// ```mlir
390/// %0 =
391/// %1 = linalg.fill ... outs(%0 : ... )
392/// %2 = linalg.matmul ... outs(%1 : ...) ...
393/// ```
394///
395/// it is legal to fuse the fill with the matmul only if the matmul is tiled
396/// along the parallel dimensions and not the reduction dimension, i.e. the tile
397/// size for the reduction dimension should be 0. The resulting fused
398/// transformation is
399///
400/// ```mlir
401/// %1 = scf.for ... iter_args(%arg0 = %0)
402/// %2 = tensor.extract_slice %arg0
403/// %3 = linalg.fill .. outs(%2 : ... )
404/// %4 = linalg.matmul .. outs(%3 : ...)
405/// }
406/// ```
407FailureOr<SCFTileAndFuseResult>
408tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter,
409 TilingInterface consumer,
410 const SCFTileAndFuseOptions &options);
411
412/// Fuse the consumer `candidateSlices` by computing the required slice of the
413/// consumer in-place. All the entries of `candidateSlices` are expected to map
414/// to the same consumer. The method returns an error if the consumer cannot be
415/// tiled in a manner that is consistent for all the passed slices. Note that
416/// the method replaces the uses of `candidateSlices` with the tiled and fused
417/// consumer value but does not delete the slice operations.
418struct SCFFuseConsumerOfSliceResult {
419 // Original untiled consumer operands.
420 SmallVector<OpOperand *> origConsumerOperands;
421 // Tiled and fused consumer operands.
422 SmallVector<OpOperand *> tiledAndFusedConsumerOperands;
423 SmallVector<Operation *> tiledOps;
424};
425FailureOr<scf::SCFFuseConsumerOfSliceResult>
426tileAndFuseConsumerOfSlices(RewriterBase &rewriter,
427 ArrayRef<Operation *> candidateSlices,
428 MutableArrayRef<LoopLikeOpInterface> loops);
429
430/// Method to lower an `op` that implements the `TilingInterface` to
431/// loops/scalars.
432FailureOr<SmallVector<scf::ForOp>>
433lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op);
434
435/// Method to tile a reduction and generate a parallel op within a serial loop.
436/// Each of the partial reductions are calculated in parallel. Then after the
437/// loop all the partial reduction are merged into a final reduction.
438/// For example for the following sequence
439///
440/// ```mlir
441/// %0 = linalg.generic %in ["parallel", "reduction"]
442/// : tensor<7x9xf32> -> tensor<7xf32>
443/// ```
444///
445/// into:
446///
447/// ```mlir
448/// %0 = linalg.fill ... : tensor<7x4xf32>
449/// %1 = scf.for ... iter_args(%arg0 = %0)
450/// %2 = tensor.extract_slice %arg0 : tensor<7x4xf32> -> tensor<7x?xf32>
451/// %3 = tensor.extract_slice %in : tensor<7x9xf32> -> tensor<7x?xf32>
452/// %4 = linalg.generic %2, %3 ["parallel", "parallel"]
453/// : tensor<7x?xf32> -> tensor<7x?xf32>
454/// %5 = tensor.insert_slice %3, %0[0, 0] : tensor<7x4xf32>
455/// }
456/// %6 = linalg.generic %1 ["parallel", "reduction"]
457/// : tensor<7x4xf32> -> tensor<7xf32>
458/// ```
459FailureOr<scf::SCFTilingResult>
460tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op,
461 ArrayRef<OpFoldResult> tileSizes);
462
463} // namespace scf
464} // namespace mlir
465
466#endif // MLIR_DIALECT_SCF_TRANSFORMS_TILEUSINGINTERFACE_H
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static llvm::ManagedStatic< PassManagerOptions > options
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Include the generated interface declarations.
ReductionTilingStrategy
Tiling can be thought of as splitting a dimension into 2 and materializing the outer dimension as a l...