MLIR  19.0.0git
Go to the documentation of this file.
1 //===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
37 #include "mlir/IR/PatternMatch.h"
38 #include "mlir/IR/TypeUtilities.h"
40 #include "mlir/Support/LLVM.h"
41 #include "mlir/Support/TypeID.h"
43 #include "llvm/ADT/STLExtras.h"
44 #include "llvm/ADT/ScopeExit.h"
45 #include "llvm/ADT/TypeSwitch.h"
46 #include "llvm/Support/Debug.h"
47 #include <type_traits>
49 using namespace mlir;
50 using namespace mlir::linalg;
51 using namespace mlir::transform;
53 #define DEBUG_TYPE "linalg-transforms"
54 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
55 #define DBGSNL() (llvm::dbgs() << "\n")
56 #define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n")
58 /// Attempts to apply the pattern specified as template argument to the given
59 /// operation. The pattern is expected to have a `returningMatchAndRewrite`
60 /// function that returns the "main" result or failure. Returns failure if the
61 /// pattern failed to apply. Extra arguments are forwarded to the pattern
62 /// constructor.
63 template <typename PatternTy, typename... Args>
64 static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
65  // Check if the given operation has the type expected by the pattern.
66  using OpTy = typename llvm::function_traits<
67  decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
68  auto op = dyn_cast<OpTy>(operation);
69  if (!op)
70  return failure();
72  // Apply the pattern directly to the op.
73  PatternTy pattern(operation->getContext(), std::forward<Args>(args)...);
74  // We want to discourage direct use of PatternRewriter in APIs but In this
75  // very specific case, an IRRewriter is not enough.
76  struct TrivialPatternRewriter : public PatternRewriter {
77  public:
78  explicit TrivialPatternRewriter(MLIRContext *context)
79  : PatternRewriter(context) {}
80  };
81  TrivialPatternRewriter rewriter(operation->getContext());
82  rewriter.setInsertionPoint(operation);
83  auto result = pattern.returningMatchAndRewrite(op, rewriter);
84  if (failed(result))
85  return failure();
86  return cast<LinalgOp>(result->getOperation());
87 }
89 /// Assuming that `ofr` is an index attr or a param of index type
90 /// or a transform dialect handle mapped to exactly one op
91 /// with one index result, return that value.
93  transform::TransformState &state, TransformOpInterface transformOp,
95  for (OpFoldResult ofr : ofrs) {
96  if (<Attribute>()) {
97  if (!isa<IntegerAttr>(ofr.get<Attribute>()))
98  return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
99  result.push_back(ofr);
100  continue;
101  }
103  Value transformValue = ofr.get<Value>();
104  if (isa<TransformParamTypeInterface>(transformValue.getType())) {
105  ArrayRef<Attribute> params = state.getParams(transformValue);
106  if (params.size() != 1)
107  return transformOp.emitDefiniteFailure()
108  << "requires exactly one parameter associated";
109  result.push_back(params[0]);
110  continue;
111  }
113  auto payloadOps = state.getPayloadOps(transformValue);
114  if (!llvm::hasSingleElement(payloadOps)) {
116  transformOp.emitSilenceableError()
117  << "handle must be mapped to exactly one payload op";
118  diag.attachNote(transformValue.getLoc())
119  << "mapped to " << llvm::range_size(payloadOps) << " payload ops";
120  return diag;
121  }
123  Operation *op = *payloadOps.begin();
124  if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
126  transformOp.emitSilenceableError()
127  << "payload op must have exactly 1 index result";
128  diag.attachNote(op->getLoc())
129  << "has " << op->getNumResults() << " results";
130  return diag;
131  }
132  result.push_back(op->getResult(0));
133  }
136 }
138 // Given a list of params that are index attrs or a list of OpFoldResults
139 // that are either index attrs or op handles, return a list of OpFoldResults
140 // of index attrs or a list of OpFoldResults where all op handles are
141 // replaced with the first (and only) OpResult of that payload op.
142 // (There must be exactly one parameter associated with the AnyParamType or
143 // one mapped payload op which must have exactly one index result.)
145  transform::TransformState &state, TransformOpInterface transformOp,
146  SmallVector<OpFoldResult> &result, Value packedHandle) {
147  if (isa<TransformParamTypeInterface>(packedHandle.getType())) {
148  ArrayRef<Attribute> params = state.getParams(packedHandle);
149  for (auto param : params) {
150  if (!isa<IntegerAttr>(param))
151  return transformOp.emitDefiniteFailure()
152  << "expected the parameter to be associated with an integer "
153  "attribute";
154  result.push_back(param);
155  }
157  }
159  for (Operation *op : state.getPayloadOps(packedHandle)) {
160  if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
162  transformOp.emitSilenceableError()
163  << "payload op must have exactly 1 index result";
164  diag.attachNote(op->getLoc())
165  << "has " << op->getNumResults() << " results";
166  return diag;
167  }
168  result.push_back(op->getResult(0));
169  }
172 }
174 //===----------------------------------------------------------------------===//
175 // Apply...PatternsOp
176 //===----------------------------------------------------------------------===//
178 void transform::ApplyEraseUnnecessaryInputsPatternsOp::populatePatterns(
179  RewritePatternSet &patterns) {
181 }
183 void transform::ApplyFoldUnitExtentDimsViaReshapesPatternsOp::populatePatterns(
184  RewritePatternSet &patterns) {
187 }
189 void transform::ApplyFoldUnitExtentDimsViaSlicesPatternsOp::populatePatterns(
190  RewritePatternSet &patterns) {
192  options.rankReductionStrategy =
195 }
197 void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns(
198  RewritePatternSet &patterns) {
200 }
202 //===----------------------------------------------------------------------===//
203 // BufferizeToAllocationOp
204 //===----------------------------------------------------------------------===//
206 void transform::BufferizeToAllocationOp::build(OpBuilder &b,
207  OperationState &result,
208  Value target,
209  Attribute memorySpace) {
210  SmallVector<Type> resultTypes;
211  resultTypes.push_back(b.getType<transform::AnyValueType>());
212  resultTypes.push_back(b.getType<transform::AnyOpType>());
213  return build(b, result,
214  /*resultTypes=*/resultTypes,
215  /*target=*/target,
216  /*memorySpace=*/memorySpace);
217 }
219 void transform::BufferizeToAllocationOp::build(OpBuilder &b,
220  OperationState &result,
221  Value target,
222  int64_t memorySpace) {
223  SmallVector<Type> resultTypes;
224  resultTypes.push_back(b.getType<transform::AnyValueType>());
225  resultTypes.push_back(b.getType<transform::AnyOpType>());
226  return build(b, result,
227  /*resultTypes=*/resultTypes,
228  /*target=*/target,
229  /*memorySpace=*/b.getI64IntegerAttr(memorySpace));
230 }
232 namespace {
233 class NewOpsListener : public RewriterBase::ForwardingListener {
234 public:
237  SmallVector<Operation *> getNewOps() const {
238  return SmallVector<Operation *>(newOps.begin(), newOps.end());
239  }
241 private:
242  void notifyOperationInserted(Operation *op,
243  OpBuilder::InsertPoint previous) override {
244  ForwardingListener::notifyOperationInserted(op, previous);
245  // We only care about newly created ops.
246  if (previous.isSet())
247  return;
248  auto inserted = newOps.insert(op);
249  (void)inserted;
250  assert(inserted.second && "expected newly created op");
251  }
253  void notifyOperationErased(Operation *op) override {
254  ForwardingListener::notifyOperationErased(op);
255  op->walk([&](Operation *op) { newOps.erase(op); });
256  }
258  DenseSet<Operation *> newOps;
259 };
260 } // namespace
262 DiagnosedSilenceableFailure transform::BufferizeToAllocationOp::apply(
265  // Attach listener to keep track of newly created ops.
266  OpBuilder::Listener *previousListener = rewriter.getListener();
267  auto resetListener =
268  llvm::make_scope_exit([&]() { rewriter.setListener(previousListener); });
269  NewOpsListener newOpsListener(previousListener);
270  rewriter.setListener(&newOpsListener);
273  if (getMemcpyOp() == "bufferization.materialize_in_destination") {
276  } else if (getMemcpyOp() == "memref.copy") {
277  options.memcpyOp =
279  } else if (getMemcpyOp() == "linalg.copy") {
280  options.memcpyOp =
282  } else {
283  llvm_unreachable("invalid memcpy op");
284  }
285  if (getAllocOp() == "memref.alloc") {
286  options.allocOp =
288  } else if (getAllocOp() == "memref.alloca") {
289  options.allocOp =
291  } else {
292  llvm_unreachable("invalid alloc op");
293  }
294  options.bufferizeDestinationOnly = getBufferizeDestinationOnly();
295  options.emitDealloc = getEmitDealloc();
297  // Bufferize ops.
298  Attribute memorySpace =
299  getMemorySpace().has_value() ? getMemorySpace().value() : Attribute();
300  SmallVector<Value> allocatedBuffers;
301  for (Operation *op : state.getPayloadOps(getTarget())) {
302  Value buffer =
303  linalg::bufferizeToAllocation(rewriter, options, op, memorySpace);
304  if (!buffer) {
305  DiagnosedSilenceableFailure diag = emitSilenceableError()
306  << "failed to bufferize operation";
307  diag.attachNote(op->getLoc()) << "target payload op";
308  return diag;
309  }
310  allocatedBuffers.push_back(buffer);
311  }
313  // Set results.
314  results.setValues(cast<OpResult>(getAllocatedBuffer()), allocatedBuffers);
315  results.set(cast<OpResult>(getNewOps()), newOpsListener.getNewOps());
317 }
319 void transform::BufferizeToAllocationOp::getEffects(
321  if (getBufferizeDestinationOnly()) {
322  // The destination is replaced with a newly allocated buffer, but the op
323  // itself remains in place.
324  onlyReadsHandle(getTarget(), effects);
325  } else {
326  consumesHandle(getTarget(), effects);
327  }
328  producesHandle(getAllocatedBuffer(), effects);
329  producesHandle(getNewOps(), effects);
330  modifiesPayload(effects);
331 }
334  if (getMemcpyOp() != "bufferization.materialize_in_destination" &&
335  getMemcpyOp() != "memref.copy" && getMemcpyOp() != "linalg.copy")
336  return emitOpError() << "unsupported memcpy op";
337  if (getAllocOp() != "memref.alloc" && getAllocOp() != "memref.alloca")
338  return emitOpError() << "unsupported alloc op";
339  return success();
340 }
342 //===----------------------------------------------------------------------===//
343 // DecomposeOp
344 //===----------------------------------------------------------------------===//
347 transform::DecomposeOp::applyToOne(transform::TransformRewriter &rewriter,
348  LinalgOp target,
350  transform::TransformState &state) {
351 #define DOWNSCALE(trans) \
352  { \
353  FailureOr<LinalgOp> res = tryApply<trans>(target); \
354  if (succeeded(res)) { \
355  results.push_back(*res); \
356  return DiagnosedSilenceableFailure::success(); \
357  } \
358  }
360 #define DOWNSCALE_CALL(a, b) DownscaleSizeOneWindowed2DConvolution<a, b>
363  DOWNSCALE_NORMAL(Conv2DNhwcHwcfOp, Conv1DNwcWcfOp)
364  DOWNSCALE_NORMAL(Conv2DNchwFchwOp, Conv1DNcwFcwOp)
365  DOWNSCALE_NORMAL(PoolingNhwcSumOp, PoolingNwcSumOp)
366  DOWNSCALE_NORMAL(PoolingNchwSumOp, PoolingNcwSumOp)
367  DOWNSCALE_NORMAL(PoolingNhwcMaxOp, PoolingNwcMaxOp)
368  DOWNSCALE_NORMAL(PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp)
369  DOWNSCALE_NORMAL(PoolingNhwcMinOp, PoolingNwcMinOp)
370  DOWNSCALE_NORMAL(PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp)
371  DOWNSCALE_NORMAL(PoolingNchwMaxOp, PoolingNcwMaxOp)
376 #undef DOWNSCALE
377  return emitDefaultSilenceableFailure(target);
378 }
380 //===----------------------------------------------------------------------===//
381 // DecomposeInterfaceOp
382 //===----------------------------------------------------------------------===//
384 // Decompose the target operation if it implements the AggregatedOpInterface.
385 // Push the decomposed operations (the ones that replaces the values produced by
386 // \p target) in the `results`.
387 DiagnosedSilenceableFailure transform::DecomposeInterfaceOp::applyToOne(
388  transform::TransformRewriter &rewriter, Operation *target,
390  transform::TransformState &state) {
391  auto decomposableOp = dyn_cast<AggregatedOpInterface>(target);
392  if (!decomposableOp) {
393  failed(rewriter.notifyMatchFailure(target,
394  "payload is not a decomposable op"));
395  return emitDefaultSilenceableFailure(target);
396  }
398  FailureOr<SmallVector<Value>> maybeNewResults =
399  decomposableOp.decomposeOperation(rewriter);
400  if (failed(maybeNewResults))
401  return emitDefaultSilenceableFailure(target);
403  rewriter.replaceOp(decomposableOp, *maybeNewResults);
404  for (Value val : *maybeNewResults) {
405  Operation *definition = val.getDefiningOp();
406  if (definition)
407  results.push_back(definition);
408  }
410 }
412 //===----------------------------------------------------------------------===//
413 // EliminateLinalgOpAnchoredEmptyTensorsOp
414 //===----------------------------------------------------------------------===//
416 void transform::EliminateLinalgOpAnchoredEmptyTensorsOp::getEffects(
418  onlyReadsHandle(getTarget(), effects);
419  modifiesPayload(effects);
420 }
423 transform::EliminateLinalgOpAnchoredEmptyTensorsOp::apply(
424  transform::TransformRewriter &rewriter, TransformResults &transformResults,
425  TransformState &state) {
427  options.allowReturnAllocsFromLoops = true;
429  for (Operation *target : state.getPayloadOps(getTarget())) {
431  if (failed(analyzeOp(target, state)))
432  return mlir::emitSilenceableFailure(target->getLoc())
433  << "failed to analyze op";
435  rewriter, target, state)))
436  return mlir::emitSilenceableFailure(target->getLoc())
437  << "failed to eliminate LinalgOp anchored tensor.empty ops";
438  }
440 }
442 //===----------------------------------------------------------------------===//
443 // FuseOp
444 //===----------------------------------------------------------------------===//
446 /// Apply a tiling transformation to all payload ops and store both the
447 /// tiled operation as well as the created tile loops.
448 template <typename Range>
450  RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps,
451  unsigned numLoops, transform::TransformResults &transformResults,
453  applyFn) {
454  SmallVector<Operation *> tiledLinalgOps;
455  SmallVector<SmallVector<Operation *>> loopOps(numLoops);
457  for (Operation *target : payloadOps) {
458  auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
459  if (!tilingInterfaceOp)
460  return transformOp->emitError("only TilingInterface ops are supported");
462  rewriter.setInsertionPoint(target);
464  applyFn(tilingInterfaceOp);
465  if (failed(tiledResults))
466  return failure();
468  // Perform the replacement of tiled and fused values.
469  SmallVector<Operation *> opsToReplace{target};
470  llvm::append_range(opsToReplace, tiledResults->fusedProducers);
471  for (Operation *toReplace : opsToReplace) {
472  for (OpResult res : toReplace->getResults())
473  if (auto replacement = tiledResults->replacements.lookup(res))
474  rewriter.replaceAllUsesWith(res, replacement);
475  if (toReplace->use_empty()) {
476  rewriter.eraseOp(toReplace);
477  }
478  }
480  // Report back the relevant handles to the transform op.
481  tiledLinalgOps.push_back(tiledResults->tiledAndFusedOps.front());
482  assert(tiledResults->loops.size() == numLoops &&
483  "Mismatched number of loops, tile and fuse transform should have "
484  "failed");
485  for (unsigned int i = 0; i < numLoops; ++i)
486  loopOps[i].push_back(tiledResults->loops[i]);
487  }
489  transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
490  for (unsigned int i = 0; i < numLoops; ++i)
491  transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
493  return success();
494 }
497 transform::FuseOp::apply(transform::TransformRewriter &rewriter,
498  mlir::transform::TransformResults &transformResults,
500  SmallVector<int64_t> tileSizes =
501  extractFromIntegerArrayAttr<int64_t>(getTileSizes());
502  SmallVector<int64_t> tileInterchange =
503  extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
505  scf::SCFTilingOptions tilingOptions;
506  tilingOptions.interchangeVector = tileInterchange;
507  SmallVector<OpFoldResult> tileSizesOfr =
508  getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
509  tilingOptions = tilingOptions.setTileSizes(tileSizesOfr);
510  scf::SCFTileAndFuseOptions tileAndFuseOptions;
511  tileAndFuseOptions.tilingOptions = tilingOptions;
513  rewriter, getOperation(), state.getPayloadOps(getTarget()),
514  tileSizes.size() - llvm::count(tileSizes, 0), transformResults,
515  [&](TilingInterface tilingInterfaceOp)
517  return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
518  tileAndFuseOptions);
519  });
522 }
525  SmallVector<int64_t> permutation =
526  extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
527  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
528  if (!std::is_permutation(sequence.begin(), sequence.end(),
529  permutation.begin(), permutation.end())) {
530  return emitOpError() << "expects interchange to be a permutation, found "
531  << getTileInterchange();
532  }
534  SmallVector<int64_t> sizes =
535  extractFromIntegerArrayAttr<int64_t>(getTileSizes());
536  size_t numExpectedLoops = sizes.size() - llvm::count(sizes, 0);
537  if (numExpectedLoops != getNumResults() - 1)
538  return emitOpError() << "expects " << numExpectedLoops << " loop results";
540  return success();
541 }
543 //===----------------------------------------------------------------------===//
544 // FuseIntoContainingOp
545 //===----------------------------------------------------------------------===//
547 void transform::FuseIntoContainingOp::build(OpBuilder &builder,
548  OperationState &result,
549  Value producerOp,
550  Value containingOp) {
551  result.addOperands({producerOp, containingOp});
552  auto resultType = transform::AnyOpType::get(builder.getContext());
553  result.addTypes({resultType, resultType});
554 }
556 /// Add new operands to the forall op for users of the producerOp
557 /// that are dominated by the containing scf.forall op.
559  RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
560  Operation *containingOp, TilingResult &tileAndFuseResult,
561  int64_t resultNumber, SmallVector<OpFoldResult> &offsets,
562  SmallVector<OpFoldResult> &sizes) {
564  // Count number of users not including the containing op
565  SetVector<Operation *> dominatedUsers;
566  DominanceInfo domInfo(containingOp);
567  for (Operation *user : producerOp->getResult(resultNumber).getUsers()) {
568  if (!containingOp->isAncestor(user) &&
569  (domInfo.dominates(containingOp, user))) {
570  dominatedUsers.insert(user);
571  }
572  }
573  if (dominatedUsers.empty())
574  return nullptr;
576  // Create new scf.forall op
577  auto forallOp = cast<scf::ForallOp>(containingOp);
578  OpBuilder::InsertionGuard g(rewriter);
579  rewriter.setInsertionPoint(forallOp);
581  // Get new output
582  Location loc = forallOp.getLoc();
583  auto genericOp = dyn_cast<linalg::GenericOp>(producerOp);
584  if (!genericOp)
585  return nullptr;
586  SmallVector<Value> outputs = genericOp.getOutputs();
587  SmallVector<Value> newOuts(forallOp.getOutputs());
588  newOuts.push_back(outputs[resultNumber]);
590  // Create new scf.forall op
591  auto newforallOp = rewriter.create<scf::ForallOp>(
592  loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
593  forallOp.getMixedStep(), newOuts, forallOp.getMapping());
594  rewriter.eraseBlock(newforallOp.getBody());
595  newforallOp.getRegion().takeBody(forallOp.getRegion());
597  // Add additional block argument for new value being returned
598  // and replaces all uses of the new output with corresponding bbArg
599  // inside the scf.forall to enable fusion into this new scf.forall.
600  newforallOp.getBody()->addArgument(newOuts.back().getType(),
601  newOuts.back().getLoc());
602  auto bbArgs = newforallOp.getBody()->getArguments();
603  rewriter.replaceUsesWithIf(newOuts.back(), bbArgs.back(),
604  [&](OpOperand &use) {
605  Operation *op = use.getOwner();
606  return newforallOp->isProperAncestor(op);
607  });
609  // Fix terminator
610  scf::InParallelOp terminatorOp = newforallOp.getTerminator();
611  SmallVector<Operation *> yieldingOps = llvm::to_vector<4>(llvm::map_range(
612  terminatorOp.getYieldingOps(), [](Operation &op) { return &op; }));
613  Operation *firstYieldOp = yieldingOps.front();
614  rewriter.setInsertionPoint(firstYieldOp);
615  Value src = tileAndFuseResult.tiledValues[0];
616  Value dst = newforallOp.getRegionIterArgs().back();
617  SmallVector<OpFoldResult> strides(offsets.size(), rewriter.getIndexAttr(1));
618  rewriter.create<tensor::ParallelInsertSliceOp>(firstYieldOp->getLoc(), src,
619  dst, offsets, sizes, strides);
621  for (auto result : llvm::enumerate(forallOp.getResults())) {
622  rewriter.replaceAllUsesWith(result.value(),
623  newforallOp->getResult(result.index()));
624  }
625  rewriter.replaceUsesWithIf(producerOp->getResult(resultNumber),
626  newforallOp->getResults().back(),
627  [&](OpOperand &use) {
628  Operation *user = use.getOwner();
629  return dominatedUsers.contains(user);
630  });
631  return newforallOp;
632 }
634 /// Find the first "extract" user of `producerOp` and tile it right before its
635 /// use. The tiled op is fused under the `containingOp`.
636 /// Return this fused op on success or nullptr if anything fails.
637 /// If tiled op has uses that are dominated by `containingOp`, return
638 /// a new `containingOp` with results of the fused op appended to
639 /// results of the `containingOp` or nullptr if there are no dominated uses.
640 static std::tuple<SmallVector<Operation *>, Operation *>
642  Operation *producerOp, Operation *containingOp) {
643  LLVM_DEBUG(DBGS() << "Try to fuse a direct extract use\n");
644  auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
645  if (!tileableProducer) {
646  diag.attachNote(producerOp->getLoc())
647  << "producer is not a TileableInterface: " << *producerOp;
648  return {};
649  }
651  // Search the producer slices accessed within the containing operation.
652  // TODO: Generalize to more extract/insert/parallel_insert triples, maybe
653  // evolve into an interface.
654  auto it = llvm::find_if(tileableProducer->getUsers(), [&](Operation *user) {
655  auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
656  return sliceOp && containingOp->isProperAncestor(sliceOp);
657  });
659  // Find a fusion opportunity.
660  if (it == tileableProducer->getUsers().end()) {
661  diag.attachNote(tileableProducer->getLoc())
662  << "could not find fusion opportunity for: " << *tileableProducer;
663  return {};
664  }
665  auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*it);
667  // Try to fuse the producer in-place.
668  OpBuilder::InsertionGuard guard(rewriter);
669  rewriter.setInsertionPoint(sliceOpToTile);
671  // Tile the producer.
672  int64_t resultNumber =
673  cast<OpResult>(sliceOpToTile.getSource()).getResultNumber();
674  LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
676  SmallVector<OpFoldResult> offsets = sliceOpToTile.getMixedOffsets();
677  SmallVector<OpFoldResult> sizes = sliceOpToTile.getMixedSizes();
679  FailureOr<TilingResult> tileAndFuseResult =
680  tileableProducer.generateResultTileValue(rewriter, resultNumber, offsets,
681  sizes);
683  if (failed(tileAndFuseResult)) {
684  diag.attachNote(tileableProducer->getLoc())
685  << "failed to tile producer op: " << *tileableProducer;
686  return {};
687  }
689 #ifndef NDEBUG
690  for (auto *tiledOp : tileAndFuseResult->tiledOps) {
691  LLVM_DEBUG(DBGS() << "tiledProducer: " << *tiledOp << "\n");
692  }
693 #endif
695  // Replace the extract op.
696  auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
697  rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
698  cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
699  if (failed(maybeRankReduced)) {
700  diag.attachNote(producerOp->getLoc())
701  << "shape types don't match (missing canonicalization?):\nTiledOp: "
702  << tileAndFuseResult->tiledValues[0]
703  << "\nSliceOp: " << sliceOpToTile.getOperation() << '\n';
704  return {};
705  }
706  rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
708  // Add new outputs to containing op, if required
709  Operation *newContainingOp = replaceForAllWithNewSignature(
710  rewriter, diag, producerOp, containingOp, *tileAndFuseResult,
711  resultNumber, offsets, sizes);
713  return std::make_tuple(tileAndFuseResult->tiledOps, newContainingOp);
714 }
716 /// First, find the first "scf::ForallOp" user of `producerOp` and ensure
717 /// it is exactly the `containingOp`, otherwise bail.
718 /// Then, find the first "extract" user of the tied block argument and tile it
719 /// right before its "extract" use. The tiled op is fused under the
720 /// `containingOp`.
721 /// Return this fused op on success or nullptr if anything fails.
724  RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
725  Operation *containingOp) {
726  LLVM_DEBUG(DBGS() << "Try to fuse an extract use through block argument\n");
728  auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
729  if (!tileableProducer) {
730  diag.attachNote(producerOp->getLoc())
731  << "producer is not a TileableInterface: " << *producerOp;
732  return {};
733  }
735  // Search the first use by a "scf::ForallOp" user.
736  scf::ForallOp forallOp;
737  auto itProducerUses =
738  llvm::find_if(tileableProducer->getUses(), [&](OpOperand &use) {
739  forallOp = dyn_cast<scf::ForallOp>(use.getOwner());
740  return forallOp;
741  });
742  // If it's not from the containing op, return.
743  if (!forallOp || forallOp != containingOp) {
744  diag.attachNote(tileableProducer->getLoc())
745  << "could not find a use by the containing op: " << *tileableProducer;
746  return {};
747  }
749  // Search the producer slices accessed within the containing
750  // operation.
751  // TODO: Generalize to more extract/insert/parallel_insert triples.
752  // Maybe evolve into an interface.
753  OpOperand *pUse = &(*itProducerUses);
754  BlockArgument bbArg = forallOp.getTiedBlockArgument(pUse);
756  // Search the producer slices accessed within the containing operation.
757  // TODO: Generalize to more extract/insert/parallel_insert triples, maybe
758  // evolve into an interface.
759  auto itBBArgUsers = llvm::find_if(bbArg.getUsers(), [&](Operation *user) {
760  auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
761  return sliceOp && containingOp->isProperAncestor(sliceOp);
762  });
764  // Find a fusion opportunity.
765  if (itBBArgUsers == bbArg.getUsers().end()) {
766  diag.attachNote(containingOp->getLoc())
767  << "could not find fusion opportunity for bbArg: " << bbArg;
768  return {};
769  }
770  auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*itBBArgUsers);
772  // Try to fuse the producer in-place.
773  OpBuilder::InsertionGuard guard(rewriter);
774  rewriter.setInsertionPoint(sliceOpToTile);
776  // Replace the use in the tileableProducer before tiling: clone, replace and
777  // then tile.
778  int64_t resultNumber = cast<OpResult>(pUse->get()).getResultNumber();
779  LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
781  // Gather destination tensors.
782  SmallVector<Value> destinationTensors;
784  rewriter, tileableProducer->getLoc(), tileableProducer,
785  destinationTensors))) {
786  diag.attachNote(tileableProducer->getLoc())
787  << "failed to get destination tensors for: " << *tileableProducer;
788  return {};
789  }
791  IRMapping bvm;
792[resultNumber], bbArg);
793  auto tileableProducerClone =
794  cast<TilingInterface>(rewriter.clone(*tileableProducer, bvm));
795  auto scopeGuard =
796  llvm::make_scope_exit([&]() { rewriter.eraseOp(tileableProducerClone); });
798  // Tile the producer.
799  FailureOr<TilingResult> tileAndFuseResult =
800  tileableProducerClone.generateResultTileValue(
801  rewriter, resultNumber, sliceOpToTile.getMixedOffsets(),
802  sliceOpToTile.getMixedSizes());
803  if (failed(tileAndFuseResult)) {
804  diag.attachNote(tileableProducer->getLoc())
805  << "failed to tile producer op: " << *tileableProducer;
806  return {};
807  }
809  // Replace the extract op.
810  auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
811  rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
812  cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
813  assert(succeeded(maybeRankReduced) && "unexpected shape");
814  rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
816  // Replace the use in containingOp.
817  rewriter.modifyOpInPlace(containingOp, [&]() {
818  containingOp->setOperand(pUse->getOperandNumber(),
819  destinationTensors.front());
820  });
822  return tileAndFuseResult->tiledOps;
823 }
826  Operation *producerOp,
827  Operation *containingOp) {
828  LLVM_DEBUG(DBGS() << "Try to fuse an use by cloning\n");
830  // Gather all uses inside the containing op.
832  for (OpResult result : producerOp->getOpResults()) {
833  for (OpOperand &use : result.getUses()) {
834  if (containingOp->isProperAncestor(use.getOwner())) {
835  uses.push_back(&use);
836  continue;
837  }
838  // Cannot clone and fuse if the use is by the containing op itself: fail
839  // immediately.
840  if (containingOp == use.getOwner()) {
841  diag.attachNote(producerOp->getLoc())
842  << "producer op use by containing op cannot be fused by cloning";
843  return nullptr;
844  }
845  }
846  }
848  // Check for a non-empty list of fusion opportunities.
849  if (uses.empty()) {
850  diag.attachNote(producerOp->getLoc()) << "no fusion opportunity by cloning";
851  return nullptr;
852  }
854  // Clone and fuse inside the containing op.
855  Operation *fusedOp = nullptr;
856  OpOperand *use = uses.front();
857  // Parallel insert slice is not a valid clone destination.
858  // TODO: Generalize to other type of ops.
859  assert(!isa<tensor::ParallelInsertSliceOp>(use->getOwner()) &&
860  "Parallel insert slice is not a valid clone destination");
861  unsigned resultNumber = cast<OpResult>(use->get()).getResultNumber();
862  LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
864  OpBuilder::InsertionGuard guard(rewriter);
865  rewriter.setInsertionPoint(use->getOwner());
866  fusedOp = rewriter.clone(*producerOp);
867  rewriter.modifyOpInPlace(
868  use->getOwner(), [&] { use->set(fusedOp->getOpResult(resultNumber)); });
870  return fusedOp;
871 }
873 bool transform::FuseIntoContainingOp::allowsRepeatedHandleOperands() {
874  // Allow repeated handles since we are fusing everything anyway.
875  return true;
876 }
879 transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter,
881  transform::TransformState &state) {
882  SmallVector<Operation *> fusedOps;
883  auto producerOps = state.getPayloadOps(getProducerOp());
884  auto containingOps = state.getPayloadOps(getContainingOp());
885  if (!llvm::hasSingleElement(containingOps)) {
886  return emitDefiniteFailure()
887  << "requires exactly one containing_op handle (got "
888  << llvm::range_size(containingOps) << ")";
889  }
890  Operation *containingOp = *containingOps.begin();
892  // If nothing to fuse, propagate success.
893  if (std::empty(producerOps)) {
894  results.set(cast<OpResult>(getFusedOp()), SmallVector<mlir::Operation *>{});
895  results.set(cast<OpResult>(getNewContainingOp()), {containingOp});
897  }
899  // Helper function to find the next producer that should be fused. Take any
900  // producer that has a use inside the containing op.
901  SetVector<Operation *> remainingProducers(producerOps.begin(),
902  producerOps.end());
903  auto getNextProducer = [&]() -> FailureOr<Operation *> {
904  for (const auto &it : enumerate(remainingProducers)) {
905  Operation *producerOp = it.value();
906  // The containing op may be a user of producerOp: use isAncestor.
907  int64_t numUsesInContainingOp =
908  llvm::count_if(producerOp->getUsers(), [&](Operation *op) {
909  return containingOp->isAncestor(op);
910  });
911  // TODO: When resolving the TODO below (no duplicate ops), take an op
912  // that has no use among the remaining producers. This is a topological
913  // sorting.
914  if (numUsesInContainingOp > 0) {
915  if (numUsesInContainingOp == 1)
916  remainingProducers.erase(remainingProducers.begin() + it.index());
917  return producerOp;
918  }
919  }
920  return failure();
921  };
923  while (!remainingProducers.empty()) {
924  auto nextProducer = getNextProducer();
925  if (failed(nextProducer)) {
926  auto diag = mlir::emitSilenceableFailure(getLoc())
927  << "could not find next producer to fuse into container";
928  diag.attachNote(containingOp->getLoc()) << "containing op";
929  return diag;
930  }
932  Operation *producerOp = *nextProducer;
934  // Default diagnostic, to be complemented with more failure information.
936  diag << "could not fuse " << *producerOp << " into " << *containingOp;
938  // TODO: If there are multiple uses of the producer in the containing op,
939  // we currently tile/clone the op multiple times (once per use). In some
940  // cases, we can tile/clone once and reuse the value for each use.
941  // Futhermore, producers should then be traversed according to a
942  // topological sorting.
943  auto [tiledOps, newContainingOp] =
944  tileAndFuseFirstExtractUse(rewriter, diag, producerOp, containingOp);
945  if (!tiledOps.empty()) {
946  LLVM_DEBUG(DBGS() << "\nFused a direct extract use\n" << *containingOp);
947  fusedOps.append(tiledOps);
948  if (newContainingOp) {
949  // Update handles associated with the containing op so we don't need to
950  // invalidate them. This is a hack to support better composability
951  // between tiling and fusion while a proper mechanism is being
952  // investigated.
953  //
954  // DO NOT replicate this elsewhere unless you understand what you are
955  // doing.
956  LogicalResult replacementStatus =
957  rewriter.notifyPayloadOperationReplaced(containingOp,
958  newContainingOp);
959  (void)replacementStatus;
960  assert(succeeded(replacementStatus) &&
961  "unable to update transform state mapping");
962  rewriter.eraseOp(containingOp);
963  containingOp = newContainingOp;
964  }
965  continue;
966  }
968  SmallVector<Operation *> tiledContainingOpOperand =
970  rewriter, diag, producerOp, containingOp);
971  if (!tiledContainingOpOperand.empty()) {
972  LLVM_DEBUG(DBGS() << "\nFused an extract use through block argument\n"
973  << *containingOp);
974  fusedOps.append(tiledContainingOpOperand);
975  continue;
976  }
978  Operation *cloned =
979  cloneAndFuseFirstUse(rewriter, diag, producerOp, containingOp);
980  if (cloned) {
981  LLVM_DEBUG(DBGS() << "\nFused an use by cloning\n" << *containingOp);
982  fusedOps.push_back(cloned);
983  continue;
984  }
986  }
988  results.set(cast<OpResult>(getFusedOp()), fusedOps);
989  results.set(cast<OpResult>(getNewContainingOp()), {containingOp});
991 }
993 void transform::FuseIntoContainingOp::getEffects(
995  consumesHandle(getProducerOp(), effects);
996  onlyReadsHandle(getContainingOp(), effects);
997  producesHandle(getResults(), effects);
998  modifiesPayload(effects);
999 }
1001 //===----------------------------------------------------------------------===//
1002 // GeneralizeOp
1003 //===----------------------------------------------------------------------===//
1006 transform::GeneralizeOp::applyToOne(transform::TransformRewriter &rewriter,
1007  LinalgOp target,
1009  transform::TransformState &state) {
1010  // Exit early if no transformation is needed.
1011  if (isa<GenericOp>(target)) {
1012  results.push_back(target);
1014  }
1015  rewriter.setInsertionPoint(target);
1016  FailureOr<LinalgOp> generic = generalizeNamedOp(rewriter, target);
1017  if (succeeded(generic)) {
1018  results.push_back(generic->getOperation());
1020  }
1021  return emitDefaultSilenceableFailure(target);
1022 }
1024 //===----------------------------------------------------------------------===//
1025 // SpecializeOp
1026 //===----------------------------------------------------------------------===/
1029 transform::SpecializeOp::applyToOne(transform::TransformRewriter &rewriter,
1030  LinalgOp target,
1032  transform::TransformState &state) {
1033  // Exit early if the operation is not a generic.
1034  if (!isa<GenericOp>(target)) {
1035  results.push_back(target);
1037  }
1038  rewriter.setInsertionPoint(target);
1039  FailureOr<LinalgOp> named =
1040  specializeGenericOp(rewriter, cast<GenericOp>(target));
1041  if (succeeded(named)) {
1042  results.push_back(named->getOperation());
1044  }
1045  return emitDefaultSilenceableFailure(target);
1046 }
1048 //===----------------------------------------------------------------------===//
1049 // InterchangeOp
1050 //===----------------------------------------------------------------------===//
1053 transform::InterchangeOp::applyToOne(transform::TransformRewriter &rewriter,
1054  GenericOp target,
1056  transform::TransformState &state) {
1057  ArrayRef<int64_t> interchangeVector = getIteratorInterchange();
1058  // Exit early if no transformation is needed.
1059  if (interchangeVector.empty()) {
1060  results.push_back(target);
1062  }
1064  unsigned numLoops = cast<LinalgOp>(target.getOperation()).getNumLoops();
1065  if (interchangeVector.size() != numLoops) {
1066  return emitSilenceableError()
1067  << getIteratorInterchangeAttrName() << " has length ("
1068  << interchangeVector.size()
1069  << ") different from the number of loops in the target operation ("
1070  << numLoops << ")";
1071  }
1072  FailureOr<GenericOp> res =
1073  interchangeGenericOp(rewriter, target,
1074  SmallVector<unsigned>(interchangeVector.begin(),
1075  interchangeVector.end()));
1076  if (failed(res))
1077  return emitDefiniteFailure() << "failed to apply";
1078  results.push_back(res->getOperation());
1080 }
1083  ArrayRef<int64_t> permutation = getIteratorInterchange();
1084  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
1085  if (!std::is_permutation(sequence.begin(), sequence.end(),
1086  permutation.begin(), permutation.end())) {
1087  return emitOpError()
1088  << "expects iterator_interchange to be a permutation, found "
1089  << getIteratorInterchange();
1090  }
1091  return success();
1092 }
1094 //===----------------------------------------------------------------------===//
1095 // LowerPackOp
1096 //===----------------------------------------------------------------------===//
1098 DiagnosedSilenceableFailure transform::LowerPackOp::applyToOne(
1099  transform::TransformRewriter &rewriter, tensor::PackOp target,
1100  transform::ApplyToEachResultList &transformResults,
1101  transform::TransformState &state) {
1102  rewriter.setInsertionPoint(target);
1103  FailureOr<LowerPackResult> res = lowerPack(rewriter, target);
1104  if (failed(res)) {
1105  return mlir::emitSilenceableFailure(target->getLoc())
1106  << "cannot lower to pad + expand + transpose";
1107  }
1108  transformResults.push_back(res->padOp);
1109  transformResults.push_back(res->expandShapeOp);
1110  transformResults.push_back(res->transposeOp);
1112 }
1114 //===----------------------------------------------------------------------===//
1115 // LowerUnPackOp
1116 //===----------------------------------------------------------------------===//
1118 DiagnosedSilenceableFailure transform::LowerUnPackOp::applyToOne(
1119  transform::TransformRewriter &rewriter, tensor::UnPackOp target,
1120  transform::ApplyToEachResultList &transformResults,
1121  transform::TransformState &state) {
1122  rewriter.setInsertionPoint(target);
1123  FailureOr<LowerUnPackOpResult> res = lowerUnPack(rewriter, target);
1124  if (failed(res)) {
1126  emitSilenceableError()
1127  << "cannot lower to transpose + collapse + extract";
1128  diag.attachNote(target->getLoc()) << "target payload op";
1129  return diag;
1130  }
1131  transformResults.push_back(res->emptyOp);
1132  transformResults.push_back(res->transposeOp);
1133  transformResults.push_back(res->collapseShapeOp);
1134  transformResults.push_back(res->extractSliceOp);
1136 }
1138 //===---------------------------------------------------------------------===//
1139 // MatchOp
1140 //===---------------------------------------------------------------------===//
1142 void transform::MatchOp::build(OpBuilder &builder, OperationState &result,
1143  Value target, ArrayRef<StringRef> opNames) {
1144  result.addOperands(target);
1145  result.addAttribute(MatchOp::getOpsAttrName(,
1146  builder.getStrArrayAttr(opNames));
1147  result.addTypes(transform::AnyOpType::get(builder.getContext()));
1148 }
1150 void transform::MatchOp::build(OpBuilder &builder, OperationState &result,
1151  TypeRange resultTypes, Value target,
1152  ArrayRef<StringRef> opNames) {
1153  result.addOperands(target);
1154  result.addAttribute(MatchOp::getOpsAttrName(,
1155  builder.getStrArrayAttr(opNames));
1156  result.addTypes(resultTypes);
1157 }
1160 transform::MatchOp::apply(transform::TransformRewriter &rewriter,
1161  transform::TransformResults &results,
1162  transform::TransformState &state) {
1163  llvm::StringSet<> strs;
1164  if (getOps().has_value())
1165  strs.insert(getOps()->getAsValueRange<StringAttr>().begin(),
1166  getOps()->getAsValueRange<StringAttr>().end());
1168  auto payloadOps = state.getPayloadOps(getTarget());
1169  if (!llvm::hasSingleElement(payloadOps)) {
1170  return emitDefiniteFailure("requires exactly one target handle");
1171  }
1174  bool incorrectNumOperandTypes = false;
1175  auto matchFun = [&](Operation *op) {
1176  if (getOps().has_value() && !strs.contains(op->getName().getStringRef()))
1177  return;
1179  // Interfaces cannot be matched by name, just by ID.
1180  // So we specifically encode the interfaces we care about for this op.
1181  if (getInterface().has_value()) {
1182  auto iface = getInterface().value();
1183  if (iface == transform::MatchInterfaceEnum::LinalgOp &&
1184  !isa<LinalgOp>(op))
1185  return;
1186  if (iface == transform::MatchInterfaceEnum::TilingInterface &&
1187  !isa<TilingInterface>(op))
1188  return;
1189  if (iface == transform::MatchInterfaceEnum::LoopLikeInterface &&
1190  !isa<LoopLikeOpInterface>(op))
1191  return;
1192  }
1194  // Check if all specified attributes match.
1195  if (getOpAttrs().has_value()) {
1196  DictionaryAttr opAttrs = getOpAttrs().value();
1197  for (NamedAttribute attr : opAttrs) {
1198  if (attr.getName() == getInterfaceAttrName() ||
1199  attr.getName() == getOpsAttrName())
1200  continue;
1201  if (!op->hasAttr(attr.getName()))
1202  return;
1203  if (op->getAttr(attr.getName()) != attr.getValue())
1204  return;
1205  }
1206  }
1208  if (getFilterResultType().has_value()) {
1209  Type t = getFilterResultType().value();
1210  if (op->getNumResults() != 1 || op->getResultTypes().front() != t)
1211  return;
1212  }
1214  if (getFilterOperandTypes().has_value()) {
1215  mlir::ArrayAttr types = getFilterOperandTypes().value();
1216  auto operandTypes = op->getOperandTypes();
1218  if (types.size() == 1) {
1219  // All the operands must must be equal to the specified type
1220  auto typeattr =
1221  dyn_cast<mlir::TypeAttr>(getFilterOperandTypes().value()[0]);
1222  Type t = typeattr.getValue().cast<::mlir::Type>();
1223  if (!llvm::all_of(op->getOperandTypes(),
1224  [&](Type operandType) { return operandType == t; }))
1225  return;
1226  } else {
1227  // The operand types must match all the types in the list (in the same
1228  // order in with they are specified)
1229  if (types.size() != operandTypes.size()) {
1230  incorrectNumOperandTypes = true;
1231  return;
1232  }
1234  for (auto [attr, operandType] :
1235  llvm::zip_equal(getFilterOperandTypes().value(), operandTypes)) {
1236  auto typeattr = cast<mlir::TypeAttr>(attr);
1237  Type type = typeattr.getValue().cast<::mlir::Type>();
1239  if (type != operandType)
1240  return;
1241  }
1242  }
1243  }
1245  // All constraints are satisfied.
1246  res.push_back(op);
1247  return;
1248  };
1250  (*payloadOps.begin())->walk(matchFun);
1251  if (incorrectNumOperandTypes)
1252  return emitDefiniteFailure("If filter_operand_types contains more than a "
1253  "type, then it must contain as much types as "
1254  "the number of operands in the target ops");
1255  results.set(cast<OpResult>(getResult()), res);
1257 }
1259 //===---------------------------------------------------------------------===//
1260 // MultiTileSizesOp
1261 //===---------------------------------------------------------------------===//
1264  Type targetType, Type lowSizeType, Type,
1265  Type) {
1266  printer.printFunctionalType(TypeRange{targetType}, TypeRange{lowSizeType});
1267 }
1270  Type &targetType, Type &lowSizeType,
1271  Type &highSizeType,
1272  Type &splitPointType) {
1273  FunctionType funcType;
1274  llvm::SMLoc typeLoc = parser.getCurrentLocation();
1275  if (failed(parser.parseType<FunctionType>(funcType)))
1276  return failure();
1278  if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
1279  parser.emitError(typeLoc) << "expects a trailing functional type with one "
1280  "argument and one result";
1281  }
1282  targetType = funcType.getInput(0);
1283  lowSizeType = highSizeType = splitPointType = funcType.getResult(0);
1285  return success();
1286 }
1288 DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne(
1289  transform::TransformRewriter &rewriter, LinalgOp target,
1291  if (isa<TransformParamTypeInterface>(getLowSize().getType())) {
1292  if (target.hasDynamicShape()) {
1293  auto diag = emitSilenceableError()
1294  << "cannot compute parametric tile sizes for dynamically "
1295  "shaped payload op";
1296  diag.attachNote(target->getLoc()) << "payload op";
1297  return diag;
1298  }
1301  target, getDimension(), getTargetSize(), getDivisor());
1302  if (failed(spec)) {
1303  return emitSilenceableError()
1304  << "failed to compute multi-size tiling sizes";
1305  }
1307  Builder builder(target.getContext());
1308  results.assign(llvm::map_range(
1309  ArrayRef<int64_t>({spec->lowTileSize, spec->highTileSize,
1310  spec->lowTileSize * spec->lowTripCount}),
1311  [&builder, this](int64_t value) {
1312  return builder.getIntegerAttr(
1313  cast<ParamType>(getLowSize().getType()).getType(), value);
1314  }));
1316  }
1318  OpBuilder builder(target.getContext());
1319  builder.setInsertionPoint(target);
1320  OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
1321  OpFoldResult divisor = builder.getIndexAttr(getDivisor());
1323  builder, target, getDimension(), targetSize, divisor);
1324  if (failed(spec)) {
1325  return emitSilenceableError() << "could not generate tile size computation";
1326  }
1328  AffineExpr s0 = builder.getAffineSymbolExpr(0);
1329  AffineExpr s1 = builder.getAffineSymbolExpr(1);
1330  Operation *splitPoint =
1331  affine::makeComposedAffineApply(builder, target.getLoc(), s0 * s1,
1332  {spec->lowTileSize, spec->lowTripCount});
1333  Operation *lowTileSize = spec->lowTileSize.getDefiningOp();
1334  Operation *highTileSize = spec->highTileSize.getDefiningOp();
1335  assert(lowTileSize && highTileSize && splitPoint &&
1336  "tile sizes are not produced by operations");
1337  results.reserve(results.size() + 3);
1338  results.push_back(lowTileSize);
1339  results.push_back(highTileSize);
1340  results.push_back(splitPoint);
1342 }
1344 void transform::MultiTileSizesOp::getEffects(
1346  onlyReadsHandle(getTarget(), effects);
1347  producesHandle(getResults(), effects);
1348  if (isa<TransformParamTypeInterface>(getLowSize().getType()))
1349  onlyReadsPayload(effects);
1350  else
1351  modifiesPayload(effects);
1352 }
1355  if (getLowSize().getType() != getHighSize().getType() ||
1356  getLowSize().getType() != getSplitPoint().getType()) {
1357  return emitOpError() << "expects all results type to be the same";
1358  }
1359  return success();
1360 }
1362 //===---------------------------------------------------------------------===//
1363 // PackOp
1364 //===---------------------------------------------------------------------===//
1366 void transform::PackOp::build(OpBuilder &builder, OperationState &result,
1367  Value target,
1368  ArrayRef<OpFoldResult> mixedPackedSizes) {
1369  SmallVector<int64_t> staticPackedSizes;
1370  SmallVector<Value> dynamicPackedSizes;
1371  dispatchIndexOpFoldResults(mixedPackedSizes, dynamicPackedSizes,
1372  staticPackedSizes);
1373  // Call the default builder which sets up the proper operands segment sizes
1374  // attributes for multiple variadic operands. In the absence of this, horrible
1375  // bugs ensue.
1376  Type linalgOpHType = transform::OperationType::get(
1377  builder.getContext(), GenericOp::getOperationName());
1378  build(builder, result,
1379  /*resultType=*/linalgOpHType,
1380  /*target=*/target,
1381  /*dynamic_sizes=*/dynamicPackedSizes,
1382  /*static_sizes=*/builder.getDenseI64ArrayAttr(staticPackedSizes));
1383 }
1385 SmallVector<OpFoldResult> transform::PackOp::getMixedPackedSizes() {
1386  Builder b(getContext());
1387  return getMixedValues(getStaticPackedSizes(), getPackedSizes(), b);
1388 }
1391 transform::PackOp::apply(transform::TransformRewriter &rewriter,
1392  transform::TransformResults &transformResults,
1393  transform::TransformState &state) {
1394  auto targetOps = state.getPayloadOps(getTarget());
1395  // If nothing to pack, propagate success.
1396  if (std::empty(targetOps)) {
1397  transformResults.set(cast<OpResult>(getPackedOp()),
1398  ArrayRef<Operation *>({}));
1400  }
1401  // Fail on multi-op handles.
1402  auto linalgOp = dyn_cast<LinalgOp>(*targetOps.begin());
1403  if (!llvm::hasSingleElement(targetOps) || !linalgOp) {
1404  return emitSilenceableError()
1405  << "requires target to map to exactly 1 LinalgOp (got "
1406  << llvm::range_size(targetOps) << ")";
1407  }
1408  // Fail on mismatched number of pack sizes.
1409  if (getMixedPackedSizes().size() != linalgOp.getNumLoops()) {
1410  return emitSilenceableError()
1411  << "requires number of packed sizes match the number of loops ("
1412  << getMixedPackedSizes().size() << " vs " << linalgOp.getNumLoops()
1413  << ")";
1414  }
1416  // Unpack handles to constants or actual SSA index values.
1417  SmallVector<OpFoldResult> packedSizes;
1419  state, *this, packedSizes, getMixedPackedSizes());
1421  rewriter.setInsertionPoint(linalgOp);
1422  FailureOr<PackResult> maybeResult = pack(rewriter, linalgOp, packedSizes);
1423  if (failed(maybeResult))
1424  return emitDefiniteFailure("data tiling failed");
1426  transformResults.set(cast<OpResult>(getPackedOp()),
1427  {maybeResult->packedLinalgOp.getOperation()});
1429 }
1431 void transform::PackOp::getEffects(
1433  transform::consumesHandle(getTarget(), effects);
1434  transform::onlyReadsHandle(getPackedSizes(), effects);
1435  transform::producesHandle(getPackedOp(), effects);
1436  transform::modifiesPayload(effects);
1437 }
1439 //===---------------------------------------------------------------------===//
1440 // PackGreedilyOp.
1441 //===---------------------------------------------------------------------===//
1444  if (!isPermutationVector(getMatmulInnerDimsOrder())) {
1445  return emitOpError() << getMatmulInnerDimsOrderAttrName()
1446  << " is not a valid permutation";
1447  }
1448  // TODO: relax to allow empty once we have another strategy than just matmul.
1449  if (!getMatmulPaddedSizesNextMultipleOf().empty()) {
1450  for (auto [s, nmo] :
1451  llvm::zip_equal(getMixedMatmulPackedSizes(),
1452  getMatmulPaddedSizesNextMultipleOf())) {
1453  std::optional<int64_t> maybeStaticPackedSize = getConstantIntValue(s);
1454  if (nmo != 0 &&
1455  (!maybeStaticPackedSize.has_value() || *maybeStaticPackedSize != 0)) {
1456  return emitOpError() << "at most one of the packed_size and the "
1457  "padded_sizes_next_multiple_of can be nonzero "
1458  "for the matmul strategy";
1459  }
1460  }
1461  }
1462  return success();
1463 }
1466 PackGreedilyOp::apply(transform::TransformRewriter &rewriter,
1467  transform::TransformResults &transformResults,
1468  transform::TransformState &state) {
1469  SmallVector<Operation *> results;
1470  for (Operation *op : state.getPayloadOps(getTarget())) {
1471  auto linalgOp = dyn_cast<LinalgOp>(op);
1472  if (!linalgOp)
1473  continue;
1474  // linalgOp will be replaced and the insertion point may be invalidated if
1475  // we set it before -> set it after.
1476  rewriter.setInsertionPointAfter(linalgOp);
1477  // Failing to pack greedily is perfectly fine.
1478  // In the future we will want to order packings according to some metric.
1480  /*rewriter=*/rewriter,
1481  /*linalgOp=*/linalgOp,
1482  /*mnkPackedSizes=*/getMixedMatmulPackedSizes(),
1483  /*mnkPaddedSizesNextMultipleOf=*/
1484  getMatmulPaddedSizesNextMultipleOf(),
1485  /*mnkOrder=*/getMatmulInnerDimsOrder());
1486  if (succeeded(packResult)) {
1487  results.push_back(packResult->packedLinalgOp);
1488  continue;
1489  }
1490  results.push_back(linalgOp);
1491  }
1492  transformResults.set(cast<OpResult>(getPackedOp()), results);
1494 }
1496 SmallVector<OpFoldResult> PackGreedilyOp::getMixedMatmulPackedSizes() {
1497  Builder b(getContext());
1498  return getMixedValues(getStaticMatmulPackedSizes(), getMatmulPackedSizes(),
1499  b);
1500 }
1502 void transform::PackGreedilyOp::getEffects(
1504  transform::consumesHandle(getTarget(), effects);
1505  transform::onlyReadsHandle(getMatmulPackedSizes(), effects);
1506  transform::producesHandle(getPackedOp(), effects);
1507  transform::modifiesPayload(effects);
1508 }
1510 //===---------------------------------------------------------------------===//
1511 // PackTransposeOp
1512 //===---------------------------------------------------------------------===//
1515  if (!isPermutationVector(getInnerPerm())) {
1516  return emitOpError() << getInnerPermAttrName()
1517  << " is not a valid permutation";
1518  }
1519  if (!isPermutationVector(getOuterPerm())) {
1520  return emitOpError() << getOuterPermAttrName()
1521  << " is not a valid permutation";
1522  }
1523  if (getInnerPerm().empty() && getOuterPerm().empty()) {
1524  return emitOpError() << " at least one of " << getInnerPermAttrName()
1525  << " or " << getOuterPermAttrName()
1526  << " must be specified";
1527  }
1528  return success();
1529 }
1531 namespace {
1532 enum class OuterOrInnerPerm { Outer = 0, Inner = 1 };
1533 } // namespace
1535 /// Return true if `permutation` is a valid permutation of the
1536 /// `outer_dims_perm` (case OuterOrInnerPerm::Outer) or `inner_dims_pos`
1537 /// (OuterOrInnerPerm::Inner) of the `tensor.pack` or `tensor.unpack` `op.
1538 /// This is the case when the `permutation` rank matches the rank expected by
1539 /// `op` and `permutation` is itself a permutation vector.
1540 /// Return true if either `op` or `permutation` are empty to allow a simpler
1541 /// polymorphic implementation.
1542 template <typename RelayoutOpTy>
1544  RelayoutOpTy op, ArrayRef<int64_t> permutation,
1545  OuterOrInnerPerm outerOrInnerPerm = OuterOrInnerPerm::Outer) {
1546  static_assert(
1547  llvm::is_one_of<RelayoutOpTy, tensor::PackOp, tensor::UnPackOp>::value,
1548  "applies to only pack or unpack operations");
1549  if (!op || permutation.empty())
1550  return true;
1551  size_t innerRank = op.getInnerDimsPos().size();
1552  if (outerOrInnerPerm == OuterOrInnerPerm::Inner)
1553  return permutation.size() == innerRank && isPermutationVector(permutation);
1554  // op.getOuterDimsPerm() may be empty, in which case it is identity.
1555  // Don't rely on it.
1556  if (std::is_same<RelayoutOpTy, tensor::PackOp>::value) {
1557  return permutation.size() == op.getSourceRank() &&
1558  isPermutationVector(permutation);
1559  }
1560  return permutation.size() == op.getDestRank() &&
1561  isPermutationVector(permutation);
1562 }
1565 transform::PackTransposeOp::apply(transform::TransformRewriter &rewriter,
1566  transform::TransformResults &transformResults,
1567  transform::TransformState &state) {
1568  auto packOrUnpackOps = state.getPayloadOps(getTargetPackOrUnPackOp());
1569  auto linalgOps = state.getPayloadOps(getTargetLinalgOp());
1570  // Step 1. If nothing to pack, propagate success.
1571  if (std::empty(packOrUnpackOps)) {
1572  transformResults.set(cast<OpResult>(getPackedOp()), {});
1573  transformResults.set(cast<OpResult>(getPackOp()), {});
1574  transformResults.set(cast<OpResult>(getUnPackOp()), {});
1576  }
1578  // Step 2. Bunch of runtime sanity check and error messages.
1579  // Step 2.1. Fail on multi-op handles.
1580  if (!llvm::hasSingleElement(packOrUnpackOps) ||
1581  !llvm::hasSingleElement(linalgOps)) {
1582  return emitSilenceableError()
1583  << "requires target to map to exactly 1 "
1584  "packing op and 1 packed op ("
1585  << "got " << llvm::range_size(packOrUnpackOps) << " and "
1586  << llvm::range_size(linalgOps) << ")";
1587  }
1589  // Step 2.2. Fail on wrong type.
1590  auto packOp = dyn_cast<tensor::PackOp>(*packOrUnpackOps.begin());
1591  auto unPackOp = dyn_cast<tensor::UnPackOp>(*packOrUnpackOps.begin());
1592  if ((!packOp && !unPackOp)) {
1593  return emitSilenceableError() << "requires target to map to a "
1594  "tensor.pack or tensor.unpack";
1595  }
1596  LinalgOp linalgOpTarget = dyn_cast<LinalgOp>(*linalgOps.begin());
1597  if (!linalgOpTarget)
1598  return emitSilenceableError() << "requires a LinalgOp target";
1600  // Step 2.3. Fail if we can't get the producer / consumer Linalg op.
1601  LinalgOp linalgOp;
1602  if (packOp && packOp.getResult().hasOneUse())
1603  linalgOp = dyn_cast<LinalgOp>(*(packOp.getResult().getUsers().begin()));
1604  else if (unPackOp)
1605  linalgOp = unPackOp.getSource().getDefiningOp<LinalgOp>();
1606  if (linalgOp != linalgOpTarget) {
1607  auto errorMsg =
1608  packOp ? StringLiteral{"not a single use by the LinalgOp target"}
1609  : StringLiteral{"not produced by the LinalgOp target"};
1610  return emitSilenceableError() << errorMsg;
1611  }
1613  // Step 2.4. If we have an UnPackOp, we need to fetch the symmetrical
1614  // PackOp.
1615  if (unPackOp) {
1616  assert(!packOp && "packOp must be null on entry when unPackOp is not null");
1617  OpOperand *packUse = linalgOp.getDpsInitOperand(
1618  cast<OpResult>(unPackOp.getSource()).getResultNumber());
1619  packOp = dyn_cast_or_null<tensor::PackOp>(packUse->get().getDefiningOp());
1620  if (!packOp || !packOp.getResult().hasOneUse())
1621  return emitSilenceableError() << "could not find matching pack op";
1622  }
1624  // Step 2.5. Fail if any permutation does not validate.
1625  for (auto permType : {OuterOrInnerPerm::Outer, OuterOrInnerPerm::Inner}) {
1626  ArrayRef<int64_t> perm =
1627  (permType == OuterOrInnerPerm::Outer) ? getOuterPerm() : getInnerPerm();
1628  auto errorMsg = (permType == OuterOrInnerPerm::Outer)
1629  ? StringLiteral{"invalid outer_perm"}
1630  : StringLiteral{"invalid inner_perm"};
1631  if (!isValidPackingPermutation(packOp, perm, permType) ||
1632  !isValidPackingPermutation(unPackOp, perm, permType)) {
1633  Operation *packOrUnpackOp =
1634  unPackOp ? unPackOp.getOperation() : packOp.getOperation();
1635  return emitSilenceableError() << errorMsg << ": " << *packOrUnpackOp;
1636  }
1637  }
1639  // From here on, packOp and linalgOp are always present, unPackOp may or may
1640  // not be present.
1641  assert(packOp && linalgOp && "unexpected null op");
1643  // Step 3. Actually transpose the ops.
1645  rewriter, packOp, linalgOp, unPackOp, getOuterPerm(), getInnerPerm());
1646  // Preconditions have been checked, it is an error to fail here.
1647  assert(succeeded(res) && "unexpected packTranspose failure");
1649  // Step 4. Return results.
1650  transformResults.set(cast<OpResult>(getPackOp()), {res->transposedPackOp});
1651  transformResults.set(cast<OpResult>(getPackedOp()),
1652  {res->transposedLinalgOp});
1653  if (unPackOp) {
1654  transformResults.set(cast<OpResult>(getUnPackOp()),
1655  {res->transposedUnPackOp});
1656  } else {
1657  transformResults.set(cast<OpResult>(getUnPackOp()), {});
1658  }
1661 }
1663 //===---------------------------------------------------------------------===//
1664 // PadOp
1665 //===---------------------------------------------------------------------===//
1667 void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
1668  ArrayRef<int64_t> paddingDimensions,
1669  ArrayRef<int64_t> padToMultipleOf,
1670  ArrayRef<int64_t> packPaddings,
1671  ArrayRef<Attribute> transposePaddings,
1672  StringRef copyBackOp) {
1673  auto resultType = transform::AnyOpType::get(b.getContext());
1674  return build(/*builder=*/b,
1675  /*result=*/result,
1676  /*types=*/TypeRange{resultType, resultType},
1677  /*target=*/target,
1678  /*paddingValues=*/ArrayAttr(), // let inference handle this
1679  /*paddingDimensions=*/b.getI64ArrayAttr(paddingDimensions),
1680  /*padToMultipleOf=*/
1681  (padToMultipleOf.empty() ? ArrayAttr()
1682  : b.getI64ArrayAttr(padToMultipleOf)),
1683  /*packPaddings=*/b.getI64ArrayAttr(packPaddings),
1684  /*transposePaddings=*/b.getArrayAttr(transposePaddings),
1685  /*copyBackOp=*/b.getStringAttr(copyBackOp));
1686 }
1689 transform::PadOp::apply(transform::TransformRewriter &rewriter,
1690  transform::TransformResults &results,
1691  transform::TransformState &state) {
1692  SmallVector<Operation *> paddedOps, padOps, copyBackOps;
1694  for (Operation *target : state.getPayloadOps(getTarget())) {
1695  auto linalgTarget = dyn_cast<LinalgOp>(target);
1696  if (!linalgTarget) {
1697  auto diag = emitSilenceableError() << "expected LinalgOp target";
1698  diag.attachNote(target->getLoc()) << "target op";
1699  return diag;
1700  }
1702  // Convert the integer packing flags to booleans.
1703  SmallVector<bool> packPaddings;
1704  for (int64_t packPadding :
1705  extractFromIntegerArrayAttr<int64_t>(getPackPaddings()))
1706  packPaddings.push_back(static_cast<bool>(packPadding));
1708  // Convert the padding values to attributes.
1709  SmallVector<Attribute> paddingValues;
1710  for (auto const &it :
1711  llvm::zip(getPaddingValues(), linalgTarget->getOperandTypes())) {
1712  auto attr = dyn_cast<TypedAttr>(std::get<0>(it));
1713  if (!attr) {
1714  emitOpError("expects padding values to be typed attributes");
1716  }
1717  Type elementType = getElementTypeOrSelf(std::get<1>(it));
1718  // Try to parse string attributes to obtain an attribute of element type.
1719  if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
1720  auto parsedAttr = dyn_cast_if_present<TypedAttr>(parseAttribute(
1721  stringAttr, getContext(), elementType,
1722  /*numRead=*/nullptr, /*isKnownNullTerminated=*/true));
1723  if (!parsedAttr || parsedAttr.getType() != elementType) {
1724  auto diag = this->emitOpError("expects a padding that parses to ")
1725  << elementType << ", got " << std::get<0>(it);
1726  diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";
1728  }
1729  paddingValues.push_back(parsedAttr);
1730  continue;
1731  }
1732  // Otherwise, add the attribute directly.
1733  if (attr.getType() != elementType) {
1734  auto diag = this->emitOpError("expects a padding value of type ")
1735  << elementType << ", got " << attr;
1736  diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";
1738  }
1739  paddingValues.push_back(attr);
1740  }
1742  // Extract the transpose vectors.
1743  SmallVector<SmallVector<int64_t>> transposePaddings;
1744  for (Attribute transposeVector : cast<ArrayAttr>(getTransposePaddings()))
1745  transposePaddings.push_back(extractFromIntegerArrayAttr<int64_t>(
1746  cast<ArrayAttr>(transposeVector)));
1748  LinalgOp paddedOp;
1750  options.paddingDimensions =
1751  extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
1752  SmallVector<int64_t> padToMultipleOf(options.paddingDimensions.size(), 1);
1753  if (getPadToMultipleOf().has_value())
1754  padToMultipleOf =
1755  extractFromIntegerArrayAttr<int64_t>(*getPadToMultipleOf());
1756  options.padToMultipleOf = padToMultipleOf;
1757  options.paddingValues = paddingValues;
1758  options.packPaddings = packPaddings;
1759  if (getCopyBackOp() ==
1760  bufferization::MaterializeInDestinationOp::getOperationName()) {
1763  } else if (getCopyBackOp() == linalg::CopyOp::getOperationName()) {
1765  } else if (getCopyBackOp() == kCopyOpNone) {
1767  } else {
1768  llvm_unreachable("unsupported copy_back op");
1769  }
1771  SmallVector<Value> replacements;
1772  SmallVector<tensor::PadOp> newPadOps;
1773  if (failed(rewriteAsPaddedOp(rewriter, linalgTarget, options, paddedOp,
1774  replacements, newPadOps))) {
1775  auto diag = emitSilenceableError() << "failed to pad op";
1776  diag.attachNote(target->getLoc()) << "target op";
1777  return diag;
1778  }
1780  // We need to perform our own replacement here because this API is still
1781  // used in patterns that "pad and hoist", for which the replacement values
1782  // need to be different.
1783  // TODO: clean this up and stop "pad and hoist" behavior more globally now
1784  // that we have more composable abstractions.
1785  rewriter.replaceOp(linalgTarget, replacements);
1786  paddedOps.push_back(paddedOp);
1787  padOps.append(newPadOps.begin(), newPadOps.end());
1788  if (options.copyBackOp != LinalgPaddingOptions::CopyBackOp::None) {
1789  for (Value v : replacements) {
1790  Operation *copyBackOp = v.getDefiningOp();
1791  if (!llvm::is_contained(copyBackOps, copyBackOp))
1792  copyBackOps.push_back(copyBackOp);
1793  }
1794  }
1795  }
1797  results.set(cast<OpResult>(getPadded()), paddedOps);
1798  results.set(cast<OpResult>(getPad()), padOps);
1799  results.set(cast<OpResult>(getCopy()), copyBackOps);
1801 }
1804  SmallVector<int64_t> packPaddings =
1805  extractFromIntegerArrayAttr<int64_t>(getPackPaddings());
1806  if (any_of(packPaddings, [](int64_t packPadding) {
1807  return packPadding != 0 && packPadding != 1;
1808  })) {
1809  return emitOpError()
1810  << "expects pack_paddings to contain booleans (0/1), found "
1811  << getPackPaddings();
1812  }
1814  SmallVector<int64_t> paddingDimensions =
1815  extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
1816  if (any_of(paddingDimensions,
1817  [](int64_t paddingDimension) { return paddingDimension < 0; })) {
1818  return emitOpError() << "expects padding_dimensions to contain positive "
1819  "integers, found "
1820  << getPaddingDimensions();
1821  }
1822  if (getPadToMultipleOf().has_value()) {
1823  if (getPadToMultipleOf()->size() != paddingDimensions.size()) {
1824  return emitOpError() << "expects as many multiples as padding_dimensions";
1825  }
1826  }
1827  ArrayAttr transposes = getTransposePaddings();
1828  for (Attribute attr : transposes) {
1829  SmallVector<int64_t> transpose = extractFromIntegerArrayAttr<int64_t>(attr);
1830  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
1831  if (!std::is_permutation(sequence.begin(), sequence.end(),
1832  transpose.begin(), transpose.end())) {
1833  return emitOpError()
1834  << "expects transpose_paddings to be a permutation, found "
1835  << attr;
1836  }
1837  }
1838  if (getCopyBackOp() !=
1839  bufferization::MaterializeInDestinationOp::getOperationName() &&
1840  getCopyBackOp() != linalg::CopyOp::getOperationName() &&
1841  getCopyBackOp() != kCopyOpNone)
1842  return emitOpError() << "invalid copy_back_op";
1843  return success();
1844 }
1846 //===---------------------------------------------------------------------===//
1847 // HoistPadOp
1848 //===---------------------------------------------------------------------===//
1850 DiagnosedSilenceableFailure transform::HoistPadBuildPackingLoopNestOp::apply(
1851  transform::TransformRewriter &rewriter,
1852  transform::TransformResults &transformResults,
1853  transform::TransformState &state) {
1854  auto targetOps = state.getPayloadOps(getTarget());
1855  auto loopOps = state.getPayloadOps(getLoop());
1856  if (!llvm::hasSingleElement(targetOps) || !llvm::hasSingleElement(loopOps)) {
1857  return emitDefiniteFailure()
1858  << "requires exactly one target and one loop handle (got "
1859  << llvm::range_size(targetOps) << " and "
1860  << llvm::range_size(loopOps) << ")";
1861  }
1863  auto padOp = dyn_cast_or_null<tensor::PadOp>(*targetOps.begin());
1864  auto loopOp = dyn_cast_or_null<scf::ForOp>(*loopOps.begin());
1865  if (!padOp || !loopOp)
1866  return emitDefiniteFailure() << "requires exactly 2 non-null handles";
1869  linalg::detail::buildPackingLoopNest(rewriter, padOp, loopOp,
1870  getTranspose());
1871  if (failed(result))
1872  return emitDefiniteFailure() << "could not build packing loop nest";
1874  if (result->clonedLoopIvs.empty()) {
1875  transformResults.set(cast<OpResult>(getPackingLoop()),
1876  {result->hoistedPadOp.getOperation()});
1878  }
1879  auto outerPackedLoop =
1880  scf::getForInductionVarOwner(result->clonedLoopIvs.front());
1881  transformResults.set(cast<OpResult>(getPackingLoop()),
1882  {outerPackedLoop.getOperation()});
1884 }
1887  ArrayRef<int64_t> transpose = getTranspose();
1888  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
1889  if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),
1890  transpose.end())) {
1891  return emitOpError() << "expects transpose to be a permutation, found "
1892  << getTranspose();
1893  }
1894  return success();
1895 }
1897 void transform::HoistPadBuildPackingLoopNestOp::getEffects(
1899  transform::onlyReadsHandle(getTarget(), effects);
1900  transform::onlyReadsHandle(getLoop(), effects);
1901  transform::producesHandle(getPackingLoop(), effects);
1902  transform::modifiesPayload(effects);
1903 }
1906 transform::HoistPadOp::applyToOne(transform::TransformRewriter &rewriter,
1907  tensor::PadOp target,
1909  transform::TransformState &state) {
1910  tensor::PadOp hoistedPadOp;
1911  SmallVector<GenericOp> transposeOps;
1912  FailureOr<Value> result =
1913  hoistPaddingOnTensors(rewriter, target, getNumLoops(), getTranspose(),
1914  hoistedPadOp, transposeOps);
1915  if (succeeded(result)) {
1916  // We need to perform our own replacement here because this API is still
1917  // used in patterns that "pad and hoist", for which the replacement values
1918  // need to be different.
1919  // TODO: clean this up and stop "pad and hoist" behavior more globally now
1920  // that we have more composable abstractions.
1921  rewriter.replaceOp(target, *result);
1922  results.push_back(hoistedPadOp);
1924  }
1925  return emitDefaultSilenceableFailure(target);
1926 }
1929  ArrayRef<int64_t> transpose = getTranspose();
1930  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
1931  if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),
1932  transpose.end())) {
1933  return emitOpError() << "expects transpose to be a permutation, found "
1934  << getTranspose();
1935  }
1936  return success();
1937 }
1939 //===----------------------------------------------------------------------===//
1940 // PromoteOp
1941 //===----------------------------------------------------------------------===//
1944 transform::PromoteOp::applyToOne(transform::TransformRewriter &rewriter,
1945  LinalgOp target,
1947  transform::TransformState &state) {
1948  LinalgPromotionOptions promotionOptions;
1949  if (!getOperandsToPromote().empty())
1950  promotionOptions = promotionOptions.setOperandsToPromote(
1951  extractFromIntegerArrayAttr<int64_t>(getOperandsToPromote()));
1952  if (getUseFullTilesByDefault())
1953  promotionOptions = promotionOptions.setUseFullTileBuffersByDefault(
1954  getUseFullTilesByDefault());
1955  if (getUseAlloca())
1956  promotionOptions = promotionOptions.setUseAlloca(getUseAlloca());
1957  if (!getUseFullTileBuffers().empty())
1958  promotionOptions = promotionOptions.setUseFullTileBuffers(
1959  llvm::to_vector(getUseFullTileBuffers().getAsValueRange<BoolAttr>()));
1960  if (getAlignment().has_value())
1961  promotionOptions = promotionOptions.setAlignment(*getAlignment());
1962  if (getMemorySpace().has_value())
1963  promotionOptions = promotionOptions.setMemorySpace(*getMemorySpace());
1965  if (getMapping().has_value()) {
1966  // The mapping should only contain an element
1967  auto mapping = *getMapping();
1968  if (mapping.size() > 1)
1969  return emitDefaultDefiniteFailure(target);
1971  auto addressSpace = cast<mlir::gpu::GPUMemorySpaceMappingAttr>(mapping[0]);
1973  if (addressSpace.getAddressSpace() ==
1974  mlir::gpu::GPUDialect::getWorkgroupAddressSpace()) {
1975  promotionOptions =
1976  promotionOptions
1980  .setUseFullTileBuffers({false, false});
1981  } else if (addressSpace.getAddressSpace() ==
1982  mlir::gpu::GPUDialect::getPrivateAddressSpace()) {
1983  promotionOptions =
1984  promotionOptions
1988  .setUseFullTileBuffers({false, false});
1989  } else {
1990  return emitDefaultDefiniteFailure(target);
1991  }
1992  }
1994  if (failed(promoteSubviewsPrecondition(target, promotionOptions)))
1995  return emitDefaultDefiniteFailure(target);
1997  rewriter.setInsertionPoint(target);
1998  FailureOr<LinalgOp> res = promoteSubViews(rewriter, target, promotionOptions);
1999  if (failed(res))
2000  return emitDefaultDefiniteFailure(target);
2001  results.push_back(target);
2003 }
2005 //===----------------------------------------------------------------------===//
2006 // ReplaceOp
2007 //===----------------------------------------------------------------------===//
2010 transform::ReplaceOp::apply(transform::TransformRewriter &rewriter,
2011  TransformResults &transformResults,
2012  TransformState &state) {
2013  auto payload = state.getPayloadOps(getTarget());
2015  // Check for invalid targets.
2016  for (Operation *target : payload) {
2017  if (target->getNumOperands() > 0)
2018  return emitDefiniteFailure() << "expected target without operands";
2019  if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
2020  target->getNumRegions() > 0)
2021  return emitDefiniteFailure()
2022  << "expected target that is isolated from above";
2023  }
2025  // Clone and replace.
2026  Operation *pattern = &getBodyRegion().front().front();
2027  SmallVector<Operation *> replacements;
2028  for (Operation *target : payload) {
2029  if (getOperation()->isAncestor(target))
2030  continue;
2031  rewriter.setInsertionPoint(target);
2032  Operation *replacement = rewriter.clone(*pattern);
2033  rewriter.replaceOp(target, replacement->getResults());
2034  replacements.push_back(replacement);
2035  }
2036  transformResults.set(cast<OpResult>(getReplacement()), replacements);
2038 }
2040 void transform::ReplaceOp::getEffects(
2042  consumesHandle(getTarget(), effects);
2043  producesHandle(getReplacement(), effects);
2044  modifiesPayload(effects);
2045 }
2048  if (!getBodyRegion().hasOneBlock())
2049  return emitOpError() << "expected one block";
2050  if (std::distance(getBodyRegion().front().begin(),
2051  getBodyRegion().front().end()) != 1)
2052  return emitOpError() << "expected one operation in block";
2053  Operation *replacement = &getBodyRegion().front().front();
2054  if (replacement->getNumOperands() > 0)
2055  return replacement->emitOpError()
2056  << "expected replacement without operands";
2057  if (!replacement->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
2058  replacement->getNumRegions() > 0)
2059  return replacement->emitOpError()
2060  << "expect op that is isolated from above";
2061  return success();
2062 }
2064 //===----------------------------------------------------------------------===//
2065 // ScalarizeOp
2066 //===----------------------------------------------------------------------===//
2069 transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
2070  LinalgOp target,
2072  transform::TransformState &state) {
2073  scf::SCFTilingOptions tilingOptions;
2074  tilingOptions.setTileSizeComputationFunction([&](OpBuilder &b, Operation *) {
2075  SmallVector<OpFoldResult> tileSizes;
2076  Location loc = target.getLoc();
2077  SmallVector<OpFoldResult> allShapeSizes =
2078  target.createFlatListOfOperandDims(b, loc);
2079  AffineMap map = target.getShapesToLoopsMap();
2080  if (!map)
2081  return tileSizes;
2082  SmallVector<OpFoldResult> shapeSizes =
2084  allShapeSizes);
2085  // If the shape size is dynamic, tile by 1.
2086  // Otherwise, do not tile (i.e. tile size 0).
2087  for (OpFoldResult shapeSize : shapeSizes) {
2088  tileSizes.push_back(getConstantIntValue(shapeSize) ? b.getIndexAttr(0)
2089  : b.getIndexAttr(1));
2090  }
2091  return tileSizes;
2092  });
2093  SmallVector<int64_t> emptyTileSizes;
2094  rewriter.setInsertionPoint(target);
2095  FailureOr<scf::SCFTilingResult> maybeTilingResult = tileUsingSCF(
2096  rewriter, cast<TilingInterface>(target.getOperation()), tilingOptions);
2097  if (failed(maybeTilingResult))
2098  return emitDefaultDefiniteFailure(target);
2100  if (target->getNumResults())
2101  rewriter.replaceOp(target, maybeTilingResult->replacements);
2102  else
2103  rewriter.eraseOp(target);
2105  results.reserve(maybeTilingResult->tiledOps.size());
2106  for (Operation *tiled : maybeTilingResult->tiledOps)
2107  results.push_back(tiled);
2109 }
2111 //===----------------------------------------------------------------------===//
2112 // ConvertToLoopsOp
2113 //===----------------------------------------------------------------------===//
2116 transform::ConvertToLoopsOp::apply(transform::TransformRewriter &rewriter,
2117  transform::TransformResults &results,
2118  transform::TransformState &state) {
2120  for (Operation *target : state.getPayloadOps(getTarget())) {
2121  auto tilingOp = dyn_cast<TilingInterface>(*target);
2122  if (!target) {
2124  emitSilenceableError()
2125  << "expected the payload to implement TilingInterface";
2126  diag.attachNote(target->getLoc()) << "payload op";
2127  return diag;
2128  }
2129  rewriter.setInsertionPoint(target);
2130  FailureOr<SmallVector<scf::ForOp>> generatedLoops =
2131  scf::lowerToLoopsUsingSCFForOp(rewriter, tilingOp);
2132  if (failed(generatedLoops))
2133  return emitDefaultDefiniteFailure(target);
2134  for (scf::ForOp &loop : *generatedLoops) {
2135  loops.push_back(loop.getOperation());
2136  }
2137  rewriter.eraseOp(target);
2138  }
2139  results.set(cast<OpResult>(getResult()), loops);
2141 }
2143 //===----------------------------------------------------------------------===//
2144 // RewriteInDestinationPassingStyleOp
2145 //===----------------------------------------------------------------------===//
2148 transform::RewriteInDestinationPassingStyleOp::applyToOne(
2149  transform::TransformRewriter &rewriter, Operation *target,
2151  transform::TransformState &state) {
2153  rewriter.setInsertionPoint(target);
2154  FailureOr<Operation *> maybeResult =
2156  .Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp>(
2157  [&rewriter](auto op) {
2158  return rewriteInDestinationPassingStyle(rewriter, op);
2159  });
2160  if (failed(maybeResult))
2161  return emitDefaultSilenceableFailure(target);
2162  results.push_back(*maybeResult);
2164 }
2166 //===----------------------------------------------------------------------===//
2167 // SplitOp
2168 //===----------------------------------------------------------------------===//
2171 SplitOp::apply(transform::TransformRewriter &rewriter,
2172  TransformResults &results, TransformState &state) {
2173  // Collect the dynamic split points if provided.
2174  SmallVector<Operation *> payload =
2175  llvm::to_vector(state.getPayloadOps(getTarget()));
2176  SmallVector<OpFoldResult> splitPoints;
2177  splitPoints.reserve(payload.size());
2178  if (getDynamicSplitPoint()) {
2180  if (isa<TransformHandleTypeInterface>(getDynamicSplitPoint().getType())) {
2181  splitPoints = llvm::to_vector(llvm::map_range(
2182  state.getPayloadOps(getDynamicSplitPoint()), [&](Operation *op) {
2183  if (op->getNumResults() != 1 ||
2184  !op->getResult(0).getType().isIndex()) {
2185  diag = emitSilenceableError()
2186  << "expected dynamic split point handle to point to a "
2187  "single-result index-typed op";
2188  diag.attachNote(op->getLoc()) << "dynamic split point";
2189  }
2190  return OpFoldResult(op->getResult(0));
2191  }));
2192  } else {
2193  splitPoints = llvm::to_vector(
2194  llvm::map_range(state.getParams(getDynamicSplitPoint()),
2195  [](Attribute attr) { return OpFoldResult(attr); }));
2196  }
2197  if (diag.isSilenceableFailure())
2198  return diag;
2200  if (splitPoints.size() != payload.size()) {
2201  return emitDefiniteFailure()
2202  << "expected the dynamic split point handle to point to as "
2203  "many operations ("
2204  << splitPoints.size() << ") as the target handle ("
2205  << payload.size() << ")";
2206  }
2207  } else {
2208  splitPoints.resize(payload.size(),
2209  rewriter.getIndexAttr(getStaticSplitPoint()));
2210  }
2212  // Split each target operation.
2213  SmallVector<Operation *> first, second;
2214  Operation *noSecondPart = nullptr;
2215  for (const auto &pair : llvm::zip(payload, splitPoints)) {
2216  Operation *target = std::get<0>(pair);
2217  auto linalgOp = dyn_cast<LinalgOp>(target);
2218  if (!linalgOp) {
2219  auto diag = emitSilenceableError() << "only applies to structured ops";
2220  diag.attachNote(target->getLoc()) << "target op";
2221  return diag;
2222  }
2224  if (getDimension() >= linalgOp.getNumLoops()) {
2225  auto diag = emitSilenceableError() << "dimension " << getDimension()
2226  << " does not exist in target op";
2227  diag.attachNote(target->getLoc()) << "target op";
2228  return diag;
2229  }
2231  rewriter.setInsertionPoint(linalgOp);
2232  std::tie(first.emplace_back(), second.emplace_back()) = linalg::splitOp(
2233  rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2234  getDimension(), std::get<1>(pair));
2236  // Propagate errors.
2237  if (!first.back() && !second.back()) {
2238  auto diag = emitDefiniteFailure() << "internal failure in splitting";
2239  diag.attachNote(target->getLoc()) << "target op";
2240  return diag;
2241  }
2243  // Do not add null second parts.
2244  if (!second.back()) {
2245  noSecondPart = target;
2246  second.pop_back();
2247  }
2248  }
2250  if (second.size() != first.size() && !second.empty()) {
2251  auto diag = emitSilenceableError()
2252  << "splitting does not produce the second part for a subset "
2253  "of targets";
2254  diag.attachNote() << "expected splitting to produce the second part of all "
2255  "or none of the targets";
2256  diag.attachNote(noSecondPart->getLoc())
2257  << "first target with no second part";
2258  return diag;
2259  }
2261  results.set(cast<OpResult>(getFirst()), first);
2262  results.set(cast<OpResult>(getSecond()), second);
2264 }
2266 void SplitOp::getEffects(
2268  consumesHandle(getTarget(), effects);
2269  if (getDynamicSplitPoint())
2270  onlyReadsHandle(getDynamicSplitPoint(), effects);
2271  producesHandle(getResults(), effects);
2272  modifiesPayload(effects);
2273 }
2276  OpAsmParser::UnresolvedOperand target, dynamicSplitPoint;
2277  IntegerAttr staticSplitPoint;
2278  if (parser.parseOperand(target) || parser.parseKeyword("after"))
2279  return failure();
2281  OptionalParseResult dynamicPointParseResult =
2282  parser.parseOptionalOperand(dynamicSplitPoint);
2283  if (!dynamicPointParseResult.has_value()) {
2284  int64_t staticSplitPointValue;
2285  if (failed(parser.parseInteger(staticSplitPointValue)))
2286  return failure();
2288  staticSplitPoint =
2289  parser.getBuilder().getI64IntegerAttr(staticSplitPointValue);
2290  }
2292  Type targetType;
2293  if (parser.parseOptionalAttrDict(result.attributes) ||
2294  parser.parseColonType(targetType) ||
2295  parser.resolveOperand(target, targetType, result.operands)) {
2296  return failure();
2297  }
2298  if (dynamicPointParseResult.has_value()) {
2299  Type splitPointType;
2300  if (failed(*dynamicPointParseResult) || parser.parseComma() ||
2301  parser.parseType(splitPointType) ||
2302  parser.resolveOperand(dynamicSplitPoint, splitPointType,
2303  result.operands)) {
2304  return failure();
2305  }
2307  staticSplitPoint =
2308  parser.getBuilder().getI64IntegerAttr(ShapedType::kDynamic);
2309  }
2311  result.addAttribute(
2312  SplitOp::getStaticSplitPointAttrName(,
2313  staticSplitPoint);
2314  result.addTypes({targetType, targetType});
2315  return success();
2316 }
2318 void SplitOp::print(OpAsmPrinter &printer) {
2319  printer << " " << getTarget() << " after ";
2320  int64_t staticSplitSize = static_cast<int64_t>(getStaticSplitPoint());
2321  if (staticSplitSize != ShapedType::kDynamic)
2322  printer << staticSplitSize;
2323  else
2324  printer << getDynamicSplitPoint();
2325  printer << " ";
2326  printer.printOptionalAttrDict(getOperation()->getAttrs(),
2327  {getStaticSplitPointAttrName()});
2328  printer << " : " << getTarget().getType();
2329  if (staticSplitSize == ShapedType::kDynamic)
2330  printer << ", " << getDynamicSplitPoint().getType();
2331 }
2334  if ((static_cast<int64_t>(getStaticSplitPoint()) != ShapedType::kDynamic) ^
2335  (getDynamicSplitPoint() == nullptr)) {
2336  return emitOpError() << "expects either a dynamic or a static split "
2337  "point to be provided";
2338  }
2339  return success();
2340 }
2342 //===----------------------------------------------------------------------===//
2343 // SplitReductionOp
2344 //===----------------------------------------------------------------------===//
2346 void transform::SplitReductionOp::build(
2347  OpBuilder &builder, OperationState &result, Value target,
2348  int64_t splitFactor, int64_t insertSplitDimension, bool innerParallel,
2349  bool useScalingAlgorithm, bool useAlloc) {
2350  MLIRContext *ctx = builder.getContext();
2351  result.addOperands(target);
2352  result.addAttribute(SplitReductionOp::getSplitFactorAttrName(,
2353  builder.getI64IntegerAttr(splitFactor));
2354  result.addAttribute(
2355  SplitReductionOp::getInsertSplitDimensionAttrName(,
2356  builder.getI64IntegerAttr(insertSplitDimension));
2357  if (innerParallel) {
2358  result.addAttribute(SplitReductionOp::getInnerParallelAttrName(,
2359  builder.getUnitAttr());
2360  }
2361  if (useScalingAlgorithm) {
2362  result.addAttribute(
2363  SplitReductionOp::getUseScalingAlgorithmAttrName(,
2364  builder.getUnitAttr());
2365  }
2366  if (useAlloc) {
2367  result.addAttribute(SplitReductionOp::getUseAllocAttrName(,
2368  builder.getUnitAttr());
2369  }
2370  auto resultType = transform::AnyOpType::get(ctx);
2371  result.addTypes({resultType, resultType, resultType, resultType});
2372 }
2374 DiagnosedSilenceableFailure transform::SplitReductionOp::applyToOne(
2375  transform::TransformRewriter &rewriter, LinalgOp target,
2377  transform::TransformState &state) {
2378  ControlSplitReductionFn splitFn = [&](LinalgOp) {
2379  return linalg::SplitReductionOptions{int64_t(getSplitFactor()),
2380  unsigned(getInsertSplitDimension()),
2381  bool(getInnerParallel())};
2382  };
2383  rewriter.setInsertionPoint(target);
2384  FailureOr<SplitReductionResult> splitResult =
2385  (getUseScalingAlgorithm())
2386  ? splitReductionByScaling(rewriter, target, splitFn, getUseAlloc())
2387  : splitReduction(rewriter, target, splitFn, getUseAlloc());
2388  if (failed(splitResult))
2389  return emitDefaultDefiniteFailure(target);
2391  results.push_back(splitResult->initOrAlloc);
2392  results.push_back(splitResult->fillOp);
2393  results.push_back(splitResult->splitLinalgOp);
2394  results.push_back(splitResult->resultCombiningLinalgOp);
2396 }
2398 //===----------------------------------------------------------------------===//
2399 // TileReductionUsingForOp
2400 //===----------------------------------------------------------------------===//
2402 void transform::TileReductionUsingForOp::build(
2403  OpBuilder &builder, OperationState &result, Value target,
2404  ArrayRef<int64_t> staticTileSizes) {
2405  // Call the default builder.
2406  // This is future-proof re mixed static-dynamic and setting up the proper
2407  // operands segment sizes attributes for multiple variadic operands.
2408  // In the absence of this, horrible bugs ensue.
2409  // TODO: support mixed static-dynamic (see TileUsingForallOp).
2410  MLIRContext *ctx = builder.getContext();
2411  auto opTy = transform::AnyOpType::get(ctx);
2412  auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
2413  build(builder, result,
2414  /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
2415  /*target=*/target,
2416  /*tile_sizes=*/staticTileSizesAttr);
2417 }
2419 DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
2420  transform::TransformRewriter &rewriter, LinalgOp target,
2422  transform::TransformState &state) {
2423  rewriter.setInsertionPoint(target);
2425  rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
2426  getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes())));
2428  if (failed(result))
2429  return emitDefaultSilenceableFailure(target);
2430  results.push_back(result->initialOp);
2431  results.push_back(result->parallelTiledOp);
2432  results.push_back(result->mergeOp);
2433  results.push_back(result->loops.front());
2435 }
2437 //===----------------------------------------------------------------------===//
2438 // TileReductionUsingForallOp
2439 //===----------------------------------------------------------------------===//
2441 void transform::TileReductionUsingForallOp::build(
2442  OpBuilder &builder, OperationState &result, Value target,
2443  ArrayRef<int64_t> staticNumThreads, ArrayRef<int64_t> staticTileSizes,
2444  ArrayAttr mapping) {
2445  // Call the default builder.
2446  // This is future-proof re mixed static-dynamic and setting up the proper
2447  // operands segment sizes attributes for multiple variadic operands.
2448  // In the absence of this, horrible bugs ensue.
2449  // TODO: support mixed static-dynamic (see TileUsingForallOp).
2450  MLIRContext *ctx = builder.getContext();
2451  auto opTy = transform::AnyOpType::get(ctx);
2452  auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads);
2453  auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
2454  build(builder, result,
2455  /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
2456  /*target=*/target,
2457  /*num_threads=*/staticNumThreadsAttr,
2458  /*tile_sizes=*/staticTileSizesAttr,
2459  /*mapping=*/mapping);
2460 }
2462 DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
2463  transform::TransformRewriter &rewriter, LinalgOp target,
2465  transform::TransformState &state) {
2466  rewriter.setInsertionPoint(target);
2467  SmallVector<OpFoldResult> numThreads =
2468  getAsOpFoldResult(rewriter.getI64ArrayAttr(getNumThreads()));
2469  SmallVector<OpFoldResult> tileSizes =
2470  getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes()));
2473  rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
2474  numThreads, tileSizes, getMapping());
2476  if (failed(result)) {
2477  auto diag = emitSilenceableError() << "could not tile reduction";
2478  diag.attachNote(target.getLoc()) << "target operation";
2479  return diag;
2480  }
2481  results.push_back(result->initialOp);
2482  results.push_back(result->parallelTiledOp);
2483  results.push_back(result->mergeOp);
2484  results.push_back(result->loops);
2486 }
2488 //===----------------------------------------------------------------------===//
2489 // TileUsingForOp
2490 //===----------------------------------------------------------------------===//
2492 void transform::TileUsingForOp::build(
2493  OpBuilder &builder, OperationState &result, TypeRange loopTypes,
2494  Value target, ArrayRef<int64_t> staticTileSizes,
2495  ArrayRef<int64_t> interchange,
2496  std::optional<ArrayRef<bool>> scalableSizes) {
2497  return build(builder, result, loopTypes,
2498  /*target=*/target,
2499  /*mixedTileSizes=*/
2500  getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
2501  interchange, scalableSizes);
2502 }
2504 void transform::TileUsingForOp::build(
2505  OpBuilder &builder, OperationState &result, Value target,
2506  ArrayRef<int64_t> staticTileSizes, ArrayRef<int64_t> interchange,
2507  std::optional<ArrayRef<bool>> scalableSizes) {
2508  build(builder, result, target,
2509  getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
2510  interchange, scalableSizes);
2511 }
2513 void transform::TileUsingForOp::build(
2514  OpBuilder &builder, OperationState &result, Value target,
2515  ArrayRef<OpFoldResult> mixedTileSizes, ArrayRef<int64_t> interchange,
2516  std::optional<ArrayRef<bool>> scalableSizes) {
2517  // Loop types are automaticaly splat by the callee, setting up one is
2518  // enough.
2519  SmallVector<Type> loopTypes(1, builder.getType<transform::AnyOpType>());
2520  build(builder, result, loopTypes, target, mixedTileSizes, interchange,
2521  scalableSizes);
2522 }
2524 void transform::TileUsingForOp::build(
2525  OpBuilder &builder, OperationState &result, TypeRange loopTypes,
2526  Value target, ArrayRef<OpFoldResult> mixedTileSizes,
2527  ArrayRef<int64_t> interchange,
2528  std::optional<ArrayRef<bool>> scalableSizes) {
2529  SmallVector<int64_t> staticTileSizes;
2530  SmallVector<Value> dynamicTileSizes;
2531  dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
2532  // Call the default builder which sets up the proper operands segment sizes
2533  // attributes for multiple variadic operands. In the absence of this,
2534  // horrible bugs ensue.
2535  auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
2536  unsigned numExpectedLoops =
2537  staticTileSizes.size() - llvm::count(staticTileSizes, 0);
2538  SmallVector<Type> resultTypes;
2539  resultTypes.reserve(numExpectedLoops);
2540  assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) &&
2541  "expected one loop type or as many as loops");
2542  if (loopTypes.size() == 1)
2543  resultTypes.append(numExpectedLoops, loopTypes[0]);
2544  else
2545  llvm::append_range(resultTypes, loopTypes);
2546  SmallVector<bool> expandedScalableSizes(mixedTileSizes.size(), false);
2547  if (scalableSizes.has_value())
2548  expandedScalableSizes.assign(scalableSizes->begin(), scalableSizes->end());
2549  build(builder, result, /*tiled_linalg_op=*/target.getType(),
2550  /*loops=*/resultTypes,
2551  /*target=*/target,
2552  /*dynamic_sizes=*/dynamicTileSizes,
2553  /*static_sizes=*/staticTileSizesAttr,
2554  /*interchange=*/builder.getDenseI64ArrayAttr(interchange),
2555  /*scalable_sizes=*/expandedScalableSizes);
2556 }
2559  if (getMixedSizes().size() != getScalableSizes().size())
2560  return emitOpError("expected same number of sizes (")
2561  << getMixedSizes().size() << ") and scalable sizes ()"
2562  << getScalableSizes().size() << ")";
2563  return success();
2564 }
2567 transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter,
2568  TransformResults &transformResults,
2569  TransformState &state) {
2570  ArrayRef<int64_t> tileSizes = getStaticSizes();
2572  SmallVector<Operation *> targets =
2573  llvm::to_vector(state.getPayloadOps(getTarget()));
2574  SmallVector<SmallVector<Operation *>> dynamicSizeProducers;
2576  dynamicSizeProducers.reserve(getDynamicSizes().size());
2577  paramSizes.reserve(getDynamicSizes().size());
2578  for (Value transformValue : getDynamicSizes()) {
2579  if (isa<ParamType>(transformValue.getType())) {
2580  dynamicSizeProducers.push_back({});
2581  ArrayRef<Attribute> params = state.getParams(transformValue);
2582  paramSizes.push_back(
2583  llvm::to_vector(llvm::map_range(params, [](Attribute attr) {
2584  return cast<IntegerAttr>(attr).getValue().getSExtValue();
2585  })));
2587  if (paramSizes.back().size() != targets.size()) {
2589  emitSilenceableError()
2590  << "expected as many parameter values ("
2591  << dynamicSizeProducers.back().size() << ") as target ops ("
2592  << targets.size() << ")";
2593  diag.attachNote(transformValue.getLoc()) << "for this parameter";
2594  return diag;
2595  }
2597  continue;
2598  }
2599  paramSizes.push_back({});
2600  dynamicSizeProducers.push_back(
2601  llvm::to_vector(state.getPayloadOps(transformValue)));
2603  if (dynamicSizeProducers.back().size() != targets.size()) {
2605  emitSilenceableError()
2606  << "expected as many dynamic size-producing operations ("
2607  << dynamicSizeProducers.back().size() << ") as target ops ("
2608  << targets.size() << ")";
2609  diag.attachNote(transformValue.getLoc()) << "for this handle";
2610  return diag;
2611  }
2613  for (Operation *op : dynamicSizeProducers.back()) {
2614  if (op->getNumResults() == 1 &&
2615  isa<IndexType>(op->getResult(0).getType())) {
2616  continue;
2617  }
2620  emitSilenceableError() << "expected sizes to be produced by ops "
2621  "with a single index-type result";
2622  diag.attachNote(op->getLoc()) << "size producer op";
2623  diag.attachNote(transformValue.getLoc()) << "for this handle";
2624  return diag;
2625  }
2626  }
2630  loops.resize(getLoops().size());
2631  auto scalableSizes = getScalableSizes();
2632  for (auto [i, op] : llvm::enumerate(targets)) {
2633  auto tilingInterface = dyn_cast<TilingInterface>(op);
2634  if (!tilingInterface) {
2636  emitSilenceableError()
2637  << "only ops implementing TilingInterface are supported";
2638  diag.attachNote(op->getLoc()) << "target op";
2639  return diag;
2640  }
2641  if (tileSizes.size() > tilingInterface.getLoopIteratorTypes().size()) {
2643  emitSilenceableError()
2644  << "too many tiles provided, expected at most "
2645  << tilingInterface.getLoopIteratorTypes().size() << " found "
2646  << tileSizes.size();
2647  diag.attachNote(op->getLoc()) << "target op";
2648  return diag;
2649  }
2651  scf::SCFTilingOptions tilingOptions;
2652  if (tileSizes.empty()) {
2653  tilingOptions.setTileSizeComputationFunction(
2655  return {};
2656  });
2657  } else {
2658  tilingOptions.setTileSizeComputationFunction([&, index = i](OpBuilder &b,
2659  Operation *) {
2661  sizes.reserve(tileSizes.size());
2662  unsigned dynamicIdx = 0;
2664  for (auto [ofrIdx, ofr] : llvm::enumerate(getMixedSizes())) {
2665  if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) {
2666  if (scalableSizes[ofrIdx]) {
2667  auto val = b.create<arith::ConstantIndexOp>(
2668  getLoc(), attr.cast<IntegerAttr>().getInt());
2669  Value vscale =
2670  b.create<vector::VectorScaleOp>(getLoc(), b.getIndexType());
2671  sizes.push_back(
2672  b.create<arith::MulIOp>(getLoc(), val, vscale).getResult());
2673  } else {
2674  sizes.push_back(attr);
2675  }
2676  continue;
2677  }
2678  ArrayRef<Operation *> dynamicSizes = dynamicSizeProducers[dynamicIdx];
2679  ArrayRef<int64_t> params = paramSizes[dynamicIdx];
2680  ++dynamicIdx;
2681  assert((dynamicSizes.empty() ^ params.empty()) &&
2682  "expected either dynamic sizes or parameters");
2683  if (!params.empty()) {
2684  sizes.push_back(b.getIndexAttr(params[index]));
2685  } else {
2686  sizes.push_back(dynamicSizes[index]->getResult(0));
2687  }
2688  }
2689  return sizes;
2690  });
2691  }
2693  tilingOptions.setInterchange(getInterchange());
2694  FailureOr<scf::SCFTilingResult> maybeTilingResult =
2695  tileUsingSCF(rewriter, tilingInterface, tilingOptions);
2696  if (failed(maybeTilingResult))
2699  rewriter.replaceOp(op, maybeTilingResult->replacements);
2701  tiled.append(maybeTilingResult->tiledOps);
2702  for (const auto &en2 : llvm::enumerate(maybeTilingResult->loops))
2703  loops[en2.index()].push_back(en2.value());
2704  }
2706  transformResults.set(cast<OpResult>(getTiledLinalgOp()), tiled);
2707  for (const auto &en : llvm::enumerate(loops))
2708  transformResults.set(cast<OpResult>(getLoops()[en.index()]), en.value());
2711 }
2714  ValueRange dynamic = getDynamicSizes();
2715  ArrayRef<int64_t> tileSizes = getStaticSizes();
2716  SmallVector<OpFoldResult> results;
2717  results.reserve(tileSizes.size());
2718  unsigned dynamicPos = 0;
2719  Builder builder(getContext());
2720  for (int64_t size : tileSizes) {
2721  if (size == ShapedType::kDynamic) {
2722  results.push_back(dynamic[dynamicPos++]);
2723  } else {
2724  results.push_back(builder.getIndexAttr(size));
2725  }
2726  }
2727  return results;
2728 }
2730 // We want to parse `DenseI64ArrayAttr` using the short form without the
2731 // `array` prefix to be consistent in the IR with `parseDynamicIndexList`.
2733  OperationState &result) {
2734  if (failed(parser.parseOptionalKeyword("interchange")))
2735  return success();
2736  if (failed(parser.parseEqual()))
2737  return failure();
2738  result.addAttribute(
2739  transform::TileUsingForOp::getInterchangeAttrName(,
2740  DenseI64ArrayAttr::parse(parser, Type{}));
2741  return success();
2742 }
2745  ArrayRef<int64_t> interchangeVals) {
2746  if (!interchangeVals.empty()) {
2747  p << " interchange = [";
2748  llvm::interleaveComma(interchangeVals, p,
2749  [&](int64_t integer) { p << integer; });
2750  p << "]";
2751  }
2752 }
2755  OperationState &result) {
2758  DenseI64ArrayAttr staticSizes;
2759  FunctionType functionalType;
2760  llvm::SMLoc operandLoc;
2761  DenseBoolArrayAttr scalableVals;
2763  if (parser.parseOperand(target) || parser.getCurrentLocation(&operandLoc) ||
2764  parseDynamicIndexList(parser, dynamicSizes, staticSizes, scalableVals) ||
2765  parseOptionalInterchange(parser, result) ||
2766  parser.parseOptionalAttrDict(result.attributes) ||
2767  parser.parseColonType(functionalType))
2768  return ParseResult::failure();
2770  size_t numExpectedLoops =
2771  staticSizes.size() - llvm::count(staticSizes.asArrayRef(), 0);
2772  if (functionalType.getNumResults() != numExpectedLoops + 1) {
2773  return parser.emitError(parser.getNameLoc())
2774  << "expected " << (numExpectedLoops + 1) << " result type(s)";
2775  }
2776  if (functionalType.getNumInputs() != dynamicSizes.size() + 1) {
2777  return parser.emitError(operandLoc)
2778  << "expected " << dynamicSizes.size() + 1 << " operand type(s)";
2779  }
2780  if (parser.resolveOperand(target, functionalType.getInputs().front(),
2781  result.operands) ||
2782  parser.resolveOperands(dynamicSizes,
2783  functionalType.getInputs().drop_front(),
2784  operandLoc, result.operands)) {
2785  return failure();
2786  }
2788  result.addAttribute(getScalableSizesAttrName(, scalableVals);
2790  result.addAttribute(getStaticSizesAttrName(, staticSizes);
2791  result.addTypes(functionalType.getResults());
2792  return success();
2793 }
2796  p << ' ' << getTarget();
2797  printDynamicIndexList(p, getOperation(), getDynamicSizes(), getStaticSizes(),
2798  /*valueTypes=*/{}, getScalableSizesAttr(),
2800  printOptionalInterchange(p, getInterchange());
2802  (*this)->getAttrs(),
2803  /*elidedAttrs=*/{getInterchangeAttrName(getOperation()->getName()),
2804  getScalableSizesAttrName(getOperation()->getName()),
2805  getStaticSizesAttrName(getOperation()->getName())});
2806  p << " : ";
2807  p.printFunctionalType(getOperands().getTypes(), getResults().getTypes());
2808 }
2810 void transform::TileUsingForOp::getEffects(
2812  consumesHandle(getTarget(), effects);
2813  onlyReadsHandle(getDynamicSizes(), effects);
2814  producesHandle(getTiledLinalgOp(), effects);
2815  producesHandle(getLoops(), effects);
2816  modifiesPayload(effects);
2817 }
2819 //===----------------------------------------------------------------------===//
2820 // TileUsingForallOp
2821 //===----------------------------------------------------------------------===//
2823 void transform::TileUsingForallOp::build(OpBuilder &builder,
2824  OperationState &result, Value target,
2825  ArrayRef<int64_t> staticTileSizes,
2827  ArrayAttr mapping) {
2828  return build(builder, result,
2829  /*target=*/target,
2830  /*mixedTileSizes=*/
2831  getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
2832  /*_=*/TileSizesSpec(),
2833  /*mapping=*/mapping);
2834 }
2836 void transform::TileUsingForallOp::build(OpBuilder &builder,
2837  OperationState &result, Value target,
2838  ArrayRef<OpFoldResult> mixedTileSizes,
2840  ArrayAttr mapping) {
2841  SmallVector<int64_t> staticTileSizes;
2842  SmallVector<Value> dynamicTileSizes;
2843  dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
2844  // Call the default builder which sets up the proper operands segment sizes
2845  // attributes for multiple variadic operands. In the absence of this,
2846  // horrible bugs ensue.
2847  MLIRContext *ctx = builder.getContext();
2848  auto operationType = transform::AnyOpType::get(ctx);
2849  auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
2850  build(builder, result,
2851  /*resultTypes=*/TypeRange{operationType, operationType},
2852  /*target=*/target,
2853  /*num_threads=*/ValueRange{},
2854  /*tile_sizes=*/dynamicTileSizes,
2855  /*packed_num_threads=*/Value(),
2856  /*packed_tile_sizes=*/Value(),
2857  /*static_num_threads=*/builder.getDenseI64ArrayAttr({}),
2858  /*static_tile_sizes=*/staticTileSizesAttr,
2859  /*mapping=*/mapping);
2860 }
2862 void transform::TileUsingForallOp::build(OpBuilder &builder,
2863  OperationState &result, Value target,
2864  ArrayRef<int64_t> staticNumThreads,
2866  ArrayAttr mapping) {
2867  return build(builder, result, target,
2868  getAsOpFoldResult(builder.getI64ArrayAttr(staticNumThreads)),
2869  NumThreadsSpec(), mapping);
2870 }
2872 void transform::TileUsingForallOp::build(OpBuilder &builder,
2873  OperationState &result, Value target,
2874  ArrayRef<OpFoldResult> mixedNumThreads,
2876  ArrayAttr mapping) {
2877  SmallVector<int64_t> staticNumThreads;
2878  SmallVector<Value> dynamicNumThreads;
2879  dispatchIndexOpFoldResults(mixedNumThreads, dynamicNumThreads,
2880  staticNumThreads);
2881  // Call the default builder which sets up the proper operands segment sizes
2882  // attributes for multiple variadic operands. In the absence of this,
2883  // horrible bugs ensue.
2884  MLIRContext *ctx = builder.getContext();
2885  auto operationType = transform::AnyOpType::get(ctx);
2886  auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads);
2887  build(builder, result,
2888  /*resultTypes=*/TypeRange{operationType, operationType},
2889  /*target=*/target,
2890  /*num_threads=*/dynamicNumThreads,
2891  /*tile_sizes=*/ValueRange{},
2892  /*packed_num_threads=*/Value(),
2893  /*packed_tile_sizes=*/Value(),
2894  /*static_num_threads=*/staticNumThreadsAttr,
2895  /*static_tile_sizes=*/builder.getDenseI64ArrayAttr({}),
2896  /*mapping=*/mapping);
2897 }
2900  RewriterBase &rewriter, transform::TransformState &state,
2901  TransformOpInterface transformOp, Operation *target,
2902  ArrayRef<OpFoldResult> mixedNumThreads,
2903  ArrayRef<OpFoldResult> mixedTileSizes, std::optional<ArrayAttr> mapping,
2904  linalg::ForallTilingResult &tilingResult) {
2905  // Transform all targets one by one.
2906  auto tileableOp = dyn_cast<TilingInterface>(target);
2907  if (!tileableOp) {
2909  transformOp.emitSilenceableError()
2910  << "only TilingInterface ops are supported";
2911  diag.attachNote(target->getLoc()) << "target op";
2912  return diag;
2913  }
2914  rewriter.setInsertionPoint(tileableOp);
2915  FailureOr<linalg::ForallTilingResult> maybeTilingResult = failure();
2916  if (!mixedNumThreads.empty()) {
2917  maybeTilingResult =
2918  linalg::tileToForallOp(rewriter, tileableOp, mixedNumThreads, mapping);
2919  } else {
2920  maybeTilingResult = linalg::tileToForallOpUsingTileSizes(
2921  rewriter, tileableOp, mixedTileSizes, mapping);
2922  }
2924  if (failed(maybeTilingResult))
2925  return transformOp.emitDefaultSilenceableFailure(tileableOp);
2926  rewriter.replaceOp(tileableOp, maybeTilingResult->tileOp->getResults());
2928  tilingResult = *maybeTilingResult;
2930 }
2932 DiagnosedSilenceableFailure transform::TileUsingForallOp::apply(
2933  transform::TransformRewriter &rewriter,
2934  transform::TransformResults &transformResults,
2935  transform::TransformState &state) {
2936  auto transformOp = cast<TransformOpInterface>(getOperation());
2938  // Result payload ops.
2939  SmallVector<Operation *> tileOps;
2940  SmallVector<Operation *> tiledOps;
2942  // Unpack handles.
2943  SmallVector<OpFoldResult> mixedNumThreads;
2945  getPackedNumThreads()
2947  state, transformOp, mixedNumThreads, getPackedNumThreads())
2949  state, transformOp, mixedNumThreads, getMixedNumThreads());
2950  if (!status.succeeded())
2951  return status;
2952  SmallVector<OpFoldResult> mixedTileSizes;
2953  status = getPackedTileSizes()
2955  state, transformOp, mixedTileSizes, getPackedTileSizes())
2957  state, transformOp, mixedTileSizes, getMixedTileSizes());
2958  if (!status.succeeded())
2959  return status;
2961  for (Operation *target : state.getPayloadOps(getTarget())) {
2962  linalg::ForallTilingResult tilingResult;
2964  rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes,
2965  getMapping(), tilingResult);
2966  if (!diag.succeeded())
2967  return diag;
2968  tileOps.push_back(tilingResult.tileOp);
2969  tiledOps.push_back(tilingResult.tiledOp);
2970  }
2972  transformResults.set(cast<OpResult>(getForallOp()), tileOps);
2973  transformResults.set(cast<OpResult>(getTiledOp()), tiledOps);
2976 }
2978 void transform::TileUsingForallOp::getEffects(
2980  consumesHandle(getTarget(), effects);
2981  onlyReadsHandle(getTileSizes(), effects);
2982  onlyReadsHandle(getNumThreads(), effects);
2983  onlyReadsHandle(getPackedNumThreads(), effects);
2984  onlyReadsHandle(getPackedTileSizes(), effects);
2985  producesHandle(getResults(), effects);
2986  modifiesPayload(effects);
2987 }
2989 SmallVector<OpFoldResult> TileUsingForallOp::getMixedNumThreads() {
2990  Builder b(getContext());
2991  return getMixedValues(getStaticNumThreads(), getNumThreads(), b);
2992 }
2994 SmallVector<OpFoldResult> TileUsingForallOp::getMixedTileSizes() {
2995  Builder b(getContext());
2996  return getMixedValues(getStaticTileSizes(), getTileSizes(), b);
2997 }
3000  int numThreadsSpec = static_cast<int>(!getMixedNumThreads().empty()) +
3001  static_cast<int>(getPackedNumThreads() != Value());
3002  if (numThreadsSpec > 1)
3003  return emitOpError(
3004  "num_threads and packed_num_threads are mutually exclusive");
3005  int tileSizesSpec = static_cast<int>(!getMixedTileSizes().empty()) +
3006  static_cast<int>(getPackedTileSizes() != Value());
3007  if (tileSizesSpec > 1)
3008  return emitOpError(
3009  "tile_sizes and packed_tile_sizes are mutually exclusive");
3010  if (numThreadsSpec == 0 && tileSizesSpec == 0)
3011  return emitOpError("either (packed_)num_threads or (packed_)tile_sizes "
3012  "must be specified");
3013  return success();
3014 }
3016 //===----------------------------------------------------------------------===//
3017 // VectorizeChildrenAndApplyPatternsOp
3018 //===----------------------------------------------------------------------===//
3020 void transform::VectorizeChildrenAndApplyPatternsOp::build(
3021  OpBuilder &builder, OperationState &result, Value target,
3022  bool vectorizePadding, bool vectorizeExtract, bool flatten1DDepthwiseConv) {
3023  result.addOperands(target);
3024  if (vectorizePadding) {
3025  result.addAttribute(
3026  VectorizeChildrenAndApplyPatternsOp::getVectorizePaddingAttrName(
3028  builder.getUnitAttr());
3029  }
3030  if (vectorizeExtract) {
3031  result.addAttribute(
3032  VectorizeChildrenAndApplyPatternsOp::getVectorizeNdExtractAttrName(
3034  builder.getUnitAttr());
3035  }
3036  if (flatten1DDepthwiseConv) {
3037  result.addAttribute(
3038  VectorizeChildrenAndApplyPatternsOp::getFlatten_1dDepthwiseConvAttrName(
3040  builder.getUnitAttr());
3041  }
3042  result.addTypes(transform::AnyOpType::get(builder.getContext()));
3043 }
3045 namespace {
3046 /// This is an helper only to call vectorize via a pattern inside of
3047 /// VectorizeChildrenAndApplyPatternsOp::applyToOne.
3048 struct VectorizationPattern : public RewritePattern {
3049  explicit VectorizationPattern(MLIRContext *context,
3050  bool vectorizeExtract = false,
3051  bool flattenConv = false)
3052  : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context),
3053  vectorizeNDExtract(vectorizeExtract),
3054  flatten1DDepthwiseConv(flattenConv) {}
3055  LogicalResult matchAndRewrite(Operation *op,
3056  PatternRewriter &rewriter) const override {
3057  LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
3058  if (!linalgOp)
3059  return rewriter.notifyMatchFailure(op, "expected Linalg Op");
3060  return vectorize(rewriter, linalgOp, /*inputVectorSizes=*/{},
3061  /*scalableVecDims=*/{}, vectorizeNDExtract,
3062  flatten1DDepthwiseConv);
3063  }
3065 private:
3066  /// Controls whether to vectorize `tensor.extract` when the input tensor is
3067  /// rank >= 2.
3068  bool vectorizeNDExtract = false;
3069  /// Controls whether to "flatten" the channel dimension when vectorising 1D
3070  /// depthwise convolutions. This should lead to bette vectorization for
3071  /// tensors with a low number of channel dimensions.
3072  bool flatten1DDepthwiseConv = false;
3073 };
3074 } // namespace
3077 transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
3078  transform::TransformRewriter &rewriter, Operation *target,
3080  transform::TransformState &state) {
3081  if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
3082  auto diag = this->emitOpError("requires isolated-from-above targets");
3083  diag.attachNote(target->getLoc()) << "non-isolated target";
3085  }
3087  MLIRContext *ctx = getContext();
3088  RewritePatternSet patterns(ctx);
3089  patterns.add<VectorizationPattern>(ctx, getVectorizeNdExtract(),
3090  getFlatten_1dDepthwiseConv());
3092  if (!getDisableTransferPermutationMapLoweringPatterns())
3095  if (!getDisableMultiReductionToContractPatterns())
3102  /*benefit=*/2);
3103  vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx);
3104  vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx);
3107  patterns.add<CopyVectorizationPattern>(ctx);
3109  if (getVectorizePadding())
3112  TrackingListener listener(state, *this);
3113  GreedyRewriteConfig config;
3114  config.listener = &listener;
3115  if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns), config)))
3116  return emitDefaultDefiniteFailure(target);
3118  results.push_back(target);
3120 }
3122 //===----------------------------------------------------------------------===//
3123 // VectorizeOp
3124 //===----------------------------------------------------------------------===//
3126 static const StringLiteral kVectorSizesKeyword = "vector_sizes";
3129  OperationState &result) {
3132  DenseI64ArrayAttr staticSizes;
3133  SmallVector<Type> operandTypes;
3134  llvm::SMLoc operandLoc;
3135  DenseBoolArrayAttr scalableVals;
3137  if (parser.parseOperand(target) || parser.getCurrentLocation(&operandLoc))
3138  return ParseResult::failure();
3141  if (failed(parseDynamicIndexList(parser, dynamicSizes, staticSizes,
3142  scalableVals)))
3143  return ParseResult::failure();
3144  }
3146  if (succeeded(parser.parseOptionalKeyword(
3147  getVectorizeNdExtractAttrName(
3148  result.addAttribute(getVectorizeNdExtractAttrName(,
3149  parser.getBuilder().getUnitAttr());
3151  if (parser.parseOptionalAttrDict(result.attributes) ||
3152  parser.parseColonTypeList(operandTypes))
3153  return ParseResult::failure();
3155  if (operandTypes.size() != dynamicSizes.size() + 1) {
3156  return parser.emitError(operandLoc)
3157  << "expected " << dynamicSizes.size() + 1 << " operand type(s)";
3158  }
3159  if (parser.resolveOperand(target, operandTypes.front(), result.operands) ||
3160  parser.resolveOperands(dynamicSizes, ArrayRef(operandTypes).drop_front(),
3161  operandLoc, result.operands)) {
3162  return failure();
3163  }
3165  if (scalableVals)
3166  result.addAttribute(getScalableSizesAttrName(, scalableVals);
3167  if (staticSizes)
3168  result.addAttribute(getStaticVectorSizesAttrName(, staticSizes);
3170  return success();
3171 }
3174  p << ' ' << getTarget() << ' ';
3175  if (!getMixedVectorSizes().empty()) {
3176  p << kVectorSizesKeyword << ' ';
3177  printDynamicIndexList(p, getOperation(), getVectorSizes(),
3178  getStaticVectorSizesAttr(),
3179  /*valueTypes=*/{}, getScalableSizesAttr(),
3181  }
3183  if (getVectorizeNdExtract())
3184  p << getVectorizeNdExtractAttrName() << ' ';
3187  (*this)->getAttrs(),
3188  /*elidedAttrs=*/{
3189  getScalableSizesAttrName(getOperation()->getName()),
3190  getStaticVectorSizesAttrName(getOperation()->getName())});
3191  p << " : ";
3192  p << getTarget().getType();
3193  if (!getVectorSizes().empty()) {
3194  p << ", ";
3195  llvm::interleaveComma(getVectorSizes(), p,
3196  [&](Value operand) { p << operand.getType(); });
3197  }
3198 }
3200 DiagnosedSilenceableFailure transform::VectorizeOp::apply(
3201  transform::TransformRewriter &rewriter,
3202  mlir::transform::TransformResults &transformResults,
3204  auto targets = state.getPayloadOps(getTarget());
3205  if (std::empty(targets))
3208  SmallVector<int64_t> vectorSizes;
3209  for (OpFoldResult sz : getMixedVectorSizes()) {
3210  if (<Attribute>()) {
3211  auto attr = sz.get<Attribute>();
3212  vectorSizes.push_back(cast<IntegerAttr>(attr).getInt());
3213  continue;
3214  } else if (<Value>() && isa<ParamType>(sz.get<Value>().getType())) {
3215  ArrayRef<Attribute> params = state.getParams(sz.get<Value>());
3216  if (params.size() != 1)
3217  return emitSilenceableFailure(getLoc()) << "expected a single param";
3218  vectorSizes.push_back(
3219  cast<IntegerAttr>(params.front()).getValue().getSExtValue());
3220  continue;
3221  }
3223  auto szPayloads = state.getPayloadOps(sz.get<Value>());
3224  if (!llvm::hasSingleElement(szPayloads)) {
3225  auto diag = this->emitOpError(
3226  "requires vector size handle that is mapped to 1 payload op");
3227  diag.attachNote(sz.get<Value>().getLoc())
3228  << "mapped to " << llvm::range_size(szPayloads) << " payload ops";
3230  }
3232  Operation *szPayloadOp = *szPayloads.begin();
3233  if (szPayloadOp->getNumResults() != 1 ||
3234  !szPayloadOp->getResult(0).getType().isIndex()) {
3235  auto diag = this->emitOpError(
3236  "requires vector size payload op with 1 index result");
3237  diag.attachNote(szPayloadOp->getLoc()) << "vector size payload op";
3239  }
3241  IntegerAttr attr;
3242  if (!matchPattern(szPayloadOp->getResult(0), m_Constant(&attr))) {
3243  auto diag = this->emitOpError("requires constant vector size");
3244  diag.attachNote(szPayloadOp->getLoc()) << "vector size payload op";
3246  }
3248  vectorSizes.push_back(attr.getInt());
3249  }
3251  // TODO: Check that the correct number of vectorSizes was provided.
3252  for (Operation *target : targets) {
3253  if (!isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp, tensor::UnPackOp>(
3254  target)) {
3255  return mlir::emitSilenceableFailure(target->getLoc())
3256  << "Unsupported Op, cannot vectorize";
3257  }
3259  if (failed(linalg::vectorize(rewriter, target, vectorSizes,
3260  getScalableSizes(),
3261  getVectorizeNdExtract().has_value()
3262  ? getVectorizeNdExtract().value()
3263  : false))) {
3264  return mlir::emitSilenceableFailure(target->getLoc())
3265  << "Attempted to vectorize, but failed";
3266  }
3267  }
3270 }
3272 void transform::VectorizeOp::getEffects(
3274  consumesHandle(getTarget(), effects);
3275  onlyReadsHandle(getVectorSizes(), effects);
3276  modifiesPayload(effects);
3277 }
3279 SmallVector<OpFoldResult> VectorizeOp::getMixedVectorSizes() {
3280  OpBuilder b(getContext());
3281  return getMixedValues(getStaticVectorSizes(), getVectorSizes(), b);
3282 }
3285  if (getStaticVectorSizes().size() != getScalableSizes().size())
3286  return emitOpError("expected same number of vector sizes (")
3287  << getStaticVectorSizes().size() << ") and scalable sizes ("
3288  << getScalableSizes().size() << ")";
3289  return success();
3290 }
3292 //===----------------------------------------------------------------------===//
3293 // HoistRedundantVectorTransfersOp
3294 //===----------------------------------------------------------------------===//
3297 transform::HoistRedundantVectorTransfersOp::applyToOne(
3298  transform::TransformRewriter &rewriter, func::FuncOp target,
3300  transform::TransformState &state) {
3301  // WARNING: This hoisting does not model parallelism and is generally
3302  // incorrect when used on distributed loops with memref semantics!
3303  // TODO: obsolete and should be retired.
3305  results.push_back(target);
3307 }
3309 //===----------------------------------------------------------------------===//
3310 // ConvertConv2DToImg2ColOp.
3311 //===----------------------------------------------------------------------===//
3313 DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne(
3314  transform::TransformRewriter &rewriter, linalg::LinalgOp target,
3316  transform::TransformState &state) {
3317  rewriter.setInsertionPoint(target);
3318  auto maybeTransformed =
3320  target)
3321  .Case([&](linalg::Conv2DNhwcHwcfOp op) {
3322  return rewriteInIm2Col(rewriter, op);
3323  })
3324  .Case([&](linalg::Conv2DNhwcFhwcOp op) {
3325  return rewriteInIm2Col(rewriter, op);
3326  })
3327  .Case([&](linalg::DepthwiseConv2DNhwcHwcOp op) {
3328  return rewriteInIm2Col(rewriter, op);
3329  })
3330  .Case([&](linalg::Conv2DNchwFchwOp op) {
3331  return rewriteInIm2Col(rewriter, op);
3332  })
3333  .Default([&](Operation *op) {
3334  return rewriter.notifyMatchFailure(op, "not supported");
3335  });
3336  if (failed(maybeTransformed))
3337  return emitDefaultSilenceableFailure(target);
3338  // Handle to the operation producing the img2col tensor.
3339  results.push_back(maybeTransformed->first);
3340  // Handle to the operation that replaces the original convolution.
3341  results.push_back(maybeTransformed->second);
3343 }
3345 //===----------------------------------------------------------------------===//
3346 // FlattenElementwiseLinalgOp.
3347 //===----------------------------------------------------------------------===//
3349 DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne(
3350  transform::TransformRewriter &rewriter, linalg::LinalgOp target,
3352  transform::TransformState &state) {
3353  rewriter.setInsertionPoint(target);
3354  if (!isElementwise(target))
3355  return mlir::emitSilenceableFailure(target->getLoc())
3356  << "only elementwise flattening is supported";
3358  // If rank <= 1, do nothing
3359  if (target.getNumLoops() <= 1) {
3360  results.push_back(target);
3362  }
3364  // Attempt to flatten all dims to one.
3365  ReassociationIndices reassociation(target.getNumLoops());
3366  std::iota(reassociation.begin(), reassociation.end(), 0);
3367  auto maybeFlattened =
3368  collapseOpIterationDims(target, reassociation, rewriter);
3369  if (failed(maybeFlattened))
3370  return mlir::emitSilenceableFailure(target->getLoc())
3371  << "attempted to flatten, but failed";
3372  results.push_back(maybeFlattened->collapsedOp);
3373  rewriter.replaceOp(target, maybeFlattened->results);
3375 }
3377 //===----------------------------------------------------------------------===//
3378 // TransposeConv2DOp
3379 //===----------------------------------------------------------------------===//
3381 DiagnosedSilenceableFailure transform::TransposeConv2DOp::applyToOne(
3382  transform::TransformRewriter &rewriter, linalg::LinalgOp target,
3384  transform::TransformState &state) {
3385  rewriter.setInsertionPoint(target);
3386  auto maybeTransformed =
3388  .Case([&](linalg::Conv2DNhwcFhwcOp op) {
3389  return transposeConv2D(rewriter, op);
3390  })
3391  .Case([&](linalg::Conv2DNhwcFhwcQOp op) {
3392  return transposeConv2D(rewriter, op);
3393  })
3394  .Default([&](Operation *op) {
3395  return rewriter.notifyMatchFailure(op, "not supported");
3396  });
3397  if (failed(maybeTransformed))
3398  return emitDefaultSilenceableFailure(target);
3399  // Handle to the new Conv2D operation with transposed filters
3400  results.push_back(*maybeTransformed);
3402 }
3404 //===----------------------------------------------------------------------===//
3405 // InsertSliceToCopyOp
3406 //===----------------------------------------------------------------------===//
3407 template <typename OpTy>
3410  transform::TransformState &state) {
3411  static_assert(llvm::is_one_of<OpTy, tensor::InsertSliceOp,
3412  tensor::ParallelInsertSliceOp>() &&
3413  "wrong op type");
3415  if (auto copySource =
3416  target.getSource().template getDefiningOp<linalg::CopyOp>()) {
3417  results.push_back(copySource);
3419  }
3421  // If we are inside an InParallel region, temporarily set the insertion point
3422  // outside: only tensor.parallel_insert_slice ops are allowed in there.
3423  if constexpr (std::is_same_v<OpTy, tensor::ParallelInsertSliceOp>) {
3424  rewriter.setInsertionPoint(
3425  target->template getParentOfType<scf::InParallelOp>());
3426  }
3428  Value extracted = rewriter.create<tensor::ExtractSliceOp>(
3429  target.getLoc(), target.getDest(), target.getMixedOffsets(),
3430  target.getMixedSizes(), target.getMixedStrides());
3431  Value copied = rewriter
3432  .create<linalg::CopyOp>(target.getLoc(),
3433  target.getSource(), extracted)
3434  .getResult(0);
3435  // Reset the insertion point.
3436  rewriter.setInsertionPoint(target);
3437  rewriter.replaceOpWithNewOp<OpTy>(
3438  target, copied, target.getDest(), target.getMixedOffsets(),
3439  target.getMixedSizes(), target.getMixedStrides());
3441  results.push_back(copied.getDefiningOp());
3443 }
3445 DiagnosedSilenceableFailure transform::InsertSliceToCopyOp::applyToOne(
3446  transform::TransformRewriter &rewriter, Operation *targetOp,
3448  transform::TransformState &state) {
3450  rewriter.setInsertionPoint(targetOp);
3451  if (auto target = dyn_cast<tensor::InsertSliceOp>(targetOp))
3452  return doit(rewriter, target, results, state);
3453  if (auto target = dyn_cast<tensor::ParallelInsertSliceOp>(targetOp))
3454  return doit(rewriter, target, results, state);
3457  emitSilenceableError()
3458  << "only InsertSliceOp and ParallelInsertSliceOp ops are supported";
3459  diag.attachNote(targetOp->getLoc()) << "target op";
3460  return diag;
3461 }
3463 //===----------------------------------------------------------------------===//
3464 // MapCopyToThreadsOp
3465 //===----------------------------------------------------------------------===//
3467 DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
3468  transform::TransformRewriter &rewriter, Operation *target,
3470  transform::TransformState &state) {
3471  // Check if the op is supported.
3472  if (!isa<linalg::CopyOp, tensor::PadOp>(target)) {
3474  emitSilenceableError()
3475  << "only linalg.copy and tensor.pad target ops are supported";
3476  diag.attachNote(target->getLoc()) << "target op";
3477  return diag;
3478  }
3479  assert(target->getNumResults() == 1 && "expected single result");
3480  auto resultShapedType = cast<ShapedType>(target->getResult(0).getType());
3481  if (!resultShapedType.hasStaticShape()) {
3483  emitSilenceableError()
3484  << "only statically sized ops of rank <= 3 are supported";
3485  diag.attachNote(target->getLoc()) << "target op";
3486  return diag;
3487  }
3489  // Conservatively set the minimum viable desired bitwidth alignment.
3490  int64_t desiredBitAlignment = getDesiredBitAlignment();
3491  int64_t eltBitwidth =
3492  resultShapedType.getElementType().getIntOrFloatBitWidth();
3493  if (desiredBitAlignment % eltBitwidth != 0) {
3494  desiredBitAlignment = eltBitwidth;
3495  }
3497  gpu::CopyMappingInfo mapping(
3498  /*ctx=*/getContext(),
3499  /*totalNumThreads=*/getTotalNumThreads(),
3500  /*alignment=*/desiredBitAlignment,
3501  /*sizes=*/resultShapedType.getShape(),
3502  /*favorPredication=*/false,
3503  /*elementalBitwidth=*/
3504  resultShapedType.getElementType().getIntOrFloatBitWidth());
3505  if (mapping.status == gpu::CopyMappingInfo::Status::Invalid) {
3507  emitSilenceableError()
3508  << "too few threads to map copy op to threads on the most minor "
3509  "dimension, given alignment and vector size constraints, try "
3510  "smaller tile size of mapping to more threads";
3511  diag.attachNote(target->getLoc()) << "target op";
3512  return diag;
3513  }
3515  // OpBuilder only used to compute attributes.
3516  OpBuilder b(getContext());
3517  linalg::ForallTilingResult tilingResult;
3519  /*rewriter=*/rewriter,
3520  /*state=*/state,
3521  /*transformOp=*/*this,
3522  /*target=*/target,
3523  /*mixedNumThreads=*/getMixedValues(mapping.numThreads, {}, b),
3524  /*mixedTileSizes=*/ArrayRef<OpFoldResult>{},
3525  /*mapping=*/b.getArrayAttr(mapping.threadMapping),
3526  /*tilingResult=*/tilingResult);
3527  if (!diag.succeeded())
3528  return diag;
3530  results.push_back(tilingResult.tileOp);
3531  results.push_back(tilingResult.tiledOp);
3533 }
3535 #include "mlir/Dialect/Linalg/TransformOps/"
3537 #define GET_OP_CLASSES
3538 #include "mlir/Dialect/Linalg/TransformOps/"
static MLIRContext * getContext(OpFoldResult val)
DiagnosedSilenceableFailure doit(RewriterBase &rewriter, OpTy target, transform::ApplyToEachResultList &results, transform::TransformState &state)
static Operation * cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
void printOptionalInterchange(OpAsmPrinter &p, ArrayRef< int64_t > interchangeVals)
#define DOWNSCALE(trans)
bool isValidPackingPermutation(RelayoutOpTy op, ArrayRef< int64_t > permutation, OuterOrInnerPerm outerOrInnerPerm=OuterOrInnerPerm::Outer)
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static const StringLiteral kVectorSizesKeyword
static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(transform::TransformState &state, TransformOpInterface transformOp, SmallVector< OpFoldResult > &result, ArrayRef< OpFoldResult > ofrs)
Assuming that ofr is an index attr or a param of index type or a transform dialect handle mapped to e...
#define DOWNSCALE_NORMAL(a, b)
static FailureOr< LinalgOp > tryApply(Operation *operation, Args &&...args)
Attempts to apply the pattern specified as template argument to the given operation.
static void printMultitileSizesTypes(OpAsmPrinter &printer, Operation *op, Type targetType, Type lowSizeType, Type, Type)
static Operation * replaceForAllWithNewSignature(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp, TilingResult &tileAndFuseResult, int64_t resultNumber, SmallVector< OpFoldResult > &offsets, SmallVector< OpFoldResult > &sizes)
Add new operands to the forall op for users of the producerOp that are dominated by the containing sc...
static SmallVector< Operation * > tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
First, find the first "scf::ForallOp" user of producerOp and ensure it is exactly the containingOp,...
static ParseResult parseMultitileSizesTypes(OpAsmParser &parser, Type &targetType, Type &lowSizeType, Type &highSizeType, Type &splitPointType)
ParseResult parseOptionalInterchange(OpAsmParser &parser, OperationState &result)
#define DBGS()
static LogicalResult applyTilingToAll(RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps, unsigned numLoops, transform::TransformResults &transformResults, function_ref< FailureOr< scf::SCFTileAndFuseResult >(TilingInterface)> applyFn)
Apply a tiling transformation to all payload ops and store both the tiled operation as well as the cr...
static std::tuple< SmallVector< Operation * >, Operation * > tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
Find the first "extract" user of producerOp and tile it right before its use.
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static FailureOr< ForallTilingResult > tileToForallOpImpl(RewriterBase &b, TilingInterface op, ArrayRef< OpFoldResult > numThreads, std::optional< ArrayRef< OpFoldResult >> nominalTileSizes, std::optional< ArrayAttr > mapping, bool omitTileOffsetBoundsCheck)
Rewrite a TilingInterface op to a tiled scf.forall.
Definition: Tiling.cpp:343
Base type for affine expression.
Definition: AffineExpr.h:69
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
U cast() const
Definition: Attributes.h:189
This class represents an argument of a Block.
Definition: Value.h:315
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
UnitAttr getUnitAttr()
Definition: Builders.cpp:114
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:238
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:183
AffineExpr getAffineSymbolExpr(unsigned position)
Definition: Builders.cpp:375
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:128
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:93
MLIRContext * getContext() const
Definition: Builders.h:55
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:273
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:288
IndexType getIndexType()
Definition: Builders.cpp:71
ArrayAttr getStrArrayAttr(ArrayRef< StringRef > values)
Definition: Builders.cpp:313
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:156
A class for computing basic dominance information.
Definition: Dominance.h:136
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition: Dominance.h:156
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class allows control over how the GreedyPatternRewriteDriver works.
RewriterBase::Listener * listener
An optional listener that should be notified about IR modifications.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:202
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
Definition: AsmPrinter.cpp:93
This class represents a saved insertion point.
Definition: Builders.h:329
bool isSet() const
Returns true if this insert point is set.
Definition: Builders.h:339
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:553
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
Definition: Builders.h:318
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:322
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:414
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:263
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:453
This class provides the API for ops that are known to be isolated from above.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getOpResult(unsigned idx)
Definition: Operation.h:416
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:745
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:529
void setOperand(unsigned idx, Value value)
Definition: Operation.h:346
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition: Operation.h:555
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:669
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:341
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_type_range getOperandTypes()
Definition: Operation.h:392
result_type_range getResultTypes()
Definition: Operation.h:423
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition: Operation.h:263
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:869
result_range getOpResults()
Definition: Operation.h:415
result_range getResults()
Definition: Operation.h:410
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:219
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:246
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:638
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
U cast() const
Definition: Types.h:340
bool isIndex() const
Definition: Types.cpp:56
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
Type front()
Return first type in the range.
Definition: TypeRange.h:148
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
user_range getUsers() const
Definition: Value.h:224
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
State for analysis-enabled bufferization.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static Attribute parse(AsmParser &parser, Type type)
Parse the short form [42, 100, -1] without any type prefix.
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
void assign(unsigned size, std::nullptr_t)
Sets the list of results to size null pointers.
void reserve(unsigned size)
Reserves space for size elements in the list.
size_t size() const
Returns the number of elements in the list.
void push_back(Operation *op)
Appends an element to the list.
A listener that updates a TransformState based on IR modifications.
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void setValues(OpResult handle, Range &&values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
LogicalResult notifyPayloadOperationReplaced(Operation *op, Operation *replacement)
Notify the transform dialect interpreter that the given op has been replaced with another op and that...
The state maintained across applications of various ops implementing the TransformOpInterface.
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
Definition: AffineOps.cpp:1235
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1138
LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze op and its nested ops.
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition: Visitors.h:137
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
FailureOr< PackingResult > buildPackingLoopNest(RewriterBase &rewriter, tensor::PadOp opToHoist, scf::ForOp outermostEnclosingForOp, ArrayRef< int64_t > transposeVector)
Build the packing loop nest required to hoist opToHoist above outermostEnclosingForOp.
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp namedOp)
Create a GenericOp from the given named operation namedOp and replace namedOp.
LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, const LinalgPaddingOptions &options, LinalgOp &paddedOp, SmallVector< Value > &replacements, SmallVector< tensor::PadOp > &padOps)
Pad the iterator dimensions paddingDimensions of all opToPad operands to a static bounding box.
Definition: Padding.cpp:153
FailureOr< std::pair< Operation *, Operation * > > rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp)
Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) and linalg....
FailureOr< ForallTilingResult > tileToForallOpUsingTileSizes(RewriterBase &builder, TilingInterface op, ArrayRef< OpFoldResult > tileSizes, std::optional< ArrayAttr > mapping)
Same as tileToForallOp, but calculate the number of threads required using the given tileSizes.
Definition: Tiling.cpp:467
FailureOr< LowerUnPackOpResult > lowerUnPack(RewriterBase &rewriter, tensor::UnPackOp unPackOp)
Rewrite pack as empty + transpose + reshape + extract_slice.
Definition: Transforms.cpp:355
std::optional< Value > allocateWorkgroupMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU workgroup memory.
Definition: Promotion.cpp:470
Value bufferizeToAllocation(RewriterBase &rewriter, const BufferizeToAllocationOptions &options, tensor::PadOp padOp, Attribute memorySpace={}, Operation *insertionPoint=nullptr)
Materialize a buffer allocation for the given tensor.pad op and lower the op to linalg....
FailureOr< LinalgOp > specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp)
Create a namedOp from the given GenericOp and replace the GenericOp.
Definition: Specialize.cpp:24
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns)
Definition: Tiling.cpp:911
LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value)
In case of GPU private memory there is no need to deallocate since the memory is freed when going out...
Definition: Promotion.cpp:511
std::optional< Value > allocateGPUPrivateMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU private memory.
Definition: Promotion.cpp:495
FailureOr< Operation * > rewriteInDestinationPassingStyle(RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp)
Rewrite tensor.from_elements to linalg.generic.
FailureOr< ForallTilingResult > tileToForallOp(RewriterBase &builder, TilingInterface op, ArrayRef< OpFoldResult > numThreads, std::optional< ArrayAttr > mapping)
Definition: Tiling.cpp:458
void hoistRedundantVectorTransfers(Operation *root)
Hoist vector.transfer_read/vector.transfer_write on buffers pairs out of immediately enclosing scf::F...
Definition: Hoisting.cpp:76
FailureOr< Operation * > transposeConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op)
Convert linalg.conv_2d_nhwc_fhwc(_q) to linalg.conv_2d_nhwc_hwcf(_q) by materializing transpose.
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
Patterns to fold unit-extent dimensions in operands/results of linalg ops on tensors via reassociativ...
LogicalResult copyToWorkgroupMemory(OpBuilder &b, Value src, Value dst)
Create Memref copy operations and add gpu barrier guards before and after the copy operation to ensur...
Definition: Promotion.cpp:486
FailureOr< Value > hoistPaddingOnTensors(RewriterBase &rewriter, tensor::PadOp opToHoist, int64_t numLoops, ArrayRef< int64_t > transposeVector, tensor::PadOp &hoistedOp, SmallVectorImpl< GenericOp > &transposeOps)
Mechanically hoist padding operations on tensors by numLoops into a new, generally larger tensor.
LogicalResult linalgOpAnchoredEmptyTensorEliminationStep(RewriterBase &rewriter, Operation *op, bufferization::OneShotAnalysisState &state)
Try to eliminate tensor::EmptyOps inside op that are anchored on a LinalgOp.
LogicalResult vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Emit a suitable vector form for an operation.
LogicalResult promoteSubviewsPrecondition(Operation *op, LinalgPromotionOptions options)
Promote memref.subviews feeding linalg-on-buffers operations.
Definition: Promotion.cpp:399
LogicalResult copyToGPUPrivateMemory(OpBuilder &b, Value src, Value dst)
Normal copy to between src and dst.
Definition: Promotion.cpp:503
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
Definition: Utils.cpp:169
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.
Definition: Interchange.cpp:50
FailureOr< StaticMultiSizeSpecification > computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, int64_t targetSize, int64_t divisor)
Definition: Tiling.cpp:111
FailureOr< SplitReductionResult > splitReduction(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
FailureOr< ForallReductionTilingResult > tileReductionUsingForall(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > tileSizes={}, std::optional< ArrayAttr > mapping=std::nullopt)
Method to tile a reduction to parallel iterations computing partial reductions.
Definition: Tiling.cpp:647
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
Definition: Transforms.cpp:777
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
Definition: Transforms.cpp:488
void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns)
Patterns to promote inputs to outputs and remove unused inputs of linalg.generic ops.
FailureOr< LowerPackResult > lowerPack(RewriterBase &rewriter, tensor::PackOp packOp)
Rewrite pack as pad + reshape + transpose.
Definition: Transforms.cpp:219
FailureOr< LinalgOp > promoteSubViews(OpBuilder &b, LinalgOp op, const LinalgPromotionOptions &options)
Promote the subViews into a new buffer allocated at the insertion point b.
Definition: Promotion.cpp:421
std::function< SplitReductionOptions(LinalgOp op)> ControlSplitReductionFn
Function signature to control reduction splitting.
Definition: Transforms.h:443
LogicalResult deallocateWorkgroupMemory(OpBuilder &, Value)
In case of GPU group memory there is no need to deallocate.
Definition: Promotion.cpp:479
FailureOr< CollapseResult > collapseOpIterationDims(LinalgOp op, ArrayRef< ReassociationIndices > foldedIterationDims, RewriterBase &rewriter)
Collapses dimensions of linalg.generic/linalg.copy operation.
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, tensor::PackOp packOp, linalg::LinalgOp linalgOp, tensor::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
Definition: Transforms.cpp:686
std::pair< TilingInterface, TilingInterface > splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension, OpFoldResult splitPoint)
Split the given op into two parts along the given iteration space dimension at the specified splitPoi...
Definition: Split.cpp:67
FailureOr< SplitReductionResult > splitReductionByScaling(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
Scaling-based implementation of the split reduction transformation.
FailureOr< MultiSizeSpecification > computeMultiTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension, OpFoldResult targetSize, OpFoldResult divisor, bool emitAssertions=true)
Emits the IR computing the multi-sized tiling specification with two tile sizes not exceeding targetS...
Definition: Tiling.cpp:137
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:77
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
FailureOr< scf::SCFReductionTilingResult > tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > tileSize)
Method to tile a reduction and generate a parallel op within a serial loop.
FailureOr< SCFTilingResult > tileUsingSCF(RewriterBase &rewriter, TilingInterface op, const SCFTilingOptions &options)
Method to tile an op that implements the TilingInterface using scf.for for iterating over the tiles.
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
Definition: SCF.cpp:597
FailureOr< SmallVector< scf::ForOp > > lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op)
Method to lower an op that implements the TilingInterface to loops/scalars.
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:105
void populateFoldTensorSubsetIntoVectorTransferPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into vector transfer ops.
DiagnosedSilenceableFailure tileToForallOpImpl(RewriterBase &rewriter, transform::TransformState &state, TransformOpInterface transformOp, Operation *target, ArrayRef< OpFoldResult > mixedNumThreads, ArrayRef< OpFoldResult > mixedTileSizes, std::optional< ArrayAttr > mapping, linalg::ForallTilingResult &tilingResult)
Implementation of tiling operations using scf.forall.
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void onlyReadsHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void consumesHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void producesHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
void populateSinkVectorBroadcastPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant vector broadcasts.
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Definition: XeGPUOps.cpp:21
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, TypeRange valueTypes=TypeRange(), ArrayRef< bool > scalables={}, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hook for custom directive in assemblyFormat.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableVals, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hook for custom directive in assemblyFormat.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
static LogicalResult failure(bool isFailure=true)
If isFailure is true a failure result is generated, otherwise a 'success' result is generated.
Definition: LogicalResult.h:36
This is the representation of an operand reference.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Definition: Builders.h:287
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
NamedAttrList attributes
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
A listener that forwards all notifications to another listener.
Definition: PatternMatch.h:463
ForwardingListener(OpBuilder::Listener *listener)
Definition: PatternMatch.h:464
Container for result values of tiling.
SmallVector< Value > tiledValues
Options for analysis-enabled bufferization.
Transformation to drop unit-extent dimensions from linalg.generic operations.
Definition: Transforms.h:474
Vectorization pattern for memref::CopyOp.
Definition: Transforms.h:1331
Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) dimensions into 1-D depthwise c...
Definition: Transforms.h:1277
Rewrite a TilingInterface op to a tiled scf.forall, applying tiling by numThreads.
Definition: Transforms.h:857
Match and rewrite for the pattern:
Definition: Transforms.h:1404
Match and rewrite for the pattern:
Definition: Transforms.h:1432
LinalgPromotionOptions & setUseFullTileBuffersByDefault(bool use)
Definition: Transforms.h:380
LinalgPromotionOptions & setAlignment(unsigned align)
Definition: Transforms.h:386
LinalgPromotionOptions & setUseAlloca(bool use)
Definition: Transforms.h:399
LinalgPromotionOptions & setCopyInOutFns(CopyCallbackFn const &copyIn, CopyCallbackFn const &copyOut)
Definition: Transforms.h:419
LinalgPromotionOptions & setUseFullTileBuffers(ArrayRef< bool > useFullTiles)
Definition: Transforms.h:369
LinalgPromotionOptions & setMemorySpace(Attribute memorySpc)
Definition: Transforms.h:393
LinalgPromotionOptions & setAllocationDeallocationFns(AllocBufferCallbackFn const &allocFn, DeallocBufferCallbackFn const &deallocFn)
Definition: Transforms.h:409
LinalgPromotionOptions & setOperandsToPromote(ArrayRef< int64_t > operands)
Definition: Transforms.h:358
Split Reduction options.
Definition: Transforms.h:428
Options used to control tile + fuse.
SCFTilingOptions tilingOptions
The tiling options used to control the tiling of the consumer.
Options to use to control tiling.
SCFTilingOptions & setTileSizeComputationFunction(SCFTileSizeComputationFunction fun)
SCFTilingOptions & setInterchange(ArrayRef< int64_t > interchange)
SCFTilingOptions & setTileSizes(ArrayRef< OpFoldResult > ts)
Convenience function to set the tileSizeComputationFunction to a function that computes tile sizes at...
SmallVector< int64_t > interchangeVector
The interchange vector to reorder the tiled loops.