MLIR  16.0.0git
Transforms.h
Go to the documentation of this file.
1 //===- Transforms.h - Linalg transformations as patterns --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_LINALG_TRANSFORMS_TRANSFORMS_H
10 #define MLIR_DIALECT_LINALG_TRANSFORMS_TRANSFORMS_H
11 
12 #include <utility>
13 
22 #include "mlir/IR/PatternMatch.h"
25 #include "llvm/ADT/SmallBitVector.h"
26 #include "llvm/ADT/SmallSet.h"
27 
28 namespace mlir {
29 namespace bufferization {
30 class BufferizeTypeConverter;
31 } // namespace bufferization
32 
33 class FrozenRewritePatternSet;
34 
35 namespace linalg {
36 
37 struct LinalgElementwiseFusionOptions;
38 struct LinalgFusionOptions;
39 struct LinalgTilingOptions;
40 
41 //===----------------------------------------------------------------------===//
42 // Transformations exposed as function calls.
43 //===----------------------------------------------------------------------===//
45 
48 
49 /// Populate patterns for splitting a `LinalgOp` with multiple statements within
50 /// its payload into multiple `GenericOp` that have a single statement.
52 
53 /// Populate patterns for vectorizing low-D convolution ops. This is a step in
54 /// progressive lowering for convolution ops, it assume high-D convolution ops
55 /// were decomposed previously.
57  PatternBenefit benefit = 1);
58 
59 /// Populate patterns that convert `ElementwiseMappable` ops to linalg
60 /// parallel loops.
62 
63 /// Populate patterns that are only useful in the context of sparse tensors.
65 
66 /// Function type which is used to control when to stop fusion. It is expected
67 /// that OpOperand is not modified in the callback. The OpOperand is not marked
68 /// as const to allow callers to use non-const methods.
69 using ControlFusionFn = std::function<bool(OpOperand *fusedOperand)>;
70 
71 /// Patterns for fusing linalg operation on tensors.
72 
73 /// Pattern to fuse `linalg.generic` -> `linalg.generic` operations
74 /// when both operations are fusable elementwise operations.
76  RewritePatternSet &patterns,
77  const ControlFusionFn &controlElementwiseOpFusion);
78 
79 /// Patterns to fold an expanding (collapsing) tensor_reshape operation with its
80 /// producer (consumer) generic operation by expanding the dimensionality of the
81 /// loop in the generic op.
83  RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes);
84 
85 /// Patterns to fold an expanding tensor.expand_shape operation with its
86 /// producer generic operation by collapsing the dimensions of the generic op.
88  RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes);
89 
90 /// Patterns to constant fold Linalg operations.
92  const ControlFusionFn &controlFn);
93 
94 /// Pattern to fuse a `tensor.pad` operation with the producer of its source,
95 /// if the producer is a `linalg` operation with all parallel iterator types.
97  RewritePatternSet &patterns);
98 
99 /// Patterns to convert from one named op to another. These can be seen as
100 /// canonicalizations of named ops into another named op.
102 
103 /// Patterns to fold unit-extent dimensions in operands/results of linalg ops on
104 /// tensors.
106 
107 /// Patterns that are used to inline constant operands into linalg generic ops.
109 
110 /// Patterns that are used to bubble up extract slice op above linalg op.
112 
113 /// Adds patterns that waps tensor.extract_slice(linalg.fill(%cst, %init)) into
114 /// linalg.fill(%cst, tensor.extract_slice(%init)).
116 
117 /// Return true if two `linalg.generic` operations with producer/consumer
118 /// relationship through `fusedOperand` can be fused using elementwise op
119 /// fusion.
120 bool areElementwiseOpsFusable(OpOperand *fusedOperand);
121 
122 /// Fuse two `linalg.generic` operations that have a producer-consumer
123 /// relationship captured through `fusedOperand`. The method expects
124 /// that `areElementwiseOpsFusable` returns true for the given `fusedOperand`.
126  OpOperand *fusedOperand);
127 
128 /// Maps the top level `scf.foreach_thread` op to GPU Thread Blocks. Mapping is
129 /// one-to-one and the induction variables of `scf.foreach_thread` are rewritten
130 /// to gpu.block_id according to the thread_dim_apping attribute. Dynamic,
131 /// `scf.foreach_thread` trip counts are currently not supported. Dynamic block
132 /// dim sizes are currently not supported.
134  RewriterBase &rewriter, scf::ForeachThreadOp foreachThreadOp,
135  function_ref<void(Operation *, const SmallVector<int64_t> &, IndexType,
137  blockIdGenerator,
138  SmallVector<int64_t> &gridDims);
139 
140 /// Finds the top level scf::ForeachThreadOp of given target.
142 
143 /// Searches `scf.foreach_thread` ops nested under `target` and maps each such
144 /// op to GPU threads. Mapping is one-to-one and the induction variables of
145 /// `scf.foreach_thread` are rewritten to gpu.thread_id according to the
146 /// thread_dim_apping attribute. Sibling `scf.foreach_thread` are supported in
147 /// which case, the union of the number of threads is computed and may result in
148 /// predication. Dynamic, `scf.foreach_thread` trip counts are currently not
149 /// supported. Dynamic block dim sizes are currently not supported.
151  RewriterBase &rewriter, Operation *target,
152  const SmallVector<int64_t> &blockDim, bool syncAfterDistribute);
153 
154 /// Split the given `op` into two parts along the given iteration space
155 /// `dimension` at the specified `splitPoint`, and return the two parts.
156 ///
157 /// For example, the following op:
158 ///
159 /// linalg.matmul ins(%0, %1 : tensor<128x32xf32>, tensor<32x64xf32>)
160 /// outs(%2 : tensor<128x64xf32>)
161 ///
162 /// split along the first dimension at position 42 will result in:
163 ///
164 /// %3 = tensor.extract_slice %0[0, 0][42, 32][1, 1]
165 /// %4 = tensor.extract_slice %2[0, 0][42, 64][1, 1]
166 /// %5 = linalg.matmul ins(%3, %1 : tensor<42x32xf32>, tensor<32x64xf32>)
167 /// outs(%5 : tensor<42x64xf32>)
168 /// %6 = tensor.insert_slice %5 into %2[0, 0][42, 64][1, 1]
169 ///
170 /// %7 = tensor.extract_slice %0[42, 0][86, 32][1, 1]
171 /// %8 = tensor.extract_slice %6[42, 0][86, 64][1, 1]
172 /// %9 = linalg.matmul ins(%7, %1 : tensor<86x32xf32>, tensor<32x64xf32>)
173 /// outs(%8 : tensor<86x64xf32>)
174 /// tensor.insert_slice %5 into %6[42, 0][86, 64][1, 1]
175 ///
176 /// Note that there is no simplification other than constant propagation applied
177 /// to slice extraction and insertion.
178 std::pair<TilingInterface, TilingInterface> splitOp(RewriterBase &rewriter,
179  TilingInterface op,
180  unsigned dimension,
181  OpFoldResult splitPoint);
182 
183 /// Perform standalone tiling of a single LinalgOp by `tileSizes`.
184 /// and permute the loop nest according to `interchangeVector`
185 /// The permutation is expressed as a list of integers that specify
186 /// the new ordering of the loop nest. The length of `interchangeVector`
187 /// must be equal to the length of `tileSizes`.
188 /// An empty vector is interpreted as the identity permutation and the
189 /// transformation returns early.
190 ///
191 /// Return a struct containing the tiled loops in the specified order
192 /// and the cloned op if successful, llvm::None otherwise.
193 ///
194 /// E.g. the permutation `(i,j,k) -> (j,k,i)` is expressed by
195 /// `interchangeVector = [1,2,0]`. All values in `interchangeVector` must be
196 /// integers, in the range 0..`tileSizes.size()` without duplications
197 /// (i.e. `[1,1,2]` is an invalid permutation).
199  LinalgOp op;
202 };
205 
206 /// Peel and canonicalize 'loops'.
207 void peelLoops(RewriterBase &rewriter, ArrayRef<scf::ForOp> loops);
208 
209 /// Peel the loops of a TiledLinalgOp.
210 void peelTiledLinalgOp(RewriterBase &rewriter, TiledLinalgOp &res,
211  ArrayRef<int64_t> peeledLoops,
212  LinalgTilingLoopType loopType);
213 
214 /// Interchange the `iterator_types` and `iterator_maps` dimensions and adapts
215 /// the index accesses of `op`. This is an in-place transformation controlled by
216 /// `interchangeVector`. An empty vector is interpreted as the identity
217 /// permutation and the transformation returns early.
218 ///
219 /// E.g. the permutation `(i,j,k) -> (j,k,i)` is expressed with
220 /// `interchangeVector = [1,2,0]`. All values in `interchangeVector` must be
221 /// integers, in the range 0..`op.rank` without duplications
222 /// (i.e. `[1,1,2]` is an invalid permutation).
223 ///
224 /// Return failure if the permutation is not valid.
226  GenericOp genericOp,
227  ArrayRef<unsigned> interchangeVector);
228 
229 /// Create a GenericOp from the given named operation `namedOp` and replace
230 /// namedOp.
231 /// Return failure if `namedOp` is a GenericOp or misses a region builder.
233  LinalgOp namedOp);
234 
235 /// Callback function type used to perform the allocation for the promoted
236 /// `subView`. In `boundingSubViewsize` a best attempt is made to find the
237 /// smallest constant value for the size of the buffer needed for each
238 /// dimension. If that is not possible, contains the dynamic size of the
239 /// subview. The call back should return the buffer to use.
240 using AllocBufferCallbackFn = std::function<Optional<Value>(
241  OpBuilder &b, memref::SubViewOp subView,
242  ArrayRef<Value> boundingSubViewSize, DataLayout &layout)>;
243 
244 /// Callback function type used to deallocate the buffers used to hold the
245 /// promoted subview.
247  std::function<LogicalResult(OpBuilder &b, Value buffer)>;
248 
249 /// Callback function type used to insert copy from original subview to subview
250 /// of the promoted region for the read operands/subview of promoted region to
251 /// original subview for the results. The copy has to happen from `src` to
252 /// `dst`.
253 using CopyCallbackFn =
254  std::function<LogicalResult(OpBuilder &b, Value src, Value dst)>;
255 
257  /// Indices of subViews to promote. If `None`, try to promote all operands.
258  Optional<DenseSet<unsigned>> operandsToPromote = None;
260  operandsToPromote = DenseSet<unsigned>();
261  operandsToPromote->insert(operands.begin(), operands.end());
262  return *this;
263  }
264  /// If ith element of `useFullTiles` is true the full view should be used for
265  /// the promoted buffer of the ith operand in `operandsToPromote`. Otherwise
266  /// the partial view will be used.
267  /// The decision is defaulted to `useFullTileBuffersDefault` when
268  /// `useFullTileBuffers` is None and for operands missing from
269  /// `useFullTileBuffers`.
270  Optional<llvm::SmallBitVector> useFullTileBuffers = None;
272  unsigned size = useFullTiles.size();
273  llvm::SmallBitVector tmp(size, false);
274  for (unsigned i = 0; i < size; ++i)
275  tmp[i] = useFullTiles[i];
276  useFullTileBuffers = tmp;
277  return *this;
278  }
279  /// If true all operands unspecified by `useFullTileBuffers` will use the full
280  /// view, otherwise the partial view.
281  bool useFullTileBuffersDefault = false;
283  useFullTileBuffersDefault = use;
284  return *this;
285  }
286  /// Alignment of promoted buffer. If `None` do not specify alignment.
289  alignment = align;
290  return *this;
291  }
292  /// Use alloca with the default allocation scheme.
293  bool useAlloca = false;
295  useAlloca = use;
296  return *this;
297  }
298  /// Callback function to do the allocation of the promoted buffer. If None,
299  /// then the default allocation scheme of allocating a memref<?xi8> buffer
300  /// followed by a view operation is used.
305  DeallocBufferCallbackFn const &deallocFn) {
306  allocationFn = allocFn;
307  deallocationFn = deallocFn;
308  return *this;
309  }
310  /// Callback function to do the copy of data to and from the promoted
311  /// subview. If None then a memref.copy is used.
315  CopyCallbackFn const &copyOut) {
316  copyInFn = copyIn;
317  copyOutFn = copyOut;
318  return *this;
319  }
320 };
321 
322 /// Create a new buffer using the `allocationFn` provided. The size of this
323 /// buffer is the smallest constant bounding size along each dimension that can
324 /// be computed for the size of the result of `subView`. Returns the allocated
325 /// buffer as `fullLocalView` and the view that matches the size of the result
326 /// of subview operation as `partialLocalView`.
330 };
332 promoteSubviewAsNewBuffer(OpBuilder &b, Location loc, memref::SubViewOp subView,
333  const AllocBufferCallbackFn &allocationFn,
334  DataLayout &layout);
335 
336 /// Promote the `subViews` into a new buffer allocated at the insertion point
337 /// `b`. Promotion occurs in 3 steps:
338 /// 1. Create a new buffer for a full tile (i.e. not clipped at the boundary).
339 /// 2. Take a full view on the buffer.
340 /// 3. Take a partial slice of the full view in step 2. and copy into it.
341 ///
342 /// Return the modified linalg op (the modification happens in place) as well
343 /// as all the copy ops created.
346 
347 /// Emit a suitable vector form for a Linalg op with fully static shape.
348 LogicalResult vectorize(RewriterBase &builder, LinalgOp linalgOp);
349 
350 /// Emit a suitable vector form for a Copy op with fully static shape.
351 LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp);
352 
353 /// Emit a loop nest of `scf.for` with the proper body for `linalgOp`.
355  LinalgOp linalgOp);
356 
357 /// Emit a loop nest of `scf.parallel` with the proper body for `linalgOp`.
359  LinalgOp linalgOp);
360 
361 /// Emit a loop nest of `affine.for` with the proper body for `linalgOp`.
363  LinalgOp linalgOp);
364 
365 //===----------------------------------------------------------------------===//
366 // Preconditions that ensure the corresponding transformation succeeds and can
367 // be applied as a rewrite pattern.
368 //===----------------------------------------------------------------------===//
369 /// Promote memref.subviews feeding linalg-on-buffers operations.
372 
373 //===----------------------------------------------------------------------===//
374 // Transformations exposed as rewrite patterns.
375 //===----------------------------------------------------------------------===//
376 // Marker used as attribute name in generated Linalg rewriting transformations.
378  static const StringLiteral kLinalgTransformMarker;
379 };
380 
381 /// Helper class to control application of linalg transformation patterns.
382 /// Control comes in 2 forms:
383 /// 1. attribute matching and setting behavior using the attribute named
384 /// `kLinalgTransformMarker`. This can be used to build a state machine
385 /// using attributes and incrementally applying patterns to advance states.
386 /// 2. filter function, which is a simple lambda on the Operation* that
387 /// returns a LogicalResult.
389  using FilterFunction = std::function<LogicalResult(Operation *)>;
390 
392  ArrayRef<StringAttr> matchDisjunction = {},
393  Optional<StringAttr> replacement = None);
394 
396  const FilterFunction &f, ArrayRef<StringAttr> matchDisjunction = {},
397  Optional<StringAttr> replacement = None);
398 
401  LogicalResult checkAndNotify(PatternRewriter &rewriter, Operation *op) const;
402  void replaceLinalgTransformationFilter(PatternRewriter &rewriter,
403  Operation *op) const;
404  bool hasReplacementFilter(Operation *op) const;
405 
407  if (f)
408  filters.push_back(f);
409  return *this;
410  }
411 
412  template <typename... OpTypes>
414  return addFilter(
415  [](Operation *op) { return success(isa<OpTypes...>(op)); });
416  }
417 
419  return addFilter([opName](Operation *op) {
420  return success(op->getName().getStringRef() == opName);
421  });
422  }
423 
425  matchByDefault = true;
426  return *this;
427  }
428 
429 private:
431  SmallVector<StringAttr> matchDisjunction;
432  Optional<StringAttr> replacement;
433  /// When set to true, if the attribute is not set, it will be treated as
434  /// a match. Default is false.
435  bool matchByDefault;
436 };
437 
439  std::function<SmallVector<Value, 4>(OpBuilder &, Operation *)>;
440 
441 /// Creates a number of ranges equal to the number of non-zero in `tileSizes`.
442 /// One for each loop of the LinalgOp that is tiled. The `tileSizes` argument
443 /// has one entry per surrounding loop. It uses zero as the convention that a
444 /// particular loop is not tiled. This convention simplifies implementations by
445 /// avoiding affine map manipulations.
446 /// The returned ranges correspond to the loop ranges, in the proper order, that
447 /// are tiled and for which new loops will be created. Also the function returns
448 /// a map from loop indices of the LinalgOp to the corresponding non-empty range
449 /// indices of newly created loops.
451 std::tuple<SmallVector<Range, 4>, LoopIndexToRangeIndexMap>
453  ArrayRef<OpFoldResult> allShapeSizes,
454  ArrayRef<OpFoldResult> allTileSizes);
455 
456 /// A description of a multi-size tiling comprising tile sizes and numbers of
457 /// tiles, expressed as Values which may or may not be constant. Multi-size
458 /// currently means two-size.
460  /// Tile sizes.
461  Value lowTileSize, highTileSize;
462  /// Number of tiles associated with each size.
463  Value lowTripCount, highTripCount;
464 };
465 
466 /// Emits the IR computing the multi-sized tiling specification with two tile
467 /// sizes not exceeding `targetSize`, each divisible by `sizeDivisor`, such that
468 /// there exist numbers of tiles with these sizes that fully cover the given
469 /// iteration space `dimension` of the structured `op`.
470 ///
471 /// The computation is as follows:
472 ///
473 /// b = originalTripCount floordiv sizeDivisor
474 /// t = (targetSize + sizeDivisor - 1) floordiv sizeDivisor
475 /// d = (b + t - 1) floordiv t
476 /// s = (b floordiv d) * sizeDivisor
477 /// v = b % d
478 /// u = d - v
479 ///
480 /// where the tile sizes are `s` and `s` + `sizeDivisor`, and the numbers of
481 /// the corresponding tiles are `u` and `v`, respectively. Alternatively,
482 ///
483 /// s * u + (s + sizeDivisor) * v == original size,
484 /// where s mod sizeDivisor = 0.
485 ///
486 /// Expects all values to be positive. In some cases with the target tile size
487 /// sufficiently close to the dimension shape and non-unit divisor, it is
488 /// impossible to compute such sizes. If `emitAssertion` is set, also emit the
489 /// assertion that size computation succeeded.
490 ///
491 /// Returns the specification consisting of both tile values and the number of
492 /// tiles of each size.
494 computeMultiTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension,
495  OpFoldResult targetSize, OpFoldResult divisor,
496  bool emitAssertions = true);
497 
498 /// Rewrite a TilingInterface `op` to a tiled `scf.foreach_thread`, applying
499 /// tiling by `numThreads`.
500 /// If non-empty, the `threadDimMapping` is added as an attribute to the
501 /// resulting `scf.foreach_thread`.
502 /// Zero tile sizes indicate that the dimension is not tiled, and can be thought
503 /// of as tiling by the full size of data.
504 /// It is the user's responsibility to ensure that `numThreads` is a
505 /// valid tiling specification (i.e. that only tiles parallel
506 /// dimensions, e.g. in the Linalg case).
510 };
512 tileToForeachThreadOp(RewriterBase &builder, TilingInterface op,
513  ArrayRef<OpFoldResult> numThreads,
514  ArrayRef<int64_t> threadDimMapping = {});
515 
516 /// Same as `tileToForeachThreadOp`, but calculate the number of threads
517 /// required using the given tileSizes.
519 tileToForeachThreadOpUsingTileSizes(RewriterBase &builder, TilingInterface op,
520  ArrayRef<OpFoldResult> tileSizes,
521  ArrayRef<int64_t> threadDimMapping = {});
522 
523 /// All indices returned by IndexOp should be invariant with respect to tiling.
524 /// Therefore, if an operation is tiled, we have to transform the indices
525 /// accordingly, i.e. offset them by the values of the corresponding induction
526 /// variables that are captured implicitly in the body of the op.
527 ///
528 /// Example. `linalg.generic` before tiling:
529 ///
530 /// #id_2d = (i, j) -> (i, j)
531 /// #pointwise_2d_trait = {
532 /// indexing_maps = [#id_2d, #id_2d],
533 /// iterator_types = ["parallel", "parallel"]
534 /// }
535 /// linalg.generic #pointwise_2d_trait %operand, %result {
536 /// ^bb0(%operand_in: f32, %result_in: f32):
537 /// %i = linalg.index 0 : index
538 /// %j = linalg.index 1 : index
539 /// <some operations that use %i, %j>
540 /// }: memref<50x100xf32>, memref<50x100xf32>
541 ///
542 /// After tiling pass with tiles sizes 10 and 25:
543 ///
544 /// #strided = (i, j)[s0, s1, s2] -> (i * s1 + s0 + j * s2)
545 ///
546 /// %c1 = arith.constant 1 : index
547 /// %c0 = arith.constant 0 : index
548 /// %c25 = arith.constant 25 : index
549 /// %c10 = arith.constant 10 : index
550 /// operand_dim_0 = dim %operand, 0 : memref<50x100xf32>
551 /// operand_dim_1 = dim %operand, 1 : memref<50x100xf32>
552 /// scf.for %k = %c0 to operand_dim_0 step %c10 {
553 /// scf.for %l = %c0 to operand_dim_1 step %c25 {
554 /// %4 = memref.subview %operand[%k, %l][%c10, %c25][%c1, %c1]
555 /// : memref<50x100xf32> to memref<?x?xf32, #strided>
556 /// %5 = memref.subview %result[%k, %l][%c10, %c25][%c1, %c1]
557 /// : memref<50x100xf32> to memref<?x?xf32, #strided>
558 /// linalg.generic pointwise_2d_trait %4, %5 {
559 /// ^bb0(%operand_in: f32, %result_in: f32):
560 /// %i = linalg.index 0 : index
561 /// %j = linalg.index 1 : index
562 /// // Indices `k` and `l` are implicitly captured in the body.
563 /// %transformed_i = arith.addi %i, %k : index // index `i` is offset by
564 /// %k %transformed_j = arith.addi %j, %l : index // index `j` is offset
565 /// by %l
566 /// // Every use of %i, %j is replaced with %transformed_i, %transformed_j
567 /// <some operations that use %transformed_i, %transformed_j>
568 /// }: memref<?x?xf32, #strided>, memref<?x?xf32, #strided>
569 /// }
570 /// }
571 ///
572 /// TODO: Investigate whether mixing implicit and explicit indices
573 /// does not lead to losing information.
574 void transformIndexOps(RewriterBase &b, LinalgOp op,
576  const LoopIndexToRangeIndexMap &loopIndexToRangeIndex);
577 
579  /// A padding value for every operand.
582  paddingValues.assign(pv.begin(), pv.end());
583  return *this;
584  }
585  /// A list of iterator dimensions to pad.
588  paddingDimensions.assign(pd.begin(), pd.end());
589  return *this;
590  }
591  /// A flag for every operand to mark the PadOp as nofold which enables packing
592  /// for statically shaped operands.
595  packPaddings.assign(pp.begin(), pp.end());
596  return *this;
597  }
598  /// A number of loops to hoist the PadOp out for every operand.
601  hoistPaddings.assign(hp.begin(), hp.end());
602  return *this;
603  }
604  /// A permutation vector for every operand used to transpose the packed PadOp
605  /// results.
609  transposePaddings.assign(tp.begin(), tp.end());
610  return *this;
611  }
612 };
613 
615  /// Tile sizes used to tile the root operation.
618  tileSizes.assign(ts.begin(), ts.end());
619  return *this;
620  }
621  /// Tile interchange used to permute the tile loops.
623  /// When specified, specifies distribution of generated tile loops to
624  /// processors.
628  tileDistribution = std::move(distributionOptions);
629  return *this;
630  }
631 };
632 
634  /// Computation function that returns the tile sizes for each operation.
635  /// Delayed construction of constant tile sizes should occur to interoperate
636  /// with folding.
637  TileSizeComputationFunction tileSizeComputationFunction = nullptr;
638 
641  tileSizeComputationFunction = std::move(fun);
642  return *this;
643  }
644  /// Set the `tileSizeComputationFunction` to return the values `ts`. The
645  /// values must not fold away when tiling. Otherwise, use a more robust
646  /// `tileSizeComputationFunction`.
648  tileSizeComputationFunction = [=](OpBuilder &, Operation *) { return ts; };
649  return *this;
650  }
651  /// Convenience function to set the `tileSizeComputationFunction` to a
652  /// function that computes tile sizes at the point they are needed. Allows
653  /// proper interaction with folding.
654  LinalgTilingOptions &setTileSizes(ArrayRef<int64_t> ts);
655 
656  /// Tile all dynamic dimensions by 1. I.e., scalarize those dimensions.
657  /// Note: `scalarizeDynamicDims` and `setTileSizes` cannot be used together.
658  LinalgTilingOptions &scalarizeDynamicDims();
659 
660  /// The interchange vector to reorder the tiled loops.
661  SmallVector<unsigned, 4> interchangeVector = {};
662 
664  interchangeVector.assign(interchange.begin(), interchange.end());
665  return *this;
666  }
667 
668  /// The type of tile loops to generate.
670 
672  loopType = lt;
673  return *this;
674  }
675 
676  /// When specified, specifies distribution of generated tile loops to
677  /// processors.
679 
682  distribution = std::move(distributionOptions);
683  return *this;
684  }
685 
686  /// Specification markers of how to distribute the `linalg.tiled_loop`.
687  SmallVector<StringRef, 2> distributionTypes = {};
688 
690  distributionTypes.assign(types.begin(), types.end());
691  return *this;
692  }
693 
694  /// Peel the specified loops.
696 
698  peeledLoops.clear();
699  peeledLoops.append(loops.begin(), loops.end());
700  return *this;
701  }
702 };
703 
704 /// Canonicalization patterns relevant to apply after tiling patterns. These are
705 /// applied automatically by the tiling pass but need to be applied manually
706 /// when tiling is called programmatically.
709 
710 ///
711 /// Linalg tiling pattern.
712 ///
713 /// Apply the `tiling` transformation as a pattern.
714 /// `filter` controls LinalgTransformMarker matching and update when specified.
715 /// See `tiling` for more details.
716 // TODO: TiledOpInterface
718  /// Construct a generic pattern applied to all LinalgOp that verify `filter`.
722  PatternBenefit benefit = 1);
723 
724  /// Construct a pattern specifically applied to `opName`.
726  StringRef opName, MLIRContext *context, LinalgTilingOptions options,
728  PatternBenefit benefit = 1);
729 
730  /// `matchAndRewrite` implementation that returns the significant transformed
731  /// pieces of IR.
733  returningMatchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const;
734 
736  PatternRewriter &rewriter) const override {
737  return returningMatchAndRewrite(op, rewriter);
738  }
739 
740 private:
741  /// LinalgTransformMarker handles special attribute manipulations.
743  /// Options to control tiling;
745 };
746 
747 ///
748 /// Linalg padding pattern.
749 ///
750 /// Apply the `padding` transformation as a pattern.
751 /// `filter` controls LinalgTransformMarker matching and update when specified.
752 /// See `padding` for more details.
754  /// Construct a generic pattern applied to all LinalgOp that verify `filter`.
756  MLIRContext *context,
759  PatternBenefit benefit = 1);
760 
761  /// Construct a pattern specifically applied to `opName`.
763  StringRef opName, MLIRContext *context,
766  PatternBenefit benefit = 1);
767 
768  /// `matchAndRewrite` implementation that returns the significant transformed
769  /// pieces of IR.
770  FailureOr<LinalgOp> returningMatchAndRewrite(LinalgOp op,
771  PatternRewriter &rewriter) const;
772 
774  PatternRewriter &rewriter) const override {
775  return returningMatchAndRewrite(op, rewriter);
776  }
777 
778 private:
779  /// LinalgTransformMarker handles special attribute manipulations.
781  /// Options to control padding and hoisting.
783 };
784 
785 /// Rewrites 2-D convolution ops with size-1 window dimensions into 1-D
786 /// convolution ops.
787 template <typename Conv2DOp, typename Conv1DOp>
789  : public OpRewritePattern<Conv2DOp> {
791  MLIRContext *context,
793  PatternBenefit benefit = 1)
794  : OpRewritePattern<Conv2DOp>(context, benefit), filter(std::move(f)) {}
795 
796  FailureOr<Conv1DOp> returningMatchAndRewrite(Conv2DOp convOp,
797  PatternRewriter &rewriter) const;
798 
800  PatternRewriter &rewriter) const override {
801  return returningMatchAndRewrite(convOp, rewriter);
802  }
803 
804 private:
805  /// LinalgTransformMarker handles special attribute manipulations.
807 };
808 
809 extern template struct DownscaleSizeOneWindowed2DConvolution<Conv2DNhwcHwcfOp,
810  Conv1DNwcWcfOp>;
811 extern template struct DownscaleSizeOneWindowed2DConvolution<Conv2DNchwFchwOp,
812  Conv1DNcwFcwOp>;
813 
814 /// Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh)
815 /// dimensions into 1-D depthwise convolution ops.
817  : public OpRewritePattern<DepthwiseConv2DNhwcHwcOp> {
819  MLIRContext *context,
821  PatternBenefit benefit = 1)
822  : OpRewritePattern<DepthwiseConv2DNhwcHwcOp>(context, benefit),
823  filter(std::move(f)) {}
824 
826  returningMatchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp,
827  PatternRewriter &rewriter) const;
828 
829  LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp,
830  PatternRewriter &rewriter) const override {
831  return returningMatchAndRewrite(convOp, rewriter);
832  }
833 
834 private:
835  /// LinalgTransformMarker handles special attribute manipulations.
837 };
838 
839 ///
840 /// Linalg tile and fuse tensor ops pattern.
841 ///
842 /// Apply tiling and fusion as a pattern.
843 /// `filter` controls LinalgTransformMarker matching and update when specified.
844 /// See `tileConsumerAndFuseProducers` for more details.
846  // Entry point to match any LinalgOp.
850  PatternBenefit benefit = 1);
851  // Entry point to match a specific LinalgOp.
853  StringRef opName, MLIRContext *context,
856  PatternBenefit benefit = 1);
857 
858  /// `matchAndRewrite` implementation that returns the significant transformed
859  /// pieces of IR.
861  returningMatchAndRewrite(Operation *op, PatternRewriter &rewriter) const;
862 
864  PatternRewriter &rewriter) const override {
865  return returningMatchAndRewrite(op, rewriter);
866  }
867 
868 private:
869  /// LinalgTransformMarker handles special attribute manipulations.
871  /// Tile sizes and interchange used to tile the root operation.
873 };
874 
875 ///
876 /// Linalg generalization pattern.
877 ///
878 /// Apply the `generalization` transformation as a pattern.
879 /// `filter` controls LinalgTransformMarker matching and update when specified.
880 /// See `generalization` for more details.
882  : public OpInterfaceRewritePattern<LinalgOp> {
883  /// Construct a generic pattern applied to all LinalgOp that verify `filter`.
885  MLIRContext *context,
887  PatternBenefit benefit = 1);
888 
889  /// Construct a pattern specifically applied to `opName`.
891  StringRef opName, MLIRContext *context,
893  PatternBenefit benefit = 1);
894 
895  /// `matchAndRewrite` implementation that returns the significant transformed
896  /// pieces of IR.
898  returningMatchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const;
899 
901  PatternRewriter &rewriter) const override {
902  return returningMatchAndRewrite(op, rewriter);
903  }
904 
905 private:
906  /// LinalgTransformMarker handles special attribute manipulations.
908 };
909 
910 ///
911 /// Linalg peeling patterns.
912 ///
913 
914 /// Compute the loops to peel and return them in a SmallVector. Loops will be
915 /// peeled in order of appearance in the SmallVector. This order will impact the
916 /// output IR. If an inner-to-outer order is provided, the peeled iterations of
917 /// the outer loops will also contain the peeled inner loops. If an
918 /// outer-to-inner order is provided, the peeled iterations of the outer loops
919 /// will not contain any peeled inner loops.
920 using LoopsToPeelComputationFunction = std::function<void(
922 
924  LoopsToPeelComputationFunction loopsToPeelComputationFunction = nullptr;
925 };
926 
927 /// `filter` controls LinalgTransformMarker matching and update when specified.
929  /// Construct a generic pattern applied to all LinalgOp that verify `filter`.
931  MLIRContext *context,
934  PatternBenefit benefit = 1);
935 
936  /// Construct a pattern specifically applied to `opName`.
938  StringRef opName, MLIRContext *context,
941  PatternBenefit benefit = 1);
942 
943  LogicalResult matchAndRewrite(LinalgOp linalgOp,
944  PatternRewriter &rewriter) const override;
945 
946 private:
947  /// LinalgTransformMarker handles special attribute manipulations.
948  const LinalgTransformationFilter filter;
949  /// Peeling options.
951 };
952 
953 ///
954 /// Linalg vectorization patterns.
955 ///
956 /// Empty for now, used for SFINAE purposes only.
958 
959 /// `filter` controls LinalgTransformMarker matching and update when specified.
960 /// See `vectorizeLinalgOp` for more details.
961 struct CopyVectorizationPattern : public OpRewritePattern<memref::CopyOp> {
963 
964  LogicalResult matchAndRewrite(memref::CopyOp copyOp,
965  PatternRewriter &rewriter) const override;
966 };
967 
968 /// Return vector::CombiningKind for the given op.
970 
971 //===----------------------------------------------------------------------===//
972 // Transformation and lowering options exposed as auxiliary structs.
973 //===----------------------------------------------------------------------===//
974 /// Options to control the application of enabling transformations.
975 /// Hoisting transformations are always deemed beneficial and must be disabled
976 /// explicitly.
978  /// Enable LICM.
979  bool licm = true;
980  LinalgEnablingOptions &enableLICM(bool val = true) {
981  licm = val;
982  return *this;
983  }
984  /// Enable hoisting of redundant vector transfer ops.
987  hoistRedundantVectorTransfers = val;
988  return *this;
989  }
990  /// Enable hoisting of redundant vector transfer ops on tensor.
994  hoistRedundantVectorTransfersOnTensor = val;
995  return *this;
996  }
997 };
998 
999 /// Vector lowering options control how ops are lowered down to 1-D and scf.for
1000 /// form.
1002  /// Enable lowering of vector.contract.
1003  /// In a progressive lowering of vectors, this would be the 1st step.
1004  bool contractionLowering = false;
1006  contractionLowering = val;
1007  return *this;
1008  }
1009  /// Enable lowering of vector.multi_reduce.
1010  /// In a progressive lowering of vectors, this would be the 2nd step.
1011  bool multiReductionLowering = false;
1013  multiReductionLowering = val;
1014  return *this;
1015  }
1016  /// Trigger full / partial vector.transfer splits.
1017  /// In a progressive lowering of vectors, this would be the 3rd step.
1018  bool transferPartialRewrite = false;
1020  transferPartialRewrite = val;
1021  return *this;
1022  }
1023  /// Enable lowering of vector.transfer to scf.
1024  /// In a progressive lowering of vectors, this would be the 4th step.
1025  bool transferToSCFConversion = false;
1027  transferToSCFConversion = val;
1028  return *this;
1029  }
1030  /// Maximal transfer rank under which we do not lower further.
1031  int64_t maxTransferRank = 1;
1033  maxTransferRank = val;
1034  return *this;
1035  }
1036  /// Vector lowering operations may result in surprising behavior when
1037  /// composing multiple codegen strategies and must be enabled explicitly.
1038  /// In a progressive lowering of vectors, this would be the 5th step.
1039  bool transferLowering = true;
1041  transferLowering = val;
1042  return *this;
1043  }
1044  /// Enable lowering of vector.shape_cast to insert/extract.
1045  /// In a progressive lowering of vectors, this would be the 6th step.
1046  bool shapeCastLowering = true;
1048  shapeCastLowering = val;
1049  return *this;
1050  }
1051  /// Enable lowering of vector.transpose.
1052  /// In a progressive lowering of vectors, this would be the 7th step.
1053  bool transposeLowering = false;
1055  transposeLowering = val;
1056  return *this;
1057  }
1058  /// Enable AVX2-specific lowerings.
1059  bool avx2Lowering = false;
1061  avx2Lowering = val;
1062  return *this;
1063  }
1064 
1065  /// Configure the post staged-patterns late vector.transfer to scf
1066  /// conversion.
1070  vectorTransferToSCFOptions = options;
1071  return *this;
1072  }
1073  /// Configure late vector transformations.
1077  vectorTransformOptions = options;
1078  return *this;
1079  }
1080  /// Configure specialized vector lowerings.
1084  avx2LoweringOptions = options;
1085  return *this;
1086  }
1087 };
1088 
1089 //===----------------------------------------------------------------------===//
1090 // Transformations exposed as rewrite patterns.
1091 //===----------------------------------------------------------------------===//
1092 ///
1093 /// Linalg lowering patterns.
1094 ///
1095 /// Apply the `linalgLowerOpToLoops` transformation as a pattern.
1096 /// `filter` controls LinalgTransformMarker matching and update when specified.
1097 /// See `linalgLowerOpToLoops` for more details.
1099  LibraryCall = 0,
1100  Loops = 1,
1101  AffineLoops = 2,
1102  ParallelLoops = 3
1103 };
1104 
1105 template <typename OpTy>
1108  MLIRContext *context, LinalgLoweringType loweringType,
1110  PatternBenefit benefit = 1)
1111  : RewritePattern(OpTy::getOperationName(), benefit, context),
1112  filter(std::move(f)), loweringType(loweringType) {}
1113 
1114  // TODO: Move implementation to .cpp once named ops are auto-generated.
1116  PatternRewriter &rewriter) const override {
1117  LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
1118  if (!linalgOp)
1119  return failure();
1120  if (failed(filter.checkAndNotify(rewriter, linalgOp)))
1121  return failure();
1122 
1123  switch (loweringType) {
1124  case LinalgLoweringType::LibraryCall:
1125  // TODO: Move lowering to library calls here.
1126  return failure();
1128  if (failed(linalgOpToLoops(rewriter, op)))
1129  return failure();
1130  break;
1131  case LinalgLoweringType::AffineLoops:
1132  if (failed(linalgOpToAffineLoops(rewriter, op)))
1133  return failure();
1134  break;
1135  case LinalgLoweringType::ParallelLoops:
1136  if (failed(linalgOpToParallelLoops(rewriter, op)))
1137  return failure();
1138  break;
1139  }
1140 
1141  rewriter.eraseOp(op);
1142  return success();
1143  }
1144 
1145 private:
1146  /// LinalgTransformMarker handles special attribute manipulations.
1148  /// Controls whether the pattern lowers to library calls, scf.for, affine.for
1149  /// or scf.parallel.
1150  LinalgLoweringType loweringType;
1151 };
1152 
1153 /// Linalg generalization patterns
1154 
1155 /// Populates `patterns` with patterns to convert spec-generated named ops to
1156 /// linalg.generic ops.
1158  RewritePatternSet &patterns,
1160 
1161 /// Linalg decompose convolutions patterns
1162 
1163 /// Populates patterns to decompose high-D convolution ops into low-D ones. This
1164 /// is a step in progressive lowering for convolution ops, afterwards we can
1165 /// vectorize the low-D convolution ops.
1167  RewritePatternSet &patterns,
1169  PatternBenefit benefit = 1);
1170 
1171 //===----------------------------------------------------------------------===//
1172 // Op-specific patterns.
1173 //===----------------------------------------------------------------------===//
1174 
1175 /// tensor::PadOp is not canonicalized away yet, so we provide a transformation
1176 /// to `linalg.generic`.
1177 struct PadOpTransformationPattern : public OpRewritePattern<tensor::PadOp> {
1179 
1180  LogicalResult matchAndRewrite(tensor::PadOp padOp,
1181  PatternRewriter &rewriter) const override;
1182 };
1183 
1184 /// Pad the iterator dimensions `paddingDimensions` of all `opToPad` operands to
1185 /// a static bounding box. Use `paddingValues` and `packPaddings` to set padding
1186 /// value and nofold attribute of the created tensor::PadOps, respectively.
1187 /// Update `paddedOp` to the cloned operation with statically shaped
1188 /// `paddingDimensions` and return the extracted dynamically shaped results. If
1189 /// padding fails, return failure.
1191 rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad,
1192  ArrayRef<int64_t> paddingDimensions,
1193  ArrayRef<Attribute> paddingValues,
1194  ArrayRef<bool> packPaddings, LinalgOp &paddedOp);
1195 
1196 using OptimizeCopyFn =
1197  std::function<LogicalResult(PatternRewriter &, tensor::PadOp, Value)>;
1198 
1199 /// Rewrite a tensor::PadOp into a sequence of InitTensorOp, FillOp and
1200 /// InsertSliceOp. For now, only constant padding values are supported.
1201 /// `OptimizeCopyFn` can be used to customize copying step optimization.
1202 struct GeneralizePadOpPattern : public OpRewritePattern<tensor::PadOp> {
1204  OptimizeCopyFn optimizeCopyFn = nullptr,
1205  PatternBenefit benefit = 1)
1206  : OpRewritePattern<tensor::PadOp>(context, benefit),
1207  optimizeCopyFn(std::move(optimizeCopyFn)) {}
1208  LogicalResult matchAndRewrite(tensor::PadOp padOp,
1209  PatternRewriter &rewriter) const override;
1210 
1211 protected:
1213  Value createFillOrGenerateOp(PatternRewriter &rewriter, tensor::PadOp padOp,
1214  Value dest,
1215  const SmallVector<Value> &dynSizes) const;
1216 };
1217 
1218 /// Populates `patterns` with patterns that vectorize tensor.pad.
1219 /// These patterns are meant to apply in a complementary fashion. Benefits
1220 /// are used to encode a certain ordering of pattern application. To avoid
1221 /// scattering magic constants throughout the code base, the patterns must be
1222 /// added with this function. `baseBenefit` can be used to offset the benefit
1223 /// of all tensor::PadOp vectorization patterns by a certain value.
1225  PatternBenefit baseBenefit = 1);
1226 
1227 /// Match and rewrite for the pattern:
1228 /// ```
1229 /// %alloc = ...
1230 /// [optional] %view = memref.view %alloc ...
1231 /// %subView = subview %allocOrView ...
1232 /// [optional] linalg.fill(%allocOrView, %cst) ...
1233 /// ...
1234 /// memref.copy(%in, %subView) ...
1235 /// vector.transfer_read %allocOrView[...], %cst ...
1236 /// ```
1237 /// into
1238 /// ```
1239 /// [unchanged] %alloc = ...
1240 /// [unchanged] [optional] %view = memref.view %alloc ...
1241 /// [unchanged] [unchanged] %subView = subview %allocOrView ...
1242 /// ...
1243 /// vector.transfer_read %in[...], %cst ...
1244 /// ```
1245 /// Where there is no interleaved use between memref.copy and transfer_read as
1246 /// well as no interleaved use between linalg.fill and memref.copy (if
1247 /// linalg.fill is specified).
1248 /// This is a custom rewrite to forward partial reads (with optional fills) to
1249 /// vector.transfer_read.
1251  : public OpRewritePattern<vector::TransferReadOp> {
1253 
1254  LogicalResult matchAndRewrite(vector::TransferReadOp xferOp,
1255  PatternRewriter &rewriter) const override;
1256 };
1257 
1258 /// Match and rewrite for the pattern:
1259 /// ```
1260 /// %alloc = ...
1261 /// [optional] %view = memref.view %alloc ...
1262 /// %subView = subview %allocOrView...
1263 /// ...
1264 /// vector.transfer_write %..., %allocOrView[...]
1265 /// memref.copy(%subView, %out)
1266 /// ```
1267 /// into
1268 /// ```
1269 /// [unchanged] %alloc = ...
1270 /// [unchanged] [optional] %view = memref.view %alloc ...
1271 /// [unchanged] %subView = subview %allocOrView...
1272 /// ...
1273 /// vector.transfer_write %..., %out[...]
1274 /// ```
1275 /// Where there is no interleaved use between transfer_write and memref.copy.
1276 /// This is a custom rewrite to forward partial writes to vector.transfer_write.
1278  : public OpRewritePattern<vector::TransferWriteOp> {
1280 
1281  LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp,
1282  PatternRewriter &rewriter) const override;
1283 };
1284 
1285 //===----------------------------------------------------------------------===//
1286 // Support for staged pattern application.
1287 //===----------------------------------------------------------------------===//
1288 /// Helper function to allow applying rewrite patterns, interleaved with more
1289 /// global transformations, in a staged fashion:
1290 /// 1. the first stage consists of a list of FrozenRewritePatternSet. Each
1291 /// FrozenRewritePatternSet in this list is applied once, in order.
1292 /// 2. the second stage consists of a single RewritePattern that is applied
1293 /// greedily until convergence.
1294 /// 3. the third stage consists of applying a lambda, generally used for
1295 /// non-local transformation effects. This allows creating custom fused
1296 /// transformations where patterns can be ordered and applied at a finer
1297 /// granularity than a sequence of traditional compiler passes.
1299  Operation *op, ArrayRef<FrozenRewritePatternSet> stage1Patterns,
1300  const FrozenRewritePatternSet &stage2Patterns,
1301  function_ref<LogicalResult(Operation *)> stage3Lambda = nullptr);
1302 
1303 /// Rewrite extract_slice(tensor.pad(x)) into tensor.pad(extract_slice(x)).
1305  : public OpRewritePattern<tensor::ExtractSliceOp> {
1306  /// A function to control pattern application and rewrite logic.
1307  ///
1308  /// The function will be given the slice op and should return:
1309  /// - None: to fail the match and not apply the pattern;
1310  /// - true: to apply the pattern with zero slice guard;
1311  /// - false: to apply the pattern without zero slice guard.
1312  ///
1313  /// See the documentation for tensor::bubbleUpPadSlice regarding zero slice
1314  /// guard.
1315  using ControlFn = std::function<llvm::Optional<bool>(tensor::ExtractSliceOp)>;
1316 
1318  ControlFn controlFn = nullptr,
1319  PatternBenefit benefit = 1)
1320  : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {}
1321 
1322  LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
1323  PatternRewriter &rewriter) const override;
1324 
1325 private:
1326  ControlFn controlFn;
1327 };
1328 
1329 //===----------------------------------------------------------------------===//
1330 // Helper classes for type list expansion.
1331 //===----------------------------------------------------------------------===//
1332 template <typename... OpTypes>
1334 
1335 template <>
1337 public:
1338  static void insert(RewritePatternSet &patterns,
1340  const LinalgTransformationFilter &f) {}
1341 };
1342 
1343 template <typename... OpTypes>
1345 
1346 template <>
1348 public:
1349  static void insert(RewritePatternSet &patterns,
1351  const LinalgTransformationFilter &f) {}
1352 };
1353 
1354 template <typename OpTy, typename... OpTypes>
1355 class TilingPatterns<OpTy, OpTypes...> {
1356 public:
1357  static void insert(RewritePatternSet &patterns,
1359  const LinalgTransformationFilter &f) {
1360  patterns.add<LinalgTilingPattern>(OpTy::getOperationName(),
1361  patterns.getContext(), options, f);
1362  TilingPatterns<OpTypes...>::insert(patterns, options, f);
1363  }
1364 };
1365 
1366 /// Function signature to control reduction splitting. This returns a pair
1367 /// containing a ratio and a dimension index. The ratio is used to split the
1368 /// reduction dimension. The dimension index is used to control where the extra
1369 /// dimension is added to the intermediate tensor shape. If the ratio value is
1370 /// less or equal to 1 then nothing will be done.
1371 // TODO: don't use unsigned unless doing bit manipulation.
1373  std::function<std::pair<int64_t, unsigned>(LinalgOp op)>;
1374 
1375 /// Patterns to apply `splitReduction` below.
1377  RewritePatternSet &patterns,
1378  const ControlSplitReductionFn &controlSplitReductionFn,
1380  bool useAlloc = false);
1381 
1382 /// Apply transformation to split the single linalg op reduction into a parallel
1383 /// and reduction dimension. Then create a new linalg.generic op doing the rest
1384 /// of the reduction. Return the new linalg op with an extra parallel dimension
1385 /// or failure if the transformation didn't happen.
1386 /// Example:
1387 /// ```
1388 /// %r = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
1389 /// affine_map<(d0) -> ()>],
1390 /// iterator_types = ["reduction"]}
1391 /// ins(%in : tensor<32xf32>)
1392 /// outs(%out : tensor<f32>) {
1393 /// ^bb0(%arg1: f32, %arg2: f32):
1394 /// %y = arith.addf %arg1, %arg2 : f32
1395 /// linalg.yield %y : f32
1396 /// } -> tensor<f32>
1397 /// ```
1398 /// To:
1399 /// ```
1400 /// %cst = arith.constant 0.000000e+00 : f32
1401 /// %0 = tensor.expand_shape %in [[0, 1]] : tensor<32xf32> into tensor<4x8xf32>
1402 /// %1 = linalg.init_tensor [4] : tensor<4xf32>
1403 /// %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<4xf32>) -> tensor<4xf32>
1404 /// %3 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
1405 /// affine_map<(d0, d1) -> (d0)>],
1406 /// iterator_types = ["parallel", "reduction"]}
1407 /// ins(%0 : tensor<4x8xf32>) outs(%2 : tensor<4xf32>) {
1408 /// ^bb0(%arg3: f32, %arg5: f32):
1409 /// %5 = arith.addf %arg3, %arg4 : f32
1410 /// linalg.yield %5 : f32
1411 /// } -> tensor<4xf32>
1412 /// %r = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
1413 /// affine_map<(d0) -> ()>],
1414 /// iterator_types = ["reduction"]}
1415 /// ins(%3 : tensor<4xf32>) outs(%out : tensor<f32>) {
1416 /// ^bb0(%arg3: f32, %arg4: f32):
1417 /// %5 = arith.addf %arg3, %arg4 : f32
1418 /// linalg.yield %5 : f32
1419 /// } -> tensor<f32>
1420 /// ```
1422 splitReduction(PatternRewriter &b, LinalgOp op,
1423  const ControlSplitReductionFn &controlSplitReductionFn,
1424  const LinalgTransformationFilter &f, bool useAlloc = false);
1425 
1426 /// Filterless version of the above.
1427 /// Returns both the new linalg ops as well as the fillOp needed to initialize
1428 /// the temporary expanded tensor with the proper neutral element.
1431  FillOp fillOp;
1432  LinalgOp splitLinalgOp;
1434 };
1436 splitReduction(PatternRewriter &b, LinalgOp op,
1437  const ControlSplitReductionFn &controlSplitReductionFn,
1438  bool useAlloc = false);
1439 
1440 /// Scaling-based implementation of the split reduction transformation.
1441 /// Instead of introducing an ExpandShapeOp, this rewrites a reduction dimension
1442 /// `k` into `k * scale + kk`.
1443 ///
1444 /// Example:
1445 /// ```
1446 /// %0 = linalg.matmul ins(%A, %B: tensor<16x256xf32>, tensor<256x32xf32>)
1447 /// outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
1448 /// ```
1449 ///
1450 /// Is transformed to:
1451 ///
1452 /// ```
1453 /// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d2 * 4 + d3)>
1454 /// #map1 = affine_map<(d0, d1, d2, d3) -> (d2 * 4 + d3, d1)>
1455 /// #map2 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
1456 /// #map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
1457 /// #map4 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
1458 /// #map5 = affine_map<(d0, d1, d2) -> (d0, d1)>
1459 /// %0 = linalg.init_tensor [16, 32, 64] : tensor<16x32x64xf32>
1460 /// %cst = arith.constant 0.000000e+00 : f32
1461 /// %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<16x32x64xf32>) ->
1462 /// tensor<16x32x64xf32>
1463 /// %2 = linalg.init_tensor [64, 4] : tensor<64x4xi1>
1464 ///
1465 /// %3 = linalg.generic {indexing_maps = [#map0, #map1, #map2, #map3],
1466 /// iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
1467 /// ins(%A, %B, %2 : tensor<16x256xf32>, tensor<256x32xf32>, tensor<64x4xi1>)
1468 /// outs(%1 : tensor<16x32x64xf32>) {
1469 /// ^bb0(%arg3: f32, %arg4: f32, %arg5: i1, %arg6: f32):
1470 /// %5 = arith.mulf %arg3, %arg4 : f32
1471 /// %6 = arith.addf %arg6, %5 : f32
1472 /// linalg.yield %6 : f32
1473 /// } -> tensor<16x32x64xf32>
1474 ///
1475 /// %4 = linalg.generic {indexing_maps = [#map4, #map5],
1476 /// iterator_types = ["parallel", "parallel", "reduction"]}
1477 // ins(%3 : tensor<16x32x64xf32>)
1478 /// outs(%C : tensor<16x32xf32>) {
1479 /// ^bb0(%arg3: f32, %arg4: f32):
1480 /// %5 = arith.addf %arg3, %arg4 : f32
1481 /// linalg.yield %5 : f32
1482 /// } -> tensor<16x32xf32>
1483 ///
1484 /// return %4 : tensor<16x32xf32>
1485 /// ```
1488  const ControlSplitReductionFn &controlSplitReductionFn,
1489  bool useAlloc = false);
1490 
1491 } // namespace linalg
1492 } // namespace mlir
1493 
1494 #endif // MLIR_DIALECT_LINALG_TRANSFORMS_TRANSFORMS_H
FailureOr< Operation * > fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand)
Fuse two linalg.generic operations that have a producer-consumer relationship captured through fusedO...
Include the generated interface declarations.
LinalgPromotionOptions & setCopyInOutFns(CopyCallbackFn const &copyIn, CopyCallbackFn const &copyOut)
Definition: Transforms.h:314
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
Definition: Transforms.h:863
SmallVector< SmallVector< int64_t > > transposePaddings
A permutation vector for every operand used to transpose the packed PadOp results.
Definition: Transforms.h:606
Helper class to control application of linalg transformation patterns.
Definition: Transforms.h:388
LinalgVectorLoweringOptions & enableContractionLowering(bool val=true)
Definition: Transforms.h:1005
LinalgPaddingOptions & setHoistPaddings(ArrayRef< int64_t > hp)
Definition: Transforms.h:600
std::function< void(OpBuilder &, Operation *, SmallVectorImpl< scf::ForOp > &)> LoopsToPeelComputationFunction
Linalg peeling patterns.
Definition: Transforms.h:921
std::tuple< SmallVector< Range, 4 >, LoopIndexToRangeIndexMap > makeTiledLoopRanges(RewriterBase &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > allShapeSizes, ArrayRef< OpFoldResult > allTileSizes)
Definition: Tiling.cpp:56
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
Definition: Transforms.h:1115
LinalgLoweringType
Linalg lowering patterns.
Definition: Transforms.h:1098
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
Linalg tiling pattern.
Definition: Transforms.h:717
Options for controlling specialized AVX2 lowerings.
Definition: Transforms.h:159
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:600
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
void populateDecomposeLinalgOpsPattern(RewritePatternSet &patterns)
Populate patterns for splitting a LinalgOp with multiple statements within its payload into multiple ...
SmallVector< Value, 4 > tensorResults
Definition: Transforms.h:201
void populateSwapExtractSliceWithFillPatterns(RewritePatternSet &patterns)
Adds patterns that waps tensor.extract_slice(linalg.fill(cst, init)) into linalg.fill(cst, tensor.extract_slice(init)).
void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, const LinalgTransformationFilter &filter=LinalgTransformationFilter(), PatternBenefit benefit=1)
Linalg decompose convolutions patterns.
Definition: Transforms.cpp:980
static void insert(RewritePatternSet &patterns, const LinalgTilingOptions &options, const LinalgTransformationFilter &f)
Definition: Transforms.h:1349
This class represents a frozen set of patterns that can be processed by a pattern applicator...
std::function< LogicalResult(OpBuilder &b, Value buffer)> DeallocBufferCallbackFn
Callback function type used to deallocate the buffers used to hold the promoted subview.
Definition: Transforms.h:247
void populateFoldReshapeOpsByCollapsingPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes)
Patterns to fold an expanding tensor.expand_shape operation with its producer generic operation by co...
Options to control the application of enabling transformations.
Definition: Transforms.h:977
Linalg padding pattern.
Definition: Transforms.h:753
Rewrite a tensor::PadOp into a sequence of InitTensorOp, FillOp and InsertSliceOp.
Definition: Transforms.h:1202
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents a single result from folding an operation.
Definition: OpDefinition.h:239
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp namedOp)
Create a GenericOp from the given named operation namedOp and replace namedOp.
FailureOr< SplitReductionResult > splitReduction(PatternRewriter &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
Rewrite a TilingInterface op to a tiled scf.foreach_thread, applying tiling by numThreads.
Definition: Transforms.h:507
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
void populateFoldReshapeOpsByExpansionPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes)
Patterns to fold an expanding (collapsing) tensor_reshape operation with its producer (consumer) gene...
LinalgPaddingOptions & setPaddingDimensions(ArrayRef< int64_t > pd)
Definition: Transforms.h:587
LinalgLoweringPattern(MLIRContext *context, LinalgLoweringType loweringType, LinalgTransformationFilter f=LinalgTransformationFilter(), PatternBenefit benefit=1)
Definition: Transforms.h:1107
LogicalResult rewriteTopLevelForeachThreadToGpuBlocks(RewriterBase &rewriter, scf::ForeachThreadOp foreachThreadOp, function_ref< void(Operation *, const SmallVector< int64_t > &, IndexType, SmallVector< Value > &)> blockIdGenerator, SmallVector< int64_t > &gridDims)
Maps the top level scf.foreach_thread op to GPU Thread Blocks.
Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) dimensions into 1-D depthwise c...
Definition: Transforms.h:816
SmallVector< int64_t > tileInterchange
Tile interchange used to permute the tile loops.
Definition: Transforms.h:622
Rewrite extract_slice(tensor.pad(x)) into tensor.pad(extract_slice(x)).
Definition: Transforms.h:1304
FailureOr< LinalgLoops > linalgOpToAffineLoops(PatternRewriter &rewriter, LinalgOp linalgOp)
Emit a loop nest of affine.for with the proper body for linalgOp.
Definition: Loops.cpp:365
LinalgTilingOptions & setInterchange(ArrayRef< unsigned > interchange)
Definition: Transforms.h:663
LinalgVectorLoweringOptions & enableMultiReductionLowering(bool val=true)
Definition: Transforms.h:1012
LinalgTilingOptions & setDistributionOptions(LinalgLoopDistributionOptions distributionOptions)
Definition: Transforms.h:681
LinalgVectorLoweringOptions & enableTransferToSCFConversion(bool val=true)
Definition: Transforms.h:1026
Rewrites 2-D convolution ops with size-1 window dimensions into 1-D convolution ops.
Definition: Transforms.h:788
LinalgPaddingOptions & setPackPaddings(ArrayRef< bool > pp)
Definition: Transforms.h:594
LinalgPromotionOptions & setAllocationDeallocationFns(AllocBufferCallbackFn const &allocFn, DeallocBufferCallbackFn const &deallocFn)
Definition: Transforms.h:304
SmallVector< int64_t > paddingDimensions
A list of iterator dimensions to pad.
Definition: Transforms.h:586
void populateSparseTensorRewriting(RewritePatternSet &patterns, bool enableRT)
std::function< SmallVector< Value, 4 >(OpBuilder &, Operation *)> TileSizeComputationFunction
Definition: Transforms.h:439
LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override
Definition: Transforms.h:735
void populatePadTensorTilingPatterns(RewritePatternSet &patterns, const LinalgTilingOptions &options)
Definition: Tiling.cpp:750
LinalgVectorLoweringOptions & setMaxTransferRank(int64_t val)
Definition: Transforms.h:1032
LinalgEnablingOptions & enableLICM(bool val=true)
Definition: Transforms.h:980
SmallVector< bool > packPaddings
A flag for every operand to mark the PadOp as nofold which enables packing for statically shaped oper...
Definition: Transforms.h:593
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
FailureOr< PromotionInfo > promoteSubviewAsNewBuffer(OpBuilder &b, Location loc, memref::SubViewOp subView, const AllocBufferCallbackFn &allocationFn, DataLayout &layout)
Definition: Promotion.cpp:210
LinalgVectorLoweringOptions & enableAVX2Lowering(bool val=true)
Definition: Transforms.h:1060
Match and rewrite for the pattern: ``` alloc = ...
Definition: Transforms.h:1250
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op...
Definition: Interchange.cpp:51
void populateElementwiseToLinalgConversionPatterns(RewritePatternSet &patterns)
Populate patterns that convert ElementwiseMappable ops to linalg parallel loops.
void populateFuseTensorPadWithProducerLinalgOpPatterns(RewritePatternSet &patterns)
Pattern to fuse a tensor.pad operation with the producer of its source, if the producer is a linalg o...
std::function< Optional< Value >(OpBuilder &b, memref::SubViewOp subView, ArrayRef< Value > boundingSubViewSize, DataLayout &layout)> AllocBufferCallbackFn
Callback function type used to perform the allocation for the promoted subView.
Definition: Transforms.h:242
SmallVector< scf::ForOp, 8 > Loops
Tile a nest of standard for loops rooted at rootForOp by finding such parametric tile sizes that the ...
Definition: Utils.h:142
filter controls LinalgTransformMarker matching and update when specified.
Definition: Transforms.h:961
bool areElementwiseOpsFusable(OpOperand *fusedOperand)
Return true if two linalg.generic operations with producer/consumer relationship through fusedOperand...
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:244
DownscaleSizeOneWindowed2DConvolution(MLIRContext *context, LinalgTransformationFilter f=LinalgTransformationFilter(), PatternBenefit benefit=1)
Definition: Transforms.h:790
FailureOr< SmallVector< Value > > rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad, ArrayRef< int64_t > paddingDimensions, ArrayRef< Attribute > paddingValues, ArrayRef< bool > packPaddings, LinalgOp &paddedOp)
Pad the iterator dimensions paddingDimensions of all opToPad operands to a static bounding box...
Definition: Transforms.cpp:260
Match and rewrite for the pattern: ``` alloc = ...
Definition: Transforms.h:1277
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
LinalgTransformationFilter & addFilter(const FilterFunction &f)
Definition: Transforms.h:406
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
static void insert(RewritePatternSet &patterns, const LinalgTilingOptions &options, const LinalgTransformationFilter &f)
Definition: Transforms.h:1357
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
std::function< llvm::Optional< bool >(tensor::ExtractSliceOp)> ControlFn
A function to control pattern application and rewrite logic.
Definition: Transforms.h:1315
LinalgPromotionOptions & setOperandsToPromote(ArrayRef< int64_t > operands)
Definition: Transforms.h:259
This class provides support for representing a failure result, or a valid value of type T...
Definition: LogicalResult.h:78
llvm::Optional< vector::CombiningKind > getCombinerOpKind(Operation *combinerOp)
Return vector::CombiningKind for the given op.
tensor::PadOp is not canonicalized away yet, so we provide a transformation to linalg.generic.
Definition: Transforms.h:1177
void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns for vectorizing low-D convolution ops.
LinalgVectorLoweringOptions & setAVX2LoweringOptions(x86vector::avx2::LoweringOptions options)
Definition: Transforms.h:1083
FailureOr< SplitReductionResult > splitReductionByScaling(PatternRewriter &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
Scaling-based implementation of the split reduction transformation.
LinalgPaddingOptions & setPaddingValues(ArrayRef< Attribute > pv)
Definition: Transforms.h:581
LogicalResult promoteSubviewsPrecondition(Operation *op, LinalgPromotionOptions options)
Promote memref.subviews feeding linalg-on-buffers operations.
Definition: Promotion.cpp:367
filter controls LinalgTransformMarker matching and update when specified.
Definition: Transforms.h:928
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:32
void populateInlineConstantOperandsPatterns(RewritePatternSet &patterns)
Patterns that are used to inline constant operands into linalg generic ops.
LinalgVectorLoweringOptions & enableVectorTransposeLowering(bool val=true)
Definition: Transforms.h:1054
x86vector::avx2::LoweringOptions avx2LoweringOptions
Configure specialized vector lowerings.
Definition: Transforms.h:1081
SmallVector< int64_t > tileSizes
Tile sizes used to tile the root operation.
Definition: Transforms.h:616
std::function< LogicalResult(PatternRewriter &, tensor::PadOp, Value)> OptimizeCopyFn
Definition: Transforms.h:1197
ExtractSliceOfPadTensorSwapPattern(MLIRContext *context, ControlFn controlFn=nullptr, PatternBenefit benefit=1)
Definition: Transforms.h:1317
FailureOr< LinalgOp > promoteSubViews(OpBuilder &b, LinalgOp op, const LinalgPromotionOptions &options)
Promote the subViews into a new buffer allocated at the insertion point b.
Definition: Promotion.cpp:389
LinalgTransformationFilter & addOpFilter()
Definition: Transforms.h:413
Value lowTripCount
Number of tiles associated with each size.
Definition: Transforms.h:463
void peelLoops(RewriterBase &rewriter, ArrayRef< scf::ForOp > loops)
Peel and canonicalize &#39;loops&#39;.
Definition: Transforms.cpp:332
LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override
Definition: Transforms.h:773
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns)
Patterns to fold unit-extent dimensions in operands/results of linalg ops on tensors.
SmallVector< int64_t > hoistPaddings
A number of loops to hoist the PadOp out for every operand.
Definition: Transforms.h:599
A multi-dimensional affine map Affine map&#39;s are immutable like Type&#39;s, and they are uniqued...
Definition: AffineMap.h:42
static void insert(RewritePatternSet &patterns, const LinalgVectorizationOptions &options, const LinalgTransformationFilter &f)
Definition: Transforms.h:1338
LinalgVectorLoweringOptions & enableTransferPartialRewrite(bool val=true)
Definition: Transforms.h:1019
FailureOr< TiledLinalgOp > tileLinalgOp(RewriterBase &b, LinalgOp op, const LinalgTilingOptions &options)
Definition: Tiling.cpp:579
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:34
LinalgTransformationFilter & addOpNameFilter(StringRef opName)
Definition: Transforms.h:418
SmallVector< int64_t > peeledLoops
Peel the specified loops.
Definition: Transforms.h:695
LinalgTilingOptions & setDistributionTypes(ArrayRef< StringRef > types)
Definition: Transforms.h:689
Create a new buffer using the allocationFn provided.
Definition: Transforms.h:327
void populateConstantFoldLinalgOperations(RewritePatternSet &patterns, const ControlFusionFn &controlFn)
Patterns to constant fold Linalg operations.
LinalgEnablingOptions & enableHoistRedundantVectorTransfers(bool val=true)
Definition: Transforms.h:986
LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override
Definition: Transforms.h:900
LinalgVectorLoweringOptions & setVectorTransformsOptions(vector::VectorTransformsOptions options)
Definition: Transforms.h:1076
void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns)
Definition: Tiling.cpp:708
LinalgTransformationFilter & setMatchByDefault()
Definition: Transforms.h:424
void populateLinalgNamedOpConversionPatterns(RewritePatternSet &patterns)
Patterns to convert from one named op to another.
SmallVector< Attribute > paddingValues
A padding value for every operand.
Definition: Transforms.h:580
std::function< LogicalResult(OpBuilder &b, Value src, Value dst)> CopyCallbackFn
Callback function type used to insert copy from original subview to subview of the promoted region fo...
Definition: Transforms.h:254
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
FailureOr< LinalgLoops > linalgOpToLoops(PatternRewriter &rewriter, LinalgOp linalgOp)
Emit a loop nest of scf.for with the proper body for linalgOp.
Definition: Loops.cpp:371
void transformIndexOps(RewriterBase &b, LinalgOp op, SmallVectorImpl< Value > &ivs, const LoopIndexToRangeIndexMap &loopIndexToRangeIndex)
All indices returned by IndexOp should be invariant with respect to tiling.
Definition: Tiling.cpp:84
static llvm::ManagedStatic< PassManagerOptions > options
LinalgVectorLoweringOptions & enableShapeCastLowering(bool val=true)
Definition: Transforms.h:1047
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:355
void hoistRedundantVectorTransfers(func::FuncOp func)
Hoist vector.transfer_read/vector.transfer_write on buffers pairs out of immediately enclosing scf::F...
Definition: Hoisting.cpp:401
LinalgPromotionOptions & setUseAlloca(bool use)
Definition: Transforms.h:294
std::function< LogicalResult(Operation *)> FilterFunction
Definition: Transforms.h:389
LinalgTilingLoopType
The type of loops to be generated during tiling.
Definition: Utils.h:149
DownscaleDepthwiseConv2DNhwcHwcOp(MLIRContext *context, LinalgTransformationFilter f=LinalgTransformationFilter(), PatternBenefit benefit=1)
Definition: Transforms.h:818
mlir::WalkResult rewriteMapNestedForeachThreadToGpuThreads(RewriterBase &rewriter, Operation *target, const SmallVector< int64_t > &blockDim, bool syncAfterDistribute)
Searches scf.foreach_thread ops nested under target and maps each such op to GPU threads.
LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp, PatternRewriter &rewriter) const override
Definition: Transforms.h:829
LinalgPromotionOptions & setAlignment(unsigned align)
Definition: Transforms.h:288
static const StringLiteral kLinalgTransformMarker
Definition: Transforms.h:378
LinalgVectorLoweringOptions & enableTransferLowering(bool val=true)
Definition: Transforms.h:1040
VectorTransferToSCFOptions vectorTransferToSCFOptions
Configure the post staged-patterns late vector.transfer to scf conversion.
Definition: Transforms.h:1067
LinalgTilingOptions & setLoopType(LinalgTilingLoopType lt)
Definition: Transforms.h:671
LinalgTilingAndFusionOptions & setDistributionOptions(LinalgLoopDistributionOptions distributionOptions)
Definition: Transforms.h:627
Linalg generalization pattern.
Definition: Transforms.h:881
FailureOr< MultiSizeSpecification > computeMultiTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension, OpFoldResult targetSize, OpFoldResult divisor, bool emitAssertions=true)
Emits the IR computing the multi-sized tiling specification with two tile sizes not exceeding targetS...
Definition: Tiling.cpp:117
LinalgVectorLoweringOptions & setVectorTransferToSCFOptions(VectorTransferToSCFOptions options)
Definition: Transforms.h:1069
Linalg tile and fuse tensor ops pattern.
Definition: Transforms.h:845
std::pair< TilingInterface, TilingInterface > splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension, OpFoldResult splitPoint)
Split the given op into two parts along the given iteration space dimension at the specified splitPoi...
Definition: Split.cpp:67
void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns)
Patterns that are used to bubble up extract slice op above linalg op.
Perform standalone tiling of a single LinalgOp by tileSizes.
Definition: Transforms.h:198
LinalgPromotionOptions & setUseFullTileBuffers(ArrayRef< bool > useFullTiles)
Definition: Transforms.h:271
LinalgTilingOptions & setPeeledLoops(ArrayRef< int64_t > loops)
Definition: Transforms.h:697
GeneralizePadOpPattern(MLIRContext *context, OptimizeCopyFn optimizeCopyFn=nullptr, PatternBenefit benefit=1)
Definition: Transforms.h:1203
Structure to control the behavior of vector transform patterns.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:56
This class represents an operand of an operation.
Definition: Value.h:251
std::function< bool(OpOperand *fusedOperand)> ControlFusionFn
Function type which is used to control when to stop fusion.
Definition: Transforms.h:69
void populateLinalgNamedOpsGeneralizationPatterns(RewritePatternSet &patterns, const LinalgTransformationFilter &filter=LinalgTransformationFilter())
Linalg generalization patterns.
Vector lowering options control how ops are lowered down to 1-D and scf.for form. ...
Definition: Transforms.h:1001
LogicalResult matchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const override
Definition: Transforms.h:799
LogicalResult vectorize(RewriterBase &builder, LinalgOp linalgOp)
Emit a suitable vector form for a Linalg op with fully static shape.
void populateSplitReductionPattern(RewritePatternSet &patterns, const ControlSplitReductionFn &controlSplitReductionFn, const LinalgTransformationFilter &f=LinalgTransformationFilter(), bool useAlloc=false)
Patterns to apply splitReduction below.
void hoistRedundantVectorTransfersOnTensor(func::FuncOp func)
Same behavior as hoistRedundantVectorTransfers but works on tensors instead of buffers.
Definition: Hoisting.cpp:349
void populateElementwiseOpsFusionPatterns(RewritePatternSet &patterns, const ControlFusionFn &controlElementwiseOpFusion)
Patterns for fusing linalg operation on tensors.
Linalg vectorization patterns.
Definition: Transforms.h:957
Filterless version of the above.
Definition: Transforms.h:1429
LogicalResult applyStagedPatterns(Operation *op, ArrayRef< FrozenRewritePatternSet > stage1Patterns, const FrozenRewritePatternSet &stage2Patterns, function_ref< LogicalResult(Operation *)> stage3Lambda=nullptr)
Helper function to allow applying rewrite patterns, interleaved with more global transformations, in a staged fashion:
Definition: Transforms.cpp:598
FailureOr< scf::ForeachThreadOp > findTopLevelForeachThreadOp(Operation *target)
Finds the top level scf::ForeachThreadOp of given target.
LinalgPromotionOptions & setUseFullTileBuffersByDefault(bool use)
Definition: Transforms.h:282
void peelTiledLinalgOp(RewriterBase &rewriter, TiledLinalgOp &res, ArrayRef< int64_t > peeledLoops, LinalgTilingLoopType loopType)
Peel the loops of a TiledLinalgOp.
Definition: Transforms.cpp:341
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:50
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
std::function< std::pair< int64_t, unsigned >(LinalgOp op)> ControlSplitReductionFn
Function signature to control reduction splitting.
Definition: Transforms.h:1373
SmallVector< Operation *, 8 > loops
Definition: Transforms.h:200
LinalgTilingOptions & setTileSizeComputationFunction(TileSizeComputationFunction fun)
Definition: Transforms.h:640
FailureOr< ForeachThreadTilingResult > tileToForeachThreadOp(RewriterBase &builder, TilingInterface op, ArrayRef< OpFoldResult > numThreads, ArrayRef< int64_t > threadDimMapping={})
Definition: Tiling.cpp:360
FailureOr< LinalgLoops > linalgOpToParallelLoops(PatternRewriter &rewriter, LinalgOp linalgOp)
Emit a loop nest of scf.parallel with the proper body for linalgOp.
Definition: Loops.cpp:378
RewritePatternSet getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx)
Canonicalization patterns relevant to apply after tiling patterns.
Definition: Tiling.cpp:702
Options that allow distribution of loops generated in Linalg transforms to processors while generatin...
Definition: Utils.h:370
This class helps build Operations.
Definition: Builders.h:196
FailureOr< ForeachThreadTilingResult > tileToForeachThreadOpUsingTileSizes(RewriterBase &builder, TilingInterface op, ArrayRef< OpFoldResult > tileSizes, ArrayRef< int64_t > threadDimMapping={})
Same as tileToForeachThreadOp, but calculate the number of threads required using the given tileSizes...
Definition: Tiling.cpp:369
A description of a multi-size tiling comprising tile sizes and numbers of tiles, expressed as Values ...
Definition: Transforms.h:459
LinalgEnablingOptions & enableHoistRedundantVectorTransfersOnTensor(bool val=true)
Definition: Transforms.h:993
LinalgTilingAndFusionOptions & setTileSizes(ArrayRef< int64_t > ts)
Definition: Transforms.h:617
MLIRContext * getContext() const
When lowering an N-d vector transfer op to an (N-1)-d vector transfer op, a temporary buffer is creat...
Definition: VectorToSCF.h:52
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:398
LinalgTilingOptions & setTileSizes(const SmallVector< Value, 4 > &ts)
Set the tileSizeComputationFunction to return the values ts.
Definition: Transforms.h:647
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:370
The main mechanism for performing data layout queries.
vector::VectorTransformsOptions vectorTransformOptions
Configure late vector transformations.
Definition: Transforms.h:1074
LinalgPaddingOptions & setTransposePaddings(ArrayRef< SmallVector< int64_t >> tp)
Definition: Transforms.h:608