9#ifndef MLIR_DIALECT_SCF_TRANSFORMS_TILEUSINGINTERFACE_H 
   10#define MLIR_DIALECT_SCF_TRANSFORMS_TILEUSINGINTERFACE_H 
   31using SCFTileSizeComputationFunction =
 
   32    std::function<SmallVector<OpFoldResult>(OpBuilder &, Operation *)>;
 
   35struct SCFTilingOptions {
 
   37  enum class LoopType { ForOp, ForallOp, CustomOp };
 
   38  LoopType loopType = LoopType::ForOp;
 
   39  SCFTilingOptions &setLoopType(LoopType type) {
 
   49  SCFTileSizeComputationFunction tileSizeComputationFunction = 
nullptr;
 
   52  setTileSizeComputationFunction(SCFTileSizeComputationFunction fun) {
 
   53    tileSizeComputationFunction = std::move(fun);
 
   59  SCFTilingOptions &setTileSizes(ArrayRef<OpFoldResult> tileSizes);
 
   62  SmallVector<int64_t> interchangeVector = {};
 
   63  SCFTilingOptions &setInterchange(ArrayRef<int64_t> interchange) {
 
   64    interchangeVector = llvm::to_vector(interchange);
 
   81  SCFTileSizeComputationFunction numThreadsComputationFunction = 
nullptr;
 
   84  setNumThreadsComputationFunction(SCFTileSizeComputationFunction fun) {
 
   85    numThreadsComputationFunction = std::move(fun);
 
   90  SCFTilingOptions &setNumThreads(ArrayRef<OpFoldResult> numThreads);
 
   96  SmallVector<Attribute> mappingVector = {};
 
   97  SCFTilingOptions &setMapping(ArrayRef<Attribute> mapping) {
 
   98    mappingVector = llvm::to_vector(mapping);
 
  108      ReductionTilingStrategy::FullReduction;
 
  110  setReductionTilingStrategy(ReductionTilingStrategy strategy) {
 
  111    reductionStrategy = strategy;
 
  118  SetVector<unsigned> reductionDims;
 
  119  SCFTilingOptions &setReductionDims(ArrayRef<unsigned> dims) {
 
  120    reductionDims.clear();
 
  121    reductionDims.insert(dims.begin(), dims.end());
 
  163  struct CustomLoopHeaderInfo {
 
  164    SmallVector<LoopLikeOpInterface> loops;
 
  165    SmallVector<OpFoldResult> tileOffset;
 
  166    SmallVector<OpFoldResult> tileSizes;
 
  167    SmallVector<Value> destinationTensors;
 
  181  using GenerateLoopHeaderFn = std::function<FailureOr<CustomLoopHeaderInfo>(
 
  182      RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
 
  183      ArrayRef<OpFoldResult> givenTileSizes, 
ValueRange destinationTensors)>;
 
  196  using GenerateLoopTerminatorFn = std::function<LogicalResult(
 
  197      RewriterBase &rewriter, Location loc, ArrayRef<LoopLikeOpInterface> loops,
 
  199      ArrayRef<SmallVector<OpFoldResult>> resultOffsets,
 
  200      ArrayRef<SmallVector<OpFoldResult>> resultSizes,
 
  204  GenerateLoopHeaderFn generateLoopHeaderFn = 
nullptr;
 
  206  GenerateLoopTerminatorFn generateLoopTerminatorFn = 
nullptr;
 
  210  setCustomLoopGenerationFns(GenerateLoopHeaderFn headerFn,
 
  211                             GenerateLoopTerminatorFn terminatorFn) {
 
  212    generateLoopHeaderFn = std::move(headerFn);
 
  213    generateLoopTerminatorFn = std::move(terminatorFn);
 
  219struct SCFTilingResult {
 
  223  SmallVector<Operation *> tiledOps;
 
  225  SmallVector<Value> initialValues;
 
  227  SmallVector<LoopLikeOpInterface> loops;
 
  230  SmallVector<Value> replacements;
 
  233  SmallVector<Operation *> generatedSlices;
 
  238  SmallVector<Operation *> mergeOps;
 
  243FailureOr<SCFTilingResult> tileUsingSCF(RewriterBase &rewriter,
 
  245                                        const SCFTilingOptions &
options);
 
  248struct SCFTileAndFuseOptions {
 
  250  SCFTilingOptions tilingOptions;
 
  251  SCFTileAndFuseOptions &setTilingOptions(SCFTilingOptions 
options) {
 
  265  struct ControlFnResult {
 
  268    bool yieldProducerReplacement = 
false;
 
  270  using ControlFnTy = std::function<std::optional<ControlFnResult>(
 
  271      tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer,
 
  272      bool isDestinationOperand)>;
 
  275  ControlFnTy fusionControlFn = [](tensor::ExtractSliceOp, OpResult,
 
  276                                   bool) -> std::optional<ControlFnResult> {
 
  277    return ControlFnResult{};
 
  279  SCFTileAndFuseOptions &setFusionControlFn(ControlFnTy controlFn) {
 
  280    fusionControlFn = controlFn;
 
  287  std::optional<FrozenRewritePatternSet> cleanupPatterns = std::nullopt;
 
  294struct SCFFuseProducerOfSliceResult {
 
  295  OpResult origProducer;       
 
  296  Value tiledAndFusedProducer; 
 
  297  SmallVector<Operation *> tiledOps;
 
  298  SmallVector<Operation *> generatedSlices;
 
  300std::optional<SCFFuseProducerOfSliceResult>
 
  301tileAndFuseProducerOfSlice(RewriterBase &rewriter,
 
  302                           tensor::ExtractSliceOp candidateSliceOp,
 
  303                           MutableArrayRef<LoopLikeOpInterface> loops);
 
  361FailureOr<SmallVector<Operation *>> yieldReplacementForFusedProducer(
 
  362    RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
 
  363    scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
 
  364    MutableArrayRef<LoopLikeOpInterface> loops,
 
  365    ArrayRef<unsigned> yieldResultNumber = ArrayRef<unsigned>{});
 
  368struct SCFTileAndFuseResult {
 
  370  llvm::SetVector<Operation *> fusedProducers;
 
  375  llvm::SetVector<Operation *> tiledAndFusedOps;
 
  377  SmallVector<LoopLikeOpInterface> loops;
 
  379  llvm::DenseMap<Value, Value> replacements;
 
  407FailureOr<SCFTileAndFuseResult>
 
  408tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter,
 
  409                                     TilingInterface consumer,
 
  410                                     const SCFTileAndFuseOptions &
options);
 
  418struct SCFFuseConsumerOfSliceResult {
 
  420  SmallVector<OpOperand *> origConsumerOperands;
 
  422  SmallVector<OpOperand *> tiledAndFusedConsumerOperands;
 
  423  SmallVector<Operation *> tiledOps;
 
  425FailureOr<scf::SCFFuseConsumerOfSliceResult>
 
  426tileAndFuseConsumerOfSlices(RewriterBase &rewriter,
 
  427                            ArrayRef<Operation *> candidateSlices,
 
  428                            MutableArrayRef<LoopLikeOpInterface> loops);
 
  432FailureOr<SmallVector<scf::ForOp>>
 
  433lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op);
 
  459FailureOr<scf::SCFTilingResult>
 
  460tileReductionUsingScf(RewriterBase &
b, PartialReductionOpInterface op,
 
  461                      ArrayRef<OpFoldResult> tileSizes);
 
static llvm::ManagedStatic< PassManagerOptions > options
 
Operation is the basic unit of execution within MLIR.
 
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
 
Include the generated interface declarations.
 
ReductionTilingStrategy
Tiling can be thought of as splitting a dimension into 2 and materializing the outer dimension as a l...