MLIR  14.0.0git
Transforms.cpp
Go to the documentation of this file.
1 //===- Transforms.cpp - Linalg transformations as patterns ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic and helpers to expose Linalg transforms as rewrite
10 // patterns.
11 //
12 //===----------------------------------------------------------------------===//
13 
25 #include "mlir/IR/AffineExpr.h"
26 #include "mlir/IR/Matchers.h"
27 #include "mlir/Pass/Pass.h"
28 #include "mlir/Support/LLVM.h"
30 #include "llvm/ADT/ScopeExit.h"
31 #include "llvm/ADT/TypeSwitch.h"
32 #include "llvm/Support/Debug.h"
33 #include "llvm/Support/raw_ostream.h"
34 #include <type_traits>
35 #include <utility>
36 
37 #define DEBUG_TYPE "linalg-transforms"
38 
39 using namespace mlir;
40 using namespace mlir::linalg;
41 
42 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
43 
44 //===----------------------------------------------------------------------===//
45 // Transformations exposed as rewrite patterns.
46 //===----------------------------------------------------------------------===//
47 // Marker used as attribute name in generated Linalg rewriting transformations.
49  "__internal_linalg_transform__";
50 
52  ArrayRef<StringAttr> matchDisjunction, Optional<StringAttr> replacement)
53  : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
54  replacement(replacement), matchByDefault(false) {}
55 
57  const FilterFunction &f, ArrayRef<StringAttr> matchDisjunction,
58  Optional<StringAttr> replacement)
59  : filters(),
60  matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
61  replacement(replacement), matchByDefault(false) {
62  if (f)
63  filters.push_back(f);
64 }
65 
67  PatternRewriter &rewriter, Operation *op) const {
68  if (llvm::any_of(filters,
69  [&](const FilterFunction &f) { return failed(f(op)); }))
70  return failure();
71 
72  auto attr = op->template getAttrOfType<StringAttr>(
74 
75  if (!attr) {
76  // 1. Has no filter case and matchDisjunction is empty.
77  if (matchDisjunction.empty() || matchByDefault)
78  return success();
79 
80  // 2. Has no filter but was expecting a filter.
81  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
82  diag << " does not have any filter from list: ";
83  interleaveComma(matchDisjunction, diag);
84  });
85  }
86 
87  // 4. Match explicit filter.
88  for (auto filter : matchDisjunction)
89  if (attr.getValue() == filter)
90  return success();
91 
92  // 5. Fail to match.
93  return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
94  diag << " does not have any filter from list: ";
95  interleaveComma(matchDisjunction, diag);
96  });
97 }
98 
101  Operation *op) const {
102  if (replacement.hasValue())
104  replacement.getValue());
105  else
106  op->removeAttr(
108 }
109 
111  Operation *op) const {
112  if (!replacement)
113  return false;
115  .dyn_cast<StringAttr>();
116  return attr && attr == replacement.getValue();
117 }
118 
121  assert(!tileSizeComputationFunction && "tile sizes already set");
122  SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end());
123  tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
124  OpBuilder::InsertionGuard guard(b);
126  &op->getParentOfType<FuncOp>().getBody().front());
127  return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
128  Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s);
129  return v;
130  }));
131  };
132  return *this;
133 }
134 
136  assert(!tileSizeComputationFunction && "tile sizes already set");
137  tileSizeComputationFunction = [](OpBuilder &b, Operation *op) {
138  SmallVector<Value, 4> tileSizes;
139  auto linalgOp = dyn_cast<LinalgOp>(op);
140  if (!linalgOp)
141  return tileSizes;
142  Location loc = linalgOp.getLoc();
143  auto allShapeSizes = linalgOp.createFlatListOfOperandDims(b, loc);
144  AffineMap map = linalgOp.getShapesToLoopsMap();
145  if (!map)
146  return tileSizes;
147  auto shapeSizes = applyMapToValues(b, loc, map, allShapeSizes);
148  // If the shape size is dynamic, tile by 1. Otherwise, do not tile (tile
149  // size 0).
150  for (Value shapeSize : shapeSizes)
151  tileSizes.push_back(getConstantIntValue(shapeSize).hasValue()
152  ? b.create<arith::ConstantIndexOp>(loc, 0)
153  : b.create<arith::ConstantIndexOp>(loc, 1));
154  return tileSizes;
155  };
156  return *this;
157 }
158 
159 /// Helper function that tries to pad `opOperand`. Exit early for scalar
160 /// operands, if `paddingFunc` returns failure, or if `opOperand` is not defined
161 /// by an ExtractSliceOp. Otherwise, try to pad the operand even if it already
162 /// has a static shape. Set `result` to the result of the created tensor::PadOp
163 /// or and return success if the operand either has been padded to a static
164 /// shape or already had a static shape and failure otherwise.
166  OpBuilder &b, linalg::LinalgOp opToPad, OpOperand *opOperand,
167  const PaddingValueComputationFunction &paddingFunc,
168  const PaddingNoFoldComputationFunction &nofoldFunc, Value &result) {
169  // Get the shape of the operand and check if it has a dynamic shape. Only
170  // return failure if the operand is not a scalar and has a dynamic shape.
171  ArrayRef<int64_t> shape = opToPad.getShape(opOperand);
172  bool hasDynamicShape = llvm::is_contained(shape, ShapedType::kDynamicSize);
173 
174  // Cannot pad scalar operands.
175  if (shape.empty())
176  return success();
177 
178  // Cannot pad if the padding value is unknown.
179  FailureOr<Value> paddingValue = paddingFunc(b, *opOperand);
180  if (failed(paddingValue))
181  return failure(hasDynamicShape);
182 
183  // Cannot construct a static bounding box if the operand is not defined by an
184  // ExtractSliceOp.
185  auto sliceOp = opOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
186  if (!sliceOp)
187  return failure(hasDynamicShape);
188 
189  // Compute the dropped dimensions if `sliceOp` is ranke-reducing.
190  llvm::SmallDenseSet<unsigned> droppedDims = sliceOp.getDroppedDims();
191 
192  // Upper bound the `sliceOp` sizes to obtain a static bounding box.
193  SmallVector<int64_t> staticSizes;
194  staticSizes.reserve(shape.size());
195  auto shapedOp = cast<OffsetSizeAndStrideOpInterface>(sliceOp.getOperation());
196  for (const auto &en : enumerate(shapedOp.getMixedSizes())) {
197  // Skip dropped dimensions.
198  if (droppedDims.contains(en.index()))
199  continue;
200  // If the size is an attribute add it directly to `staticSizes`.
201  if (en.value().is<Attribute>()) {
202  staticSizes.push_back(
203  en.value().get<Attribute>().dyn_cast<IntegerAttr>().getInt());
204  continue;
205  }
206  // Otherwise, try to compute a constant upper bound for the size value.
207  FailureOr<int64_t> upperBound =
208  getConstantUpperBoundForIndex(en.value().get<Value>());
209  if (failed(upperBound)) {
210  LLVM_DEBUG(DBGS() << "No constant bounding box can be found for padding");
211  return failure();
212  }
213  staticSizes.push_back(upperBound.getValue());
214  }
215  assert(staticSizes.size() == shape.size() &&
216  "expect the dynamic and static ranks to match");
217 
218  // Pad the operand to the bounding box defined by `staticSizes`.
219  auto staticTensorType = RankedTensorType::get(
220  staticSizes, getElementTypeOrSelf(opOperand->get()));
221  bool nofold = nofoldFunc ? nofoldFunc(*opOperand) : false;
222  result =
223  makeComposedPadHighOp(b, opToPad->getLoc(), staticTensorType,
224  opOperand->get(), paddingValue.getValue(), nofold);
225  return success();
226 }
227 
230  const PaddingValueComputationFunction &paddingFunc,
231  const PaddingNoFoldComputationFunction &nofoldFunc,
232  LinalgOp &paddedOp) {
233  Location loc = opToPad->getLoc();
234 
235  // TODO: there are cases where we may still want to pad to larger sizes.
236  assert(opToPad.hasTensorSemantics() &&
237  "expected operation to have tensor semantics");
238 
240  // Set IP after op because we also take the dims of the original output.
241  b.setInsertionPointAfter(opToPad);
242  // Make a copy of the shaped operands and update it.
243  SmallVector<Value> newOperands;
244  newOperands.reserve(opToPad.getNumInputsAndOutputs());
245  for (OpOperand *opOperand : opToPad.getInputAndOutputOperands()) {
246  Value paddedOperand;
247  // If padding was requested but the shape cannot be bounded statically then
248  // the pattern fails to apply.
250  b, opToPad, opOperand, paddingFunc, nofoldFunc, paddedOperand)))
251  return failure();
252  newOperands.push_back(paddedOperand ? paddedOperand : opOperand->get());
253  }
254 
255  SmallVector<SmallVector<Value>> reifiedResultShapes;
256  if (failed(cast<ReifyRankedShapedTypeOpInterface>(opToPad.getOperation())
257  .reifyResultShapes(b, reifiedResultShapes)))
258  return failure();
259  assert(reifiedResultShapes.size() == opToPad->getNumResults() &&
260  "expected same number of results");
261 
262  // Clone `opToPad` to operate on the statically padded shapes.
263  auto resultTensorTypes =
264  ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes();
265  paddedOp = opToPad.clone(b, loc, resultTensorTypes, newOperands);
266 
267  // Recover the slice out of the new static results. This keeps the original
268  // linalg op around because it uses the dims of the original results.
269  SmallVector<Value> paddedSubviewResults;
270  paddedSubviewResults.reserve(opToPad->getNumResults());
271  for (const auto &en : llvm::enumerate(paddedOp->getResults())) {
272  Value paddedResult = en.value();
273  int64_t resultNumber = en.index();
274  int64_t rank = paddedResult.getType().cast<RankedTensorType>().getRank();
275  SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
277  for (Value v : reifiedResultShapes[resultNumber])
278  sizes.push_back(getAsOpFoldResult(v));
279  SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
280  paddedSubviewResults.push_back(b.create<tensor::ExtractSliceOp>(
281  loc, paddedResult, offsets, sizes, strides));
282  }
283  return paddedSubviewResults;
284 }
285 
286 /// Try to peel a loop `op` and return the new result.
287 // TODO: Add support for scf.parallel and affine.for loops.
290  .Case<scf::ForOp>([&](scf::ForOp forOp) {
291  scf::ForOp partialIteration;
292  if (succeeded(scf::peelAndCanonicalizeForLoop(rewriter, forOp,
293  partialIteration)))
294  return partialIteration->getResults();
295  assert(!partialIteration && "expected that loop was not peeled");
296  return forOp->getResults();
297  })
298  .Default([&](Operation *op) { return op->getResults(); });
299 }
300 
301 /// Try to peel a TiledLoopOp and return the new result.
303  TiledLoopOp tiledLoop, int64_t idx) {
304  assert(idx < static_cast<int64_t>(tiledLoop.iterator_types().size()) &&
305  "requested peeling of non-existing loop");
306  TiledLoopOp result;
307  if (succeeded(peelAndCanonicalizeTiledLoop(rewriter, tiledLoop, idx, result)))
308  return result->getResults();
309  assert(!result && "expected that loop was not peeled");
310  return tiledLoop->getResults();
311 }
312 
313 /// Peel loops after tiling.
315  ArrayRef<int64_t> peeledLoops,
316  LinalgTilingLoopType loopType) {
317  for (int64_t loop : peeledLoops) {
318  assert(loop < static_cast<int64_t>(res.loops.size()) &&
319  "requested peeling of non-existing loop");
320  SmallVector<Value, 4> loopResults;
321  Operation *loopOp = res.loops[loop];
322  if (loopType == LinalgTilingLoopType::TiledLoops) {
323  assert(llvm::all_of(
324  res.loops,
325  [&](Operation *op) { return op == res.loops.front(); }) &&
326  "expected that all loop ops are the same TiledLoopOp");
327  auto tiledLoopOp = dyn_cast<TiledLoopOp>(loopOp);
328  assert(tiledLoopOp && "expected TiledLoopOp");
329  loopResults = peelLoop(rewriter, tiledLoopOp, loop);
330  } else {
331  loopResults = peelLoop(rewriter, loopOp);
332  }
333 
334  // The result of the loop nest may change with peeling.
335  if (res.tensorResults.size() == loopOp->getNumResults() &&
336  std::equal(res.tensorResults.begin(), res.tensorResults.end(),
337  loopOp->getResults().begin()))
338  res.tensorResults = loopResults;
339  }
340 }
341 
343  if (tiledOp.loops.empty())
344  return tiledOp.op.getOperation()->getResults();
345  return tiledOp.loops.front()->getResults();
346 }
347 
348 static ValueRange
350  if (tiledAndFusedOp.fusedLoops.empty())
351  return tiledAndFusedOp.op.getOperation()->getResults();
352  return tiledAndFusedOp.fusedLoops.front()->getResults();
353 }
354 
356  StringRef opName, MLIRContext *context,
357  const LinalgDependenceGraph &dependenceGraph,
358  LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions,
360  LinalgTransformationFilter originalOpMarker, PatternBenefit benefit)
361  : RewritePattern(opName, benefit, context, {}),
362  dependenceGraph(dependenceGraph), tilingOptions(std::move(tilingOptions)),
363  fusionOptions(std::move(fusionOptions)), filter(std::move(f)),
364  fusedOpMarker(std::move(fusedOpMarker)),
365  originalOpMarker(std::move(originalOpMarker)) {}
366 
368  Operation *op, PatternRewriter &rewriter) const {
369  LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
370  // TODO: remove hasIndexSemantics check once index ops are supported.
371  if (!linalgOp || linalgOp.hasIndexSemantics())
372  return failure();
373  if (failed(filter.checkAndNotify(rewriter, linalgOp)))
374  return failure();
375 
376  DenseSet<Operation *> producers;
377  producers.insert(linalgOp);
378  for (auto dependence : dependenceGraph.getDependentOperationsInto(linalgOp)) {
379  Optional<unsigned> operandNumber = dependence.getIndexingOpViewOperandNum();
380  // When looking at dependences into, indexingOp is always OpOperand. We
381  // could assert, but continue if this is not the case.
382  if (!operandNumber)
383  continue;
384  if (!fusionOptions.indicesToFuse.count(operandNumber.getValue()))
385  continue;
386  if (isa<LinalgOp>(dependence.getDependentOp()))
387  producers.insert(dependence.getDependentOp());
388  }
389 
390  SmallVector<LinalgOp, 1> fusionOps;
391  for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie;
392  ++it) {
393  auto producerLinalgOp = dyn_cast<LinalgOp>(&(*it));
394  if (producerLinalgOp && producers.count(producerLinalgOp))
395  fusionOps.push_back(producerLinalgOp);
396  }
397  fusionOps.push_back(linalgOp);
398 
399  SmallVector<Value, 4> tileSizes =
400  tilingOptions.tileSizeComputationFunction(rewriter, op);
401  LinalgTilingOptions instanceTilingOptions = tilingOptions;
402  instanceTilingOptions.setTileSizes(tileSizes);
404  rewriter, fusionOps, dependenceGraph, instanceTilingOptions);
405  if (!tiledAndFusedOps)
406  return failure();
407 
408  // Tile the unfused loops;
409  SmallVector<Value, 4> unfusedLoopTileSizes;
410  Value zero = rewriter.create<arith::ConstantIndexOp>(op->getLoc(), 0);
411  for (const auto &tileSize : enumerate(tileSizes)) {
412  if (tiledAndFusedOps->fusedLoopDims.count(tileSize.index()))
413  unfusedLoopTileSizes.push_back(zero);
414  else
415  unfusedLoopTileSizes.push_back(tileSize.value());
416  }
417  // Tile the loop only if there is a non-zero tile size.
418  if (unfusedLoopTileSizes.size() > linalgOp.getNumLoops())
419  unfusedLoopTileSizes.resize(linalgOp.getNumLoops());
420  if (llvm::any_of(unfusedLoopTileSizes, [](Value val) {
421  if (auto cst = val.getDefiningOp<arith::ConstantIndexOp>())
422  return cst.value() != 0;
423  return true;
424  })) {
425  LinalgTilingOptions unfusedTilingOptions = tilingOptions;
426  unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes);
427  FailureOr<TiledLinalgOp> unfusedTiledOp =
428  tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions);
429  if (failed(unfusedTiledOp))
430  return failure();
431  rewriter.replaceOp(tiledAndFusedOps->op,
432  getTiledOpResult(unfusedTiledOp.getValue()));
433  tiledAndFusedOps->op = unfusedTiledOp->op;
434  }
435  op->replaceAllUsesWith(getTiledAndFusedOpResult(tiledAndFusedOps.getValue()));
436 
437  filter.replaceLinalgTransformationFilter(rewriter,
438  tiledAndFusedOps->op.getOperation());
439  for (auto fusedOp : tiledAndFusedOps->fusedProducers) {
440  fusedOpMarker.replaceLinalgTransformationFilter(rewriter,
441  fusedOp.getOperation());
442  }
443  for (auto origProducerOp : ArrayRef<LinalgOp>(fusionOps).drop_back()) {
444  originalOpMarker.replaceLinalgTransformationFilter(
445  rewriter, origProducerOp.getOperation());
446  }
447  rewriter.updateRootInPlace(op, [&]() {
448  originalOpMarker.replaceLinalgTransformationFilter(rewriter, op);
449  });
450  return success();
451 }
452 
453 /// Linalg tiling pattern.
457  : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
458  filter(std::move(f)), options(std::move(options)) {}
459 
461  StringRef opName, MLIRContext *context, LinalgTilingOptions options,
463  : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
464  filter(f.addOpNameFilter(opName)), options(std::move(options)) {}
465 
468  LinalgOp op, PatternRewriter &rewriter) const {
469  if (failed(filter.checkAndNotify(rewriter, op)))
470  return failure();
471 
472  FailureOr<TiledLinalgOp> res = tileLinalgOp(rewriter, op, options);
473  if (failed(res))
474  return failure();
475 
476  // Clear filter to stop recursive pattern application.
477  // This must be done here to properly propagate to peeling branches.
478  filter.replaceLinalgTransformationFilter(rewriter, res->op);
479 
480  // Peel the loops of the TiledLinalgOp.
481  peelTiledLinalgOp(rewriter, *res, options.peeledLoops, options.loopType);
482 
483  if (res->tensorResults.empty())
484  rewriter.eraseOp(op);
485  else
486  rewriter.replaceOp(op, res->tensorResults);
487 
488  return res;
489 }
490 
491 /// Linalg padding pattern.
493  MLIRContext *context, LinalgPaddingOptions options,
495  : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
496  filter(std::move(f)), options(std::move(options)) {}
497 
499  StringRef opName, MLIRContext *context, LinalgPaddingOptions options,
501  : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
502  filter(f.addOpNameFilter(opName)), options(std::move(options)) {}
503 
506  LinalgOp linalgOp, PatternRewriter &rewriter) const {
507  if (!linalgOp.hasTensorSemantics())
508  return failure();
509  if (failed(filter.checkAndNotify(rewriter, linalgOp)))
510  return failure();
511 
512  // Pad the operation.
513  LinalgOp paddedOp;
515  rewriter, linalgOp, options.paddingValueComputationFunction,
516  options.paddingNoFoldComputationFunction, paddedOp);
517  if (failed(newResults))
518  return failure();
519 
520  // Compute the desired hoisting depths.
521  SmallVector<int64_t> depths;
522  if (options.paddingHoistComputationFunction) {
523  for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands())
524  depths.push_back(options.paddingHoistComputationFunction(*opOperand));
525  }
526 
527  // Hoist the padding.
528  for (const auto &en : enumerate(depths)) {
529  OpOperand &opOperand = paddedOp->getOpOperand(en.index());
530  auto padOp = opOperand.get().getDefiningOp<tensor::PadOp>();
531  if (!padOp || en.value() == 0)
532  continue;
533  tensor::PadOp hoistedOp;
534  SmallVector<GenericOp> transposeOps;
535  SmallVector<int64_t> transposeVector =
536  options.paddingTransposeComputationFunction(opOperand);
537 
539  padOp, en.value(), transposeVector, hoistedOp, transposeOps);
540  if (failed(newResult))
541  continue;
542  rewriter.replaceOp(padOp, newResult.getValue());
543 
544  // Do not apply hoist padding to the newly introduced transpose operations.
545  for (GenericOp transposeOp : transposeOps)
546  filter.replaceLinalgTransformationFilter(rewriter, transposeOp);
547  }
548 
549  // Replace the original operation to pad.
550  rewriter.replaceOp(linalgOp, newResults.getValue());
551  filter.replaceLinalgTransformationFilter(rewriter, paddedOp);
552 
553  return paddedOp;
554 }
555 
556 /// Linalg tile and fuse tensor ops pattern.
561  PatternBenefit benefit)
562  : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
563  filter(std::move(f)), options(std::move(options)) {}
564 
569  PatternBenefit benefit)
570  : RewritePattern(opName, benefit, context), filter(std::move(f)),
571  options(std::move(options)) {}
572 
574  Operation *op, PatternRewriter &rewriter) const {
575  LinalgOp rootOp = dyn_cast<LinalgOp>(op);
576  if (!rootOp)
577  return failure();
578  if (failed(filter.checkAndNotify(rewriter, op)))
579  return failure();
580 
581  // Check `tileSizes` contains a tile size for every `rootOp` loop dimension.
582  if (options.tileSizes.size() < rootOp.getNumLoops())
583  return rewriter.notifyMatchFailure(op, "expect #tile sizes >= #loops");
584 
585  // Check `tileInterchange` contains no entries or as many as `tileSizes`.
586  if (!options.tileInterchange.empty() &&
587  options.tileInterchange.size() != options.tileSizes.size())
588  return rewriter.notifyMatchFailure(
589  op, "expect the number of tile sizes and interchange dims to match");
590 
591  // Copy the `tileSizes` and `tileInterchange` prefixes needed for `rootOp`.
592  SmallVector<int64_t> rootTileSizes(options.tileSizes.begin(),
593  options.tileSizes.begin() +
594  rootOp.getNumLoops());
595  SmallVector<int64_t> rootInterchange =
596  options.tileInterchange.empty()
597  ? llvm::to_vector<6>(llvm::seq<int64_t>(0, rootOp.getNumLoops()))
598  : SmallVector<int64_t>(options.tileInterchange.begin(),
599  options.tileInterchange.begin() +
600  rootOp.getNumLoops());
601 
602  // Check `rootInterchange` is a permutation of the `rootOp` loop dimensions.
603  // It has to be a permutation since the tiling cannot tile the same loop
604  // dimension multiple times.
605  if (!isPermutation(rootInterchange))
606  return rewriter.notifyMatchFailure(
607  op, "expect the tile interchange permutes the root loops");
608 
609  // Tile `rootOp` and fuse its producers.
611  rewriter, rootOp, rootTileSizes, rootInterchange);
612  if (failed(tileLoopNest))
613  return rewriter.notifyMatchFailure(
614  op, "tileConsumerAndFuseProducers failed unexpectedly");
615 
616  // Replace all uses of the tiled loop operation.
617  rootOp->replaceAllUsesWith(tileLoopNest->getRootOpReplacementResults());
618 
619  // Apply the filter if specified.
620  for (LinalgOp linalgOp : tileLoopNest->getAllTiledAndFusedOps())
621  filter.replaceLinalgTransformationFilter(rewriter, linalgOp);
622  return failure();
623 }
624 
625 /// Linalg generic interchange pattern.
627  MLIRContext *context, ArrayRef<unsigned> interchangeVector,
629  : OpRewritePattern(context, benefit), filter(std::move(f)),
630  interchangeVector(interchangeVector.begin(), interchangeVector.end()) {}
631 
634  GenericOp genericOp, PatternRewriter &rewriter) const {
635  if (failed(filter.checkAndNotify(rewriter, genericOp)))
636  return failure();
637 
638  FailureOr<GenericOp> transformedOp =
639  interchangeGenericOp(rewriter, genericOp, interchangeVector);
640  if (failed(transformedOp))
641  return failure();
642 
643  // New filter if specified.
644  filter.replaceLinalgTransformationFilter(rewriter, genericOp);
645  return transformedOp;
646 }
647 
648 /// Linalg generalization pattern.
651  : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
652  filter(std::move(f)) {}
653 
655  StringRef opName, MLIRContext *context, LinalgTransformationFilter f,
656  PatternBenefit benefit)
657  : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
658  filter(f.addOpNameFilter(opName)) {}
659 
662  LinalgOp linalgOp, PatternRewriter &rewriter) const {
663  if (failed(filter.checkAndNotify(rewriter, linalgOp)))
664  return failure();
665  FailureOr<GenericOp> genericOp = generalizeNamedOp(rewriter, linalgOp);
666  if (failed(genericOp))
667  return failure();
668  filter.replaceLinalgTransformationFilter(rewriter, *genericOp);
669  return genericOp;
670 }
671 
675  : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
676  filter(std::move(f)), options(std::move(options)) {}
677 
679  StringRef opName, MLIRContext *context, LinalgPromotionOptions options,
681  : RewritePattern(opName, benefit, context, {}), filter(std::move(f)),
682  options(std::move(options)) {}
683 
685  Operation *op, PatternRewriter &rewriter) const {
686  if (failed(filter.checkAndNotify(rewriter, op)))
687  return failure();
688  if (failed(promoteSubviewsPrecondition(op, options)))
689  return failure();
690 
691  // TODO: We cannot use root update here. This pattern is creating other ops,
692  // so if the promotion fails, those need to be cleaned up, which doesnt seem
693  // to be happening here. So to fail properly, we should be cloning the op and
694  // deleting the previous op. This needs more investigation.
695  rewriter.startRootUpdate(op);
696  Optional<LinalgOp> promotedOp = promoteSubViews(rewriter, op, options);
697  if (!promotedOp) {
698  rewriter.cancelRootUpdate(op);
699  return op->emitError("subview promotion failed");
700  }
701  rewriter.finalizeRootUpdate(op);
702  filter.replaceLinalgTransformationFilter(rewriter, op);
703  return success();
704 }
705 
709  : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
710  filter(std::move(f)) {}
711 
713  StringRef opName, MLIRContext *context, LinalgVectorizationOptions options,
715  : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
716  filter(f.addOpNameFilter(opName)) {}
717 
719  LinalgOp linalgOp, PatternRewriter &rewriter) const {
720  if (failed(filter.checkAndNotify(rewriter, linalgOp)))
721  return failure();
722  return vectorize(rewriter, linalgOp);
723 }
724 
726  Operation *op, ArrayRef<FrozenRewritePatternSet> stage1Patterns,
727  const FrozenRewritePatternSet &stage2Patterns,
728  function_ref<LogicalResult(Operation *)> stage3Lambda) {
729  unsigned iteration = 0;
730  (void)iteration;
731  for (const auto &patterns : stage1Patterns) {
732  LLVM_DEBUG(DBGS() << "Before 1st stage, iter: " << ++iteration << "\n"
733  << *op);
734  if (failed(applyPatternsAndFoldGreedily(op, patterns))) {
735  LLVM_DEBUG(DBGS() << "Underlying first stage rewrite did not converge");
736  return failure();
737  }
738  LLVM_DEBUG(DBGS() << "After 1st stage, iter: " << ++iteration << "\n"
739  << *op);
740  if (failed(applyPatternsAndFoldGreedily(op, stage2Patterns))) {
741  LLVM_DEBUG(DBGS() << "Underlying 2nd stage rewrite did not converge");
742  return failure();
743  }
744  LLVM_DEBUG(DBGS() << "After 2nd stage, iter : " << iteration << "\n"
745  << *op);
746  if (stage3Lambda) {
747  if (failed(stage3Lambda(op)))
748  return failure();
749  LLVM_DEBUG(DBGS() << "After 3rd stage, iter : " << iteration << "\n"
750  << *op);
751  }
752  }
753  return success();
754 }
755 
756 static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) {
757  return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName());
758 }
759 
760 /// Rewrite a tensor::PadOp into a sequence of InitTensorOp, FillOp (to
761 /// initialize with pad_val) and GenericOp (to copy contents).
764  PatternRewriter &rewriter) const {
765 
766  auto inputShapedType = padOp.source().getType().cast<ShapedType>();
767  auto resultShapedType = padOp.result().getType().cast<ShapedType>();
768 
769  // Bail on non-static shapes.
770  if (!inputShapedType.hasStaticShape())
771  return failure();
772  if (!resultShapedType.hasStaticShape())
773  return failure();
774 
775  // Only support padding with a constant for now, i.e. either:
776  // 1. A BBarg from a different block.
777  // 2. A value defined outside of the current block.
778  Block &block = padOp.region().front();
779  auto yieldOp = cast<tensor::YieldOp>(block.getTerminator());
780  Value padValue = yieldOp.value();
781  Operation *definingOp = padValue.getDefiningOp();
782  if (definingOp && definingOp->getBlock() == &block)
783  return failure();
784  if (!definingOp && padValue.cast<BlockArgument>().getOwner() == &block)
785  return failure();
786 
787  // Create tensor with the padded shape
788  Location loc = padOp.getLoc();
789  SmallVector<Value> indices(resultShapedType.getRank(),
790  rewriter.create<arith::ConstantIndexOp>(loc, 0));
791  Value initTensor = rewriter.create<InitTensorOp>(
792  loc, resultShapedType.getShape(), resultShapedType.getElementType());
793 
794  // Initialize tensor with the pad value
795  Value tmpTensor =
796  rewriter.create<linalg::FillOp>(loc, padValue, initTensor).result();
797 
798  // Copy original contents into new tensor
799  // Uses linalg.generic, but could be done with tensor.insert_slice
800  SmallVector<AffineExpr, 4> outputExprs;
801  for (unsigned i = 0; i < resultShapedType.getRank(); ++i) {
802  outputExprs.push_back(getAffineDimExpr(i, rewriter.getContext()) +
803  padOp.static_low()[i].cast<IntegerAttr>().getInt());
804  }
805 
806  SmallVector<AffineMap, 2> transferMaps = {
807  rewriter.getMultiDimIdentityMap(inputShapedType.getRank()),
808  AffineMap::get(resultShapedType.getRank(),
809  /*symbolCount=*/0, outputExprs, rewriter.getContext())};
810 
811  rewriter.replaceOpWithNewOp<linalg::GenericOp>(
812  padOp, resultShapedType, padOp.source(), tmpTensor, transferMaps,
813  getNParallelLoopsAttrs(resultShapedType.getRank()),
814  [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
815  nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
816  });
817 
818  return success();
819 }
820 
821 /// Filling `dest` using FillOp constant padding value if possible.
822 /// Otherwise, generate a tensor::GenerateOp.
824  PatternRewriter &rewriter, tensor::PadOp padOp, Value dest,
825  const SmallVector<Value> &dynSizes) const {
826  auto padValue = padOp.getConstantPaddingValue();
827  if (padValue)
828  return rewriter.create<FillOp>(padOp.getLoc(), padValue, dest).result();
829 
830  // Fill could not be optimized: Lower to tensor::GenerateOp with region.
831  auto generateOp = rewriter.create<tensor::GenerateOp>(
832  padOp.getLoc(), padOp.getResultType(), dynSizes);
833  // Copy region to new op.
835  padOp.region().cloneInto(&generateOp.getRegion(), bvm);
836  return generateOp;
837 }
838 
841  PatternRewriter &rewriter) const {
842  // Given an OpFoldResult, return an index-typed value.
843  auto getIdxValue = [&](OpFoldResult ofr) {
844  if (auto val = ofr.dyn_cast<Value>())
845  return val;
846  return rewriter
848  padOp.getLoc(), ofr.get<Attribute>().cast<IntegerAttr>().getInt())
849  .getResult();
850  };
851 
852  auto resultType = padOp.getResultType();
853  // Compute size of InitTensorOp. Any combination of static/dynamic is
854  // supported.
855  SmallVector<Value> dynSizes;
856  SmallVector<int64_t> staticSizes;
857  for (unsigned dim = 0; dim < resultType.getRank(); ++dim) {
858  if (resultType.isDynamicDim(dim)) {
859  auto srcSize = rewriter.createOrFold<tensor::DimOp>(padOp.getLoc(),
860  padOp.source(), dim);
861  // Add low and high padding value.
862  auto plusLow = rewriter.createOrFold<arith::AddIOp>(
863  padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim]));
864  auto plusHigh = rewriter.createOrFold<arith::AddIOp>(
865  padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim]));
866  dynSizes.push_back(plusHigh);
867  }
868  staticSizes.push_back(resultType.getDimSize(dim));
869  }
870 
871  // Init tensor and fill it with padding.
872  Value init = rewriter.create<InitTensorOp>(
873  padOp.getLoc(), dynSizes, staticSizes, resultType.getElementType());
874  Value fill = createFillOrGenerateOp(rewriter, padOp, init, dynSizes);
875 
876  // Try optimize the copy of source.
877  if (optimizeCopyFn && optimizeCopyFn(rewriter, padOp, fill).succeeded())
878  return success();
879 
880  // tensor::PadOps cannot be optimized. Generate a InsertSliceOp instead
881  // for copying the PadOp source.
882  auto sourceType = padOp.getSourceType();
883  // Compute size of source of tensor::PadOp.
884  SmallVector<OpFoldResult> srcSizes;
885  for (unsigned dim = 0; dim < sourceType.getRank(); ++dim) {
886  if (sourceType.isDynamicDim(dim)) {
887  srcSizes.push_back(rewriter.createOrFold<tensor::DimOp>(
888  padOp.getLoc(), padOp.source(), dim));
889  } else {
890  srcSizes.push_back(rewriter.getIndexAttr(sourceType.getDimSize(dim)));
891  }
892  }
893  // Strides of InsertSliceOp are all 1.
894  SmallVector<OpFoldResult> strides(sourceType.getRank(),
895  rewriter.getIndexAttr(1));
896  rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
897  padOp, padOp.source(), fill, padOp.getMixedLowPad(), srcSizes, strides);
898 
899  return success();
900 }
901 
903  tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const {
904  auto padOp = sliceOp.source().getDefiningOp<tensor::PadOp>();
905  if (!padOp)
906  return failure();
907  // Only unit stride supported.
908  if (!sliceOp.hasUnitStride())
909  return failure();
910 
911  TilingInterface tilingInterface =
912  dyn_cast<TilingInterface>(padOp.getOperation());
913  Operation *tiledPadOp =
914  tilingInterface
915  .getTiledImplementation(
916  rewriter, /*dest=*/ValueRange{}, sliceOp.getMixedOffsets(),
917  sliceOp.getMixedSizes(), /*tileDestOperands=*/false)
918  .front();
919  // All shapes are static and the data source is actually used. Rewrite into
920  // pad_tensor(subtensor(x)).
921  rewriter.replaceOp(sliceOp, tiledPadOp->getResults());
922  return success();
923 }
924 
925 namespace {
926 // The following are patterns for downscaling convolution ops with size-1
927 // window dimensions.
928 //
929 // Note that we'd eventually want to write such transformations in a generic
930 // way, e.g., converting to linalg.generic, removing the size-1 dimensions,
931 // and then turning back to named ops. But for now it's fine to have a few
932 // patterns matching special ops to get started.
933 
934 /// Rewrites 2-D convolution ops with size-1 window dimensions into 1-D
935 /// convolution ops.
936 struct DownscaleSizeOneWindowed2DConvolution final
937  : public OpRewritePattern<Conv2DNhwcHwcfOp> {
938  DownscaleSizeOneWindowed2DConvolution(
939  MLIRContext *context,
941  PatternBenefit benefit = 1)
942  : OpRewritePattern<Conv2DNhwcHwcfOp>(context, benefit),
943  filter(std::move(f)) {}
944 
945  LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
946  PatternRewriter &rewriter) const override {
947  if (failed(filter.checkAndNotify(rewriter, convOp)))
948  return failure();
949  if (convOp.hasBufferSemantics())
950  return failure(); // To be implemented
951 
952  Value input = convOp.inputs().front();
953  Value kernel = convOp.inputs().back();
954  Value output = convOp.outputs().front();
955 
956  auto inputType = input.getType().dyn_cast<RankedTensorType>();
957  auto kernelType = kernel.getType().dyn_cast<RankedTensorType>();
958  auto outputType = output.getType().dyn_cast<RankedTensorType>();
959 
960  auto kernelShape = kernelType.getShape();
961  auto outputShape = outputType.getShape();
962 
963  // Only handle the case where at least one of the window dimensions is
964  // of size 1. Other cases can rely on tiling to reduce to such cases.
965  int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
966  int64_t ohSize = outputShape[1], owSize = outputShape[2];
967  bool removeH = (khSize == 1 && ohSize == 1);
968  bool removeW = (kwSize == 1 && owSize == 1);
969  if (!removeH && !removeW)
970  return failure();
971 
972  // Get new shapes and types for all operands by removing the size-1
973  // dimension.
974  using RTTBuilder = RankedTensorType::Builder;
975  RankedTensorType newInputType =
976  RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
977  RankedTensorType newKernelType =
978  RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
979  RankedTensorType newOutputType =
980  RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
981 
982  // Rank-reduce operands.
983  Location loc = convOp.getLoc();
985  rewriter, loc, input, newInputType);
987  rewriter, loc, kernel, newKernelType);
989  rewriter, loc, output, newOutputType);
990 
991  // Rank-reduce strides and dilations too.
992  // TODO: dropDim 1-liner helper.
993  auto strides = llvm::to_vector<4>(convOp.strides().getValues<int64_t>());
994  strides.erase(strides.begin() + (removeH ? 0 : 1));
995  auto stridesAttr = rewriter.getI64VectorAttr(strides);
996 
997  auto dilations =
998  llvm::to_vector<4>(convOp.dilations().getValues<int64_t>());
999  dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1000  auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
1001 
1002  auto conv1DOp = rewriter.create<linalg::Conv1DNwcWcfOp>(
1003  loc, newOutputType, ValueRange{newInput, newKernel},
1004  ValueRange{newOutput}, stridesAttr, dilationsAttr);
1005 
1006  // Insert back.
1008  rewriter, loc, conv1DOp.getResult(0), output);
1009  rewriter.replaceOp(convOp, inserted);
1010 
1011  filter.replaceLinalgTransformationFilter(rewriter, conv1DOp);
1012  return success();
1013  };
1014 
1015 private:
1016  /// LinalgTransformMarker handles special attribute manipulations.
1018 };
1019 
1020 /// Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh)
1021 /// dimensions into 1-D depthwise convolution ops.
1022 struct DownscaleDepthwiseConv2DNhwcHwcOp final
1023  : public OpRewritePattern<DepthwiseConv2DNhwcHwcOp> {
1024  DownscaleDepthwiseConv2DNhwcHwcOp(
1025  MLIRContext *context,
1027  PatternBenefit benefit = 1)
1028  : OpRewritePattern<DepthwiseConv2DNhwcHwcOp>(context, benefit),
1029  filter(std::move(f)) {}
1030 
1031  LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp,
1032  PatternRewriter &rewriter) const override {
1033  if (failed(filter.checkAndNotify(rewriter, convOp)))
1034  return failure();
1035  if (convOp.hasBufferSemantics())
1036  return failure(); // To be implemented
1037 
1038  Value input = convOp.inputs().front();
1039  Value kernel = convOp.inputs().back();
1040  Value output = convOp.outputs().front();
1041 
1042  auto inputType = input.getType().dyn_cast<RankedTensorType>();
1043  auto kernelType = kernel.getType().dyn_cast<RankedTensorType>();
1044  auto outputType = output.getType().dyn_cast<RankedTensorType>();
1045 
1046  auto kernelShape = kernelType.getShape();
1047  auto outputShape = outputType.getShape();
1048 
1049  // Only handle the case where at least one of the window dimensions is
1050  // of size 1. Other cases can rely on tiling to reduce to such cases.
1051  int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1052  int64_t ohSize = outputShape[1], owSize = outputShape[2];
1053  bool removeH = (khSize == 1 && ohSize == 1);
1054  bool removeW = (kwSize == 1 && owSize == 1);
1055  if (!removeH && !removeW)
1056  return failure();
1057 
1058  // Get new shapes and types for all operands by removing the size-1
1059  // dimension.
1060  using RTTBuilder = RankedTensorType::Builder;
1061  RankedTensorType newInputType =
1062  RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
1063  RankedTensorType newKernelType =
1064  RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
1065  RankedTensorType newOutputType =
1066  RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
1067 
1068  // Rank-reduce operands.
1069  Location loc = convOp.getLoc();
1071  rewriter, loc, input, newInputType);
1073  rewriter, loc, kernel, newKernelType);
1075  rewriter, loc, output, newOutputType);
1076 
1077  // Rank-reduce strides and dilations too.
1078  // TODO: dropDim 1-liner helper.
1079  auto strides = llvm::to_vector<4>(convOp.strides().getValues<int64_t>());
1080  strides.erase(strides.begin() + (removeH ? 0 : 1));
1081  auto stridesAttr = rewriter.getI64VectorAttr(strides);
1082 
1083  auto dilations =
1084  llvm::to_vector<4>(convOp.dilations().getValues<int64_t>());
1085  dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1086  auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
1087 
1088  auto conv1DOp = rewriter.create<DepthwiseConv1DNwcWcOp>(
1089  loc, newOutputType, ValueRange{newInput, newKernel},
1090  ValueRange{newOutput}, stridesAttr, dilationsAttr);
1091 
1092  // Insert back.
1094  rewriter, loc, conv1DOp.getResult(0), output);
1095  rewriter.replaceOp(convOp, inserted);
1096 
1097  filter.replaceLinalgTransformationFilter(rewriter, conv1DOp);
1098  return success();
1099  };
1100 
1101 private:
1102  /// LinalgTransformMarker handles special attribute manipulations.
1104 };
1105 
1106 } // namespace
1107 
1109  RewritePatternSet &patterns, const LinalgTransformationFilter &filter,
1110  PatternBenefit benefit) {
1111  patterns.add<DownscaleSizeOneWindowed2DConvolution,
1112  DownscaleDepthwiseConv2DNhwcHwcOp>(patterns.getContext(), filter,
1113  benefit);
1114 }
Include the generated interface declarations.
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
Definition: Transforms.cpp:573
Helper class to control application of linalg transformation patterns.
Definition: Transforms.h:428
OpTy create(Location location, Args &&...args)
Create an operation of specific op type at the current insertion point.
Definition: Builders.h:430
void replaceLinalgTransformationFilter(PatternRewriter &rewriter, Operation *op) const
Definition: Transforms.cpp:100
iterator begin()
Definition: Block.h:134
static std::string diag(llvm::Value &v)
DenseIntElementsAttr getI64VectorAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:115
MLIRContext * getContext() const
Definition: Builders.h:54
constexpr StringRef getParallelIteratorTypeName()
Use to encode that a particular iterator type has parallel semantics.
static ValueRange getTiledOpResult(TiledLinalgOp tiledOp)
Definition: Transforms.cpp:342
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:444
LogicalResult checkAndNotify(PatternRewriter &rewriter, Operation *op) const
Definition: Transforms.cpp:66
#define DBGS()
Definition: Transforms.cpp:42
LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const override
Definition: Transforms.cpp:902
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:881
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:308
LinalgTileAndFuseTensorOpsPattern(MLIRContext *context, LinalgTilingAndFusionOptions options, LinalgTransformationFilter f=LinalgTransformationFilter(), PatternBenefit benefit=1)
Linalg tile and fuse tensor ops pattern.
Definition: Transforms.cpp:558
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
FailureOr< GenericOp > returningMatchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const
matchAndRewrite implementation that returns the significant transformed pieces of IR...
Definition: Transforms.cpp:661
SmallVector< Value, 4 > tensorResults
Definition: Transforms.h:176
LogicalResult peelAndCanonicalizeTiledLoop(RewriterBase &rewriter, TiledLoopOp loopOp, int64_t idx, TiledLoopOp &result)
Rewrite a TiledLoopOp with bounds/step that potentially do not divide evenly into a TiledLoopOp where...
Definition: Loops.cpp:499
void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, const LinalgTransformationFilter &filter=LinalgTransformationFilter(), PatternBenefit benefit=1)
Linalg decompose convolutions patterns.
SmallVector< LinalgDependenceGraphElem, 2 > getDependentOperationsInto(LinalgOp linalgOp, ArrayRef< DependenceType > depTypes={ DependenceType::RAW, DependenceType::WAW}) const
Returns all operations that have a dependence into linalgOp of types listed in depTypes.
This class represents a frozen set of patterns that can be processed by a pattern applicator...
Block represents an ordered list of Operations.
Definition: Block.h:29
TileSizeComputationFunction tileSizeComputationFunction
Computation function that returns the tile sizes for each operation.
Definition: Transforms.h:559
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents a single result from folding an operation.
Definition: OpDefinition.h:244
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
Definition: Transforms.cpp:684
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp namedOp)
Create a GenericOp from the given named operation namedOp and replace namedOp.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const override
Rewrite a tensor::PadOp into a sequence of InitTensorOp, FillOp (to initialize with pad_val) and Gene...
Definition: Transforms.cpp:763
SmallVector< int64_t > tileInterchange
Tile interchange used to permute the tile loops.
Definition: Transforms.h:552
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
Definition: LogicalResult.h:68
Operation & front()
Definition: Block.h:144
LinalgVectorizationPattern(MLIRContext *context, LinalgTransformationFilter f=LinalgTransformationFilter(), LinalgVectorizationOptions options=LinalgVectorizationOptions(), PatternBenefit benefit=1)
Construct a generic pattern applied to all LinalgOp that verify filter.
Definition: Transforms.cpp:706
static ValueRange getTiledAndFusedOpResult(TiledAndFusedLinalgOps tiledAndFusedOp)
Definition: Transforms.cpp:349
virtual void startRootUpdate(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:774
LinalgTilingPattern(MLIRContext *context, LinalgTilingOptions options, LinalgTransformationFilter f=LinalgTransformationFilter(), PatternBenefit benefit=1)
Construct a generic pattern applied to all LinalgOp that verify filter.
Definition: Transforms.cpp:454
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:96
LinalgBaseTileAndFusePattern(StringRef opName, MLIRContext *context, const LinalgDependenceGraph &dependenceGraph, LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions, LinalgTransformationFilter f=LinalgTransformationFilter(), LinalgTransformationFilter fusedOpMarker=LinalgTransformationFilter(), LinalgTransformationFilter originalOpMarker=LinalgTransformationFilter(), PatternBenefit benefit=1)
Definition: Transforms.cpp:355
FailureOr< SmallVector< Value > > rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad, const PaddingValueComputationFunction &paddingFunc, const PaddingNoFoldComputationFunction &nofoldFunc, LinalgOp &paddedOp)
Pad the operands of opToPad to a static bounding box.
Definition: Transforms.cpp:229
PaddingValueComputationFunction paddingValueComputationFunction
Callback returning the padding value to use for a given OpOperand or failure for no padding...
Definition: Transforms.h:506
FailureOr< TileLoopNest > tileConsumerAndFuseProducers(OpBuilder &b, LinalgOp consumerOp, ArrayRef< int64_t > tileSizes, ArrayRef< int64_t > tileInterchange)
Tiles consumerOp and fuses its dependencies if possible.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op...
Definition: Interchange.cpp:49
LinalgTilingLoopType loopType
The type of tile loops to generate.
Definition: Transforms.h:591
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:343
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:307
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:244
LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const override
Definition: Transforms.cpp:840
SmallVector< Operation *, 4 > fusedLoops
The fused loop generated.
Definition: Transforms.h:244
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine...
Definition: Diagnostics.h:157
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LinalgGeneralizationPattern(MLIRContext *context, LinalgTransformationFilter f=LinalgTransformationFilter(), PatternBenefit benefit=1)
Construct a generic pattern applied to all LinalgOp that verify filter.
Definition: Transforms.cpp:649
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LinalgTransformationFilter(ArrayRef< StringAttr > matchDisjunction={}, Optional< StringAttr > replacement=None)
Definition: Transforms.cpp:51
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
OpListType::iterator iterator
Definition: Block.h:131
GenericOpInterchangePattern(MLIRContext *context, ArrayRef< unsigned > interchangeVector, LinalgTransformationFilter f=LinalgTransformationFilter(), PatternBenefit benefit=1)
GenericOp-specific constructor with an optional filter.
Definition: Transforms.cpp:626
Value createFillOrGenerateOp(PatternRewriter &rewriter, tensor::PadOp padOp, Value dest, const SmallVector< Value > &dynSizes) const
Filling dest using FillOp constant padding value if possible.
Definition: Transforms.cpp:823
This class provides support for representing a failure result, or a valid value of type T...
Definition: LogicalResult.h:77
FailureOr< TiledAndFusedLinalgOps > tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef< LinalgOp > ops, const LinalgDependenceGraph &dependenceGraph, const LinalgTilingOptions &tilingOptions)
Definition: Fusion.cpp:942
FailureOr< GenericOp > returningMatchAndRewrite(GenericOp op, PatternRewriter &rewriter) const
matchAndRewrite implementation that returns the significant transformed pieces of IR...
Definition: Transforms.cpp:633
PaddingHoistComputationFunction paddingHoistComputationFunction
Callback returning the number of loops to hoist the PadOp defining the given OpOperand.
Definition: Transforms.h:528
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static LogicalResult padOperandToSmallestStaticBoundingBox(OpBuilder &b, linalg::LinalgOp opToPad, OpOperand *opOperand, const PaddingValueComputationFunction &paddingFunc, const PaddingNoFoldComputationFunction &nofoldFunc, Value &result)
Helper function that tries to pad opOperand.
Definition: Transforms.cpp:165
std::function< bool(OpOperand &)> PaddingNoFoldComputationFunction
Callback returning true if the PadOp defining the given OpOperand shall be marked as nofold to enable...
Definition: Transforms.h:489
llvm::SmallSet< unsigned, 1 > indicesToFuse
List of operands indices to use for fusion.
Definition: Transforms.h:709
U dyn_cast() const
Definition: Types.h:244
std::function< FailureOr< Value >(OpBuilder &, OpOperand &)> PaddingValueComputationFunction
Callback returning the padding value to use for a given OpOperand or failure for no padding...
Definition: Transforms.h:485
LogicalResult promoteSubviewsPrecondition(Operation *op, LinalgPromotionOptions options)
Promote memref.subviews feeding linalg-on-buffers operations.
Definition: Promotion.cpp:359
Attributes are known-constant values of operations.
Definition: Attributes.h:24
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:32
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:206
void replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this operation with the provided &#39;values&#39;.
Definition: Operation.h:154
SmallVector< int64_t > tileSizes
Tile sizes used to tile the root operation.
Definition: Transforms.h:550
Data structure for holding a dependence graph that operates on LinalgOp and views as SSA values...
LinalgOp op
Operation obtained by tiling the last operation in sequence of ops passed to tileAndFuseLinalgOps.
Definition: Transforms.h:238
Attribute removeAttr(StringAttr name)
Remove the attribute with the specified name if it exists.
Definition: Operation.h:359
LinalgTilingOptions & scalarizeDynamicDims()
Tile all dynamic dimensions by 1.
Definition: Transforms.cpp:135
Builder & dropDim(unsigned pos)
Erase a dim from shape .
Definition: BuiltinTypes.h:239
FailureOr< LinalgOp > promoteSubViews(OpBuilder &b, LinalgOp op, const LinalgPromotionOptions &options)
Promote the subViews into a new buffer allocated at the insertion point b.
Definition: Promotion.cpp:381
SmallVector< Value, 4 > applyMapToValues(OpBuilder &b, Location loc, AffineMap map, ValueRange values)
Returns the values obtained by applying map to the list of values.
Definition: AffineOps.cpp:742
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:789
LinalgPaddingPattern(MLIRContext *context, LinalgPaddingOptions options=LinalgPaddingOptions(), LinalgTransformationFilter f=LinalgTransformationFilter(), PatternBenefit benefit=1)
Construct a generic pattern applied to all LinalgOp that verify filter.
Definition: Transforms.cpp:492
Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, Value source, Value pad, bool nofold)
Create a tensor::PadOp that pads source to the size of the statically sized type whose static sizes a...
Definition: Utils.cpp:325
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:106
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:133
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
Definition: TensorOps.cpp:1248
A multi-dimensional affine map Affine map&#39;s are immutable like Type&#39;s, and they are uniqued...
Definition: AffineMap.h:38
This class acts as a special tag that makes the desire to match "any" operation type explicit...
Definition: PatternMatch.h:157
U cast() const
Definition: AffineExpr.h:291
FailureOr< TiledLinalgOp > tileLinalgOp(RewriterBase &b, LinalgOp op, const LinalgTilingOptions &options)
Definition: Tiling.cpp:327
SmallVector< int64_t > peeledLoops
Peel the specified loops.
Definition: Transforms.h:617
This class represents an argument of a Block.
Definition: Value.h:298
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:491
FailureOr< Value > hoistPaddingOnTensors(tensor::PadOp opToHoist, int numLoops, ArrayRef< int64_t > transposeVector, tensor::PadOp &hoistedOp, SmallVectorImpl< GenericOp > &transposeOps)
Mechanically hoist padding operations on tensors by numLoops into a new, generally larger tensor...
virtual void finalizeRootUpdate(Operation *op)
This method is used to signal the end of a root update on the given operation.
Definition: PatternMatch.h:779
FailureOr< TiledLinalgOp > returningMatchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const
matchAndRewrite implementation that returns the significant transformed pieces of IR...
Definition: Transforms.cpp:467
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
Definition: Transforms.cpp:367
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:84
LogicalResult peelAndCanonicalizeForLoop(RewriterBase &rewriter, ForOp forOp, scf::ForOp &partialIteration)
Rewrite a for loop with bounds/step that potentially do not divide evenly into a for loop where the s...
static SmallVector< StringRef > getNParallelLoopsAttrs(unsigned nParallelLoops)
Definition: Transforms.cpp:756
bool isPermutation(ArrayRef< int64_t > permutation)
Check if permutation is a permutation of the range [0, permutation.size()).
Definition: Utils.cpp:145
static llvm::ManagedStatic< PassManagerOptions > options
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:230
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:355
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:362
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:212
std::function< LogicalResult(Operation *)> FilterFunction
Definition: Transforms.h:429
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:279
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:347
LinalgTilingLoopType
The type of loops to be generated during tiling.
Definition: Utils.h:143
OpTy replaceOpWithNewOp(Operation *op, Args &&... args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:741
Type getType() const
Return the type of this value.
Definition: Value.h:117
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
Definition: PatternMatch.h:930
static const StringLiteral kLinalgTransformMarker
Definition: Transforms.h:418
LogicalResult matchAndRewrite(LinalgOp linalgOp, PatternRewriter &rewriter) const override
Definition: Transforms.cpp:718
Fuse a sequence of linalg operations (ops) using tile-and-fuse.
Definition: Transforms.h:235
U dyn_cast() const
Definition: Attributes.h:117
Specialization of arith.constant op that returns an integer of index type.
Definition: Arithmetic.h:78
static SmallVector< Value, 4 > peelLoop(RewriterBase &rewriter, Operation *op)
Try to peel a loop op and return the new result.
Definition: Transforms.cpp:288
Perform standalone tiling of a single LinalgOp by tileSizes.
Definition: Transforms.h:173
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
This class represents an operand of an operation.
Definition: Value.h:249
LogicalResult vectorize(RewriterBase &builder, LinalgOp linalgOp)
Emit a suitable vector form for a Linalg op with fully static shape.
PaddingTransposeComputationFunction paddingTransposeComputationFunction
Callback returning the transpose vector used to permute the result tensor dimensions of the PadOp def...
Definition: Transforms.h:538
bool hasReplacementFilter(Operation *op) const
Definition: Transforms.cpp:110
Linalg vectorization patterns.
Definition: Transforms.h:910
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:273
LogicalResult applyStagedPatterns(Operation *op, ArrayRef< FrozenRewritePatternSet > stage1Patterns, const FrozenRewritePatternSet &stage2Patterns, function_ref< LogicalResult(Operation *)> stage3Lambda=nullptr)
Helper function to allow applying rewrite patterns, interleaved with more global transformations, in a staged fashion:
Definition: Transforms.cpp:725
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Operation *op, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure...
Definition: PatternMatch.h:802
LinalgBasePromotionPattern(MLIRContext *context, LinalgTransformationFilter f, LinalgPromotionOptions options=LinalgPromotionOptions(), PatternBenefit benefit=1)
Entry point to match any LinalgOp OpInterface.
Definition: Transforms.cpp:672
FailureOr< int64_t > getConstantUpperBoundForIndex(Value value)
Returns a constant upper bound for the result value of an index computation.
Definition: Utils.cpp:265
void peelTiledLinalgOp(RewriterBase &rewriter, TiledLinalgOp &res, ArrayRef< int64_t > peeledLoops, LinalgTilingLoopType loopType)
Peel the loops of a TiledLinalgOp.
Definition: Transforms.cpp:314
SmallVector< Operation *, 8 > loops
Definition: Transforms.h:175
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:231
Optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult applyPatternsAndFoldGreedily(MutableArrayRef< Region > regions, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig())
Rewrite the regions of the specified operation, which must be isolated from above, by repeatedly applying the highest benefit patterns in a greedy work-list driven manner.
result_range getResults()
Definition: Operation.h:284
This class helps build Operations.
Definition: Builders.h:177
This class provides an abstraction over the different types of ranges over Values.
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:95
virtual void cancelRootUpdate(Operation *op)
This method cancels a pending root update.
Definition: PatternMatch.h:783
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:323
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:201
MLIRContext * getContext() const
Definition: PatternMatch.h:906
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Definition: TensorOps.cpp:1528
PaddingNoFoldComputationFunction paddingNoFoldComputationFunction
Callback returning true if the PadOp defining the given OpOperand shall be marked as nofold to enable...
Definition: Transforms.h:518
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:688
LinalgTilingOptions & setTileSizes(const SmallVector< Value, 4 > &ts)
Set the tileSizeComputationFunction to return the values ts.
Definition: Transforms.h:569
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:370
U cast() const
Definition: Types.h:250
FailureOr< LinalgOp > returningMatchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const
matchAndRewrite implementation that returns the significant transformed pieces of IR...
Definition: Transforms.cpp:505