MLIR  16.0.0git
LinalgTransformOps.cpp
Go to the documentation of this file.
1 //===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
24 #include "llvm/ADT/StringSet.h"
25 #include "llvm/Support/Debug.h"
26 
27 using namespace mlir;
28 using namespace mlir::linalg;
29 using namespace mlir::transform;
30 
31 #define DEBUG_TYPE "linalg-transforms"
32 
33 /// Extracts a vector of unsigned from an array attribute. Asserts if the
34 /// attribute contains values other than intergers. May truncate.
35 static SmallVector<unsigned> extractUIntArray(ArrayAttr attr) {
36  SmallVector<unsigned> result;
37  result.reserve(attr.size());
38  for (APInt value : attr.getAsValueRange<IntegerAttr>())
39  result.push_back(value.getZExtValue());
40  return result;
41 }
42 
43 namespace {
44 /// A simple pattern rewriter that implements no special logic.
45 class SimpleRewriter : public PatternRewriter {
46 public:
47  SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {}
48 };
49 } // namespace
50 
51 /// Attempts to apply the pattern specified as template argument to the given
52 /// operation. The pattern is expected to have a `returningMatchAndRewrite`
53 /// function that returns the "main" result or failure. Returns failure if the
54 /// pattern failed to apply. Extra arguments are forwarded to the pattern
55 /// constructor.
56 template <typename PatternTy, typename... Args>
57 static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
58  // Check if the given operation has the type expected by the pattern.
59  using OpTy = typename llvm::function_traits<
60  decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
61  auto op = dyn_cast<OpTy>(operation);
62  if (!op)
63  return failure();
64 
65  // Apply the pattern directly to the op.
66  PatternTy pattern(operation->getContext(), std::forward<Args>(args)...);
67  SimpleRewriter rewriter(operation->getContext());
68  rewriter.setInsertionPoint(operation);
69  auto result = pattern.returningMatchAndRewrite(op, rewriter);
70  if (failed(result))
71  return failure();
72  return cast<LinalgOp>(result->getOperation());
73 }
74 
75 //===----------------------------------------------------------------------===//
76 // DecomposeOp
77 //===----------------------------------------------------------------------===//
78 
80 transform::DecomposeOp::applyToOne(linalg::LinalgOp target,
83  FailureOr<LinalgOp> windowedNhwc =
84  tryApply<DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNhwcHwcfOp,
85  Conv1DNwcWcfOp>>(target);
86  if (succeeded(windowedNhwc)) {
87  results.push_back(*windowedNhwc);
89  }
90  FailureOr<LinalgOp> windowedNchw =
91  tryApply<DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNchwFchwOp,
92  Conv1DNcwFcwOp>>(target);
93  if (succeeded(windowedNchw)) {
94  results.push_back(*windowedNchw);
96  }
97  FailureOr<LinalgOp> depthwise =
98  tryApply<DownscaleDepthwiseConv2DNhwcHwcOp>(target);
99  if (succeeded(depthwise)) {
100  results.push_back(*depthwise);
102  }
103  results.assign(1, nullptr);
104  return emitDefaultSilenceableFailure(target);
105 }
106 //===----------------------------------------------------------------------===//
107 // FuseOp
108 //===----------------------------------------------------------------------===//
109 
110 /// Apply a tiling transformation to all payload ops and store both the
111 /// tiled operation as well as the created tile loops.
113  Operation *transformOp, ArrayRef<Operation *> payloadOps, unsigned numLoops,
114  transform::TransformResults &transformResults,
116  applyFn) {
117  SmallVector<Operation *> tiledLinalgOps;
118  SmallVector<SmallVector<Operation *>> loopOps(numLoops);
119  for (unsigned int i = 0; i < numLoops; ++i)
120  loopOps[i].reserve(payloadOps.size());
121 
122  for (Operation *target : payloadOps) {
123  auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
124  if (!tilingInterfaceOp)
125  return transformOp->emitError("only TilingInterface ops are supported");
126 
127  SimpleRewriter rewriter(target->getContext());
128  rewriter.setInsertionPoint(target);
130  applyFn(tilingInterfaceOp);
131  if (failed(tiledResults))
132  return failure();
133 
134  // Perform the replacement of tiled and fused values.
135  SmallVector<Operation *> opsToReplace{target};
136  llvm::append_range(opsToReplace, tiledResults->fusedProducers);
137  for (Operation *toReplace : opsToReplace) {
138  SmallVector<Value> replacements;
139  replacements.reserve(toReplace->getNumResults());
140  for (OpResult res : toReplace->getResults()) {
141  auto it = tiledResults->replacements.find(res);
142  if (it == tiledResults->replacements.end())
143  replacements.push_back(res);
144  else
145  replacements.push_back(it->getSecond());
146  }
147  rewriter.replaceOp(toReplace, replacements);
148  }
149 
150  // Report back the relevant handles to the transform op.
151  tiledLinalgOps.push_back(tiledResults->tiledAndFusedOps.front());
152  assert(tiledResults->loops.size() == numLoops &&
153  "Mismatched number of loops, tile and fuse transform should have "
154  "failed");
155  for (unsigned int i = 0; i < numLoops; ++i)
156  loopOps[i].push_back(tiledResults->loops[i]);
157  }
158 
159  transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
160  for (unsigned int i = 0; i < numLoops; ++i)
161  transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
162 
163  return success();
164 }
165 
166 /// Parse a tiling-like operation that returns the tiled op as well as the
167 /// created tile loops. The function counts the non-zero tile sizes to compute
168 /// the number of results.
170  StringRef sizesAttrName) {
171  OpAsmParser::UnresolvedOperand targetOperand;
172  SMLoc opLoc = parser.getCurrentLocation();
173  if (parser.parseOperand(targetOperand) ||
174  parser.parseOptionalAttrDict(result.attributes))
175  return failure();
176  Attribute sizesAttr = result.attributes.get(sizesAttrName);
177  if (!sizesAttr)
178  return parser.emitError(opLoc)
179  << "expected '" << sizesAttrName << "' attribute";
180  auto sizesArrayAttr = sizesAttr.dyn_cast<ArrayAttr>();
181  if (!sizesArrayAttr)
182  return parser.emitError(opLoc)
183  << "'" << sizesAttrName << "' attribute must be an array";
184  Type pdlOpType = parser.getBuilder().getType<pdl::OperationType>();
185  size_t numExpectedLoops =
186  sizesArrayAttr.size() -
187  llvm::count(extractFromI64ArrayAttr(sizesArrayAttr), 0);
188  result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOpType));
189  if (parser.resolveOperand(targetOperand, pdlOpType, result.operands))
190  return failure();
191  return success();
192 }
193 
195 transform::FuseOp::apply(mlir::transform::TransformResults &transformResults,
197  SmallVector<int64_t> tileSizes = extractFromI64ArrayAttr(getTileSizes());
198  SmallVector<int64_t> tileInterchange =
199  extractFromI64ArrayAttr(getTileInterchange());
200 
201  scf::SCFTilingOptions tilingOptions;
202  tilingOptions.interchangeVector = tileInterchange;
203  tilingOptions = tilingOptions.setTileSizes(tileSizes);
204  scf::SCFTileAndFuseOptions tileAndFuseOptions;
205  tileAndFuseOptions.tilingOptions = tilingOptions;
207  getOperation(), state.getPayloadOps(getTarget()),
208  tileSizes.size() - llvm::count(tileSizes, 0), transformResults,
209  [&](TilingInterface tilingInterfaceOp)
211  SimpleRewriter rewriter(getContext());
212  return tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
213  rewriter, tilingInterfaceOp, tileAndFuseOptions);
214  });
215  return DiagnosedSilenceableFailure(result);
216 }
217 
218 ParseResult transform::FuseOp::parse(OpAsmParser &parser,
219  OperationState &result) {
220  return parseTileLikeOp(
221  parser, result,
222  transform::FuseOp::getTileSizesAttrName(result.name).getValue());
223 }
224 
226  p << ' ';
227  p << getTarget();
228  p.printOptionalAttrDict((*this)->getAttrs());
229 }
230 
232  SmallVector<int64_t> permutation =
233  extractFromI64ArrayAttr(getTileInterchange());
234  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
235  if (!std::is_permutation(sequence.begin(), sequence.end(),
236  permutation.begin(), permutation.end())) {
237  return emitOpError() << "expects interchange to be a permutation, found "
238  << getTileInterchange();
239  }
240  return success();
241 }
242 
243 //===----------------------------------------------------------------------===//
244 // FuseIntoContainingOp
245 //===----------------------------------------------------------------------===//
246 
247 void transform::FuseIntoContainingOp::build(OpBuilder &builder,
248  OperationState &result,
249  Value producerOp,
250  Value containingOp) {
251  result.addOperands({producerOp, containingOp});
252  result.addTypes(pdl::OperationType::get(builder.getContext()));
253 }
254 
255 /// Find the first "extract" user of `producerOp` and tile it right before its
256 /// use. The tiled op is fused under the `containingOp`.
257 /// Return this fused op on success or nullptr if anything fails.
259  Diagnostic &diag,
260  Operation *producerOp,
261  Operation *containingOp) {
262  LLVM_DEBUG(llvm::dbgs() << "Try to fuse a direct extract use\n");
263  auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
264  if (!tileableProducer) {
265  diag.attachNote(producerOp->getLoc())
266  << "producer is not a TileableInterface: " << *producerOp;
267  return nullptr;
268  }
269 
270  // Search the producer slices accessed within the containing operation.
271  // TODO: Generalize to more extract/insert/parallel_insert triples, maybe
272  // evolve into an interface.
273  auto it = llvm::find_if(tileableProducer->getUsers(), [&](Operation *user) {
274  auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
275  return sliceOp && containingOp->isProperAncestor(sliceOp);
276  });
277 
278  // Find a fusion opportunity.
279  if (it == tileableProducer->getUsers().end()) {
280  diag.attachNote(tileableProducer->getLoc())
281  << "could not find fusion opportunity for: " << *tileableProducer;
282  return nullptr;
283  }
284  auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*it);
285 
286  // Try to fuse the producer in-place.
287  OpBuilder::InsertionGuard guard(rewriter);
288  rewriter.setInsertionPoint(sliceOpToTile);
289 
290  // Tile the producer.
291  int64_t resultNumber =
292  sliceOpToTile.getSource().cast<OpResult>().getResultNumber();
293  LLVM_DEBUG(llvm::dbgs() << "resultNumber: " << resultNumber << "\n");
294 
295  FailureOr<Value> tiledProducer = tileableProducer.generateResultTileValue(
296  rewriter, resultNumber, sliceOpToTile.getMixedOffsets(),
297  sliceOpToTile.getMixedSizes());
298  if (failed(tiledProducer)) {
299  diag.attachNote(tileableProducer->getLoc())
300  << "failed to tile producer op: " << *tileableProducer;
301  return nullptr;
302  }
303  LLVM_DEBUG(llvm::dbgs() << "tiledProducer: " << *tiledProducer << "\n");
304 
305  // Replace the extract op.
306  Operation *fusedOp = tiledProducer->getDefiningOp();
307  rewriter.replaceOp(sliceOpToTile, fusedOp->getResult(resultNumber));
308  return fusedOp;
309 }
310 
311 /// First, find the first "scf::ForeachThreadOp" user of `producerOp` and ensure
312 /// it is exactly the `containingOp`, otherwise bail.
313 /// Then, find the first "extract" user of the tied block argument and tile it
314 /// right before its "extract" use. The tiled op is fused under the
315 /// `containingOp`.
316 /// Return this fused op on success or nullptr if anything fails.
318  RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
319  Operation *containingOp) {
320  LLVM_DEBUG(
321  llvm::dbgs() << "Try to fuse an extract use through block argument\n");
322 
323  auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
324  if (!tileableProducer) {
325  diag.attachNote(producerOp->getLoc())
326  << "producer is not a TileableInterface: " << *producerOp;
327  return nullptr;
328  }
329 
330  // Search the first use by a "scf::ForeachThreadOp" user.
331  scf::ForeachThreadOp foreachThreadOp;
332  auto itProducerUses =
333  llvm::find_if(tileableProducer->getUses(), [&](OpOperand &use) {
334  foreachThreadOp = dyn_cast<scf::ForeachThreadOp>(use.getOwner());
335  return foreachThreadOp;
336  });
337  // If it's not from the containing op, return.
338  if (!foreachThreadOp || foreachThreadOp != containingOp) {
339  diag.attachNote(tileableProducer->getLoc())
340  << "could not find a use by the containing op: " << *tileableProducer;
341  return nullptr;
342  }
343 
344  // Search the producer slices accessed within the containing
345  // operation.
346  // TODO: Generalize to more extract/insert/parallel_insert triples.
347  // Maybe evolve into an interface.
348  OpOperand *pUse = &(*itProducerUses);
349  BlockArgument bbArg = foreachThreadOp.getTiedBlockArgument(pUse);
350 
351  // Search the producer slices accessed within the containing operation.
352  // TODO: Generalize to more extract/insert/parallel_insert triples, maybe
353  // evolve into an interface.
354  auto itBBArgUsers = llvm::find_if(bbArg.getUsers(), [&](Operation *user) {
355  auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
356  return sliceOp && containingOp->isProperAncestor(sliceOp);
357  });
358 
359  // Find a fusion opportunity.
360  if (itBBArgUsers == bbArg.getUsers().end()) {
361  diag.attachNote(containingOp->getLoc())
362  << "could not find fusion opportunity for bbArg: " << bbArg;
363  return nullptr;
364  }
365  auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*itBBArgUsers);
366 
367  // Try to fuse the producer in-place.
368  OpBuilder::InsertionGuard guard(rewriter);
369  rewriter.setInsertionPoint(sliceOpToTile);
370 
371  // Replace the use in the tileableProducer before tiling: clone, replace and
372  // then tile.
373  int64_t resultNumber = pUse->get().cast<OpResult>().getResultNumber();
374  LLVM_DEBUG(llvm::dbgs() << "resultNumber: " << resultNumber << "\n");
375 
376  // Gather destination tensors.
377  SmallVector<Value> destinationTensors;
379  rewriter, tileableProducer->getLoc(), tileableProducer,
380  destinationTensors))) {
381  diag.attachNote(tileableProducer->getLoc())
382  << "failed to get destination tensors for: " << *tileableProducer;
383  return nullptr;
384  }
385 
387  bvm.map(destinationTensors[resultNumber], bbArg);
388  auto tileableProducerClone =
389  cast<TilingInterface>(rewriter.clone(*tileableProducer, bvm));
390  auto scopeGuard =
391  llvm::make_scope_exit([&]() { rewriter.eraseOp(tileableProducerClone); });
392 
393  // Tile the producer.
394  FailureOr<Value> tiledProducer =
395  tileableProducerClone.generateResultTileValue(
396  rewriter, resultNumber, sliceOpToTile.getMixedOffsets(),
397  sliceOpToTile.getMixedSizes());
398  if (failed(tiledProducer)) {
399  diag.attachNote(tileableProducer->getLoc())
400  << "failed to tile producer op: " << *tileableProducer;
401  return nullptr;
402  }
403  LLVM_DEBUG(llvm::dbgs() << "tiledProducer: " << *tiledProducer << "\n");
404 
405  // Replace the extract op.
406  Operation *fusedOp = tiledProducer->getDefiningOp();
407  rewriter.replaceOp(sliceOpToTile, fusedOp->getResult(resultNumber));
408 
409  // Replace the use in containingOp.
410  rewriter.updateRootInPlace(containingOp, [&]() {
411  containingOp->setOperand(pUse->getOperandNumber(),
412  destinationTensors.front());
413  });
414 
415  return fusedOp;
416 }
417 
419  Operation *producerOp,
420  Operation *containingOp) {
421  LLVM_DEBUG(llvm::dbgs() << "Try to fuse an use by cloning\n");
422 
423  // Gather all uses inside the containing op.
425  for (OpResult result : producerOp->getOpResults()) {
426  for (OpOperand &use : result.getUses()) {
427  if (containingOp->isProperAncestor(use.getOwner())) {
428  uses.push_back(&use);
429  continue;
430  }
431  // Cannot clone and fuse if the use is by the containing op itself: fail
432  // immediately.
433  if (containingOp == use.getOwner()) {
434  diag.attachNote(producerOp->getLoc())
435  << "producer op use by containing op cannot be fused by cloning";
436  return nullptr;
437  }
438  }
439  }
440 
441  // Check for a non-empty list of fusion opportunities.
442  if (uses.empty()) {
443  diag.attachNote(producerOp->getLoc()) << "no fusion opportunity by cloning";
444  return nullptr;
445  }
446 
447  // Clone and fuse inside the containing op.
448  Operation *fusedOp = nullptr;
449  OpOperand *use = uses.front();
450  // Parallel insert slice is not a valid clone destination.
451  // TODO: Generalize to other type of ops.
452  assert(!isa<tensor::ParallelInsertSliceOp>(use->getOwner()) &&
453  "Parallel insert slice is not a valid clone destination");
454  unsigned resultNumber = use->get().cast<OpResult>().getResultNumber();
455  LLVM_DEBUG(llvm::dbgs() << "resultNumber: " << resultNumber << "\n");
456 
457  OpBuilder::InsertionGuard guard(rewriter);
458  rewriter.setInsertionPoint(use->getOwner());
459  fusedOp = rewriter.clone(*producerOp);
460  rewriter.updateRootInPlace(
461  use->getOwner(), [&] { use->set(fusedOp->getOpResult(resultNumber)); });
462 
463  return fusedOp;
464 }
465 
467 transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
468  transform::TransformState &state) {
469  SmallVector<Operation *> fusedOps;
470  ArrayRef<Operation *> producerOps = state.getPayloadOps(getProducerOp());
471  // If nothing to fuse, propagate success.
472  if (producerOps.empty()) {
473  results.set(getFusedOp().cast<OpResult>(),
476  }
477  ArrayRef<Operation *> containingOps = state.getPayloadOps(getContainingOp());
478  if (containingOps.size() != 1) {
479  return emitDefiniteFailure()
480  << "requires exactly one containing_op handle (got "
481  << containingOps.size() << ")";
482  }
483  Operation *containingOp = containingOps.front();
484 
485  // Helper function to find the next producer that should be fused. Take any
486  // producer that has a use inside the containing op.
487  SmallVector<Operation *> remainingProducers(producerOps.begin(),
488  producerOps.end());
489  auto getNextProducer = [&]() -> FailureOr<Operation *> {
490  for (const auto &it : enumerate(remainingProducers)) {
491  Operation *producerOp = it.value();
492  // The containing op may be a user of producerOp: use isAncestor.
493  int64_t numUsesInContainingOp =
494  llvm::count_if(producerOp->getUsers(), [&](Operation *op) {
495  return containingOp->isAncestor(op);
496  });
497  // TODO: When resolving the TODO below (no duplicate ops), take an op
498  // that has no use among the remaining producers. This is a topological
499  // sorting.
500  if (numUsesInContainingOp > 0) {
501  if (numUsesInContainingOp == 1)
502  remainingProducers.erase(remainingProducers.begin() + it.index());
503  return producerOp;
504  }
505  }
506  return failure();
507  };
508 
509  IRRewriter rewriter(getContext());
510  while (!remainingProducers.empty()) {
511  auto nextProducer = getNextProducer();
512  if (failed(nextProducer)) {
513  results.set(getFusedOp().cast<OpResult>(), ArrayRef<Operation *>());
515  diag << "could not find next producer to fuse into container";
517  }
518 
519  Operation *producerOp = *nextProducer;
520 
521  // Default diagnostic, to be complemented with more failure information.
523  diag << "could not fuse " << *producerOp << " into " << *containingOp;
524 
525  // TODO: If there are multiple uses of the producer in the containing op,
526  // we currently tile/clone the op multiple times (once per use). In some
527  // cases, we can tile/clone once and reuse the value for each use.
528  // Futhermore, producers should then be traversed according to a
529  // topological sorting.
530  Operation *tiled =
531  tileAndFuseFirstExtractUse(rewriter, diag, producerOp, containingOp);
532  if (tiled) {
533  LLVM_DEBUG(llvm::dbgs() << "\nFused a direct extract use\n"
534  << *containingOp);
535  fusedOps.push_back(tiled);
536  continue;
537  }
538 
539  Operation *tiledContainingOpOperand =
541  rewriter, diag, producerOp, containingOp);
542  if (tiledContainingOpOperand) {
543  LLVM_DEBUG(llvm::dbgs()
544  << "\nFused an extract use through block argument\n"
545  << *containingOp);
546  fusedOps.push_back(tiledContainingOpOperand);
547  continue;
548  }
549 
550  Operation *cloned =
551  cloneAndFuseFirstUse(rewriter, diag, producerOp, containingOp);
552  if (cloned) {
553  LLVM_DEBUG(llvm::dbgs() << "\nFused an use by cloning\n"
554  << *containingOp);
555  fusedOps.push_back(cloned);
556  continue;
557  }
558  results.set(getFusedOp().cast<OpResult>(), ArrayRef<Operation *>());
560  }
561 
562  results.set(getFusedOp().cast<OpResult>(), fusedOps);
564 }
565 
566 //===----------------------------------------------------------------------===//
567 // GeneralizeOp
568 //===----------------------------------------------------------------------===//
569 
571 transform::GeneralizeOp::applyToOne(linalg::LinalgOp target,
573  transform::TransformState &state) {
574  // Exit early if no transformation is needed.
575  if (isa<GenericOp>(target)) {
576  results.push_back(target);
578  }
579  FailureOr<LinalgOp> generic = tryApply<LinalgGeneralizationPattern>(target);
580  if (succeeded(generic)) {
581  results.push_back(generic->getOperation());
583  }
584  results.assign(1, nullptr);
585  return emitDefaultSilenceableFailure(target);
586 }
587 
588 //===----------------------------------------------------------------------===//
589 // InterchangeOp
590 //===----------------------------------------------------------------------===//
591 
593 transform::InterchangeOp::applyToOne(linalg::GenericOp target,
595  transform::TransformState &state) {
596  SmallVector<unsigned> interchangeVector =
597  extractUIntArray(getIteratorInterchange());
598  // Exit early if no transformation is needed.
599  if (interchangeVector.empty()) {
600  results.push_back(target);
602  }
603  SimpleRewriter rewriter(target->getContext());
605  interchangeGenericOp(rewriter, target, interchangeVector);
606  if (failed(res))
608  results.push_back(res->getOperation());
610 }
611 
613  SmallVector<unsigned> permutation =
614  extractUIntArray(getIteratorInterchange());
615  auto sequence = llvm::to_vector(llvm::seq<unsigned>(0, permutation.size()));
616  if (!std::is_permutation(sequence.begin(), sequence.end(),
617  permutation.begin(), permutation.end())) {
618  return emitOpError()
619  << "expects iterator_interchange to be a permutation, found "
620  << getIteratorInterchange();
621  }
622  return success();
623 }
624 
625 //===---------------------------------------------------------------------===//
626 // MatchOp
627 //===---------------------------------------------------------------------===//
628 
629 void transform::MatchOp::build(OpBuilder &builder, OperationState &result,
630  Value target, ArrayRef<StringRef> opNames) {
631  result.addOperands(target);
632  result.addAttribute(MatchOp::getOpsAttrName(result.name),
633  builder.getStrArrayAttr(opNames));
634  result.addTypes(pdl::OperationType::get(builder.getContext()));
635 }
636 
638 transform::MatchOp::apply(transform::TransformResults &results,
639  transform::TransformState &state) {
640  llvm::StringSet<> strs;
641  if (getOps().has_value())
642  strs.insert(getOps()->getAsValueRange<StringAttr>().begin(),
643  getOps()->getAsValueRange<StringAttr>().end());
644 
645  ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
646  if (payloadOps.size() != 1) {
647  results.set(getResult().cast<OpResult>(), {});
649  this->emitOpError("requires exactly one target handle"));
650  }
651 
653  auto matchFun = [&](Operation *op) {
654  if (getOps().has_value() && !strs.contains(op->getName().getStringRef()))
655  return;
656 
657  // Interfaces cannot be matched by name, just by ID.
658  // So we specifically encode the interfaces we care about for this op.
659  if (getInterface().has_value()) {
660  auto iface = getInterface().value();
661  if (iface == transform::MatchInterfaceEnum::LinalgOp &&
662  !isa<linalg::LinalgOp>(op))
663  return;
664  if (iface == transform::MatchInterfaceEnum::TilingInterface &&
665  isa<TilingInterface>(op))
666  return;
667  }
668 
669  // Check if all specified attributes match.
670  if (getOpAttrs().has_value()) {
671  DictionaryAttr opAttrs = getOpAttrs().value();
672  for (NamedAttribute attr : opAttrs) {
673  if (attr.getName() == getInterfaceAttrName() ||
674  attr.getName() == getOpsAttrName())
675  continue;
676  if (!op->hasAttr(attr.getName()))
677  return;
678  if (op->getAttr(attr.getName()) != attr.getValue())
679  return;
680  }
681  }
682 
683  if (getFilterResultType().has_value()) {
684  Type t = getFilterResultType().value();
685  if (op->getNumResults() != 1 || op->getResultTypes().front() != t)
686  return;
687  }
688 
689  // All constraints are satisfied.
690  res.push_back(op);
691  return;
692  };
693 
694  payloadOps.front()->walk(matchFun);
695  results.set(getResult().cast<OpResult>(), res);
697 }
698 
699 //===---------------------------------------------------------------------===//
700 // MultiTileSizesOp
701 //===---------------------------------------------------------------------===//
702 
703 DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne(
704  LinalgOp target, SmallVector<Operation *> &results, TransformState &state) {
705  OpBuilder builder(target.getContext());
706  builder.setInsertionPoint(target);
707  OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
708  OpFoldResult divisor = builder.getIndexAttr(getDivisor());
710  builder, target, getDimension(), targetSize, divisor);
711  if (failed(spec)) {
712  return emitSilenceableError() << "could not generate tile size computation";
713  }
714 
715  AffineExpr s0 = builder.getAffineSymbolExpr(0);
716  AffineExpr s1 = builder.getAffineSymbolExpr(1);
717  Operation *splitPoint =
718  makeComposedAffineApply(builder, target.getLoc(), s0 * s1,
719  {spec->lowTileSize, spec->lowTripCount});
720  Operation *lowTileSize = spec->lowTileSize.getDefiningOp();
721  Operation *highTileSize = spec->highTileSize.getDefiningOp();
722  assert(lowTileSize && highTileSize && splitPoint &&
723  "tile sizes are not produced by operations");
724  results.reserve(results.size() + 3);
725  results.push_back(lowTileSize);
726  results.push_back(highTileSize);
727  results.push_back(splitPoint);
729 }
730 
731 void transform::MultiTileSizesOp::getEffects(
733  onlyReadsHandle(getTarget(), effects);
734  producesHandle(getResults(), effects);
735  modifiesPayload(effects);
736 }
737 
738 //===---------------------------------------------------------------------===//
739 // PadOp
740 //===---------------------------------------------------------------------===//
741 
743 transform::PadOp::applyToOne(linalg::LinalgOp target,
745  transform::TransformState &state) {
746  // Convert the integer packing flags to booleans.
747  SmallVector<bool> packPaddings;
748  for (int64_t packPadding : extractFromI64ArrayAttr(getPackPaddings()))
749  packPaddings.push_back(static_cast<bool>(packPadding));
750 
751  // Convert the padding values to attributes.
752  SmallVector<Attribute> paddingValues;
753  for (auto const &it :
754  llvm::zip(getPaddingValues(), target->getOperandTypes())) {
755  auto attr = std::get<0>(it).dyn_cast<TypedAttr>();
756  if (!attr) {
757  emitOpError("expects padding values to be typed attributes");
759  }
760  Type elementType = getElementTypeOrSelf(std::get<1>(it));
761  // Try to parse string attributes to obtain an attribute of element type.
762  if (auto stringAttr = attr.dyn_cast<StringAttr>()) {
763  paddingValues.push_back(
764  parseAttribute(attr.cast<StringAttr>(), elementType));
765  if (!paddingValues.back()) {
766  auto diag = this->emitOpError("expects a padding that parses to ")
767  << elementType << ", got " << std::get<0>(it);
768  diag.attachNote(target.getLoc()) << "when applied to this op";
770  }
771  continue;
772  }
773  // Otherwise, add the attribute directly.
774  if (attr.getType() != elementType) {
775  auto diag = this->emitOpError("expects a padding value of type ")
776  << elementType << ", got " << attr;
777  diag.attachNote(target.getLoc()) << "when applied to this op";
779  }
780  paddingValues.push_back(attr);
781  }
782 
783  // Extract the transpose vectors.
784  SmallVector<SmallVector<int64_t>> transposePaddings;
785  for (Attribute transposeVector : getTransposePaddings().cast<ArrayAttr>())
786  transposePaddings.push_back(
787  extractFromI64ArrayAttr(transposeVector.cast<ArrayAttr>()));
788 
789  LinalgPaddingOptions paddingOptions;
790  paddingOptions.setPaddingValues(paddingValues);
791  paddingOptions.setPaddingDimensions(
792  extractFromI64ArrayAttr(getPaddingDimensions()));
793  paddingOptions.setPackPaddings(packPaddings);
794  paddingOptions.setHoistPaddings(extractFromI64ArrayAttr(getHoistPaddings()));
795  paddingOptions.setTransposePaddings(transposePaddings);
796 
797  FailureOr<LinalgOp> result =
798  tryApply<LinalgPaddingPattern>(target, paddingOptions);
799  if (succeeded(result)) {
800  results.push_back(result->getOperation());
802  }
803 
804  results.assign(1, nullptr);
805  return emitDefaultSilenceableFailure(target);
806 }
807 
809  SmallVector<int64_t> packPaddings =
810  extractFromI64ArrayAttr(getPackPaddings());
811  if (any_of(packPaddings, [](int64_t packPadding) {
812  return packPadding != 0 && packPadding != 1;
813  })) {
814  return emitOpError()
815  << "expects pack_paddings to contain booleans (0/1), found "
816  << getPackPaddings();
817  }
818 
819  SmallVector<int64_t> paddingDimensions =
820  extractFromI64ArrayAttr(getPaddingDimensions());
821  if (any_of(paddingDimensions,
822  [](int64_t paddingDimension) { return paddingDimension < 0; })) {
823  return emitOpError() << "expects padding_dimensions to contain positive "
824  "integers, found "
825  << getPaddingDimensions();
826  }
827 
828  SmallVector<int64_t> hoistPaddings =
829  extractFromI64ArrayAttr(getHoistPaddings());
830  if (any_of(hoistPaddings,
831  [](int64_t hoistPadding) { return hoistPadding < 0; })) {
832  return emitOpError()
833  << "expects hoist_paddings to contain positive integers, found "
834  << getHoistPaddings();
835  }
836 
837  ArrayAttr transposes = getTransposePaddings();
838  for (Attribute attr : transposes) {
840  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
841  if (!std::is_permutation(sequence.begin(), sequence.end(),
842  transpose.begin(), transpose.end())) {
843  return emitOpError()
844  << "expects transpose_paddings to be a permutation, found "
845  << attr;
846  }
847  }
848  return success();
849 }
850 
851 //===----------------------------------------------------------------------===//
852 // PromoteOp
853 //===----------------------------------------------------------------------===//
854 
856 transform::PromoteOp::applyToOne(linalg::LinalgOp target,
858  transform::TransformState &state) {
859  LinalgPromotionOptions promotionOptions;
860  if (!getOperandsToPromote().empty())
861  promotionOptions = promotionOptions.setOperandsToPromote(
862  extractFromI64ArrayAttr(getOperandsToPromote()));
863  if (getUseFullTilesByDefault())
864  promotionOptions = promotionOptions.setUseFullTileBuffersByDefault(
865  getUseFullTilesByDefault());
866  if (getUseAlloca())
867  promotionOptions = promotionOptions.setUseAlloca(getUseAlloca());
868  if (!getUseFullTileBuffers().empty())
869  promotionOptions = promotionOptions.setUseFullTileBuffers(
870  llvm::to_vector(getUseFullTileBuffers().getAsValueRange<BoolAttr>()));
871  if (getAlignment().has_value())
872  promotionOptions = promotionOptions.setAlignment(*getAlignment());
873 
874  if (failed(promoteSubviewsPrecondition(target, promotionOptions)))
875  return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
876 
877  SimpleRewriter rewriter(target->getContext());
878  rewriter.setInsertionPoint(target);
879  FailureOr<LinalgOp> res = promoteSubViews(rewriter, target, promotionOptions);
880  if (failed(res))
881  return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
882  results.push_back(target);
884 }
885 
886 //===----------------------------------------------------------------------===//
887 // ScalarizeOp
888 //===----------------------------------------------------------------------===//
889 
891 transform::ScalarizeOp::applyToOne(linalg::LinalgOp target,
893  transform::TransformState &state) {
894  scf::SCFTilingOptions tilingOptions;
895  tilingOptions.setTileSizeComputationFunction([&](OpBuilder &b, Operation *) {
896  SmallVector<Value, 4> tileSizes;
897  Location loc = target.getLoc();
898  SmallVector<OpFoldResult> allShapeSizes =
899  target.createFlatListOfOperandDims(b, loc);
900  AffineMap map = target.getShapesToLoopsMap();
901  if (!map)
902  return tileSizes;
903  IRRewriter rewriter(b);
904  SmallVector<OpFoldResult> shapeSizes =
905  makeComposedFoldedMultiResultAffineApply(rewriter, loc, map,
906  allShapeSizes);
907  // If the shape size is dynamic, tile by 1.
908  // Otherwise, do not tile (i.e. tile size 0).
909  for (OpFoldResult shapeSize : shapeSizes) {
910  tileSizes.push_back(getConstantIntValue(shapeSize)
911  ? b.create<arith::ConstantIndexOp>(loc, 0)
912  : b.create<arith::ConstantIndexOp>(loc, 1));
913  }
914  return tileSizes;
915  });
916  SmallVector<int64_t> emptyTileSizes;
917  SimpleRewriter rewriter(getContext());
918  rewriter.setInsertionPoint(target);
920  rewriter, cast<TilingInterface>(target.getOperation()), tilingOptions);
921  if (failed(maybeTilingResult))
922  return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
923 
924  results.append(maybeTilingResult->tiledOps);
926 }
927 
928 //===----------------------------------------------------------------------===//
929 // SplitOp
930 //===----------------------------------------------------------------------===//
931 
932 DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results,
933  TransformState &state) {
934  // Collect the dynamic split points if provided.
935  ArrayRef<Operation *> payload = state.getPayloadOps(getTarget());
936  SimpleRewriter rewriter(getContext());
937  SmallVector<OpFoldResult> splitPoints;
938  splitPoints.reserve(payload.size());
939  if (getDynamicSplitPoint()) {
941  splitPoints = llvm::to_vector(llvm::map_range(
942  state.getPayloadOps(getDynamicSplitPoint()), [&](Operation *op) {
943  if (op->getNumResults() != 1 ||
944  !op->getResult(0).getType().isIndex()) {
945  diag = emitSilenceableError()
946  << "expected dynamic split point handle to point to a "
947  "single-result index-typed op";
948  diag.attachNote(op->getLoc()) << "dynamic split point";
949  }
950  return OpFoldResult(op->getResult(0));
951  }));
952  if (diag.isSilenceableFailure()) {
953  results.set(getFirst().cast<OpResult>(), {});
954  results.set(getSecond().cast<OpResult>(), {});
955  return diag;
956  }
957 
958  if (splitPoints.size() != payload.size()) {
959  return emitDefiniteFailure()
960  << "expected the dynamic split point handle to point to as "
961  "many operations ("
962  << splitPoints.size() << ") as the target handle ("
963  << payload.size() << ")";
964  }
965  } else {
966  splitPoints.resize(payload.size(),
967  rewriter.getIndexAttr(getStaticSplitPoint()));
968  }
969 
970  // Split each target operation.
971  SmallVector<Operation *> first, second;
972  for (const auto &pair : llvm::zip(payload, splitPoints)) {
973  Operation *target = std::get<0>(pair);
974  auto linalgOp = dyn_cast<LinalgOp>(target);
975  if (!linalgOp) {
976  auto diag = emitSilenceableError() << "only applies to structured ops";
977  diag.attachNote(target->getLoc()) << "target op";
978  results.set(getFirst().cast<OpResult>(), {});
979  results.set(getSecond().cast<OpResult>(), {});
980  return diag;
981  }
982 
983  if (getDimension() >= linalgOp.getNumLoops()) {
984  auto diag = emitSilenceableError() << "dimension " << getDimension()
985  << " does not exist in target op";
986  diag.attachNote(target->getLoc()) << "target op";
987  results.set(getFirst().cast<OpResult>(), {});
988  results.set(getSecond().cast<OpResult>(), {});
989  return diag;
990  }
991 
992  rewriter.setInsertionPoint(linalgOp);
993  std::tie(first.emplace_back(), second.emplace_back()) = linalg::splitOp(
994  rewriter, cast<TilingInterface>(linalgOp.getOperation()),
995  getDimension(), std::get<1>(pair));
996  }
997 
998  results.set(getFirst().cast<OpResult>(), first);
999  results.set(getSecond().cast<OpResult>(), second);
1001 }
1002 
1003 void SplitOp::getEffects(
1005  consumesHandle(getTarget(), effects);
1006  if (getDynamicSplitPoint())
1007  onlyReadsHandle(getDynamicSplitPoint(), effects);
1008  producesHandle(getResults(), effects);
1009  modifiesPayload(effects);
1010 }
1011 
1012 ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) {
1013  OpAsmParser::UnresolvedOperand target, dynamicSplitPoint;
1014  IntegerAttr staticSplitPoint;
1015  auto pdlOperationType =
1016  pdl::OperationType::get(parser.getBuilder().getContext());
1017  if (parser.parseOperand(target) ||
1018  parser.resolveOperand(target, pdlOperationType, result.operands) ||
1019  parser.parseKeyword("after"))
1020  return failure();
1021 
1022  OptionalParseResult dynamicPointParseResult =
1023  parser.parseOptionalOperand(dynamicSplitPoint);
1024  if (!dynamicPointParseResult.has_value()) {
1025  int64_t staticSplitPointValue;
1026  if (failed(parser.parseInteger(staticSplitPointValue)))
1027  return failure();
1028 
1029  staticSplitPoint =
1030  parser.getBuilder().getI64IntegerAttr(staticSplitPointValue);
1031  } else {
1032  if (failed(*dynamicPointParseResult) ||
1033  parser.resolveOperand(dynamicSplitPoint, pdlOperationType,
1034  result.operands)) {
1035  return failure();
1036  }
1037 
1038  staticSplitPoint =
1039  parser.getBuilder().getI64IntegerAttr(ShapedType::kDynamic);
1040  }
1041 
1042  result.addAttribute(
1043  SplitOp::getStaticSplitPointAttrName(result.name).getValue(),
1044  staticSplitPoint);
1045  if (failed(parser.parseOptionalAttrDict(result.attributes)))
1046  return failure();
1047 
1048  result.addTypes({pdlOperationType, pdlOperationType});
1049  return success();
1050 }
1051 
1052 void SplitOp::print(OpAsmPrinter &printer) {
1053  printer << " " << getTarget() << " after ";
1054  int64_t staticSplitSize = static_cast<int64_t>(getStaticSplitPoint());
1055  if (staticSplitSize != ShapedType::kDynamic)
1056  printer << staticSplitSize;
1057  else
1058  printer << getDynamicSplitPoint();
1059  printer << " ";
1060  printer.printOptionalAttrDict(getOperation()->getAttrs(),
1061  {getStaticSplitPointAttrName()});
1062 }
1063 
1065  if ((static_cast<int64_t>(getStaticSplitPoint()) !=
1066  ShapedType::kDynamic) ^
1067  (getDynamicSplitPoint() == nullptr)) {
1068  return emitOpError() << "expects either a dynamic or a static split "
1069  "point to be provided";
1070  }
1071  return success();
1072 }
1073 
1074 //===----------------------------------------------------------------------===//
1075 // SplitReductionOp
1076 //===----------------------------------------------------------------------===//
1077 
1078 void transform::SplitReductionOp::build(
1079  OpBuilder &builder, OperationState &result, Value target,
1080  int64_t splitFactor, int64_t insertSplitDimension, bool innerParallel,
1081  bool useScalingAlgorithm, bool useAlloc) {
1082  MLIRContext *ctx = builder.getContext();
1083  result.addOperands(target);
1084  result.addAttribute(SplitReductionOp::getSplitFactorAttrName(result.name),
1085  builder.getI64IntegerAttr(splitFactor));
1086  result.addAttribute(
1087  SplitReductionOp::getInsertSplitDimensionAttrName(result.name),
1088  builder.getI64IntegerAttr(insertSplitDimension));
1089  if (innerParallel) {
1090  result.addAttribute(SplitReductionOp::getInnerParallelAttrName(result.name),
1091  builder.getUnitAttr());
1092  }
1093  if (useScalingAlgorithm) {
1094  result.addAttribute(
1095  SplitReductionOp::getUseScalingAlgorithmAttrName(result.name),
1096  builder.getUnitAttr());
1097  }
1098  if (useAlloc) {
1099  result.addAttribute(SplitReductionOp::getUseAllocAttrName(result.name),
1100  builder.getUnitAttr());
1101  }
1102  auto resultType = pdl::OperationType::get(ctx);
1103  result.addTypes({resultType, resultType, resultType, resultType});
1104 }
1105 
1107 transform::SplitReductionOp::applyToOne(linalg::LinalgOp target,
1109  transform::TransformState &state) {
1110  ControlSplitReductionFn splitFn = [&](LinalgOp) {
1111  return linalg::SplitReductionOptions{int64_t(getSplitFactor()),
1112  unsigned(getInsertSplitDimension()),
1113  bool(getInnerParallel())};
1114  };
1115  SimpleRewriter rewriter(getContext());
1116  rewriter.setInsertionPoint(target);
1117  FailureOr<SplitReductionResult> splitResult =
1118  (getUseScalingAlgorithm())
1119  ? splitReductionByScaling(rewriter, target, splitFn, getUseAlloc())
1120  : splitReduction(rewriter, target, splitFn, getUseAlloc());
1121  if (failed(splitResult))
1122  return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
1123 
1124  results.push_back(splitResult->initOrAlloc);
1125  results.push_back(splitResult->fillOp);
1126  results.push_back(splitResult->splitLinalgOp);
1127  results.push_back(splitResult->resultCombiningLinalgOp);
1129 }
1130 
1131 //===----------------------------------------------------------------------===//
1132 // SplitReductionOp
1133 //===----------------------------------------------------------------------===//
1134 
1135 DiagnosedSilenceableFailure transform::TileReductionUsingScfOp::applyToOne(
1136  linalg::LinalgOp target, SmallVectorImpl<Operation *> &results,
1137  transform::TransformState &state) {
1138  SimpleRewriter rewriter(getContext());
1139  rewriter.setInsertionPoint(target);
1140  SmallVector<int64_t> tileSizes = extractFromI64ArrayAttr(getTileSizes());
1142  for (int64_t size : tileSizes) {
1143  sizes.push_back(rewriter.getIndexAttr(size));
1144  }
1145 
1147  rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
1148  sizes);
1149 
1150  if (failed(result))
1151  return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
1152  results.push_back(result->initialOp);
1153  results.push_back(result->parallelTiledOp);
1154  results.push_back(result->mergeOp);
1156 }
1157 
1158 //===----------------------------------------------------------------------===//
1159 // TileReductionUsingForeachThreadOp
1160 //===----------------------------------------------------------------------===//
1161 
1163 transform::TileReductionUsingForeachThreadOp::applyToOne(
1164  linalg::LinalgOp target, SmallVectorImpl<Operation *> &results,
1165  transform::TransformState &state) {
1166  SimpleRewriter rewriter(getContext());
1167  rewriter.setInsertionPoint(target);
1168  SmallVector<int64_t> numThreads = extractFromI64ArrayAttr(getNumThreads());
1169  SmallVector<OpFoldResult> numThreadResults;
1170  for (int64_t num : numThreads) {
1171  numThreadResults.push_back(rewriter.getIndexAttr(num));
1172  }
1173 
1176  rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
1177  numThreadResults, /*mapping=*/llvm::None);
1178 
1179  if (failed(result)) {
1180  results.assign(3, nullptr);
1181  Diagnostic diag(target->getLoc(), DiagnosticSeverity::Remark);
1182  diag << "could not tile reduction in target.";
1184  }
1185  results.push_back(result->initialOp);
1186  results.push_back(result->parallelTiledOp);
1187  results.push_back(result->mergeOp);
1189 }
1190 
1191 //===----------------------------------------------------------------------===//
1192 // TileOp
1193 //===----------------------------------------------------------------------===//
1194 
1196 transform::TileOp::apply(TransformResults &transformResults,
1197  TransformState &state) {
1198  ArrayRef<int64_t> tileSizes = getStaticSizes();
1199 
1200  ArrayRef<Operation *> targets = state.getPayloadOps(getTarget());
1201  SmallVector<ArrayRef<Operation *>> dynamicSizeProducers;
1202  dynamicSizeProducers.reserve(getDynamicSizes().size());
1203  for (Value dynamicSizeProducerHandle : getDynamicSizes()) {
1204  dynamicSizeProducers.push_back(
1205  state.getPayloadOps(dynamicSizeProducerHandle));
1206 
1207  if (dynamicSizeProducers.back().size() != targets.size()) {
1209  emitSilenceableError()
1210  << "expected as many dynamic size-producing operations ("
1211  << dynamicSizeProducers.back().size() << ") as target ops ("
1212  << targets.size() << ")";
1213  diag.attachNote(dynamicSizeProducerHandle.getLoc()) << "for this handle";
1214  return diag;
1215  }
1216 
1217  for (Operation *op : dynamicSizeProducers.back()) {
1218  if (op->getNumResults() == 1 &&
1219  op->getResult(0).getType().isa<IndexType>())
1220  continue;
1222  emitSilenceableError() << "expected sizes to be produced by ops "
1223  "with a single index-type result";
1224  diag.attachNote(op->getLoc()) << "size producer op";
1225  diag.attachNote(dynamicSizeProducerHandle.getLoc()) << "for this handle";
1226  return diag;
1227  }
1228  }
1229 
1232  loops.resize(getLoops().size());
1233  for (auto &en : llvm::enumerate(targets)) {
1234  auto linalgOp = dyn_cast<LinalgOp>(en.value());
1235  if (!linalgOp) {
1236  DiagnosedSilenceableFailure diag = emitSilenceableError()
1237  << "only linalg ops are supported";
1238  diag.attachNote(en.value()->getLoc()) << "target op";
1239  return diag;
1240  }
1241 
1242  scf::SCFTilingOptions tilingOptions;
1243  unsigned index = en.index();
1244  if (!tileSizes.empty()) {
1245  tilingOptions.setTileSizeComputationFunction(
1246  [&, index](OpBuilder &b, Operation *) {
1247  SmallVector<Value, 4> sizes;
1248  sizes.reserve(tileSizes.size());
1249  unsigned dynamicIdx = 0;
1250  for (OpFoldResult ofr : getMixedSizes()) {
1251  if (auto attr = ofr.dyn_cast<Attribute>()) {
1252  sizes.push_back(b.create<arith::ConstantIndexOp>(
1253  getLoc(), attr.cast<IntegerAttr>().getInt()));
1254  } else {
1255  sizes.push_back(
1256  dynamicSizeProducers[dynamicIdx++][index]->getResult(0));
1257  }
1258  }
1259  return sizes;
1260  });
1261  }
1262 
1263  tilingOptions.setInterchange(getInterchange());
1264  SimpleRewriter rewriter(linalgOp.getContext());
1266  rewriter, cast<TilingInterface>(linalgOp.getOperation()),
1267  tilingOptions);
1268  if (failed(maybeTilingResult))
1270 
1271  if (linalgOp.hasBufferSemantics())
1272  rewriter.eraseOp(linalgOp);
1273  else
1274  rewriter.replaceOp(linalgOp,
1275  maybeTilingResult->loops.front()->getResults());
1276 
1277  tiled.append(maybeTilingResult->tiledOps);
1278  for (const auto &en2 : llvm::enumerate(maybeTilingResult->loops))
1279  loops[en2.index()].push_back(en2.value());
1280  }
1281 
1282  transformResults.set(getTiledLinalgOp().cast<OpResult>(), tiled);
1283  for (const auto &en : llvm::enumerate(loops))
1284  transformResults.set(getLoops()[en.index()].cast<OpResult>(), en.value());
1285 
1287 }
1288 
1290  ValueRange dynamic = getDynamicSizes();
1291  ArrayRef<int64_t> tileSizes = getStaticSizes();
1292  SmallVector<OpFoldResult> results;
1293  results.reserve(tileSizes.size());
1294  unsigned dynamicPos = 0;
1295  Builder builder(getContext());
1296  for (int64_t size : tileSizes) {
1297  if (size == ShapedType::kDynamic) {
1298  results.push_back(dynamic[dynamicPos++]);
1299  } else {
1300  results.push_back(builder.getIndexAttr(size));
1301  }
1302  }
1303  return results;
1304 }
1305 
1306 // We want to parse `DenseI64ArrayAttr` using the short form without the
1307 // `array` prefix to be consistent in the IR with `parseDynamicIndexList`.
1309  OperationState &result) {
1310  if (succeeded(parser.parseOptionalLBrace())) {
1311  if (failed(parser.parseKeyword("interchange")))
1312  return parser.emitError(parser.getNameLoc()) << "expect `interchange`";
1313  if (failed(parser.parseEqual()))
1314  return parser.emitError(parser.getNameLoc()) << "expect `=`";
1315  result.addAttribute("interchange",
1316  DenseI64ArrayAttr::parse(parser, Type{}));
1317  if (failed(parser.parseRBrace()))
1318  return parser.emitError(parser.getNameLoc()) << "expect `}`";
1319  }
1320  return success();
1321 }
1322 
1324  ArrayRef<int64_t> interchangeVals) {
1325  if (!interchangeVals.empty()) {
1326  p << " {interchange = [";
1327  llvm::interleaveComma(interchangeVals, p,
1328  [&](int64_t integer) { p << integer; });
1329  p << "]}";
1330  }
1331 }
1332 
1333 ParseResult transform::TileOp::parse(OpAsmParser &parser,
1334  OperationState &result) {
1337  DenseI64ArrayAttr staticSizes;
1338  auto pdlOperationType = pdl::OperationType::get(parser.getContext());
1339  if (parser.parseOperand(target) ||
1340  parser.resolveOperand(target, pdlOperationType, result.operands) ||
1341  parseDynamicIndexList(parser, dynamicSizes, staticSizes) ||
1342  parser.resolveOperands(dynamicSizes, pdlOperationType, result.operands))
1343  return ParseResult::failure();
1344 
1345  // Parse optional interchange.
1346  if (failed(parseOptionalInterchange(parser, result)))
1347  return ParseResult::failure();
1348  result.addAttribute(getStaticSizesAttrName(result.name), staticSizes);
1349  size_t numExpectedLoops =
1350  staticSizes.size() - llvm::count(staticSizes.asArrayRef(), 0);
1351  result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOperationType));
1352  return success();
1353 }
1354 
1355 void TileOp::print(OpAsmPrinter &p) {
1356  p << ' ' << getTarget();
1357  printDynamicIndexList(p, getOperation(), getDynamicSizes(), getStaticSizes());
1358  printOptionalInterchange(p, getInterchange());
1359 }
1360 
1361 void transform::TileOp::getEffects(
1363  consumesHandle(getTarget(), effects);
1364  onlyReadsHandle(getDynamicSizes(), effects);
1365  producesHandle(getTiledLinalgOp(), effects);
1366  producesHandle(getLoops(), effects);
1367  modifiesPayload(effects);
1368 }
1369 
1370 //===----------------------------------------------------------------------===//
1371 // TileToForeachThreadOp
1372 //===----------------------------------------------------------------------===//
1373 
1374 void transform::TileToForeachThreadOp::build(OpBuilder &builder,
1375  OperationState &result,
1376  Value target,
1377  ArrayRef<int64_t> staticTileSizes,
1379  ArrayAttr mapping) {
1380  return build(builder, result,
1381  /*target=*/target,
1382  /*mixedTileSizes=*/
1383  getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
1384  /*_=*/TileSizesSpec(),
1385  /*mapping=*/mapping);
1386 }
1387 
1388 void transform::TileToForeachThreadOp::build(
1389  OpBuilder &builder, OperationState &result, Value target,
1391  ArrayAttr mapping) {
1392  SmallVector<int64_t> staticTileSizes;
1393  SmallVector<Value> dynamicTileSizes;
1394  dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes,
1395  ShapedType::kDynamic);
1396  // Call the default builder which sets up the proper operands segment sizes
1397  // attributes for multiple variadic operands. In the absence of this, horrible
1398  // bugs ensue.
1399  MLIRContext *ctx = builder.getContext();
1400  auto operationType = pdl::OperationType::get(ctx);
1401  auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
1402  build(builder, result,
1403  /*resultTypes=*/TypeRange{operationType, operationType},
1404  /*target=*/target,
1405  /*num_threads=*/ValueRange{},
1406  /*tile_sizes=*/dynamicTileSizes,
1407  /*static_num_threads=*/builder.getDenseI64ArrayAttr({}),
1408  /*static_tile_sizes=*/staticTileSizesAttr,
1409  /*mapping=*/mapping);
1410 }
1411 
1412 void transform::TileToForeachThreadOp::build(OpBuilder &builder,
1413  OperationState &result,
1414  Value target,
1415  ArrayRef<int64_t> staticNumThreads,
1417  ArrayAttr mapping) {
1418  return build(builder, result, target,
1419  getAsOpFoldResult(builder.getI64ArrayAttr(staticNumThreads)),
1420  NumThreadsSpec(), mapping);
1421 }
1422 
1423 void transform::TileToForeachThreadOp::build(
1424  OpBuilder &builder, OperationState &result, Value target,
1426  ArrayAttr mapping) {
1427  SmallVector<int64_t> staticNumThreads;
1428  SmallVector<Value> dynamicNumThreads;
1429  dispatchIndexOpFoldResults(mixedNumThreads, dynamicNumThreads,
1430  staticNumThreads, ShapedType::kDynamic);
1431  // Call the default builder which sets up the proper operands segment sizes
1432  // attributes for multiple variadic operands. In the absence of this, horrible
1433  // bugs ensue.
1434  MLIRContext *ctx = builder.getContext();
1435  auto operationType = pdl::OperationType::get(ctx);
1436  auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads);
1437  build(builder, result,
1438  /*resultTypes=*/TypeRange{operationType, operationType},
1439  /*target=*/target,
1440  /*num_threads=*/dynamicNumThreads,
1441  /*tile_sizes=*/ValueRange{},
1442  /*static_num_threads=*/staticNumThreadsAttr,
1443  /*static_tile_sizes=*/builder.getDenseI64ArrayAttr({}),
1444  /*mapping=*/mapping);
1445 }
1446 
1447 // Given a list of OpFoldResults that are either index attrs or op
1448 // handles, return a list of OpFoldResults where all op handles are
1449 // replaced with the first (and only) OpResult of that payload op. (There
1450 // must be exactly one mapped payload op and it must have exactly one
1451 // index result.)
1453  transform::TransformState &state, TransformOpInterface transformOp,
1455  for (OpFoldResult ofr : ofrs) {
1456  // Don't try to unpack non-PDL operation.
1457  if (ofr.is<Attribute>() ||
1458  !ofr.get<Value>().getType().isa<pdl::OperationType>()) {
1459  result.push_back(ofr);
1460  continue;
1461  }
1462  ArrayRef<Operation *> payloadOps = state.getPayloadOps(ofr.get<Value>());
1463  for (Operation *op : payloadOps) {
1464  if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
1466  transformOp.emitSilenceableError()
1467  << "payload op must have exactly 1 index result";
1468  diag.attachNote(op->getLoc())
1469  << "has " << op->getNumResults() << " results";
1470  return diag;
1471  }
1472  result.push_back(op->getResult(0));
1473  }
1474  }
1475 
1477 }
1478 
1480  RewriterBase &rewriter, transform::TransformState &state,
1481  TransformOpInterface transformOp, ArrayRef<Operation *> targets,
1482  ArrayRef<OpFoldResult> mixedNumThreads,
1483  ArrayRef<OpFoldResult> mixedTileSizes, Optional<ArrayAttr> mapping,
1484  SmallVector<Operation *> &tileOps, SmallVector<Operation *> &tiledOps) {
1485  if (targets.empty())
1487 
1488  // getMixedNumThreads are OpFoldResults[index attributes or PDL operation].
1489  // Convert to OpFoldResults[index attributes or payload op].
1490  SmallVector<OpFoldResult> numThreads;
1492  unpackPDLOperations(state, transformOp, numThreads, mixedNumThreads);
1493  if (!status.succeeded())
1494  return status;
1495 
1496  // getMixedTileSizes are OpFoldResults[index attributes or PDL operation].
1497  // Convert to OpFoldResults[index attributes or payload op].
1498  SmallVector<OpFoldResult> tileSizes;
1499  status = unpackPDLOperations(state, transformOp, tileSizes, mixedTileSizes);
1500  if (!status.succeeded())
1501  return status;
1502 
1503  // Transform all targets one by one.
1504  for (Operation *target : targets) {
1505  auto tilableOp = dyn_cast<TilingInterface>(target);
1506  if (!tilableOp) {
1508  transformOp.emitSilenceableError()
1509  << "only TilingInterface ops are supported";
1510  diag.attachNote(target->getLoc()) << "target op";
1511  return diag;
1512  }
1513  rewriter.setInsertionPoint(tilableOp);
1515  if (!mixedNumThreads.empty()) {
1516  tilingResult = linalg::tileToForeachThreadOp(rewriter, tilableOp,
1517  numThreads, mapping);
1518  } else {
1520  rewriter, tilableOp, tileSizes, mapping);
1521  }
1522 
1523  if (failed(tilingResult))
1524  return transformOp.emitDefaultSilenceableFailure(tilableOp);
1525  rewriter.replaceOp(tilableOp, tilingResult->tileOp->getResults());
1526 
1527  tileOps.push_back(tilingResult->tileOp);
1528  tiledOps.push_back(tilingResult->tiledOp);
1529  }
1531 }
1532 
1533 DiagnosedSilenceableFailure transform::TileToForeachThreadOp::apply(
1534  transform::TransformResults &transformResults,
1535  transform::TransformState &state) {
1536  IRRewriter rewriter(getContext());
1537  ArrayRef<Operation *> targets = state.getPayloadOps(getTarget());
1538 
1539  // Result payload ops.
1540  SmallVector<Operation *> tileOps;
1541  SmallVector<Operation *> tiledOps;
1542 
1544  rewriter, state, cast<TransformOpInterface>(getOperation()), targets,
1545  getMixedNumThreads(), getMixedTileSizes(), getMapping(), tileOps,
1546  tiledOps);
1547 
1548  if (!diag.succeeded()) {
1549  transformResults.set(getForeachThreadOp().cast<OpResult>(), {});
1550  transformResults.set(getTiledOp().cast<OpResult>(), {});
1551  return diag;
1552  }
1553 
1554  transformResults.set(getForeachThreadOp().cast<OpResult>(), tileOps);
1555  transformResults.set(getTiledOp().cast<OpResult>(), tiledOps);
1556 
1558 }
1559 
1560 void transform::TileToForeachThreadOp::getEffects(
1562  consumesHandle(getTarget(), effects);
1563  onlyReadsHandle(getTileSizes(), effects);
1564  onlyReadsHandle(getNumThreads(), effects);
1565  producesHandle(getResults(), effects);
1566 }
1567 
1568 SmallVector<OpFoldResult> TileToForeachThreadOp::getMixedNumThreads() {
1569  Builder b(getContext());
1570  return getMixedValues(getStaticNumThreads(), getNumThreads(), b);
1571 }
1572 
1573 SmallVector<OpFoldResult> TileToForeachThreadOp::getMixedTileSizes() {
1574  Builder b(getContext());
1575  return getMixedValues(getStaticTileSizes(), getTileSizes(), b);
1576 }
1577 
1579  if (getMixedNumThreads().empty() == getMixedTileSizes().empty())
1580  return emitOpError("either num_threads or tile_sizes must be specified");
1581  return success();
1582 }
1583 
1584 //===----------------------------------------------------------------------===//
1585 // TileToScfForOp
1586 //===----------------------------------------------------------------------===//
1587 
1589 transform::TileToScfForOp::apply(TransformResults &transformResults,
1590  TransformState &state) {
1591  ArrayRef<int64_t> tileSizes = getStaticSizes();
1592 
1593  ArrayRef<Operation *> targets = state.getPayloadOps(getTarget());
1594  SmallVector<ArrayRef<Operation *>> dynamicSizeProducers;
1595  dynamicSizeProducers.reserve(getDynamicSizes().size());
1596  for (Value dynamicSizeProducerHandle : getDynamicSizes()) {
1597  dynamicSizeProducers.push_back(
1598  state.getPayloadOps(dynamicSizeProducerHandle));
1599 
1600  if (dynamicSizeProducers.back().size() != targets.size()) {
1602  emitSilenceableError()
1603  << "expected as many dynamic size-producing operations ("
1604  << dynamicSizeProducers.back().size() << ") as target ops ("
1605  << targets.size() << ")";
1606  diag.attachNote(dynamicSizeProducerHandle.getLoc()) << "for this handle";
1607  return diag;
1608  }
1609 
1610  for (Operation *op : dynamicSizeProducers.back()) {
1611  if (op->getNumResults() == 1 &&
1612  op->getResult(0).getType().isa<IndexType>())
1613  continue;
1615  emitSilenceableError() << "expected sizes to be produced by ops "
1616  "with a single index-type result";
1617  diag.attachNote(op->getLoc()) << "size producer op";
1618  diag.attachNote(dynamicSizeProducerHandle.getLoc()) << "for this handle";
1619  return diag;
1620  }
1621  }
1622 
1625  loops.resize(getLoops().size());
1626  for (auto &en : llvm::enumerate(targets)) {
1627  auto tilingInterfaceOp = dyn_cast<TilingInterface>(en.value());
1628  if (!tilingInterfaceOp) {
1630  emitSilenceableError() << "only TilingInterface ops are supported";
1631  diag.attachNote(en.value()->getLoc()) << "target op";
1632  return diag;
1633  }
1634 
1635  scf::SCFTilingOptions tilingOptions;
1636  unsigned index = en.index();
1637  if (!tileSizes.empty()) {
1638  tilingOptions.setTileSizeComputationFunction(
1639  [&, index](OpBuilder &b, Operation *) {
1640  SmallVector<Value, 4> sizes;
1641  sizes.reserve(tileSizes.size());
1642  unsigned dynamicIdx = 0;
1643  for (OpFoldResult ofr : getMixedSizes()) {
1644  if (auto attr = ofr.dyn_cast<Attribute>()) {
1645  sizes.push_back(b.create<arith::ConstantIndexOp>(
1646  getLoc(), attr.cast<IntegerAttr>().getInt()));
1647  } else {
1648  sizes.push_back(
1649  dynamicSizeProducers[dynamicIdx++][index]->getResult(0));
1650  }
1651  }
1652  return sizes;
1653  });
1654  }
1655 
1656  tilingOptions.setInterchange(getInterchange());
1657  SimpleRewriter rewriter(tilingInterfaceOp.getContext());
1658  FailureOr<scf::SCFTilingResult> tilingResult =
1659  tileUsingSCFForOp(rewriter, tilingInterfaceOp, tilingOptions);
1660  if (failed(tilingResult))
1662 
1663  rewriter.replaceOp(tilingInterfaceOp, tilingResult->replacements);
1664 
1665  tiled.append(tilingResult->tiledOps);
1666  for (const auto &en2 : llvm::enumerate(tilingResult->loops))
1667  loops[en2.index()].push_back(en2.value());
1668  }
1669 
1670  transformResults.set(getTiledLinalgOp().cast<OpResult>(), tiled);
1671  for (const auto &en : llvm::enumerate(loops))
1672  transformResults.set(getLoops()[en.index()].cast<OpResult>(), en.value());
1673 
1675 }
1676 
1678  ValueRange dynamic = getDynamicSizes();
1679  ArrayRef<int64_t> tileSizes = getStaticSizes();
1680  SmallVector<OpFoldResult> results;
1681  results.reserve(tileSizes.size());
1682  unsigned dynamicPos = 0;
1683  Builder builder(getContext());
1684  for (int64_t size : tileSizes) {
1685  if (size == ShapedType::kDynamic) {
1686  results.push_back(dynamic[dynamicPos++]);
1687  } else {
1688  results.push_back(builder.getIndexAttr(size));
1689  }
1690  }
1691  return results;
1692 }
1693 
1694 ParseResult transform::TileToScfForOp::parse(OpAsmParser &parser,
1695  OperationState &result) {
1698  DenseI64ArrayAttr staticSizes;
1699  auto pdlOperationType = pdl::OperationType::get(parser.getContext());
1700  if (parser.parseOperand(target) ||
1701  parser.resolveOperand(target, pdlOperationType, result.operands) ||
1702  parseDynamicIndexList(parser, dynamicSizes, staticSizes) ||
1703  parser.resolveOperands(dynamicSizes, pdlOperationType, result.operands))
1704  return ParseResult::failure();
1705 
1706  // Parse optional interchange.
1707  if (failed(parseOptionalInterchange(parser, result)))
1708  return ParseResult::failure();
1709  result.addAttribute(getStaticSizesAttrName(result.name), staticSizes);
1710  size_t numExpectedLoops =
1711  staticSizes.size() - llvm::count(staticSizes.asArrayRef(), 0);
1712  result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOperationType));
1713  return success();
1714 }
1715 
1717  p << ' ' << getTarget();
1718  printDynamicIndexList(p, getOperation(), getDynamicSizes(), getStaticSizes());
1719  printOptionalInterchange(p, getInterchange());
1720 }
1721 
1722 void transform::TileToScfForOp::getEffects(
1724  consumesHandle(getTarget(), effects);
1725  onlyReadsHandle(getDynamicSizes(), effects);
1726  producesHandle(getTiledLinalgOp(), effects);
1727  producesHandle(getLoops(), effects);
1728  modifiesPayload(effects);
1729 }
1730 
1731 //===----------------------------------------------------------------------===//
1732 // VectorizeOp
1733 //===----------------------------------------------------------------------===//
1734 
1735 void transform::VectorizeOp::build(OpBuilder &builder, OperationState &result,
1736  Value target, bool vectorizePadding) {
1737  result.addOperands(target);
1738  if (vectorizePadding) {
1739  result.addAttribute(VectorizeOp::getVectorizePaddingAttrName(result.name),
1740  builder.getUnitAttr());
1741  }
1742  result.addTypes(pdl::OperationType::get(builder.getContext()));
1743 }
1744 
1745 namespace {
1746 /// This is an helper only to call vectorize via a pattern inside of
1747 /// VectorizeOp::applyToOne.
1748 struct VectorizationPattern : public RewritePattern {
1749  explicit VectorizationPattern(MLIRContext *context)
1750  : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
1751  LogicalResult matchAndRewrite(Operation *op,
1752  PatternRewriter &rewriter) const override {
1753  LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
1754  if (!linalgOp)
1755  return failure();
1756  return vectorize(rewriter, linalgOp);
1757  }
1758 };
1759 } // namespace
1760 
1762 transform::VectorizeOp::applyToOne(Operation *target,
1764  transform::TransformState &state) {
1765  if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
1766  auto diag = this->emitOpError("requires isolated-from-above targets");
1767  diag.attachNote(target->getLoc()) << "non-isolated target";
1769  }
1770 
1771  MLIRContext *ctx = getContext();
1772  RewritePatternSet patterns(ctx);
1773  patterns.add<VectorizationPattern>(ctx);
1774 
1775  if (!getDisableTransferPermutationMapLoweringPatterns())
1777 
1778  if (!getDisableMultiReductionToContractPatterns())
1780 
1783  /*benefit=*/2);
1784  vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx);
1785  vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx);
1786 
1787  patterns.add<CopyVectorizationPattern>(ctx);
1788 
1789  if (getVectorizePadding())
1791 
1792  if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns))))
1793  return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
1794 
1795  results.push_back(target);
1797 }
1798 
1799 //===----------------------------------------------------------------------===//
1800 // Transform op registration
1801 //===----------------------------------------------------------------------===//
1802 
1803 namespace {
1804 /// Registers new ops and declares PDL as dependent dialect since the
1805 /// additional ops are using PDL types for operands and results.
1806 class LinalgTransformDialectExtension
1808  LinalgTransformDialectExtension> {
1809 public:
1810  using Base::Base;
1811 
1812  void init() {
1813  declareDependentDialect<pdl::PDLDialect>();
1814  declareDependentDialect<LinalgDialect>();
1815  declareGeneratedDialect<AffineDialect>();
1816  declareGeneratedDialect<arith::ArithDialect>();
1817  declareGeneratedDialect<scf::SCFDialect>();
1818  declareGeneratedDialect<vector::VectorDialect>();
1819  declareGeneratedDialect<gpu::GPUDialect>();
1820 
1821  registerTransformOps<
1822 #define GET_OP_LIST
1823 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
1824  >();
1825  }
1826 };
1827 } // namespace
1828 
1829 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
1830 
1831 #define GET_OP_CLASSES
1832 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
1833 
1835  DialectRegistry &registry) {
1836  registry.addExtensions<LinalgTransformDialectExtension>();
1837 }
static std::string diag(llvm::Value &value)
static constexpr const bool value
static Operation * cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
static ParseResult parseTileLikeOp(OpAsmParser &parser, OperationState &result, StringRef sizesAttrName)
Parse a tiling-like operation that returns the tiled op as well as the created tile loops.
void printOptionalInterchange(OpAsmPrinter &p, ArrayRef< int64_t > interchangeVals)
static DiagnosedSilenceableFailure unpackPDLOperations(transform::TransformState &state, TransformOpInterface transformOp, SmallVector< OpFoldResult > &result, ArrayRef< OpFoldResult > ofrs)
static Operation * tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
Find the first "extract" user of producerOp and tile it right before its use.
static FailureOr< LinalgOp > tryApply(Operation *operation, Args &&...args)
Attempts to apply the pattern specified as template argument to the given operation.
ParseResult parseOptionalInterchange(OpAsmParser &parser, OperationState &result)
static LogicalResult applyTilingToAll(Operation *transformOp, ArrayRef< Operation * > payloadOps, unsigned numLoops, transform::TransformResults &transformResults, function_ref< FailureOr< scf::SCFTileAndFuseResult >(TilingInterface)> applyFn)
Apply a tiling transformation to all payload ops and store both the tiled operation as well as the cr...
static Operation * tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
First, find the first "scf::ForeachThreadOp" user of producerOp and ensure it is exactly the containi...
static SmallVector< unsigned > extractUIntArray(ArrayAttr attr)
Extracts a vector of unsigned from an array attribute.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static void getDynamicSizes(RankedTensorType tp, const SmallVectorImpl< Value > &sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static FailureOr< ForeachThreadTilingResult > tileToForeachThreadOpImpl(RewriterBase &b, TilingInterface op, ArrayRef< OpFoldResult > numThreads, Optional< ArrayRef< OpFoldResult >> nominalTileSizes, Optional< ArrayAttr > mapping, bool omitTileOffsetBoundsCheck)
Rewrite a TilingInterface op to a tiled scf.foreach_thread.
Definition: Tiling.cpp:296
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:42
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:67
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
U dyn_cast() const
Definition: Attributes.h:127
void map(Block *from, Block *to)
Inserts a new mapping for 'from' to 'to'.
This class represents an argument of a Block.
Definition: Value.h:296
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:49
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:109
UnitAttr getUnitAttr()
Definition: Builders.cpp:99
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:157
AffineExpr getAffineSymbolExpr(unsigned position)
Definition: Builders.cpp:331
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:113
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:88
MLIRContext * getContext() const
Definition: Builders.h:54
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:262
ArrayAttr getStrArrayAttr(ArrayRef< StringRef > values)
Definition: Builders.cpp:288
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:137
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:589
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:64
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:56
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:150
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:300
This class helps build Operations.
Definition: Builders.h:198
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:350
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:422
Operation * clone(Operation &op, BlockAndValueMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:510
This class represents a single result from folding an operation.
Definition: OpDefinition.h:233
This class represents an operand of an operation.
Definition: Value.h:247
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:212
This is a value defined by a result of an operation.
Definition: Value.h:442
This class provides the API for ops that are known to be isolated from above.
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:31
OpResult getOpResult(unsigned idx)
Definition: Operation.h:338
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:528
void setOperand(unsigned idx, Value value)
Definition: Operation.h:268
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:324
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:147
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:154
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:225
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:650
result_range getOpResults()
Definition: Operation.h:337
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:176
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:321
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:37
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:47
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:605
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:244
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:398
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:499
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:30
bool isa() const
Definition: Types.h:260
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:349
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
Type getType() const
Return the type of this value.
Definition: Value.h:114
U cast() const
Definition: Value.h:105
user_range getUsers() const
Definition: Value.h:209
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static Attribute parse(AsmParser &parser, Type type)
Parse the short form [42, 100, -1] without any type prefix.
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:40
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void set(OpResult value, ArrayRef< Operation * > ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
The state maintained across applications of various ops implementing the TransformOpInterface.
ArrayRef< Operation * > getPayloadOps(Value value) const
Returns the list of ops that the given transform IR value corresponds to.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:230
void registerTransformDialectExtension(DialectRegistry &registry)
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
FailureOr< ForeachThreadTilingResult > tileToForeachThreadOp(RewriterBase &builder, TilingInterface op, ArrayRef< OpFoldResult > numThreads, Optional< ArrayAttr > mapping)
Definition: Tiling.cpp:382
FailureOr< SplitReductionResult > splitReduction(PatternRewriter &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
LogicalResult promoteSubviewsPrecondition(Operation *op, LinalgPromotionOptions options)
Promote memref.subviews feeding linalg-on-buffers operations.
Definition: Promotion.cpp:368
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.
Definition: Interchange.cpp:51
FailureOr< SplitReductionResult > splitReductionByScaling(PatternRewriter &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
Scaling-based implementation of the split reduction transformation.
FailureOr< ForeachThreadReductionTilingResult > tileReductionUsingForeachThread(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > numThreads, Optional< ArrayAttr > mapping)
Method to tile a reduction to parallel iterations computing partial reductions.
Definition: Tiling.cpp:414
FailureOr< LinalgOp > promoteSubViews(OpBuilder &b, LinalgOp op, const LinalgPromotionOptions &options)
Promote the subViews into a new buffer allocated at the insertion point b.
Definition: Promotion.cpp:390
std::function< SplitReductionOptions(LinalgOp op)> ControlSplitReductionFn
Function signature to control reduction splitting.
Definition: Transforms.h:961
LogicalResult vectorize(RewriterBase &builder, LinalgOp linalgOp)
Emit a suitable vector form for a Linalg op with fully static shape.
FailureOr< ForeachThreadTilingResult > tileToForeachThreadOpUsingTileSizes(RewriterBase &builder, TilingInterface op, ArrayRef< OpFoldResult > tileSizes, Optional< ArrayAttr > mapping)
Same as tileToForeachThreadOp, but calculate the number of threads required using the given tileSizes...
Definition: Tiling.cpp:391
std::pair< TilingInterface, TilingInterface > splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension, OpFoldResult splitPoint)
Split the given op into two parts along the given iteration space dimension at the specified splitPoi...
Definition: Split.cpp:67
FailureOr< MultiSizeSpecification > computeMultiTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension, OpFoldResult targetSize, OpFoldResult divisor, bool emitAssertions=true)
Emits the IR computing the multi-sized tiling specification with two tile sizes not exceeding targetS...
Definition: Tiling.cpp:117
FailureOr< scf::SCFReductionTilingResult > tileReductionUsingScf(PatternRewriter &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > tileSize)
Method to tile a reduction and generate a parallel op within a serial loop.
FailureOr< SCFTilingResult > tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op, const SCFTilingOptions &options)
Method to tile an op that implements the TilingInterface using scf.for for iterating over the tiles.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:46
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:100
void onlyReadsHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void consumesHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
DiagnosedSilenceableFailure tileToForeachThreadOpImpl(RewriterBase &rewriter, transform::TransformState &state, TransformOpInterface transformOp, ArrayRef< Operation * > targets, ArrayRef< OpFoldResult > mixedNumThreads, ArrayRef< OpFoldResult > mixedTileSizes, Optional< ArrayAttr > mapping, SmallVector< Operation * > &tileOps, SmallVector< Operation * > &tiledOps)
Implementation of tiling operations using scf.foreach_thread.
void producesHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
SmallVector< int64_t, 4 > extractFromI64ArrayAttr(Attribute attr)
Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers)
Printer hook for custom directive in assemblyFormat.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec, int64_t sentinel)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ValueRange operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:964
Optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
LogicalResult applyPatternsAndFoldGreedily(MutableArrayRef< Region > regions, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig())
Rewrite the regions of the specified operation, which must be isolated from above,...
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
Definition: AffineOps.cpp:1037
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers)
Pasrer hook for custom directive in assemblyFormat.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:372
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context)
This parses a single MLIR attribute to an MLIR context if it was valid.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
static LogicalResult failure(bool isFailure=true)
If isFailure is true a failure result is generated, otherwise a 'success' result is generated.
Definition: LogicalResult.h:36
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
NamedAttrList attributes
Vectorization pattern for memref::CopyOp.
Definition: Transforms.h:776
Rewrites 2-D convolution ops with size-1 window dimensions into 1-D convolution ops.
Definition: Transforms.h:715
Match and rewrite for the pattern:
Definition: Transforms.h:884
Match and rewrite for the pattern:
Definition: Transforms.h:912
LinalgPaddingOptions & setPaddingDimensions(ArrayRef< int64_t > pd)
Definition: Transforms.h:563
LinalgPaddingOptions & setTransposePaddings(ArrayRef< SmallVector< int64_t >> tp)
Definition: Transforms.h:584
LinalgPaddingOptions & setPaddingValues(ArrayRef< Attribute > pv)
Definition: Transforms.h:557
LinalgPaddingOptions & setPackPaddings(ArrayRef< bool > pp)
Definition: Transforms.h:570
LinalgPaddingOptions & setHoistPaddings(ArrayRef< int64_t > hp)
Definition: Transforms.h:576
LinalgPromotionOptions & setUseFullTileBuffersByDefault(bool use)
Definition: Transforms.h:274
LinalgPromotionOptions & setAlignment(unsigned align)
Definition: Transforms.h:280
LinalgPromotionOptions & setUseAlloca(bool use)
Definition: Transforms.h:286
LinalgPromotionOptions & setUseFullTileBuffers(ArrayRef< bool > useFullTiles)
Definition: Transforms.h:263
LinalgPromotionOptions & setOperandsToPromote(ArrayRef< int64_t > operands)
Definition: Transforms.h:252
Split Reduction options.
Definition: Transforms.h:946
Options used to control tile + fuse.
SCFTilingOptions tilingOptions
The tiling options used to control the tiling of the consumer.
Options to use to control tiling.
SCFTilingOptions & setTileSizeComputationFunction(SCFTileSizeComputationFunction fun)
SCFTilingOptions & setInterchange(ArrayRef< int64_t > interchange)
SmallVector< int64_t > interchangeVector
The interchange vector to reorder the tiled loops.
SCFTilingOptions & setTileSizes(const SmallVector< Value, 4 > &ts)
Set the tileSizeComputationFunction to return the values ts.