MLIR  18.0.0git
LinalgTransformOps.cpp
Go to the documentation of this file.
1 //===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
12 
37 #include "mlir/IR/PatternMatch.h"
38 #include "mlir/IR/TypeUtilities.h"
40 #include "mlir/Support/LLVM.h"
41 #include "mlir/Support/TypeID.h"
43 #include "llvm/ADT/STLExtras.h"
44 #include "llvm/ADT/ScopeExit.h"
45 #include "llvm/ADT/TypeSwitch.h"
46 #include "llvm/Support/Debug.h"
47 #include <type_traits>
48 
49 using namespace mlir;
50 using namespace mlir::linalg;
51 using namespace mlir::transform;
52 
53 #define DEBUG_TYPE "linalg-transforms"
54 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
55 #define DBGSNL() (llvm::dbgs() << "\n")
56 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
57 
58 /// Attempts to apply the pattern specified as template argument to the given
59 /// operation. The pattern is expected to have a `returningMatchAndRewrite`
60 /// function that returns the "main" result or failure. Returns failure if the
61 /// pattern failed to apply. Extra arguments are forwarded to the pattern
62 /// constructor.
63 template <typename PatternTy, typename... Args>
64 static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
65  // Check if the given operation has the type expected by the pattern.
66  using OpTy = typename llvm::function_traits<
67  decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
68  auto op = dyn_cast<OpTy>(operation);
69  if (!op)
70  return failure();
71 
72  // Apply the pattern directly to the op.
73  PatternTy pattern(operation->getContext(), std::forward<Args>(args)...);
74  // We want to discourage direct use of PatternRewriter in APIs but In this
75  // very specific case, an IRRewriter is not enough.
76  struct TrivialPatternRewriter : public PatternRewriter {
77  public:
78  explicit TrivialPatternRewriter(MLIRContext *context)
79  : PatternRewriter(context) {}
80  };
81  TrivialPatternRewriter rewriter(operation->getContext());
82  rewriter.setInsertionPoint(operation);
83  auto result = pattern.returningMatchAndRewrite(op, rewriter);
84  if (failed(result))
85  return failure();
86  return cast<LinalgOp>(result->getOperation());
87 }
88 
89 /// Assuming that `ofr` is an index attr or a transform dialect handle mapped
90 /// to exactly one op with one index result, return that value.
92  transform::TransformState &state, TransformOpInterface transformOp,
94  for (OpFoldResult ofr : ofrs) {
95  if (ofr.is<Attribute>()) {
96  if (!isa<IntegerAttr>(ofr.get<Attribute>()))
97  return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
98  result.push_back(ofr);
99  continue;
100  }
101  auto payloadOps = state.getPayloadOps(ofr.get<Value>());
102  if (!llvm::hasSingleElement(payloadOps)) {
104  transformOp.emitSilenceableError()
105  << "handle must be mapped to exactly one payload op";
106  diag.attachNote(ofr.get<Value>().getLoc())
107  << "mapped to " << llvm::range_size(payloadOps) << " payload ops";
108  return diag;
109  }
110 
111  Operation *op = *payloadOps.begin();
112  if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
114  transformOp.emitSilenceableError()
115  << "payload op must have exactly 1 index result";
116  diag.attachNote(op->getLoc())
117  << "has " << op->getNumResults() << " results";
118  return diag;
119  }
120  result.push_back(op->getResult(0));
121  }
122 
124 }
125 
126 // Given a list of OpFoldResults that are either index attrs or op
127 // handles, return a list of OpFoldResults where all op handles are
128 // replaced with the first (and only) OpResult of that payload op. (There
129 // must be exactly one mapped payload op and it must have exactly one
130 // index result.)
132  transform::TransformState &state, TransformOpInterface transformOp,
133  SmallVector<OpFoldResult> &result, Value packedHandle) {
134  for (Operation *op : state.getPayloadOps(packedHandle)) {
135  if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
137  transformOp.emitSilenceableError()
138  << "payload op must have exactly 1 index result";
139  diag.attachNote(op->getLoc())
140  << "has " << op->getNumResults() << " results";
141  return diag;
142  }
143  result.push_back(op->getResult(0));
144  }
145 
147 }
148 
149 //===----------------------------------------------------------------------===//
150 // Apply...PatternsOp
151 //===----------------------------------------------------------------------===//
152 
153 void transform::ApplyEraseUnnecessaryInputsPatternsOp::populatePatterns(
154  RewritePatternSet &patterns) {
156 }
157 
158 void transform::ApplyFoldUnitExtentDimsViaReshapesPatternsOp::populatePatterns(
159  RewritePatternSet &patterns) {
162 }
163 
164 void transform::ApplyFoldUnitExtentDimsViaSlicesPatternsOp::populatePatterns(
165  RewritePatternSet &patterns) {
167  options.rankReductionStrategy =
170 }
171 
172 void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns(
173  RewritePatternSet &patterns) {
175 }
176 
177 //===----------------------------------------------------------------------===//
178 // BufferizeToAllocationOp
179 //===----------------------------------------------------------------------===//
180 
181 void transform::BufferizeToAllocationOp::build(OpBuilder &b,
182  OperationState &result,
183  Value target,
184  Attribute memorySpace) {
185  SmallVector<Type> resultTypes;
186  resultTypes.push_back(b.getType<transform::AnyValueType>());
187  resultTypes.push_back(b.getType<transform::AnyOpType>());
188  return build(b, result,
189  /*resultTypes=*/resultTypes,
190  /*target=*/target,
191  /*memorySpace=*/memorySpace);
192 }
193 
194 void transform::BufferizeToAllocationOp::build(OpBuilder &b,
195  OperationState &result,
196  Value target,
197  int64_t memorySpace) {
198  SmallVector<Type> resultTypes;
199  resultTypes.push_back(b.getType<transform::AnyValueType>());
200  resultTypes.push_back(b.getType<transform::AnyOpType>());
201  return build(b, result,
202  /*resultTypes=*/resultTypes,
203  /*target=*/target,
204  /*memorySpace=*/b.getI64IntegerAttr(memorySpace));
205 }
206 
207 namespace {
208 class NewOpsListener : public RewriterBase::ForwardingListener {
209 public:
211 
212  SmallVector<Operation *> getNewOps() const {
213  return SmallVector<Operation *>(newOps.begin(), newOps.end());
214  }
215 
216 private:
217  void notifyOperationInserted(Operation *op) override {
218  ForwardingListener::notifyOperationInserted(op);
219  auto inserted = newOps.insert(op);
220  (void)inserted;
221  assert(inserted.second && "expected newly created op");
222  }
223 
224  void notifyOperationRemoved(Operation *op) override {
225  ForwardingListener::notifyOperationRemoved(op);
226  op->walk([&](Operation *op) { newOps.erase(op); });
227  }
228 
229  DenseSet<Operation *> newOps;
230 };
231 } // namespace
232 
233 DiagnosedSilenceableFailure transform::BufferizeToAllocationOp::apply(
236  // Attach listener to keep track of newly created ops.
237  OpBuilder::Listener *previousListener = rewriter.getListener();
238  auto resetListener =
239  llvm::make_scope_exit([&]() { rewriter.setListener(previousListener); });
240  NewOpsListener newOpsListener(previousListener);
241  rewriter.setListener(&newOpsListener);
242 
244  if (getMemcpyOp() == "bufferization.materialize_in_destination") {
247  } else if (getMemcpyOp() == "memref.copy") {
248  options.memcpyOp =
250  } else if (getMemcpyOp() == "linalg.copy") {
251  options.memcpyOp =
253  } else {
254  llvm_unreachable("invalid memcpy op");
255  }
256  if (getAllocOp() == "memref.alloc") {
257  options.allocOp =
259  } else if (getAllocOp() == "memref.alloca") {
260  options.allocOp =
262  } else {
263  llvm_unreachable("invalid alloc op");
264  }
265  options.bufferizeDestinationOnly = getBufferizeDestinationOnly();
266  options.emitDealloc = getEmitDealloc();
267 
268  // Bufferize ops.
269  Attribute memorySpace =
270  getMemorySpace().has_value() ? getMemorySpace().value() : Attribute();
271  SmallVector<Value> allocatedBuffers;
272  for (Operation *op : state.getPayloadOps(getTarget())) {
273  Value buffer =
274  linalg::bufferizeToAllocation(rewriter, options, op, memorySpace);
275  if (!buffer) {
276  DiagnosedSilenceableFailure diag = emitSilenceableError()
277  << "failed to bufferize operation";
278  diag.attachNote(op->getLoc()) << "target payload op";
279  return diag;
280  }
281  allocatedBuffers.push_back(buffer);
282  }
283 
284  // Set results.
285  results.setValues(cast<OpResult>(getAllocatedBuffer()), allocatedBuffers);
286  results.set(cast<OpResult>(getNewOps()), newOpsListener.getNewOps());
288 }
289 
290 void transform::BufferizeToAllocationOp::getEffects(
292  if (getBufferizeDestinationOnly()) {
293  // The destination is replaced with a newly allocated buffer, but the op
294  // itself remains in place.
295  onlyReadsHandle(getTarget(), effects);
296  } else {
297  consumesHandle(getTarget(), effects);
298  }
299  producesHandle(getAllocatedBuffer(), effects);
300  producesHandle(getNewOps(), effects);
301  modifiesPayload(effects);
302 }
303 
305  if (getMemcpyOp() != "bufferization.materialize_in_destination" &&
306  getMemcpyOp() != "memref.copy" && getMemcpyOp() != "linalg.copy")
307  return emitOpError() << "unsupported memcpy op";
308  if (getAllocOp() != "memref.alloc" && getAllocOp() != "memref.alloca")
309  return emitOpError() << "unsupported alloc op";
310  return success();
311 }
312 
313 //===----------------------------------------------------------------------===//
314 // DecomposeOp
315 //===----------------------------------------------------------------------===//
316 
318 transform::DecomposeOp::applyToOne(transform::TransformRewriter &rewriter,
319  LinalgOp target,
321  transform::TransformState &state) {
322 #define DOWNSCALE(trans) \
323  { \
324  FailureOr<LinalgOp> res = tryApply<trans>(target); \
325  if (succeeded(res)) { \
326  results.push_back(*res); \
327  return DiagnosedSilenceableFailure::success(); \
328  } \
329  }
330 
331 #define DOWNSCALE_CALL(a, b) DownscaleSizeOneWindowed2DConvolution<a, b>
332 #define DOWNSCALE_NORMAL(a, b) DOWNSCALE(DOWNSCALE_CALL(a, b))
333 
334  DOWNSCALE_NORMAL(Conv2DNhwcHwcfOp, Conv1DNwcWcfOp)
335  DOWNSCALE_NORMAL(Conv2DNchwFchwOp, Conv1DNcwFcwOp)
336  DOWNSCALE_NORMAL(PoolingNhwcSumOp, PoolingNwcSumOp)
337  DOWNSCALE_NORMAL(PoolingNchwSumOp, PoolingNcwSumOp)
338  DOWNSCALE_NORMAL(PoolingNhwcMaxOp, PoolingNwcMaxOp)
339  DOWNSCALE_NORMAL(PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp)
340  DOWNSCALE_NORMAL(PoolingNhwcMinOp, PoolingNwcMinOp)
341  DOWNSCALE_NORMAL(PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp)
342  DOWNSCALE_NORMAL(PoolingNchwMaxOp, PoolingNcwMaxOp)
345 #undef DOWNSCALE_NORMAL
346 #undef DOWNSCALE_CALL
347 #undef DOWNSCALE
348  return emitDefaultSilenceableFailure(target);
349 }
350 
351 //===----------------------------------------------------------------------===//
352 // DecomposeInterfaceOp
353 //===----------------------------------------------------------------------===//
354 
355 // Decompose the target operation if it implements the AggregatedOpInterface.
356 // Push the decomposed operations (the ones that replaces the values produced by
357 // \p target) in the `results`.
358 DiagnosedSilenceableFailure transform::DecomposeInterfaceOp::applyToOne(
359  transform::TransformRewriter &rewriter, Operation *target,
361  transform::TransformState &state) {
362  auto decomposableOp = dyn_cast<AggregatedOpInterface>(target);
363  if (!decomposableOp) {
364  failed(rewriter.notifyMatchFailure(target,
365  "payload is not a decomposable op"));
366  return emitDefaultSilenceableFailure(target);
367  }
368 
369  FailureOr<SmallVector<Value>> maybeNewResults =
370  decomposableOp.decomposeOperation(rewriter);
371  if (failed(maybeNewResults))
372  return emitDefaultSilenceableFailure(target);
373 
374  rewriter.replaceOp(decomposableOp, *maybeNewResults);
375  for (Value val : *maybeNewResults) {
376  Operation *definition = val.getDefiningOp();
377  if (definition)
378  results.push_back(definition);
379  }
381 }
382 
383 //===----------------------------------------------------------------------===//
384 // EliminateLinalgOpAnchoredEmptyTensorsOp
385 //===----------------------------------------------------------------------===//
386 
387 void transform::EliminateLinalgOpAnchoredEmptyTensorsOp::getEffects(
389  onlyReadsHandle(getTarget(), effects);
390  modifiesPayload(effects);
391 }
392 
394 transform::EliminateLinalgOpAnchoredEmptyTensorsOp::apply(
395  transform::TransformRewriter &rewriter, TransformResults &transformResults,
396  TransformState &state) {
398  options.allowReturnAllocsFromLoops = true;
399 
400  for (Operation *target : state.getPayloadOps(getTarget())) {
402  if (failed(analyzeOp(target, state)))
403  return mlir::emitSilenceableFailure(target->getLoc())
404  << "failed to analyze op";
406  rewriter, target, state)))
407  return mlir::emitSilenceableFailure(target->getLoc())
408  << "failed to eliminate LinalgOp anchored tensor.empty ops";
409  }
411 }
412 
413 //===----------------------------------------------------------------------===//
414 // FuseOp
415 //===----------------------------------------------------------------------===//
416 
417 /// Apply a tiling transformation to all payload ops and store both the
418 /// tiled operation as well as the created tile loops.
419 template <typename Range>
421  RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps,
422  unsigned numLoops, transform::TransformResults &transformResults,
424  applyFn) {
425  SmallVector<Operation *> tiledLinalgOps;
426  SmallVector<SmallVector<Operation *>> loopOps(numLoops);
427 
428  for (Operation *target : payloadOps) {
429  auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
430  if (!tilingInterfaceOp)
431  return transformOp->emitError("only TilingInterface ops are supported");
432 
433  rewriter.setInsertionPoint(target);
435  applyFn(tilingInterfaceOp);
436  if (failed(tiledResults))
437  return failure();
438 
439  // Perform the replacement of tiled and fused values.
440  SmallVector<Operation *> opsToReplace{target};
441  llvm::append_range(opsToReplace, tiledResults->fusedProducers);
442  for (Operation *toReplace : opsToReplace) {
443  for (OpResult res : toReplace->getResults())
444  if (auto replacement = tiledResults->replacements.lookup(res))
445  rewriter.replaceAllUsesWith(res, replacement);
446  if (toReplace->use_empty()) {
447  rewriter.eraseOp(toReplace);
448  }
449  }
450 
451  // Report back the relevant handles to the transform op.
452  tiledLinalgOps.push_back(tiledResults->tiledAndFusedOps.front());
453  assert(tiledResults->loops.size() == numLoops &&
454  "Mismatched number of loops, tile and fuse transform should have "
455  "failed");
456  for (unsigned int i = 0; i < numLoops; ++i)
457  loopOps[i].push_back(tiledResults->loops[i]);
458  }
459 
460  transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
461  for (unsigned int i = 0; i < numLoops; ++i)
462  transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
463 
464  return success();
465 }
466 
468 transform::FuseOp::apply(transform::TransformRewriter &rewriter,
469  mlir::transform::TransformResults &transformResults,
471  SmallVector<int64_t> tileSizes =
472  extractFromIntegerArrayAttr<int64_t>(getTileSizes());
473  SmallVector<int64_t> tileInterchange =
474  extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
475 
476  scf::SCFTilingOptions tilingOptions;
477  tilingOptions.interchangeVector = tileInterchange;
478  SmallVector<OpFoldResult> tileSizesOfr =
479  getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
480  tilingOptions = tilingOptions.setTileSizes(tileSizesOfr);
481  scf::SCFTileAndFuseOptions tileAndFuseOptions;
482  tileAndFuseOptions.tilingOptions = tilingOptions;
484  rewriter, getOperation(), state.getPayloadOps(getTarget()),
485  tileSizes.size() - llvm::count(tileSizes, 0), transformResults,
486  [&](TilingInterface tilingInterfaceOp)
488  return tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
489  rewriter, tilingInterfaceOp, tileAndFuseOptions);
490  });
493 }
494 
496  OperationState &result) {
497  OpAsmParser::UnresolvedOperand targetOperand;
498  if (parser.parseOperand(targetOperand) ||
499  parser.parseOptionalAttrDict(result.attributes))
500  return failure();
501 
502  FunctionType trailingType;
503  SMLoc typeLoc;
504  if (parser.getCurrentLocation(&typeLoc) ||
505  parser.parseColonType(trailingType)) {
506  return failure();
507  }
508  if (trailingType.getNumInputs() != 1)
509  return parser.emitError(typeLoc) << "expected one input type";
510 
511  result.addTypes(trailingType.getResults());
512  if (parser.resolveOperand(targetOperand, trailingType.getInput(0),
513  result.operands))
514  return failure();
515  return success();
516 }
517 
519  p << ' ';
520  p << getTarget();
521  p.printOptionalAttrDict((*this)->getAttrs());
522  p << " : ";
523  p.printFunctionalType(TypeRange(getOperand().getType()),
524  getResults().getTypes());
525 }
526 
528  SmallVector<int64_t> permutation =
529  extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
530  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
531  if (!std::is_permutation(sequence.begin(), sequence.end(),
532  permutation.begin(), permutation.end())) {
533  return emitOpError() << "expects interchange to be a permutation, found "
534  << getTileInterchange();
535  }
536 
537  SmallVector<int64_t> sizes =
538  extractFromIntegerArrayAttr<int64_t>(getTileSizes());
539  size_t numExpectedLoops = sizes.size() - llvm::count(sizes, 0);
540  if (numExpectedLoops != getNumResults() - 1)
541  return emitOpError() << "expects " << numExpectedLoops << " loop results";
542 
543  return success();
544 }
545 
546 //===----------------------------------------------------------------------===//
547 // FuseIntoContainingOp
548 //===----------------------------------------------------------------------===//
549 
550 void transform::FuseIntoContainingOp::build(OpBuilder &builder,
551  OperationState &result,
552  Value producerOp,
553  Value containingOp) {
554  result.addOperands({producerOp, containingOp});
555  auto resultType = transform::AnyOpType::get(builder.getContext());
556  result.addTypes({resultType, resultType});
557 }
558 
559 /// Add new operands to the forall op for users of the producerOp
560 /// that are dominated by the containing scf.forall op.
562  RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
563  Operation *containingOp, TilingResult &tileAndFuseResult,
564  int64_t resultNumber, SmallVector<OpFoldResult> &offsets,
565  SmallVector<OpFoldResult> &sizes) {
566 
567  // Count number of users not including the containing op
568  SetVector<Operation *> dominatedUsers;
569  DominanceInfo domInfo(containingOp);
570  for (Operation *user : producerOp->getResult(resultNumber).getUsers()) {
571  if (!containingOp->isAncestor(user) &&
572  (domInfo.dominates(containingOp, user))) {
573  dominatedUsers.insert(user);
574  }
575  }
576  if (dominatedUsers.empty())
577  return nullptr;
578 
579  // Create new scf.forall op
580  auto forallOp = cast<scf::ForallOp>(containingOp);
581  OpBuilder::InsertionGuard g(rewriter);
582  rewriter.setInsertionPoint(forallOp);
583 
584  // Get new output
585  Location loc = forallOp.getLoc();
586  auto genericOp = dyn_cast<linalg::GenericOp>(producerOp);
587  if (!genericOp)
588  return nullptr;
589  SmallVector<Value> outputs = genericOp.getOutputs();
590  SmallVector<Value> newOuts(forallOp.getOutputs());
591  newOuts.push_back(outputs[resultNumber]);
592 
593  // Create new scf.forall op
594  auto newforallOp = rewriter.create<scf::ForallOp>(
595  loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
596  forallOp.getMixedStep(), newOuts, forallOp.getMapping());
597  rewriter.eraseBlock(newforallOp.getBody());
598  newforallOp.getRegion().takeBody(forallOp.getRegion());
599 
600  // Add additional block argument for new value being returned
601  // and replaces all uses of the new output with corresponding bbArg
602  // inside the scf.forall to enable fusion into this new scf.forall.
603  newforallOp.getBody()->addArgument(newOuts.back().getType(),
604  newOuts.back().getLoc());
605  auto bbArgs = newforallOp.getBody()->getArguments();
606  rewriter.replaceUsesWithIf(newOuts.back(), bbArgs.back(),
607  [&](OpOperand &use) {
608  Operation *op = use.getOwner();
609  return newforallOp->isProperAncestor(op);
610  });
611 
612  // Fix terminator
613  scf::InParallelOp terminatorOp = newforallOp.getTerminator();
614  SmallVector<Operation *> yieldingOps = llvm::to_vector<4>(llvm::map_range(
615  terminatorOp.getYieldingOps(), [](Operation &op) { return &op; }));
616  Operation *firstYieldOp = yieldingOps.front();
617  rewriter.setInsertionPoint(firstYieldOp);
618  Value src = tileAndFuseResult.tiledValues[0];
619  Value dst = newforallOp.getOutputBlockArguments().back();
620  SmallVector<OpFoldResult> strides(offsets.size(), rewriter.getIndexAttr(1));
621  rewriter.create<tensor::ParallelInsertSliceOp>(firstYieldOp->getLoc(), src,
622  dst, offsets, sizes, strides);
623 
624  for (auto result : llvm::enumerate(forallOp.getResults())) {
625  rewriter.replaceAllUsesWith(result.value(),
626  newforallOp->getResult(result.index()));
627  }
628  rewriter.replaceUsesWithIf(producerOp->getResult(resultNumber),
629  newforallOp->getResults().back(),
630  [&](OpOperand &use) {
631  Operation *user = use.getOwner();
632  return dominatedUsers.contains(user);
633  });
634  return newforallOp;
635 }
636 
637 /// Find the first "extract" user of `producerOp` and tile it right before its
638 /// use. The tiled op is fused under the `containingOp`.
639 /// Return this fused op on success or nullptr if anything fails.
640 /// If tiled op has uses that are dominated by `containingOp`, return
641 /// a new `containingOp` with results of the fused op appended to
642 /// results of the `containingOp` or nullptr if there are no dominated uses.
643 static std::tuple<SmallVector<Operation *>, Operation *>
645  Operation *producerOp, Operation *containingOp) {
646  LLVM_DEBUG(DBGS() << "Try to fuse a direct extract use\n");
647  auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
648  if (!tileableProducer) {
649  diag.attachNote(producerOp->getLoc())
650  << "producer is not a TileableInterface: " << *producerOp;
651  return {};
652  }
653 
654  // Search the producer slices accessed within the containing operation.
655  // TODO: Generalize to more extract/insert/parallel_insert triples, maybe
656  // evolve into an interface.
657  auto it = llvm::find_if(tileableProducer->getUsers(), [&](Operation *user) {
658  auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
659  return sliceOp && containingOp->isProperAncestor(sliceOp);
660  });
661 
662  // Find a fusion opportunity.
663  if (it == tileableProducer->getUsers().end()) {
664  diag.attachNote(tileableProducer->getLoc())
665  << "could not find fusion opportunity for: " << *tileableProducer;
666  return {};
667  }
668  auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*it);
669 
670  // Try to fuse the producer in-place.
671  OpBuilder::InsertionGuard guard(rewriter);
672  rewriter.setInsertionPoint(sliceOpToTile);
673 
674  // Tile the producer.
675  int64_t resultNumber =
676  cast<OpResult>(sliceOpToTile.getSource()).getResultNumber();
677  LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
678 
679  SmallVector<OpFoldResult> offsets = sliceOpToTile.getMixedOffsets();
680  SmallVector<OpFoldResult> sizes = sliceOpToTile.getMixedSizes();
681 
682  FailureOr<TilingResult> tileAndFuseResult =
683  tileableProducer.generateResultTileValue(rewriter, resultNumber, offsets,
684  sizes);
685 
686  if (failed(tileAndFuseResult)) {
687  diag.attachNote(tileableProducer->getLoc())
688  << "failed to tile producer op: " << *tileableProducer;
689  return {};
690  }
691 
692 #ifndef NDEBUG
693  for (auto tiledOp : tileAndFuseResult->tiledOps) {
694  LLVM_DEBUG(DBGS() << "tiledProducer: " << *tiledOp << "\n");
695  }
696 #endif
697 
698  // Replace the extract op.
699  auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
700  rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
701  cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
702  if (failed(maybeRankReduced)) {
703  diag.attachNote(producerOp->getLoc())
704  << "shape types don't match (missing canonicalization?):\nTiledOp: "
705  << tileAndFuseResult->tiledValues[0]
706  << "\nSliceOp: " << sliceOpToTile.getOperation() << '\n';
707  return {};
708  }
709  rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
710 
711  // Add new outputs to containing op, if required
712  Operation *newContainingOp = replaceForAllWithNewSignature(
713  rewriter, diag, producerOp, containingOp, *tileAndFuseResult,
714  resultNumber, offsets, sizes);
715 
716  return std::make_tuple(tileAndFuseResult->tiledOps, newContainingOp);
717 }
718 
719 /// First, find the first "scf::ForallOp" user of `producerOp` and ensure
720 /// it is exactly the `containingOp`, otherwise bail.
721 /// Then, find the first "extract" user of the tied block argument and tile it
722 /// right before its "extract" use. The tiled op is fused under the
723 /// `containingOp`.
724 /// Return this fused op on success or nullptr if anything fails.
727  RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
728  Operation *containingOp) {
729  LLVM_DEBUG(DBGS() << "Try to fuse an extract use through block argument\n");
730 
731  auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
732  if (!tileableProducer) {
733  diag.attachNote(producerOp->getLoc())
734  << "producer is not a TileableInterface: " << *producerOp;
735  return {};
736  }
737 
738  // Search the first use by a "scf::ForallOp" user.
739  scf::ForallOp forallOp;
740  auto itProducerUses =
741  llvm::find_if(tileableProducer->getUses(), [&](OpOperand &use) {
742  forallOp = dyn_cast<scf::ForallOp>(use.getOwner());
743  return forallOp;
744  });
745  // If it's not from the containing op, return.
746  if (!forallOp || forallOp != containingOp) {
747  diag.attachNote(tileableProducer->getLoc())
748  << "could not find a use by the containing op: " << *tileableProducer;
749  return {};
750  }
751 
752  // Search the producer slices accessed within the containing
753  // operation.
754  // TODO: Generalize to more extract/insert/parallel_insert triples.
755  // Maybe evolve into an interface.
756  OpOperand *pUse = &(*itProducerUses);
757  BlockArgument bbArg = forallOp.getTiedBlockArgument(pUse);
758 
759  // Search the producer slices accessed within the containing operation.
760  // TODO: Generalize to more extract/insert/parallel_insert triples, maybe
761  // evolve into an interface.
762  auto itBBArgUsers = llvm::find_if(bbArg.getUsers(), [&](Operation *user) {
763  auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
764  return sliceOp && containingOp->isProperAncestor(sliceOp);
765  });
766 
767  // Find a fusion opportunity.
768  if (itBBArgUsers == bbArg.getUsers().end()) {
769  diag.attachNote(containingOp->getLoc())
770  << "could not find fusion opportunity for bbArg: " << bbArg;
771  return {};
772  }
773  auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*itBBArgUsers);
774 
775  // Try to fuse the producer in-place.
776  OpBuilder::InsertionGuard guard(rewriter);
777  rewriter.setInsertionPoint(sliceOpToTile);
778 
779  // Replace the use in the tileableProducer before tiling: clone, replace and
780  // then tile.
781  int64_t resultNumber = cast<OpResult>(pUse->get()).getResultNumber();
782  LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
783 
784  // Gather destination tensors.
785  SmallVector<Value> destinationTensors;
787  rewriter, tileableProducer->getLoc(), tileableProducer,
788  destinationTensors))) {
789  diag.attachNote(tileableProducer->getLoc())
790  << "failed to get destination tensors for: " << *tileableProducer;
791  return {};
792  }
793 
794  IRMapping bvm;
795  bvm.map(destinationTensors[resultNumber], bbArg);
796  auto tileableProducerClone =
797  cast<TilingInterface>(rewriter.clone(*tileableProducer, bvm));
798  auto scopeGuard =
799  llvm::make_scope_exit([&]() { rewriter.eraseOp(tileableProducerClone); });
800 
801  // Tile the producer.
802  FailureOr<TilingResult> tileAndFuseResult =
803  tileableProducerClone.generateResultTileValue(
804  rewriter, resultNumber, sliceOpToTile.getMixedOffsets(),
805  sliceOpToTile.getMixedSizes());
806  if (failed(tileAndFuseResult)) {
807  diag.attachNote(tileableProducer->getLoc())
808  << "failed to tile producer op: " << *tileableProducer;
809  return {};
810  }
811 
812  // Replace the extract op.
813  auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
814  rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
815  cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
816  assert(succeeded(maybeRankReduced) && "unexpected shape");
817  rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
818 
819  // Replace the use in containingOp.
820  rewriter.updateRootInPlace(containingOp, [&]() {
821  containingOp->setOperand(pUse->getOperandNumber(),
822  destinationTensors.front());
823  });
824 
825  return tileAndFuseResult->tiledOps;
826 }
827 
829  Operation *producerOp,
830  Operation *containingOp) {
831  LLVM_DEBUG(DBGS() << "Try to fuse an use by cloning\n");
832 
833  // Gather all uses inside the containing op.
835  for (OpResult result : producerOp->getOpResults()) {
836  for (OpOperand &use : result.getUses()) {
837  if (containingOp->isProperAncestor(use.getOwner())) {
838  uses.push_back(&use);
839  continue;
840  }
841  // Cannot clone and fuse if the use is by the containing op itself: fail
842  // immediately.
843  if (containingOp == use.getOwner()) {
844  diag.attachNote(producerOp->getLoc())
845  << "producer op use by containing op cannot be fused by cloning";
846  return nullptr;
847  }
848  }
849  }
850 
851  // Check for a non-empty list of fusion opportunities.
852  if (uses.empty()) {
853  diag.attachNote(producerOp->getLoc()) << "no fusion opportunity by cloning";
854  return nullptr;
855  }
856 
857  // Clone and fuse inside the containing op.
858  Operation *fusedOp = nullptr;
859  OpOperand *use = uses.front();
860  // Parallel insert slice is not a valid clone destination.
861  // TODO: Generalize to other type of ops.
862  assert(!isa<tensor::ParallelInsertSliceOp>(use->getOwner()) &&
863  "Parallel insert slice is not a valid clone destination");
864  unsigned resultNumber = cast<OpResult>(use->get()).getResultNumber();
865  LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
866 
867  OpBuilder::InsertionGuard guard(rewriter);
868  rewriter.setInsertionPoint(use->getOwner());
869  fusedOp = rewriter.clone(*producerOp);
870  rewriter.updateRootInPlace(
871  use->getOwner(), [&] { use->set(fusedOp->getOpResult(resultNumber)); });
872 
873  return fusedOp;
874 }
875 
876 bool transform::FuseIntoContainingOp::allowsRepeatedHandleOperands() {
877  // Allow repeated handles since we are fusing everything anyway.
878  return true;
879 }
880 
882 transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter,
884  transform::TransformState &state) {
885  SmallVector<Operation *> fusedOps;
886  auto producerOps = state.getPayloadOps(getProducerOp());
887  auto containingOps = state.getPayloadOps(getContainingOp());
888  if (!llvm::hasSingleElement(containingOps)) {
889  return emitDefiniteFailure()
890  << "requires exactly one containing_op handle (got "
891  << llvm::range_size(containingOps) << ")";
892  }
893  Operation *containingOp = *containingOps.begin();
894 
895  // If nothing to fuse, propagate success.
896  if (std::empty(producerOps)) {
897  results.set(cast<OpResult>(getFusedOp()), SmallVector<mlir::Operation *>{});
898  results.set(cast<OpResult>(getNewContainingOp()), {containingOp});
900  }
901 
902  // Helper function to find the next producer that should be fused. Take any
903  // producer that has a use inside the containing op.
904  SetVector<Operation *> remainingProducers(producerOps.begin(),
905  producerOps.end());
906  auto getNextProducer = [&]() -> FailureOr<Operation *> {
907  for (const auto &it : enumerate(remainingProducers)) {
908  Operation *producerOp = it.value();
909  // The containing op may be a user of producerOp: use isAncestor.
910  int64_t numUsesInContainingOp =
911  llvm::count_if(producerOp->getUsers(), [&](Operation *op) {
912  return containingOp->isAncestor(op);
913  });
914  // TODO: When resolving the TODO below (no duplicate ops), take an op
915  // that has no use among the remaining producers. This is a topological
916  // sorting.
917  if (numUsesInContainingOp > 0) {
918  if (numUsesInContainingOp == 1)
919  remainingProducers.erase(remainingProducers.begin() + it.index());
920  return producerOp;
921  }
922  }
923  return failure();
924  };
925 
926  while (!remainingProducers.empty()) {
927  auto nextProducer = getNextProducer();
928  if (failed(nextProducer)) {
929  auto diag = mlir::emitSilenceableFailure(getLoc())
930  << "could not find next producer to fuse into container";
931  diag.attachNote(containingOp->getLoc()) << "containing op";
932  return diag;
933  }
934 
935  Operation *producerOp = *nextProducer;
936 
937  // Default diagnostic, to be complemented with more failure information.
939  diag << "could not fuse " << *producerOp << " into " << *containingOp;
940 
941  // TODO: If there are multiple uses of the producer in the containing op,
942  // we currently tile/clone the op multiple times (once per use). In some
943  // cases, we can tile/clone once and reuse the value for each use.
944  // Futhermore, producers should then be traversed according to a
945  // topological sorting.
946  auto [tiledOps, newContainingOp] =
947  tileAndFuseFirstExtractUse(rewriter, diag, producerOp, containingOp);
948  if (!tiledOps.empty()) {
949  LLVM_DEBUG(DBGS() << "\nFused a direct extract use\n" << *containingOp);
950  fusedOps.append(tiledOps);
951  if (newContainingOp) {
952  // Update handles associated with the containing op so we don't need to
953  // invalidate them. This is a hack to support better composability
954  // between tiling and fusion while a proper mechanism is being
955  // investigated.
956  //
957  // DO NOT replicate this elsewhere unless you understand what you are
958  // doing.
959  LogicalResult replacementStatus =
960  rewriter.notifyPayloadOperationReplaced(containingOp,
961  newContainingOp);
962  (void)replacementStatus;
963  assert(succeeded(replacementStatus) &&
964  "unable to update transform state mapping");
965  rewriter.eraseOp(containingOp);
966  containingOp = newContainingOp;
967  }
968  continue;
969  }
970 
971  SmallVector<Operation *> tiledContainingOpOperand =
973  rewriter, diag, producerOp, containingOp);
974  if (!tiledContainingOpOperand.empty()) {
975  LLVM_DEBUG(DBGS() << "\nFused an extract use through block argument\n"
976  << *containingOp);
977  fusedOps.append(tiledContainingOpOperand);
978  continue;
979  }
980 
981  Operation *cloned =
982  cloneAndFuseFirstUse(rewriter, diag, producerOp, containingOp);
983  if (cloned) {
984  LLVM_DEBUG(DBGS() << "\nFused an use by cloning\n" << *containingOp);
985  fusedOps.push_back(cloned);
986  continue;
987  }
989  }
990 
991  results.set(cast<OpResult>(getFusedOp()), fusedOps);
992  results.set(cast<OpResult>(getNewContainingOp()), {containingOp});
994 }
995 
996 void transform::FuseIntoContainingOp::getEffects(
998  consumesHandle(getProducerOp(), effects);
999  onlyReadsHandle(getContainingOp(), effects);
1000  producesHandle(getResults(), effects);
1001  modifiesPayload(effects);
1002 }
1003 
1004 //===----------------------------------------------------------------------===//
1005 // GeneralizeOp
1006 //===----------------------------------------------------------------------===//
1007 
1009 transform::GeneralizeOp::applyToOne(transform::TransformRewriter &rewriter,
1010  LinalgOp target,
1012  transform::TransformState &state) {
1013  // Exit early if no transformation is needed.
1014  if (isa<GenericOp>(target)) {
1015  results.push_back(target);
1017  }
1018  rewriter.setInsertionPoint(target);
1019  FailureOr<LinalgOp> generic = generalizeNamedOp(rewriter, target);
1020  if (succeeded(generic)) {
1021  results.push_back(generic->getOperation());
1023  }
1024  return emitDefaultSilenceableFailure(target);
1025 }
1026 
1027 //===----------------------------------------------------------------------===//
1028 // SpecializeOp
1029 //===----------------------------------------------------------------------===/
1030 
1032 transform::SpecializeOp::applyToOne(transform::TransformRewriter &rewriter,
1033  LinalgOp target,
1035  transform::TransformState &state) {
1036  // Exit early if the operation is not a generic.
1037  if (!isa<GenericOp>(target)) {
1038  results.push_back(target);
1040  }
1041  rewriter.setInsertionPoint(target);
1042  FailureOr<LinalgOp> named =
1043  specializeGenericOp(rewriter, cast<GenericOp>(target));
1044  if (succeeded(named)) {
1045  results.push_back(named->getOperation());
1047  }
1048  return emitDefaultSilenceableFailure(target);
1049 }
1050 
1051 //===----------------------------------------------------------------------===//
1052 // InterchangeOp
1053 //===----------------------------------------------------------------------===//
1054 
1056 transform::InterchangeOp::applyToOne(transform::TransformRewriter &rewriter,
1057  GenericOp target,
1059  transform::TransformState &state) {
1060  ArrayRef<int64_t> interchangeVector = getIteratorInterchange();
1061  // Exit early if no transformation is needed.
1062  if (interchangeVector.empty()) {
1063  results.push_back(target);
1065  }
1066 
1067  unsigned numLoops = cast<LinalgOp>(target.getOperation()).getNumLoops();
1068  if (interchangeVector.size() != numLoops) {
1069  return emitSilenceableError()
1070  << getIteratorInterchangeAttrName() << " has length ("
1071  << interchangeVector.size()
1072  << ") different from the number of loops in the target operation ("
1073  << numLoops << ")";
1074  }
1075  FailureOr<GenericOp> res =
1076  interchangeGenericOp(rewriter, target,
1077  SmallVector<unsigned>(interchangeVector.begin(),
1078  interchangeVector.end()));
1079  if (failed(res))
1080  return emitDefiniteFailure() << "failed to apply";
1081  results.push_back(res->getOperation());
1083 }
1084 
1086  ArrayRef<int64_t> permutation = getIteratorInterchange();
1087  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
1088  if (!std::is_permutation(sequence.begin(), sequence.end(),
1089  permutation.begin(), permutation.end())) {
1090  return emitOpError()
1091  << "expects iterator_interchange to be a permutation, found "
1092  << getIteratorInterchange();
1093  }
1094  return success();
1095 }
1096 
1097 //===----------------------------------------------------------------------===//
1098 // LowerPackOp
1099 //===----------------------------------------------------------------------===//
1100 
1101 DiagnosedSilenceableFailure transform::LowerPackOp::applyToOne(
1102  transform::TransformRewriter &rewriter, tensor::PackOp target,
1103  transform::ApplyToEachResultList &transformResults,
1104  transform::TransformState &state) {
1105  rewriter.setInsertionPoint(target);
1106  FailureOr<LowerPackResult> res = lowerPack(rewriter, target);
1107  if (failed(res)) {
1108  return mlir::emitSilenceableFailure(target->getLoc())
1109  << "cannot lower to pad + expand + transpose";
1110  }
1111  transformResults.push_back(res->padOp);
1112  transformResults.push_back(res->expandShapeOp);
1113  transformResults.push_back(res->transposeOp);
1115 }
1116 
1117 //===----------------------------------------------------------------------===//
1118 // LowerUnPackOp
1119 //===----------------------------------------------------------------------===//
1120 
1121 DiagnosedSilenceableFailure transform::LowerUnPackOp::applyToOne(
1122  transform::TransformRewriter &rewriter, tensor::UnPackOp target,
1123  transform::ApplyToEachResultList &transformResults,
1124  transform::TransformState &state) {
1125  rewriter.setInsertionPoint(target);
1126  FailureOr<LowerUnPackOpResult> res = lowerUnPack(rewriter, target);
1127  if (failed(res)) {
1128  return mlir::emitSilenceableFailure(target->getLoc())
1129  << "cannot rewrite to pad + expand + transpose";
1130  }
1131  transformResults.push_back(res->emptyOp);
1132  transformResults.push_back(res->transposeOp);
1133  transformResults.push_back(res->collapseShapeOp);
1134  transformResults.push_back(res->extractSliceOp);
1136 }
1137 
1138 //===---------------------------------------------------------------------===//
1139 // MatchOp
1140 //===---------------------------------------------------------------------===//
1141 
1142 void transform::MatchOp::build(OpBuilder &builder, OperationState &result,
1143  Value target, ArrayRef<StringRef> opNames) {
1144  result.addOperands(target);
1145  result.addAttribute(MatchOp::getOpsAttrName(result.name),
1146  builder.getStrArrayAttr(opNames));
1147  result.addTypes(transform::AnyOpType::get(builder.getContext()));
1148 }
1149 
1150 void transform::MatchOp::build(OpBuilder &builder, OperationState &result,
1151  TypeRange resultTypes, Value target,
1152  ArrayRef<StringRef> opNames) {
1153  result.addOperands(target);
1154  result.addAttribute(MatchOp::getOpsAttrName(result.name),
1155  builder.getStrArrayAttr(opNames));
1156  result.addTypes(resultTypes);
1157 }
1158 
1160 transform::MatchOp::apply(transform::TransformRewriter &rewriter,
1161  transform::TransformResults &results,
1162  transform::TransformState &state) {
1163  llvm::StringSet<> strs;
1164  if (getOps().has_value())
1165  strs.insert(getOps()->getAsValueRange<StringAttr>().begin(),
1166  getOps()->getAsValueRange<StringAttr>().end());
1167 
1168  auto payloadOps = state.getPayloadOps(getTarget());
1169  if (!llvm::hasSingleElement(payloadOps)) {
1170  return emitDefiniteFailure("requires exactly one target handle");
1171  }
1172 
1174  auto matchFun = [&](Operation *op) {
1175  if (getOps().has_value() && !strs.contains(op->getName().getStringRef()))
1176  return;
1177 
1178  // Interfaces cannot be matched by name, just by ID.
1179  // So we specifically encode the interfaces we care about for this op.
1180  if (getInterface().has_value()) {
1181  auto iface = getInterface().value();
1182  if (iface == transform::MatchInterfaceEnum::LinalgOp &&
1183  !isa<LinalgOp>(op))
1184  return;
1185  if (iface == transform::MatchInterfaceEnum::TilingInterface &&
1186  !isa<TilingInterface>(op))
1187  return;
1188  if (iface == transform::MatchInterfaceEnum::LoopLikeInterface &&
1189  !isa<LoopLikeOpInterface>(op))
1190  return;
1191  }
1192 
1193  // Check if all specified attributes match.
1194  if (getOpAttrs().has_value()) {
1195  DictionaryAttr opAttrs = getOpAttrs().value();
1196  for (NamedAttribute attr : opAttrs) {
1197  if (attr.getName() == getInterfaceAttrName() ||
1198  attr.getName() == getOpsAttrName())
1199  continue;
1200  if (!op->hasAttr(attr.getName()))
1201  return;
1202  if (op->getAttr(attr.getName()) != attr.getValue())
1203  return;
1204  }
1205  }
1206 
1207  if (getFilterResultType().has_value()) {
1208  Type t = getFilterResultType().value();
1209  if (op->getNumResults() != 1 || op->getResultTypes().front() != t)
1210  return;
1211  }
1212 
1213  // All constraints are satisfied.
1214  res.push_back(op);
1215  return;
1216  };
1217 
1218  (*payloadOps.begin())->walk(matchFun);
1219  results.set(cast<OpResult>(getResult()), res);
1221 }
1222 
1223 //===---------------------------------------------------------------------===//
1224 // MultiTileSizesOp
1225 //===---------------------------------------------------------------------===//
1226 
1228  Type targetType, Type lowSizeType, Type,
1229  Type) {
1230  printer.printFunctionalType(TypeRange{targetType}, TypeRange{lowSizeType});
1231 }
1232 
1234  Type &targetType, Type &lowSizeType,
1235  Type &highSizeType,
1236  Type &splitPointType) {
1237  FunctionType funcType;
1238  llvm::SMLoc typeLoc = parser.getCurrentLocation();
1239  if (failed(parser.parseType<FunctionType>(funcType)))
1240  return failure();
1241 
1242  if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
1243  parser.emitError(typeLoc) << "expects a trailing functional type with one "
1244  "argument and one result";
1245  }
1246  targetType = funcType.getInput(0);
1247  lowSizeType = highSizeType = splitPointType = funcType.getResult(0);
1248 
1249  return success();
1250 }
1251 
1252 DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne(
1253  transform::TransformRewriter &rewriter, LinalgOp target,
1255  if (isa<TransformParamTypeInterface>(getLowSize().getType())) {
1256  if (target.hasDynamicShape()) {
1257  auto diag = emitSilenceableError()
1258  << "cannot compute parametric tile sizes for dynamically "
1259  "shaped payload op";
1260  diag.attachNote(target->getLoc()) << "payload op";
1261  return diag;
1262  }
1263 
1265  target, getDimension(), getTargetSize(), getDivisor());
1266  if (failed(spec)) {
1267  return emitSilenceableError()
1268  << "failed to compute multi-size tiling sizes";
1269  }
1270 
1271  Builder builder(target.getContext());
1272  results.assign(llvm::map_range(
1273  ArrayRef<int64_t>({spec->lowTileSize, spec->highTileSize,
1274  spec->lowTileSize * spec->lowTripCount}),
1275  [&builder, this](int64_t value) {
1276  return builder.getIntegerAttr(
1277  cast<ParamType>(getLowSize().getType()).getType(), value);
1278  }));
1280  }
1281 
1282  OpBuilder builder(target.getContext());
1283  builder.setInsertionPoint(target);
1284  OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
1285  OpFoldResult divisor = builder.getIndexAttr(getDivisor());
1287  builder, target, getDimension(), targetSize, divisor);
1288  if (failed(spec)) {
1289  return emitSilenceableError() << "could not generate tile size computation";
1290  }
1291 
1292  AffineExpr s0 = builder.getAffineSymbolExpr(0);
1293  AffineExpr s1 = builder.getAffineSymbolExpr(1);
1294  Operation *splitPoint =
1295  affine::makeComposedAffineApply(builder, target.getLoc(), s0 * s1,
1296  {spec->lowTileSize, spec->lowTripCount});
1297  Operation *lowTileSize = spec->lowTileSize.getDefiningOp();
1298  Operation *highTileSize = spec->highTileSize.getDefiningOp();
1299  assert(lowTileSize && highTileSize && splitPoint &&
1300  "tile sizes are not produced by operations");
1301  results.reserve(results.size() + 3);
1302  results.push_back(lowTileSize);
1303  results.push_back(highTileSize);
1304  results.push_back(splitPoint);
1306 }
1307 
1308 void transform::MultiTileSizesOp::getEffects(
1310  onlyReadsHandle(getTarget(), effects);
1311  producesHandle(getResults(), effects);
1312  if (isa<TransformParamTypeInterface>(getLowSize().getType()))
1313  onlyReadsPayload(effects);
1314  else
1315  modifiesPayload(effects);
1316 }
1317 
1319  if (getLowSize().getType() != getHighSize().getType() ||
1320  getLowSize().getType() != getSplitPoint().getType()) {
1321  return emitOpError() << "expects all results type to be the same";
1322  }
1323  return success();
1324 }
1325 
1326 //===---------------------------------------------------------------------===//
1327 // PackOp
1328 //===---------------------------------------------------------------------===//
1329 
1330 void transform::PackOp::build(OpBuilder &builder, OperationState &result,
1331  Value target,
1332  ArrayRef<OpFoldResult> mixedPackedSizes) {
1333  SmallVector<int64_t> staticPackedSizes;
1334  SmallVector<Value> dynamicPackedSizes;
1335  dispatchIndexOpFoldResults(mixedPackedSizes, dynamicPackedSizes,
1336  staticPackedSizes);
1337  // Call the default builder which sets up the proper operands segment sizes
1338  // attributes for multiple variadic operands. In the absence of this, horrible
1339  // bugs ensue.
1340  Type linalgOpHType = transform::OperationType::get(
1341  builder.getContext(), GenericOp::getOperationName());
1342  build(builder, result,
1343  /*resultType=*/linalgOpHType,
1344  /*target=*/target,
1345  /*dynamic_sizes=*/dynamicPackedSizes,
1346  /*static_sizes=*/builder.getDenseI64ArrayAttr(staticPackedSizes));
1347 }
1348 
1349 SmallVector<OpFoldResult> transform::PackOp::getMixedPackedSizes() {
1350  Builder b(getContext());
1351  return getMixedValues(getStaticPackedSizes(), getPackedSizes(), b);
1352 }
1353 
1355 transform::PackOp::apply(transform::TransformRewriter &rewriter,
1356  transform::TransformResults &transformResults,
1357  transform::TransformState &state) {
1358  auto targetOps = state.getPayloadOps(getTarget());
1359  // If nothing to pack, propagate success.
1360  if (std::empty(targetOps)) {
1361  transformResults.set(cast<OpResult>(getPackedOp()),
1362  ArrayRef<Operation *>({}));
1364  }
1365  // Fail on multi-op handles.
1366  auto linalgOp = dyn_cast<LinalgOp>(*targetOps.begin());
1367  if (!llvm::hasSingleElement(targetOps) || !linalgOp) {
1368  return emitSilenceableError()
1369  << "requires target to map to exactly 1 LinalgOp (got "
1370  << llvm::range_size(targetOps) << ")";
1371  }
1372  // Fail on mismatched number of pack sizes.
1373  if (getMixedPackedSizes().size() != linalgOp.getNumLoops()) {
1374  return emitSilenceableError()
1375  << "requires number of packed sizes match the number of loops ("
1376  << getMixedPackedSizes().size() << " vs " << linalgOp.getNumLoops()
1377  << ")";
1378  }
1379 
1380  // Unpack handles to constants or actual SSA index values.
1381  SmallVector<OpFoldResult> packedSizes;
1383  state, *this, packedSizes, getMixedPackedSizes());
1384 
1385  rewriter.setInsertionPoint(linalgOp);
1386  FailureOr<PackResult> maybeResult = pack(rewriter, linalgOp, packedSizes);
1387  if (failed(maybeResult))
1388  return emitDefiniteFailure("data tiling failed");
1389 
1390  transformResults.set(cast<OpResult>(getPackedOp()),
1391  {maybeResult->packedLinalgOp.getOperation()});
1393 }
1394 
1395 void transform::PackOp::getEffects(
1397  transform::consumesHandle(getTarget(), effects);
1398  transform::onlyReadsHandle(getPackedSizes(), effects);
1399  transform::producesHandle(getPackedOp(), effects);
1400  transform::modifiesPayload(effects);
1401 }
1402 
1403 //===---------------------------------------------------------------------===//
1404 // PackGreedilyOp.
1405 //===---------------------------------------------------------------------===//
1406 
1408  if (!isPermutationVector(getMatmulInnerDimsOrder())) {
1409  return emitOpError() << getMatmulInnerDimsOrderAttrName()
1410  << " is not a valid permutation";
1411  }
1412  // TODO: relax to allow empty once we have another strategy than just matmul.
1413  if (!getMatmulPaddedSizesNextMultipleOf().empty()) {
1414  for (auto [s, nmo] :
1415  llvm::zip_equal(getMixedMatmulPackedSizes(),
1416  getMatmulPaddedSizesNextMultipleOf())) {
1417  std::optional<int64_t> maybeStaticPackedSize = getConstantIntValue(s);
1418  if (nmo != 0 &&
1419  (!maybeStaticPackedSize.has_value() || *maybeStaticPackedSize != 0)) {
1420  return emitOpError() << "at most one of the packed_size and the "
1421  "padded_sizes_next_multiple_of can be nonzero "
1422  "for the matmul strategy";
1423  }
1424  }
1425  }
1426  return success();
1427 }
1428 
1430 PackGreedilyOp::apply(transform::TransformRewriter &rewriter,
1431  transform::TransformResults &transformResults,
1432  transform::TransformState &state) {
1433  SmallVector<Operation *> results;
1434  for (Operation *op : state.getPayloadOps(getTarget())) {
1435  auto linalgOp = dyn_cast<LinalgOp>(op);
1436  if (!linalgOp)
1437  continue;
1438  // linalgOp will be replaced and the insertion point may be invalidated if
1439  // we set it before -> set it after.
1440  rewriter.setInsertionPointAfter(linalgOp);
1441  // Failing to pack greedily is perfectly fine.
1442  // In the future we will want to order packings according to some metric.
1444  /*rewriter=*/rewriter,
1445  /*linalgOp=*/linalgOp,
1446  /*mnkPackedSizes=*/getMixedMatmulPackedSizes(),
1447  /*mnkPaddedSizesNextMultipleOf=*/
1448  getMatmulPaddedSizesNextMultipleOf(),
1449  /*mnkOrder=*/getMatmulInnerDimsOrder());
1450  if (succeeded(packResult)) {
1451  results.push_back(packResult->packedLinalgOp);
1452  continue;
1453  }
1454  results.push_back(linalgOp);
1455  }
1456  transformResults.set(cast<OpResult>(getPackedOp()), results);
1458 }
1459 
1460 SmallVector<OpFoldResult> PackGreedilyOp::getMixedMatmulPackedSizes() {
1461  Builder b(getContext());
1462  return getMixedValues(getStaticMatmulPackedSizes(), getMatmulPackedSizes(),
1463  b);
1464 }
1465 
1466 void transform::PackGreedilyOp::getEffects(
1468  transform::consumesHandle(getTarget(), effects);
1469  transform::onlyReadsHandle(getMatmulPackedSizes(), effects);
1470  transform::producesHandle(getPackedOp(), effects);
1471  transform::modifiesPayload(effects);
1472 }
1473 
1474 //===---------------------------------------------------------------------===//
1475 // PackTransposeOp
1476 //===---------------------------------------------------------------------===//
1477 
1479  if (!isPermutationVector(getInnerPerm())) {
1480  return emitOpError() << getInnerPermAttrName()
1481  << " is not a valid permutation";
1482  }
1483  if (!isPermutationVector(getOuterPerm())) {
1484  return emitOpError() << getOuterPermAttrName()
1485  << " is not a valid permutation";
1486  }
1487  if (getInnerPerm().empty() && getOuterPerm().empty()) {
1488  return emitOpError() << " at least one of " << getInnerPermAttrName()
1489  << " or " << getOuterPermAttrName()
1490  << " must be specified";
1491  }
1492  return success();
1493 }
1494 
1495 namespace {
1496 enum class OuterOrInnerPerm { Outer = 0, Inner = 1 };
1497 } // namespace
1498 
1499 /// Return true if `permutation` is a valid permutation of the
1500 /// `outer_dims_perm` (case OuterOrInnerPerm::Outer) or `inner_dims_pos`
1501 /// (OuterOrInnerPerm::Inner) of the `tensor.pack` or `tensor.unpack` `op.
1502 /// This is the case when the `permutation` rank matches the rank expected by
1503 /// `op` and `permutation` is itself a permutation vector.
1504 /// Return true if either `op` or `permutation` are empty to allow a simpler
1505 /// polymorphic implementation.
1506 template <typename RelayoutOpTy>
1508  RelayoutOpTy op, ArrayRef<int64_t> permutation,
1509  OuterOrInnerPerm outerOrInnerPerm = OuterOrInnerPerm::Outer) {
1510  static_assert(
1511  llvm::is_one_of<RelayoutOpTy, tensor::PackOp, tensor::UnPackOp>::value,
1512  "applies to only pack or unpack operations");
1513  if (!op || permutation.empty())
1514  return true;
1515  size_t innerRank = op.getInnerDimsPos().size();
1516  if (outerOrInnerPerm == OuterOrInnerPerm::Inner)
1517  return permutation.size() == innerRank && isPermutationVector(permutation);
1518  // op.getOuterDimsPerm() may be empty, in which case it is identity.
1519  // Don't rely on it.
1520  if (std::is_same<RelayoutOpTy, tensor::PackOp>::value) {
1521  return permutation.size() == op.getSourceRank() &&
1522  isPermutationVector(permutation);
1523  }
1524  return permutation.size() == op.getDestRank() &&
1525  isPermutationVector(permutation);
1526 }
1527 
1529 transform::PackTransposeOp::apply(transform::TransformRewriter &rewriter,
1530  transform::TransformResults &transformResults,
1531  transform::TransformState &state) {
1532  auto packOrUnpackOps = state.getPayloadOps(getTargetPackOrUnPackOp());
1533  auto linalgOps = state.getPayloadOps(getTargetLinalgOp());
1534  // Step 1. If nothing to pack, propagate success.
1535  if (std::empty(packOrUnpackOps)) {
1536  transformResults.set(cast<OpResult>(getPackedOp()), {});
1537  transformResults.set(cast<OpResult>(getPackOp()), {});
1538  transformResults.set(cast<OpResult>(getUnPackOp()), {});
1540  }
1541 
1542  // Step 2. Bunch of runtime sanity check and error messages.
1543  // Step 2.1. Fail on multi-op handles.
1544  if (!llvm::hasSingleElement(packOrUnpackOps) ||
1545  !llvm::hasSingleElement(linalgOps)) {
1546  return emitSilenceableError()
1547  << "requires target to map to exactly 1 "
1548  "packing op and 1 packed op ("
1549  << "got " << llvm::range_size(packOrUnpackOps) << " and "
1550  << llvm::range_size(linalgOps) << ")";
1551  }
1552 
1553  // Step 2.2. Fail on wrong type.
1554  auto packOp = dyn_cast<tensor::PackOp>(*packOrUnpackOps.begin());
1555  auto unPackOp = dyn_cast<tensor::UnPackOp>(*packOrUnpackOps.begin());
1556  if ((!packOp && !unPackOp)) {
1557  return emitSilenceableError() << "requires target to map to a "
1558  "tensor.pack or tensor.unpack";
1559  }
1560  LinalgOp linalgOpTarget = dyn_cast<LinalgOp>(*linalgOps.begin());
1561  if (!linalgOpTarget)
1562  return emitSilenceableError() << "requires a LinalgOp target";
1563 
1564  // Step 2.3. Fail if we can't get the producer / consumer Linalg op.
1565  LinalgOp linalgOp;
1566  if (packOp && packOp.getResult().hasOneUse())
1567  linalgOp = dyn_cast<LinalgOp>(*(packOp.getResult().getUsers().begin()));
1568  else if (unPackOp)
1569  linalgOp = unPackOp.getSource().getDefiningOp<LinalgOp>();
1570  if (linalgOp != linalgOpTarget) {
1571  auto errorMsg =
1572  packOp ? StringLiteral{"not a single use by the LinalgOp target"}
1573  : StringLiteral{"not produced by the LinalgOp target"};
1574  return emitSilenceableError() << errorMsg;
1575  }
1576 
1577  // Step 2.4. If we have an UnPackOp, we need to fetch the symmetrical
1578  // PackOp.
1579  if (unPackOp) {
1580  assert(!packOp && "packOp must be null on entry when unPackOp is not null");
1581  OpOperand *packUse = linalgOp.getDpsInitOperand(
1582  cast<OpResult>(unPackOp.getSource()).getResultNumber());
1583  packOp = dyn_cast_or_null<tensor::PackOp>(packUse->get().getDefiningOp());
1584  if (!packOp || !packOp.getResult().hasOneUse())
1585  return emitSilenceableError() << "could not find matching pack op";
1586  }
1587 
1588  // Step 2.5. Fail if any permutation does not validate.
1589  for (auto permType : {OuterOrInnerPerm::Outer, OuterOrInnerPerm::Inner}) {
1590  ArrayRef<int64_t> perm =
1591  (permType == OuterOrInnerPerm::Outer) ? getOuterPerm() : getInnerPerm();
1592  auto errorMsg = (permType == OuterOrInnerPerm::Outer)
1593  ? StringLiteral{"invalid outer_perm"}
1594  : StringLiteral{"invalid inner_perm"};
1595  if (!isValidPackingPermutation(packOp, perm, permType) ||
1596  !isValidPackingPermutation(unPackOp, perm, permType)) {
1597  Operation *packOrUnpackOp =
1598  unPackOp ? unPackOp.getOperation() : packOp.getOperation();
1599  return emitSilenceableError() << errorMsg << ": " << *packOrUnpackOp;
1600  }
1601  }
1602 
1603  // From here on, packOp and linalgOp are always present, unPackOp may or may
1604  // not be present.
1605  assert(packOp && linalgOp && "unexpected null op");
1606 
1607  // Step 3. Actually transpose the ops.
1609  rewriter, packOp, linalgOp, unPackOp, getOuterPerm(), getInnerPerm());
1610  // Preconditions have been checked, it is an error to fail here.
1611  assert(succeeded(res) && "unexpected packTranspose failure");
1612 
1613  // Step 4. Return results.
1614  transformResults.set(cast<OpResult>(getPackOp()), {res->transposedPackOp});
1615  transformResults.set(cast<OpResult>(getPackedOp()),
1616  {res->transposedLinalgOp});
1617  if (unPackOp) {
1618  transformResults.set(cast<OpResult>(getUnPackOp()),
1619  {res->transposedUnPackOp});
1620  } else {
1621  transformResults.set(cast<OpResult>(getUnPackOp()), {});
1622  }
1623 
1625 }
1626 
1627 //===---------------------------------------------------------------------===//
1628 // PadOp
1629 //===---------------------------------------------------------------------===//
1630 
1631 void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
1632  ArrayRef<int64_t> paddingDimensions,
1633  ArrayRef<int64_t> padToMultipleOf,
1634  ArrayRef<int64_t> packPaddings,
1635  ArrayRef<Attribute> transposePaddings,
1636  StringRef copyBackOp) {
1637  auto resultType = transform::AnyOpType::get(b.getContext());
1638  return build(/*builder=*/b,
1639  /*result=*/result,
1640  /*types=*/TypeRange{resultType, resultType},
1641  /*target=*/target,
1642  /*paddingValues=*/ArrayAttr(), // let inference handle this
1643  /*paddingDimensions=*/b.getI64ArrayAttr(paddingDimensions),
1644  /*padToMultipleOf=*/
1645  (padToMultipleOf.empty() ? ArrayAttr()
1646  : b.getI64ArrayAttr(padToMultipleOf)),
1647  /*packPaddings=*/b.getI64ArrayAttr(packPaddings),
1648  /*transposePaddings=*/b.getArrayAttr(transposePaddings),
1649  /*copyBackOp=*/b.getStringAttr(copyBackOp));
1650 }
1651 
1653 transform::PadOp::apply(transform::TransformRewriter &rewriter,
1654  transform::TransformResults &results,
1655  transform::TransformState &state) {
1656  SmallVector<Operation *> paddedOps, padOps, copyBackOps;
1657 
1658  for (Operation *target : state.getPayloadOps(getTarget())) {
1659  auto linalgTarget = dyn_cast<LinalgOp>(target);
1660  if (!linalgTarget) {
1661  auto diag = emitSilenceableError() << "expected LinalgOp target";
1662  diag.attachNote(target->getLoc()) << "target op";
1663  return diag;
1664  }
1665 
1666  // Convert the integer packing flags to booleans.
1667  SmallVector<bool> packPaddings;
1668  for (int64_t packPadding :
1669  extractFromIntegerArrayAttr<int64_t>(getPackPaddings()))
1670  packPaddings.push_back(static_cast<bool>(packPadding));
1671 
1672  // Convert the padding values to attributes.
1673  SmallVector<Attribute> paddingValues;
1674  for (auto const &it :
1675  llvm::zip(getPaddingValues(), linalgTarget->getOperandTypes())) {
1676  auto attr = dyn_cast<TypedAttr>(std::get<0>(it));
1677  if (!attr) {
1678  emitOpError("expects padding values to be typed attributes");
1680  }
1681  Type elementType = getElementTypeOrSelf(std::get<1>(it));
1682  // Try to parse string attributes to obtain an attribute of element type.
1683  if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
1684  auto parsedAttr = dyn_cast_if_present<TypedAttr>(parseAttribute(
1685  stringAttr, getContext(), elementType,
1686  /*numRead=*/nullptr, /*isKnownNullTerminated=*/true));
1687  if (!parsedAttr || parsedAttr.getType() != elementType) {
1688  auto diag = this->emitOpError("expects a padding that parses to ")
1689  << elementType << ", got " << std::get<0>(it);
1690  diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";
1692  }
1693  paddingValues.push_back(parsedAttr);
1694  continue;
1695  }
1696  // Otherwise, add the attribute directly.
1697  if (attr.getType() != elementType) {
1698  auto diag = this->emitOpError("expects a padding value of type ")
1699  << elementType << ", got " << attr;
1700  diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";
1702  }
1703  paddingValues.push_back(attr);
1704  }
1705 
1706  // Extract the transpose vectors.
1707  SmallVector<SmallVector<int64_t>> transposePaddings;
1708  for (Attribute transposeVector : cast<ArrayAttr>(getTransposePaddings()))
1709  transposePaddings.push_back(extractFromIntegerArrayAttr<int64_t>(
1710  cast<ArrayAttr>(transposeVector)));
1711 
1712  LinalgOp paddedOp;
1714  options.paddingDimensions =
1715  extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
1716  SmallVector<int64_t> padToMultipleOf(options.paddingDimensions.size(), 1);
1717  if (getPadToMultipleOf().has_value())
1718  padToMultipleOf =
1719  extractFromIntegerArrayAttr<int64_t>(*getPadToMultipleOf());
1720  options.padToMultipleOf = padToMultipleOf;
1721  options.paddingValues = paddingValues;
1722  options.packPaddings = packPaddings;
1723  if (getCopyBackOp() ==
1724  bufferization::MaterializeInDestinationOp::getOperationName()) {
1727  } else if (getCopyBackOp() == linalg::CopyOp::getOperationName()) {
1729  } else if (getCopyBackOp() == kCopyOpNone) {
1731  } else {
1732  llvm_unreachable("unsupported copy_back op");
1733  }
1734 
1735  SmallVector<Value> replacements;
1736  SmallVector<tensor::PadOp> newPadOps;
1737  if (failed(rewriteAsPaddedOp(rewriter, linalgTarget, options, paddedOp,
1738  replacements, newPadOps))) {
1739  auto diag = emitSilenceableError() << "failed to pad op";
1740  diag.attachNote(target->getLoc()) << "target op";
1741  return diag;
1742  }
1743 
1744  // We need to perform our own replacement here because this API is still
1745  // used in patterns that "pad and hoist", for which the replacement values
1746  // need to be different.
1747  // TODO: clean this up and stop "pad and hoist" behavior more globally now
1748  // that we have more composable abstractions.
1749  rewriter.replaceOp(linalgTarget, replacements);
1750  paddedOps.push_back(paddedOp);
1751  padOps.append(newPadOps.begin(), newPadOps.end());
1752  if (options.copyBackOp != LinalgPaddingOptions::CopyBackOp::None) {
1753  for (Value v : replacements) {
1754  Operation *copyBackOp = v.getDefiningOp();
1755  if (llvm::find(copyBackOps, copyBackOp) == copyBackOps.end())
1756  copyBackOps.push_back(copyBackOp);
1757  }
1758  }
1759  }
1760 
1761  results.set(cast<OpResult>(getPadded()), paddedOps);
1762  results.set(cast<OpResult>(getPad()), padOps);
1763  results.set(cast<OpResult>(getCopy()), copyBackOps);
1765 }
1766 
1768  SmallVector<int64_t> packPaddings =
1769  extractFromIntegerArrayAttr<int64_t>(getPackPaddings());
1770  if (any_of(packPaddings, [](int64_t packPadding) {
1771  return packPadding != 0 && packPadding != 1;
1772  })) {
1773  return emitOpError()
1774  << "expects pack_paddings to contain booleans (0/1), found "
1775  << getPackPaddings();
1776  }
1777 
1778  SmallVector<int64_t> paddingDimensions =
1779  extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
1780  if (any_of(paddingDimensions,
1781  [](int64_t paddingDimension) { return paddingDimension < 0; })) {
1782  return emitOpError() << "expects padding_dimensions to contain positive "
1783  "integers, found "
1784  << getPaddingDimensions();
1785  }
1786  if (getPadToMultipleOf().has_value()) {
1787  if (getPadToMultipleOf()->size() != paddingDimensions.size()) {
1788  return emitOpError() << "expects as many multiples as padding_dimensions";
1789  }
1790  }
1791  ArrayAttr transposes = getTransposePaddings();
1792  for (Attribute attr : transposes) {
1793  SmallVector<int64_t> transpose = extractFromIntegerArrayAttr<int64_t>(attr);
1794  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
1795  if (!std::is_permutation(sequence.begin(), sequence.end(),
1796  transpose.begin(), transpose.end())) {
1797  return emitOpError()
1798  << "expects transpose_paddings to be a permutation, found "
1799  << attr;
1800  }
1801  }
1802  if (getCopyBackOp() !=
1803  bufferization::MaterializeInDestinationOp::getOperationName() &&
1804  getCopyBackOp() != linalg::CopyOp::getOperationName() &&
1805  getCopyBackOp() != kCopyOpNone)
1806  return emitOpError() << "invalid copy_back_op";
1807  return success();
1808 }
1809 
1810 //===---------------------------------------------------------------------===//
1811 // HoistPadOp
1812 //===---------------------------------------------------------------------===//
1813 
1814 DiagnosedSilenceableFailure transform::HoistPadBuildPackingLoopNestOp::apply(
1815  transform::TransformRewriter &rewriter,
1816  transform::TransformResults &transformResults,
1817  transform::TransformState &state) {
1818  auto targetOps = state.getPayloadOps(getTarget());
1819  auto loopOps = state.getPayloadOps(getLoop());
1820  if (!llvm::hasSingleElement(targetOps) || !llvm::hasSingleElement(loopOps)) {
1821  return emitDefiniteFailure()
1822  << "requires exactly one target and one loop handle (got "
1823  << llvm::range_size(targetOps) << " and "
1824  << llvm::range_size(loopOps) << ")";
1825  }
1826 
1827  auto padOp = dyn_cast_or_null<tensor::PadOp>(*targetOps.begin());
1828  auto loopOp = dyn_cast_or_null<scf::ForOp>(*loopOps.begin());
1829  if (!padOp || !loopOp)
1830  return emitDefiniteFailure() << "requires exactly 2 non-null handles";
1831 
1833  linalg::detail::buildPackingLoopNest(rewriter, padOp, loopOp,
1834  getTranspose());
1835  if (failed(result))
1836  return emitDefiniteFailure() << "could not build packing loop nest";
1837 
1838  if (result->clonedLoopIvs.empty()) {
1839  transformResults.set(cast<OpResult>(getPackingLoop()),
1840  {result->hoistedPadOp.getOperation()});
1842  }
1843  auto outerPackedLoop =
1844  scf::getForInductionVarOwner(result->clonedLoopIvs.front());
1845  transformResults.set(cast<OpResult>(getPackingLoop()),
1846  {outerPackedLoop.getOperation()});
1848 }
1849 
1851  ArrayRef<int64_t> transpose = getTranspose();
1852  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
1853  if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),
1854  transpose.end())) {
1855  return emitOpError() << "expects transpose to be a permutation, found "
1856  << getTranspose();
1857  }
1858  return success();
1859 }
1860 
1861 void transform::HoistPadBuildPackingLoopNestOp::getEffects(
1863  transform::onlyReadsHandle(getTarget(), effects);
1864  transform::onlyReadsHandle(getLoop(), effects);
1865  transform::producesHandle(getPackingLoop(), effects);
1866  transform::modifiesPayload(effects);
1867 }
1868 
1870 transform::HoistPadOp::applyToOne(transform::TransformRewriter &rewriter,
1871  tensor::PadOp target,
1873  transform::TransformState &state) {
1874  tensor::PadOp hoistedPadOp;
1875  SmallVector<GenericOp> transposeOps;
1876  FailureOr<Value> result =
1877  hoistPaddingOnTensors(rewriter, target, getNumLoops(), getTranspose(),
1878  hoistedPadOp, transposeOps);
1879  if (succeeded(result)) {
1880  // We need to perform our own replacement here because this API is still
1881  // used in patterns that "pad and hoist", for which the replacement values
1882  // need to be different.
1883  // TODO: clean this up and stop "pad and hoist" behavior more globally now
1884  // that we have more composable abstractions.
1885  rewriter.replaceOp(target, *result);
1886  results.push_back(hoistedPadOp);
1888  }
1889  return emitDefaultSilenceableFailure(target);
1890 }
1891 
1893  ArrayRef<int64_t> transpose = getTranspose();
1894  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
1895  if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),
1896  transpose.end())) {
1897  return emitOpError() << "expects transpose to be a permutation, found "
1898  << getTranspose();
1899  }
1900  return success();
1901 }
1902 
1903 //===----------------------------------------------------------------------===//
1904 // PromoteOp
1905 //===----------------------------------------------------------------------===//
1906 
1908 transform::PromoteOp::applyToOne(transform::TransformRewriter &rewriter,
1909  LinalgOp target,
1911  transform::TransformState &state) {
1912  LinalgPromotionOptions promotionOptions;
1913  if (!getOperandsToPromote().empty())
1914  promotionOptions = promotionOptions.setOperandsToPromote(
1915  extractFromIntegerArrayAttr<int64_t>(getOperandsToPromote()));
1916  if (getUseFullTilesByDefault())
1917  promotionOptions = promotionOptions.setUseFullTileBuffersByDefault(
1918  getUseFullTilesByDefault());
1919  if (getUseAlloca())
1920  promotionOptions = promotionOptions.setUseAlloca(getUseAlloca());
1921  if (!getUseFullTileBuffers().empty())
1922  promotionOptions = promotionOptions.setUseFullTileBuffers(
1923  llvm::to_vector(getUseFullTileBuffers().getAsValueRange<BoolAttr>()));
1924  if (getAlignment().has_value())
1925  promotionOptions = promotionOptions.setAlignment(*getAlignment());
1926  if (getMemorySpace().has_value())
1927  promotionOptions = promotionOptions.setMemorySpace(*getMemorySpace());
1928 
1929  if (getMapping().has_value()) {
1930  // The mapping should only contain an element
1931  auto mapping = *getMapping();
1932  if (mapping.size() > 1)
1933  return emitDefaultDefiniteFailure(target);
1934 
1935  auto addressSpace = cast<mlir::gpu::GPUMemorySpaceMappingAttr>(mapping[0]);
1936 
1937  if (addressSpace.getAddressSpace() ==
1938  mlir::gpu::GPUDialect::getWorkgroupAddressSpace()) {
1939  promotionOptions =
1940  promotionOptions
1944  .setUseFullTileBuffers({false, false});
1945  } else if (addressSpace.getAddressSpace() ==
1946  mlir::gpu::GPUDialect::getPrivateAddressSpace()) {
1947  promotionOptions =
1948  promotionOptions
1952  .setUseFullTileBuffers({false, false});
1953  } else {
1954  return emitDefaultDefiniteFailure(target);
1955  }
1956  }
1957 
1958  if (failed(promoteSubviewsPrecondition(target, promotionOptions)))
1959  return emitDefaultDefiniteFailure(target);
1960 
1961  rewriter.setInsertionPoint(target);
1962  FailureOr<LinalgOp> res = promoteSubViews(rewriter, target, promotionOptions);
1963  if (failed(res))
1964  return emitDefaultDefiniteFailure(target);
1965  results.push_back(target);
1967 }
1968 
1969 //===----------------------------------------------------------------------===//
1970 // ReplaceOp
1971 //===----------------------------------------------------------------------===//
1972 
1974 transform::ReplaceOp::apply(transform::TransformRewriter &rewriter,
1975  TransformResults &transformResults,
1976  TransformState &state) {
1977  auto payload = state.getPayloadOps(getTarget());
1978 
1979  // Check for invalid targets.
1980  for (Operation *target : payload) {
1981  if (target->getNumOperands() > 0)
1982  return emitDefiniteFailure() << "expected target without operands";
1983  if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
1984  target->getNumRegions() > 0)
1985  return emitDefiniteFailure()
1986  << "expected target that is isolated from above";
1987  }
1988 
1989  // Clone and replace.
1990  Operation *pattern = &getBodyRegion().front().front();
1991  SmallVector<Operation *> replacements;
1992  for (Operation *target : payload) {
1993  if (getOperation()->isAncestor(target))
1994  continue;
1995  rewriter.setInsertionPoint(target);
1996  Operation *replacement = rewriter.clone(*pattern);
1997  rewriter.replaceOp(target, replacement->getResults());
1998  replacements.push_back(replacement);
1999  }
2000  transformResults.set(cast<OpResult>(getReplacement()), replacements);
2002 }
2003 
2004 void transform::ReplaceOp::getEffects(
2006  consumesHandle(getTarget(), effects);
2007  producesHandle(getReplacement(), effects);
2008  modifiesPayload(effects);
2009 }
2010 
2012  if (!getBodyRegion().hasOneBlock())
2013  return emitOpError() << "expected one block";
2014  if (std::distance(getBodyRegion().front().begin(),
2015  getBodyRegion().front().end()) != 1)
2016  return emitOpError() << "expected one operation in block";
2017  Operation *replacement = &getBodyRegion().front().front();
2018  if (replacement->getNumOperands() > 0)
2019  return replacement->emitOpError()
2020  << "expected replacement without operands";
2021  if (!replacement->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
2022  replacement->getNumRegions() > 0)
2023  return replacement->emitOpError()
2024  << "expect op that is isolated from above";
2025  return success();
2026 }
2027 
2028 //===----------------------------------------------------------------------===//
2029 // ScalarizeOp
2030 //===----------------------------------------------------------------------===//
2031 
2033 transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
2034  LinalgOp target,
2036  transform::TransformState &state) {
2037  scf::SCFTilingOptions tilingOptions;
2038  tilingOptions.setTileSizeComputationFunction([&](OpBuilder &b, Operation *) {
2039  SmallVector<OpFoldResult> tileSizes;
2040  Location loc = target.getLoc();
2041  SmallVector<OpFoldResult> allShapeSizes =
2042  target.createFlatListOfOperandDims(b, loc);
2043  AffineMap map = target.getShapesToLoopsMap();
2044  if (!map)
2045  return tileSizes;
2046  SmallVector<OpFoldResult> shapeSizes =
2048  allShapeSizes);
2049  // If the shape size is dynamic, tile by 1.
2050  // Otherwise, do not tile (i.e. tile size 0).
2051  for (OpFoldResult shapeSize : shapeSizes) {
2052  tileSizes.push_back(getConstantIntValue(shapeSize) ? b.getIndexAttr(0)
2053  : b.getIndexAttr(1));
2054  }
2055  return tileSizes;
2056  });
2057  SmallVector<int64_t> emptyTileSizes;
2058  rewriter.setInsertionPoint(target);
2060  rewriter, cast<TilingInterface>(target.getOperation()), tilingOptions);
2061  if (failed(maybeTilingResult))
2062  return emitDefaultDefiniteFailure(target);
2063 
2064  if (target->getNumResults())
2065  rewriter.replaceOp(target, maybeTilingResult->replacements);
2066  else
2067  rewriter.eraseOp(target);
2068 
2069  results.reserve(maybeTilingResult->tiledOps.size());
2070  for (Operation *tiled : maybeTilingResult->tiledOps)
2071  results.push_back(tiled);
2073 }
2074 
2075 //===----------------------------------------------------------------------===//
2076 // RewriteInDestinationPassingStyleOp
2077 //===----------------------------------------------------------------------===//
2078 
2080 transform::RewriteInDestinationPassingStyleOp::applyToOne(
2081  transform::TransformRewriter &rewriter, Operation *target,
2083  transform::TransformState &state) {
2085  rewriter.setInsertionPoint(target);
2086  FailureOr<Operation *> maybeResult =
2088  .Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp>(
2089  [&rewriter](auto op) {
2090  return rewriteInDestinationPassingStyle(rewriter, op);
2091  });
2092  if (failed(maybeResult))
2093  return emitDefaultSilenceableFailure(target);
2094  results.push_back(*maybeResult);
2096 }
2097 
2098 //===----------------------------------------------------------------------===//
2099 // SplitOp
2100 //===----------------------------------------------------------------------===//
2101 
2103 SplitOp::apply(transform::TransformRewriter &rewriter,
2104  TransformResults &results, TransformState &state) {
2105  // Collect the dynamic split points if provided.
2106  SmallVector<Operation *> payload =
2107  llvm::to_vector(state.getPayloadOps(getTarget()));
2108  SmallVector<OpFoldResult> splitPoints;
2109  splitPoints.reserve(payload.size());
2110  if (getDynamicSplitPoint()) {
2112  if (isa<TransformHandleTypeInterface>(getDynamicSplitPoint().getType())) {
2113  splitPoints = llvm::to_vector(llvm::map_range(
2114  state.getPayloadOps(getDynamicSplitPoint()), [&](Operation *op) {
2115  if (op->getNumResults() != 1 ||
2116  !op->getResult(0).getType().isIndex()) {
2117  diag = emitSilenceableError()
2118  << "expected dynamic split point handle to point to a "
2119  "single-result index-typed op";
2120  diag.attachNote(op->getLoc()) << "dynamic split point";
2121  }
2122  return OpFoldResult(op->getResult(0));
2123  }));
2124  } else {
2125  splitPoints = llvm::to_vector(
2126  llvm::map_range(state.getParams(getDynamicSplitPoint()),
2127  [](Attribute attr) { return OpFoldResult(attr); }));
2128  }
2129  if (diag.isSilenceableFailure())
2130  return diag;
2131 
2132  if (splitPoints.size() != payload.size()) {
2133  return emitDefiniteFailure()
2134  << "expected the dynamic split point handle to point to as "
2135  "many operations ("
2136  << splitPoints.size() << ") as the target handle ("
2137  << payload.size() << ")";
2138  }
2139  } else {
2140  splitPoints.resize(payload.size(),
2141  rewriter.getIndexAttr(getStaticSplitPoint()));
2142  }
2143 
2144  // Split each target operation.
2145  SmallVector<Operation *> first, second;
2146  Operation *noSecondPart = nullptr;
2147  for (const auto &pair : llvm::zip(payload, splitPoints)) {
2148  Operation *target = std::get<0>(pair);
2149  auto linalgOp = dyn_cast<LinalgOp>(target);
2150  if (!linalgOp) {
2151  auto diag = emitSilenceableError() << "only applies to structured ops";
2152  diag.attachNote(target->getLoc()) << "target op";
2153  return diag;
2154  }
2155 
2156  if (getDimension() >= linalgOp.getNumLoops()) {
2157  auto diag = emitSilenceableError() << "dimension " << getDimension()
2158  << " does not exist in target op";
2159  diag.attachNote(target->getLoc()) << "target op";
2160  return diag;
2161  }
2162 
2163  rewriter.setInsertionPoint(linalgOp);
2164  std::tie(first.emplace_back(), second.emplace_back()) = linalg::splitOp(
2165  rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2166  getDimension(), std::get<1>(pair));
2167 
2168  // Propagate errors.
2169  if (!first.back() && !second.back()) {
2170  auto diag = emitDefiniteFailure() << "internal failure in splitting";
2171  diag.attachNote(target->getLoc()) << "target op";
2172  return diag;
2173  }
2174 
2175  // Do not add null second parts.
2176  if (!second.back()) {
2177  noSecondPart = target;
2178  second.pop_back();
2179  }
2180  }
2181 
2182  if (second.size() != first.size() && !second.empty()) {
2183  auto diag = emitSilenceableError()
2184  << "splitting does not produce the second part for a subset "
2185  "of targets";
2186  diag.attachNote() << "expected splitting to produce the second part of all "
2187  "or none of the targets";
2188  diag.attachNote(noSecondPart->getLoc())
2189  << "first target with no second part";
2190  return diag;
2191  }
2192 
2193  results.set(cast<OpResult>(getFirst()), first);
2194  results.set(cast<OpResult>(getSecond()), second);
2196 }
2197 
2198 void SplitOp::getEffects(
2200  consumesHandle(getTarget(), effects);
2201  if (getDynamicSplitPoint())
2202  onlyReadsHandle(getDynamicSplitPoint(), effects);
2203  producesHandle(getResults(), effects);
2204  modifiesPayload(effects);
2205 }
2206 
2208  OpAsmParser::UnresolvedOperand target, dynamicSplitPoint;
2209  IntegerAttr staticSplitPoint;
2210  if (parser.parseOperand(target) || parser.parseKeyword("after"))
2211  return failure();
2212 
2213  OptionalParseResult dynamicPointParseResult =
2214  parser.parseOptionalOperand(dynamicSplitPoint);
2215  if (!dynamicPointParseResult.has_value()) {
2216  int64_t staticSplitPointValue;
2217  if (failed(parser.parseInteger(staticSplitPointValue)))
2218  return failure();
2219 
2220  staticSplitPoint =
2221  parser.getBuilder().getI64IntegerAttr(staticSplitPointValue);
2222  }
2223 
2224  Type targetType;
2225  if (parser.parseOptionalAttrDict(result.attributes) ||
2226  parser.parseColonType(targetType) ||
2227  parser.resolveOperand(target, targetType, result.operands)) {
2228  return failure();
2229  }
2230  if (dynamicPointParseResult.has_value()) {
2231  Type splitPointType;
2232  if (failed(*dynamicPointParseResult) || parser.parseComma() ||
2233  parser.parseType(splitPointType) ||
2234  parser.resolveOperand(dynamicSplitPoint, splitPointType,
2235  result.operands)) {
2236  return failure();
2237  }
2238 
2239  staticSplitPoint =
2240  parser.getBuilder().getI64IntegerAttr(ShapedType::kDynamic);
2241  }
2242 
2243  result.addAttribute(
2244  SplitOp::getStaticSplitPointAttrName(result.name).getValue(),
2245  staticSplitPoint);
2246  result.addTypes({targetType, targetType});
2247  return success();
2248 }
2249 
2250 void SplitOp::print(OpAsmPrinter &printer) {
2251  printer << " " << getTarget() << " after ";
2252  int64_t staticSplitSize = static_cast<int64_t>(getStaticSplitPoint());
2253  if (staticSplitSize != ShapedType::kDynamic)
2254  printer << staticSplitSize;
2255  else
2256  printer << getDynamicSplitPoint();
2257  printer << " ";
2258  printer.printOptionalAttrDict(getOperation()->getAttrs(),
2259  {getStaticSplitPointAttrName()});
2260  printer << " : " << getTarget().getType();
2261  if (staticSplitSize == ShapedType::kDynamic)
2262  printer << ", " << getDynamicSplitPoint().getType();
2263 }
2264 
2266  if ((static_cast<int64_t>(getStaticSplitPoint()) != ShapedType::kDynamic) ^
2267  (getDynamicSplitPoint() == nullptr)) {
2268  return emitOpError() << "expects either a dynamic or a static split "
2269  "point to be provided";
2270  }
2271  return success();
2272 }
2273 
2274 //===----------------------------------------------------------------------===//
2275 // SplitReductionOp
2276 //===----------------------------------------------------------------------===//
2277 
2278 void transform::SplitReductionOp::build(
2279  OpBuilder &builder, OperationState &result, Value target,
2280  int64_t splitFactor, int64_t insertSplitDimension, bool innerParallel,
2281  bool useScalingAlgorithm, bool useAlloc) {
2282  MLIRContext *ctx = builder.getContext();
2283  result.addOperands(target);
2284  result.addAttribute(SplitReductionOp::getSplitFactorAttrName(result.name),
2285  builder.getI64IntegerAttr(splitFactor));
2286  result.addAttribute(
2287  SplitReductionOp::getInsertSplitDimensionAttrName(result.name),
2288  builder.getI64IntegerAttr(insertSplitDimension));
2289  if (innerParallel) {
2290  result.addAttribute(SplitReductionOp::getInnerParallelAttrName(result.name),
2291  builder.getUnitAttr());
2292  }
2293  if (useScalingAlgorithm) {
2294  result.addAttribute(
2295  SplitReductionOp::getUseScalingAlgorithmAttrName(result.name),
2296  builder.getUnitAttr());
2297  }
2298  if (useAlloc) {
2299  result.addAttribute(SplitReductionOp::getUseAllocAttrName(result.name),
2300  builder.getUnitAttr());
2301  }
2302  auto resultType = transform::AnyOpType::get(ctx);
2303  result.addTypes({resultType, resultType, resultType, resultType});
2304 }
2305 
2306 DiagnosedSilenceableFailure transform::SplitReductionOp::applyToOne(
2307  transform::TransformRewriter &rewriter, LinalgOp target,
2309  transform::TransformState &state) {
2310  ControlSplitReductionFn splitFn = [&](LinalgOp) {
2311  return linalg::SplitReductionOptions{int64_t(getSplitFactor()),
2312  unsigned(getInsertSplitDimension()),
2313  bool(getInnerParallel())};
2314  };
2315  rewriter.setInsertionPoint(target);
2316  FailureOr<SplitReductionResult> splitResult =
2317  (getUseScalingAlgorithm())
2318  ? splitReductionByScaling(rewriter, target, splitFn, getUseAlloc())
2319  : splitReduction(rewriter, target, splitFn, getUseAlloc());
2320  if (failed(splitResult))
2321  return emitDefaultDefiniteFailure(target);
2322 
2323  results.push_back(splitResult->initOrAlloc);
2324  results.push_back(splitResult->fillOp);
2325  results.push_back(splitResult->splitLinalgOp);
2326  results.push_back(splitResult->resultCombiningLinalgOp);
2328 }
2329 
2330 //===----------------------------------------------------------------------===//
2331 // TileReductionUsingForOp
2332 //===----------------------------------------------------------------------===//
2333 
2334 void transform::TileReductionUsingForOp::build(
2335  OpBuilder &builder, OperationState &result, Value target,
2336  ArrayRef<int64_t> staticTileSizes) {
2337  // Call the default builder.
2338  // This is future-proof re mixed static-dynamic and setting up the proper
2339  // operands segment sizes attributes for multiple variadic operands.
2340  // In the absence of this, horrible bugs ensue.
2341  // TODO: support mixed static-dynamic (see TileUsingForallOp).
2342  MLIRContext *ctx = builder.getContext();
2343  auto opTy = transform::AnyOpType::get(ctx);
2344  auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
2345  build(builder, result,
2346  /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
2347  /*target=*/target,
2348  /*tile_sizes=*/staticTileSizesAttr);
2349 }
2350 
2351 DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
2352  transform::TransformRewriter &rewriter, LinalgOp target,
2354  transform::TransformState &state) {
2355  rewriter.setInsertionPoint(target);
2357  rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
2358  getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes())));
2359 
2360  if (failed(result))
2361  return emitDefaultSilenceableFailure(target);
2362  results.push_back(result->initialOp);
2363  results.push_back(result->parallelTiledOp);
2364  results.push_back(result->mergeOp);
2365  results.push_back(result->loops.front());
2367 }
2368 
2369 //===----------------------------------------------------------------------===//
2370 // TileReductionUsingForallOp
2371 //===----------------------------------------------------------------------===//
2372 
2373 void transform::TileReductionUsingForallOp::build(
2374  OpBuilder &builder, OperationState &result, Value target,
2375  ArrayRef<int64_t> staticNumThreads, ArrayRef<int64_t> staticTileSizes,
2376  ArrayAttr mapping) {
2377  // Call the default builder.
2378  // This is future-proof re mixed static-dynamic and setting up the proper
2379  // operands segment sizes attributes for multiple variadic operands.
2380  // In the absence of this, horrible bugs ensue.
2381  // TODO: support mixed static-dynamic (see TileUsingForallOp).
2382  MLIRContext *ctx = builder.getContext();
2383  auto opTy = transform::AnyOpType::get(ctx);
2384  auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads);
2385  auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
2386  build(builder, result,
2387  /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
2388  /*target=*/target,
2389  /*num_threads=*/staticNumThreadsAttr,
2390  /*tile_sizes=*/staticTileSizesAttr,
2391  /*mapping=*/mapping);
2392 }
2393 
2394 DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
2395  transform::TransformRewriter &rewriter, LinalgOp target,
2397  transform::TransformState &state) {
2398  rewriter.setInsertionPoint(target);
2399  SmallVector<OpFoldResult> numThreads =
2400  getAsOpFoldResult(rewriter.getI64ArrayAttr(getNumThreads()));
2401  SmallVector<OpFoldResult> tileSizes =
2402  getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes()));
2405  rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
2406  numThreads, tileSizes, getMapping());
2407 
2408  if (failed(result)) {
2409  auto diag = emitSilenceableError() << "could not tile reduction";
2410  diag.attachNote(target.getLoc()) << "target operation";
2411  return diag;
2412  }
2413  results.push_back(result->initialOp);
2414  results.push_back(result->parallelTiledOp);
2415  results.push_back(result->mergeOp);
2416  results.push_back(result->loops);
2418 }
2419 
2420 //===----------------------------------------------------------------------===//
2421 // TileUsingForOp
2422 //===----------------------------------------------------------------------===//
2423 
2424 void transform::TileUsingForOp::build(
2425  OpBuilder &builder, OperationState &result, TypeRange loopTypes,
2426  Value target, ArrayRef<int64_t> staticTileSizes,
2427  ArrayRef<int64_t> interchange,
2428  std::optional<ArrayRef<bool>> scalableSizes) {
2429  return build(builder, result, loopTypes,
2430  /*target=*/target,
2431  /*mixedTileSizes=*/
2432  getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
2433  interchange, scalableSizes);
2434 }
2435 
2436 void transform::TileUsingForOp::build(
2437  OpBuilder &builder, OperationState &result, Value target,
2438  ArrayRef<int64_t> staticTileSizes, ArrayRef<int64_t> interchange,
2439  std::optional<ArrayRef<bool>> scalableSizes) {
2440  build(builder, result, target,
2441  getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
2442  interchange, scalableSizes);
2443 }
2444 
2445 void transform::TileUsingForOp::build(
2446  OpBuilder &builder, OperationState &result, Value target,
2447  ArrayRef<OpFoldResult> mixedTileSizes, ArrayRef<int64_t> interchange,
2448  std::optional<ArrayRef<bool>> scalableSizes) {
2449  // Loop types are automaticaly splat by the callee, setting up one is
2450  // enough.
2451  SmallVector<Type> loopTypes(1, builder.getType<transform::AnyOpType>());
2452  build(builder, result, loopTypes, target, mixedTileSizes, interchange,
2453  scalableSizes);
2454 }
2455 
2456 void transform::TileUsingForOp::build(
2457  OpBuilder &builder, OperationState &result, TypeRange loopTypes,
2458  Value target, ArrayRef<OpFoldResult> mixedTileSizes,
2459  ArrayRef<int64_t> interchange,
2460  std::optional<ArrayRef<bool>> scalableSizes) {
2461  SmallVector<int64_t> staticTileSizes;
2462  SmallVector<Value> dynamicTileSizes;
2463  dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
2464  // Call the default builder which sets up the proper operands segment sizes
2465  // attributes for multiple variadic operands. In the absence of this,
2466  // horrible bugs ensue.
2467  auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
2468  unsigned numExpectedLoops =
2469  staticTileSizes.size() - llvm::count(staticTileSizes, 0);
2470  SmallVector<Type> resultTypes;
2471  resultTypes.reserve(numExpectedLoops);
2472  assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) &&
2473  "expected one loop type or as many as loops");
2474  if (loopTypes.size() == 1)
2475  resultTypes.append(numExpectedLoops, loopTypes[0]);
2476  else
2477  llvm::append_range(resultTypes, loopTypes);
2478  SmallVector<bool> expandedScalableSizes(mixedTileSizes.size(), false);
2479  if (scalableSizes.has_value())
2480  expandedScalableSizes.assign(scalableSizes->begin(), scalableSizes->end());
2481  build(builder, result, /*tiled_linalg_op=*/target.getType(),
2482  /*loops=*/resultTypes,
2483  /*target=*/target,
2484  /*dynamic_sizes=*/dynamicTileSizes,
2485  /*static_sizes=*/staticTileSizesAttr,
2486  /*interchange=*/builder.getDenseI64ArrayAttr(interchange),
2487  /*scalable_sizes=*/expandedScalableSizes);
2488 }
2489 
2491  if (getMixedSizes().size() != getScalableSizes().size())
2492  return emitOpError("expected same number of sizes (")
2493  << getMixedSizes().size() << ") and scalable sizes ()"
2494  << getScalableSizes().size() << ")";
2495  return success();
2496 }
2497 
2499 transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter,
2500  TransformResults &transformResults,
2501  TransformState &state) {
2502  ArrayRef<int64_t> tileSizes = getStaticSizes();
2503 
2504  SmallVector<Operation *> targets =
2505  llvm::to_vector(state.getPayloadOps(getTarget()));
2506  SmallVector<SmallVector<Operation *>> dynamicSizeProducers;
2508  dynamicSizeProducers.reserve(getDynamicSizes().size());
2509  paramSizes.reserve(getDynamicSizes().size());
2510  for (Value transformValue : getDynamicSizes()) {
2511  if (isa<ParamType>(transformValue.getType())) {
2512  dynamicSizeProducers.push_back({});
2513  ArrayRef<Attribute> params = state.getParams(transformValue);
2514  paramSizes.push_back(
2515  llvm::to_vector(llvm::map_range(params, [](Attribute attr) {
2516  return cast<IntegerAttr>(attr).getValue().getSExtValue();
2517  })));
2518 
2519  if (paramSizes.back().size() != targets.size()) {
2521  emitSilenceableError()
2522  << "expected as many parameter values ("
2523  << dynamicSizeProducers.back().size() << ") as target ops ("
2524  << targets.size() << ")";
2525  diag.attachNote(transformValue.getLoc()) << "for this parameter";
2526  return diag;
2527  }
2528 
2529  continue;
2530  }
2531  paramSizes.push_back({});
2532  dynamicSizeProducers.push_back(
2533  llvm::to_vector(state.getPayloadOps(transformValue)));
2534 
2535  if (dynamicSizeProducers.back().size() != targets.size()) {
2537  emitSilenceableError()
2538  << "expected as many dynamic size-producing operations ("
2539  << dynamicSizeProducers.back().size() << ") as target ops ("
2540  << targets.size() << ")";
2541  diag.attachNote(transformValue.getLoc()) << "for this handle";
2542  return diag;
2543  }
2544 
2545  for (Operation *op : dynamicSizeProducers.back()) {
2546  if (op->getNumResults() == 1 &&
2547  isa<IndexType>(op->getResult(0).getType())) {
2548  continue;
2549  }
2550 
2552  emitSilenceableError() << "expected sizes to be produced by ops "
2553  "with a single index-type result";
2554  diag.attachNote(op->getLoc()) << "size producer op";
2555  diag.attachNote(transformValue.getLoc()) << "for this handle";
2556  return diag;
2557  }
2558  }
2559 
2562  loops.resize(getLoops().size());
2563  auto scalableSizes = getScalableSizes();
2564  for (auto [i, op] : llvm::enumerate(targets)) {
2565  auto tilingInterface = dyn_cast<TilingInterface>(op);
2566  if (!tilingInterface) {
2568  emitSilenceableError()
2569  << "only ops implementing TilingInterface are supported";
2570  diag.attachNote(op->getLoc()) << "target op";
2571  return diag;
2572  }
2573  if (tileSizes.size() > tilingInterface.getLoopIteratorTypes().size()) {
2575  emitSilenceableError()
2576  << "too many tiles provided, expected at most "
2577  << tilingInterface.getLoopIteratorTypes().size() << " found "
2578  << tileSizes.size();
2579  diag.attachNote(op->getLoc()) << "target op";
2580  return diag;
2581  }
2582 
2583  scf::SCFTilingOptions tilingOptions;
2584  if (!tileSizes.empty()) {
2585  tilingOptions.setTileSizeComputationFunction([&, index = i](OpBuilder &b,
2586  Operation *) {
2588  sizes.reserve(tileSizes.size());
2589  unsigned dynamicIdx = 0;
2590 
2591  for (auto [ofrIdx, ofr] : llvm::enumerate(getMixedSizes())) {
2592  if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) {
2593  if (scalableSizes[ofrIdx]) {
2594  auto val = b.create<arith::ConstantIndexOp>(
2595  getLoc(), attr.cast<IntegerAttr>().getInt());
2596  Value vscale =
2597  b.create<vector::VectorScaleOp>(getLoc(), b.getIndexType());
2598  sizes.push_back(
2599  b.create<arith::MulIOp>(getLoc(), val, vscale).getResult());
2600  } else {
2601  sizes.push_back(attr);
2602  }
2603  continue;
2604  }
2605  ArrayRef<Operation *> dynamicSizes = dynamicSizeProducers[dynamicIdx];
2606  ArrayRef<int64_t> params = paramSizes[dynamicIdx];
2607  ++dynamicIdx;
2608  assert((dynamicSizes.empty() ^ params.empty()) &&
2609  "expected either dynamic sizes or parameters");
2610  if (!params.empty()) {
2611  sizes.push_back(b.getIndexAttr(params[index]));
2612  } else {
2613  sizes.push_back(dynamicSizes[index]->getResult(0));
2614  }
2615  }
2616  return sizes;
2617  });
2618  }
2619 
2620  tilingOptions.setInterchange(getInterchange());
2621  FailureOr<scf::SCFTilingResult> maybeTilingResult =
2622  tileUsingSCFForOp(rewriter, tilingInterface, tilingOptions);
2623  if (failed(maybeTilingResult))
2625 
2626  rewriter.replaceOp(op, maybeTilingResult->replacements);
2627 
2628  tiled.append(maybeTilingResult->tiledOps);
2629  for (const auto &en2 : llvm::enumerate(maybeTilingResult->loops))
2630  loops[en2.index()].push_back(en2.value());
2631  }
2632 
2633  transformResults.set(cast<OpResult>(getTiledLinalgOp()), tiled);
2634  for (const auto &en : llvm::enumerate(loops))
2635  transformResults.set(cast<OpResult>(getLoops()[en.index()]), en.value());
2636 
2638 }
2639 
2641  ValueRange dynamic = getDynamicSizes();
2642  ArrayRef<int64_t> tileSizes = getStaticSizes();
2643  SmallVector<OpFoldResult> results;
2644  results.reserve(tileSizes.size());
2645  unsigned dynamicPos = 0;
2646  Builder builder(getContext());
2647  for (int64_t size : tileSizes) {
2648  if (size == ShapedType::kDynamic) {
2649  results.push_back(dynamic[dynamicPos++]);
2650  } else {
2651  results.push_back(builder.getIndexAttr(size));
2652  }
2653  }
2654  return results;
2655 }
2656 
2657 // We want to parse `DenseI64ArrayAttr` using the short form without the
2658 // `array` prefix to be consistent in the IR with `parseDynamicIndexList`.
2660  OperationState &result) {
2661  if (succeeded(parser.parseOptionalLBrace())) {
2662  if (failed(parser.parseKeyword("interchange")))
2663  return parser.emitError(parser.getNameLoc()) << "expect `interchange`";
2664  if (failed(parser.parseEqual()))
2665  return parser.emitError(parser.getNameLoc()) << "expect `=`";
2666  result.addAttribute("interchange",
2667  DenseI64ArrayAttr::parse(parser, Type{}));
2668  if (failed(parser.parseRBrace()))
2669  return parser.emitError(parser.getNameLoc()) << "expect `}`";
2670  }
2671  return success();
2672 }
2673 
2675  ArrayRef<int64_t> interchangeVals) {
2676  if (!interchangeVals.empty()) {
2677  p << " {interchange = [";
2678  llvm::interleaveComma(interchangeVals, p,
2679  [&](int64_t integer) { p << integer; });
2680  p << "]}";
2681  }
2682 }
2683 
2685  OperationState &result) {
2688  DenseI64ArrayAttr staticSizes;
2689  FunctionType functionalType;
2690  llvm::SMLoc operandLoc;
2691  DenseBoolArrayAttr scalableVals;
2692 
2693  if (parser.parseOperand(target) || parser.getCurrentLocation(&operandLoc) ||
2694  parseDynamicIndexList(parser, dynamicSizes, staticSizes, scalableVals) ||
2695  parseOptionalInterchange(parser, result) ||
2696  parser.parseColonType(functionalType))
2697  return ParseResult::failure();
2698 
2699  size_t numExpectedLoops =
2700  staticSizes.size() - llvm::count(staticSizes.asArrayRef(), 0);
2701  if (functionalType.getNumResults() != numExpectedLoops + 1) {
2702  return parser.emitError(parser.getNameLoc())
2703  << "expected " << (numExpectedLoops + 1) << " result type(s)";
2704  }
2705  if (functionalType.getNumInputs() != dynamicSizes.size() + 1) {
2706  return parser.emitError(operandLoc)
2707  << "expected " << dynamicSizes.size() + 1 << " operand type(s)";
2708  }
2709  if (parser.resolveOperand(target, functionalType.getInputs().front(),
2710  result.operands) ||
2711  parser.resolveOperands(dynamicSizes,
2712  functionalType.getInputs().drop_front(),
2713  operandLoc, result.operands)) {
2714  return failure();
2715  }
2716 
2717  result.addAttribute(getScalableSizesAttrName(result.name), scalableVals);
2718 
2719  result.addAttribute(getStaticSizesAttrName(result.name), staticSizes);
2720  result.addTypes(functionalType.getResults());
2721  return success();
2722 }
2723 
2725  p << ' ' << getTarget();
2726  printDynamicIndexList(p, getOperation(), getDynamicSizes(), getStaticSizes(),
2727  /*valueTypes=*/{}, getScalableSizesAttr(),
2729  printOptionalInterchange(p, getInterchange());
2730  p << " : ";
2731  p.printFunctionalType(getOperands().getTypes(), getResults().getTypes());
2732 }
2733 
2734 void transform::TileUsingForOp::getEffects(
2736  consumesHandle(getTarget(), effects);
2737  onlyReadsHandle(getDynamicSizes(), effects);
2738  producesHandle(getTiledLinalgOp(), effects);
2739  producesHandle(getLoops(), effects);
2740  modifiesPayload(effects);
2741 }
2742 
2743 //===----------------------------------------------------------------------===//
2744 // TileUsingForallOp
2745 //===----------------------------------------------------------------------===//
2746 
2747 void transform::TileUsingForallOp::build(OpBuilder &builder,
2748  OperationState &result, Value target,
2749  ArrayRef<int64_t> staticTileSizes,
2751  ArrayAttr mapping) {
2752  return build(builder, result,
2753  /*target=*/target,
2754  /*mixedTileSizes=*/
2755  getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
2756  /*_=*/TileSizesSpec(),
2757  /*mapping=*/mapping);
2758 }
2759 
2760 void transform::TileUsingForallOp::build(OpBuilder &builder,
2761  OperationState &result, Value target,
2762  ArrayRef<OpFoldResult> mixedTileSizes,
2764  ArrayAttr mapping) {
2765  SmallVector<int64_t> staticTileSizes;
2766  SmallVector<Value> dynamicTileSizes;
2767  dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
2768  // Call the default builder which sets up the proper operands segment sizes
2769  // attributes for multiple variadic operands. In the absence of this,
2770  // horrible bugs ensue.
2771  MLIRContext *ctx = builder.getContext();
2772  auto operationType = transform::AnyOpType::get(ctx);
2773  auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
2774  build(builder, result,
2775  /*resultTypes=*/TypeRange{operationType, operationType},
2776  /*target=*/target,
2777  /*num_threads=*/ValueRange{},
2778  /*tile_sizes=*/dynamicTileSizes,
2779  /*packed_num_threads=*/Value(),
2780  /*packed_tile_sizes=*/Value(),
2781  /*static_num_threads=*/builder.getDenseI64ArrayAttr({}),
2782  /*static_tile_sizes=*/staticTileSizesAttr,
2783  /*mapping=*/mapping);
2784 }
2785 
2786 void transform::TileUsingForallOp::build(OpBuilder &builder,
2787  OperationState &result, Value target,
2788  ArrayRef<int64_t> staticNumThreads,
2790  ArrayAttr mapping) {
2791  return build(builder, result, target,
2792  getAsOpFoldResult(builder.getI64ArrayAttr(staticNumThreads)),
2793  NumThreadsSpec(), mapping);
2794 }
2795 
2796 void transform::TileUsingForallOp::build(OpBuilder &builder,
2797  OperationState &result, Value target,
2798  ArrayRef<OpFoldResult> mixedNumThreads,
2800  ArrayAttr mapping) {
2801  SmallVector<int64_t> staticNumThreads;
2802  SmallVector<Value> dynamicNumThreads;
2803  dispatchIndexOpFoldResults(mixedNumThreads, dynamicNumThreads,
2804  staticNumThreads);
2805  // Call the default builder which sets up the proper operands segment sizes
2806  // attributes for multiple variadic operands. In the absence of this,
2807  // horrible bugs ensue.
2808  MLIRContext *ctx = builder.getContext();
2809  auto operationType = transform::AnyOpType::get(ctx);
2810  auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads);
2811  build(builder, result,
2812  /*resultTypes=*/TypeRange{operationType, operationType},
2813  /*target=*/target,
2814  /*num_threads=*/dynamicNumThreads,
2815  /*tile_sizes=*/ValueRange{},
2816  /*packed_num_threads=*/Value(),
2817  /*packed_tile_sizes=*/Value(),
2818  /*static_num_threads=*/staticNumThreadsAttr,
2819  /*static_tile_sizes=*/builder.getDenseI64ArrayAttr({}),
2820  /*mapping=*/mapping);
2821 }
2822 
2824  RewriterBase &rewriter, transform::TransformState &state,
2825  TransformOpInterface transformOp, Operation *target,
2826  ArrayRef<OpFoldResult> mixedNumThreads,
2827  ArrayRef<OpFoldResult> mixedTileSizes, std::optional<ArrayAttr> mapping,
2828  linalg::ForallTilingResult &tilingResult) {
2829  // Transform all targets one by one.
2830  auto tileableOp = dyn_cast<TilingInterface>(target);
2831  if (!tileableOp) {
2833  transformOp.emitSilenceableError()
2834  << "only TilingInterface ops are supported";
2835  diag.attachNote(target->getLoc()) << "target op";
2836  return diag;
2837  }
2838  rewriter.setInsertionPoint(tileableOp);
2839  FailureOr<linalg::ForallTilingResult> maybeTilingResult = failure();
2840  if (!mixedNumThreads.empty()) {
2841  maybeTilingResult =
2842  linalg::tileToForallOp(rewriter, tileableOp, mixedNumThreads, mapping);
2843  } else {
2844  maybeTilingResult = linalg::tileToForallOpUsingTileSizes(
2845  rewriter, tileableOp, mixedTileSizes, mapping);
2846  }
2847 
2848  if (failed(maybeTilingResult))
2849  return transformOp.emitDefaultSilenceableFailure(tileableOp);
2850  rewriter.replaceOp(tileableOp, maybeTilingResult->tileOp->getResults());
2851 
2852  tilingResult = *maybeTilingResult;
2854 }
2855 
2856 DiagnosedSilenceableFailure transform::TileUsingForallOp::apply(
2857  transform::TransformRewriter &rewriter,
2858  transform::TransformResults &transformResults,
2859  transform::TransformState &state) {
2860  auto transformOp = cast<TransformOpInterface>(getOperation());
2861 
2862  // Result payload ops.
2863  SmallVector<Operation *> tileOps;
2864  SmallVector<Operation *> tiledOps;
2865 
2866  // Unpack handles.
2867  SmallVector<OpFoldResult> mixedNumThreads;
2869  getPackedNumThreads()
2871  state, transformOp, mixedNumThreads, getPackedNumThreads())
2873  state, transformOp, mixedNumThreads, getMixedNumThreads());
2874  if (!status.succeeded())
2875  return status;
2876  SmallVector<OpFoldResult> mixedTileSizes;
2877  status = getPackedTileSizes()
2879  state, transformOp, mixedTileSizes, getPackedTileSizes())
2881  state, transformOp, mixedTileSizes, getMixedTileSizes());
2882  if (!status.succeeded())
2883  return status;
2884 
2885  for (Operation *target : state.getPayloadOps(getTarget())) {
2886  linalg::ForallTilingResult tilingResult;
2888  rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes,
2889  getMapping(), tilingResult);
2890  if (!diag.succeeded())
2891  return diag;
2892  tileOps.push_back(tilingResult.tileOp);
2893  tiledOps.push_back(tilingResult.tiledOp);
2894  }
2895 
2896  transformResults.set(cast<OpResult>(getForallOp()), tileOps);
2897  transformResults.set(cast<OpResult>(getTiledOp()), tiledOps);
2898 
2900 }
2901 
2902 void transform::TileUsingForallOp::getEffects(
2904  consumesHandle(getTarget(), effects);
2905  onlyReadsHandle(getTileSizes(), effects);
2906  onlyReadsHandle(getNumThreads(), effects);
2907  onlyReadsHandle(getPackedNumThreads(), effects);
2908  onlyReadsHandle(getPackedTileSizes(), effects);
2909  producesHandle(getResults(), effects);
2910  modifiesPayload(effects);
2911 }
2912 
2913 SmallVector<OpFoldResult> TileUsingForallOp::getMixedNumThreads() {
2914  Builder b(getContext());
2915  return getMixedValues(getStaticNumThreads(), getNumThreads(), b);
2916 }
2917 
2918 SmallVector<OpFoldResult> TileUsingForallOp::getMixedTileSizes() {
2919  Builder b(getContext());
2920  return getMixedValues(getStaticTileSizes(), getTileSizes(), b);
2921 }
2922 
2924  int numThreadsSpec = static_cast<int>(!getMixedNumThreads().empty()) +
2925  static_cast<int>(getPackedNumThreads() != Value());
2926  if (numThreadsSpec > 1)
2927  return emitOpError(
2928  "num_threads and packed_num_threads are mutually exclusive");
2929  int tileSizesSpec = static_cast<int>(!getMixedTileSizes().empty()) +
2930  static_cast<int>(getPackedTileSizes() != Value());
2931  if (tileSizesSpec > 1)
2932  return emitOpError(
2933  "tile_sizes and packed_tile_sizes are mutually exclusive");
2934  if (numThreadsSpec == 0 && tileSizesSpec == 0)
2935  return emitOpError("either (packed_)num_threads or (packed_)tile_sizes "
2936  "must be specified");
2937  return success();
2938 }
2939 
2940 //===----------------------------------------------------------------------===//
2941 // VectorizeChildrenAndApplyPatternsOp
2942 //===----------------------------------------------------------------------===//
2943 
2944 void transform::VectorizeChildrenAndApplyPatternsOp::build(
2945  OpBuilder &builder, OperationState &result, Value target,
2946  bool vectorizePadding, bool vectorizeExtract) {
2947  result.addOperands(target);
2948  if (vectorizePadding) {
2949  result.addAttribute(
2950  VectorizeChildrenAndApplyPatternsOp::getVectorizePaddingAttrName(
2951  result.name),
2952  builder.getUnitAttr());
2953  }
2954  if (vectorizeExtract) {
2955  result.addAttribute(
2956  VectorizeChildrenAndApplyPatternsOp::getVectorizeNdExtractAttrName(
2957  result.name),
2958  builder.getUnitAttr());
2959  }
2960  result.addTypes(transform::AnyOpType::get(builder.getContext()));
2961 }
2962 
2963 namespace {
2964 /// This is an helper only to call vectorize via a pattern inside of
2965 /// VectorizeChildrenAndApplyPatternsOp::applyToOne.
2966 struct VectorizationPattern : public RewritePattern {
2967  explicit VectorizationPattern(MLIRContext *context,
2968  bool vectorizeExtract = false)
2969  : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context),
2970  vectorizeNDExtract(vectorizeExtract) {}
2971  LogicalResult matchAndRewrite(Operation *op,
2972  PatternRewriter &rewriter) const override {
2973  LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
2974  if (!linalgOp)
2975  return rewriter.notifyMatchFailure(op, "expected Linalg Op");
2976  return vectorize(rewriter, linalgOp, /*inputVectorSizes=*/{},
2977  /*scalableVecDims=*/{}, vectorizeNDExtract);
2978  }
2979 
2980 private:
2981  /// Controls whether to vectorize `tensor.extract` when the input tensor is
2982  /// rank >= 2.
2983  bool vectorizeNDExtract = false;
2984 };
2985 } // namespace
2986 
2988 transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
2989  transform::TransformRewriter &rewriter, Operation *target,
2991  transform::TransformState &state) {
2992  if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
2993  auto diag = this->emitOpError("requires isolated-from-above targets");
2994  diag.attachNote(target->getLoc()) << "non-isolated target";
2996  }
2997 
2998  MLIRContext *ctx = getContext();
2999  RewritePatternSet patterns(ctx);
3000  patterns.add<VectorizationPattern>(ctx, getVectorizeNdExtract());
3001 
3002  if (!getDisableTransferPermutationMapLoweringPatterns())
3004 
3005  if (!getDisableMultiReductionToContractPatterns())
3007 
3009 
3012  /*benefit=*/2);
3013  vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx);
3014  vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx);
3016 
3017  patterns.add<CopyVectorizationPattern>(ctx);
3018 
3019  if (getVectorizePadding())
3021 
3022  TrackingListener listener(state, *this);
3023  GreedyRewriteConfig config;
3024  config.listener = &listener;
3025  if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns), config)))
3026  return emitDefaultDefiniteFailure(target);
3027 
3028  results.push_back(target);
3030 }
3031 
3032 //===----------------------------------------------------------------------===//
3033 // VectorizeOp
3034 //===----------------------------------------------------------------------===//
3035 DiagnosedSilenceableFailure transform::VectorizeOp::apply(
3036  transform::TransformRewriter &rewriter,
3037  mlir::transform::TransformResults &transformResults,
3039  auto targets = state.getPayloadOps(getTarget());
3040  if (std::empty(targets))
3042 
3043  SmallVector<int64_t> vectorSizes;
3044  for (OpFoldResult sz : getMixedVectorSizes()) {
3045  if (sz.is<Attribute>()) {
3046  auto attr = sz.get<Attribute>();
3047  vectorSizes.push_back(cast<IntegerAttr>(attr).getInt());
3048  continue;
3049  }
3050 
3051  auto szPayloads = state.getPayloadOps(sz.get<Value>());
3052  if (!llvm::hasSingleElement(szPayloads)) {
3053  auto diag = this->emitOpError(
3054  "requires vector size handle that is mapped to 1 payload op");
3055  diag.attachNote(sz.get<Value>().getLoc())
3056  << "mapped to " << llvm::range_size(szPayloads) << " payload ops";
3058  }
3059 
3060  Operation *szPayloadOp = *szPayloads.begin();
3061  if (szPayloadOp->getNumResults() != 1 ||
3062  !szPayloadOp->getResult(0).getType().isIndex()) {
3063  auto diag = this->emitOpError(
3064  "requires vector size payload op with 1 index result");
3065  diag.attachNote(szPayloadOp->getLoc()) << "vector size payload op";
3067  }
3068 
3069  IntegerAttr attr;
3070  if (!matchPattern(szPayloadOp->getResult(0), m_Constant(&attr))) {
3071  auto diag = this->emitOpError("requires constant vector size");
3072  diag.attachNote(szPayloadOp->getLoc()) << "vector size payload op";
3074  }
3075 
3076  vectorSizes.push_back(attr.getInt());
3077  }
3078 
3079  // TODO: Check that the correct number of vectorSizes was provided.
3080  for (Operation *target : targets) {
3081  if (!isa<linalg::LinalgOp, tensor::PadOp>(target)) {
3082  return mlir::emitSilenceableFailure(target->getLoc())
3083  << "Unsupported Op, cannot vectorize";
3084  }
3085 
3086  if (failed(linalg::vectorize(rewriter, target, vectorSizes,
3087  getScalableSizes(),
3088  getVectorizeNdExtract().has_value()
3089  ? getVectorizeNdExtract().value()
3090  : false))) {
3091  return mlir::emitSilenceableFailure(target->getLoc())
3092  << "Attempted to vectorize, but failed";
3093  }
3094  }
3095 
3097 }
3098 
3099 void transform::VectorizeOp::getEffects(
3101  consumesHandle(getTarget(), effects);
3102  onlyReadsHandle(getVectorSizes(), effects);
3103  modifiesPayload(effects);
3104 }
3105 
3106 SmallVector<OpFoldResult> VectorizeOp::getMixedVectorSizes() {
3107  OpBuilder b(getContext());
3108  return getMixedValues(getStaticVectorSizes(), getVectorSizes(), b);
3109 }
3110 
3112  if (getStaticVectorSizes().size() != getScalableSizes().size())
3113  return emitOpError("expected same number of vector sizes (")
3114  << getStaticVectorSizes().size() << ") and scalable sizes ("
3115  << getScalableSizes().size() << ")";
3116  return success();
3117 }
3118 
3119 //===----------------------------------------------------------------------===//
3120 // HoistRedundantVectorTransfersOp
3121 //===----------------------------------------------------------------------===//
3122 
3124 transform::HoistRedundantVectorTransfersOp::applyToOne(
3125  transform::TransformRewriter &rewriter, func::FuncOp target,
3127  transform::TransformState &state) {
3128  // WARNING: This hoisting does not model parallelism and is generally
3129  // incorrect when used on distributed loops with memref semantics!
3130  // TODO: obsolete and should be retired.
3132  results.push_back(target);
3134 }
3135 
3136 //===----------------------------------------------------------------------===//
3137 // ConvertConv2DToImg2ColOp.
3138 //===----------------------------------------------------------------------===//
3139 
3140 DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne(
3141  transform::TransformRewriter &rewriter, linalg::LinalgOp target,
3143  transform::TransformState &state) {
3144  rewriter.setInsertionPoint(target);
3145  auto maybeTransformed =
3147  target)
3148  .Case([&](linalg::Conv2DNhwcHwcfOp op) {
3149  return rewriteInIm2Col(rewriter, op);
3150  })
3151  .Case([&](linalg::Conv2DNhwcFhwcOp op) {
3152  return rewriteInIm2Col(rewriter, op);
3153  })
3154  .Case([&](linalg::DepthwiseConv2DNhwcHwcOp op) {
3155  return rewriteInIm2Col(rewriter, op);
3156  })
3157  .Case([&](linalg::Conv2DNchwFchwOp op) {
3158  return rewriteInIm2Col(rewriter, op);
3159  })
3160  .Default([&](Operation *op) {
3161  return rewriter.notifyMatchFailure(op, "not supported");
3162  });
3163  if (failed(maybeTransformed))
3164  return emitDefaultSilenceableFailure(target);
3165  // Handle to the operation producing the img2col tensor.
3166  results.push_back(maybeTransformed->first);
3167  // Handle to the operation that replaces the original convolution.
3168  results.push_back(maybeTransformed->second);
3170 }
3171 
3172 //===----------------------------------------------------------------------===//
3173 // TransposeConv2DOp
3174 //===----------------------------------------------------------------------===//
3175 
3176 DiagnosedSilenceableFailure transform::TransposeConv2DOp::applyToOne(
3177  transform::TransformRewriter &rewriter, linalg::LinalgOp target,
3179  transform::TransformState &state) {
3180  rewriter.setInsertionPoint(target);
3181  auto maybeTransformed =
3183  .Case([&](linalg::Conv2DNhwcFhwcOp op) {
3184  return transposeConv2D(rewriter, op);
3185  })
3186  .Case([&](linalg::Conv2DNhwcFhwcQOp op) {
3187  return transposeConv2D(rewriter, op);
3188  })
3189  .Default([&](Operation *op) {
3190  return rewriter.notifyMatchFailure(op, "not supported");
3191  });
3192  if (failed(maybeTransformed))
3193  return emitDefaultSilenceableFailure(target);
3194  // Handle to the new Conv2D operation with transposed filters
3195  results.push_back(*maybeTransformed);
3197 }
3198 
3199 //===----------------------------------------------------------------------===//
3200 // InsertSliceToCopyOp
3201 //===----------------------------------------------------------------------===//
3202 template <typename OpTy>
3205  transform::TransformState &state) {
3206  static_assert(llvm::is_one_of<OpTy, tensor::InsertSliceOp,
3207  tensor::ParallelInsertSliceOp>() &&
3208  "wrong op type");
3209 
3210  if (auto copySource =
3211  target.getSource().template getDefiningOp<linalg::CopyOp>()) {
3212  results.push_back(copySource);
3214  }
3215 
3216  // If we are inside an InParallel region, temporarily set the insertion point
3217  // outside: only tensor.parallel_insert_slice ops are allowed in there.
3218  if constexpr (std::is_same_v<OpTy, tensor::ParallelInsertSliceOp>) {
3219  rewriter.setInsertionPoint(
3220  target->template getParentOfType<scf::InParallelOp>());
3221  }
3222 
3223  Value extracted = rewriter.create<tensor::ExtractSliceOp>(
3224  target.getLoc(), target.getDest(), target.getMixedOffsets(),
3225  target.getMixedSizes(), target.getMixedStrides());
3226  Value copied = rewriter
3227  .create<linalg::CopyOp>(target.getLoc(),
3228  target.getSource(), extracted)
3229  .getResult(0);
3230  // Reset the insertion point.
3231  rewriter.setInsertionPoint(target);
3232  rewriter.replaceOpWithNewOp<OpTy>(
3233  target, copied, target.getDest(), target.getMixedOffsets(),
3234  target.getMixedSizes(), target.getMixedStrides());
3235 
3236  results.push_back(copied.getDefiningOp());
3238 }
3239 
3240 DiagnosedSilenceableFailure transform::InsertSliceToCopyOp::applyToOne(
3241  transform::TransformRewriter &rewriter, Operation *targetOp,
3243  transform::TransformState &state) {
3244 
3245  rewriter.setInsertionPoint(targetOp);
3246  if (auto target = dyn_cast<tensor::InsertSliceOp>(targetOp))
3247  return doit(rewriter, target, results, state);
3248  if (auto target = dyn_cast<tensor::ParallelInsertSliceOp>(targetOp))
3249  return doit(rewriter, target, results, state);
3250 
3252  emitSilenceableError()
3253  << "only InsertSliceOp and ParallelInsertSliceOp ops are supported";
3254  diag.attachNote(targetOp->getLoc()) << "target op";
3255  return diag;
3256 }
3257 
3258 //===----------------------------------------------------------------------===//
3259 // MapCopyToThreadsOp
3260 //===----------------------------------------------------------------------===//
3261 
3262 DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
3263  transform::TransformRewriter &rewriter, Operation *target,
3265  transform::TransformState &state) {
3266  // Check if the op is supported.
3267  if (!isa<linalg::CopyOp, tensor::PadOp>(target)) {
3269  emitSilenceableError()
3270  << "only linalg.copy and tensor.pad target ops are supported";
3271  diag.attachNote(target->getLoc()) << "target op";
3272  return diag;
3273  }
3274  assert(target->getNumResults() == 1 && "expected single result");
3275  auto resultShapedType = cast<ShapedType>(target->getResult(0).getType());
3276  if (!resultShapedType.hasStaticShape()) {
3278  emitSilenceableError()
3279  << "only statically sized ops of rank <= 3 are supported";
3280  diag.attachNote(target->getLoc()) << "target op";
3281  return diag;
3282  }
3283 
3284  // Conservatively set the minimum viable desired bitwidth alignment.
3285  int64_t desiredBitAlignment = getDesiredBitAlignment();
3286  int64_t eltBitwidth =
3287  resultShapedType.getElementType().getIntOrFloatBitWidth();
3288  if (desiredBitAlignment % eltBitwidth != 0) {
3289  desiredBitAlignment = eltBitwidth;
3290  }
3291 
3292  gpu::CopyMappingInfo mapping(
3293  /*ctx=*/getContext(),
3294  /*totalNumThreads=*/getTotalNumThreads(),
3295  /*alignment=*/desiredBitAlignment,
3296  /*sizes=*/resultShapedType.getShape(),
3297  /*favorPredication=*/false,
3298  /*elementalBitwidth=*/
3299  resultShapedType.getElementType().getIntOrFloatBitWidth());
3300  if (mapping.status == gpu::CopyMappingInfo::Status::Invalid) {
3302  emitSilenceableError()
3303  << "too few threads to map copy op to threads on the most minor "
3304  "dimension, given alignment and vector size constraints, try "
3305  "smaller tile size of mapping to more threads";
3306  diag.attachNote(target->getLoc()) << "target op";
3307  return diag;
3308  }
3309 
3310  // OpBuilder only used to compute attributes.
3311  OpBuilder b(getContext());
3312  linalg::ForallTilingResult tilingResult;
3314  /*rewriter=*/rewriter,
3315  /*state=*/state,
3316  /*transformOp=*/*this,
3317  /*target=*/target,
3318  /*mixedNumThreads=*/getMixedValues(mapping.numThreads, {}, b),
3319  /*mixedTileSizes=*/ArrayRef<OpFoldResult>{},
3320  /*mapping=*/b.getArrayAttr(mapping.threadMapping),
3321  /*tilingResult=*/tilingResult);
3322  if (!diag.succeeded())
3323  return diag;
3324 
3325  results.push_back(tilingResult.tileOp);
3326  results.push_back(tilingResult.tiledOp);
3328 }
3329 
3330 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
3331 
3332 #define GET_OP_CLASSES
3333 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
static MLIRContext * getContext(OpFoldResult val)
DiagnosedSilenceableFailure doit(RewriterBase &rewriter, OpTy target, transform::ApplyToEachResultList &results, transform::TransformState &state)
static Operation * cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
void printOptionalInterchange(OpAsmPrinter &p, ArrayRef< int64_t > interchangeVals)
#define DOWNSCALE(trans)
bool isValidPackingPermutation(RelayoutOpTy op, ArrayRef< int64_t > permutation, OuterOrInnerPerm outerOrInnerPerm=OuterOrInnerPerm::Outer)
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(transform::TransformState &state, TransformOpInterface transformOp, SmallVector< OpFoldResult > &result, ArrayRef< OpFoldResult > ofrs)
Assuming that ofr is an index attr or a transform dialect handle mapped to exactly one op with one in...
#define DOWNSCALE_NORMAL(a, b)
static FailureOr< LinalgOp > tryApply(Operation *operation, Args &&...args)
Attempts to apply the pattern specified as template argument to the given operation.
static void printMultitileSizesTypes(OpAsmPrinter &printer, Operation *op, Type targetType, Type lowSizeType, Type, Type)
static Operation * replaceForAllWithNewSignature(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp, TilingResult &tileAndFuseResult, int64_t resultNumber, SmallVector< OpFoldResult > &offsets, SmallVector< OpFoldResult > &sizes)
Add new operands to the forall op for users of the producerOp that are dominated by the containing sc...
static SmallVector< Operation * > tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
First, find the first "scf::ForallOp" user of producerOp and ensure it is exactly the containingOp,...
static ParseResult parseMultitileSizesTypes(OpAsmParser &parser, Type &targetType, Type &lowSizeType, Type &highSizeType, Type &splitPointType)
ParseResult parseOptionalInterchange(OpAsmParser &parser, OperationState &result)
#define DBGS()
static LogicalResult applyTilingToAll(RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps, unsigned numLoops, transform::TransformResults &transformResults, function_ref< FailureOr< scf::SCFTileAndFuseResult >(TilingInterface)> applyFn)
Apply a tiling transformation to all payload ops and store both the tiled operation as well as the cr...
static std::tuple< SmallVector< Operation * >, Operation * > tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
Find the first "extract" user of producerOp and tile it right before its use.
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static FailureOr< ForallTilingResult > tileToForallOpImpl(RewriterBase &b, TilingInterface op, ArrayRef< OpFoldResult > numThreads, std::optional< ArrayRef< OpFoldResult >> nominalTileSizes, std::optional< ArrayAttr > mapping, bool omitTileOffsetBoundsCheck)
Rewrite a TilingInterface op to a tiled scf.forall.
Definition: Tiling.cpp:319
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
@ Square
Square brackets surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
U cast() const
Definition: Attributes.h:185
This class represents an argument of a Block.
Definition: Value.h:315
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
UnitAttr getUnitAttr()
Definition: Builders.cpp:114
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:238
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:183
AffineExpr getAffineSymbolExpr(unsigned position)
Definition: Builders.cpp:357
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:128
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:93
MLIRContext * getContext() const
Definition: Builders.h:55
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:273
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:288
IndexType getIndexType()
Definition: Builders.cpp:71
ArrayAttr getStrArrayAttr(ArrayRef< StringRef > values)
Definition: Builders.cpp:313
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:156
A class for computing basic dominance information.
Definition: Dominance.h:121
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition: Dominance.h:141
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class allows control over how the GreedyPatternRewriteDriver works.
RewriterBase::Listener * listener
An optional listener that should be notified about IR modifications.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:198
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
Definition: AsmPrinter.cpp:91
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:333
This class helps build Operations.
Definition: Builders.h:206
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:528
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
Definition: Builders.h:301
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:383
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:305
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:397
This class represents a single result from folding an operation.
Definition: OpDefinition.h:266
This class represents an operand of an operation.
Definition: Value.h:263
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:453
This class provides the API for ops that are known to be isolated from above.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getOpResult(unsigned idx)
Definition: Operation.h:416
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:728
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:512
void setOperand(unsigned idx, Value value)
Definition: Operation.h:346
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition: Operation.h:538
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:776
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:652
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:341
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_type_range getResultTypes()
Definition: Operation.h:423
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition: Operation.h:263
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:852
result_range getOpResults()
Definition: Operation.h:415
result_range getResults()
Definition: Operation.h:410
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:218
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:640
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:727
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:245
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:660
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:606
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:615
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor)
Find uses of from and replace them with to if the functor returns true.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:539
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:56
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
Type front()
Return first type in the range.
Definition: TypeRange.h:148
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
user_range getUsers() const
Definition: Value.h:224
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
State for analysis-enabled bufferization.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static Attribute parse(AsmParser &parser, Type type)
Parse the short form [42, 100, -1] without any type prefix.
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
void assign(unsigned size, std::nullptr_t)
Sets the list of results to size null pointers.
void reserve(unsigned size)
Reserves space for size elements in the list.
size_t size() const
Returns the number of elements in the list.
void push_back(Operation *op)
Appends an element to the list.
A listener that updates a TransformState based on IR modifications.
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void setValues(OpResult handle, Range &&values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
LogicalResult notifyPayloadOperationReplaced(Operation *op, Operation *replacement)
Notify the transform dialect interpreter that the given op has been replaced with another op and that...
The state maintained across applications of various ops implementing the TransformOpInterface.
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
Definition: AffineOps.cpp:1217
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1124
LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze op and its nested ops.
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition: Visitors.h:137
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
FailureOr< PackingResult > buildPackingLoopNest(RewriterBase &rewriter, tensor::PadOp opToHoist, scf::ForOp outermostEnclosingForOp, ArrayRef< int64_t > transposeVector)
Build the packing loop nest required to hoist opToHoist above outermostEnclosingForOp.
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp namedOp)
Create a GenericOp from the given named operation namedOp and replace namedOp.
LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, const LinalgPaddingOptions &options, LinalgOp &paddedOp, SmallVector< Value > &replacements, SmallVector< tensor::PadOp > &padOps)
Pad the iterator dimensions paddingDimensions of all opToPad operands to a static bounding box.
Definition: Padding.cpp:151
FailureOr< std::pair< Operation *, Operation * > > rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp)
Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) and linalg....
FailureOr< ForallTilingResult > tileToForallOpUsingTileSizes(RewriterBase &builder, TilingInterface op, ArrayRef< OpFoldResult > tileSizes, std::optional< ArrayAttr > mapping)
Same as tileToForallOp, but calculate the number of threads required using the given tileSizes.
Definition: Tiling.cpp:433
FailureOr< LowerUnPackOpResult > lowerUnPack(RewriterBase &rewriter, tensor::UnPackOp unPackOp)
Rewrite pack as empty + transpose + reshape + extract_slice.
Definition: Transforms.cpp:377
std::optional< Value > allocateWorkgroupMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU workgroup memory.
Definition: Promotion.cpp:470
Value bufferizeToAllocation(RewriterBase &rewriter, const BufferizeToAllocationOptions &options, tensor::PadOp padOp, Attribute memorySpace={}, Operation *insertionPoint=nullptr)
Materialize a buffer allocation for the given tensor.pad op and lower the op to linalg....
FailureOr< LinalgOp > specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp)
Create a namedOp from the given GenericOp and replace the GenericOp.
Definition: Specialize.cpp:24
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns)
Definition: Tiling.cpp:877
LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value)
In case of GPU private memory there is no need to deallocate since the memory is freed when going out...
Definition: Promotion.cpp:511
std::optional< Value > allocateGPUPrivateMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU private memory.
Definition: Promotion.cpp:495
FailureOr< Operation * > rewriteInDestinationPassingStyle(RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp)
Rewrite tensor.from_elements to linalg.generic.
FailureOr< ForallTilingResult > tileToForallOp(RewriterBase &builder, TilingInterface op, ArrayRef< OpFoldResult > numThreads, std::optional< ArrayAttr > mapping)
Definition: Tiling.cpp:424
FailureOr< Operation * > transposeConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op)
Convert linalg.conv_2d_nhwc_fhwc(_q) to linalg.conv_2d_nhwc_hwcf(_q) by materializing transpose.
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
Patterns to fold unit-extent dimensions in operands/results of linalg ops on tensors via reassociativ...
LogicalResult copyToWorkgroupMemory(OpBuilder &b, Value src, Value dst)
Create Memref copy operations and add gpu barrier guards before and after the copy operation to ensur...
Definition: Promotion.cpp:486
FailureOr< Value > hoistPaddingOnTensors(RewriterBase &rewriter, tensor::PadOp opToHoist, int64_t numLoops, ArrayRef< int64_t > transposeVector, tensor::PadOp &hoistedOp, SmallVectorImpl< GenericOp > &transposeOps)
Mechanically hoist padding operations on tensors by numLoops into a new, generally larger tensor.
LogicalResult linalgOpAnchoredEmptyTensorEliminationStep(RewriterBase &rewriter, Operation *op, bufferization::OneShotAnalysisState &state)
Try to eliminate tensor::EmptyOps inside op that are anchored on a LinalgOp.
LogicalResult promoteSubviewsPrecondition(Operation *op, LinalgPromotionOptions options)
Promote memref.subviews feeding linalg-on-buffers operations.
Definition: Promotion.cpp:399
void hoistRedundantVectorTransfers(func::FuncOp func)
Hoist vector.transfer_read/vector.transfer_write on buffers pairs out of immediately enclosing scf::F...
Definition: Hoisting.cpp:76
LogicalResult copyToGPUPrivateMemory(OpBuilder &b, Value src, Value dst)
Normal copy to between src and dst.
Definition: Promotion.cpp:503
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.
Definition: Interchange.cpp:50
FailureOr< StaticMultiSizeSpecification > computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, int64_t targetSize, int64_t divisor)
Definition: Tiling.cpp:111
FailureOr< SplitReductionResult > splitReduction(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
FailureOr< ForallReductionTilingResult > tileReductionUsingForall(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > tileSizes={}, std::optional< ArrayAttr > mapping=std::nullopt)
Method to tile a reduction to parallel iterations computing partial reductions.
Definition: Tiling.cpp:613
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
Definition: Transforms.cpp:794
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
Definition: Transforms.cpp:507
void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns)
Patterns to promote inputs to outputs and remove unused inputs of linalg.generic ops.
FailureOr< LowerPackResult > lowerPack(RewriterBase &rewriter, tensor::PackOp packOp)
Rewrite pack as pad + reshape + transpose.
Definition: Transforms.cpp:219
FailureOr< LinalgOp > promoteSubViews(OpBuilder &b, LinalgOp op, const LinalgPromotionOptions &options)
Promote the subViews into a new buffer allocated at the insertion point b.
Definition: Promotion.cpp:421
std::function< SplitReductionOptions(LinalgOp op)> ControlSplitReductionFn
Function signature to control reduction splitting.
Definition: Transforms.h:443
LogicalResult deallocateWorkgroupMemory(OpBuilder &, Value)
In case of GPU group memory there is no need to deallocate.
Definition: Promotion.cpp:479
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, tensor::PackOp packOp, linalg::LinalgOp linalgOp, tensor::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
Definition: Transforms.cpp:703
LogicalResult vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false)
Emit a suitable vector form for an operation.
std::pair< TilingInterface, TilingInterface > splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension, OpFoldResult splitPoint)
Split the given op into two parts along the given iteration space dimension at the specified splitPoi...
Definition: Split.cpp:67
FailureOr< SplitReductionResult > splitReductionByScaling(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
Scaling-based implementation of the split reduction transformation.
FailureOr< MultiSizeSpecification > computeMultiTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension, OpFoldResult targetSize, OpFoldResult divisor, bool emitAssertions=true)
Emits the IR computing the multi-sized tiling specification with two tile sizes not exceeding targetS...
Definition: Tiling.cpp:137
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:77
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:19
FailureOr< scf::SCFReductionTilingResult > tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > tileSize)
Method to tile a reduction and generate a parallel op within a serial loop.
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
Definition: SCF.cpp:588
FailureOr< SCFTilingResult > tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op, const SCFTilingOptions &options)
Method to tile an op that implements the TilingInterface using scf.for for iterating over the tiles.
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:105
void populateFoldTensorSubsetIntoVectorTransferPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into vector transfer ops.
DiagnosedSilenceableFailure tileToForallOpImpl(RewriterBase &rewriter, transform::TransformState &state, TransformOpInterface transformOp, Operation *target, ArrayRef< OpFoldResult > mixedNumThreads, ArrayRef< OpFoldResult > mixedTileSizes, std::optional< ArrayAttr > mapping, linalg::ForallTilingResult &tilingResult)
Implementation of tiling operations using scf.forall.
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void onlyReadsHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void consumesHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void producesHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
void populateSinkVectorBroadcastPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant vector broadcasts.
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, TypeRange valueTypes=TypeRange(), ArrayRef< bool > scalables={}, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hook for custom directive in assemblyFormat.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableVals, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hook for custom directive in assemblyFormat.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
static LogicalResult failure(bool isFailure=true)
If isFailure is true a failure result is generated, otherwise a 'success' result is generated.
Definition: LogicalResult.h:36
This is the representation of an operand reference.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Definition: Builders.h:283
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
NamedAttrList attributes
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
A listener that forwards all notifications to another listener.
Definition: PatternMatch.h:446
ForwardingListener(OpBuilder::Listener *listener)
Definition: PatternMatch.h:447
Container for result values of tiling.
SmallVector< Value > tiledValues
Options for analysis-enabled bufferization.
Transformation to drop unit-extent dimensions from linalg.generic operations.
Definition: Transforms.h:473
Vectorization pattern for memref::CopyOp.
Definition: Transforms.h:1319
Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) dimensions into 1-D depthwise c...
Definition: Transforms.h:1265
Rewrite a TilingInterface op to a tiled scf.forall, applying tiling by numThreads.
Definition: Transforms.h:849
Match and rewrite for the pattern:
Definition: Transforms.h:1392
Match and rewrite for the pattern:
Definition: Transforms.h:1420
LinalgPromotionOptions & setUseFullTileBuffersByDefault(bool use)
Definition: Transforms.h:380
LinalgPromotionOptions & setAlignment(unsigned align)
Definition: Transforms.h:386
LinalgPromotionOptions & setUseAlloca(bool use)
Definition: Transforms.h:399
LinalgPromotionOptions & setCopyInOutFns(CopyCallbackFn const &copyIn, CopyCallbackFn const &copyOut)
Definition: Transforms.h:419
LinalgPromotionOptions & setUseFullTileBuffers(ArrayRef< bool > useFullTiles)
Definition: Transforms.h:369
LinalgPromotionOptions & setMemorySpace(Attribute memorySpc)
Definition: Transforms.h:393
LinalgPromotionOptions & setAllocationDeallocationFns(AllocBufferCallbackFn const &allocFn, DeallocBufferCallbackFn const &deallocFn)
Definition: Transforms.h:409
LinalgPromotionOptions & setOperandsToPromote(ArrayRef< int64_t > operands)
Definition: Transforms.h:358
Split Reduction options.
Definition: Transforms.h:428
Options used to control tile + fuse.
SCFTilingOptions tilingOptions
The tiling options used to control the tiling of the consumer.
Options to use to control tiling.
SCFTilingOptions & setTileSizeComputationFunction(SCFTileSizeComputationFunction fun)
SCFTilingOptions & setInterchange(ArrayRef< int64_t > interchange)
SCFTilingOptions & setTileSizes(ArrayRef< OpFoldResult > ts)
Convenience function to set the tileSizeComputationFunction to a function that computes tile sizes at...
SmallVector< int64_t > interchangeVector
The interchange vector to reorder the tiled loops.