MLIR  19.0.0git
LinalgTransformOps.cpp
Go to the documentation of this file.
1 //===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
12 
37 #include "mlir/IR/PatternMatch.h"
38 #include "mlir/IR/TypeUtilities.h"
40 #include "mlir/Support/LLVM.h"
41 #include "mlir/Support/TypeID.h"
43 #include "llvm/ADT/STLExtras.h"
44 #include "llvm/ADT/ScopeExit.h"
45 #include "llvm/ADT/TypeSwitch.h"
46 #include "llvm/Support/Debug.h"
47 #include <type_traits>
48 
49 using namespace mlir;
50 using namespace mlir::linalg;
51 using namespace mlir::transform;
52 
53 #define DEBUG_TYPE "linalg-transforms"
54 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
55 #define DBGSNL() (llvm::dbgs() << "\n")
56 #define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n")
57 
58 /// Attempts to apply the pattern specified as template argument to the given
59 /// operation. The pattern is expected to have a `returningMatchAndRewrite`
60 /// function that returns the "main" result or failure. Returns failure if the
61 /// pattern failed to apply. Extra arguments are forwarded to the pattern
62 /// constructor.
63 template <typename PatternTy, typename... Args>
64 static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
65  // Check if the given operation has the type expected by the pattern.
66  using OpTy = typename llvm::function_traits<
67  decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
68  auto op = dyn_cast<OpTy>(operation);
69  if (!op)
70  return failure();
71 
72  // Apply the pattern directly to the op.
73  PatternTy pattern(operation->getContext(), std::forward<Args>(args)...);
74  // We want to discourage direct use of PatternRewriter in APIs but In this
75  // very specific case, an IRRewriter is not enough.
76  struct TrivialPatternRewriter : public PatternRewriter {
77  public:
78  explicit TrivialPatternRewriter(MLIRContext *context)
79  : PatternRewriter(context) {}
80  };
81  TrivialPatternRewriter rewriter(operation->getContext());
82  rewriter.setInsertionPoint(operation);
83  auto result = pattern.returningMatchAndRewrite(op, rewriter);
84  if (failed(result))
85  return failure();
86  return cast<LinalgOp>(result->getOperation());
87 }
88 
89 /// Assuming that `ofr` is an index attr or a param of index type
90 /// or a transform dialect handle mapped to exactly one op
91 /// with one index result, return that value.
93  transform::TransformState &state, TransformOpInterface transformOp,
95  for (OpFoldResult ofr : ofrs) {
96  if (ofr.is<Attribute>()) {
97  if (!isa<IntegerAttr>(ofr.get<Attribute>()))
98  return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
99  result.push_back(ofr);
100  continue;
101  }
102 
103  Value transformValue = ofr.get<Value>();
104  if (isa<TransformParamTypeInterface>(transformValue.getType())) {
105  ArrayRef<Attribute> params = state.getParams(transformValue);
106  if (params.size() != 1)
107  return transformOp.emitDefiniteFailure()
108  << "requires exactly one parameter associated";
109  result.push_back(params[0]);
110  continue;
111  }
112 
113  auto payloadOps = state.getPayloadOps(transformValue);
114  if (!llvm::hasSingleElement(payloadOps)) {
116  transformOp.emitSilenceableError()
117  << "handle must be mapped to exactly one payload op";
118  diag.attachNote(transformValue.getLoc())
119  << "mapped to " << llvm::range_size(payloadOps) << " payload ops";
120  return diag;
121  }
122 
123  Operation *op = *payloadOps.begin();
124  if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
126  transformOp.emitSilenceableError()
127  << "payload op must have exactly 1 index result";
128  diag.attachNote(op->getLoc())
129  << "has " << op->getNumResults() << " results";
130  return diag;
131  }
132  result.push_back(op->getResult(0));
133  }
134 
136 }
137 
138 // Given a list of params that are index attrs or a list of OpFoldResults
139 // that are either index attrs or op handles, return a list of OpFoldResults
140 // of index attrs or a list of OpFoldResults where all op handles are
141 // replaced with the first (and only) OpResult of that payload op.
142 // (There must be exactly one parameter associated with the AnyParamType or
143 // one mapped payload op which must have exactly one index result.)
145  transform::TransformState &state, TransformOpInterface transformOp,
146  SmallVector<OpFoldResult> &result, Value packedHandle) {
147  if (isa<TransformParamTypeInterface>(packedHandle.getType())) {
148  ArrayRef<Attribute> params = state.getParams(packedHandle);
149  for (auto param : params) {
150  if (!isa<IntegerAttr>(param))
151  return transformOp.emitDefiniteFailure()
152  << "expected the parameter to be associated with an integer "
153  "attribute";
154  result.push_back(param);
155  }
157  }
158 
159  for (Operation *op : state.getPayloadOps(packedHandle)) {
160  if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
162  transformOp.emitSilenceableError()
163  << "payload op must have exactly 1 index result";
164  diag.attachNote(op->getLoc())
165  << "has " << op->getNumResults() << " results";
166  return diag;
167  }
168  result.push_back(op->getResult(0));
169  }
170 
172 }
173 
174 //===----------------------------------------------------------------------===//
175 // Apply...PatternsOp
176 //===----------------------------------------------------------------------===//
177 
178 void transform::ApplyEraseUnnecessaryInputsPatternsOp::populatePatterns(
179  RewritePatternSet &patterns) {
181 }
182 
183 void transform::ApplyFoldUnitExtentDimsViaReshapesPatternsOp::populatePatterns(
184  RewritePatternSet &patterns) {
187 }
188 
189 void transform::ApplyFoldUnitExtentDimsViaSlicesPatternsOp::populatePatterns(
190  RewritePatternSet &patterns) {
192  options.rankReductionStrategy =
195 }
196 
197 void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns(
198  RewritePatternSet &patterns) {
200 }
201 
202 //===----------------------------------------------------------------------===//
203 // BufferizeToAllocationOp
204 //===----------------------------------------------------------------------===//
205 
206 void transform::BufferizeToAllocationOp::build(OpBuilder &b,
207  OperationState &result,
208  Value target,
209  Attribute memorySpace) {
210  SmallVector<Type> resultTypes;
211  resultTypes.push_back(b.getType<transform::AnyValueType>());
212  resultTypes.push_back(b.getType<transform::AnyOpType>());
213  return build(b, result,
214  /*resultTypes=*/resultTypes,
215  /*target=*/target,
216  /*memorySpace=*/memorySpace);
217 }
218 
219 void transform::BufferizeToAllocationOp::build(OpBuilder &b,
220  OperationState &result,
221  Value target,
222  int64_t memorySpace) {
223  SmallVector<Type> resultTypes;
224  resultTypes.push_back(b.getType<transform::AnyValueType>());
225  resultTypes.push_back(b.getType<transform::AnyOpType>());
226  return build(b, result,
227  /*resultTypes=*/resultTypes,
228  /*target=*/target,
229  /*memorySpace=*/b.getI64IntegerAttr(memorySpace));
230 }
231 
232 namespace {
233 class NewOpsListener : public RewriterBase::ForwardingListener {
234 public:
236 
237  SmallVector<Operation *> getNewOps() const {
238  return SmallVector<Operation *>(newOps.begin(), newOps.end());
239  }
240 
241 private:
242  void notifyOperationInserted(Operation *op,
243  OpBuilder::InsertPoint previous) override {
244  ForwardingListener::notifyOperationInserted(op, previous);
245  // We only care about newly created ops.
246  if (previous.isSet())
247  return;
248  auto inserted = newOps.insert(op);
249  (void)inserted;
250  assert(inserted.second && "expected newly created op");
251  }
252 
253  void notifyOperationErased(Operation *op) override {
254  ForwardingListener::notifyOperationErased(op);
255  op->walk([&](Operation *op) { newOps.erase(op); });
256  }
257 
258  DenseSet<Operation *> newOps;
259 };
260 } // namespace
261 
262 DiagnosedSilenceableFailure transform::BufferizeToAllocationOp::apply(
265  // Attach listener to keep track of newly created ops.
266  OpBuilder::Listener *previousListener = rewriter.getListener();
267  auto resetListener =
268  llvm::make_scope_exit([&]() { rewriter.setListener(previousListener); });
269  NewOpsListener newOpsListener(previousListener);
270  rewriter.setListener(&newOpsListener);
271 
273  if (getMemcpyOp() == "bufferization.materialize_in_destination") {
276  } else if (getMemcpyOp() == "memref.copy") {
277  options.memcpyOp =
279  } else if (getMemcpyOp() == "linalg.copy") {
280  options.memcpyOp =
282  } else {
283  llvm_unreachable("invalid memcpy op");
284  }
285  if (getAllocOp() == "memref.alloc") {
286  options.allocOp =
288  } else if (getAllocOp() == "memref.alloca") {
289  options.allocOp =
291  } else {
292  llvm_unreachable("invalid alloc op");
293  }
294  options.bufferizeDestinationOnly = getBufferizeDestinationOnly();
295  options.emitDealloc = getEmitDealloc();
296 
297  // Bufferize ops.
298  Attribute memorySpace =
299  getMemorySpace().has_value() ? getMemorySpace().value() : Attribute();
300  SmallVector<Value> allocatedBuffers;
301  for (Operation *op : state.getPayloadOps(getTarget())) {
302  Value buffer =
303  linalg::bufferizeToAllocation(rewriter, options, op, memorySpace);
304  if (!buffer) {
305  DiagnosedSilenceableFailure diag = emitSilenceableError()
306  << "failed to bufferize operation";
307  diag.attachNote(op->getLoc()) << "target payload op";
308  return diag;
309  }
310  allocatedBuffers.push_back(buffer);
311  }
312 
313  // Set results.
314  results.setValues(cast<OpResult>(getAllocatedBuffer()), allocatedBuffers);
315  results.set(cast<OpResult>(getNewOps()), newOpsListener.getNewOps());
317 }
318 
319 void transform::BufferizeToAllocationOp::getEffects(
321  if (getBufferizeDestinationOnly()) {
322  // The destination is replaced with a newly allocated buffer, but the op
323  // itself remains in place.
324  onlyReadsHandle(getTarget(), effects);
325  } else {
326  consumesHandle(getTarget(), effects);
327  }
328  producesHandle(getAllocatedBuffer(), effects);
329  producesHandle(getNewOps(), effects);
330  modifiesPayload(effects);
331 }
332 
334  if (getMemcpyOp() != "bufferization.materialize_in_destination" &&
335  getMemcpyOp() != "memref.copy" && getMemcpyOp() != "linalg.copy")
336  return emitOpError() << "unsupported memcpy op";
337  if (getAllocOp() != "memref.alloc" && getAllocOp() != "memref.alloca")
338  return emitOpError() << "unsupported alloc op";
339  return success();
340 }
341 
342 //===----------------------------------------------------------------------===//
343 // DecomposeOp
344 //===----------------------------------------------------------------------===//
345 
347 transform::DecomposeOp::applyToOne(transform::TransformRewriter &rewriter,
348  LinalgOp target,
350  transform::TransformState &state) {
351 #define DOWNSCALE(trans) \
352  { \
353  FailureOr<LinalgOp> res = tryApply<trans>(target); \
354  if (succeeded(res)) { \
355  results.push_back(*res); \
356  return DiagnosedSilenceableFailure::success(); \
357  } \
358  }
359 
360 #define DOWNSCALE_CALL(a, b) DownscaleSizeOneWindowed2DConvolution<a, b>
361 #define DOWNSCALE_NORMAL(a, b) DOWNSCALE(DOWNSCALE_CALL(a, b))
362 
363  DOWNSCALE_NORMAL(Conv2DNhwcHwcfOp, Conv1DNwcWcfOp)
364  DOWNSCALE_NORMAL(Conv2DNchwFchwOp, Conv1DNcwFcwOp)
365  DOWNSCALE_NORMAL(PoolingNhwcSumOp, PoolingNwcSumOp)
366  DOWNSCALE_NORMAL(PoolingNchwSumOp, PoolingNcwSumOp)
367  DOWNSCALE_NORMAL(PoolingNhwcMaxOp, PoolingNwcMaxOp)
368  DOWNSCALE_NORMAL(PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp)
369  DOWNSCALE_NORMAL(PoolingNhwcMinOp, PoolingNwcMinOp)
370  DOWNSCALE_NORMAL(PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp)
371  DOWNSCALE_NORMAL(PoolingNchwMaxOp, PoolingNcwMaxOp)
374 #undef DOWNSCALE_NORMAL
375 #undef DOWNSCALE_CALL
376 #undef DOWNSCALE
377  return emitDefaultSilenceableFailure(target);
378 }
379 
380 //===----------------------------------------------------------------------===//
381 // DecomposeInterfaceOp
382 //===----------------------------------------------------------------------===//
383 
384 // Decompose the target operation if it implements the AggregatedOpInterface.
385 // Push the decomposed operations (the ones that replaces the values produced by
386 // \p target) in the `results`.
387 DiagnosedSilenceableFailure transform::DecomposeInterfaceOp::applyToOne(
388  transform::TransformRewriter &rewriter, Operation *target,
390  transform::TransformState &state) {
391  auto decomposableOp = dyn_cast<AggregatedOpInterface>(target);
392  if (!decomposableOp) {
393  failed(rewriter.notifyMatchFailure(target,
394  "payload is not a decomposable op"));
395  return emitDefaultSilenceableFailure(target);
396  }
397 
398  FailureOr<SmallVector<Value>> maybeNewResults =
399  decomposableOp.decomposeOperation(rewriter);
400  if (failed(maybeNewResults))
401  return emitDefaultSilenceableFailure(target);
402 
403  rewriter.replaceOp(decomposableOp, *maybeNewResults);
404  for (Value val : *maybeNewResults) {
405  Operation *definition = val.getDefiningOp();
406  if (definition)
407  results.push_back(definition);
408  }
410 }
411 
412 //===----------------------------------------------------------------------===//
413 // EliminateLinalgOpAnchoredEmptyTensorsOp
414 //===----------------------------------------------------------------------===//
415 
416 void transform::EliminateLinalgOpAnchoredEmptyTensorsOp::getEffects(
418  onlyReadsHandle(getTarget(), effects);
419  modifiesPayload(effects);
420 }
421 
423 transform::EliminateLinalgOpAnchoredEmptyTensorsOp::apply(
424  transform::TransformRewriter &rewriter, TransformResults &transformResults,
425  TransformState &state) {
427  options.allowReturnAllocsFromLoops = true;
428 
429  for (Operation *target : state.getPayloadOps(getTarget())) {
431  if (failed(analyzeOp(target, state)))
432  return mlir::emitSilenceableFailure(target->getLoc())
433  << "failed to analyze op";
435  rewriter, target, state)))
436  return mlir::emitSilenceableFailure(target->getLoc())
437  << "failed to eliminate LinalgOp anchored tensor.empty ops";
438  }
440 }
441 
442 //===----------------------------------------------------------------------===//
443 // FuseOp
444 //===----------------------------------------------------------------------===//
445 
446 /// Apply a tiling transformation to all payload ops and store both the
447 /// tiled operation as well as the created tile loops.
448 template <typename Range>
450  RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps,
451  unsigned numLoops, transform::TransformResults &transformResults,
453  applyFn) {
454  SmallVector<Operation *> tiledLinalgOps;
455  SmallVector<SmallVector<Operation *>> loopOps(numLoops);
456 
457  for (Operation *target : payloadOps) {
458  auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
459  if (!tilingInterfaceOp)
460  return transformOp->emitError("only TilingInterface ops are supported");
461 
462  rewriter.setInsertionPoint(target);
464  applyFn(tilingInterfaceOp);
465  if (failed(tiledResults))
466  return failure();
467 
468  // Perform the replacement of tiled and fused values.
469  SmallVector<Operation *> opsToReplace{target};
470  llvm::append_range(opsToReplace, tiledResults->fusedProducers);
471  for (Operation *toReplace : opsToReplace) {
472  for (OpResult res : toReplace->getResults())
473  if (auto replacement = tiledResults->replacements.lookup(res))
474  rewriter.replaceAllUsesWith(res, replacement);
475  if (toReplace->use_empty()) {
476  rewriter.eraseOp(toReplace);
477  }
478  }
479 
480  // Report back the relevant handles to the transform op.
481  tiledLinalgOps.push_back(tiledResults->tiledAndFusedOps.front());
482  assert(tiledResults->loops.size() == numLoops &&
483  "Mismatched number of loops, tile and fuse transform should have "
484  "failed");
485  for (unsigned int i = 0; i < numLoops; ++i)
486  loopOps[i].push_back(tiledResults->loops[i]);
487  }
488 
489  transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
490  for (unsigned int i = 0; i < numLoops; ++i)
491  transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
492 
493  return success();
494 }
495 
497 transform::FuseOp::apply(transform::TransformRewriter &rewriter,
498  mlir::transform::TransformResults &transformResults,
500  SmallVector<int64_t> tileSizes =
501  extractFromIntegerArrayAttr<int64_t>(getTileSizes());
502  SmallVector<int64_t> tileInterchange =
503  extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
504 
505  scf::SCFTilingOptions tilingOptions;
506  tilingOptions.interchangeVector = tileInterchange;
507  SmallVector<OpFoldResult> tileSizesOfr =
508  getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
509  tilingOptions = tilingOptions.setTileSizes(tileSizesOfr);
510  scf::SCFTileAndFuseOptions tileAndFuseOptions;
511  tileAndFuseOptions.tilingOptions = tilingOptions;
513  rewriter, getOperation(), state.getPayloadOps(getTarget()),
514  tileSizes.size() - llvm::count(tileSizes, 0), transformResults,
515  [&](TilingInterface tilingInterfaceOp)
517  return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
518  tileAndFuseOptions);
519  });
522 }
523 
525  SmallVector<int64_t> permutation =
526  extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
527  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
528  if (!std::is_permutation(sequence.begin(), sequence.end(),
529  permutation.begin(), permutation.end())) {
530  return emitOpError() << "expects interchange to be a permutation, found "
531  << getTileInterchange();
532  }
533 
534  SmallVector<int64_t> sizes =
535  extractFromIntegerArrayAttr<int64_t>(getTileSizes());
536  size_t numExpectedLoops = sizes.size() - llvm::count(sizes, 0);
537  if (numExpectedLoops != getNumResults() - 1)
538  return emitOpError() << "expects " << numExpectedLoops << " loop results";
539 
540  return success();
541 }
542 
543 //===----------------------------------------------------------------------===//
544 // FuseIntoContainingOp
545 //===----------------------------------------------------------------------===//
546 
547 void transform::FuseIntoContainingOp::build(OpBuilder &builder,
548  OperationState &result,
549  Value producerOp,
550  Value containingOp) {
551  result.addOperands({producerOp, containingOp});
552  auto resultType = transform::AnyOpType::get(builder.getContext());
553  result.addTypes({resultType, resultType});
554 }
555 
556 /// Add new operands to the forall op for users of the producerOp
557 /// that are dominated by the containing scf.forall op.
559  RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
560  Operation *containingOp, TilingResult &tileAndFuseResult,
561  int64_t resultNumber, SmallVector<OpFoldResult> &offsets,
562  SmallVector<OpFoldResult> &sizes) {
563 
564  // Count number of users not including the containing op
565  SetVector<Operation *> dominatedUsers;
566  DominanceInfo domInfo(containingOp);
567  for (Operation *user : producerOp->getResult(resultNumber).getUsers()) {
568  if (!containingOp->isAncestor(user) &&
569  (domInfo.dominates(containingOp, user))) {
570  dominatedUsers.insert(user);
571  }
572  }
573  if (dominatedUsers.empty())
574  return nullptr;
575 
576  // Create new scf.forall op
577  auto forallOp = cast<scf::ForallOp>(containingOp);
578  OpBuilder::InsertionGuard g(rewriter);
579  rewriter.setInsertionPoint(forallOp);
580 
581  // Get new output
582  Location loc = forallOp.getLoc();
583  auto genericOp = dyn_cast<linalg::GenericOp>(producerOp);
584  if (!genericOp)
585  return nullptr;
586  SmallVector<Value> outputs = genericOp.getOutputs();
587  SmallVector<Value> newOuts(forallOp.getOutputs());
588  newOuts.push_back(outputs[resultNumber]);
589 
590  // Create new scf.forall op
591  auto newforallOp = rewriter.create<scf::ForallOp>(
592  loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
593  forallOp.getMixedStep(), newOuts, forallOp.getMapping());
594  rewriter.eraseBlock(newforallOp.getBody());
595  newforallOp.getRegion().takeBody(forallOp.getRegion());
596 
597  // Add additional block argument for new value being returned
598  // and replaces all uses of the new output with corresponding bbArg
599  // inside the scf.forall to enable fusion into this new scf.forall.
600  newforallOp.getBody()->addArgument(newOuts.back().getType(),
601  newOuts.back().getLoc());
602  auto bbArgs = newforallOp.getBody()->getArguments();
603  rewriter.replaceUsesWithIf(newOuts.back(), bbArgs.back(),
604  [&](OpOperand &use) {
605  Operation *op = use.getOwner();
606  return newforallOp->isProperAncestor(op);
607  });
608 
609  // Fix terminator
610  scf::InParallelOp terminatorOp = newforallOp.getTerminator();
611  SmallVector<Operation *> yieldingOps = llvm::to_vector<4>(llvm::map_range(
612  terminatorOp.getYieldingOps(), [](Operation &op) { return &op; }));
613  Operation *firstYieldOp = yieldingOps.front();
614  rewriter.setInsertionPoint(firstYieldOp);
615  Value src = tileAndFuseResult.tiledValues[0];
616  Value dst = newforallOp.getRegionIterArgs().back();
617  SmallVector<OpFoldResult> strides(offsets.size(), rewriter.getIndexAttr(1));
618  rewriter.create<tensor::ParallelInsertSliceOp>(firstYieldOp->getLoc(), src,
619  dst, offsets, sizes, strides);
620 
621  for (auto result : llvm::enumerate(forallOp.getResults())) {
622  rewriter.replaceAllUsesWith(result.value(),
623  newforallOp->getResult(result.index()));
624  }
625  rewriter.replaceUsesWithIf(producerOp->getResult(resultNumber),
626  newforallOp->getResults().back(),
627  [&](OpOperand &use) {
628  Operation *user = use.getOwner();
629  return dominatedUsers.contains(user);
630  });
631  return newforallOp;
632 }
633 
634 /// Find the first "extract" user of `producerOp` and tile it right before its
635 /// use. The tiled op is fused under the `containingOp`.
636 /// Return this fused op on success or nullptr if anything fails.
637 /// If tiled op has uses that are dominated by `containingOp`, return
638 /// a new `containingOp` with results of the fused op appended to
639 /// results of the `containingOp` or nullptr if there are no dominated uses.
640 static std::tuple<SmallVector<Operation *>, Operation *>
642  Operation *producerOp, Operation *containingOp) {
643  LLVM_DEBUG(DBGS() << "Try to fuse a direct extract use\n");
644  auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
645  if (!tileableProducer) {
646  diag.attachNote(producerOp->getLoc())
647  << "producer is not a TileableInterface: " << *producerOp;
648  return {};
649  }
650 
651  // Search the producer slices accessed within the containing operation.
652  // TODO: Generalize to more extract/insert/parallel_insert triples, maybe
653  // evolve into an interface.
654  auto it = llvm::find_if(tileableProducer->getUsers(), [&](Operation *user) {
655  auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
656  return sliceOp && containingOp->isProperAncestor(sliceOp);
657  });
658 
659  // Find a fusion opportunity.
660  if (it == tileableProducer->getUsers().end()) {
661  diag.attachNote(tileableProducer->getLoc())
662  << "could not find fusion opportunity for: " << *tileableProducer;
663  return {};
664  }
665  auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*it);
666 
667  // Try to fuse the producer in-place.
668  OpBuilder::InsertionGuard guard(rewriter);
669  rewriter.setInsertionPoint(sliceOpToTile);
670 
671  // Tile the producer.
672  int64_t resultNumber =
673  cast<OpResult>(sliceOpToTile.getSource()).getResultNumber();
674  LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
675 
676  SmallVector<OpFoldResult> offsets = sliceOpToTile.getMixedOffsets();
677  SmallVector<OpFoldResult> sizes = sliceOpToTile.getMixedSizes();
678 
679  FailureOr<TilingResult> tileAndFuseResult =
680  tileableProducer.generateResultTileValue(rewriter, resultNumber, offsets,
681  sizes);
682 
683  if (failed(tileAndFuseResult)) {
684  diag.attachNote(tileableProducer->getLoc())
685  << "failed to tile producer op: " << *tileableProducer;
686  return {};
687  }
688 
689 #ifndef NDEBUG
690  for (auto *tiledOp : tileAndFuseResult->tiledOps) {
691  LLVM_DEBUG(DBGS() << "tiledProducer: " << *tiledOp << "\n");
692  }
693 #endif
694 
695  // Replace the extract op.
696  auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
697  rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
698  cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
699  if (failed(maybeRankReduced)) {
700  diag.attachNote(producerOp->getLoc())
701  << "shape types don't match (missing canonicalization?):\nTiledOp: "
702  << tileAndFuseResult->tiledValues[0]
703  << "\nSliceOp: " << sliceOpToTile.getOperation() << '\n';
704  return {};
705  }
706  rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
707 
708  // Add new outputs to containing op, if required
709  Operation *newContainingOp = replaceForAllWithNewSignature(
710  rewriter, diag, producerOp, containingOp, *tileAndFuseResult,
711  resultNumber, offsets, sizes);
712 
713  return std::make_tuple(tileAndFuseResult->tiledOps, newContainingOp);
714 }
715 
716 /// First, find the first "scf::ForallOp" user of `producerOp` and ensure
717 /// it is exactly the `containingOp`, otherwise bail.
718 /// Then, find the first "extract" user of the tied block argument and tile it
719 /// right before its "extract" use. The tiled op is fused under the
720 /// `containingOp`.
721 /// Return this fused op on success or nullptr if anything fails.
724  RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
725  Operation *containingOp) {
726  LLVM_DEBUG(DBGS() << "Try to fuse an extract use through block argument\n");
727 
728  auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
729  if (!tileableProducer) {
730  diag.attachNote(producerOp->getLoc())
731  << "producer is not a TileableInterface: " << *producerOp;
732  return {};
733  }
734 
735  // Search the first use by a "scf::ForallOp" user.
736  scf::ForallOp forallOp;
737  auto itProducerUses =
738  llvm::find_if(tileableProducer->getUses(), [&](OpOperand &use) {
739  forallOp = dyn_cast<scf::ForallOp>(use.getOwner());
740  return forallOp;
741  });
742  // If it's not from the containing op, return.
743  if (!forallOp || forallOp != containingOp) {
744  diag.attachNote(tileableProducer->getLoc())
745  << "could not find a use by the containing op: " << *tileableProducer;
746  return {};
747  }
748 
749  // Search the producer slices accessed within the containing
750  // operation.
751  // TODO: Generalize to more extract/insert/parallel_insert triples.
752  // Maybe evolve into an interface.
753  OpOperand *pUse = &(*itProducerUses);
754  BlockArgument bbArg = forallOp.getTiedBlockArgument(pUse);
755 
756  // Search the producer slices accessed within the containing operation.
757  // TODO: Generalize to more extract/insert/parallel_insert triples, maybe
758  // evolve into an interface.
759  auto itBBArgUsers = llvm::find_if(bbArg.getUsers(), [&](Operation *user) {
760  auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
761  return sliceOp && containingOp->isProperAncestor(sliceOp);
762  });
763 
764  // Find a fusion opportunity.
765  if (itBBArgUsers == bbArg.getUsers().end()) {
766  diag.attachNote(containingOp->getLoc())
767  << "could not find fusion opportunity for bbArg: " << bbArg;
768  return {};
769  }
770  auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*itBBArgUsers);
771 
772  // Try to fuse the producer in-place.
773  OpBuilder::InsertionGuard guard(rewriter);
774  rewriter.setInsertionPoint(sliceOpToTile);
775 
776  // Replace the use in the tileableProducer before tiling: clone, replace and
777  // then tile.
778  int64_t resultNumber = cast<OpResult>(pUse->get()).getResultNumber();
779  LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
780 
781  // Gather destination tensors.
782  SmallVector<Value> destinationTensors;
784  rewriter, tileableProducer->getLoc(), tileableProducer,
785  destinationTensors))) {
786  diag.attachNote(tileableProducer->getLoc())
787  << "failed to get destination tensors for: " << *tileableProducer;
788  return {};
789  }
790 
791  IRMapping bvm;
792  bvm.map(destinationTensors[resultNumber], bbArg);
793  auto tileableProducerClone =
794  cast<TilingInterface>(rewriter.clone(*tileableProducer, bvm));
795  auto scopeGuard =
796  llvm::make_scope_exit([&]() { rewriter.eraseOp(tileableProducerClone); });
797 
798  // Tile the producer.
799  FailureOr<TilingResult> tileAndFuseResult =
800  tileableProducerClone.generateResultTileValue(
801  rewriter, resultNumber, sliceOpToTile.getMixedOffsets(),
802  sliceOpToTile.getMixedSizes());
803  if (failed(tileAndFuseResult)) {
804  diag.attachNote(tileableProducer->getLoc())
805  << "failed to tile producer op: " << *tileableProducer;
806  return {};
807  }
808 
809  // Replace the extract op.
810  auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
811  rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
812  cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
813  assert(succeeded(maybeRankReduced) && "unexpected shape");
814  rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
815 
816  // Replace the use in containingOp.
817  rewriter.modifyOpInPlace(containingOp, [&]() {
818  containingOp->setOperand(pUse->getOperandNumber(),
819  destinationTensors.front());
820  });
821 
822  return tileAndFuseResult->tiledOps;
823 }
824 
826  Operation *producerOp,
827  Operation *containingOp) {
828  LLVM_DEBUG(DBGS() << "Try to fuse an use by cloning\n");
829 
830  // Gather all uses inside the containing op.
832  for (OpResult result : producerOp->getOpResults()) {
833  for (OpOperand &use : result.getUses()) {
834  if (containingOp->isProperAncestor(use.getOwner())) {
835  uses.push_back(&use);
836  continue;
837  }
838  // Cannot clone and fuse if the use is by the containing op itself: fail
839  // immediately.
840  if (containingOp == use.getOwner()) {
841  diag.attachNote(producerOp->getLoc())
842  << "producer op use by containing op cannot be fused by cloning";
843  return nullptr;
844  }
845  }
846  }
847 
848  // Check for a non-empty list of fusion opportunities.
849  if (uses.empty()) {
850  diag.attachNote(producerOp->getLoc()) << "no fusion opportunity by cloning";
851  return nullptr;
852  }
853 
854  // Clone and fuse inside the containing op.
855  Operation *fusedOp = nullptr;
856  OpOperand *use = uses.front();
857  // Parallel insert slice is not a valid clone destination.
858  // TODO: Generalize to other type of ops.
859  assert(!isa<tensor::ParallelInsertSliceOp>(use->getOwner()) &&
860  "Parallel insert slice is not a valid clone destination");
861  unsigned resultNumber = cast<OpResult>(use->get()).getResultNumber();
862  LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
863 
864  OpBuilder::InsertionGuard guard(rewriter);
865  rewriter.setInsertionPoint(use->getOwner());
866  fusedOp = rewriter.clone(*producerOp);
867  rewriter.modifyOpInPlace(
868  use->getOwner(), [&] { use->set(fusedOp->getOpResult(resultNumber)); });
869 
870  return fusedOp;
871 }
872 
873 bool transform::FuseIntoContainingOp::allowsRepeatedHandleOperands() {
874  // Allow repeated handles since we are fusing everything anyway.
875  return true;
876 }
877 
879 transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter,
881  transform::TransformState &state) {
882  SmallVector<Operation *> fusedOps;
883  auto producerOps = state.getPayloadOps(getProducerOp());
884  auto containingOps = state.getPayloadOps(getContainingOp());
885  if (!llvm::hasSingleElement(containingOps)) {
886  return emitDefiniteFailure()
887  << "requires exactly one containing_op handle (got "
888  << llvm::range_size(containingOps) << ")";
889  }
890  Operation *containingOp = *containingOps.begin();
891 
892  // If nothing to fuse, propagate success.
893  if (std::empty(producerOps)) {
894  results.set(cast<OpResult>(getFusedOp()), SmallVector<mlir::Operation *>{});
895  results.set(cast<OpResult>(getNewContainingOp()), {containingOp});
897  }
898 
899  // Helper function to find the next producer that should be fused. Take any
900  // producer that has a use inside the containing op.
901  SetVector<Operation *> remainingProducers(producerOps.begin(),
902  producerOps.end());
903  auto getNextProducer = [&]() -> FailureOr<Operation *> {
904  for (const auto &it : enumerate(remainingProducers)) {
905  Operation *producerOp = it.value();
906  // The containing op may be a user of producerOp: use isAncestor.
907  int64_t numUsesInContainingOp =
908  llvm::count_if(producerOp->getUsers(), [&](Operation *op) {
909  return containingOp->isAncestor(op);
910  });
911  // TODO: When resolving the TODO below (no duplicate ops), take an op
912  // that has no use among the remaining producers. This is a topological
913  // sorting.
914  if (numUsesInContainingOp > 0) {
915  if (numUsesInContainingOp == 1)
916  remainingProducers.erase(remainingProducers.begin() + it.index());
917  return producerOp;
918  }
919  }
920  return failure();
921  };
922 
923  while (!remainingProducers.empty()) {
924  auto nextProducer = getNextProducer();
925  if (failed(nextProducer)) {
926  auto diag = mlir::emitSilenceableFailure(getLoc())
927  << "could not find next producer to fuse into container";
928  diag.attachNote(containingOp->getLoc()) << "containing op";
929  return diag;
930  }
931 
932  Operation *producerOp = *nextProducer;
933 
934  // Default diagnostic, to be complemented with more failure information.
936  diag << "could not fuse " << *producerOp << " into " << *containingOp;
937 
938  // TODO: If there are multiple uses of the producer in the containing op,
939  // we currently tile/clone the op multiple times (once per use). In some
940  // cases, we can tile/clone once and reuse the value for each use.
941  // Futhermore, producers should then be traversed according to a
942  // topological sorting.
943  auto [tiledOps, newContainingOp] =
944  tileAndFuseFirstExtractUse(rewriter, diag, producerOp, containingOp);
945  if (!tiledOps.empty()) {
946  LLVM_DEBUG(DBGS() << "\nFused a direct extract use\n" << *containingOp);
947  fusedOps.append(tiledOps);
948  if (newContainingOp) {
949  // Update handles associated with the containing op so we don't need to
950  // invalidate them. This is a hack to support better composability
951  // between tiling and fusion while a proper mechanism is being
952  // investigated.
953  //
954  // DO NOT replicate this elsewhere unless you understand what you are
955  // doing.
956  LogicalResult replacementStatus =
957  rewriter.notifyPayloadOperationReplaced(containingOp,
958  newContainingOp);
959  (void)replacementStatus;
960  assert(succeeded(replacementStatus) &&
961  "unable to update transform state mapping");
962  rewriter.eraseOp(containingOp);
963  containingOp = newContainingOp;
964  }
965  continue;
966  }
967 
968  SmallVector<Operation *> tiledContainingOpOperand =
970  rewriter, diag, producerOp, containingOp);
971  if (!tiledContainingOpOperand.empty()) {
972  LLVM_DEBUG(DBGS() << "\nFused an extract use through block argument\n"
973  << *containingOp);
974  fusedOps.append(tiledContainingOpOperand);
975  continue;
976  }
977 
978  Operation *cloned =
979  cloneAndFuseFirstUse(rewriter, diag, producerOp, containingOp);
980  if (cloned) {
981  LLVM_DEBUG(DBGS() << "\nFused an use by cloning\n" << *containingOp);
982  fusedOps.push_back(cloned);
983  continue;
984  }
986  }
987 
988  results.set(cast<OpResult>(getFusedOp()), fusedOps);
989  results.set(cast<OpResult>(getNewContainingOp()), {containingOp});
991 }
992 
993 void transform::FuseIntoContainingOp::getEffects(
995  consumesHandle(getProducerOp(), effects);
996  onlyReadsHandle(getContainingOp(), effects);
997  producesHandle(getResults(), effects);
998  modifiesPayload(effects);
999 }
1000 
1001 //===----------------------------------------------------------------------===//
1002 // GeneralizeOp
1003 //===----------------------------------------------------------------------===//
1004 
1006 transform::GeneralizeOp::applyToOne(transform::TransformRewriter &rewriter,
1007  LinalgOp target,
1009  transform::TransformState &state) {
1010  // Exit early if no transformation is needed.
1011  if (isa<GenericOp>(target)) {
1012  results.push_back(target);
1014  }
1015  rewriter.setInsertionPoint(target);
1016  FailureOr<LinalgOp> generic = generalizeNamedOp(rewriter, target);
1017  if (succeeded(generic)) {
1018  results.push_back(generic->getOperation());
1020  }
1021  return emitDefaultSilenceableFailure(target);
1022 }
1023 
1024 //===----------------------------------------------------------------------===//
1025 // SpecializeOp
1026 //===----------------------------------------------------------------------===/
1027 
1029 transform::SpecializeOp::applyToOne(transform::TransformRewriter &rewriter,
1030  LinalgOp target,
1032  transform::TransformState &state) {
1033  // Exit early if the operation is not a generic.
1034  if (!isa<GenericOp>(target)) {
1035  results.push_back(target);
1037  }
1038  rewriter.setInsertionPoint(target);
1039  FailureOr<LinalgOp> named =
1040  specializeGenericOp(rewriter, cast<GenericOp>(target));
1041  if (succeeded(named)) {
1042  results.push_back(named->getOperation());
1044  }
1045  return emitDefaultSilenceableFailure(target);
1046 }
1047 
1048 //===----------------------------------------------------------------------===//
1049 // InterchangeOp
1050 //===----------------------------------------------------------------------===//
1051 
1053 transform::InterchangeOp::applyToOne(transform::TransformRewriter &rewriter,
1054  GenericOp target,
1056  transform::TransformState &state) {
1057  ArrayRef<int64_t> interchangeVector = getIteratorInterchange();
1058  // Exit early if no transformation is needed.
1059  if (interchangeVector.empty()) {
1060  results.push_back(target);
1062  }
1063 
1064  unsigned numLoops = cast<LinalgOp>(target.getOperation()).getNumLoops();
1065  if (interchangeVector.size() != numLoops) {
1066  return emitSilenceableError()
1067  << getIteratorInterchangeAttrName() << " has length ("
1068  << interchangeVector.size()
1069  << ") different from the number of loops in the target operation ("
1070  << numLoops << ")";
1071  }
1072  FailureOr<GenericOp> res =
1073  interchangeGenericOp(rewriter, target,
1074  SmallVector<unsigned>(interchangeVector.begin(),
1075  interchangeVector.end()));
1076  if (failed(res))
1077  return emitDefiniteFailure() << "failed to apply";
1078  results.push_back(res->getOperation());
1080 }
1081 
1083  ArrayRef<int64_t> permutation = getIteratorInterchange();
1084  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
1085  if (!std::is_permutation(sequence.begin(), sequence.end(),
1086  permutation.begin(), permutation.end())) {
1087  return emitOpError()
1088  << "expects iterator_interchange to be a permutation, found "
1089  << getIteratorInterchange();
1090  }
1091  return success();
1092 }
1093 
1094 //===----------------------------------------------------------------------===//
1095 // LowerPackOp
1096 //===----------------------------------------------------------------------===//
1097 
1098 DiagnosedSilenceableFailure transform::LowerPackOp::applyToOne(
1099  transform::TransformRewriter &rewriter, tensor::PackOp target,
1100  transform::ApplyToEachResultList &transformResults,
1101  transform::TransformState &state) {
1102  rewriter.setInsertionPoint(target);
1103  FailureOr<LowerPackResult> res = lowerPack(rewriter, target);
1104  if (failed(res)) {
1105  return mlir::emitSilenceableFailure(target->getLoc())
1106  << "cannot lower to pad + expand + transpose";
1107  }
1108  transformResults.push_back(res->padOp);
1109  transformResults.push_back(res->expandShapeOp);
1110  transformResults.push_back(res->transposeOp);
1112 }
1113 
1114 //===----------------------------------------------------------------------===//
1115 // LowerUnPackOp
1116 //===----------------------------------------------------------------------===//
1117 
1118 DiagnosedSilenceableFailure transform::LowerUnPackOp::applyToOne(
1119  transform::TransformRewriter &rewriter, tensor::UnPackOp target,
1120  transform::ApplyToEachResultList &transformResults,
1121  transform::TransformState &state) {
1122  rewriter.setInsertionPoint(target);
1123  FailureOr<LowerUnPackOpResult> res = lowerUnPack(rewriter, target);
1124  if (failed(res)) {
1126  emitSilenceableError()
1127  << "cannot lower to transpose + collapse + extract";
1128  diag.attachNote(target->getLoc()) << "target payload op";
1129  return diag;
1130  }
1131  transformResults.push_back(res->emptyOp);
1132  transformResults.push_back(res->transposeOp);
1133  transformResults.push_back(res->collapseShapeOp);
1134  transformResults.push_back(res->extractSliceOp);
1136 }
1137 
1138 //===---------------------------------------------------------------------===//
1139 // MatchOp
1140 //===---------------------------------------------------------------------===//
1141 
1142 void transform::MatchOp::build(OpBuilder &builder, OperationState &result,
1143  Value target, ArrayRef<StringRef> opNames) {
1144  result.addOperands(target);
1145  result.addAttribute(MatchOp::getOpsAttrName(result.name),
1146  builder.getStrArrayAttr(opNames));
1147  result.addTypes(transform::AnyOpType::get(builder.getContext()));
1148 }
1149 
1150 void transform::MatchOp::build(OpBuilder &builder, OperationState &result,
1151  TypeRange resultTypes, Value target,
1152  ArrayRef<StringRef> opNames) {
1153  result.addOperands(target);
1154  result.addAttribute(MatchOp::getOpsAttrName(result.name),
1155  builder.getStrArrayAttr(opNames));
1156  result.addTypes(resultTypes);
1157 }
1158 
1160 transform::MatchOp::apply(transform::TransformRewriter &rewriter,
1161  transform::TransformResults &results,
1162  transform::TransformState &state) {
1163  llvm::StringSet<> strs;
1164  if (getOps().has_value())
1165  strs.insert(getOps()->getAsValueRange<StringAttr>().begin(),
1166  getOps()->getAsValueRange<StringAttr>().end());
1167 
1168  auto payloadOps = state.getPayloadOps(getTarget());
1169  if (!llvm::hasSingleElement(payloadOps)) {
1170  return emitDefiniteFailure("requires exactly one target handle");
1171  }
1172 
1174  bool incorrectNumOperandTypes = false;
1175  auto matchFun = [&](Operation *op) {
1176  if (getOps().has_value() && !strs.contains(op->getName().getStringRef()))
1177  return;
1178 
1179  // Interfaces cannot be matched by name, just by ID.
1180  // So we specifically encode the interfaces we care about for this op.
1181  if (getInterface().has_value()) {
1182  auto iface = getInterface().value();
1183  if (iface == transform::MatchInterfaceEnum::LinalgOp &&
1184  !isa<LinalgOp>(op))
1185  return;
1186  if (iface == transform::MatchInterfaceEnum::TilingInterface &&
1187  !isa<TilingInterface>(op))
1188  return;
1189  if (iface == transform::MatchInterfaceEnum::LoopLikeInterface &&
1190  !isa<LoopLikeOpInterface>(op))
1191  return;
1192  }
1193 
1194  // Check if all specified attributes match.
1195  if (getOpAttrs().has_value()) {
1196  DictionaryAttr opAttrs = getOpAttrs().value();
1197  for (NamedAttribute attr : opAttrs) {
1198  if (attr.getName() == getInterfaceAttrName() ||
1199  attr.getName() == getOpsAttrName())
1200  continue;
1201  if (!op->hasAttr(attr.getName()))
1202  return;
1203  if (op->getAttr(attr.getName()) != attr.getValue())
1204  return;
1205  }
1206  }
1207 
1208  if (getFilterResultType().has_value()) {
1209  Type t = getFilterResultType().value();
1210  if (op->getNumResults() != 1 || op->getResultTypes().front() != t)
1211  return;
1212  }
1213 
1214  if (getFilterOperandTypes().has_value()) {
1215  mlir::ArrayAttr types = getFilterOperandTypes().value();
1216  auto operandTypes = op->getOperandTypes();
1217 
1218  if (types.size() == 1) {
1219  // All the operands must must be equal to the specified type
1220  auto typeattr =
1221  dyn_cast<mlir::TypeAttr>(getFilterOperandTypes().value()[0]);
1222  Type t = typeattr.getValue().cast<::mlir::Type>();
1223  if (!llvm::all_of(op->getOperandTypes(),
1224  [&](Type operandType) { return operandType == t; }))
1225  return;
1226  } else {
1227  // The operand types must match all the types in the list (in the same
1228  // order in with they are specified)
1229  if (types.size() != operandTypes.size()) {
1230  incorrectNumOperandTypes = true;
1231  return;
1232  }
1233 
1234  for (auto [attr, operandType] :
1235  llvm::zip_equal(getFilterOperandTypes().value(), operandTypes)) {
1236  auto typeattr = cast<mlir::TypeAttr>(attr);
1237  Type type = typeattr.getValue().cast<::mlir::Type>();
1238 
1239  if (type != operandType)
1240  return;
1241  }
1242  }
1243  }
1244 
1245  // All constraints are satisfied.
1246  res.push_back(op);
1247  return;
1248  };
1249 
1250  (*payloadOps.begin())->walk(matchFun);
1251  if (incorrectNumOperandTypes)
1252  return emitDefiniteFailure("If filter_operand_types contains more than a "
1253  "type, then it must contain as much types as "
1254  "the number of operands in the target ops");
1255  results.set(cast<OpResult>(getResult()), res);
1257 }
1258 
1259 //===---------------------------------------------------------------------===//
1260 // MultiTileSizesOp
1261 //===---------------------------------------------------------------------===//
1262 
1264  Type targetType, Type lowSizeType, Type,
1265  Type) {
1266  printer.printFunctionalType(TypeRange{targetType}, TypeRange{lowSizeType});
1267 }
1268 
1270  Type &targetType, Type &lowSizeType,
1271  Type &highSizeType,
1272  Type &splitPointType) {
1273  FunctionType funcType;
1274  llvm::SMLoc typeLoc = parser.getCurrentLocation();
1275  if (failed(parser.parseType<FunctionType>(funcType)))
1276  return failure();
1277 
1278  if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
1279  parser.emitError(typeLoc) << "expects a trailing functional type with one "
1280  "argument and one result";
1281  }
1282  targetType = funcType.getInput(0);
1283  lowSizeType = highSizeType = splitPointType = funcType.getResult(0);
1284 
1285  return success();
1286 }
1287 
1288 DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne(
1289  transform::TransformRewriter &rewriter, LinalgOp target,
1291  if (isa<TransformParamTypeInterface>(getLowSize().getType())) {
1292  if (target.hasDynamicShape()) {
1293  auto diag = emitSilenceableError()
1294  << "cannot compute parametric tile sizes for dynamically "
1295  "shaped payload op";
1296  diag.attachNote(target->getLoc()) << "payload op";
1297  return diag;
1298  }
1299 
1301  target, getDimension(), getTargetSize(), getDivisor());
1302  if (failed(spec)) {
1303  return emitSilenceableError()
1304  << "failed to compute multi-size tiling sizes";
1305  }
1306 
1307  Builder builder(target.getContext());
1308  results.assign(llvm::map_range(
1309  ArrayRef<int64_t>({spec->lowTileSize, spec->highTileSize,
1310  spec->lowTileSize * spec->lowTripCount}),
1311  [&builder, this](int64_t value) {
1312  return builder.getIntegerAttr(
1313  cast<ParamType>(getLowSize().getType()).getType(), value);
1314  }));
1316  }
1317 
1318  OpBuilder builder(target.getContext());
1319  builder.setInsertionPoint(target);
1320  OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
1321  OpFoldResult divisor = builder.getIndexAttr(getDivisor());
1323  builder, target, getDimension(), targetSize, divisor);
1324  if (failed(spec)) {
1325  return emitSilenceableError() << "could not generate tile size computation";
1326  }
1327 
1328  AffineExpr s0 = builder.getAffineSymbolExpr(0);
1329  AffineExpr s1 = builder.getAffineSymbolExpr(1);
1330  Operation *splitPoint =
1331  affine::makeComposedAffineApply(builder, target.getLoc(), s0 * s1,
1332  {spec->lowTileSize, spec->lowTripCount});
1333  Operation *lowTileSize = spec->lowTileSize.getDefiningOp();
1334  Operation *highTileSize = spec->highTileSize.getDefiningOp();
1335  assert(lowTileSize && highTileSize && splitPoint &&
1336  "tile sizes are not produced by operations");
1337  results.reserve(results.size() + 3);
1338  results.push_back(lowTileSize);
1339  results.push_back(highTileSize);
1340  results.push_back(splitPoint);
1342 }
1343 
1344 void transform::MultiTileSizesOp::getEffects(
1346  onlyReadsHandle(getTarget(), effects);
1347  producesHandle(getResults(), effects);
1348  if (isa<TransformParamTypeInterface>(getLowSize().getType()))
1349  onlyReadsPayload(effects);
1350  else
1351  modifiesPayload(effects);
1352 }
1353 
1355  if (getLowSize().getType() != getHighSize().getType() ||
1356  getLowSize().getType() != getSplitPoint().getType()) {
1357  return emitOpError() << "expects all results type to be the same";
1358  }
1359  return success();
1360 }
1361 
1362 //===---------------------------------------------------------------------===//
1363 // PackOp
1364 //===---------------------------------------------------------------------===//
1365 
1366 void transform::PackOp::build(OpBuilder &builder, OperationState &result,
1367  Value target,
1368  ArrayRef<OpFoldResult> mixedPackedSizes) {
1369  SmallVector<int64_t> staticPackedSizes;
1370  SmallVector<Value> dynamicPackedSizes;
1371  dispatchIndexOpFoldResults(mixedPackedSizes, dynamicPackedSizes,
1372  staticPackedSizes);
1373  // Call the default builder which sets up the proper operands segment sizes
1374  // attributes for multiple variadic operands. In the absence of this, horrible
1375  // bugs ensue.
1376  Type linalgOpHType = transform::OperationType::get(
1377  builder.getContext(), GenericOp::getOperationName());
1378  build(builder, result,
1379  /*resultType=*/linalgOpHType,
1380  /*target=*/target,
1381  /*dynamic_sizes=*/dynamicPackedSizes,
1382  /*static_sizes=*/builder.getDenseI64ArrayAttr(staticPackedSizes));
1383 }
1384 
1385 SmallVector<OpFoldResult> transform::PackOp::getMixedPackedSizes() {
1386  Builder b(getContext());
1387  return getMixedValues(getStaticPackedSizes(), getPackedSizes(), b);
1388 }
1389 
1391 transform::PackOp::apply(transform::TransformRewriter &rewriter,
1392  transform::TransformResults &transformResults,
1393  transform::TransformState &state) {
1394  auto targetOps = state.getPayloadOps(getTarget());
1395  // If nothing to pack, propagate success.
1396  if (std::empty(targetOps)) {
1397  transformResults.set(cast<OpResult>(getPackedOp()),
1398  ArrayRef<Operation *>({}));
1400  }
1401  // Fail on multi-op handles.
1402  auto linalgOp = dyn_cast<LinalgOp>(*targetOps.begin());
1403  if (!llvm::hasSingleElement(targetOps) || !linalgOp) {
1404  return emitSilenceableError()
1405  << "requires target to map to exactly 1 LinalgOp (got "
1406  << llvm::range_size(targetOps) << ")";
1407  }
1408  // Fail on mismatched number of pack sizes.
1409  if (getMixedPackedSizes().size() != linalgOp.getNumLoops()) {
1410  return emitSilenceableError()
1411  << "requires number of packed sizes match the number of loops ("
1412  << getMixedPackedSizes().size() << " vs " << linalgOp.getNumLoops()
1413  << ")";
1414  }
1415 
1416  // Unpack handles to constants or actual SSA index values.
1417  SmallVector<OpFoldResult> packedSizes;
1419  state, *this, packedSizes, getMixedPackedSizes());
1420 
1421  rewriter.setInsertionPoint(linalgOp);
1422  FailureOr<PackResult> maybeResult = pack(rewriter, linalgOp, packedSizes);
1423  if (failed(maybeResult))
1424  return emitDefiniteFailure("data tiling failed");
1425 
1426  transformResults.set(cast<OpResult>(getPackedOp()),
1427  {maybeResult->packedLinalgOp.getOperation()});
1429 }
1430 
1431 void transform::PackOp::getEffects(
1433  transform::consumesHandle(getTarget(), effects);
1434  transform::onlyReadsHandle(getPackedSizes(), effects);
1435  transform::producesHandle(getPackedOp(), effects);
1436  transform::modifiesPayload(effects);
1437 }
1438 
1439 //===---------------------------------------------------------------------===//
1440 // PackGreedilyOp.
1441 //===---------------------------------------------------------------------===//
1442 
1444  if (!isPermutationVector(getMatmulInnerDimsOrder())) {
1445  return emitOpError() << getMatmulInnerDimsOrderAttrName()
1446  << " is not a valid permutation";
1447  }
1448  // TODO: relax to allow empty once we have another strategy than just matmul.
1449  if (!getMatmulPaddedSizesNextMultipleOf().empty()) {
1450  for (auto [s, nmo] :
1451  llvm::zip_equal(getMixedMatmulPackedSizes(),
1452  getMatmulPaddedSizesNextMultipleOf())) {
1453  std::optional<int64_t> maybeStaticPackedSize = getConstantIntValue(s);
1454  if (nmo != 0 &&
1455  (!maybeStaticPackedSize.has_value() || *maybeStaticPackedSize != 0)) {
1456  return emitOpError() << "at most one of the packed_size and the "
1457  "padded_sizes_next_multiple_of can be nonzero "
1458  "for the matmul strategy";
1459  }
1460  }
1461  }
1462  return success();
1463 }
1464 
1466 PackGreedilyOp::apply(transform::TransformRewriter &rewriter,
1467  transform::TransformResults &transformResults,
1468  transform::TransformState &state) {
1469  SmallVector<Operation *> results;
1470  for (Operation *op : state.getPayloadOps(getTarget())) {
1471  auto linalgOp = dyn_cast<LinalgOp>(op);
1472  if (!linalgOp)
1473  continue;
1474  // linalgOp will be replaced and the insertion point may be invalidated if
1475  // we set it before -> set it after.
1476  rewriter.setInsertionPointAfter(linalgOp);
1477  // Failing to pack greedily is perfectly fine.
1478  // In the future we will want to order packings according to some metric.
1480  /*rewriter=*/rewriter,
1481  /*linalgOp=*/linalgOp,
1482  /*mnkPackedSizes=*/getMixedMatmulPackedSizes(),
1483  /*mnkPaddedSizesNextMultipleOf=*/
1484  getMatmulPaddedSizesNextMultipleOf(),
1485  /*mnkOrder=*/getMatmulInnerDimsOrder());
1486  if (succeeded(packResult)) {
1487  results.push_back(packResult->packedLinalgOp);
1488  continue;
1489  }
1490  results.push_back(linalgOp);
1491  }
1492  transformResults.set(cast<OpResult>(getPackedOp()), results);
1494 }
1495 
1496 SmallVector<OpFoldResult> PackGreedilyOp::getMixedMatmulPackedSizes() {
1497  Builder b(getContext());
1498  return getMixedValues(getStaticMatmulPackedSizes(), getMatmulPackedSizes(),
1499  b);
1500 }
1501 
1502 void transform::PackGreedilyOp::getEffects(
1504  transform::consumesHandle(getTarget(), effects);
1505  transform::onlyReadsHandle(getMatmulPackedSizes(), effects);
1506  transform::producesHandle(getPackedOp(), effects);
1507  transform::modifiesPayload(effects);
1508 }
1509 
1510 //===---------------------------------------------------------------------===//
1511 // PackTransposeOp
1512 //===---------------------------------------------------------------------===//
1513 
1515  if (!isPermutationVector(getInnerPerm())) {
1516  return emitOpError() << getInnerPermAttrName()
1517  << " is not a valid permutation";
1518  }
1519  if (!isPermutationVector(getOuterPerm())) {
1520  return emitOpError() << getOuterPermAttrName()
1521  << " is not a valid permutation";
1522  }
1523  if (getInnerPerm().empty() && getOuterPerm().empty()) {
1524  return emitOpError() << " at least one of " << getInnerPermAttrName()
1525  << " or " << getOuterPermAttrName()
1526  << " must be specified";
1527  }
1528  return success();
1529 }
1530 
1531 namespace {
1532 enum class OuterOrInnerPerm { Outer = 0, Inner = 1 };
1533 } // namespace
1534 
1535 /// Return true if `permutation` is a valid permutation of the
1536 /// `outer_dims_perm` (case OuterOrInnerPerm::Outer) or `inner_dims_pos`
1537 /// (OuterOrInnerPerm::Inner) of the `tensor.pack` or `tensor.unpack` `op.
1538 /// This is the case when the `permutation` rank matches the rank expected by
1539 /// `op` and `permutation` is itself a permutation vector.
1540 /// Return true if either `op` or `permutation` are empty to allow a simpler
1541 /// polymorphic implementation.
1542 template <typename RelayoutOpTy>
1544  RelayoutOpTy op, ArrayRef<int64_t> permutation,
1545  OuterOrInnerPerm outerOrInnerPerm = OuterOrInnerPerm::Outer) {
1546  static_assert(
1547  llvm::is_one_of<RelayoutOpTy, tensor::PackOp, tensor::UnPackOp>::value,
1548  "applies to only pack or unpack operations");
1549  if (!op || permutation.empty())
1550  return true;
1551  size_t innerRank = op.getInnerDimsPos().size();
1552  if (outerOrInnerPerm == OuterOrInnerPerm::Inner)
1553  return permutation.size() == innerRank && isPermutationVector(permutation);
1554  // op.getOuterDimsPerm() may be empty, in which case it is identity.
1555  // Don't rely on it.
1556  if (std::is_same<RelayoutOpTy, tensor::PackOp>::value) {
1557  return permutation.size() == op.getSourceRank() &&
1558  isPermutationVector(permutation);
1559  }
1560  return permutation.size() == op.getDestRank() &&
1561  isPermutationVector(permutation);
1562 }
1563 
1565 transform::PackTransposeOp::apply(transform::TransformRewriter &rewriter,
1566  transform::TransformResults &transformResults,
1567  transform::TransformState &state) {
1568  auto packOrUnpackOps = state.getPayloadOps(getTargetPackOrUnPackOp());
1569  auto linalgOps = state.getPayloadOps(getTargetLinalgOp());
1570  // Step 1. If nothing to pack, propagate success.
1571  if (std::empty(packOrUnpackOps)) {
1572  transformResults.set(cast<OpResult>(getPackedOp()), {});
1573  transformResults.set(cast<OpResult>(getPackOp()), {});
1574  transformResults.set(cast<OpResult>(getUnPackOp()), {});
1576  }
1577 
1578  // Step 2. Bunch of runtime sanity check and error messages.
1579  // Step 2.1. Fail on multi-op handles.
1580  if (!llvm::hasSingleElement(packOrUnpackOps) ||
1581  !llvm::hasSingleElement(linalgOps)) {
1582  return emitSilenceableError()
1583  << "requires target to map to exactly 1 "
1584  "packing op and 1 packed op ("
1585  << "got " << llvm::range_size(packOrUnpackOps) << " and "
1586  << llvm::range_size(linalgOps) << ")";
1587  }
1588 
1589  // Step 2.2. Fail on wrong type.
1590  auto packOp = dyn_cast<tensor::PackOp>(*packOrUnpackOps.begin());
1591  auto unPackOp = dyn_cast<tensor::UnPackOp>(*packOrUnpackOps.begin());
1592  if ((!packOp && !unPackOp)) {
1593  return emitSilenceableError() << "requires target to map to a "
1594  "tensor.pack or tensor.unpack";
1595  }
1596  LinalgOp linalgOpTarget = dyn_cast<LinalgOp>(*linalgOps.begin());
1597  if (!linalgOpTarget)
1598  return emitSilenceableError() << "requires a LinalgOp target";
1599 
1600  // Step 2.3. Fail if we can't get the producer / consumer Linalg op.
1601  LinalgOp linalgOp;
1602  if (packOp && packOp.getResult().hasOneUse())
1603  linalgOp = dyn_cast<LinalgOp>(*(packOp.getResult().getUsers().begin()));
1604  else if (unPackOp)
1605  linalgOp = unPackOp.getSource().getDefiningOp<LinalgOp>();
1606  if (linalgOp != linalgOpTarget) {
1607  auto errorMsg =
1608  packOp ? StringLiteral{"not a single use by the LinalgOp target"}
1609  : StringLiteral{"not produced by the LinalgOp target"};
1610  return emitSilenceableError() << errorMsg;
1611  }
1612 
1613  // Step 2.4. If we have an UnPackOp, we need to fetch the symmetrical
1614  // PackOp.
1615  if (unPackOp) {
1616  assert(!packOp && "packOp must be null on entry when unPackOp is not null");
1617  OpOperand *packUse = linalgOp.getDpsInitOperand(
1618  cast<OpResult>(unPackOp.getSource()).getResultNumber());
1619  packOp = dyn_cast_or_null<tensor::PackOp>(packUse->get().getDefiningOp());
1620  if (!packOp || !packOp.getResult().hasOneUse())
1621  return emitSilenceableError() << "could not find matching pack op";
1622  }
1623 
1624  // Step 2.5. Fail if any permutation does not validate.
1625  for (auto permType : {OuterOrInnerPerm::Outer, OuterOrInnerPerm::Inner}) {
1626  ArrayRef<int64_t> perm =
1627  (permType == OuterOrInnerPerm::Outer) ? getOuterPerm() : getInnerPerm();
1628  auto errorMsg = (permType == OuterOrInnerPerm::Outer)
1629  ? StringLiteral{"invalid outer_perm"}
1630  : StringLiteral{"invalid inner_perm"};
1631  if (!isValidPackingPermutation(packOp, perm, permType) ||
1632  !isValidPackingPermutation(unPackOp, perm, permType)) {
1633  Operation *packOrUnpackOp =
1634  unPackOp ? unPackOp.getOperation() : packOp.getOperation();
1635  return emitSilenceableError() << errorMsg << ": " << *packOrUnpackOp;
1636  }
1637  }
1638 
1639  // From here on, packOp and linalgOp are always present, unPackOp may or may
1640  // not be present.
1641  assert(packOp && linalgOp && "unexpected null op");
1642 
1643  // Step 3. Actually transpose the ops.
1645  rewriter, packOp, linalgOp, unPackOp, getOuterPerm(), getInnerPerm());
1646  // Preconditions have been checked, it is an error to fail here.
1647  assert(succeeded(res) && "unexpected packTranspose failure");
1648 
1649  // Step 4. Return results.
1650  transformResults.set(cast<OpResult>(getPackOp()), {res->transposedPackOp});
1651  transformResults.set(cast<OpResult>(getPackedOp()),
1652  {res->transposedLinalgOp});
1653  if (unPackOp) {
1654  transformResults.set(cast<OpResult>(getUnPackOp()),
1655  {res->transposedUnPackOp});
1656  } else {
1657  transformResults.set(cast<OpResult>(getUnPackOp()), {});
1658  }
1659 
1661 }
1662 
1663 //===---------------------------------------------------------------------===//
1664 // PadOp
1665 //===---------------------------------------------------------------------===//
1666 
1667 void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
1668  ArrayRef<int64_t> paddingDimensions,
1669  ArrayRef<int64_t> padToMultipleOf,
1670  ArrayRef<int64_t> packPaddings,
1671  ArrayRef<Attribute> transposePaddings,
1672  StringRef copyBackOp) {
1673  auto resultType = transform::AnyOpType::get(b.getContext());
1674  return build(/*builder=*/b,
1675  /*result=*/result,
1676  /*types=*/TypeRange{resultType, resultType},
1677  /*target=*/target,
1678  /*paddingValues=*/ArrayAttr(), // let inference handle this
1679  /*paddingDimensions=*/b.getI64ArrayAttr(paddingDimensions),
1680  /*padToMultipleOf=*/
1681  (padToMultipleOf.empty() ? ArrayAttr()
1682  : b.getI64ArrayAttr(padToMultipleOf)),
1683  /*packPaddings=*/b.getI64ArrayAttr(packPaddings),
1684  /*transposePaddings=*/b.getArrayAttr(transposePaddings),
1685  /*copyBackOp=*/b.getStringAttr(copyBackOp));
1686 }
1687 
1689 transform::PadOp::apply(transform::TransformRewriter &rewriter,
1690  transform::TransformResults &results,
1691  transform::TransformState &state) {
1692  SmallVector<Operation *> paddedOps, padOps, copyBackOps;
1693 
1694  for (Operation *target : state.getPayloadOps(getTarget())) {
1695  auto linalgTarget = dyn_cast<LinalgOp>(target);
1696  if (!linalgTarget) {
1697  auto diag = emitSilenceableError() << "expected LinalgOp target";
1698  diag.attachNote(target->getLoc()) << "target op";
1699  return diag;
1700  }
1701 
1702  // Convert the integer packing flags to booleans.
1703  SmallVector<bool> packPaddings;
1704  for (int64_t packPadding :
1705  extractFromIntegerArrayAttr<int64_t>(getPackPaddings()))
1706  packPaddings.push_back(static_cast<bool>(packPadding));
1707 
1708  // Convert the padding values to attributes.
1709  SmallVector<Attribute> paddingValues;
1710  for (auto const &it :
1711  llvm::zip(getPaddingValues(), linalgTarget->getOperandTypes())) {
1712  auto attr = dyn_cast<TypedAttr>(std::get<0>(it));
1713  if (!attr) {
1714  emitOpError("expects padding values to be typed attributes");
1716  }
1717  Type elementType = getElementTypeOrSelf(std::get<1>(it));
1718  // Try to parse string attributes to obtain an attribute of element type.
1719  if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
1720  auto parsedAttr = dyn_cast_if_present<TypedAttr>(parseAttribute(
1721  stringAttr, getContext(), elementType,
1722  /*numRead=*/nullptr, /*isKnownNullTerminated=*/true));
1723  if (!parsedAttr || parsedAttr.getType() != elementType) {
1724  auto diag = this->emitOpError("expects a padding that parses to ")
1725  << elementType << ", got " << std::get<0>(it);
1726  diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";
1728  }
1729  paddingValues.push_back(parsedAttr);
1730  continue;
1731  }
1732  // Otherwise, add the attribute directly.
1733  if (attr.getType() != elementType) {
1734  auto diag = this->emitOpError("expects a padding value of type ")
1735  << elementType << ", got " << attr;
1736  diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";
1738  }
1739  paddingValues.push_back(attr);
1740  }
1741 
1742  // Extract the transpose vectors.
1743  SmallVector<SmallVector<int64_t>> transposePaddings;
1744  for (Attribute transposeVector : cast<ArrayAttr>(getTransposePaddings()))
1745  transposePaddings.push_back(extractFromIntegerArrayAttr<int64_t>(
1746  cast<ArrayAttr>(transposeVector)));
1747 
1748  LinalgOp paddedOp;
1750  options.paddingDimensions =
1751  extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
1752  SmallVector<int64_t> padToMultipleOf(options.paddingDimensions.size(), 1);
1753  if (getPadToMultipleOf().has_value())
1754  padToMultipleOf =
1755  extractFromIntegerArrayAttr<int64_t>(*getPadToMultipleOf());
1756  options.padToMultipleOf = padToMultipleOf;
1757  options.paddingValues = paddingValues;
1758  options.packPaddings = packPaddings;
1759  if (getCopyBackOp() ==
1760  bufferization::MaterializeInDestinationOp::getOperationName()) {
1763  } else if (getCopyBackOp() == linalg::CopyOp::getOperationName()) {
1765  } else if (getCopyBackOp() == kCopyOpNone) {
1767  } else {
1768  llvm_unreachable("unsupported copy_back op");
1769  }
1770 
1771  SmallVector<Value> replacements;
1772  SmallVector<tensor::PadOp> newPadOps;
1773  if (failed(rewriteAsPaddedOp(rewriter, linalgTarget, options, paddedOp,
1774  replacements, newPadOps))) {
1775  auto diag = emitSilenceableError() << "failed to pad op";
1776  diag.attachNote(target->getLoc()) << "target op";
1777  return diag;
1778  }
1779 
1780  // We need to perform our own replacement here because this API is still
1781  // used in patterns that "pad and hoist", for which the replacement values
1782  // need to be different.
1783  // TODO: clean this up and stop "pad and hoist" behavior more globally now
1784  // that we have more composable abstractions.
1785  rewriter.replaceOp(linalgTarget, replacements);
1786  paddedOps.push_back(paddedOp);
1787  padOps.append(newPadOps.begin(), newPadOps.end());
1788  if (options.copyBackOp != LinalgPaddingOptions::CopyBackOp::None) {
1789  for (Value v : replacements) {
1790  Operation *copyBackOp = v.getDefiningOp();
1791  if (!llvm::is_contained(copyBackOps, copyBackOp))
1792  copyBackOps.push_back(copyBackOp);
1793  }
1794  }
1795  }
1796 
1797  results.set(cast<OpResult>(getPadded()), paddedOps);
1798  results.set(cast<OpResult>(getPad()), padOps);
1799  results.set(cast<OpResult>(getCopy()), copyBackOps);
1801 }
1802 
1804  SmallVector<int64_t> packPaddings =
1805  extractFromIntegerArrayAttr<int64_t>(getPackPaddings());
1806  if (any_of(packPaddings, [](int64_t packPadding) {
1807  return packPadding != 0 && packPadding != 1;
1808  })) {
1809  return emitOpError()
1810  << "expects pack_paddings to contain booleans (0/1), found "
1811  << getPackPaddings();
1812  }
1813 
1814  SmallVector<int64_t> paddingDimensions =
1815  extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
1816  if (any_of(paddingDimensions,
1817  [](int64_t paddingDimension) { return paddingDimension < 0; })) {
1818  return emitOpError() << "expects padding_dimensions to contain positive "
1819  "integers, found "
1820  << getPaddingDimensions();
1821  }
1822  if (getPadToMultipleOf().has_value()) {
1823  if (getPadToMultipleOf()->size() != paddingDimensions.size()) {
1824  return emitOpError() << "expects as many multiples as padding_dimensions";
1825  }
1826  }
1827  ArrayAttr transposes = getTransposePaddings();
1828  for (Attribute attr : transposes) {
1829  SmallVector<int64_t> transpose = extractFromIntegerArrayAttr<int64_t>(attr);
1830  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
1831  if (!std::is_permutation(sequence.begin(), sequence.end(),
1832  transpose.begin(), transpose.end())) {
1833  return emitOpError()
1834  << "expects transpose_paddings to be a permutation, found "
1835  << attr;
1836  }
1837  }
1838  if (getCopyBackOp() !=
1839  bufferization::MaterializeInDestinationOp::getOperationName() &&
1840  getCopyBackOp() != linalg::CopyOp::getOperationName() &&
1841  getCopyBackOp() != kCopyOpNone)
1842  return emitOpError() << "invalid copy_back_op";
1843  return success();
1844 }
1845 
1846 //===---------------------------------------------------------------------===//
1847 // HoistPadOp
1848 //===---------------------------------------------------------------------===//
1849 
1850 DiagnosedSilenceableFailure transform::HoistPadBuildPackingLoopNestOp::apply(
1851  transform::TransformRewriter &rewriter,
1852  transform::TransformResults &transformResults,
1853  transform::TransformState &state) {
1854  auto targetOps = state.getPayloadOps(getTarget());
1855  auto loopOps = state.getPayloadOps(getLoop());
1856  if (!llvm::hasSingleElement(targetOps) || !llvm::hasSingleElement(loopOps)) {
1857  return emitDefiniteFailure()
1858  << "requires exactly one target and one loop handle (got "
1859  << llvm::range_size(targetOps) << " and "
1860  << llvm::range_size(loopOps) << ")";
1861  }
1862 
1863  auto padOp = dyn_cast_or_null<tensor::PadOp>(*targetOps.begin());
1864  auto loopOp = dyn_cast_or_null<scf::ForOp>(*loopOps.begin());
1865  if (!padOp || !loopOp)
1866  return emitDefiniteFailure() << "requires exactly 2 non-null handles";
1867 
1869  linalg::detail::buildPackingLoopNest(rewriter, padOp, loopOp,
1870  getTranspose());
1871  if (failed(result))
1872  return emitDefiniteFailure() << "could not build packing loop nest";
1873 
1874  if (result->clonedLoopIvs.empty()) {
1875  transformResults.set(cast<OpResult>(getPackingLoop()),
1876  {result->hoistedPadOp.getOperation()});
1878  }
1879  auto outerPackedLoop =
1880  scf::getForInductionVarOwner(result->clonedLoopIvs.front());
1881  transformResults.set(cast<OpResult>(getPackingLoop()),
1882  {outerPackedLoop.getOperation()});
1884 }
1885 
1887  ArrayRef<int64_t> transpose = getTranspose();
1888  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
1889  if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),
1890  transpose.end())) {
1891  return emitOpError() << "expects transpose to be a permutation, found "
1892  << getTranspose();
1893  }
1894  return success();
1895 }
1896 
1897 void transform::HoistPadBuildPackingLoopNestOp::getEffects(
1899  transform::onlyReadsHandle(getTarget(), effects);
1900  transform::onlyReadsHandle(getLoop(), effects);
1901  transform::producesHandle(getPackingLoop(), effects);
1902  transform::modifiesPayload(effects);
1903 }
1904 
1906 transform::HoistPadOp::applyToOne(transform::TransformRewriter &rewriter,
1907  tensor::PadOp target,
1909  transform::TransformState &state) {
1910  tensor::PadOp hoistedPadOp;
1911  SmallVector<GenericOp> transposeOps;
1912  FailureOr<Value> result =
1913  hoistPaddingOnTensors(rewriter, target, getNumLoops(), getTranspose(),
1914  hoistedPadOp, transposeOps);
1915  if (succeeded(result)) {
1916  // We need to perform our own replacement here because this API is still
1917  // used in patterns that "pad and hoist", for which the replacement values
1918  // need to be different.
1919  // TODO: clean this up and stop "pad and hoist" behavior more globally now
1920  // that we have more composable abstractions.
1921  rewriter.replaceOp(target, *result);
1922  results.push_back(hoistedPadOp);
1924  }
1925  return emitDefaultSilenceableFailure(target);
1926 }
1927 
1929  ArrayRef<int64_t> transpose = getTranspose();
1930  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
1931  if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),
1932  transpose.end())) {
1933  return emitOpError() << "expects transpose to be a permutation, found "
1934  << getTranspose();
1935  }
1936  return success();
1937 }
1938 
1939 //===----------------------------------------------------------------------===//
1940 // PromoteOp
1941 //===----------------------------------------------------------------------===//
1942 
1944 transform::PromoteOp::applyToOne(transform::TransformRewriter &rewriter,
1945  LinalgOp target,
1947  transform::TransformState &state) {
1948  LinalgPromotionOptions promotionOptions;
1949  if (!getOperandsToPromote().empty())
1950  promotionOptions = promotionOptions.setOperandsToPromote(
1951  extractFromIntegerArrayAttr<int64_t>(getOperandsToPromote()));
1952  if (getUseFullTilesByDefault())
1953  promotionOptions = promotionOptions.setUseFullTileBuffersByDefault(
1954  getUseFullTilesByDefault());
1955  if (getUseAlloca())
1956  promotionOptions = promotionOptions.setUseAlloca(getUseAlloca());
1957  if (!getUseFullTileBuffers().empty())
1958  promotionOptions = promotionOptions.setUseFullTileBuffers(
1959  llvm::to_vector(getUseFullTileBuffers().getAsValueRange<BoolAttr>()));
1960  if (getAlignment().has_value())
1961  promotionOptions = promotionOptions.setAlignment(*getAlignment());
1962  if (getMemorySpace().has_value())
1963  promotionOptions = promotionOptions.setMemorySpace(*getMemorySpace());
1964 
1965  if (getMapping().has_value()) {
1966  // The mapping should only contain an element
1967  auto mapping = *getMapping();
1968  if (mapping.size() > 1)
1969  return emitDefaultDefiniteFailure(target);
1970 
1971  auto addressSpace = cast<mlir::gpu::GPUMemorySpaceMappingAttr>(mapping[0]);
1972 
1973  if (addressSpace.getAddressSpace() ==
1974  mlir::gpu::GPUDialect::getWorkgroupAddressSpace()) {
1975  promotionOptions =
1976  promotionOptions
1980  .setUseFullTileBuffers({false, false});
1981  } else if (addressSpace.getAddressSpace() ==
1982  mlir::gpu::GPUDialect::getPrivateAddressSpace()) {
1983  promotionOptions =
1984  promotionOptions
1988  .setUseFullTileBuffers({false, false});
1989  } else {
1990  return emitDefaultDefiniteFailure(target);
1991  }
1992  }
1993 
1994  if (failed(promoteSubviewsPrecondition(target, promotionOptions)))
1995  return emitDefaultDefiniteFailure(target);
1996 
1997  rewriter.setInsertionPoint(target);
1998  FailureOr<LinalgOp> res = promoteSubViews(rewriter, target, promotionOptions);
1999  if (failed(res))
2000  return emitDefaultDefiniteFailure(target);
2001  results.push_back(target);
2003 }
2004 
2005 //===----------------------------------------------------------------------===//
2006 // ReplaceOp
2007 //===----------------------------------------------------------------------===//
2008 
2010 transform::ReplaceOp::apply(transform::TransformRewriter &rewriter,
2011  TransformResults &transformResults,
2012  TransformState &state) {
2013  auto payload = state.getPayloadOps(getTarget());
2014 
2015  // Check for invalid targets.
2016  for (Operation *target : payload) {
2017  if (target->getNumOperands() > 0)
2018  return emitDefiniteFailure() << "expected target without operands";
2019  if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
2020  target->getNumRegions() > 0)
2021  return emitDefiniteFailure()
2022  << "expected target that is isolated from above";
2023  }
2024 
2025  // Clone and replace.
2026  Operation *pattern = &getBodyRegion().front().front();
2027  SmallVector<Operation *> replacements;
2028  for (Operation *target : payload) {
2029  if (getOperation()->isAncestor(target))
2030  continue;
2031  rewriter.setInsertionPoint(target);
2032  Operation *replacement = rewriter.clone(*pattern);
2033  rewriter.replaceOp(target, replacement->getResults());
2034  replacements.push_back(replacement);
2035  }
2036  transformResults.set(cast<OpResult>(getReplacement()), replacements);
2038 }
2039 
2040 void transform::ReplaceOp::getEffects(
2042  consumesHandle(getTarget(), effects);
2043  producesHandle(getReplacement(), effects);
2044  modifiesPayload(effects);
2045 }
2046 
2048  if (!getBodyRegion().hasOneBlock())
2049  return emitOpError() << "expected one block";
2050  if (std::distance(getBodyRegion().front().begin(),
2051  getBodyRegion().front().end()) != 1)
2052  return emitOpError() << "expected one operation in block";
2053  Operation *replacement = &getBodyRegion().front().front();
2054  if (replacement->getNumOperands() > 0)
2055  return replacement->emitOpError()
2056  << "expected replacement without operands";
2057  if (!replacement->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
2058  replacement->getNumRegions() > 0)
2059  return replacement->emitOpError()
2060  << "expect op that is isolated from above";
2061  return success();
2062 }
2063 
2064 //===----------------------------------------------------------------------===//
2065 // ScalarizeOp
2066 //===----------------------------------------------------------------------===//
2067 
2069 transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
2070  LinalgOp target,
2072  transform::TransformState &state) {
2073  scf::SCFTilingOptions tilingOptions;
2074  tilingOptions.setTileSizeComputationFunction([&](OpBuilder &b, Operation *) {
2075  SmallVector<OpFoldResult> tileSizes;
2076  Location loc = target.getLoc();
2077  SmallVector<OpFoldResult> allShapeSizes =
2078  target.createFlatListOfOperandDims(b, loc);
2079  AffineMap map = target.getShapesToLoopsMap();
2080  if (!map)
2081  return tileSizes;
2082  SmallVector<OpFoldResult> shapeSizes =
2084  allShapeSizes);
2085  // If the shape size is dynamic, tile by 1.
2086  // Otherwise, do not tile (i.e. tile size 0).
2087  for (OpFoldResult shapeSize : shapeSizes) {
2088  tileSizes.push_back(getConstantIntValue(shapeSize) ? b.getIndexAttr(0)
2089  : b.getIndexAttr(1));
2090  }
2091  return tileSizes;
2092  });
2093  SmallVector<int64_t> emptyTileSizes;
2094  rewriter.setInsertionPoint(target);
2095  FailureOr<scf::SCFTilingResult> maybeTilingResult = tileUsingSCF(
2096  rewriter, cast<TilingInterface>(target.getOperation()), tilingOptions);
2097  if (failed(maybeTilingResult))
2098  return emitDefaultDefiniteFailure(target);
2099 
2100  if (target->getNumResults())
2101  rewriter.replaceOp(target, maybeTilingResult->replacements);
2102  else
2103  rewriter.eraseOp(target);
2104 
2105  results.reserve(maybeTilingResult->tiledOps.size());
2106  for (Operation *tiled : maybeTilingResult->tiledOps)
2107  results.push_back(tiled);
2109 }
2110 
2111 //===----------------------------------------------------------------------===//
2112 // ConvertToLoopsOp
2113 //===----------------------------------------------------------------------===//
2114 
2116 transform::ConvertToLoopsOp::apply(transform::TransformRewriter &rewriter,
2117  transform::TransformResults &results,
2118  transform::TransformState &state) {
2120  for (Operation *target : state.getPayloadOps(getTarget())) {
2121  auto tilingOp = dyn_cast<TilingInterface>(*target);
2122  if (!target) {
2124  emitSilenceableError()
2125  << "expected the payload to implement TilingInterface";
2126  diag.attachNote(target->getLoc()) << "payload op";
2127  return diag;
2128  }
2129  rewriter.setInsertionPoint(target);
2130  FailureOr<SmallVector<scf::ForOp>> generatedLoops =
2131  scf::lowerToLoopsUsingSCFForOp(rewriter, tilingOp);
2132  if (failed(generatedLoops))
2133  return emitDefaultDefiniteFailure(target);
2134  for (scf::ForOp &loop : *generatedLoops) {
2135  loops.push_back(loop.getOperation());
2136  }
2137  rewriter.eraseOp(target);
2138  }
2139  results.set(cast<OpResult>(getResult()), loops);
2141 }
2142 
2143 //===----------------------------------------------------------------------===//
2144 // RewriteInDestinationPassingStyleOp
2145 //===----------------------------------------------------------------------===//
2146 
2148 transform::RewriteInDestinationPassingStyleOp::applyToOne(
2149  transform::TransformRewriter &rewriter, Operation *target,
2151  transform::TransformState &state) {
2153  rewriter.setInsertionPoint(target);
2154  FailureOr<Operation *> maybeResult =
2156  .Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp>(
2157  [&rewriter](auto op) {
2158  return rewriteInDestinationPassingStyle(rewriter, op);
2159  });
2160  if (failed(maybeResult))
2161  return emitDefaultSilenceableFailure(target);
2162  results.push_back(*maybeResult);
2164 }
2165 
2166 //===----------------------------------------------------------------------===//
2167 // SplitOp
2168 //===----------------------------------------------------------------------===//
2169 
2171 SplitOp::apply(transform::TransformRewriter &rewriter,
2172  TransformResults &results, TransformState &state) {
2173  // Collect the dynamic split points if provided.
2174  SmallVector<Operation *> payload =
2175  llvm::to_vector(state.getPayloadOps(getTarget()));
2176  SmallVector<OpFoldResult> splitPoints;
2177  splitPoints.reserve(payload.size());
2178  if (getDynamicSplitPoint()) {
2180  if (isa<TransformHandleTypeInterface>(getDynamicSplitPoint().getType())) {
2181  splitPoints = llvm::to_vector(llvm::map_range(
2182  state.getPayloadOps(getDynamicSplitPoint()), [&](Operation *op) {
2183  if (op->getNumResults() != 1 ||
2184  !op->getResult(0).getType().isIndex()) {
2185  diag = emitSilenceableError()
2186  << "expected dynamic split point handle to point to a "
2187  "single-result index-typed op";
2188  diag.attachNote(op->getLoc()) << "dynamic split point";
2189  }
2190  return OpFoldResult(op->getResult(0));
2191  }));
2192  } else {
2193  splitPoints = llvm::to_vector(
2194  llvm::map_range(state.getParams(getDynamicSplitPoint()),
2195  [](Attribute attr) { return OpFoldResult(attr); }));
2196  }
2197  if (diag.isSilenceableFailure())
2198  return diag;
2199 
2200  if (splitPoints.size() != payload.size()) {
2201  return emitDefiniteFailure()
2202  << "expected the dynamic split point handle to point to as "
2203  "many operations ("
2204  << splitPoints.size() << ") as the target handle ("
2205  << payload.size() << ")";
2206  }
2207  } else {
2208  splitPoints.resize(payload.size(),
2209  rewriter.getIndexAttr(getStaticSplitPoint()));
2210  }
2211 
2212  // Split each target operation.
2213  SmallVector<Operation *> first, second;
2214  Operation *noSecondPart = nullptr;
2215  for (const auto &pair : llvm::zip(payload, splitPoints)) {
2216  Operation *target = std::get<0>(pair);
2217  auto linalgOp = dyn_cast<LinalgOp>(target);
2218  if (!linalgOp) {
2219  auto diag = emitSilenceableError() << "only applies to structured ops";
2220  diag.attachNote(target->getLoc()) << "target op";
2221  return diag;
2222  }
2223 
2224  if (getDimension() >= linalgOp.getNumLoops()) {
2225  auto diag = emitSilenceableError() << "dimension " << getDimension()
2226  << " does not exist in target op";
2227  diag.attachNote(target->getLoc()) << "target op";
2228  return diag;
2229  }
2230 
2231  rewriter.setInsertionPoint(linalgOp);
2232  std::tie(first.emplace_back(), second.emplace_back()) = linalg::splitOp(
2233  rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2234  getDimension(), std::get<1>(pair));
2235 
2236  // Propagate errors.
2237  if (!first.back() && !second.back()) {
2238  auto diag = emitDefiniteFailure() << "internal failure in splitting";
2239  diag.attachNote(target->getLoc()) << "target op";
2240  return diag;
2241  }
2242 
2243  // Do not add null second parts.
2244  if (!second.back()) {
2245  noSecondPart = target;
2246  second.pop_back();
2247  }
2248  }
2249 
2250  if (second.size() != first.size() && !second.empty()) {
2251  auto diag = emitSilenceableError()
2252  << "splitting does not produce the second part for a subset "
2253  "of targets";
2254  diag.attachNote() << "expected splitting to produce the second part of all "
2255  "or none of the targets";
2256  diag.attachNote(noSecondPart->getLoc())
2257  << "first target with no second part";
2258  return diag;
2259  }
2260 
2261  results.set(cast<OpResult>(getFirst()), first);
2262  results.set(cast<OpResult>(getSecond()), second);
2264 }
2265 
2266 void SplitOp::getEffects(
2268  consumesHandle(getTarget(), effects);
2269  if (getDynamicSplitPoint())
2270  onlyReadsHandle(getDynamicSplitPoint(), effects);
2271  producesHandle(getResults(), effects);
2272  modifiesPayload(effects);
2273 }
2274 
2276  OpAsmParser::UnresolvedOperand target, dynamicSplitPoint;
2277  IntegerAttr staticSplitPoint;
2278  if (parser.parseOperand(target) || parser.parseKeyword("after"))
2279  return failure();
2280 
2281  OptionalParseResult dynamicPointParseResult =
2282  parser.parseOptionalOperand(dynamicSplitPoint);
2283  if (!dynamicPointParseResult.has_value()) {
2284  int64_t staticSplitPointValue;
2285  if (failed(parser.parseInteger(staticSplitPointValue)))
2286  return failure();
2287 
2288  staticSplitPoint =
2289  parser.getBuilder().getI64IntegerAttr(staticSplitPointValue);
2290  }
2291 
2292  Type targetType;
2293  if (parser.parseOptionalAttrDict(result.attributes) ||
2294  parser.parseColonType(targetType) ||
2295  parser.resolveOperand(target, targetType, result.operands)) {
2296  return failure();
2297  }
2298  if (dynamicPointParseResult.has_value()) {
2299  Type splitPointType;
2300  if (failed(*dynamicPointParseResult) || parser.parseComma() ||
2301  parser.parseType(splitPointType) ||
2302  parser.resolveOperand(dynamicSplitPoint, splitPointType,
2303  result.operands)) {
2304  return failure();
2305  }
2306 
2307  staticSplitPoint =
2308  parser.getBuilder().getI64IntegerAttr(ShapedType::kDynamic);
2309  }
2310 
2311  result.addAttribute(
2312  SplitOp::getStaticSplitPointAttrName(result.name).getValue(),
2313  staticSplitPoint);
2314  result.addTypes({targetType, targetType});
2315  return success();
2316 }
2317 
2318 void SplitOp::print(OpAsmPrinter &printer) {
2319  printer << " " << getTarget() << " after ";
2320  int64_t staticSplitSize = static_cast<int64_t>(getStaticSplitPoint());
2321  if (staticSplitSize != ShapedType::kDynamic)
2322  printer << staticSplitSize;
2323  else
2324  printer << getDynamicSplitPoint();
2325  printer << " ";
2326  printer.printOptionalAttrDict(getOperation()->getAttrs(),
2327  {getStaticSplitPointAttrName()});
2328  printer << " : " << getTarget().getType();
2329  if (staticSplitSize == ShapedType::kDynamic)
2330  printer << ", " << getDynamicSplitPoint().getType();
2331 }
2332 
2334  if ((static_cast<int64_t>(getStaticSplitPoint()) != ShapedType::kDynamic) ^
2335  (getDynamicSplitPoint() == nullptr)) {
2336  return emitOpError() << "expects either a dynamic or a static split "
2337  "point to be provided";
2338  }
2339  return success();
2340 }
2341 
2342 //===----------------------------------------------------------------------===//
2343 // SplitReductionOp
2344 //===----------------------------------------------------------------------===//
2345 
2346 void transform::SplitReductionOp::build(
2347  OpBuilder &builder, OperationState &result, Value target,
2348  int64_t splitFactor, int64_t insertSplitDimension, bool innerParallel,
2349  bool useScalingAlgorithm, bool useAlloc) {
2350  MLIRContext *ctx = builder.getContext();
2351  result.addOperands(target);
2352  result.addAttribute(SplitReductionOp::getSplitFactorAttrName(result.name),
2353  builder.getI64IntegerAttr(splitFactor));
2354  result.addAttribute(
2355  SplitReductionOp::getInsertSplitDimensionAttrName(result.name),
2356  builder.getI64IntegerAttr(insertSplitDimension));
2357  if (innerParallel) {
2358  result.addAttribute(SplitReductionOp::getInnerParallelAttrName(result.name),
2359  builder.getUnitAttr());
2360  }
2361  if (useScalingAlgorithm) {
2362  result.addAttribute(
2363  SplitReductionOp::getUseScalingAlgorithmAttrName(result.name),
2364  builder.getUnitAttr());
2365  }
2366  if (useAlloc) {
2367  result.addAttribute(SplitReductionOp::getUseAllocAttrName(result.name),
2368  builder.getUnitAttr());
2369  }
2370  auto resultType = transform::AnyOpType::get(ctx);
2371  result.addTypes({resultType, resultType, resultType, resultType});
2372 }
2373 
2374 DiagnosedSilenceableFailure transform::SplitReductionOp::applyToOne(
2375  transform::TransformRewriter &rewriter, LinalgOp target,
2377  transform::TransformState &state) {
2378  ControlSplitReductionFn splitFn = [&](LinalgOp) {
2379  return linalg::SplitReductionOptions{int64_t(getSplitFactor()),
2380  unsigned(getInsertSplitDimension()),
2381  bool(getInnerParallel())};
2382  };
2383  rewriter.setInsertionPoint(target);
2384  FailureOr<SplitReductionResult> splitResult =
2385  (getUseScalingAlgorithm())
2386  ? splitReductionByScaling(rewriter, target, splitFn, getUseAlloc())
2387  : splitReduction(rewriter, target, splitFn, getUseAlloc());
2388  if (failed(splitResult))
2389  return emitDefaultDefiniteFailure(target);
2390 
2391  results.push_back(splitResult->initOrAlloc);
2392  results.push_back(splitResult->fillOp);
2393  results.push_back(splitResult->splitLinalgOp);
2394  results.push_back(splitResult->resultCombiningLinalgOp);
2396 }
2397 
2398 //===----------------------------------------------------------------------===//
2399 // TileReductionUsingForOp
2400 //===----------------------------------------------------------------------===//
2401 
2402 void transform::TileReductionUsingForOp::build(
2403  OpBuilder &builder, OperationState &result, Value target,
2404  ArrayRef<int64_t> staticTileSizes) {
2405  // Call the default builder.
2406  // This is future-proof re mixed static-dynamic and setting up the proper
2407  // operands segment sizes attributes for multiple variadic operands.
2408  // In the absence of this, horrible bugs ensue.
2409  // TODO: support mixed static-dynamic (see TileUsingForallOp).
2410  MLIRContext *ctx = builder.getContext();
2411  auto opTy = transform::AnyOpType::get(ctx);
2412  auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
2413  build(builder, result,
2414  /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
2415  /*target=*/target,
2416  /*tile_sizes=*/staticTileSizesAttr);
2417 }
2418 
2419 DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
2420  transform::TransformRewriter &rewriter, LinalgOp target,
2422  transform::TransformState &state) {
2423  rewriter.setInsertionPoint(target);
2425  rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
2426  getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes())));
2427 
2428  if (failed(result))
2429  return emitDefaultSilenceableFailure(target);
2430  results.push_back(result->initialOp);
2431  results.push_back(result->parallelTiledOp);
2432  results.push_back(result->mergeOp);
2433  results.push_back(result->loops.front());
2435 }
2436 
2437 //===----------------------------------------------------------------------===//
2438 // TileReductionUsingForallOp
2439 //===----------------------------------------------------------------------===//
2440 
2441 void transform::TileReductionUsingForallOp::build(
2442  OpBuilder &builder, OperationState &result, Value target,
2443  ArrayRef<int64_t> staticNumThreads, ArrayRef<int64_t> staticTileSizes,
2444  ArrayAttr mapping) {
2445  // Call the default builder.
2446  // This is future-proof re mixed static-dynamic and setting up the proper
2447  // operands segment sizes attributes for multiple variadic operands.
2448  // In the absence of this, horrible bugs ensue.
2449  // TODO: support mixed static-dynamic (see TileUsingForallOp).
2450  MLIRContext *ctx = builder.getContext();
2451  auto opTy = transform::AnyOpType::get(ctx);
2452  auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads);
2453  auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
2454  build(builder, result,
2455  /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
2456  /*target=*/target,
2457  /*num_threads=*/staticNumThreadsAttr,
2458  /*tile_sizes=*/staticTileSizesAttr,
2459  /*mapping=*/mapping);
2460 }
2461 
2462 DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
2463  transform::TransformRewriter &rewriter, LinalgOp target,
2465  transform::TransformState &state) {
2466  rewriter.setInsertionPoint(target);
2467  SmallVector<OpFoldResult> numThreads =
2468  getAsOpFoldResult(rewriter.getI64ArrayAttr(getNumThreads()));
2469  SmallVector<OpFoldResult> tileSizes =
2470  getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes()));
2473  rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
2474  numThreads, tileSizes, getMapping());
2475 
2476  if (failed(result)) {
2477  auto diag = emitSilenceableError() << "could not tile reduction";
2478  diag.attachNote(target.getLoc()) << "target operation";
2479  return diag;
2480  }
2481  results.push_back(result->initialOp);
2482  results.push_back(result->parallelTiledOp);
2483  results.push_back(result->mergeOp);
2484  results.push_back(result->loops);
2486 }
2487 
2488 //===----------------------------------------------------------------------===//
2489 // TileUsingForOp
2490 //===----------------------------------------------------------------------===//
2491 
2492 void transform::TileUsingForOp::build(
2493  OpBuilder &builder, OperationState &result, TypeRange loopTypes,
2494  Value target, ArrayRef<int64_t> staticTileSizes,
2495  ArrayRef<int64_t> interchange,
2496  std::optional<ArrayRef<bool>> scalableSizes) {
2497  return build(builder, result, loopTypes,
2498  /*target=*/target,
2499  /*mixedTileSizes=*/
2500  getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
2501  interchange, scalableSizes);
2502 }
2503 
2504 void transform::TileUsingForOp::build(
2505  OpBuilder &builder, OperationState &result, Value target,
2506  ArrayRef<int64_t> staticTileSizes, ArrayRef<int64_t> interchange,
2507  std::optional<ArrayRef<bool>> scalableSizes) {
2508  build(builder, result, target,
2509  getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
2510  interchange, scalableSizes);
2511 }
2512 
2513 void transform::TileUsingForOp::build(
2514  OpBuilder &builder, OperationState &result, Value target,
2515  ArrayRef<OpFoldResult> mixedTileSizes, ArrayRef<int64_t> interchange,
2516  std::optional<ArrayRef<bool>> scalableSizes) {
2517  // Loop types are automaticaly splat by the callee, setting up one is
2518  // enough.
2519  SmallVector<Type> loopTypes(1, builder.getType<transform::AnyOpType>());
2520  build(builder, result, loopTypes, target, mixedTileSizes, interchange,
2521  scalableSizes);
2522 }
2523 
2524 void transform::TileUsingForOp::build(
2525  OpBuilder &builder, OperationState &result, TypeRange loopTypes,
2526  Value target, ArrayRef<OpFoldResult> mixedTileSizes,
2527  ArrayRef<int64_t> interchange,
2528  std::optional<ArrayRef<bool>> scalableSizes) {
2529  SmallVector<int64_t> staticTileSizes;
2530  SmallVector<Value> dynamicTileSizes;
2531  dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
2532  // Call the default builder which sets up the proper operands segment sizes
2533  // attributes for multiple variadic operands. In the absence of this,
2534  // horrible bugs ensue.
2535  auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
2536  unsigned numExpectedLoops =
2537  staticTileSizes.size() - llvm::count(staticTileSizes, 0);
2538  SmallVector<Type> resultTypes;
2539  resultTypes.reserve(numExpectedLoops);
2540  assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) &&
2541  "expected one loop type or as many as loops");
2542  if (loopTypes.size() == 1)
2543  resultTypes.append(numExpectedLoops, loopTypes[0]);
2544  else
2545  llvm::append_range(resultTypes, loopTypes);
2546  SmallVector<bool> expandedScalableSizes(mixedTileSizes.size(), false);
2547  if (scalableSizes.has_value())
2548  expandedScalableSizes.assign(scalableSizes->begin(), scalableSizes->end());
2549  build(builder, result, /*tiled_linalg_op=*/target.getType(),
2550  /*loops=*/resultTypes,
2551  /*target=*/target,
2552  /*dynamic_sizes=*/dynamicTileSizes,
2553  /*static_sizes=*/staticTileSizesAttr,
2554  /*interchange=*/builder.getDenseI64ArrayAttr(interchange),
2555  /*scalable_sizes=*/expandedScalableSizes);
2556 }
2557 
2559  if (getMixedSizes().size() != getScalableSizes().size())
2560  return emitOpError("expected same number of sizes (")
2561  << getMixedSizes().size() << ") and scalable sizes ()"
2562  << getScalableSizes().size() << ")";
2563  return success();
2564 }
2565 
2567 transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter,
2568  TransformResults &transformResults,
2569  TransformState &state) {
2570  ArrayRef<int64_t> tileSizes = getStaticSizes();
2571 
2572  SmallVector<Operation *> targets =
2573  llvm::to_vector(state.getPayloadOps(getTarget()));
2574  SmallVector<SmallVector<Operation *>> dynamicSizeProducers;
2576  dynamicSizeProducers.reserve(getDynamicSizes().size());
2577  paramSizes.reserve(getDynamicSizes().size());
2578  for (Value transformValue : getDynamicSizes()) {
2579  if (isa<ParamType>(transformValue.getType())) {
2580  dynamicSizeProducers.push_back({});
2581  ArrayRef<Attribute> params = state.getParams(transformValue);
2582  paramSizes.push_back(
2583  llvm::to_vector(llvm::map_range(params, [](Attribute attr) {
2584  return cast<IntegerAttr>(attr).getValue().getSExtValue();
2585  })));
2586 
2587  if (paramSizes.back().size() != targets.size()) {
2589  emitSilenceableError()
2590  << "expected as many parameter values ("
2591  << dynamicSizeProducers.back().size() << ") as target ops ("
2592  << targets.size() << ")";
2593  diag.attachNote(transformValue.getLoc()) << "for this parameter";
2594  return diag;
2595  }
2596 
2597  continue;
2598  }
2599  paramSizes.push_back({});
2600  dynamicSizeProducers.push_back(
2601  llvm::to_vector(state.getPayloadOps(transformValue)));
2602 
2603  if (dynamicSizeProducers.back().size() != targets.size()) {
2605  emitSilenceableError()
2606  << "expected as many dynamic size-producing operations ("
2607  << dynamicSizeProducers.back().size() << ") as target ops ("
2608  << targets.size() << ")";
2609  diag.attachNote(transformValue.getLoc()) << "for this handle";
2610  return diag;
2611  }
2612 
2613  for (Operation *op : dynamicSizeProducers.back()) {
2614  if (op->getNumResults() == 1 &&
2615  isa<IndexType>(op->getResult(0).getType())) {
2616  continue;
2617  }
2618 
2620  emitSilenceableError() << "expected sizes to be produced by ops "
2621  "with a single index-type result";
2622  diag.attachNote(op->getLoc()) << "size producer op";
2623  diag.attachNote(transformValue.getLoc()) << "for this handle";
2624  return diag;
2625  }
2626  }
2627 
2630  loops.resize(getLoops().size());
2631  auto scalableSizes = getScalableSizes();
2632  for (auto [i, op] : llvm::enumerate(targets)) {
2633  auto tilingInterface = dyn_cast<TilingInterface>(op);
2634  if (!tilingInterface) {
2636  emitSilenceableError()
2637  << "only ops implementing TilingInterface are supported";
2638  diag.attachNote(op->getLoc()) << "target op";
2639  return diag;
2640  }
2641  if (tileSizes.size() > tilingInterface.getLoopIteratorTypes().size()) {
2643  emitSilenceableError()
2644  << "too many tiles provided, expected at most "
2645  << tilingInterface.getLoopIteratorTypes().size() << " found "
2646  << tileSizes.size();
2647  diag.attachNote(op->getLoc()) << "target op";
2648  return diag;
2649  }
2650 
2651  scf::SCFTilingOptions tilingOptions;
2652  if (tileSizes.empty()) {
2653  tilingOptions.setTileSizeComputationFunction(
2655  return {};
2656  });
2657  } else {
2658  tilingOptions.setTileSizeComputationFunction([&, index = i](OpBuilder &b,
2659  Operation *) {
2661  sizes.reserve(tileSizes.size());
2662  unsigned dynamicIdx = 0;
2663 
2664  for (auto [ofrIdx, ofr] : llvm::enumerate(getMixedSizes())) {
2665  if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) {
2666  if (scalableSizes[ofrIdx]) {
2667  auto val = b.create<arith::ConstantIndexOp>(
2668  getLoc(), attr.cast<IntegerAttr>().getInt());
2669  Value vscale =
2670  b.create<vector::VectorScaleOp>(getLoc(), b.getIndexType());
2671  sizes.push_back(
2672  b.create<arith::MulIOp>(getLoc(), val, vscale).getResult());
2673  } else {
2674  sizes.push_back(attr);
2675  }
2676  continue;
2677  }
2678  ArrayRef<Operation *> dynamicSizes = dynamicSizeProducers[dynamicIdx];
2679  ArrayRef<int64_t> params = paramSizes[dynamicIdx];
2680  ++dynamicIdx;
2681  assert((dynamicSizes.empty() ^ params.empty()) &&
2682  "expected either dynamic sizes or parameters");
2683  if (!params.empty()) {
2684  sizes.push_back(b.getIndexAttr(params[index]));
2685  } else {
2686  sizes.push_back(dynamicSizes[index]->getResult(0));
2687  }
2688  }
2689  return sizes;
2690  });
2691  }
2692 
2693  tilingOptions.setInterchange(getInterchange());
2694  FailureOr<scf::SCFTilingResult> maybeTilingResult =
2695  tileUsingSCF(rewriter, tilingInterface, tilingOptions);
2696  if (failed(maybeTilingResult))
2698 
2699  rewriter.replaceOp(op, maybeTilingResult->replacements);
2700 
2701  tiled.append(maybeTilingResult->tiledOps);
2702  for (const auto &en2 : llvm::enumerate(maybeTilingResult->loops))
2703  loops[en2.index()].push_back(en2.value());
2704  }
2705 
2706  transformResults.set(cast<OpResult>(getTiledLinalgOp()), tiled);
2707  for (const auto &en : llvm::enumerate(loops))
2708  transformResults.set(cast<OpResult>(getLoops()[en.index()]), en.value());
2709 
2711 }
2712 
2714  ValueRange dynamic = getDynamicSizes();
2715  ArrayRef<int64_t> tileSizes = getStaticSizes();
2716  SmallVector<OpFoldResult> results;
2717  results.reserve(tileSizes.size());
2718  unsigned dynamicPos = 0;
2719  Builder builder(getContext());
2720  for (int64_t size : tileSizes) {
2721  if (size == ShapedType::kDynamic) {
2722  results.push_back(dynamic[dynamicPos++]);
2723  } else {
2724  results.push_back(builder.getIndexAttr(size));
2725  }
2726  }
2727  return results;
2728 }
2729 
2730 // We want to parse `DenseI64ArrayAttr` using the short form without the
2731 // `array` prefix to be consistent in the IR with `parseDynamicIndexList`.
2733  OperationState &result) {
2734  if (failed(parser.parseOptionalKeyword("interchange")))
2735  return success();
2736  if (failed(parser.parseEqual()))
2737  return failure();
2738  result.addAttribute(
2739  transform::TileUsingForOp::getInterchangeAttrName(result.name),
2740  DenseI64ArrayAttr::parse(parser, Type{}));
2741  return success();
2742 }
2743 
2745  ArrayRef<int64_t> interchangeVals) {
2746  if (!interchangeVals.empty()) {
2747  p << " interchange = [";
2748  llvm::interleaveComma(interchangeVals, p,
2749  [&](int64_t integer) { p << integer; });
2750  p << "]";
2751  }
2752 }
2753 
2755  OperationState &result) {
2758  DenseI64ArrayAttr staticSizes;
2759  FunctionType functionalType;
2760  llvm::SMLoc operandLoc;
2761  DenseBoolArrayAttr scalableVals;
2762 
2763  if (parser.parseOperand(target) || parser.getCurrentLocation(&operandLoc) ||
2764  parseDynamicIndexList(parser, dynamicSizes, staticSizes, scalableVals) ||
2765  parseOptionalInterchange(parser, result) ||
2766  parser.parseOptionalAttrDict(result.attributes) ||
2767  parser.parseColonType(functionalType))
2768  return ParseResult::failure();
2769 
2770  size_t numExpectedLoops =
2771  staticSizes.size() - llvm::count(staticSizes.asArrayRef(), 0);
2772  if (functionalType.getNumResults() != numExpectedLoops + 1) {
2773  return parser.emitError(parser.getNameLoc())
2774  << "expected " << (numExpectedLoops + 1) << " result type(s)";
2775  }
2776  if (functionalType.getNumInputs() != dynamicSizes.size() + 1) {
2777  return parser.emitError(operandLoc)
2778  << "expected " << dynamicSizes.size() + 1 << " operand type(s)";
2779  }
2780  if (parser.resolveOperand(target, functionalType.getInputs().front(),
2781  result.operands) ||
2782  parser.resolveOperands(dynamicSizes,
2783  functionalType.getInputs().drop_front(),
2784  operandLoc, result.operands)) {
2785  return failure();
2786  }
2787 
2788  result.addAttribute(getScalableSizesAttrName(result.name), scalableVals);
2789 
2790  result.addAttribute(getStaticSizesAttrName(result.name), staticSizes);
2791  result.addTypes(functionalType.getResults());
2792  return success();
2793 }
2794 
2796  p << ' ' << getTarget();
2797  printDynamicIndexList(p, getOperation(), getDynamicSizes(), getStaticSizes(),
2798  /*valueTypes=*/{}, getScalableSizesAttr(),
2800  printOptionalInterchange(p, getInterchange());
2802  (*this)->getAttrs(),
2803  /*elidedAttrs=*/{getInterchangeAttrName(getOperation()->getName()),
2804  getScalableSizesAttrName(getOperation()->getName()),
2805  getStaticSizesAttrName(getOperation()->getName())});
2806  p << " : ";
2807  p.printFunctionalType(getOperands().getTypes(), getResults().getTypes());
2808 }
2809 
2810 void transform::TileUsingForOp::getEffects(
2812  consumesHandle(getTarget(), effects);
2813  onlyReadsHandle(getDynamicSizes(), effects);
2814  producesHandle(getTiledLinalgOp(), effects);
2815  producesHandle(getLoops(), effects);
2816  modifiesPayload(effects);
2817 }
2818 
2819 //===----------------------------------------------------------------------===//
2820 // TileUsingForallOp
2821 //===----------------------------------------------------------------------===//
2822 
2823 void transform::TileUsingForallOp::build(OpBuilder &builder,
2824  OperationState &result, Value target,
2825  ArrayRef<int64_t> staticTileSizes,
2827  ArrayAttr mapping) {
2828  return build(builder, result,
2829  /*target=*/target,
2830  /*mixedTileSizes=*/
2831  getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
2832  /*_=*/TileSizesSpec(),
2833  /*mapping=*/mapping);
2834 }
2835 
2836 void transform::TileUsingForallOp::build(OpBuilder &builder,
2837  OperationState &result, Value target,
2838  ArrayRef<OpFoldResult> mixedTileSizes,
2840  ArrayAttr mapping) {
2841  SmallVector<int64_t> staticTileSizes;
2842  SmallVector<Value> dynamicTileSizes;
2843  dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
2844  // Call the default builder which sets up the proper operands segment sizes
2845  // attributes for multiple variadic operands. In the absence of this,
2846  // horrible bugs ensue.
2847  MLIRContext *ctx = builder.getContext();
2848  auto operationType = transform::AnyOpType::get(ctx);
2849  auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
2850  build(builder, result,
2851  /*resultTypes=*/TypeRange{operationType, operationType},
2852  /*target=*/target,
2853  /*num_threads=*/ValueRange{},
2854  /*tile_sizes=*/dynamicTileSizes,
2855  /*packed_num_threads=*/Value(),
2856  /*packed_tile_sizes=*/Value(),
2857  /*static_num_threads=*/builder.getDenseI64ArrayAttr({}),
2858  /*static_tile_sizes=*/staticTileSizesAttr,
2859  /*mapping=*/mapping);
2860 }
2861 
2862 void transform::TileUsingForallOp::build(OpBuilder &builder,
2863  OperationState &result, Value target,
2864  ArrayRef<int64_t> staticNumThreads,
2866  ArrayAttr mapping) {
2867  return build(builder, result, target,
2868  getAsOpFoldResult(builder.getI64ArrayAttr(staticNumThreads)),
2869  NumThreadsSpec(), mapping);
2870 }
2871 
2872 void transform::TileUsingForallOp::build(OpBuilder &builder,
2873  OperationState &result, Value target,
2874  ArrayRef<OpFoldResult> mixedNumThreads,
2876  ArrayAttr mapping) {
2877  SmallVector<int64_t> staticNumThreads;
2878  SmallVector<Value> dynamicNumThreads;
2879  dispatchIndexOpFoldResults(mixedNumThreads, dynamicNumThreads,
2880  staticNumThreads);
2881  // Call the default builder which sets up the proper operands segment sizes
2882  // attributes for multiple variadic operands. In the absence of this,
2883  // horrible bugs ensue.
2884  MLIRContext *ctx = builder.getContext();
2885  auto operationType = transform::AnyOpType::get(ctx);
2886  auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads);
2887  build(builder, result,
2888  /*resultTypes=*/TypeRange{operationType, operationType},
2889  /*target=*/target,
2890  /*num_threads=*/dynamicNumThreads,
2891  /*tile_sizes=*/ValueRange{},
2892  /*packed_num_threads=*/Value(),
2893  /*packed_tile_sizes=*/Value(),
2894  /*static_num_threads=*/staticNumThreadsAttr,
2895  /*static_tile_sizes=*/builder.getDenseI64ArrayAttr({}),
2896  /*mapping=*/mapping);
2897 }
2898 
2900  RewriterBase &rewriter, transform::TransformState &state,
2901  TransformOpInterface transformOp, Operation *target,
2902  ArrayRef<OpFoldResult> mixedNumThreads,
2903  ArrayRef<OpFoldResult> mixedTileSizes, std::optional<ArrayAttr> mapping,
2904  linalg::ForallTilingResult &tilingResult) {
2905  // Transform all targets one by one.
2906  auto tileableOp = dyn_cast<TilingInterface>(target);
2907  if (!tileableOp) {
2909  transformOp.emitSilenceableError()
2910  << "only TilingInterface ops are supported";
2911  diag.attachNote(target->getLoc()) << "target op";
2912  return diag;
2913  }
2914  rewriter.setInsertionPoint(tileableOp);
2915  FailureOr<linalg::ForallTilingResult> maybeTilingResult = failure();
2916  if (!mixedNumThreads.empty()) {
2917  maybeTilingResult =
2918  linalg::tileToForallOp(rewriter, tileableOp, mixedNumThreads, mapping);
2919  } else {
2920  maybeTilingResult = linalg::tileToForallOpUsingTileSizes(
2921  rewriter, tileableOp, mixedTileSizes, mapping);
2922  }
2923 
2924  if (failed(maybeTilingResult))
2925  return transformOp.emitDefaultSilenceableFailure(tileableOp);
2926  rewriter.replaceOp(tileableOp, maybeTilingResult->tileOp->getResults());
2927 
2928  tilingResult = *maybeTilingResult;
2930 }
2931 
2932 DiagnosedSilenceableFailure transform::TileUsingForallOp::apply(
2933  transform::TransformRewriter &rewriter,
2934  transform::TransformResults &transformResults,
2935  transform::TransformState &state) {
2936  auto transformOp = cast<TransformOpInterface>(getOperation());
2937 
2938  // Result payload ops.
2939  SmallVector<Operation *> tileOps;
2940  SmallVector<Operation *> tiledOps;
2941 
2942  // Unpack handles.
2943  SmallVector<OpFoldResult> mixedNumThreads;
2945  getPackedNumThreads()
2947  state, transformOp, mixedNumThreads, getPackedNumThreads())
2949  state, transformOp, mixedNumThreads, getMixedNumThreads());
2950  if (!status.succeeded())
2951  return status;
2952  SmallVector<OpFoldResult> mixedTileSizes;
2953  status = getPackedTileSizes()
2955  state, transformOp, mixedTileSizes, getPackedTileSizes())
2957  state, transformOp, mixedTileSizes, getMixedTileSizes());
2958  if (!status.succeeded())
2959  return status;
2960 
2961  for (Operation *target : state.getPayloadOps(getTarget())) {
2962  linalg::ForallTilingResult tilingResult;
2964  rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes,
2965  getMapping(), tilingResult);
2966  if (!diag.succeeded())
2967  return diag;
2968  tileOps.push_back(tilingResult.tileOp);
2969  tiledOps.push_back(tilingResult.tiledOp);
2970  }
2971 
2972  transformResults.set(cast<OpResult>(getForallOp()), tileOps);
2973  transformResults.set(cast<OpResult>(getTiledOp()), tiledOps);
2974 
2976 }
2977 
2978 void transform::TileUsingForallOp::getEffects(
2980  consumesHandle(getTarget(), effects);
2981  onlyReadsHandle(getTileSizes(), effects);
2982  onlyReadsHandle(getNumThreads(), effects);
2983  onlyReadsHandle(getPackedNumThreads(), effects);
2984  onlyReadsHandle(getPackedTileSizes(), effects);
2985  producesHandle(getResults(), effects);
2986  modifiesPayload(effects);
2987 }
2988 
2989 SmallVector<OpFoldResult> TileUsingForallOp::getMixedNumThreads() {
2990  Builder b(getContext());
2991  return getMixedValues(getStaticNumThreads(), getNumThreads(), b);
2992 }
2993 
2994 SmallVector<OpFoldResult> TileUsingForallOp::getMixedTileSizes() {
2995  Builder b(getContext());
2996  return getMixedValues(getStaticTileSizes(), getTileSizes(), b);
2997 }
2998 
3000  int numThreadsSpec = static_cast<int>(!getMixedNumThreads().empty()) +
3001  static_cast<int>(getPackedNumThreads() != Value());
3002  if (numThreadsSpec > 1)
3003  return emitOpError(
3004  "num_threads and packed_num_threads are mutually exclusive");
3005  int tileSizesSpec = static_cast<int>(!getMixedTileSizes().empty()) +
3006  static_cast<int>(getPackedTileSizes() != Value());
3007  if (tileSizesSpec > 1)
3008  return emitOpError(
3009  "tile_sizes and packed_tile_sizes are mutually exclusive");
3010  if (numThreadsSpec == 0 && tileSizesSpec == 0)
3011  return emitOpError("either (packed_)num_threads or (packed_)tile_sizes "
3012  "must be specified");
3013  return success();
3014 }
3015 
3016 //===----------------------------------------------------------------------===//
3017 // VectorizeChildrenAndApplyPatternsOp
3018 //===----------------------------------------------------------------------===//
3019 
3020 void transform::VectorizeChildrenAndApplyPatternsOp::build(
3021  OpBuilder &builder, OperationState &result, Value target,
3022  bool vectorizePadding, bool vectorizeExtract, bool flatten1DDepthwiseConv) {
3023  result.addOperands(target);
3024  if (vectorizePadding) {
3025  result.addAttribute(
3026  VectorizeChildrenAndApplyPatternsOp::getVectorizePaddingAttrName(
3027  result.name),
3028  builder.getUnitAttr());
3029  }
3030  if (vectorizeExtract) {
3031  result.addAttribute(
3032  VectorizeChildrenAndApplyPatternsOp::getVectorizeNdExtractAttrName(
3033  result.name),
3034  builder.getUnitAttr());
3035  }
3036  if (flatten1DDepthwiseConv) {
3037  result.addAttribute(
3038  VectorizeChildrenAndApplyPatternsOp::getFlatten_1dDepthwiseConvAttrName(
3039  result.name),
3040  builder.getUnitAttr());
3041  }
3042  result.addTypes(transform::AnyOpType::get(builder.getContext()));
3043 }
3044 
3045 namespace {
3046 /// This is an helper only to call vectorize via a pattern inside of
3047 /// VectorizeChildrenAndApplyPatternsOp::applyToOne.
3048 struct VectorizationPattern : public RewritePattern {
3049  explicit VectorizationPattern(MLIRContext *context,
3050  bool vectorizeExtract = false,
3051  bool flattenConv = false)
3052  : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context),
3053  vectorizeNDExtract(vectorizeExtract),
3054  flatten1DDepthwiseConv(flattenConv) {}
3055  LogicalResult matchAndRewrite(Operation *op,
3056  PatternRewriter &rewriter) const override {
3057  LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
3058  if (!linalgOp)
3059  return rewriter.notifyMatchFailure(op, "expected Linalg Op");
3060  return vectorize(rewriter, linalgOp, /*inputVectorSizes=*/{},
3061  /*scalableVecDims=*/{}, vectorizeNDExtract,
3062  flatten1DDepthwiseConv);
3063  }
3064 
3065 private:
3066  /// Controls whether to vectorize `tensor.extract` when the input tensor is
3067  /// rank >= 2.
3068  bool vectorizeNDExtract = false;
3069  /// Controls whether to "flatten" the channel dimension when vectorising 1D
3070  /// depthwise convolutions. This should lead to bette vectorization for
3071  /// tensors with a low number of channel dimensions.
3072  bool flatten1DDepthwiseConv = false;
3073 };
3074 } // namespace
3075 
3077 transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
3078  transform::TransformRewriter &rewriter, Operation *target,
3080  transform::TransformState &state) {
3081  if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
3082  auto diag = this->emitOpError("requires isolated-from-above targets");
3083  diag.attachNote(target->getLoc()) << "non-isolated target";
3085  }
3086 
3087  MLIRContext *ctx = getContext();
3088  RewritePatternSet patterns(ctx);
3089  patterns.add<VectorizationPattern>(ctx, getVectorizeNdExtract(),
3090  getFlatten_1dDepthwiseConv());
3091 
3092  if (!getDisableTransferPermutationMapLoweringPatterns())
3094 
3095  if (!getDisableMultiReductionToContractPatterns())
3097 
3099 
3102  /*benefit=*/2);
3103  vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx);
3104  vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx);
3106 
3107  patterns.add<CopyVectorizationPattern>(ctx);
3108 
3109  if (getVectorizePadding())
3111 
3112  TrackingListener listener(state, *this);
3113  GreedyRewriteConfig config;
3114  config.listener = &listener;
3115  if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns), config)))
3116  return emitDefaultDefiniteFailure(target);
3117 
3118  results.push_back(target);
3120 }
3121 
3122 //===----------------------------------------------------------------------===//
3123 // VectorizeOp
3124 //===----------------------------------------------------------------------===//
3125 
3126 static const StringLiteral kVectorSizesKeyword = "vector_sizes";
3127 
3129  OperationState &result) {
3132  DenseI64ArrayAttr staticSizes;
3133  SmallVector<Type> operandTypes;
3134  llvm::SMLoc operandLoc;
3135  DenseBoolArrayAttr scalableVals;
3136 
3137  if (parser.parseOperand(target) || parser.getCurrentLocation(&operandLoc))
3138  return ParseResult::failure();
3139 
3141  if (failed(parseDynamicIndexList(parser, dynamicSizes, staticSizes,
3142  scalableVals)))
3143  return ParseResult::failure();
3144  }
3145 
3146  if (succeeded(parser.parseOptionalKeyword(
3147  getVectorizeNdExtractAttrName(result.name))))
3148  result.addAttribute(getVectorizeNdExtractAttrName(result.name),
3149  parser.getBuilder().getUnitAttr());
3150 
3151  if (parser.parseOptionalAttrDict(result.attributes) ||
3152  parser.parseColonTypeList(operandTypes))
3153  return ParseResult::failure();
3154 
3155  if (operandTypes.size() != dynamicSizes.size() + 1) {
3156  return parser.emitError(operandLoc)
3157  << "expected " << dynamicSizes.size() + 1 << " operand type(s)";
3158  }
3159  if (parser.resolveOperand(target, operandTypes.front(), result.operands) ||
3160  parser.resolveOperands(dynamicSizes, ArrayRef(operandTypes).drop_front(),
3161  operandLoc, result.operands)) {
3162  return failure();
3163  }
3164 
3165  if (scalableVals)
3166  result.addAttribute(getScalableSizesAttrName(result.name), scalableVals);
3167  if (staticSizes)
3168  result.addAttribute(getStaticVectorSizesAttrName(result.name), staticSizes);
3169 
3170  return success();
3171 }
3172 
3174  p << ' ' << getTarget() << ' ';
3175  if (!getMixedVectorSizes().empty()) {
3176  p << kVectorSizesKeyword << ' ';
3177  printDynamicIndexList(p, getOperation(), getVectorSizes(),
3178  getStaticVectorSizesAttr(),
3179  /*valueTypes=*/{}, getScalableSizesAttr(),
3181  }
3182 
3183  if (getVectorizeNdExtract())
3184  p << getVectorizeNdExtractAttrName() << ' ';
3185 
3187  (*this)->getAttrs(),
3188  /*elidedAttrs=*/{
3189  getScalableSizesAttrName(getOperation()->getName()),
3190  getStaticVectorSizesAttrName(getOperation()->getName())});
3191  p << " : ";
3192  p << getTarget().getType();
3193  if (!getVectorSizes().empty()) {
3194  p << ", ";
3195  llvm::interleaveComma(getVectorSizes(), p,
3196  [&](Value operand) { p << operand.getType(); });
3197  }
3198 }
3199 
3200 DiagnosedSilenceableFailure transform::VectorizeOp::apply(
3201  transform::TransformRewriter &rewriter,
3202  mlir::transform::TransformResults &transformResults,
3204  auto targets = state.getPayloadOps(getTarget());
3205  if (std::empty(targets))
3207 
3208  SmallVector<int64_t> vectorSizes;
3209  for (OpFoldResult sz : getMixedVectorSizes()) {
3210  if (sz.is<Attribute>()) {
3211  auto attr = sz.get<Attribute>();
3212  vectorSizes.push_back(cast<IntegerAttr>(attr).getInt());
3213  continue;
3214  } else if (sz.is<Value>() && isa<ParamType>(sz.get<Value>().getType())) {
3215  ArrayRef<Attribute> params = state.getParams(sz.get<Value>());
3216  if (params.size() != 1)
3217  return emitSilenceableFailure(getLoc()) << "expected a single param";
3218  vectorSizes.push_back(
3219  cast<IntegerAttr>(params.front()).getValue().getSExtValue());
3220  continue;
3221  }
3222 
3223  auto szPayloads = state.getPayloadOps(sz.get<Value>());
3224  if (!llvm::hasSingleElement(szPayloads)) {
3225  auto diag = this->emitOpError(
3226  "requires vector size handle that is mapped to 1 payload op");
3227  diag.attachNote(sz.get<Value>().getLoc())
3228  << "mapped to " << llvm::range_size(szPayloads) << " payload ops";
3230  }
3231 
3232  Operation *szPayloadOp = *szPayloads.begin();
3233  if (szPayloadOp->getNumResults() != 1 ||
3234  !szPayloadOp->getResult(0).getType().isIndex()) {
3235  auto diag = this->emitOpError(
3236  "requires vector size payload op with 1 index result");
3237  diag.attachNote(szPayloadOp->getLoc()) << "vector size payload op";
3239  }
3240 
3241  IntegerAttr attr;
3242  if (!matchPattern(szPayloadOp->getResult(0), m_Constant(&attr))) {
3243  auto diag = this->emitOpError("requires constant vector size");
3244  diag.attachNote(szPayloadOp->getLoc()) << "vector size payload op";
3246  }
3247 
3248  vectorSizes.push_back(attr.getInt());
3249  }
3250 
3251  // TODO: Check that the correct number of vectorSizes was provided.
3252  for (Operation *target : targets) {
3253  if (!isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp, tensor::UnPackOp>(
3254  target)) {
3255  return mlir::emitSilenceableFailure(target->getLoc())
3256  << "Unsupported Op, cannot vectorize";
3257  }
3258 
3259  if (failed(linalg::vectorize(rewriter, target, vectorSizes,
3260  getScalableSizes(),
3261  getVectorizeNdExtract().has_value()
3262  ? getVectorizeNdExtract().value()
3263  : false))) {
3264  return mlir::emitSilenceableFailure(target->getLoc())
3265  << "Attempted to vectorize, but failed";
3266  }
3267  }
3268 
3270 }
3271 
3272 void transform::VectorizeOp::getEffects(
3274  consumesHandle(getTarget(), effects);
3275  onlyReadsHandle(getVectorSizes(), effects);
3276  modifiesPayload(effects);
3277 }
3278 
3279 SmallVector<OpFoldResult> VectorizeOp::getMixedVectorSizes() {
3280  OpBuilder b(getContext());
3281  return getMixedValues(getStaticVectorSizes(), getVectorSizes(), b);
3282 }
3283 
3285  if (getStaticVectorSizes().size() != getScalableSizes().size())
3286  return emitOpError("expected same number of vector sizes (")
3287  << getStaticVectorSizes().size() << ") and scalable sizes ("
3288  << getScalableSizes().size() << ")";
3289  return success();
3290 }
3291 
3292 //===----------------------------------------------------------------------===//
3293 // HoistRedundantVectorTransfersOp
3294 //===----------------------------------------------------------------------===//
3295 
3297 transform::HoistRedundantVectorTransfersOp::applyToOne(
3298  transform::TransformRewriter &rewriter, func::FuncOp target,
3300  transform::TransformState &state) {
3301  // WARNING: This hoisting does not model parallelism and is generally
3302  // incorrect when used on distributed loops with memref semantics!
3303  // TODO: obsolete and should be retired.
3305  results.push_back(target);
3307 }
3308 
3309 //===----------------------------------------------------------------------===//
3310 // ConvertConv2DToImg2ColOp.
3311 //===----------------------------------------------------------------------===//
3312 
3313 DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne(
3314  transform::TransformRewriter &rewriter, linalg::LinalgOp target,
3316  transform::TransformState &state) {
3317  rewriter.setInsertionPoint(target);
3318  auto maybeTransformed =
3320  target)
3321  .Case([&](linalg::Conv2DNhwcHwcfOp op) {
3322  return rewriteInIm2Col(rewriter, op);
3323  })
3324  .Case([&](linalg::Conv2DNhwcFhwcOp op) {
3325  return rewriteInIm2Col(rewriter, op);
3326  })
3327  .Case([&](linalg::DepthwiseConv2DNhwcHwcOp op) {
3328  return rewriteInIm2Col(rewriter, op);
3329  })
3330  .Case([&](linalg::Conv2DNchwFchwOp op) {
3331  return rewriteInIm2Col(rewriter, op);
3332  })
3333  .Default([&](Operation *op) {
3334  return rewriter.notifyMatchFailure(op, "not supported");
3335  });
3336  if (failed(maybeTransformed))
3337  return emitDefaultSilenceableFailure(target);
3338  // Handle to the operation producing the img2col tensor.
3339  results.push_back(maybeTransformed->first);
3340  // Handle to the operation that replaces the original convolution.
3341  results.push_back(maybeTransformed->second);
3343 }
3344 
3345 //===----------------------------------------------------------------------===//
3346 // FlattenElementwiseLinalgOp.
3347 //===----------------------------------------------------------------------===//
3348 
3349 DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne(
3350  transform::TransformRewriter &rewriter, linalg::LinalgOp target,
3352  transform::TransformState &state) {
3353  rewriter.setInsertionPoint(target);
3354  if (!isElementwise(target))
3355  return mlir::emitSilenceableFailure(target->getLoc())
3356  << "only elementwise flattening is supported";
3357 
3358  // If rank <= 1, do nothing
3359  if (target.getNumLoops() <= 1) {
3360  results.push_back(target);
3362  }
3363 
3364  // Attempt to flatten all dims to one.
3365  ReassociationIndices reassociation(target.getNumLoops());
3366  std::iota(reassociation.begin(), reassociation.end(), 0);
3367  auto maybeFlattened =
3368  collapseOpIterationDims(target, reassociation, rewriter);
3369  if (failed(maybeFlattened))
3370  return mlir::emitSilenceableFailure(target->getLoc())
3371  << "attempted to flatten, but failed";
3372  results.push_back(maybeFlattened->collapsedOp);
3373  rewriter.replaceOp(target, maybeFlattened->results);
3375 }
3376 
3377 //===----------------------------------------------------------------------===//
3378 // TransposeConv2DOp
3379 //===----------------------------------------------------------------------===//
3380 
3381 DiagnosedSilenceableFailure transform::TransposeConv2DOp::applyToOne(
3382  transform::TransformRewriter &rewriter, linalg::LinalgOp target,
3384  transform::TransformState &state) {
3385  rewriter.setInsertionPoint(target);
3386  auto maybeTransformed =
3388  .Case([&](linalg::Conv2DNhwcFhwcOp op) {
3389  return transposeConv2D(rewriter, op);
3390  })
3391  .Case([&](linalg::Conv2DNhwcFhwcQOp op) {
3392  return transposeConv2D(rewriter, op);
3393  })
3394  .Default([&](Operation *op) {
3395  return rewriter.notifyMatchFailure(op, "not supported");
3396  });
3397  if (failed(maybeTransformed))
3398  return emitDefaultSilenceableFailure(target);
3399  // Handle to the new Conv2D operation with transposed filters
3400  results.push_back(*maybeTransformed);
3402 }
3403 
3404 //===----------------------------------------------------------------------===//
3405 // InsertSliceToCopyOp
3406 //===----------------------------------------------------------------------===//
3407 template <typename OpTy>
3410  transform::TransformState &state) {
3411  static_assert(llvm::is_one_of<OpTy, tensor::InsertSliceOp,
3412  tensor::ParallelInsertSliceOp>() &&
3413  "wrong op type");
3414 
3415  if (auto copySource =
3416  target.getSource().template getDefiningOp<linalg::CopyOp>()) {
3417  results.push_back(copySource);
3419  }
3420 
3421  // If we are inside an InParallel region, temporarily set the insertion point
3422  // outside: only tensor.parallel_insert_slice ops are allowed in there.
3423  if constexpr (std::is_same_v<OpTy, tensor::ParallelInsertSliceOp>) {
3424  rewriter.setInsertionPoint(
3425  target->template getParentOfType<scf::InParallelOp>());
3426  }
3427 
3428  Value extracted = rewriter.create<tensor::ExtractSliceOp>(
3429  target.getLoc(), target.getDest(), target.getMixedOffsets(),
3430  target.getMixedSizes(), target.getMixedStrides());
3431  Value copied = rewriter
3432  .create<linalg::CopyOp>(target.getLoc(),
3433  target.getSource(), extracted)
3434  .getResult(0);
3435  // Reset the insertion point.
3436  rewriter.setInsertionPoint(target);
3437  rewriter.replaceOpWithNewOp<OpTy>(
3438  target, copied, target.getDest(), target.getMixedOffsets(),
3439  target.getMixedSizes(), target.getMixedStrides());
3440 
3441  results.push_back(copied.getDefiningOp());
3443 }
3444 
3445 DiagnosedSilenceableFailure transform::InsertSliceToCopyOp::applyToOne(
3446  transform::TransformRewriter &rewriter, Operation *targetOp,
3448  transform::TransformState &state) {
3449 
3450  rewriter.setInsertionPoint(targetOp);
3451  if (auto target = dyn_cast<tensor::InsertSliceOp>(targetOp))
3452  return doit(rewriter, target, results, state);
3453  if (auto target = dyn_cast<tensor::ParallelInsertSliceOp>(targetOp))
3454  return doit(rewriter, target, results, state);
3455 
3457  emitSilenceableError()
3458  << "only InsertSliceOp and ParallelInsertSliceOp ops are supported";
3459  diag.attachNote(targetOp->getLoc()) << "target op";
3460  return diag;
3461 }
3462 
3463 //===----------------------------------------------------------------------===//
3464 // MapCopyToThreadsOp
3465 //===----------------------------------------------------------------------===//
3466 
3467 DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
3468  transform::TransformRewriter &rewriter, Operation *target,
3470  transform::TransformState &state) {
3471  // Check if the op is supported.
3472  if (!isa<linalg::CopyOp, tensor::PadOp>(target)) {
3474  emitSilenceableError()
3475  << "only linalg.copy and tensor.pad target ops are supported";
3476  diag.attachNote(target->getLoc()) << "target op";
3477  return diag;
3478  }
3479  assert(target->getNumResults() == 1 && "expected single result");
3480  auto resultShapedType = cast<ShapedType>(target->getResult(0).getType());
3481  if (!resultShapedType.hasStaticShape()) {
3483  emitSilenceableError()
3484  << "only statically sized ops of rank <= 3 are supported";
3485  diag.attachNote(target->getLoc()) << "target op";
3486  return diag;
3487  }
3488 
3489  // Conservatively set the minimum viable desired bitwidth alignment.
3490  int64_t desiredBitAlignment = getDesiredBitAlignment();
3491  int64_t eltBitwidth =
3492  resultShapedType.getElementType().getIntOrFloatBitWidth();
3493  if (desiredBitAlignment % eltBitwidth != 0) {
3494  desiredBitAlignment = eltBitwidth;
3495  }
3496 
3497  gpu::CopyMappingInfo mapping(
3498  /*ctx=*/getContext(),
3499  /*totalNumThreads=*/getTotalNumThreads(),
3500  /*alignment=*/desiredBitAlignment,
3501  /*sizes=*/resultShapedType.getShape(),
3502  /*favorPredication=*/false,
3503  /*elementalBitwidth=*/
3504  resultShapedType.getElementType().getIntOrFloatBitWidth());
3505  if (mapping.status == gpu::CopyMappingInfo::Status::Invalid) {
3507  emitSilenceableError()
3508  << "too few threads to map copy op to threads on the most minor "
3509  "dimension, given alignment and vector size constraints, try "
3510  "smaller tile size of mapping to more threads";
3511  diag.attachNote(target->getLoc()) << "target op";
3512  return diag;
3513  }
3514 
3515  // OpBuilder only used to compute attributes.
3516  OpBuilder b(getContext());
3517  linalg::ForallTilingResult tilingResult;
3519  /*rewriter=*/rewriter,
3520  /*state=*/state,
3521  /*transformOp=*/*this,
3522  /*target=*/target,
3523  /*mixedNumThreads=*/getMixedValues(mapping.numThreads, {}, b),
3524  /*mixedTileSizes=*/ArrayRef<OpFoldResult>{},
3525  /*mapping=*/b.getArrayAttr(mapping.threadMapping),
3526  /*tilingResult=*/tilingResult);
3527  if (!diag.succeeded())
3528  return diag;
3529 
3530  results.push_back(tilingResult.tileOp);
3531  results.push_back(tilingResult.tiledOp);
3533 }
3534 
3535 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
3536 
3537 #define GET_OP_CLASSES
3538 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
static MLIRContext * getContext(OpFoldResult val)
DiagnosedSilenceableFailure doit(RewriterBase &rewriter, OpTy target, transform::ApplyToEachResultList &results, transform::TransformState &state)
static Operation * cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
void printOptionalInterchange(OpAsmPrinter &p, ArrayRef< int64_t > interchangeVals)
#define DOWNSCALE(trans)
bool isValidPackingPermutation(RelayoutOpTy op, ArrayRef< int64_t > permutation, OuterOrInnerPerm outerOrInnerPerm=OuterOrInnerPerm::Outer)
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static const StringLiteral kVectorSizesKeyword
static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(transform::TransformState &state, TransformOpInterface transformOp, SmallVector< OpFoldResult > &result, ArrayRef< OpFoldResult > ofrs)
Assuming that ofr is an index attr or a param of index type or a transform dialect handle mapped to e...
#define DOWNSCALE_NORMAL(a, b)
static FailureOr< LinalgOp > tryApply(Operation *operation, Args &&...args)
Attempts to apply the pattern specified as template argument to the given operation.
static void printMultitileSizesTypes(OpAsmPrinter &printer, Operation *op, Type targetType, Type lowSizeType, Type, Type)
static Operation * replaceForAllWithNewSignature(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp, TilingResult &tileAndFuseResult, int64_t resultNumber, SmallVector< OpFoldResult > &offsets, SmallVector< OpFoldResult > &sizes)
Add new operands to the forall op for users of the producerOp that are dominated by the containing sc...
static SmallVector< Operation * > tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
First, find the first "scf::ForallOp" user of producerOp and ensure it is exactly the containingOp,...
static ParseResult parseMultitileSizesTypes(OpAsmParser &parser, Type &targetType, Type &lowSizeType, Type &highSizeType, Type &splitPointType)
ParseResult parseOptionalInterchange(OpAsmParser &parser, OperationState &result)
#define DBGS()
static LogicalResult applyTilingToAll(RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps, unsigned numLoops, transform::TransformResults &transformResults, function_ref< FailureOr< scf::SCFTileAndFuseResult >(TilingInterface)> applyFn)
Apply a tiling transformation to all payload ops and store both the tiled operation as well as the cr...
static std::tuple< SmallVector< Operation * >, Operation * > tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
Find the first "extract" user of producerOp and tile it right before its use.
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static FailureOr< ForallTilingResult > tileToForallOpImpl(RewriterBase &b, TilingInterface op, ArrayRef< OpFoldResult > numThreads, std::optional< ArrayRef< OpFoldResult >> nominalTileSizes, std::optional< ArrayAttr > mapping, bool omitTileOffsetBoundsCheck)
Rewrite a TilingInterface op to a tiled scf.forall.
Definition: Tiling.cpp:343
Base type for affine expression.
Definition: AffineExpr.h:69
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
U cast() const
Definition: Attributes.h:189
This class represents an argument of a Block.
Definition: Value.h:315
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
UnitAttr getUnitAttr()
Definition: Builders.cpp:114
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:238
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:183
AffineExpr getAffineSymbolExpr(unsigned position)
Definition: Builders.cpp:375
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:128
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:93
MLIRContext * getContext() const
Definition: Builders.h:55
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:273
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:288
IndexType getIndexType()
Definition: Builders.cpp:71
ArrayAttr getStrArrayAttr(ArrayRef< StringRef > values)
Definition: Builders.cpp:313
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:156
A class for computing basic dominance information.
Definition: Dominance.h:136
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition: Dominance.h:156
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class allows control over how the GreedyPatternRewriteDriver works.
RewriterBase::Listener * listener
An optional listener that should be notified about IR modifications.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:202
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
Definition: AsmPrinter.cpp:93
This class represents a saved insertion point.
Definition: Builders.h:329
bool isSet() const
Returns true if this insert point is set.
Definition: Builders.h:339
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:553
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
Definition: Builders.h:318
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:322
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:414
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:263
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:453
This class provides the API for ops that are known to be isolated from above.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getOpResult(unsigned idx)
Definition: Operation.h:416
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:745
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:529
void setOperand(unsigned idx, Value value)
Definition: Operation.h:346
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition: Operation.h:555
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:669
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:341
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_type_range getOperandTypes()
Definition: Operation.h:392
result_type_range getResultTypes()
Definition: Operation.h:423
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition: Operation.h:263
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:869
result_range getOpResults()
Definition: Operation.h:415
result_range getResults()
Definition: Operation.h:410
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:219
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:246
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:638
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
U cast() const
Definition: Types.h:340
bool isIndex() const
Definition: Types.cpp:56
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
Type front()
Return first type in the range.
Definition: TypeRange.h:148
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
user_range getUsers() const
Definition: Value.h:224
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
State for analysis-enabled bufferization.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static Attribute parse(AsmParser &parser, Type type)
Parse the short form [42, 100, -1] without any type prefix.
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
void assign(unsigned size, std::nullptr_t)
Sets the list of results to size null pointers.
void reserve(unsigned size)
Reserves space for size elements in the list.
size_t size() const
Returns the number of elements in the list.
void push_back(Operation *op)
Appends an element to the list.
A listener that updates a TransformState based on IR modifications.
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void setValues(OpResult handle, Range &&values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
LogicalResult notifyPayloadOperationReplaced(Operation *op, Operation *replacement)
Notify the transform dialect interpreter that the given op has been replaced with another op and that...
The state maintained across applications of various ops implementing the TransformOpInterface.
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
Definition: AffineOps.cpp:1235
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1138
LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze op and its nested ops.
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition: Visitors.h:137
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
FailureOr< PackingResult > buildPackingLoopNest(RewriterBase &rewriter, tensor::PadOp opToHoist, scf::ForOp outermostEnclosingForOp, ArrayRef< int64_t > transposeVector)
Build the packing loop nest required to hoist opToHoist above outermostEnclosingForOp.
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp namedOp)
Create a GenericOp from the given named operation namedOp and replace namedOp.
LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, const LinalgPaddingOptions &options, LinalgOp &paddedOp, SmallVector< Value > &replacements, SmallVector< tensor::PadOp > &padOps)
Pad the iterator dimensions paddingDimensions of all opToPad operands to a static bounding box.
Definition: Padding.cpp:153
FailureOr< std::pair< Operation *, Operation * > > rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp)
Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) and linalg....
FailureOr< ForallTilingResult > tileToForallOpUsingTileSizes(RewriterBase &builder, TilingInterface op, ArrayRef< OpFoldResult > tileSizes, std::optional< ArrayAttr > mapping)
Same as tileToForallOp, but calculate the number of threads required using the given tileSizes.
Definition: Tiling.cpp:467
FailureOr< LowerUnPackOpResult > lowerUnPack(RewriterBase &rewriter, tensor::UnPackOp unPackOp)
Rewrite pack as empty + transpose + reshape + extract_slice.
Definition: Transforms.cpp:355
std::optional< Value > allocateWorkgroupMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU workgroup memory.
Definition: Promotion.cpp:470
Value bufferizeToAllocation(RewriterBase &rewriter, const BufferizeToAllocationOptions &options, tensor::PadOp padOp, Attribute memorySpace={}, Operation *insertionPoint=nullptr)
Materialize a buffer allocation for the given tensor.pad op and lower the op to linalg....
FailureOr< LinalgOp > specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp)
Create a namedOp from the given GenericOp and replace the GenericOp.
Definition: Specialize.cpp:24
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns)
Definition: Tiling.cpp:911
LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value)
In case of GPU private memory there is no need to deallocate since the memory is freed when going out...
Definition: Promotion.cpp:511
std::optional< Value > allocateGPUPrivateMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU private memory.
Definition: Promotion.cpp:495
FailureOr< Operation * > rewriteInDestinationPassingStyle(RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp)
Rewrite tensor.from_elements to linalg.generic.
FailureOr< ForallTilingResult > tileToForallOp(RewriterBase &builder, TilingInterface op, ArrayRef< OpFoldResult > numThreads, std::optional< ArrayAttr > mapping)
Definition: Tiling.cpp:458
void hoistRedundantVectorTransfers(Operation *root)
Hoist vector.transfer_read/vector.transfer_write on buffers pairs out of immediately enclosing scf::F...
Definition: Hoisting.cpp:76
FailureOr< Operation * > transposeConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op)
Convert linalg.conv_2d_nhwc_fhwc(_q) to linalg.conv_2d_nhwc_hwcf(_q) by materializing transpose.
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
Patterns to fold unit-extent dimensions in operands/results of linalg ops on tensors via reassociativ...
LogicalResult copyToWorkgroupMemory(OpBuilder &b, Value src, Value dst)
Create Memref copy operations and add gpu barrier guards before and after the copy operation to ensur...
Definition: Promotion.cpp:486
FailureOr< Value > hoistPaddingOnTensors(RewriterBase &rewriter, tensor::PadOp opToHoist, int64_t numLoops, ArrayRef< int64_t > transposeVector, tensor::PadOp &hoistedOp, SmallVectorImpl< GenericOp > &transposeOps)
Mechanically hoist padding operations on tensors by numLoops into a new, generally larger tensor.
LogicalResult linalgOpAnchoredEmptyTensorEliminationStep(RewriterBase &rewriter, Operation *op, bufferization::OneShotAnalysisState &state)
Try to eliminate tensor::EmptyOps inside op that are anchored on a LinalgOp.
LogicalResult vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Emit a suitable vector form for an operation.
LogicalResult promoteSubviewsPrecondition(Operation *op, LinalgPromotionOptions options)
Promote memref.subviews feeding linalg-on-buffers operations.
Definition: Promotion.cpp:399
LogicalResult copyToGPUPrivateMemory(OpBuilder &b, Value src, Value dst)
Normal copy to between src and dst.
Definition: Promotion.cpp:503
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
Definition: Utils.cpp:169
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.
Definition: Interchange.cpp:50
FailureOr< StaticMultiSizeSpecification > computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, int64_t targetSize, int64_t divisor)
Definition: Tiling.cpp:111
FailureOr< SplitReductionResult > splitReduction(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
FailureOr< ForallReductionTilingResult > tileReductionUsingForall(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > tileSizes={}, std::optional< ArrayAttr > mapping=std::nullopt)
Method to tile a reduction to parallel iterations computing partial reductions.
Definition: Tiling.cpp:647
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
Definition: Transforms.cpp:777
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
Definition: Transforms.cpp:488
void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns)
Patterns to promote inputs to outputs and remove unused inputs of linalg.generic ops.
FailureOr< LowerPackResult > lowerPack(RewriterBase &rewriter, tensor::PackOp packOp)
Rewrite pack as pad + reshape + transpose.
Definition: Transforms.cpp:219
FailureOr< LinalgOp > promoteSubViews(OpBuilder &b, LinalgOp op, const LinalgPromotionOptions &options)
Promote the subViews into a new buffer allocated at the insertion point b.
Definition: Promotion.cpp:421
std::function< SplitReductionOptions(LinalgOp op)> ControlSplitReductionFn
Function signature to control reduction splitting.
Definition: Transforms.h:443
LogicalResult deallocateWorkgroupMemory(OpBuilder &, Value)
In case of GPU group memory there is no need to deallocate.
Definition: Promotion.cpp:479
FailureOr< CollapseResult > collapseOpIterationDims(LinalgOp op, ArrayRef< ReassociationIndices > foldedIterationDims, RewriterBase &rewriter)
Collapses dimensions of linalg.generic/linalg.copy operation.
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, tensor::PackOp packOp, linalg::LinalgOp linalgOp, tensor::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
Definition: Transforms.cpp:686
std::pair< TilingInterface, TilingInterface > splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension, OpFoldResult splitPoint)
Split the given op into two parts along the given iteration space dimension at the specified splitPoi...
Definition: Split.cpp:67
FailureOr< SplitReductionResult > splitReductionByScaling(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
Scaling-based implementation of the split reduction transformation.
FailureOr< MultiSizeSpecification > computeMultiTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension, OpFoldResult targetSize, OpFoldResult divisor, bool emitAssertions=true)
Emits the IR computing the multi-sized tiling specification with two tile sizes not exceeding targetS...
Definition: Tiling.cpp:137
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:77
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
FailureOr< scf::SCFReductionTilingResult > tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > tileSize)
Method to tile a reduction and generate a parallel op within a serial loop.
FailureOr< SCFTilingResult > tileUsingSCF(RewriterBase &rewriter, TilingInterface op, const SCFTilingOptions &options)
Method to tile an op that implements the TilingInterface using scf.for for iterating over the tiles.
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
Definition: SCF.cpp:597
FailureOr< SmallVector< scf::ForOp > > lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op)
Method to lower an op that implements the TilingInterface to loops/scalars.
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:105
void populateFoldTensorSubsetIntoVectorTransferPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into vector transfer ops.
DiagnosedSilenceableFailure tileToForallOpImpl(RewriterBase &rewriter, transform::TransformState &state, TransformOpInterface transformOp, Operation *target, ArrayRef< OpFoldResult > mixedNumThreads, ArrayRef< OpFoldResult > mixedTileSizes, std::optional< ArrayAttr > mapping, linalg::ForallTilingResult &tilingResult)
Implementation of tiling operations using scf.forall.
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void onlyReadsHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void consumesHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void producesHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
void populateSinkVectorBroadcastPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant vector broadcasts.
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Definition: XeGPUOps.cpp:21
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, TypeRange valueTypes=TypeRange(), ArrayRef< bool > scalables={}, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hook for custom directive in assemblyFormat.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableVals, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hook for custom directive in assemblyFormat.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
static LogicalResult failure(bool isFailure=true)
If isFailure is true a failure result is generated, otherwise a 'success' result is generated.
Definition: LogicalResult.h:36
This is the representation of an operand reference.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Definition: Builders.h:287
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
NamedAttrList attributes
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
A listener that forwards all notifications to another listener.
Definition: PatternMatch.h:463
ForwardingListener(OpBuilder::Listener *listener)
Definition: PatternMatch.h:464
Container for result values of tiling.
SmallVector< Value > tiledValues
Options for analysis-enabled bufferization.
Transformation to drop unit-extent dimensions from linalg.generic operations.
Definition: Transforms.h:474
Vectorization pattern for memref::CopyOp.
Definition: Transforms.h:1331
Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) dimensions into 1-D depthwise c...
Definition: Transforms.h:1277
Rewrite a TilingInterface op to a tiled scf.forall, applying tiling by numThreads.
Definition: Transforms.h:857
Match and rewrite for the pattern:
Definition: Transforms.h:1404
Match and rewrite for the pattern:
Definition: Transforms.h:1432
LinalgPromotionOptions & setUseFullTileBuffersByDefault(bool use)
Definition: Transforms.h:380
LinalgPromotionOptions & setAlignment(unsigned align)
Definition: Transforms.h:386
LinalgPromotionOptions & setUseAlloca(bool use)
Definition: Transforms.h:399
LinalgPromotionOptions & setCopyInOutFns(CopyCallbackFn const &copyIn, CopyCallbackFn const &copyOut)
Definition: Transforms.h:419
LinalgPromotionOptions & setUseFullTileBuffers(ArrayRef< bool > useFullTiles)
Definition: Transforms.h:369
LinalgPromotionOptions & setMemorySpace(Attribute memorySpc)
Definition: Transforms.h:393
LinalgPromotionOptions & setAllocationDeallocationFns(AllocBufferCallbackFn const &allocFn, DeallocBufferCallbackFn const &deallocFn)
Definition: Transforms.h:409
LinalgPromotionOptions & setOperandsToPromote(ArrayRef< int64_t > operands)
Definition: Transforms.h:358
Split Reduction options.
Definition: Transforms.h:428
Options used to control tile + fuse.
SCFTilingOptions tilingOptions
The tiling options used to control the tiling of the consumer.
Options to use to control tiling.
SCFTilingOptions & setTileSizeComputationFunction(SCFTileSizeComputationFunction fun)
SCFTilingOptions & setInterchange(ArrayRef< int64_t > interchange)
SCFTilingOptions & setTileSizes(ArrayRef< OpFoldResult > ts)
Convenience function to set the tileSizeComputationFunction to a function that computes tile sizes at...
SmallVector< int64_t > interchangeVector
The interchange vector to reorder the tiled loops.