9#ifndef MLIR_DIALECT_SCF_TRANSFORMS_TILEUSINGINTERFACE_H
10#define MLIR_DIALECT_SCF_TRANSFORMS_TILEUSINGINTERFACE_H
31using SCFTileSizeComputationFunction =
32 std::function<SmallVector<OpFoldResult>(OpBuilder &, Operation *)>;
35struct SCFTilingOptions {
37 enum class LoopType { ForOp, ForallOp, CustomOp };
38 LoopType loopType = LoopType::ForOp;
39 SCFTilingOptions &setLoopType(LoopType type) {
49 SCFTileSizeComputationFunction tileSizeComputationFunction =
nullptr;
52 setTileSizeComputationFunction(SCFTileSizeComputationFunction fun) {
53 tileSizeComputationFunction = std::move(fun);
59 SCFTilingOptions &setTileSizes(ArrayRef<OpFoldResult> tileSizes);
62 SmallVector<int64_t> interchangeVector = {};
63 SCFTilingOptions &setInterchange(ArrayRef<int64_t> interchange) {
64 interchangeVector = llvm::to_vector(interchange);
81 SCFTileSizeComputationFunction numThreadsComputationFunction =
nullptr;
84 setNumThreadsComputationFunction(SCFTileSizeComputationFunction fun) {
85 numThreadsComputationFunction = std::move(fun);
90 SCFTilingOptions &setNumThreads(ArrayRef<OpFoldResult> numThreads);
96 SmallVector<Attribute> mappingVector = {};
97 SCFTilingOptions &setMapping(ArrayRef<Attribute> mapping) {
98 mappingVector = llvm::to_vector(mapping);
108 ReductionTilingStrategy::FullReduction;
110 setReductionTilingStrategy(ReductionTilingStrategy strategy) {
111 reductionStrategy = strategy;
118 SetVector<unsigned> reductionDims;
119 SCFTilingOptions &setReductionDims(ArrayRef<unsigned> dims) {
120 reductionDims.clear();
121 reductionDims.insert(dims.begin(), dims.end());
163 struct CustomLoopHeaderInfo {
164 SmallVector<LoopLikeOpInterface> loops;
165 SmallVector<OpFoldResult> tileOffset;
166 SmallVector<OpFoldResult> tileSizes;
167 SmallVector<Value> destinationTensors;
181 using GenerateLoopHeaderFn = std::function<FailureOr<CustomLoopHeaderInfo>(
182 RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
183 ArrayRef<OpFoldResult> givenTileSizes,
ValueRange destinationTensors)>;
196 using GenerateLoopTerminatorFn = std::function<LogicalResult(
197 RewriterBase &rewriter, Location loc, ArrayRef<LoopLikeOpInterface> loops,
199 ArrayRef<SmallVector<OpFoldResult>> resultOffsets,
200 ArrayRef<SmallVector<OpFoldResult>> resultSizes,
204 GenerateLoopHeaderFn generateLoopHeaderFn =
nullptr;
206 GenerateLoopTerminatorFn generateLoopTerminatorFn =
nullptr;
210 setCustomLoopGenerationFns(GenerateLoopHeaderFn headerFn,
211 GenerateLoopTerminatorFn terminatorFn) {
212 generateLoopHeaderFn = std::move(headerFn);
213 generateLoopTerminatorFn = std::move(terminatorFn);
219struct SCFTilingResult {
223 SmallVector<Operation *> tiledOps;
225 SmallVector<Value> initialValues;
227 SmallVector<LoopLikeOpInterface> loops;
230 SmallVector<Value> replacements;
233 SmallVector<Operation *> generatedSlices;
238 SmallVector<Operation *> mergeOps;
243FailureOr<SCFTilingResult> tileUsingSCF(RewriterBase &rewriter,
245 const SCFTilingOptions &
options);
248struct SCFTileAndFuseOptions {
250 SCFTilingOptions tilingOptions;
251 SCFTileAndFuseOptions &setTilingOptions(SCFTilingOptions
options) {
265 struct ControlFnResult {
268 bool yieldProducerReplacement =
false;
270 using ControlFnTy = std::function<std::optional<ControlFnResult>(
271 tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer,
272 bool isDestinationOperand)>;
275 ControlFnTy fusionControlFn = [](tensor::ExtractSliceOp, OpResult,
276 bool) -> std::optional<ControlFnResult> {
277 return ControlFnResult{};
279 SCFTileAndFuseOptions &setFusionControlFn(ControlFnTy controlFn) {
280 fusionControlFn = controlFn;
287 std::optional<FrozenRewritePatternSet> cleanupPatterns = std::nullopt;
294struct SCFFuseProducerOfSliceResult {
295 OpResult origProducer;
296 Value tiledAndFusedProducer;
297 SmallVector<Operation *> tiledOps;
298 SmallVector<Operation *> generatedSlices;
300std::optional<SCFFuseProducerOfSliceResult>
301tileAndFuseProducerOfSlice(RewriterBase &rewriter,
302 tensor::ExtractSliceOp candidateSliceOp,
303 MutableArrayRef<LoopLikeOpInterface> loops);
361FailureOr<SmallVector<Operation *>> yieldReplacementForFusedProducer(
362 RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
363 scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
364 MutableArrayRef<LoopLikeOpInterface> loops,
365 ArrayRef<unsigned> yieldResultNumber = ArrayRef<unsigned>{});
368struct SCFTileAndFuseResult {
370 llvm::SetVector<Operation *> fusedProducers;
375 llvm::SetVector<Operation *> tiledAndFusedOps;
377 SmallVector<LoopLikeOpInterface> loops;
379 llvm::DenseMap<Value, Value> replacements;
407FailureOr<SCFTileAndFuseResult>
408tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter,
409 TilingInterface consumer,
410 const SCFTileAndFuseOptions &
options);
422struct SCFFuseConsumerOfSliceResult {
424 SmallVector<OpOperand *> origConsumerOperands;
426 SmallVector<OpOperand *> tiledAndFusedConsumerOperands;
427 SmallVector<Operation *> tiledOps;
429FailureOr<scf::SCFFuseConsumerOfSliceResult>
430tileAndFuseConsumerOfSlices(RewriterBase &rewriter,
431 ArrayRef<Operation *> candidateSlices,
432 MutableArrayRef<LoopLikeOpInterface> loops);
438FailureOr<scf::SCFFuseConsumerOfSliceResult>
439tileAndFuseConsumer(RewriterBase &rewriter, Operation *consumer,
440 MutableArrayRef<LoopLikeOpInterface> loops);
444FailureOr<SmallVector<scf::ForOp>>
445lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op);
471FailureOr<scf::SCFTilingResult>
472tileReductionUsingScf(RewriterBase &
b, PartialReductionOpInterface op,
473 ArrayRef<OpFoldResult> tileSizes);
static llvm::ManagedStatic< PassManagerOptions > options
Operation is the basic unit of execution within MLIR.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Include the generated interface declarations.
ReductionTilingStrategy
Tiling can be thought of as splitting a dimension into 2 and materializing the outer dimension as a l...