MLIR  21.0.0git
LinalgTransformOps.cpp
Go to the documentation of this file.
1 //===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
12 
38 #include "mlir/IR/PatternMatch.h"
39 #include "mlir/IR/TypeUtilities.h"
41 #include "mlir/Support/LLVM.h"
42 #include "mlir/Support/TypeID.h"
44 #include "llvm/ADT/STLExtras.h"
45 #include "llvm/ADT/ScopeExit.h"
46 #include "llvm/ADT/TypeSwitch.h"
47 #include "llvm/Support/Debug.h"
48 #include "llvm/Support/LogicalResult.h"
49 #include <type_traits>
50 
51 using namespace mlir;
52 using namespace mlir::linalg;
53 using namespace mlir::transform;
54 
55 #define DEBUG_TYPE "linalg-transforms"
56 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
57 #define DBGSNL() (llvm::dbgs() << "\n")
58 #define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n")
59 
60 /// Attempts to apply the pattern specified as template argument to the given
61 /// operation. The pattern is expected to have a `returningMatchAndRewrite`
62 /// function that returns the "main" result or failure. Returns failure if the
63 /// pattern failed to apply. Extra arguments are forwarded to the pattern
64 /// constructor.
65 template <typename PatternTy, typename... Args>
66 static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
67  // Check if the given operation has the type expected by the pattern.
68  using OpTy = typename llvm::function_traits<
69  decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
70  auto op = dyn_cast<OpTy>(operation);
71  if (!op)
72  return failure();
73 
74  // Apply the pattern directly to the op.
75  PatternTy pattern(operation->getContext(), std::forward<Args>(args)...);
76  // We want to discourage direct use of PatternRewriter in APIs but In this
77  // very specific case, an IRRewriter is not enough.
78  struct TrivialPatternRewriter : public PatternRewriter {
79  public:
80  explicit TrivialPatternRewriter(MLIRContext *context)
81  : PatternRewriter(context) {}
82  };
83  TrivialPatternRewriter rewriter(operation->getContext());
84  rewriter.setInsertionPoint(operation);
85  auto result = pattern.returningMatchAndRewrite(op, rewriter);
86  if (failed(result))
87  return failure();
88  return cast<LinalgOp>(result->getOperation());
89 }
90 
91 /// Assuming that `ofr` is an index attr or a param of index type
92 /// or a transform dialect handle mapped to exactly one op
93 /// with one index result, return that value.
95  transform::TransformState &state, TransformOpInterface transformOp,
97  for (OpFoldResult ofr : ofrs) {
98  if (auto attr = dyn_cast<Attribute>(ofr)) {
99  if (!isa<IntegerAttr>(attr))
100  return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
101  result.push_back(ofr);
102  continue;
103  }
104 
105  Value transformValue = cast<Value>(ofr);
106  if (isa<TransformParamTypeInterface>(transformValue.getType())) {
107  ArrayRef<Attribute> params = state.getParams(transformValue);
108  if (params.size() != 1)
109  return transformOp.emitDefiniteFailure()
110  << "requires exactly one parameter associated";
111  result.push_back(params[0]);
112  continue;
113  }
114 
115  auto payloadOps = state.getPayloadOps(transformValue);
116  if (!llvm::hasSingleElement(payloadOps)) {
118  transformOp.emitSilenceableError()
119  << "handle must be mapped to exactly one payload op";
120  diag.attachNote(transformValue.getLoc())
121  << "mapped to " << llvm::range_size(payloadOps) << " payload ops";
122  return diag;
123  }
124 
125  Operation *op = *payloadOps.begin();
126  if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
128  transformOp.emitSilenceableError()
129  << "payload op must have exactly 1 index result";
130  diag.attachNote(op->getLoc())
131  << "has " << op->getNumResults() << " results";
132  return diag;
133  }
134  result.push_back(op->getResult(0));
135  }
136 
138 }
139 
140 // Given a list of params that are index attrs or a list of OpFoldResults
141 // that are either index attrs or op handles, return a list of OpFoldResults
142 // of index attrs or a list of OpFoldResults where all op handles are
143 // replaced with the first (and only) OpResult of that payload op.
144 // (There must be exactly one parameter associated with the AnyParamType or
145 // one mapped payload op which must have exactly one index result.)
147  transform::TransformState &state, TransformOpInterface transformOp,
148  SmallVector<OpFoldResult> &result, Value packedHandle) {
149  if (isa<TransformParamTypeInterface>(packedHandle.getType())) {
150  ArrayRef<Attribute> params = state.getParams(packedHandle);
151  for (auto param : params) {
152  if (!isa<IntegerAttr>(param))
153  return transformOp.emitDefiniteFailure()
154  << "expected the parameter to be associated with an integer "
155  "attribute";
156  result.push_back(param);
157  }
159  }
160 
161  for (Operation *op : state.getPayloadOps(packedHandle)) {
162  if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
164  transformOp.emitSilenceableError()
165  << "payload op must have exactly 1 index result";
166  diag.attachNote(op->getLoc())
167  << "has " << op->getNumResults() << " results";
168  return diag;
169  }
170  result.push_back(op->getResult(0));
171  }
172 
174 }
175 
176 /// When possible, converts each `OpFoldResult` in `mixedResult` to
177 /// an integer if the value can be statically inferred. If a result
178 /// is a `Value` then it must be either a `ParamType` or a handle
179 /// to an a constant like op.
181  TransformState &state, TransformOpInterface &transformOp,
182  ArrayRef<OpFoldResult> mixedResults, SmallVectorImpl<int64_t> &reified) {
183  for (OpFoldResult paramOrHandle : mixedResults) {
184  if (auto attr = dyn_cast<Attribute>(paramOrHandle)) {
185  reified.push_back(cast<IntegerAttr>(attr).getInt());
186  continue;
187  } else if (isa<ParamType>(cast<Value>(paramOrHandle).getType())) {
188  ArrayRef<Attribute> params = state.getParams(cast<Value>(paramOrHandle));
189  if (params.size() != 1)
190  return transformOp.emitSilenceableError() << "expected a single param";
191  reified.push_back(
192  cast<IntegerAttr>(params.front()).getValue().getSExtValue());
193  continue;
194  }
195 
196  Value handle = cast<Value>(paramOrHandle);
197  if (!isa<TransformHandleTypeInterface>(handle.getType()))
198  return transformOp.emitSilenceableError() << "unexpected value handle";
199  auto payload = state.getPayloadOps(handle);
200  if (!llvm::hasSingleElement(payload))
201  return transformOp.emitSilenceableError()
202  << "requires param or handle that is mapped to 1 payload op";
203 
204  Operation *paramOrHandlePayloadOp = *payload.begin();
205  if (paramOrHandlePayloadOp->getNumResults() != 1 ||
206  !paramOrHandlePayloadOp->getResult(0).getType().isIndex()) {
207  return transformOp.emitSilenceableError()
208  << "requires param or handle to be result of op with 1 index "
209  "result";
210  }
211 
212  IntegerAttr attr;
213  if (!matchPattern(paramOrHandlePayloadOp->getResult(0), m_Constant(&attr)))
214  return transformOp.emitSilenceableError()
215  << "requires param or handle to be the result of a constant like "
216  "op";
217 
218  reified.push_back(attr.getInt());
219  }
221 }
222 
223 //===----------------------------------------------------------------------===//
224 // Apply...PatternsOp
225 //===----------------------------------------------------------------------===//
226 
227 void transform::ApplyEraseUnnecessaryInputsPatternsOp::populatePatterns(
230 }
231 
232 void transform::ApplyDecomposeTensorPackUnpackPatternsOp::populatePatterns(
235 }
236 
237 void transform::ApplyDecomposeTensorPadPatternsOp::populatePatterns(
240 }
241 
242 void transform::ApplyFoldUnitExtentDimsViaReshapesPatternsOp::populatePatterns(
246 }
247 
248 void transform::ApplyFoldUnitExtentDimsViaSlicesPatternsOp::populatePatterns(
251  options.rankReductionStrategy =
254 }
255 
256 void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns(
259 }
260 
261 void transform::ApplyFoldAddIntoDestPatternsOp::populatePatterns(
264 }
265 
266 void transform::ApplyPadVectorizationPatternsOp::populatePatterns(
269 }
270 
271 void transform::ApplyFoldIntoPackAndUnpackPatternsOp::populatePatterns(
274 }
275 
276 void transform::ApplyFoldPackUnpackIntoEmptyPatternsOp::populatePatterns(
279 }
280 
281 //===----------------------------------------------------------------------===//
282 // BufferizeToAllocationOp
283 //===----------------------------------------------------------------------===//
284 
285 void transform::BufferizeToAllocationOp::build(OpBuilder &b,
286  OperationState &result,
287  Value target,
288  Attribute memorySpace) {
289  SmallVector<Type> resultTypes;
290  resultTypes.push_back(b.getType<transform::AnyValueType>());
291  resultTypes.push_back(b.getType<transform::AnyOpType>());
292  return build(b, result,
293  /*resultTypes=*/resultTypes,
294  /*target=*/target,
295  /*memorySpace=*/memorySpace);
296 }
297 
298 void transform::BufferizeToAllocationOp::build(OpBuilder &b,
299  OperationState &result,
300  Value target,
301  int64_t memorySpace) {
302  SmallVector<Type> resultTypes;
303  resultTypes.push_back(b.getType<transform::AnyValueType>());
304  resultTypes.push_back(b.getType<transform::AnyOpType>());
305  return build(b, result,
306  /*resultTypes=*/resultTypes,
307  /*target=*/target,
308  /*memorySpace=*/b.getI64IntegerAttr(memorySpace));
309 }
310 
311 namespace {
312 class NewOpsListener : public RewriterBase::ForwardingListener {
313 public:
315 
316  SmallVector<Operation *> getNewOps() const {
317  return SmallVector<Operation *>(newOps.begin(), newOps.end());
318  }
319 
320 private:
321  void notifyOperationInserted(Operation *op,
322  OpBuilder::InsertPoint previous) override {
323  ForwardingListener::notifyOperationInserted(op, previous);
324  // We only care about newly created ops.
325  if (previous.isSet())
326  return;
327  auto inserted = newOps.insert(op);
328  (void)inserted;
329  assert(inserted.second && "expected newly created op");
330  }
331 
332  void notifyOperationErased(Operation *op) override {
333  ForwardingListener::notifyOperationErased(op);
334  op->walk([&](Operation *op) { newOps.erase(op); });
335  }
336 
337  DenseSet<Operation *> newOps;
338 };
339 } // namespace
340 
341 DiagnosedSilenceableFailure transform::BufferizeToAllocationOp::apply(
344  // Attach listener to keep track of newly created ops.
345  OpBuilder::Listener *previousListener = rewriter.getListener();
346  auto resetListener =
347  llvm::make_scope_exit([&]() { rewriter.setListener(previousListener); });
348  NewOpsListener newOpsListener(previousListener);
349  rewriter.setListener(&newOpsListener);
350 
352  if (getMemcpyOp() == "bufferization.materialize_in_destination") {
355  } else if (getMemcpyOp() == "memref.copy") {
356  options.memcpyOp =
358  } else if (getMemcpyOp() == "linalg.copy") {
359  options.memcpyOp =
361  } else {
362  llvm_unreachable("invalid memcpy op");
363  }
364  if (getAllocOp() == "memref.alloc") {
365  options.allocOp =
367  } else if (getAllocOp() == "memref.alloca") {
368  options.allocOp =
370  } else {
371  llvm_unreachable("invalid alloc op");
372  }
373  options.bufferizeDestinationOnly = getBufferizeDestinationOnly();
374  options.emitDealloc = getEmitDealloc();
375 
376  // Bufferize ops.
377  Attribute memorySpace =
378  getMemorySpace().has_value() ? getMemorySpace().value() : Attribute();
379  SmallVector<Value> allocatedBuffers;
380  for (Operation *op : state.getPayloadOps(getTarget())) {
381  Value buffer =
382  linalg::bufferizeToAllocation(rewriter, options, op, memorySpace);
383  if (!buffer) {
384  DiagnosedSilenceableFailure diag = emitSilenceableError()
385  << "failed to bufferize operation";
386  diag.attachNote(op->getLoc()) << "target payload op";
387  return diag;
388  }
389  allocatedBuffers.push_back(buffer);
390  }
391 
392  // Set results.
393  results.setValues(cast<OpResult>(getAllocatedBuffer()), allocatedBuffers);
394  results.set(cast<OpResult>(getNewOps()), newOpsListener.getNewOps());
396 }
397 
398 void transform::BufferizeToAllocationOp::getEffects(
400  if (getBufferizeDestinationOnly()) {
401  // The destination is replaced with a newly allocated buffer, but the op
402  // itself remains in place.
403  onlyReadsHandle(getTargetMutable(), effects);
404  } else {
405  consumesHandle(getTargetMutable(), effects);
406  }
407  producesHandle(getOperation()->getOpResults(), effects);
408  modifiesPayload(effects);
409 }
410 
412  if (getMemcpyOp() != "bufferization.materialize_in_destination" &&
413  getMemcpyOp() != "memref.copy" && getMemcpyOp() != "linalg.copy")
414  return emitOpError() << "unsupported memcpy op";
415  if (getAllocOp() != "memref.alloc" && getAllocOp() != "memref.alloca")
416  return emitOpError() << "unsupported alloc op";
417  return success();
418 }
419 
420 //===----------------------------------------------------------------------===//
421 // DecomposeOp
422 //===----------------------------------------------------------------------===//
423 
425 transform::DecomposeOp::applyToOne(transform::TransformRewriter &rewriter,
426  LinalgOp target,
428  transform::TransformState &state) {
429 #define DOWNSCALE(trans) \
430  { \
431  FailureOr<LinalgOp> res = tryApply<trans>(target); \
432  if (succeeded(res)) { \
433  results.push_back(*res); \
434  return DiagnosedSilenceableFailure::success(); \
435  } \
436  }
437 
438 #define DOWNSCALE_CALL(a, b) DownscaleSizeOneWindowed2DConvolution<a, b>
439 #define DOWNSCALE_NORMAL(a, b) DOWNSCALE(DOWNSCALE_CALL(a, b))
440 
441  DOWNSCALE_NORMAL(Conv2DNhwcHwcfOp, Conv1DNwcWcfOp)
442  DOWNSCALE_NORMAL(Conv2DNchwFchwOp, Conv1DNcwFcwOp)
443  DOWNSCALE_NORMAL(PoolingNhwcSumOp, PoolingNwcSumOp)
444  DOWNSCALE_NORMAL(PoolingNchwSumOp, PoolingNcwSumOp)
445  DOWNSCALE_NORMAL(PoolingNhwcMaxOp, PoolingNwcMaxOp)
446  DOWNSCALE_NORMAL(PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp)
447  DOWNSCALE_NORMAL(PoolingNhwcMinOp, PoolingNwcMinOp)
448  DOWNSCALE_NORMAL(PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp)
449  DOWNSCALE_NORMAL(PoolingNchwMaxOp, PoolingNcwMaxOp)
452 #undef DOWNSCALE_NORMAL
453 #undef DOWNSCALE_CALL
454 #undef DOWNSCALE
455  return emitDefaultSilenceableFailure(target);
456 }
457 
458 //===----------------------------------------------------------------------===//
459 // DecomposeInterfaceOp
460 //===----------------------------------------------------------------------===//
461 
462 // Decompose the target operation if it implements the AggregatedOpInterface.
463 // Push the decomposed operations (the ones that replaces the values produced by
464 // \p target) in the `results`.
465 DiagnosedSilenceableFailure transform::DecomposeInterfaceOp::applyToOne(
466  transform::TransformRewriter &rewriter, Operation *target,
468  transform::TransformState &state) {
469  auto decomposableOp = dyn_cast<AggregatedOpInterface>(target);
470  if (!decomposableOp) {
471  failed(rewriter.notifyMatchFailure(target,
472  "payload is not a decomposable op"));
473  return emitDefaultSilenceableFailure(target);
474  }
475 
476  FailureOr<SmallVector<Value>> maybeNewResults =
477  decomposableOp.decomposeOperation(rewriter);
478  if (failed(maybeNewResults))
479  return emitDefaultSilenceableFailure(target);
480 
481  rewriter.replaceOp(decomposableOp, *maybeNewResults);
482  for (Value val : *maybeNewResults) {
483  Operation *definition = val.getDefiningOp();
484  if (definition)
485  results.push_back(definition);
486  }
488 }
489 
490 //===----------------------------------------------------------------------===//
491 // EliminateLinalgOpAnchoredEmptyTensorsOp
492 //===----------------------------------------------------------------------===//
493 
494 void transform::EliminateLinalgOpAnchoredEmptyTensorsOp::getEffects(
496  onlyReadsHandle(getTargetMutable(), effects);
497  modifiesPayload(effects);
498 }
499 
501 transform::EliminateLinalgOpAnchoredEmptyTensorsOp::apply(
502  transform::TransformRewriter &rewriter, TransformResults &transformResults,
503  TransformState &state) {
505  options.allowReturnAllocsFromLoops = true;
506 
507  for (Operation *target : state.getPayloadOps(getTarget())) {
509  if (failed(analyzeOp(target, state)))
510  return mlir::emitSilenceableFailure(target->getLoc())
511  << "failed to analyze op";
513  rewriter, target, state)))
514  return mlir::emitSilenceableFailure(target->getLoc())
515  << "failed to eliminate LinalgOp anchored tensor.empty ops";
516  }
518 }
519 
520 //===----------------------------------------------------------------------===//
521 // FuseOp
522 //===----------------------------------------------------------------------===//
523 
524 /// Apply a tiling transformation to all payload ops and store both the
525 /// tiled operation as well as the created tile loops.
526 template <typename Range>
527 static LogicalResult applyTilingToAll(
528  RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps,
529  unsigned numLoops, transform::TransformResults &transformResults,
530  function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)>
531  applyFn) {
532  SmallVector<Operation *> tiledLinalgOps;
533  SmallVector<SmallVector<Operation *>> loopOps(numLoops);
534 
535  for (Operation *target : payloadOps) {
536  auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
537  if (!tilingInterfaceOp)
538  return transformOp->emitError("only TilingInterface ops are supported");
539 
540  rewriter.setInsertionPoint(target);
541  FailureOr<scf::SCFTileAndFuseResult> tiledResults =
542  applyFn(tilingInterfaceOp);
543  if (failed(tiledResults))
544  return failure();
545 
546  // Perform the replacement of tiled and fused values.
547  SmallVector<Operation *> opsToReplace{target};
548  llvm::append_range(opsToReplace, tiledResults->fusedProducers);
549  for (Operation *toReplace : opsToReplace) {
550  for (OpResult res : toReplace->getResults())
551  if (auto replacement = tiledResults->replacements.lookup(res))
552  rewriter.replaceAllUsesWith(res, replacement);
553  if (toReplace->use_empty()) {
554  rewriter.eraseOp(toReplace);
555  }
556  }
557 
558  // Report back the relevant handles to the transform op.
559  tiledLinalgOps.push_back(tiledResults->tiledAndFusedOps.front());
560  assert(tiledResults->loops.size() == numLoops &&
561  "Mismatched number of loops, tile and fuse transform should have "
562  "failed");
563  for (unsigned int i = 0; i < numLoops; ++i)
564  loopOps[i].push_back(tiledResults->loops[i]);
565  }
566 
567  transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
568  for (unsigned int i = 0; i < numLoops; ++i)
569  transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
570 
571  return success();
572 }
573 
575 transform::FuseOp::apply(transform::TransformRewriter &rewriter,
576  mlir::transform::TransformResults &transformResults,
578  SmallVector<int64_t> tileSizes =
579  extractFromIntegerArrayAttr<int64_t>(getTileSizes());
580  SmallVector<int64_t> tileInterchange =
581  extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
582 
583  scf::SCFTilingOptions tilingOptions;
584  tilingOptions.interchangeVector = tileInterchange;
585  SmallVector<OpFoldResult> tileSizesOfr =
586  getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
587  tilingOptions = tilingOptions.setTileSizes(tileSizesOfr);
588  scf::SCFTileAndFuseOptions tileAndFuseOptions;
589  tileAndFuseOptions.tilingOptions = tilingOptions;
590 
591  if (getApplyCleanup()) {
592  MLIRContext *context = rewriter.getContext();
593  RewritePatternSet patterns(context);
594  tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, context);
597  tileAndFuseOptions.cleanupPatterns = std::move(patterns);
598  }
599 
600  LogicalResult result = applyTilingToAll(
601  rewriter, getOperation(), state.getPayloadOps(getTarget()),
602  tileSizes.size() - llvm::count(tileSizes, 0), transformResults,
603  [&](TilingInterface tilingInterfaceOp)
604  -> FailureOr<scf::SCFTileAndFuseResult> {
605  return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
606  tileAndFuseOptions);
607  });
608  return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
609  : DiagnosedSilenceableFailure::success();
610 }
611 
612 LogicalResult transform::FuseOp::verify() {
613  SmallVector<int64_t> permutation =
614  extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
615  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
616  if (!std::is_permutation(sequence.begin(), sequence.end(),
617  permutation.begin(), permutation.end())) {
618  return emitOpError() << "expects interchange to be a permutation, found "
619  << getTileInterchange();
620  }
621 
622  SmallVector<int64_t> sizes =
623  extractFromIntegerArrayAttr<int64_t>(getTileSizes());
624  size_t numExpectedLoops = sizes.size() - llvm::count(sizes, 0);
625  if (numExpectedLoops != getNumResults() - 1)
626  return emitOpError() << "expects " << numExpectedLoops << " loop results";
627 
628  return success();
629 }
630 
631 //===----------------------------------------------------------------------===//
632 // FuseIntoContainingOp
633 //===----------------------------------------------------------------------===//
634 
635 void transform::FuseIntoContainingOp::build(OpBuilder &builder,
636  OperationState &result,
637  Value producerOp,
638  Value containingOp) {
639  result.addOperands({producerOp, containingOp});
640  auto resultType = transform::AnyOpType::get(builder.getContext());
641  result.addTypes({resultType, resultType});
642 }
643 
644 /// Add new operands to the forall op for users of the producerOp
645 /// that are dominated by the containing scf.forall op.
647  RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
648  Operation *containingOp, TilingResult &tileAndFuseResult,
649  int64_t resultNumber, SmallVector<OpFoldResult> &offsets,
650  SmallVector<OpFoldResult> &sizes) {
651 
652  // Count number of users not including the containing op
653  SetVector<Operation *> dominatedUsers;
654  DominanceInfo domInfo(containingOp);
655  for (Operation *user : producerOp->getResult(resultNumber).getUsers()) {
656  if (!containingOp->isAncestor(user) &&
657  (domInfo.dominates(containingOp, user))) {
658  dominatedUsers.insert(user);
659  }
660  }
661  if (dominatedUsers.empty())
662  return nullptr;
663 
664  // Create new scf.forall op
665  auto forallOp = cast<scf::ForallOp>(containingOp);
666  OpBuilder::InsertionGuard g(rewriter);
667  rewriter.setInsertionPoint(forallOp);
668 
669  // Get new output
670  Location loc = forallOp.getLoc();
671  auto genericOp = dyn_cast<linalg::GenericOp>(producerOp);
672  if (!genericOp)
673  return nullptr;
674  SmallVector<Value> outputs = genericOp.getOutputs();
675  SmallVector<Value> newOuts(forallOp.getOutputs());
676  newOuts.push_back(outputs[resultNumber]);
677 
678  // Create new scf.forall op
679  auto newforallOp = rewriter.create<scf::ForallOp>(
680  loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
681  forallOp.getMixedStep(), newOuts, forallOp.getMapping());
682  rewriter.eraseBlock(newforallOp.getBody());
683  newforallOp.getRegion().takeBody(forallOp.getRegion());
684 
685  // Add additional block argument for new value being returned
686  // and replaces all uses of the new output with corresponding bbArg
687  // inside the scf.forall to enable fusion into this new scf.forall.
688  newforallOp.getBody()->addArgument(newOuts.back().getType(),
689  newOuts.back().getLoc());
690  auto bbArgs = newforallOp.getBody()->getArguments();
691  rewriter.replaceUsesWithIf(newOuts.back(), bbArgs.back(),
692  [&](OpOperand &use) {
693  Operation *op = use.getOwner();
694  return newforallOp->isProperAncestor(op);
695  });
696 
697  // Fix terminator
698  scf::InParallelOp terminatorOp = newforallOp.getTerminator();
699  SmallVector<Operation *> yieldingOps = llvm::to_vector<4>(llvm::map_range(
700  terminatorOp.getYieldingOps(), [](Operation &op) { return &op; }));
701  Operation *firstYieldOp = yieldingOps.front();
702  rewriter.setInsertionPoint(firstYieldOp);
703  Value src = tileAndFuseResult.tiledValues[0];
704  Value dst = newforallOp.getRegionIterArgs().back();
705  SmallVector<OpFoldResult> strides(offsets.size(), rewriter.getIndexAttr(1));
706  rewriter.create<tensor::ParallelInsertSliceOp>(firstYieldOp->getLoc(), src,
707  dst, offsets, sizes, strides);
708 
709  for (auto result : llvm::enumerate(forallOp.getResults())) {
710  rewriter.replaceAllUsesWith(result.value(),
711  newforallOp->getResult(result.index()));
712  }
713  rewriter.replaceUsesWithIf(producerOp->getResult(resultNumber),
714  newforallOp->getResults().back(),
715  [&](OpOperand &use) {
716  Operation *user = use.getOwner();
717  return dominatedUsers.contains(user);
718  });
719  return newforallOp;
720 }
721 
722 /// Given two operands coming from a loop iter arg, 'src' and 'dst', return true
723 /// if the operand 'src' is equal to 'dst' or equal to a iter arg present in a
724 /// outer loop. To determine the second condition, this function iterates
725 /// using a worklist over the enclosing loops, trying to find 'src' in any of
726 /// the parent loop's iter args.
727 static bool sameOrEquivalentIterArg(Value src, Value dst) {
728  // Stack like vector containing possible iterArgs candidates. The first one
729  // is dst, and we will transverse the IR from there.
730  SmallVector<Value> destWorklist;
731  destWorklist.push_back(dst);
732 
733  while (!destWorklist.empty()) {
734  Value currentDst = destWorklist.pop_back_val();
735 
736  // We have found the same operand in some iter arg in the loop structure,
737  // so src and dst are equivalent.
738  if (src == currentDst)
739  return true;
740 
741  // The operands are not equivalent, look for enclosing loops over
742  // currentDst.
743  auto bbArg = dyn_cast<BlockArgument>(currentDst);
744  if (!bbArg)
745  continue;
746 
747  Block *parentBlock = bbArg.getOwner();
748  assert(parentBlock && "unlinked block argument");
749 
750  Operation *parentOp = parentBlock->getParentOp();
751  assert(parentOp && "expected block argument with parent operation");
752 
753  // Check if parent is loop-like. If it's not, do not add it to the worklist.
754  auto parentLoop = dyn_cast<LoopLikeOpInterface>(parentOp);
755  if (!parentLoop)
756  continue;
757 
758  for (auto innerIterArg : parentLoop.getRegionIterArgs()) {
759  // No need to check for null as innerIterArg is tied to parentLoop.
760  OpOperand *operand = parentLoop.getTiedLoopInit(innerIterArg);
761  Value loopBlockArgument =
762  parentLoop->getOperand(operand->getOperandNumber());
763  destWorklist.push_back(loopBlockArgument);
764  }
765  }
766 
767  return false;
768 }
769 
770 /// Find the first "extract" user of `producerOp` and tile it right before its
771 /// use. The tiled op is fused under the `containingOp`.
772 /// Return this fused op on success or nullptr if anything fails.
773 /// If tiled op has uses that are dominated by `containingOp`, return
774 /// a new `containingOp` with results of the fused op appended to
775 /// results of the `containingOp` or nullptr if there are no dominated uses.
776 static std::tuple<SmallVector<Operation *>, Operation *>
778  Operation *producerOp, Operation *containingOp) {
779  LLVM_DEBUG(DBGS() << "Try to fuse a direct extract use\n");
780  auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
781  if (!tileableProducer) {
782  diag.attachNote(producerOp->getLoc())
783  << "producer is not a TileableInterface: " << *producerOp;
784  return {};
785  }
786 
787  // Search the producer slices accessed within the containing operation.
788  // TODO: Generalize to more extract/insert/parallel_insert triples, maybe
789  // evolve into an interface.
790  auto it = llvm::find_if(tileableProducer->getUsers(), [&](Operation *user) {
791  auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
792  return sliceOp && containingOp->isProperAncestor(sliceOp);
793  });
794 
795  // Find a fusion opportunity.
796  if (it == tileableProducer->getUsers().end()) {
797  diag.attachNote(tileableProducer->getLoc())
798  << "could not find fusion opportunity for: " << *tileableProducer;
799  return {};
800  }
801  auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*it);
802 
803  // Try to fuse the producer in-place.
804  OpBuilder::InsertionGuard guard(rewriter);
805  rewriter.setInsertionPoint(sliceOpToTile);
806 
807  // Clone the producer inside the consumer and try to update the producer init
808  // operands using the loop bbArgs if applicable. More precisely, if the bbArg
809  // of the container loop points to a value that it is used by the consumer op,
810  // then, instead of using such value on the consumer, use the value coming
811  // from the bbArg instead. This allows to reuse the output tensor (instead of
812  // creating a new one) of the container when both producer and container write
813  // to the same output.
814  if (LoopLikeOpInterface containerLoop =
815  dyn_cast<LoopLikeOpInterface>(sliceOpToTile->getParentOp())) {
816  Operation *clone = rewriter.clone(*producerOp);
817  rewriter.modifyOpInPlace(clone, [&]() {
818  // Iterate over the outputs of the producer and over the loop bbArgs and
819  // check if any bbArg points to the same value as the producer output. In
820  // such case, make the producer output point to the bbArg directly.
821  for (OpOperand &initOperandPtr :
822  cast<DestinationStyleOpInterface>(clone).getDpsInitsMutable()) {
823  Value producerOperand =
824  clone->getOperand(initOperandPtr.getOperandNumber());
825  for (BlockArgument containerIterArg :
826  containerLoop.getRegionIterArgs()) {
827  OpOperand *bbArg = containerLoop.getTiedLoopInit(containerIterArg);
828  Value consumerOperand =
829  containerLoop->getOperand(bbArg->getOperandNumber());
830  // The producer has the same init as the loop bbArg, use it.
831  if (sameOrEquivalentIterArg(producerOperand, consumerOperand)) {
832  initOperandPtr.set(containerIterArg);
833  }
834  }
835  }
836  });
837 
838  tileableProducer = dyn_cast<TilingInterface>(clone);
839  }
840 
841  // Tile the producer.
842  int64_t resultNumber =
843  cast<OpResult>(sliceOpToTile.getSource()).getResultNumber();
844  LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
845 
846  SmallVector<OpFoldResult> offsets = sliceOpToTile.getMixedOffsets();
847  SmallVector<OpFoldResult> sizes = sliceOpToTile.getMixedSizes();
848 
849  FailureOr<TilingResult> tileAndFuseResult =
850  tileableProducer.generateResultTileValue(rewriter, resultNumber, offsets,
851  sizes);
852 
853  if (failed(tileAndFuseResult)) {
854  diag.attachNote(tileableProducer->getLoc())
855  << "failed to tile producer op: " << *tileableProducer;
856  return {};
857  }
858 
859 #ifndef NDEBUG
860  for (auto *tiledOp : tileAndFuseResult->tiledOps) {
861  LLVM_DEBUG(DBGS() << "tiledProducer: " << *tiledOp << "\n");
862  }
863 #endif
864 
865  // Replace the extract op.
866  auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
867  rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
868  cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
869  if (failed(maybeRankReduced)) {
870  diag.attachNote(producerOp->getLoc())
871  << "shape types don't match (missing canonicalization?):\nTiledOp: "
872  << tileAndFuseResult->tiledValues[0]
873  << "\nSliceOp: " << sliceOpToTile.getOperation() << '\n';
874  return {};
875  }
876  rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
877 
878  // Add new outputs to containing op, if required
879  Operation *newContainingOp = replaceForAllWithNewSignature(
880  rewriter, diag, producerOp, containingOp, *tileAndFuseResult,
881  resultNumber, offsets, sizes);
882 
883  // Cleanup clone.
884  if (dyn_cast<LoopLikeOpInterface>(containingOp))
885  rewriter.eraseOp(tileableProducer);
886 
887  return std::make_tuple(tileAndFuseResult->tiledOps, newContainingOp);
888 }
889 
890 /// First, find the first "scf::ForallOp" user of `producerOp` and ensure
891 /// it is exactly the `containingOp`, otherwise bail.
892 /// Then, find the first "extract" user of the tied block argument and tile it
893 /// right before its "extract" use. The tiled op is fused under the
894 /// `containingOp`.
895 /// Return this fused op on success or nullptr if anything fails.
898  RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
899  Operation *containingOp) {
900  LLVM_DEBUG(DBGS() << "Try to fuse an extract use through block argument\n");
901 
902  auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
903  if (!tileableProducer) {
904  diag.attachNote(producerOp->getLoc())
905  << "producer is not a TileableInterface: " << *producerOp;
906  return {};
907  }
908 
909  // Search the first use by a "scf::ForallOp" user.
910  scf::ForallOp forallOp;
911  auto itProducerUses =
912  llvm::find_if(tileableProducer->getUses(), [&](OpOperand &use) {
913  forallOp = dyn_cast<scf::ForallOp>(use.getOwner());
914  return forallOp;
915  });
916  // If it's not from the containing op, return.
917  if (!forallOp || forallOp != containingOp) {
918  diag.attachNote(tileableProducer->getLoc())
919  << "could not find a use by the containing op: " << *tileableProducer;
920  return {};
921  }
922 
923  // Search the producer slices accessed within the containing
924  // operation.
925  // TODO: Generalize to more extract/insert/parallel_insert triples.
926  // Maybe evolve into an interface.
927  OpOperand *pUse = &(*itProducerUses);
928  BlockArgument bbArg = forallOp.getTiedBlockArgument(pUse);
929 
930  // Search the producer slices accessed within the containing operation.
931  // TODO: Generalize to more extract/insert/parallel_insert triples, maybe
932  // evolve into an interface.
933  auto itBBArgUsers = llvm::find_if(bbArg.getUsers(), [&](Operation *user) {
934  auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
935  return sliceOp && containingOp->isProperAncestor(sliceOp);
936  });
937 
938  // Find a fusion opportunity.
939  if (itBBArgUsers == bbArg.getUsers().end()) {
940  diag.attachNote(containingOp->getLoc())
941  << "could not find fusion opportunity for bbArg: " << bbArg;
942  return {};
943  }
944  auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*itBBArgUsers);
945 
946  // Try to fuse the producer in-place.
947  OpBuilder::InsertionGuard guard(rewriter);
948  rewriter.setInsertionPoint(sliceOpToTile);
949 
950  // Replace the use in the tileableProducer before tiling: clone, replace and
951  // then tile.
952  int64_t resultNumber = cast<OpResult>(pUse->get()).getResultNumber();
953  LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
954 
955  // Gather destination tensors.
956  SmallVector<Value> destinationTensors;
958  rewriter, tileableProducer->getLoc(), tileableProducer,
959  destinationTensors))) {
960  diag.attachNote(tileableProducer->getLoc())
961  << "failed to get destination tensors for: " << *tileableProducer;
962  return {};
963  }
964 
965  IRMapping bvm;
966  bvm.map(destinationTensors[resultNumber], bbArg);
967  auto tileableProducerClone =
968  cast<TilingInterface>(rewriter.clone(*tileableProducer, bvm));
969  auto scopeGuard =
970  llvm::make_scope_exit([&]() { rewriter.eraseOp(tileableProducerClone); });
971 
972  // Tile the producer.
973  FailureOr<TilingResult> tileAndFuseResult =
974  tileableProducerClone.generateResultTileValue(
975  rewriter, resultNumber, sliceOpToTile.getMixedOffsets(),
976  sliceOpToTile.getMixedSizes());
977  if (failed(tileAndFuseResult)) {
978  diag.attachNote(tileableProducer->getLoc())
979  << "failed to tile producer op: " << *tileableProducer;
980  return {};
981  }
982 
983  // Replace the extract op.
984  auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
985  rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
986  cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
987  assert(succeeded(maybeRankReduced) && "unexpected shape");
988  rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
989 
990  // Replace the use in containingOp.
991  rewriter.modifyOpInPlace(containingOp, [&]() {
992  containingOp->setOperand(pUse->getOperandNumber(),
993  destinationTensors.front());
994  });
995 
996  return tileAndFuseResult->tiledOps;
997 }
998 
1000  Operation *producerOp,
1001  Operation *containingOp) {
1002  LLVM_DEBUG(DBGS() << "Try to fuse an use by cloning\n");
1003 
1004  // Gather all uses inside the containing op.
1006  for (OpResult result : producerOp->getOpResults()) {
1007  for (OpOperand &use : result.getUses()) {
1008  if (containingOp->isProperAncestor(use.getOwner())) {
1009  uses.push_back(&use);
1010  continue;
1011  }
1012  // Cannot clone and fuse if the use is by the containing op itself: fail
1013  // immediately.
1014  if (containingOp == use.getOwner()) {
1015  diag.attachNote(producerOp->getLoc())
1016  << "producer op use by containing op cannot be fused by cloning";
1017  return nullptr;
1018  }
1019  }
1020  }
1021 
1022  // Check for a non-empty list of fusion opportunities.
1023  if (uses.empty()) {
1024  diag.attachNote(producerOp->getLoc()) << "no fusion opportunity by cloning";
1025  return nullptr;
1026  }
1027 
1028  // Clone and fuse inside the containing op.
1029  Operation *fusedOp = nullptr;
1030  OpOperand *use = uses.front();
1031  // Parallel insert slice is not a valid clone destination.
1032  // TODO: Generalize to other type of ops.
1033  assert(!isa<tensor::ParallelInsertSliceOp>(use->getOwner()) &&
1034  "Parallel insert slice is not a valid clone destination");
1035  unsigned resultNumber = cast<OpResult>(use->get()).getResultNumber();
1036  LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
1037 
1038  OpBuilder::InsertionGuard guard(rewriter);
1039  rewriter.setInsertionPoint(use->getOwner());
1040  fusedOp = rewriter.clone(*producerOp);
1041  rewriter.modifyOpInPlace(
1042  use->getOwner(), [&] { use->set(fusedOp->getOpResult(resultNumber)); });
1043 
1044  return fusedOp;
1045 }
1046 
1047 bool transform::FuseIntoContainingOp::allowsRepeatedHandleOperands() {
1048  // Allow repeated handles since we are fusing everything anyway.
1049  return true;
1050 }
1051 
1053 transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter,
1054  transform::TransformResults &results,
1055  transform::TransformState &state) {
1056  SmallVector<Operation *> fusedOps;
1057  auto producerOps = state.getPayloadOps(getProducerOp());
1058  auto containingOps = state.getPayloadOps(getContainingOp());
1059  if (!llvm::hasSingleElement(containingOps)) {
1060  return emitDefiniteFailure()
1061  << "requires exactly one containing_op handle (got "
1062  << llvm::range_size(containingOps) << ")";
1063  }
1064  Operation *containingOp = *containingOps.begin();
1065 
1066  // If nothing to fuse, propagate success.
1067  if (std::empty(producerOps)) {
1068  results.set(cast<OpResult>(getFusedOp()), SmallVector<mlir::Operation *>{});
1069  results.set(cast<OpResult>(getNewContainingOp()), {containingOp});
1071  }
1072 
1073  // Helper function to find the next producer that should be fused. Take any
1074  // producer that has a use inside the containing op.
1075  SetVector<Operation *> remainingProducers(llvm::from_range, producerOps);
1076  auto getNextProducer = [&]() -> FailureOr<Operation *> {
1077  for (const auto &it : enumerate(remainingProducers)) {
1078  Operation *producerOp = it.value();
1079  // The containing op may be a user of producerOp: use isAncestor.
1080  int64_t numUsesInContainingOp =
1081  llvm::count_if(producerOp->getUsers(), [&](Operation *op) {
1082  return containingOp->isAncestor(op);
1083  });
1084  // TODO: When resolving the TODO below (no duplicate ops), take an op
1085  // that has no use among the remaining producers. This is a topological
1086  // sorting.
1087  if (numUsesInContainingOp > 0) {
1088  if (numUsesInContainingOp == 1)
1089  remainingProducers.erase(remainingProducers.begin() + it.index());
1090  return producerOp;
1091  }
1092  }
1093  return failure();
1094  };
1095 
1096  while (!remainingProducers.empty()) {
1097  auto nextProducer = getNextProducer();
1098  if (failed(nextProducer)) {
1099  auto diag = mlir::emitSilenceableFailure(getLoc())
1100  << "could not find next producer to fuse into container";
1101  diag.attachNote(containingOp->getLoc()) << "containing op";
1102  return diag;
1103  }
1104 
1105  Operation *producerOp = *nextProducer;
1106 
1107  // Default diagnostic, to be complemented with more failure information.
1109  diag << "could not fuse " << *producerOp << " into " << *containingOp;
1110 
1111  // TODO: If there are multiple uses of the producer in the containing op,
1112  // we currently tile/clone the op multiple times (once per use). In some
1113  // cases, we can tile/clone once and reuse the value for each use.
1114  // Futhermore, producers should then be traversed according to a
1115  // topological sorting.
1116  auto [tiledOps, newContainingOp] =
1117  tileAndFuseFirstExtractUse(rewriter, diag, producerOp, containingOp);
1118  if (!tiledOps.empty()) {
1119  LLVM_DEBUG(DBGS() << "\nFused a direct extract use\n" << *containingOp);
1120  fusedOps.append(tiledOps);
1121  if (newContainingOp) {
1122  // Update handles associated with the containing op so we don't need to
1123  // invalidate them. This is a hack to support better composability
1124  // between tiling and fusion while a proper mechanism is being
1125  // investigated.
1126  //
1127  // DO NOT replicate this elsewhere unless you understand what you are
1128  // doing.
1129  LogicalResult replacementStatus =
1130  rewriter.notifyPayloadOperationReplaced(containingOp,
1131  newContainingOp);
1132  (void)replacementStatus;
1133  assert(succeeded(replacementStatus) &&
1134  "unable to update transform state mapping");
1135  rewriter.eraseOp(containingOp);
1136  containingOp = newContainingOp;
1137  }
1138  continue;
1139  }
1140 
1141  SmallVector<Operation *> tiledContainingOpOperand =
1143  rewriter, diag, producerOp, containingOp);
1144  if (!tiledContainingOpOperand.empty()) {
1145  LLVM_DEBUG(DBGS() << "\nFused an extract use through block argument\n"
1146  << *containingOp);
1147  fusedOps.append(tiledContainingOpOperand);
1148  continue;
1149  }
1150 
1151  Operation *cloned =
1152  cloneAndFuseFirstUse(rewriter, diag, producerOp, containingOp);
1153  if (cloned) {
1154  LLVM_DEBUG(DBGS() << "\nFused an use by cloning\n" << *containingOp);
1155  fusedOps.push_back(cloned);
1156  continue;
1157  }
1159  }
1160 
1161  results.set(cast<OpResult>(getFusedOp()), fusedOps);
1162  results.set(cast<OpResult>(getNewContainingOp()), {containingOp});
1164 }
1165 
1166 void transform::FuseIntoContainingOp::getEffects(
1168  consumesHandle(getProducerOpMutable(), effects);
1169  onlyReadsHandle(getContainingOpMutable(), effects);
1170  producesHandle(getOperation()->getOpResults(), effects);
1171  modifiesPayload(effects);
1172 }
1173 
1174 //===----------------------------------------------------------------------===//
1175 // GeneralizeOp
1176 //===----------------------------------------------------------------------===//
1177 
1179 transform::GeneralizeOp::applyToOne(transform::TransformRewriter &rewriter,
1180  LinalgOp target,
1182  transform::TransformState &state) {
1183  // Exit early if no transformation is needed.
1184  if (isa<GenericOp>(target)) {
1185  results.push_back(target);
1187  }
1188  rewriter.setInsertionPoint(target);
1189  FailureOr<LinalgOp> generic = generalizeNamedOp(rewriter, target);
1190  if (succeeded(generic)) {
1191  results.push_back(generic->getOperation());
1193  }
1194  return emitDefaultSilenceableFailure(target);
1195 }
1196 
1197 //===----------------------------------------------------------------------===//
1198 // SpecializeOp
1199 //===----------------------------------------------------------------------===/
1200 
1202 transform::SpecializeOp::applyToOne(transform::TransformRewriter &rewriter,
1203  LinalgOp target,
1205  transform::TransformState &state) {
1206  // Exit early if the operation is not a generic.
1207  if (!isa<GenericOp>(target)) {
1208  results.push_back(target);
1210  }
1211  rewriter.setInsertionPoint(target);
1212  FailureOr<LinalgOp> named =
1213  specializeGenericOp(rewriter, cast<GenericOp>(target));
1214  if (succeeded(named)) {
1215  results.push_back(named->getOperation());
1217  }
1218  return emitDefaultSilenceableFailure(target);
1219 }
1220 
1221 //===----------------------------------------------------------------------===//
1222 // InterchangeOp
1223 //===----------------------------------------------------------------------===//
1224 
1226 transform::InterchangeOp::applyToOne(transform::TransformRewriter &rewriter,
1227  GenericOp target,
1229  transform::TransformState &state) {
1230  ArrayRef<int64_t> interchangeVector = getIteratorInterchange();
1231  // Exit early if no transformation is needed.
1232  if (interchangeVector.empty()) {
1233  results.push_back(target);
1235  }
1236 
1237  unsigned numLoops = cast<LinalgOp>(target.getOperation()).getNumLoops();
1238  if (interchangeVector.size() != numLoops) {
1239  return emitSilenceableError()
1240  << getIteratorInterchangeAttrName() << " has length ("
1241  << interchangeVector.size()
1242  << ") different from the number of loops in the target operation ("
1243  << numLoops << ")";
1244  }
1245  FailureOr<GenericOp> res = interchangeGenericOp(
1246  rewriter, target, SmallVector<unsigned>(interchangeVector));
1247  if (failed(res))
1248  return emitDefiniteFailure() << "failed to apply";
1249  results.push_back(res->getOperation());
1251 }
1252 
1253 LogicalResult transform::InterchangeOp::verify() {
1254  ArrayRef<int64_t> permutation = getIteratorInterchange();
1255  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
1256  if (!std::is_permutation(sequence.begin(), sequence.end(),
1257  permutation.begin(), permutation.end())) {
1258  return emitOpError()
1259  << "expects iterator_interchange to be a permutation, found "
1260  << getIteratorInterchange();
1261  }
1262  return success();
1263 }
1264 
1265 //===----------------------------------------------------------------------===//
1266 // LinalgCopyToMemrefOp
1267 //===----------------------------------------------------------------------===//
1268 
1269 DiagnosedSilenceableFailure transform::LinalgCopyToMemrefOp::applyToOne(
1270  transform::TransformRewriter &rewriter, Operation *targetOp,
1272  transform::TransformState &state) {
1273 
1274  // Check if the target can be converted.
1275  if (!isa<linalg::CopyOp>(targetOp)) {
1277  emitSilenceableError() << "only linalg.copy target ops are supported";
1278  diag.attachNote(targetOp->getLoc()) << "target op";
1279  return diag;
1280  }
1281 
1282  auto copyOp = dyn_cast<linalg::CopyOp>(targetOp);
1283  if (!copyOp.hasPureBufferSemantics()) {
1285  emitSilenceableError()
1286  << "cannot transform a linalg.copy on tensors into a memref.copy";
1287  diag.attachNote(targetOp->getLoc()) << "target op";
1288  return diag;
1289  }
1290 
1291  SmallVector<Value> inputs = copyOp.getInputs();
1292  SmallVector<Value> outputs = copyOp.getOutputs();
1293  assert(inputs.size() == 1 && "expected linalg copy op with one input");
1294  assert(outputs.size() == 1 && "expected memref copy op with one output");
1295  Value input = inputs.front();
1296  Value output = outputs.front();
1297 
1298  // linalg.copy supports different element types on source/dest whereas
1299  // memref.copy does not, so we must check that the source and dest types can
1300  // be handled by memref.copy and otherwise reject the transformation.
1301  if (!isa<ShapedType>(input.getType())) {
1303  emitSilenceableError()
1304  << "cannot transform a linalg.copy which input has no shape";
1305  diag.attachNote(targetOp->getLoc()) << "target op";
1306  return diag;
1307  }
1308 
1309  // linalg.copy destination must be a shaped type.
1310  assert(isa<ShapedType>(output.getType()));
1311 
1312  if (cast<ShapedType>(input.getType()).getElementType() !=
1313  cast<ShapedType>(output.getType()).getElementType()) {
1315  emitSilenceableError()
1316  << "cannot transform a linalg.copy with different source and "
1317  "destination element types ";
1318  diag.attachNote(targetOp->getLoc()) << "target op";
1319  return diag;
1320  }
1321 
1322  // Target can be converted, do it.
1323  auto memrefCopyOp =
1324  rewriter.replaceOpWithNewOp<memref::CopyOp>(targetOp, input, output);
1325 
1326  results.push_back(memrefCopyOp);
1328 }
1329 
1330 //===----------------------------------------------------------------------===//
1331 // LowerPackOp
1332 //===----------------------------------------------------------------------===//
1333 
1334 DiagnosedSilenceableFailure transform::LowerPackOp::applyToOne(
1335  transform::TransformRewriter &rewriter, linalg::PackOp target,
1336  transform::ApplyToEachResultList &transformResults,
1337  transform::TransformState &state) {
1338  rewriter.setInsertionPoint(target);
1339  bool lowerPadLikeWithInsertSlice = getLowerPadLikeWithInsertSlice();
1340  FailureOr<LowerPackResult> res =
1341  lowerPack(rewriter, target, lowerPadLikeWithInsertSlice);
1342  if (failed(res)) {
1343  return mlir::emitSilenceableFailure(target->getLoc())
1344  << "cannot lower to pad + expand + transpose";
1345  }
1346  transformResults.push_back(res->padOp);
1347  transformResults.push_back(res->expandShapeOp);
1348  transformResults.push_back(res->transposeOp);
1350 }
1351 
1352 //===----------------------------------------------------------------------===//
1353 // LowerUnPackOp
1354 //===----------------------------------------------------------------------===//
1355 
1356 DiagnosedSilenceableFailure transform::LowerUnPackOp::applyToOne(
1357  transform::TransformRewriter &rewriter, linalg::UnPackOp target,
1358  transform::ApplyToEachResultList &transformResults,
1359  transform::TransformState &state) {
1360  rewriter.setInsertionPoint(target);
1361  bool lowerUnpadLikeWithExtractSlice = getLowerUnpadLikeWithExtractSlice();
1362  FailureOr<LowerUnPackOpResult> res =
1363  lowerUnPack(rewriter, target, lowerUnpadLikeWithExtractSlice);
1364  if (failed(res)) {
1366  emitSilenceableError()
1367  << "cannot lower to transpose + collapse + extract";
1368  diag.attachNote(target->getLoc()) << "target payload op";
1369  return diag;
1370  }
1371  transformResults.push_back(res->emptyOp);
1372  transformResults.push_back(res->transposeOp);
1373  transformResults.push_back(res->collapseShapeOp);
1374  transformResults.push_back(res->extractSliceOp);
1376 }
1377 
1378 //===---------------------------------------------------------------------===//
1379 // MatchOp
1380 //===---------------------------------------------------------------------===//
1381 
1382 void transform::MatchOp::build(OpBuilder &builder, OperationState &result,
1383  Value target, ArrayRef<StringRef> opNames) {
1384  result.addOperands(target);
1385  result.addAttribute(MatchOp::getOpsAttrName(result.name),
1386  builder.getStrArrayAttr(opNames));
1387  result.addTypes(transform::AnyOpType::get(builder.getContext()));
1388 }
1389 
1390 void transform::MatchOp::build(OpBuilder &builder, OperationState &result,
1391  TypeRange resultTypes, Value target,
1392  ArrayRef<StringRef> opNames) {
1393  result.addOperands(target);
1394  result.addAttribute(MatchOp::getOpsAttrName(result.name),
1395  builder.getStrArrayAttr(opNames));
1396  result.addTypes(resultTypes);
1397 }
1398 
1400 transform::MatchOp::apply(transform::TransformRewriter &rewriter,
1401  transform::TransformResults &results,
1402  transform::TransformState &state) {
1403  llvm::StringSet<> strs;
1404  if (getOps().has_value())
1405  strs.insert_range(getOps()->getAsValueRange<StringAttr>());
1406 
1407  auto payloadOps = state.getPayloadOps(getTarget());
1408  if (!llvm::hasSingleElement(payloadOps)) {
1409  return emitDefiniteFailure("requires exactly one target handle");
1410  }
1411 
1413  bool incorrectNumOperandTypes = false;
1414  auto matchFun = [&](Operation *op) {
1415  if (getOps().has_value() && !strs.contains(op->getName().getStringRef()))
1416  return;
1417 
1418  // Interfaces cannot be matched by name, just by ID.
1419  // So we specifically encode the interfaces we care about for this op.
1420  if (getInterface().has_value()) {
1421  auto iface = getInterface().value();
1422  if (iface == transform::MatchInterfaceEnum::LinalgOp &&
1423  !isa<LinalgOp>(op))
1424  return;
1425  if (iface == transform::MatchInterfaceEnum::TilingInterface &&
1426  !isa<TilingInterface>(op))
1427  return;
1428  if (iface == transform::MatchInterfaceEnum::LoopLikeInterface &&
1429  !isa<LoopLikeOpInterface>(op))
1430  return;
1431  }
1432 
1433  // Check if all specified attributes match.
1434  if (getOpAttrs().has_value()) {
1435  DictionaryAttr opAttrs = getOpAttrs().value();
1436  for (NamedAttribute attr : opAttrs) {
1437  if (attr.getName() == getInterfaceAttrName() ||
1438  attr.getName() == getOpsAttrName())
1439  continue;
1440  if (!op->hasAttr(attr.getName()))
1441  return;
1442  if (op->getAttr(attr.getName()) != attr.getValue())
1443  return;
1444  }
1445  }
1446 
1447  if (getFilterResultType().has_value()) {
1448  Type t = getFilterResultType().value();
1449  if (op->getNumResults() != 1 || op->getResultTypes().front() != t)
1450  return;
1451  }
1452 
1453  if (getFilterOperandTypes().has_value()) {
1454  mlir::ArrayAttr types = getFilterOperandTypes().value();
1455  auto operandTypes = op->getOperandTypes();
1456 
1457  if (types.size() == 1) {
1458  // All the operands must must be equal to the specified type
1459  auto typeattr =
1460  dyn_cast<mlir::TypeAttr>(getFilterOperandTypes().value()[0]);
1461  Type t = cast<::mlir::Type>(typeattr.getValue());
1462  if (!llvm::all_of(op->getOperandTypes(),
1463  [&](Type operandType) { return operandType == t; }))
1464  return;
1465  } else {
1466  // The operand types must match all the types in the list (in the same
1467  // order in with they are specified)
1468  if (types.size() != operandTypes.size()) {
1469  incorrectNumOperandTypes = true;
1470  return;
1471  }
1472 
1473  for (auto [attr, operandType] :
1474  llvm::zip_equal(getFilterOperandTypes().value(), operandTypes)) {
1475  auto typeattr = cast<mlir::TypeAttr>(attr);
1476  Type type = cast<::mlir::Type>(typeattr.getValue());
1477 
1478  if (type != operandType)
1479  return;
1480  }
1481  }
1482  }
1483 
1484  // All constraints are satisfied.
1485  res.push_back(op);
1486  return;
1487  };
1488 
1489  (*payloadOps.begin())->walk(matchFun);
1490  if (incorrectNumOperandTypes)
1491  return emitDefiniteFailure("If filter_operand_types contains more than a "
1492  "type, then it must contain as much types as "
1493  "the number of operands in the target ops");
1494  results.set(cast<OpResult>(getResult()), res);
1496 }
1497 
1498 //===---------------------------------------------------------------------===//
1499 // MultiTileSizesOp
1500 //===---------------------------------------------------------------------===//
1501 
1503  Type targetType, Type lowSizeType, Type,
1504  Type) {
1505  printer.printFunctionalType(TypeRange{targetType}, TypeRange{lowSizeType});
1506 }
1507 
1508 static ParseResult parseMultitileSizesTypes(OpAsmParser &parser,
1509  Type &targetType, Type &lowSizeType,
1510  Type &highSizeType,
1511  Type &splitPointType) {
1512  FunctionType funcType;
1513  llvm::SMLoc typeLoc = parser.getCurrentLocation();
1514  if (failed(parser.parseType<FunctionType>(funcType)))
1515  return failure();
1516 
1517  if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
1518  parser.emitError(typeLoc) << "expects a trailing functional type with one "
1519  "argument and one result";
1520  }
1521  targetType = funcType.getInput(0);
1522  lowSizeType = highSizeType = splitPointType = funcType.getResult(0);
1523 
1524  return success();
1525 }
1526 
1527 DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne(
1528  transform::TransformRewriter &rewriter, LinalgOp target,
1530  if (isa<TransformParamTypeInterface>(getLowSize().getType())) {
1531  if (target.hasDynamicShape()) {
1532  auto diag = emitSilenceableError()
1533  << "cannot compute parametric tile sizes for dynamically "
1534  "shaped payload op";
1535  diag.attachNote(target->getLoc()) << "payload op";
1536  return diag;
1537  }
1538 
1539  FailureOr<StaticMultiSizeSpecification> spec = computeStaticMultiTileSizes(
1540  target, getDimension(), getTargetSize(), getDivisor());
1541  if (failed(spec)) {
1542  return emitSilenceableError()
1543  << "failed to compute multi-size tiling sizes";
1544  }
1545 
1546  Builder builder(target.getContext());
1547  results.assign(llvm::map_range(
1548  ArrayRef<int64_t>({spec->lowTileSize, spec->highTileSize,
1549  spec->lowTileSize * spec->lowTripCount}),
1550  [&builder, this](int64_t value) {
1551  return builder.getIntegerAttr(
1552  cast<ParamType>(getLowSize().getType()).getType(), value);
1553  }));
1555  }
1556 
1557  OpBuilder builder(target.getContext());
1558  builder.setInsertionPoint(target);
1559  OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
1560  OpFoldResult divisor = builder.getIndexAttr(getDivisor());
1561  FailureOr<MultiSizeSpecification> spec = computeMultiTileSizes(
1562  builder, target, getDimension(), targetSize, divisor);
1563  if (failed(spec)) {
1564  return emitSilenceableError() << "could not generate tile size computation";
1565  }
1566 
1567  AffineExpr s0 = builder.getAffineSymbolExpr(0);
1568  AffineExpr s1 = builder.getAffineSymbolExpr(1);
1569  Operation *splitPoint =
1570  affine::makeComposedAffineApply(builder, target.getLoc(), s0 * s1,
1571  {spec->lowTileSize, spec->lowTripCount});
1572  Operation *lowTileSize = spec->lowTileSize.getDefiningOp();
1573  Operation *highTileSize = spec->highTileSize.getDefiningOp();
1574  assert(lowTileSize && highTileSize && splitPoint &&
1575  "tile sizes are not produced by operations");
1576  results.reserve(results.size() + 3);
1577  results.push_back(lowTileSize);
1578  results.push_back(highTileSize);
1579  results.push_back(splitPoint);
1581 }
1582 
1583 void transform::MultiTileSizesOp::getEffects(
1585  onlyReadsHandle(getTargetMutable(), effects);
1586  producesHandle(getOperation()->getOpResults(), effects);
1587  if (isa<TransformParamTypeInterface>(getLowSize().getType()))
1588  onlyReadsPayload(effects);
1589  else
1590  modifiesPayload(effects);
1591 }
1592 
1593 LogicalResult transform::MultiTileSizesOp::verify() {
1594  if (getLowSize().getType() != getHighSize().getType() ||
1595  getLowSize().getType() != getSplitPoint().getType()) {
1596  return emitOpError() << "expects all results type to be the same";
1597  }
1598  return success();
1599 }
1600 
1601 //===---------------------------------------------------------------------===//
1602 // PackOp
1603 //===---------------------------------------------------------------------===//
1604 
1605 void transform::PackOp::build(OpBuilder &builder, OperationState &result,
1606  Value target,
1607  ArrayRef<OpFoldResult> mixedPackedSizes) {
1608  SmallVector<int64_t> staticPackedSizes;
1609  SmallVector<Value> dynamicPackedSizes;
1610  dispatchIndexOpFoldResults(mixedPackedSizes, dynamicPackedSizes,
1611  staticPackedSizes);
1612  // Call the default builder which sets up the proper operands segment sizes
1613  // attributes for multiple variadic operands. In the absence of this, horrible
1614  // bugs ensue.
1615  Type linalgOpHType = transform::OperationType::get(
1616  builder.getContext(), GenericOp::getOperationName());
1617  build(builder, result,
1618  /*resultType=*/linalgOpHType,
1619  /*target=*/target,
1620  /*dynamic_sizes=*/dynamicPackedSizes,
1621  /*static_sizes=*/builder.getDenseI64ArrayAttr(staticPackedSizes));
1622 }
1623 
1624 SmallVector<OpFoldResult> transform::PackOp::getMixedPackedSizes() {
1625  Builder b(getContext());
1626  return getMixedValues(getStaticPackedSizes(), getPackedSizes(), b);
1627 }
1628 
1630 transform::PackOp::apply(transform::TransformRewriter &rewriter,
1631  transform::TransformResults &transformResults,
1632  transform::TransformState &state) {
1633  auto targetOps = state.getPayloadOps(getTarget());
1634  // If nothing to pack, propagate success.
1635  if (std::empty(targetOps)) {
1636  transformResults.set(cast<OpResult>(getPackedOp()),
1637  ArrayRef<Operation *>({}));
1639  }
1640  // Fail on multi-op handles.
1641  auto linalgOp = dyn_cast<LinalgOp>(*targetOps.begin());
1642  if (!llvm::hasSingleElement(targetOps) || !linalgOp) {
1643  return emitSilenceableError()
1644  << "requires target to map to exactly 1 LinalgOp (got "
1645  << llvm::range_size(targetOps) << ")";
1646  }
1647  // Fail on mismatched number of pack sizes.
1648  if (getMixedPackedSizes().size() != linalgOp.getNumLoops()) {
1649  return emitSilenceableError()
1650  << "requires number of packed sizes match the number of loops ("
1651  << getMixedPackedSizes().size() << " vs " << linalgOp.getNumLoops()
1652  << ")";
1653  }
1654 
1655  // Unpack handles to constants or actual SSA index values.
1656  SmallVector<OpFoldResult> packedSizes;
1658  state, *this, packedSizes, getMixedPackedSizes());
1659 
1660  rewriter.setInsertionPoint(linalgOp);
1661  FailureOr<PackResult> maybeResult = pack(rewriter, linalgOp, packedSizes);
1662  if (failed(maybeResult))
1663  return emitDefiniteFailure("data tiling failed");
1664 
1665  transformResults.set(cast<OpResult>(getPackedOp()),
1666  {maybeResult->packedLinalgOp.getOperation()});
1668 }
1669 
1670 void transform::PackOp::getEffects(
1672  transform::consumesHandle(getTargetMutable(), effects);
1673  transform::onlyReadsHandle(getPackedSizesMutable(), effects);
1674  transform::producesHandle(getOperation()->getOpResults(), effects);
1675  transform::modifiesPayload(effects);
1676 }
1677 
1678 //===---------------------------------------------------------------------===//
1679 // PackGreedilyOp.
1680 //===---------------------------------------------------------------------===//
1681 
1682 LogicalResult transform::PackGreedilyOp::verify() {
1683  if (!isPermutationVector(getMatmulInnerDimsOrder())) {
1684  return emitOpError() << getMatmulInnerDimsOrderAttrName()
1685  << " is not a valid permutation";
1686  }
1687  // TODO: relax to allow empty once we have another strategy than just matmul.
1688  if (!getMatmulPaddedSizesNextMultipleOf().empty()) {
1689  for (auto [s, nmo] :
1690  llvm::zip_equal(getMixedMatmulPackedSizes(),
1691  getMatmulPaddedSizesNextMultipleOf())) {
1692  std::optional<int64_t> maybeStaticPackedSize = getConstantIntValue(s);
1693  if (nmo != 0 &&
1694  (!maybeStaticPackedSize.has_value() || *maybeStaticPackedSize != 0)) {
1695  return emitOpError() << "at most one of the packed_size and the "
1696  "padded_sizes_next_multiple_of can be nonzero "
1697  "for the matmul strategy";
1698  }
1699  }
1700  }
1701  return success();
1702 }
1703 
1705 PackGreedilyOp::apply(transform::TransformRewriter &rewriter,
1706  transform::TransformResults &transformResults,
1707  transform::TransformState &state) {
1708  SmallVector<Operation *> results;
1709  for (Operation *op : state.getPayloadOps(getTarget())) {
1710  auto linalgOp = dyn_cast<LinalgOp>(op);
1711  if (!linalgOp)
1712  continue;
1713  // linalgOp will be replaced and the insertion point may be invalidated if
1714  // we set it before -> set it after.
1715  rewriter.setInsertionPointAfter(linalgOp);
1716  // Failing to pack greedily is perfectly fine.
1717  // In the future we will want to order packings according to some metric.
1718  FailureOr<PackResult> packResult = packMatmulGreedily(
1719  /*rewriter=*/rewriter,
1720  /*linalgOp=*/linalgOp,
1721  /*mnkPackedSizes=*/getMixedMatmulPackedSizes(),
1722  /*mnkPaddedSizesNextMultipleOf=*/
1723  getMatmulPaddedSizesNextMultipleOf(),
1724  /*mnkOrder=*/getMatmulInnerDimsOrder());
1725  if (succeeded(packResult)) {
1726  results.push_back(packResult->packedLinalgOp);
1727  continue;
1728  }
1729  results.push_back(linalgOp);
1730  }
1731  transformResults.set(cast<OpResult>(getPackedOp()), results);
1733 }
1734 
1735 SmallVector<OpFoldResult> PackGreedilyOp::getMixedMatmulPackedSizes() {
1736  Builder b(getContext());
1737  return getMixedValues(getStaticMatmulPackedSizes(), getMatmulPackedSizes(),
1738  b);
1739 }
1740 
1741 void transform::PackGreedilyOp::getEffects(
1743  transform::consumesHandle(getTargetMutable(), effects);
1744  transform::onlyReadsHandle(getMatmulPackedSizesMutable(), effects);
1745  transform::producesHandle(getOperation()->getOpResults(), effects);
1746  transform::modifiesPayload(effects);
1747 }
1748 
1749 //===---------------------------------------------------------------------===//
1750 // PackTransposeOp
1751 //===---------------------------------------------------------------------===//
1752 
1753 LogicalResult transform::PackTransposeOp::verify() {
1754  if (!isPermutationVector(getInnerPerm())) {
1755  return emitOpError() << getInnerPermAttrName()
1756  << " is not a valid permutation";
1757  }
1758  if (!isPermutationVector(getOuterPerm())) {
1759  return emitOpError() << getOuterPermAttrName()
1760  << " is not a valid permutation";
1761  }
1762  if (getInnerPerm().empty() && getOuterPerm().empty()) {
1763  return emitOpError() << " at least one of " << getInnerPermAttrName()
1764  << " or " << getOuterPermAttrName()
1765  << " must be specified";
1766  }
1767  return success();
1768 }
1769 
1770 namespace {
1771 enum class OuterOrInnerPerm { Outer = 0, Inner = 1 };
1772 } // namespace
1773 
1774 /// Return true if `permutation` is a valid permutation of the
1775 /// `outer_dims_perm` (case OuterOrInnerPerm::Outer) or `inner_dims_pos`
1776 /// (OuterOrInnerPerm::Inner) of the `tensor.pack` or `tensor.unpack` `op.
1777 /// This is the case when the `permutation` rank matches the rank expected by
1778 /// `op` and `permutation` is itself a permutation vector.
1779 /// Return true if either `op` or `permutation` are empty to allow a simpler
1780 /// polymorphic implementation.
1781 template <typename RelayoutOpTy>
1783  RelayoutOpTy op, ArrayRef<int64_t> permutation,
1784  OuterOrInnerPerm outerOrInnerPerm = OuterOrInnerPerm::Outer) {
1785  static_assert(
1786  llvm::is_one_of<RelayoutOpTy, linalg::PackOp, linalg::UnPackOp>::value,
1787  "applies to only pack or unpack operations");
1788  if (!op || permutation.empty())
1789  return true;
1790  size_t innerRank = op.getInnerDimsPos().size();
1791  if (outerOrInnerPerm == OuterOrInnerPerm::Inner)
1792  return permutation.size() == innerRank && isPermutationVector(permutation);
1793  // op.getOuterDimsPerm() may be empty, in which case it is identity.
1794  // Don't rely on it.
1795  if (std::is_same<RelayoutOpTy, linalg::PackOp>::value) {
1796  return permutation.size() == op.getSourceRank() &&
1797  isPermutationVector(permutation);
1798  }
1799  return permutation.size() == op.getDestRank() &&
1800  isPermutationVector(permutation);
1801 }
1802 
1804 transform::PackTransposeOp::apply(transform::TransformRewriter &rewriter,
1805  transform::TransformResults &transformResults,
1806  transform::TransformState &state) {
1807  auto packOrUnpackOps = state.getPayloadOps(getTargetPackOrUnPackOp());
1808  auto linalgOps = state.getPayloadOps(getTargetLinalgOp());
1809  // Step 1. If nothing to pack, propagate success.
1810  if (std::empty(packOrUnpackOps)) {
1811  transformResults.set(cast<OpResult>(getPackedOp()), {});
1812  transformResults.set(cast<OpResult>(getPackOp()), {});
1813  transformResults.set(cast<OpResult>(getUnPackOp()), {});
1815  }
1816 
1817  // Step 2. Bunch of runtime sanity check and error messages.
1818  // Step 2.1. Fail on multi-op handles.
1819  if (!llvm::hasSingleElement(packOrUnpackOps) ||
1820  !llvm::hasSingleElement(linalgOps)) {
1821  return emitSilenceableError()
1822  << "requires target to map to exactly 1 "
1823  "packing op and 1 packed op ("
1824  << "got " << llvm::range_size(packOrUnpackOps) << " and "
1825  << llvm::range_size(linalgOps) << ")";
1826  }
1827 
1828  // Step 2.2. Fail on wrong type.
1829  auto packOp = dyn_cast<linalg::PackOp>(*packOrUnpackOps.begin());
1830  auto unPackOp = dyn_cast<linalg::UnPackOp>(*packOrUnpackOps.begin());
1831  if ((!packOp && !unPackOp)) {
1832  return emitSilenceableError() << "requires target to map to a "
1833  "linalg.pack or linalg.unpack";
1834  }
1835  LinalgOp linalgOpTarget = dyn_cast<LinalgOp>(*linalgOps.begin());
1836  if (!linalgOpTarget)
1837  return emitSilenceableError() << "requires a LinalgOp target";
1838 
1839  // Step 2.3. Fail if we can't get the producer / consumer Linalg op.
1840  LinalgOp linalgOp;
1841  if (packOp && packOp.getResult().hasOneUse())
1842  linalgOp = dyn_cast<LinalgOp>(*(packOp.getResult().getUsers().begin()));
1843  else if (unPackOp)
1844  linalgOp = unPackOp.getSource().getDefiningOp<LinalgOp>();
1845  if (linalgOp != linalgOpTarget) {
1846  auto errorMsg =
1847  packOp ? StringLiteral{"not a single use by the LinalgOp target"}
1848  : StringLiteral{"not produced by the LinalgOp target"};
1849  return emitSilenceableError() << errorMsg;
1850  }
1851 
1852  // Step 2.4. If we have an UnPackOp, we need to fetch the symmetrical
1853  // PackOp.
1854  if (unPackOp) {
1855  assert(!packOp && "packOp must be null on entry when unPackOp is not null");
1856  OpOperand *packUse = linalgOp.getDpsInitOperand(
1857  cast<OpResult>(unPackOp.getSource()).getResultNumber());
1858  packOp = dyn_cast_or_null<linalg::PackOp>(packUse->get().getDefiningOp());
1859  if (!packOp || !packOp.getResult().hasOneUse())
1860  return emitSilenceableError() << "could not find matching pack op";
1861  }
1862 
1863  // Step 2.5. Fail if any permutation does not validate.
1864  for (auto permType : {OuterOrInnerPerm::Outer, OuterOrInnerPerm::Inner}) {
1865  ArrayRef<int64_t> perm =
1866  (permType == OuterOrInnerPerm::Outer) ? getOuterPerm() : getInnerPerm();
1867  auto errorMsg = (permType == OuterOrInnerPerm::Outer)
1868  ? StringLiteral{"invalid outer_perm"}
1869  : StringLiteral{"invalid inner_perm"};
1870  if (!isValidPackingPermutation(packOp, perm, permType) ||
1871  !isValidPackingPermutation(unPackOp, perm, permType)) {
1872  Operation *packOrUnpackOp =
1873  unPackOp ? unPackOp.getOperation() : packOp.getOperation();
1874  return emitSilenceableError() << errorMsg << ": " << *packOrUnpackOp;
1875  }
1876  }
1877 
1878  // From here on, packOp and linalgOp are always present, unPackOp may or may
1879  // not be present.
1880  assert(packOp && linalgOp && "unexpected null op");
1881 
1882  // Step 3. Actually transpose the ops.
1883  FailureOr<PackTransposeResult> res = packTranspose(
1884  rewriter, packOp, linalgOp, unPackOp, getOuterPerm(), getInnerPerm());
1885  // Preconditions have been checked, it is an error to fail here.
1886  assert(succeeded(res) && "unexpected packTranspose failure");
1887 
1888  // Step 4. Return results.
1889  transformResults.set(cast<OpResult>(getPackOp()), {res->transposedPackOp});
1890  transformResults.set(cast<OpResult>(getPackedOp()),
1891  {res->transposedLinalgOp});
1892  if (unPackOp) {
1893  transformResults.set(cast<OpResult>(getUnPackOp()),
1894  {res->transposedUnPackOp});
1895  } else {
1896  transformResults.set(cast<OpResult>(getUnPackOp()), {});
1897  }
1898 
1900 }
1901 
1902 //===---------------------------------------------------------------------===//
1903 // PadOp
1904 //===---------------------------------------------------------------------===//
1905 
1906 void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
1907  ArrayRef<int64_t> paddingDimensions,
1908  ArrayRef<int64_t> padToMultipleOf,
1909  ArrayRef<int64_t> nofoldFlags,
1910  ArrayRef<Attribute> transposePaddings,
1911  StringRef copyBackOp,
1912  bool usePrescribedTensorShapes) {
1913  auto resultType = transform::AnyOpType::get(b.getContext());
1914  return build(/*builder=*/b,
1915  /*result=*/result,
1916  /*types=*/TypeRange{resultType, resultType},
1917  /*target=*/target,
1918  /*paddingValues=*/ArrayAttr(), // let inference handle this
1919  /*paddingDimensions=*/b.getI64ArrayAttr(paddingDimensions),
1920  /*padToMultipleOf=*/ValueRange{},
1921  /*padToMultipleOf=*/
1922  (padToMultipleOf.empty()
1923  ? DenseI64ArrayAttr()
1924  : b.getDenseI64ArrayAttr(padToMultipleOf)),
1925  /*nofoldFlags=*/b.getI64ArrayAttr(nofoldFlags),
1926  /*transposePaddings=*/b.getArrayAttr(transposePaddings),
1927  /*copyBackOp=*/b.getStringAttr(copyBackOp),
1928  /*usePrescribedTensorShapes=*/
1929  usePrescribedTensorShapes ? b.getUnitAttr() : nullptr);
1930 }
1931 
1932 void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
1933  ArrayRef<int64_t> paddingDimensions,
1934  ArrayRef<OpFoldResult> mixedPadToMultipleOf,
1935  ArrayRef<int64_t> nofoldFlags,
1936  ArrayRef<Attribute> transposePaddings,
1937  StringRef copyBackOp,
1938  bool usePrescribedTensorShapes) {
1939  auto resultType = transform::AnyOpType::get(b.getContext());
1940  SmallVector<int64_t> staticPadToMultipleOf;
1941  SmallVector<Value> dynamicPadToMultipleOf;
1942  dispatchIndexOpFoldResults(mixedPadToMultipleOf, dynamicPadToMultipleOf,
1943  staticPadToMultipleOf);
1944  return build(/*builder=*/b,
1945  /*result=*/result,
1946  /*types=*/TypeRange{resultType, resultType},
1947  /*target=*/target,
1948  /*paddingValues=*/ArrayAttr(), // let inference handle this
1949  /*paddingDimensions=*/b.getI64ArrayAttr(paddingDimensions),
1950  /*padToMultipleOf=*/dynamicPadToMultipleOf,
1951  /*padToMultipleOf=*/staticPadToMultipleOf,
1952  /*nofoldFlags=*/b.getI64ArrayAttr(nofoldFlags),
1953  /*transposePaddings=*/b.getArrayAttr(transposePaddings),
1954  /*copyBackOp=*/copyBackOp,
1955  /*usePrescribedTensorShapes=*/usePrescribedTensorShapes);
1956 }
1957 
1958 void PadOp::getEffects(
1960  consumesHandle(getTargetMutable(), effects);
1961  onlyReadsHandle(getPadToMultipleOfMutable(), effects);
1962  producesHandle(getOperation()->getOpResults(), effects);
1963  modifiesPayload(effects);
1964 }
1965 
1966 SmallVector<OpFoldResult> PadOp::getMixedPadToMultipleOf() {
1967  Builder b(getContext());
1968  return getMixedValues(getStaticPadToMultipleOf(), getPadToMultipleOf(), b);
1969 }
1970 
1972 transform::PadOp::apply(transform::TransformRewriter &rewriter,
1973  transform::TransformResults &results,
1974  transform::TransformState &state) {
1975  auto transformOp = cast<TransformOpInterface>(getOperation());
1976  SmallVector<Operation *> paddedOps, padOps, copyBackOps;
1977 
1978  for (Operation *target : state.getPayloadOps(getTarget())) {
1979  auto linalgTarget = dyn_cast<LinalgOp>(target);
1980  if (!linalgTarget) {
1981  auto diag = emitSilenceableError() << "expected LinalgOp target";
1982  diag.attachNote(target->getLoc()) << "target op";
1983  return diag;
1984  }
1985 
1986  // Convert the integer packing flags to booleans.
1987  SmallVector<bool> nofoldFlags;
1988  for (int64_t packPadding :
1989  extractFromIntegerArrayAttr<int64_t>(getNofoldFlags()))
1990  nofoldFlags.push_back(static_cast<bool>(packPadding));
1991 
1992  // Convert the padding values to attributes.
1993  SmallVector<Attribute> paddingValues;
1994  for (auto const &it :
1995  llvm::zip(getPaddingValues(), linalgTarget->getOperandTypes())) {
1996  auto attr = dyn_cast<TypedAttr>(std::get<0>(it));
1997  if (!attr) {
1998  emitOpError("expects padding values to be typed attributes");
2000  }
2001  Type elementType = getElementTypeOrSelf(std::get<1>(it));
2002  // Try to parse string attributes to obtain an attribute of element type.
2003  if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
2004  auto parsedAttr = dyn_cast_if_present<TypedAttr>(parseAttribute(
2005  stringAttr, getContext(), elementType,
2006  /*numRead=*/nullptr, /*isKnownNullTerminated=*/true));
2007  if (!parsedAttr || parsedAttr.getType() != elementType) {
2008  auto diag = this->emitOpError("expects a padding that parses to ")
2009  << elementType << ", got " << std::get<0>(it);
2010  diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";
2012  }
2013  paddingValues.push_back(parsedAttr);
2014  continue;
2015  }
2016  // Otherwise, add the attribute directly.
2017  if (attr.getType() != elementType) {
2018  auto diag = this->emitOpError("expects a padding value of type ")
2019  << elementType << ", got " << attr;
2020  diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";
2022  }
2023  paddingValues.push_back(attr);
2024  }
2025 
2026  // Extract the transpose vectors.
2027  SmallVector<SmallVector<int64_t>> transposePaddings;
2028  for (Attribute transposeVector : cast<ArrayAttr>(getTransposePaddings()))
2029  transposePaddings.push_back(extractFromIntegerArrayAttr<int64_t>(
2030  cast<ArrayAttr>(transposeVector)));
2031 
2032  LinalgOp paddedOp;
2034  options.paddingDimensions =
2035  extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
2036 
2037  SmallVector<int64_t> padToMultipleOf;
2039  state, transformOp, getMixedPadToMultipleOf(), padToMultipleOf);
2040  if (!status.succeeded())
2041  return status;
2042  if (padToMultipleOf.empty())
2043  padToMultipleOf =
2044  SmallVector<int64_t>(options.paddingDimensions.size(), 1);
2045 
2046  options.padToMultipleOf = padToMultipleOf;
2047  options.paddingValues = paddingValues;
2048  options.nofoldFlags = nofoldFlags;
2049  if (getCopyBackOp() ==
2050  bufferization::MaterializeInDestinationOp::getOperationName()) {
2053  } else if (getCopyBackOp() == linalg::CopyOp::getOperationName()) {
2055  } else if (getCopyBackOp() == kCopyOpNone) {
2057  } else {
2058  llvm_unreachable("unsupported copy_back op");
2059  }
2060  // Populate `sizeToPadTo` with the dynamic tensor sizes for each operand.
2061  bool irChanged = false;
2062  if (getUsePrescribedTensorShapes() &&
2063  linalgTarget.hasPureTensorSemantics()) {
2064  OpBuilder::InsertionGuard g(rewriter);
2065  rewriter.setInsertionPoint(linalgTarget);
2066  for (OpOperand &operand : linalgTarget->getOpOperands()) {
2067  for (auto [i, dim] : llvm::enumerate(linalgTarget.getShape(&operand))) {
2068  if (!ShapedType::isDynamic(dim))
2069  continue;
2070  options.setSizeToPadTo(operand.getOperandNumber(), i,
2071  tensor::getMixedSize(rewriter,
2072  operand.get().getLoc(),
2073  operand.get(), i));
2074  irChanged = true;
2075  }
2076  }
2077  }
2078 
2079  SmallVector<Value> replacements;
2080  SmallVector<tensor::PadOp> newPadOps;
2081  if (failed(rewriteAsPaddedOp(rewriter, linalgTarget, options, paddedOp,
2082  replacements, newPadOps))) {
2083  if (irChanged) {
2084  auto diag = emitDefiniteFailure() << "failed to pad op";
2085  diag.attachNote(target->getLoc()) << "target op";
2086  return diag;
2087  }
2088  auto diag = emitSilenceableError() << "failed to pad op";
2089  diag.attachNote(target->getLoc()) << "target op";
2090  return diag;
2091  }
2092 
2093  // We need to perform our own replacement here because this API is still
2094  // used in patterns that "pad and hoist", for which the replacement values
2095  // need to be different.
2096  // TODO: clean this up and stop "pad and hoist" behavior more globally now
2097  // that we have more composable abstractions.
2098  rewriter.replaceOp(linalgTarget, replacements);
2099  paddedOps.push_back(paddedOp);
2100  padOps.append(newPadOps.begin(), newPadOps.end());
2101  if (options.copyBackOp != LinalgPaddingOptions::CopyBackOp::None) {
2102  for (Value v : replacements) {
2103  Operation *copyBackOp = v.getDefiningOp();
2104  if (!llvm::is_contained(copyBackOps, copyBackOp))
2105  copyBackOps.push_back(copyBackOp);
2106  }
2107  }
2108  }
2109 
2110  results.set(cast<OpResult>(getPadded()), paddedOps);
2111  results.set(cast<OpResult>(getPad()), padOps);
2112  results.set(cast<OpResult>(getCopy()), copyBackOps);
2114 }
2115 
2116 LogicalResult transform::PadOp::verify() {
2117  SmallVector<int64_t> nofoldFlags =
2118  extractFromIntegerArrayAttr<int64_t>(getNofoldFlags());
2119  if (any_of(nofoldFlags, [](int64_t packPadding) {
2120  return packPadding != 0 && packPadding != 1;
2121  })) {
2122  return emitOpError()
2123  << "expects nofold_flags to contain booleans (0/1), found "
2124  << getNofoldFlags();
2125  }
2126 
2127  SmallVector<int64_t> paddingDimensions =
2128  extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
2129  if (any_of(paddingDimensions,
2130  [](int64_t paddingDimension) { return paddingDimension < 0; })) {
2131  return emitOpError() << "expects padding_dimensions to contain positive "
2132  "integers, found "
2133  << getPaddingDimensions();
2134  }
2135  if (!getMixedPadToMultipleOf().empty()) {
2136  if (getMixedPadToMultipleOf().size() != paddingDimensions.size()) {
2137  return emitOpError() << "expects as many multiples as padding_dimensions";
2138  }
2139  }
2140  ArrayAttr transposes = getTransposePaddings();
2141  for (Attribute attr : transposes) {
2142  SmallVector<int64_t> transpose = extractFromIntegerArrayAttr<int64_t>(attr);
2143  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2144  if (!std::is_permutation(sequence.begin(), sequence.end(),
2145  transpose.begin(), transpose.end())) {
2146  return emitOpError()
2147  << "expects transpose_paddings to be a permutation, found "
2148  << attr;
2149  }
2150  }
2151  if (getCopyBackOp() !=
2152  bufferization::MaterializeInDestinationOp::getOperationName() &&
2153  getCopyBackOp() != linalg::CopyOp::getOperationName() &&
2154  getCopyBackOp() != kCopyOpNone)
2155  return emitOpError() << "invalid copy_back_op";
2156  return success();
2157 }
2158 
2159 //===---------------------------------------------------------------------===//
2160 // PadTilingInterfaceOp
2161 //===---------------------------------------------------------------------===//
2162 
2163 void transform::PadTilingInterfaceOp::build(OpBuilder &b,
2164  OperationState &result,
2165  Value target,
2166  ArrayRef<int64_t> paddingDimensions,
2167  ArrayRef<int64_t> paddingSizes,
2168  bool padToMultipleOf) {
2169  auto resultType = transform::AnyOpType::get(b.getContext());
2170  return build(/*builder=*/b,
2171  /*result=*/result,
2172  /*types=*/TypeRange{resultType, resultType},
2173  /*target=*/target,
2174  /*paddingValues=*/ArrayAttr(), // let inference handle this
2175  /*paddingDimensions=*/b.getI64ArrayAttr(paddingDimensions),
2176  /*paddingSizes=*/ValueRange{},
2177  /*paddingSizes=*/
2178  (paddingSizes.empty() ? DenseI64ArrayAttr()
2179  : b.getDenseI64ArrayAttr(paddingSizes)),
2180  /*padToMultipleOf=*/
2181  padToMultipleOf ? b.getUnitAttr() : nullptr);
2182 }
2183 
2184 void transform::PadTilingInterfaceOp::build(
2185  OpBuilder &b, OperationState &result, Value target,
2186  ArrayRef<int64_t> paddingDimensions,
2187  ArrayRef<OpFoldResult> mixedPaddingSizes, bool padToMultipleOf) {
2188  auto resultType = transform::AnyOpType::get(b.getContext());
2189  SmallVector<int64_t> staticPaddingSizes;
2190  SmallVector<Value> dynamicPaddingSizes;
2191  dispatchIndexOpFoldResults(mixedPaddingSizes, dynamicPaddingSizes,
2192  staticPaddingSizes);
2193  return build(/*builder=*/b,
2194  /*result=*/result,
2195  /*types=*/TypeRange{resultType, resultType},
2196  /*target=*/target,
2197  /*paddingValues=*/ArrayAttr(), // let inference handle this
2198  /*paddingDimensions=*/b.getI64ArrayAttr(paddingDimensions),
2199  /*paddingSizes=*/dynamicPaddingSizes,
2200  /*paddingSizes=*/staticPaddingSizes,
2201  /*usePrescribedTensorShapes=*/padToMultipleOf);
2202 }
2203 
2204 void transform::PadTilingInterfaceOp::getEffects(
2206  consumesHandle(getTargetMutable(), effects);
2207  onlyReadsHandle(getPaddingSizesMutable(), effects);
2208  producesHandle(getOperation()->getOpResults(), effects);
2209  modifiesPayload(effects);
2210 }
2211 
2213 transform::PadTilingInterfaceOp::getMixedPaddingSizes() {
2214  Builder b(getContext());
2215  return getMixedValues(getStaticPaddingSizes(), getPaddingSizes(), b);
2216 }
2217 
2219 transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter,
2220  transform::TransformResults &results,
2221  transform::TransformState &state) {
2222  SmallVector<Operation *> paddedOps, padOps;
2223 
2224  for (Operation *target : state.getPayloadOps(getTarget())) {
2225  auto targetOp = dyn_cast<TilingInterface>(target);
2226  if (!targetOp) {
2227  auto diag = emitSilenceableError() << "expected TilingInterface target";
2228  diag.attachNote(target->getLoc()) << "target op";
2229  return diag;
2230  }
2231 
2232  // Only IndexingMapOpInterface ops for now, until TilingInterface exposes a
2233  // loopsToOperand map / C++ APIs to compute the effect of padding on
2234  // operands.
2235  if (!isa<IndexingMapOpInterface>(targetOp.getOperation())) {
2236  auto diag = emitSilenceableError() << "only IndexingMapOpInterface ops "
2237  "supported atm";
2238  diag.attachNote(target->getLoc()) << "target op";
2239  return diag;
2240  }
2241 
2242  // Convert the padding values to attributes.
2243  SmallVector<Attribute> paddingValues;
2244  for (auto const &[untypedAttr, elementOrTensorType] :
2245  llvm::zip(getPaddingValues(), targetOp->getOperandTypes())) {
2246  auto attr = dyn_cast<TypedAttr>(untypedAttr);
2247  Type elementType = getElementTypeOrSelf(elementOrTensorType);
2248  if (!attr) {
2249  emitOpError("expects padding values to be typed attributes");
2251  }
2252  // Try to parse string attributes to obtain an attribute of element type.
2253  if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
2254  auto parsedAttr = dyn_cast_if_present<TypedAttr>(parseAttribute(
2255  stringAttr, getContext(), elementType,
2256  /*numRead=*/nullptr, /*isKnownNullTerminated=*/true));
2257  if (!parsedAttr || parsedAttr.getType() != elementType) {
2258  auto diag = this->emitOpError("expects a padding that parses to ")
2259  << elementType << ", got " << attr;
2260  diag.attachNote(targetOp.getLoc()) << "when applied to this op";
2262  }
2263  paddingValues.push_back(parsedAttr);
2264  continue;
2265  }
2266  // Otherwise, add the attribute directly.
2267  if (attr.getType() != elementType) {
2268  auto diag = this->emitOpError("expects a padding value of type ")
2269  << elementType << ", got " << attr;
2270  diag.attachNote(targetOp.getLoc()) << "when applied to this op";
2272  }
2273  paddingValues.push_back(attr);
2274  }
2275 
2276  // Set options.
2277  TilingInterface paddedOp;
2279  options.setPaddingValues(paddingValues)
2280  .setPaddingDimensions(
2281  extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions()))
2282  .setPaddingSizes(getMixedPaddingSizes())
2283  .setPadToMultipleOf(getPadToMultipleOf());
2284 
2285  // Apply padding.
2286  SmallVector<tensor::PadOp> newPadOps;
2287  FailureOr<TilingInterface> maybePaddedOp = rewriteAsPaddedOp(
2288  rewriter, cast<TilingInterface>(targetOp.getOperation()), options,
2289  newPadOps);
2290  if (failed(maybePaddedOp)) {
2291  auto diag = emitSilenceableError() << "failed to pad op";
2292  diag.attachNote(target->getLoc()) << "target op";
2293  return diag;
2294  }
2295 
2296  // Set transform results.
2297  paddedOps.push_back(cast<TilingInterface>(maybePaddedOp->getOperation()));
2298  padOps.append(newPadOps.begin(), newPadOps.end());
2299  }
2300 
2301  results.set(cast<OpResult>(getPadded()), paddedOps);
2302  results.set(cast<OpResult>(getPad()), padOps);
2304 }
2305 
2307  SmallVector<int64_t> paddingDimensions =
2308  extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
2309  if (any_of(paddingDimensions,
2310  [](int64_t paddingDimension) { return paddingDimension < 0; })) {
2311  return emitOpError() << "expects padding_dimensions to contain positive "
2312  "integers, found "
2313  << getPaddingDimensions();
2314  }
2315  if (getMixedPaddingSizes().size() != paddingDimensions.size()) {
2316  return emitOpError() << "expects as many multiples as padding_dimensions";
2317  }
2318  return success();
2319 }
2320 
2321 //===---------------------------------------------------------------------===//
2322 // HoistPadOp
2323 //===---------------------------------------------------------------------===//
2324 
2325 DiagnosedSilenceableFailure transform::HoistPadBuildPackingLoopNestOp::apply(
2326  transform::TransformRewriter &rewriter,
2327  transform::TransformResults &transformResults,
2328  transform::TransformState &state) {
2329  auto targetOps = state.getPayloadOps(getTarget());
2330  auto loopOps = state.getPayloadOps(getLoop());
2331  if (!llvm::hasSingleElement(targetOps) || !llvm::hasSingleElement(loopOps)) {
2332  return emitDefiniteFailure()
2333  << "requires exactly one target and one loop handle (got "
2334  << llvm::range_size(targetOps) << " and "
2335  << llvm::range_size(loopOps) << ")";
2336  }
2337 
2338  auto padOp = dyn_cast_or_null<tensor::PadOp>(*targetOps.begin());
2339  auto loopOp = dyn_cast_or_null<scf::ForOp>(*loopOps.begin());
2340  if (!padOp || !loopOp)
2341  return emitDefiniteFailure() << "requires exactly 2 non-null handles";
2342 
2343  FailureOr<linalg::detail::PackingResult> result =
2344  linalg::detail::buildPackingLoopNest(rewriter, padOp, loopOp,
2345  getTranspose());
2346  if (failed(result))
2347  return emitDefiniteFailure() << "could not build packing loop nest";
2348 
2349  if (result->clonedLoopIvs.empty()) {
2350  transformResults.set(cast<OpResult>(getPackingLoop()),
2351  {result->hoistedPadOp.getOperation()});
2353  }
2354  auto outerPackedLoop =
2355  scf::getForInductionVarOwner(result->clonedLoopIvs.front());
2356  transformResults.set(cast<OpResult>(getPackingLoop()),
2357  {outerPackedLoop.getOperation()});
2359 }
2360 
2362  ArrayRef<int64_t> transpose = getTranspose();
2363  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2364  if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),
2365  transpose.end())) {
2366  return emitOpError() << "expects transpose to be a permutation, found "
2367  << getTranspose();
2368  }
2369  return success();
2370 }
2371 
2372 void transform::HoistPadBuildPackingLoopNestOp::getEffects(
2374  transform::onlyReadsHandle(getTargetMutable(), effects);
2375  transform::onlyReadsHandle(getLoopMutable(), effects);
2376  transform::producesHandle(getOperation()->getOpResults(), effects);
2377  transform::modifiesPayload(effects);
2378 }
2379 
2381 transform::HoistPadOp::applyToOne(transform::TransformRewriter &rewriter,
2382  tensor::PadOp target,
2384  transform::TransformState &state) {
2385  tensor::PadOp hoistedPadOp;
2386  SmallVector<TransposeOp> transposeOps;
2387  FailureOr<Value> result =
2388  hoistPaddingOnTensors(rewriter, target, getNumLoops(), getTranspose(),
2389  hoistedPadOp, transposeOps);
2390  if (succeeded(result)) {
2391  // We need to perform our own replacement here because this API is still
2392  // used in patterns that "pad and hoist", for which the replacement values
2393  // need to be different.
2394  // TODO: clean this up and stop "pad and hoist" behavior more globally now
2395  // that we have more composable abstractions.
2396  rewriter.replaceOp(target, *result);
2397  results.push_back(hoistedPadOp);
2399  }
2400  return emitDefaultSilenceableFailure(target);
2401 }
2402 
2403 LogicalResult transform::HoistPadOp::verify() {
2404  ArrayRef<int64_t> transpose = getTranspose();
2405  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2406  if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),
2407  transpose.end())) {
2408  return emitOpError() << "expects transpose to be a permutation, found "
2409  << getTranspose();
2410  }
2411  return success();
2412 }
2413 
2414 //===----------------------------------------------------------------------===//
2415 // PromoteOp
2416 //===----------------------------------------------------------------------===//
2417 
2419 transform::PromoteOp::applyToOne(transform::TransformRewriter &rewriter,
2420  LinalgOp target,
2422  transform::TransformState &state) {
2423  LinalgPromotionOptions promotionOptions;
2424  if (!getOperandsToPromote().empty())
2425  promotionOptions = promotionOptions.setOperandsToPromote(
2426  extractFromIntegerArrayAttr<int64_t>(getOperandsToPromote()));
2427  if (getUseFullTilesByDefault())
2428  promotionOptions = promotionOptions.setUseFullTileBuffersByDefault(
2429  getUseFullTilesByDefault());
2430  if (getUseAlloca())
2431  promotionOptions = promotionOptions.setUseAlloca(getUseAlloca());
2432  if (!getUseFullTileBuffers().empty())
2433  promotionOptions = promotionOptions.setUseFullTileBuffers(
2434  llvm::to_vector(getUseFullTileBuffers().getAsValueRange<BoolAttr>()));
2435  if (getAlignment().has_value())
2436  promotionOptions = promotionOptions.setAlignment(*getAlignment());
2437  if (getMemorySpace().has_value())
2438  promotionOptions = promotionOptions.setMemorySpace(*getMemorySpace());
2439 
2440  if (getMapping().has_value()) {
2441  // The mapping should only contain an element
2442  auto mapping = *getMapping();
2443  if (mapping.size() > 1)
2444  return emitDefaultDefiniteFailure(target);
2445 
2446  auto addressSpace = cast<mlir::gpu::GPUMemorySpaceMappingAttr>(mapping[0]);
2447 
2448  if (addressSpace.getAddressSpace() ==
2449  mlir::gpu::GPUDialect::getWorkgroupAddressSpace()) {
2450  promotionOptions =
2451  promotionOptions
2455  .setUseFullTileBuffers({false, false});
2456  } else if (addressSpace.getAddressSpace() ==
2457  mlir::gpu::GPUDialect::getPrivateAddressSpace()) {
2458  promotionOptions =
2459  promotionOptions
2463  .setUseFullTileBuffers({false, false});
2464  } else {
2465  return emitDefaultDefiniteFailure(target);
2466  }
2467  }
2468 
2469  if (failed(promoteSubviewsPrecondition(target, promotionOptions)))
2470  return emitDefaultDefiniteFailure(target);
2471 
2472  rewriter.setInsertionPoint(target);
2473  FailureOr<LinalgOp> res = promoteSubViews(rewriter, target, promotionOptions);
2474  if (failed(res))
2475  return emitDefaultDefiniteFailure(target);
2476  results.push_back(target);
2478 }
2479 
2480 //===----------------------------------------------------------------------===//
2481 // ReplaceOp
2482 //===----------------------------------------------------------------------===//
2483 
2485 transform::ReplaceOp::apply(transform::TransformRewriter &rewriter,
2486  TransformResults &transformResults,
2487  TransformState &state) {
2488  auto payload = state.getPayloadOps(getTarget());
2489 
2490  // Check for invalid targets.
2491  for (Operation *target : payload) {
2492  if (target->getNumOperands() > 0)
2493  return emitDefiniteFailure() << "expected target without operands";
2494  if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
2495  target->getNumRegions() > 0)
2496  return emitDefiniteFailure()
2497  << "expected target that is isolated from above";
2498  }
2499 
2500  // Clone and replace.
2501  Operation *pattern = &getBodyRegion().front().front();
2502  SmallVector<Operation *> replacements;
2503  for (Operation *target : payload) {
2504  if (getOperation()->isAncestor(target))
2505  continue;
2506  rewriter.setInsertionPoint(target);
2507  Operation *replacement = rewriter.clone(*pattern);
2508  rewriter.replaceOp(target, replacement->getResults());
2509  replacements.push_back(replacement);
2510  }
2511  transformResults.set(cast<OpResult>(getReplacement()), replacements);
2513 }
2514 
2515 void transform::ReplaceOp::getEffects(
2517  consumesHandle(getTargetMutable(), effects);
2518  producesHandle(getOperation()->getOpResults(), effects);
2519  modifiesPayload(effects);
2520 }
2521 
2522 LogicalResult transform::ReplaceOp::verify() {
2523  if (!getBodyRegion().hasOneBlock())
2524  return emitOpError() << "expected one block";
2525  if (std::distance(getBodyRegion().front().begin(),
2526  getBodyRegion().front().end()) != 1)
2527  return emitOpError() << "expected one operation in block";
2528  Operation *replacement = &getBodyRegion().front().front();
2529  if (replacement->getNumOperands() > 0)
2530  return replacement->emitOpError()
2531  << "expected replacement without operands";
2532  if (!replacement->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
2533  replacement->getNumRegions() > 0)
2534  return replacement->emitOpError()
2535  << "expect op that is isolated from above";
2536  return success();
2537 }
2538 
2539 //===----------------------------------------------------------------------===//
2540 // ScalarizeOp
2541 //===----------------------------------------------------------------------===//
2542 
2544 transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
2545  LinalgOp target,
2547  transform::TransformState &state) {
2548  scf::SCFTilingOptions tilingOptions;
2549  tilingOptions.setTileSizeComputationFunction([&](OpBuilder &b, Operation *) {
2550  SmallVector<OpFoldResult> tileSizes;
2551  Location loc = target.getLoc();
2552  SmallVector<OpFoldResult> allShapeSizes =
2553  target.createFlatListOfOperandDims(b, loc);
2554  AffineMap map = target.getShapesToLoopsMap();
2555  if (!map)
2556  return tileSizes;
2557  SmallVector<OpFoldResult> shapeSizes =
2559  allShapeSizes);
2560  // If the shape size is dynamic, tile by 1.
2561  // Otherwise, do not tile (i.e. tile size 0).
2562  for (OpFoldResult shapeSize : shapeSizes) {
2563  tileSizes.push_back(getConstantIntValue(shapeSize) ? b.getIndexAttr(0)
2564  : b.getIndexAttr(1));
2565  }
2566  return tileSizes;
2567  });
2568  rewriter.setInsertionPoint(target);
2569  FailureOr<scf::SCFTilingResult> maybeTilingResult = tileUsingSCF(
2570  rewriter, cast<TilingInterface>(target.getOperation()), tilingOptions);
2571  if (failed(maybeTilingResult))
2572  return emitDefaultDefiniteFailure(target);
2573 
2574  if (target->getNumResults())
2575  rewriter.replaceOp(target, maybeTilingResult->replacements);
2576  else
2577  rewriter.eraseOp(target);
2578 
2579  results.reserve(maybeTilingResult->tiledOps.size());
2580  for (Operation *tiled : maybeTilingResult->tiledOps)
2581  results.push_back(tiled);
2583 }
2584 
2585 //===----------------------------------------------------------------------===//
2586 // ConvertToLoopsOp
2587 //===----------------------------------------------------------------------===//
2588 
2590 transform::ConvertToLoopsOp::apply(transform::TransformRewriter &rewriter,
2591  transform::TransformResults &results,
2592  transform::TransformState &state) {
2594  for (Operation *target : state.getPayloadOps(getTarget())) {
2595  auto tilingOp = dyn_cast<TilingInterface>(*target);
2596  if (!tilingOp) {
2598  emitSilenceableError()
2599  << "expected the payload to implement TilingInterface";
2600  diag.attachNote(target->getLoc()) << "payload op";
2601  return diag;
2602  }
2603  rewriter.setInsertionPoint(target);
2604  FailureOr<SmallVector<scf::ForOp>> generatedLoops =
2605  scf::lowerToLoopsUsingSCFForOp(rewriter, tilingOp);
2606  if (failed(generatedLoops))
2607  return emitDefaultDefiniteFailure(target);
2608  for (scf::ForOp &loop : *generatedLoops) {
2609  loops.push_back(loop.getOperation());
2610  }
2611  rewriter.eraseOp(target);
2612  }
2613  results.set(cast<OpResult>(getResult()), loops);
2615 }
2616 
2617 //===----------------------------------------------------------------------===//
2618 // RewriteInDestinationPassingStyleOp
2619 //===----------------------------------------------------------------------===//
2620 
2622 transform::RewriteInDestinationPassingStyleOp::applyToOne(
2623  transform::TransformRewriter &rewriter, Operation *target,
2625  transform::TransformState &state) {
2626  rewriter.setInsertionPoint(target);
2627  FailureOr<Operation *> maybeResult =
2629  .Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp>(
2630  [&rewriter](auto op) {
2631  return rewriteInDestinationPassingStyle(rewriter, op);
2632  });
2633  if (failed(maybeResult))
2634  return emitDefaultSilenceableFailure(target);
2635  results.push_back(*maybeResult);
2637 }
2638 
2639 //===----------------------------------------------------------------------===//
2640 // SplitOp
2641 //===----------------------------------------------------------------------===//
2642 
2644 SplitOp::apply(transform::TransformRewriter &rewriter,
2645  TransformResults &results, TransformState &state) {
2646  // Collect the dynamic split points if provided.
2647  SmallVector<Operation *> payload =
2648  llvm::to_vector(state.getPayloadOps(getTarget()));
2649 
2650  bool isMultiwaySplit = getMultiway();
2651 
2652  if (isMultiwaySplit && !llvm::hasSingleElement(payload)) {
2653  return mlir::emitSilenceableFailure(getLoc())
2654  << "requires exactly one target when "
2655  "multiway split is enabled (got "
2656  << llvm::range_size(payload) << ")";
2657  }
2658 
2659  SmallVector<OpFoldResult> chunkSizes;
2660 
2661  if (!isMultiwaySplit)
2662  chunkSizes.reserve(payload.size());
2663 
2664  if (getDynamicChunkSizes()) {
2666  if (isa<TransformHandleTypeInterface>(getDynamicChunkSizes().getType())) {
2667  chunkSizes = llvm::to_vector(llvm::map_range(
2668  state.getPayloadOps(getDynamicChunkSizes()), [&](Operation *op) {
2669  if (op->getNumResults() != 1 ||
2670  !op->getResult(0).getType().isIndex()) {
2671  diag = emitSilenceableError()
2672  << "expected dynamic split point handle to point to a "
2673  "single-result index-typed op";
2674  diag.attachNote(op->getLoc()) << "dynamic split point";
2675  }
2676  return OpFoldResult(op->getResult(0));
2677  }));
2678  } else {
2679  chunkSizes = llvm::to_vector(
2680  llvm::map_range(state.getParams(getDynamicChunkSizes()),
2681  [](Attribute attr) { return OpFoldResult(attr); }));
2682  }
2683  if (diag.isSilenceableFailure())
2684  return diag;
2685 
2686  // For multiway split, a single payload is expected to have multiple
2687  // split points.
2688  if (!isMultiwaySplit && chunkSizes.size() != payload.size()) {
2689  return emitDefiniteFailure()
2690  << "expected the dynamic split point handle to point to as "
2691  "many operations ("
2692  << chunkSizes.size() << ") as the target handle ("
2693  << payload.size() << ")";
2694  }
2695  } else {
2696  chunkSizes.resize(payload.size(),
2697  rewriter.getIndexAttr(getStaticChunkSizes()));
2698  }
2699 
2700  auto checkStructuredOpAndDimensions =
2701  [&](LinalgOp linalgOp, Location loc) -> DiagnosedSilenceableFailure {
2702  if (!linalgOp) {
2703  auto diag = emitSilenceableError() << "only applies to structured ops";
2704  diag.attachNote(loc) << "target op";
2705  return diag;
2706  }
2707 
2708  if (getDimension() >= linalgOp.getNumLoops()) {
2709  auto diag = emitSilenceableError() << "dimension " << getDimension()
2710  << " does not exist in target op";
2711  diag.attachNote(loc) << "target op";
2712  return diag;
2713  }
2715  };
2716 
2717  auto checkFailureInSplitting =
2718  [&](bool hasFailed, Location loc) -> DiagnosedSilenceableFailure {
2719  if (hasFailed) {
2720  auto diag = emitDefiniteFailure() << "internal failure in splitting";
2721  diag.attachNote(loc) << "target op";
2722  return diag;
2723  }
2725  };
2726 
2727  SmallVector<Operation *> opList;
2728  if (isMultiwaySplit) {
2729 
2730  // Split a single target operation at multiple points.
2731  TilingInterface head, tail;
2732  Operation *target = payload.front();
2733 
2734  LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
2735 
2736  // Check that the target is a valid LinalgOp with correct dimensions.
2738  checkStructuredOpAndDimensions(linalgOp, target->getLoc());
2739  if (diag.isSilenceableFailure())
2740  return diag;
2741 
2742  for (auto &&[idx, chunkSize] : llvm::enumerate(chunkSizes)) {
2743 
2744  if (idx > 0)
2745  target = tail.getOperation();
2746 
2747  if (!target)
2748  break;
2749 
2750  linalgOp = cast<LinalgOp>(target);
2751  Location loc = target->getLoc();
2752 
2753  rewriter.setInsertionPoint(linalgOp);
2754  std::tie(head, tail) = linalg::splitOp(
2755  rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2756  getDimension(), chunkSize);
2757 
2758  // Propagate errors.
2760  checkFailureInSplitting(!head && !tail, loc);
2761  if (diag.isDefiniteFailure())
2762  return diag;
2763 
2764  opList.push_back(head.getOperation());
2765  }
2766 
2767  // Append any leftover parts to the end of the result list.
2768  if (tail)
2769  opList.push_back(tail.getOperation());
2770 
2771  } else {
2772  // Split each target operation.
2773  SmallVector<Operation *> first, second;
2774  Operation *noSecondPart = nullptr;
2775  for (const auto &pair : llvm::zip(payload, chunkSizes)) {
2776  Operation *target = std::get<0>(pair);
2777  Location loc = target->getLoc();
2778  LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
2780  checkStructuredOpAndDimensions(linalgOp, target->getLoc());
2781 
2782  if (diag.isSilenceableFailure())
2783  return diag;
2784 
2785  rewriter.setInsertionPoint(linalgOp);
2786  std::tie(first.emplace_back(), second.emplace_back()) = linalg::splitOp(
2787  rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2788  getDimension(), std::get<1>(pair));
2789 
2790  // Propagate errors.
2791  DiagnosedSilenceableFailure diagSplit =
2792  checkFailureInSplitting(!first.back() && !second.back(), loc);
2793  if (diagSplit.isDefiniteFailure())
2794  return diag;
2795 
2796  // Do not add null second parts.
2797  if (!second.back()) {
2798  noSecondPart = target;
2799  second.pop_back();
2800  }
2801  }
2802 
2803  if (second.size() != first.size() && !second.empty()) {
2804  auto diag = emitSilenceableError()
2805  << "splitting does not produce the second part for a subset "
2806  "of targets";
2807  diag.attachNote()
2808  << "expected splitting to produce the second part of all "
2809  "or none of the targets";
2810  diag.attachNote(noSecondPart->getLoc())
2811  << "first target with no second part";
2812  return diag;
2813  }
2814 
2815  opList.append(first);
2816  if (second.size())
2817  opList.append(second);
2818  }
2819  results.set(cast<OpResult>(getSplitList()), opList);
2821 }
2822 
2823 void SplitOp::getEffects(
2825  consumesHandle(getTargetMutable(), effects);
2826  if (getDynamicChunkSizes())
2827  onlyReadsHandle(getDynamicChunkSizesMutable(), effects);
2828  producesHandle(getOperation()->getOpResults(), effects);
2829  modifiesPayload(effects);
2830 }
2831 
2832 ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) {
2833  OpAsmParser::UnresolvedOperand target, dynamicChunkSizes;
2834  IntegerAttr staticChunkSizes;
2835  if (parser.parseOperand(target) || parser.parseKeyword("after"))
2836  return failure();
2837 
2838  OptionalParseResult dynamicPointParseResult =
2839  parser.parseOptionalOperand(dynamicChunkSizes);
2840  if (!dynamicPointParseResult.has_value()) {
2841  int64_t staticChunkSizesValue;
2842  if (failed(parser.parseInteger(staticChunkSizesValue)))
2843  return failure();
2844 
2845  staticChunkSizes =
2846  parser.getBuilder().getI64IntegerAttr(staticChunkSizesValue);
2847  }
2848 
2849  Type targetType;
2850  if (parser.parseOptionalAttrDict(result.attributes) ||
2851  parser.parseColonType(targetType) ||
2852  parser.resolveOperand(target, targetType, result.operands)) {
2853  return failure();
2854  }
2855  if (dynamicPointParseResult.has_value()) {
2856  Type ChunkSizesType;
2857  if (failed(*dynamicPointParseResult) || parser.parseComma() ||
2858  parser.parseType(ChunkSizesType) ||
2859  parser.resolveOperand(dynamicChunkSizes, ChunkSizesType,
2860  result.operands)) {
2861  return failure();
2862  }
2863 
2864  staticChunkSizes =
2865  parser.getBuilder().getI64IntegerAttr(ShapedType::kDynamic);
2866  }
2867 
2868  result.addAttribute(
2869  SplitOp::getStaticChunkSizesAttrName(result.name).getValue(),
2870  staticChunkSizes);
2871  result.addTypes(targetType);
2872  return success();
2873 }
2874 
2875 void SplitOp::print(OpAsmPrinter &printer) {
2876  printer << " " << getTarget() << " after ";
2877  int64_t staticChunkSize = static_cast<int64_t>(getStaticChunkSizes());
2878  if (staticChunkSize != ShapedType::kDynamic)
2879  printer << staticChunkSize;
2880  else
2881  printer << getDynamicChunkSizes();
2882  printer << " ";
2883  printer.printOptionalAttrDict(getOperation()->getAttrs(),
2884  {getStaticChunkSizesAttrName()});
2885  printer << " : " << getTarget().getType();
2886  if (staticChunkSize == ShapedType::kDynamic)
2887  printer << ", " << getDynamicChunkSizes().getType();
2888 }
2889 
2890 LogicalResult SplitOp::verify() {
2891  if ((static_cast<int64_t>(getStaticChunkSizes()) != ShapedType::kDynamic) ^
2892  (getDynamicChunkSizes() == nullptr)) {
2893  return emitOpError() << "expects either a dynamic or a static split "
2894  "point to be provided";
2895  }
2896  return success();
2897 }
2898 
2899 //===----------------------------------------------------------------------===//
2900 // SplitReductionOp
2901 //===----------------------------------------------------------------------===//
2902 
2903 void transform::SplitReductionOp::build(
2904  OpBuilder &builder, OperationState &result, Value target,
2905  int64_t splitFactor, int64_t insertSplitDimension, bool innerParallel,
2906  bool useScalingAlgorithm, bool useAlloc) {
2907  MLIRContext *ctx = builder.getContext();
2908  result.addOperands(target);
2909  result.addAttribute(SplitReductionOp::getSplitFactorAttrName(result.name),
2910  builder.getI64IntegerAttr(splitFactor));
2911  result.addAttribute(
2912  SplitReductionOp::getInsertSplitDimensionAttrName(result.name),
2913  builder.getI64IntegerAttr(insertSplitDimension));
2914  if (innerParallel) {
2915  result.addAttribute(SplitReductionOp::getInnerParallelAttrName(result.name),
2916  builder.getUnitAttr());
2917  }
2918  if (useScalingAlgorithm) {
2919  result.addAttribute(
2920  SplitReductionOp::getUseScalingAlgorithmAttrName(result.name),
2921  builder.getUnitAttr());
2922  }
2923  if (useAlloc) {
2924  result.addAttribute(SplitReductionOp::getUseAllocAttrName(result.name),
2925  builder.getUnitAttr());
2926  }
2927  auto resultType = transform::AnyOpType::get(ctx);
2928  result.addTypes({resultType, resultType, resultType, resultType});
2929 }
2930 
2931 DiagnosedSilenceableFailure transform::SplitReductionOp::applyToOne(
2932  transform::TransformRewriter &rewriter, LinalgOp target,
2934  transform::TransformState &state) {
2935  ControlSplitReductionFn splitFn = [&](LinalgOp) {
2936  return linalg::SplitReductionOptions{int64_t(getSplitFactor()),
2937  unsigned(getInsertSplitDimension()),
2938  bool(getInnerParallel())};
2939  };
2940  rewriter.setInsertionPoint(target);
2941  FailureOr<SplitReductionResult> splitResult =
2942  (getUseScalingAlgorithm())
2943  ? splitReductionByScaling(rewriter, target, splitFn, getUseAlloc())
2944  : splitReduction(rewriter, target, splitFn, getUseAlloc());
2945  if (failed(splitResult))
2946  return emitDefaultDefiniteFailure(target);
2947 
2948  results.push_back(splitResult->initOrAlloc);
2949  results.push_back(splitResult->fillOp);
2950  results.push_back(splitResult->splitLinalgOp);
2951  results.push_back(splitResult->resultCombiningLinalgOp);
2953 }
2954 
2955 //===----------------------------------------------------------------------===//
2956 // TileReductionUsingForOp
2957 //===----------------------------------------------------------------------===//
2958 
2959 void transform::TileReductionUsingForOp::build(
2960  OpBuilder &builder, OperationState &result, Value target,
2961  ArrayRef<int64_t> staticTileSizes) {
2962  // Call the default builder.
2963  // This is future-proof re mixed static-dynamic and setting up the proper
2964  // operands segment sizes attributes for multiple variadic operands.
2965  // In the absence of this, horrible bugs ensue.
2966  // TODO: support mixed static-dynamic (see TileUsingForallOp).
2967  MLIRContext *ctx = builder.getContext();
2968  auto opTy = transform::AnyOpType::get(ctx);
2969  auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
2970  build(builder, result,
2971  /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
2972  /*target=*/target,
2973  /*tile_sizes=*/staticTileSizesAttr);
2974 }
2975 
2976 DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
2977  transform::TransformRewriter &rewriter, Operation *target,
2979  transform::TransformState &state) {
2980  rewriter.setInsertionPoint(target);
2981 
2982  auto partialReductionOp = dyn_cast<PartialReductionOpInterface>(target);
2983  if (!partialReductionOp) {
2984  return emitSilenceableFailure(
2985  target->getLoc(),
2986  "Operation should implement PartialReductionOpInterface");
2987  }
2988  FailureOr<scf::SCFTilingResult> result = scf::tileReductionUsingScf(
2989  rewriter, partialReductionOp,
2991 
2992  if (failed(result))
2993  return emitDefaultSilenceableFailure(target);
2994  rewriter.replaceOp(target, result->replacements);
2995  for (Value initValue : result->initialValues)
2996  results.push_back(initValue.getDefiningOp());
2997  for (auto parallelTiledOp : result->tiledOps)
2998  results.push_back(parallelTiledOp);
2999  for (auto mergeOp : result->mergeOps)
3000  results.push_back(mergeOp);
3001  results.push_back(result->loops.front());
3003 }
3004 
3005 //===----------------------------------------------------------------------===//
3006 // TileReductionUsingForallOp
3007 //===----------------------------------------------------------------------===//
3008 
3009 void transform::TileReductionUsingForallOp::build(
3010  OpBuilder &builder, OperationState &result, Value target,
3011  ArrayRef<int64_t> staticNumThreads, ArrayRef<int64_t> staticTileSizes,
3012  ArrayAttr mapping) {
3013  // Call the default builder.
3014  // This is future-proof re mixed static-dynamic and setting up the proper
3015  // operands segment sizes attributes for multiple variadic operands.
3016  // In the absence of this, horrible bugs ensue.
3017  // TODO: support mixed static-dynamic (see TileUsingForallOp).
3018  MLIRContext *ctx = builder.getContext();
3019  auto opTy = transform::AnyOpType::get(ctx);
3020  auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads);
3021  auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
3022  build(builder, result,
3023  /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
3024  /*target=*/target,
3025  /*num_threads=*/staticNumThreadsAttr,
3026  /*tile_sizes=*/staticTileSizesAttr,
3027  /*mapping=*/mapping);
3028 }
3029 
3030 DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
3031  transform::TransformRewriter &rewriter, LinalgOp target,
3033  transform::TransformState &state) {
3034  rewriter.setInsertionPoint(target);
3035  SmallVector<OpFoldResult> numThreads =
3036  getAsOpFoldResult(rewriter.getI64ArrayAttr(getNumThreads()));
3037  SmallVector<OpFoldResult> tileSizes =
3039  FailureOr<linalg::ForallReductionTilingResult> result =
3041  rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
3042  numThreads, tileSizes, getMapping());
3043 
3044  if (failed(result)) {
3045  auto diag = emitSilenceableError() << "could not tile reduction";
3046  diag.attachNote(target.getLoc()) << "target operation";
3047  return diag;
3048  }
3049  for (Value initValue : result->initialValues)
3050  results.push_back(initValue.getDefiningOp());
3051  for (auto parallelTiledOp : result->parallelTiledOps)
3052  results.push_back(parallelTiledOp);
3053  for (auto mergeOp : result->mergeOps)
3054  results.push_back(mergeOp);
3055  results.push_back(result->loops);
3057 }
3058 
3059 //===----------------------------------------------------------------------===//
3060 // ContinuousTileSizesOp
3061 //===----------------------------------------------------------------------===//
3062 
3064 transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
3065  TransformResults &transformResults,
3066  TransformState &state) {
3067 
3068  SmallVector<Operation *> targetOps =
3069  llvm::to_vector(state.getPayloadOps(getTarget()));
3070 
3071  if (!llvm::hasSingleElement(targetOps)) {
3072  return mlir::emitSilenceableFailure(getLoc())
3073  << "requires exactly one target (got " << llvm::range_size(targetOps)
3074  << ")";
3075  }
3076 
3077  Operation *target = *targetOps.begin();
3078  auto linalgOp = dyn_cast<LinalgOp>(target);
3079  auto tileableOp = dyn_cast<TilingInterface>(target);
3080 
3081  if (!linalgOp)
3082  return emitDefiniteFailure() << "expected Linalg Op";
3083 
3084  OpBuilder builder(linalgOp.getContext());
3085 
3086  if (isa<TransformParamTypeInterface>(getChunkSizes().getType())) {
3087  if (linalgOp.hasDynamicShape()) {
3088  auto diag = emitSilenceableError()
3089  << "cannot compute parametric tile sizes for dynamically "
3090  "shaped payload op";
3091  diag.attachNote(linalgOp->getLoc()) << "payload op";
3092  return diag;
3093  }
3094 
3095  FailureOr<StaticContinuousTileSizeSpecification> spec =
3096  computeStaticContinuousTileSizes(linalgOp, getDimension(),
3097  getTargetSize());
3098  if (failed(spec)) {
3099  return emitSilenceableError()
3100  << "failed to compute multi-size tiling sizes";
3101  }
3102 
3103  SmallVector<int64_t> chunkSizes;
3104 
3105  for (auto &&[tileSize, tripCount] :
3106  llvm::zip_equal(spec->tileSizes, spec->tripCounts))
3107  chunkSizes.push_back(tileSize * tripCount);
3108 
3109  auto getI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
3110  return llvm::map_to_vector(values, [&](int64_t value) -> Attribute {
3111  return builder.getI64IntegerAttr(value);
3112  });
3113  };
3114  transformResults.setParams(cast<OpResult>(getTileSizes()),
3115  getI64AttrsFromI64(spec->tileSizes));
3116  transformResults.setParams(cast<OpResult>(getChunkSizes()),
3117  getI64AttrsFromI64(chunkSizes));
3118 
3120  }
3121 
3122  builder.setInsertionPoint(linalgOp);
3123 
3124  OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
3125  unsigned dimension = getDimension();
3126 
3127  FailureOr<ContinuousTileSizeSpecification> spec = computeContinuousTileSizes(
3128  builder, tileableOp, dimension, targetSize, true);
3129  if (failed(spec)) {
3130  return emitSilenceableError() << "could not generate tile size computation";
3131  }
3132 
3133  AffineExpr s0 = builder.getAffineSymbolExpr(0);
3134  AffineExpr s1 = builder.getAffineSymbolExpr(1);
3135  auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value {
3136  return affine::makeComposedAffineApply(builder, linalgOp->getLoc(), expr,
3137  ofrs);
3138  };
3139 
3140  SmallVector<Value> chunkSizes;
3141  Value splitPoint;
3142  for (auto &&[tileSize, tripCount] :
3143  llvm::zip_equal(spec->tileSizes, spec->tripCounts)) {
3144  splitPoint = apply(s0 * s1, {tileSize, tripCount});
3145  chunkSizes.push_back(splitPoint);
3146  }
3147 
3148  auto getDefiningOps = [&](ArrayRef<Value> values) {
3149  return llvm::map_to_vector(values, [&](Value value) -> Operation * {
3150  return value.getDefiningOp();
3151  });
3152  };
3153 
3154  transformResults.set(cast<OpResult>(getTileSizes()),
3155  getDefiningOps(spec->tileSizes));
3156  transformResults.set(cast<OpResult>(getChunkSizes()),
3157  getDefiningOps(chunkSizes));
3158 
3160 }
3161 
3163 
3164  if (getTileSizes().getType() != getChunkSizes().getType()) {
3165  return emitOpError() << "expects all results type to be the same";
3166  }
3167 
3168  return success();
3169 }
3170 
3171 void transform::ContinuousTileSizesOp::getEffects(
3173  if (isa<TransformParamTypeInterface>(getTileSizes().getType()))
3174  onlyReadsPayload(effects);
3175  else
3176  modifiesPayload(effects);
3177  onlyReadsHandle(getTargetMutable(), effects);
3178  producesHandle(getOperation()->getOpResults(), effects);
3179 }
3180 
3182  Type targetType, Type tile_sizes,
3183  Type) {
3184  printer.printFunctionalType(TypeRange{targetType}, TypeRange{tile_sizes});
3185 }
3186 
3187 static ParseResult parseContinuousTileSizeTypes(OpAsmParser &parser,
3188  Type &targetType,
3189  Type &tileSizesType,
3190  Type &chunkSizesType) {
3191  FunctionType funcType;
3192  llvm::SMLoc typeLoc = parser.getCurrentLocation();
3193  if (failed(parser.parseType<FunctionType>(funcType)))
3194  return failure();
3195 
3196  if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
3197  parser.emitError(typeLoc) << "expects a trailing functional type with one "
3198  "argument and one result";
3199  }
3200  targetType = funcType.getInput(0);
3201  tileSizesType = chunkSizesType = funcType.getResult(0);
3202 
3203  return success();
3204 }
3205 
3206 //===----------------------------------------------------------------------===//
3207 // TileUsingForOp
3208 //===----------------------------------------------------------------------===//
3209 
3210 void transform::TileUsingForOp::build(
3211  OpBuilder &builder, OperationState &result, TypeRange loopTypes,
3212  Value target, ArrayRef<int64_t> staticTileSizes,
3213  ArrayRef<int64_t> interchange,
3214  std::optional<ArrayRef<bool>> scalableSizes) {
3215  return build(builder, result, loopTypes,
3216  /*target=*/target,
3217  /*mixedTileSizes=*/
3218  getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
3219  interchange, scalableSizes);
3220 }
3221 
3222 void transform::TileUsingForOp::build(
3223  OpBuilder &builder, OperationState &result, Value target,
3224  ArrayRef<int64_t> staticTileSizes, ArrayRef<int64_t> interchange,
3225  std::optional<ArrayRef<bool>> scalableSizes) {
3226  build(builder, result, target,
3227  getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
3228  interchange, scalableSizes);
3229 }
3230 
3231 void transform::TileUsingForOp::build(
3232  OpBuilder &builder, OperationState &result, Value target,
3233  ArrayRef<OpFoldResult> mixedTileSizes, ArrayRef<int64_t> interchange,
3234  std::optional<ArrayRef<bool>> scalableSizes) {
3235  // Loop types are automaticaly splat by the callee, setting up one is
3236  // enough.
3237  SmallVector<Type> loopTypes(1, builder.getType<transform::AnyOpType>());
3238  build(builder, result, loopTypes, target, mixedTileSizes, interchange,
3239  scalableSizes);
3240 }
3241 
3242 void transform::TileUsingForOp::build(
3243  OpBuilder &builder, OperationState &result, TypeRange loopTypes,
3244  Value target, ArrayRef<OpFoldResult> mixedTileSizes,
3245  ArrayRef<int64_t> interchange,
3246  std::optional<ArrayRef<bool>> scalableSizes) {
3247  SmallVector<int64_t> staticTileSizes;
3248  SmallVector<Value> dynamicTileSizes;
3249  dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
3250  // Call the default builder which sets up the proper operands segment sizes
3251  // attributes for multiple variadic operands. In the absence of this,
3252  // horrible bugs ensue.
3253  auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
3254  unsigned numExpectedLoops =
3255  staticTileSizes.size() - llvm::count(staticTileSizes, 0);
3256  SmallVector<Type> resultTypes;
3257  resultTypes.reserve(numExpectedLoops);
3258  assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) &&
3259  "expected one loop type or as many as loops");
3260  if (loopTypes.size() == 1)
3261  resultTypes.append(numExpectedLoops, loopTypes[0]);
3262  else
3263  llvm::append_range(resultTypes, loopTypes);
3264  SmallVector<bool> expandedScalableSizes(mixedTileSizes.size(), false);
3265  if (scalableSizes.has_value())
3266  expandedScalableSizes.assign(scalableSizes->begin(), scalableSizes->end());
3267  build(builder, result, /*tiled_linalg_op=*/target.getType(),
3268  /*loops=*/resultTypes,
3269  /*target=*/target,
3270  /*dynamic_sizes=*/dynamicTileSizes,
3271  /*static_sizes=*/staticTileSizesAttr,
3272  /*interchange=*/builder.getDenseI64ArrayAttr(interchange),
3273  /*scalable_sizes=*/expandedScalableSizes);
3274 }
3275 
3276 LogicalResult transform::TileUsingForOp::verify() {
3277  if (getMixedSizes().size() != getScalableSizes().size())
3278  return emitOpError("expected same number of sizes (")
3279  << getMixedSizes().size() << ") and scalable sizes ("
3280  << getScalableSizes().size() << ")";
3281  ArrayRef<int64_t> staticSizes = getStaticSizes();
3282  unsigned numExpectedLoops = staticSizes.size() - llvm::count(staticSizes, 0);
3283  if (getLoops().size() != numExpectedLoops)
3284  return emitOpError("expected number of loops to tile (")
3285  << numExpectedLoops << ") to match number of `loops` results ("
3286  << getLoops().size() << ")";
3287  return success();
3288 }
3289 
3291 transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter,
3292  TransformResults &transformResults,
3293  TransformState &state) {
3294  ArrayRef<int64_t> tileSizes = getStaticSizes();
3295 
3296  SmallVector<Operation *> targets =
3297  llvm::to_vector(state.getPayloadOps(getTarget()));
3298  SmallVector<SmallVector<Operation *>> dynamicSizeProducers;
3300  dynamicSizeProducers.reserve(getDynamicSizes().size());
3301  paramSizes.reserve(getDynamicSizes().size());
3302  for (Value transformValue : getDynamicSizes()) {
3303  if (isa<ParamType>(transformValue.getType())) {
3304  dynamicSizeProducers.push_back({});
3305  ArrayRef<Attribute> params = state.getParams(transformValue);
3306  paramSizes.push_back(
3307  llvm::to_vector(llvm::map_range(params, [](Attribute attr) {
3308  return cast<IntegerAttr>(attr).getValue().getSExtValue();
3309  })));
3310 
3311  if (paramSizes.back().size() != targets.size()) {
3313  emitSilenceableError()
3314  << "expected as many parameter values ("
3315  << dynamicSizeProducers.back().size() << ") as target ops ("
3316  << targets.size() << ")";
3317  diag.attachNote(transformValue.getLoc()) << "for this parameter";
3318  return diag;
3319  }
3320 
3321  continue;
3322  }
3323  paramSizes.push_back({});
3324  dynamicSizeProducers.push_back(
3325  llvm::to_vector(state.getPayloadOps(transformValue)));
3326 
3327  if (dynamicSizeProducers.back().size() != targets.size()) {
3329  emitSilenceableError()
3330  << "expected as many dynamic size-producing operations ("
3331  << dynamicSizeProducers.back().size() << ") as target ops ("
3332  << targets.size() << ")";
3333  diag.attachNote(transformValue.getLoc()) << "for this handle";
3334  return diag;
3335  }
3336 
3337  for (Operation *op : dynamicSizeProducers.back()) {
3338  if (op->getNumResults() == 1 &&
3339  isa<IndexType>(op->getResult(0).getType())) {
3340  continue;
3341  }
3342 
3344  emitSilenceableError() << "expected sizes to be produced by ops "
3345  "with a single index-type result";
3346  diag.attachNote(op->getLoc()) << "size producer op";
3347  diag.attachNote(transformValue.getLoc()) << "for this handle";
3348  return diag;
3349  }
3350  }
3351 
3354  loops.resize(getLoops().size());
3355  auto scalableSizes = getScalableSizes();
3356  for (auto [i, op] : llvm::enumerate(targets)) {
3357  auto tilingInterface = dyn_cast<TilingInterface>(op);
3358  if (!tilingInterface) {
3360  emitSilenceableError()
3361  << "only ops implementing TilingInterface are supported";
3362  diag.attachNote(op->getLoc()) << "target op";
3363  return diag;
3364  }
3365  if (tileSizes.size() > tilingInterface.getLoopIteratorTypes().size()) {
3367  emitSilenceableError()
3368  << "too many tiles provided, expected at most "
3369  << tilingInterface.getLoopIteratorTypes().size() << " found "
3370  << tileSizes.size();
3371  diag.attachNote(op->getLoc()) << "target op";
3372  return diag;
3373  }
3374 
3375  scf::SCFTilingOptions tilingOptions;
3376  if (tileSizes.empty()) {
3377  tilingOptions.setTileSizeComputationFunction(
3379  return {};
3380  });
3381  } else {
3382  tilingOptions.setTileSizeComputationFunction([&, index = i](OpBuilder &b,
3383  Operation *) {
3385  sizes.reserve(tileSizes.size());
3386  unsigned dynamicIdx = 0;
3387 
3388  for (auto [ofrIdx, ofr] : llvm::enumerate(getMixedSizes())) {
3389  if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) {
3390  if (scalableSizes[ofrIdx]) {
3391  auto val = b.create<arith::ConstantIndexOp>(
3392  getLoc(), cast<IntegerAttr>(attr).getInt());
3393  Value vscale =
3394  b.create<vector::VectorScaleOp>(getLoc(), b.getIndexType());
3395  sizes.push_back(
3396  b.create<arith::MulIOp>(getLoc(), val, vscale).getResult());
3397  } else {
3398  sizes.push_back(attr);
3399  }
3400  continue;
3401  }
3402  ArrayRef<Operation *> dynamicSizes = dynamicSizeProducers[dynamicIdx];
3403  ArrayRef<int64_t> params = paramSizes[dynamicIdx];
3404  ++dynamicIdx;
3405  assert((dynamicSizes.empty() ^ params.empty()) &&
3406  "expected either dynamic sizes or parameters");
3407  if (!params.empty()) {
3408  sizes.push_back(b.getIndexAttr(params[index]));
3409  } else {
3410  sizes.push_back(dynamicSizes[index]->getResult(0));
3411  }
3412  }
3413  return sizes;
3414  });
3415  }
3416 
3417  tilingOptions.setInterchange(getInterchange());
3418  FailureOr<scf::SCFTilingResult> maybeTilingResult =
3419  tileUsingSCF(rewriter, tilingInterface, tilingOptions);
3420  if (failed(maybeTilingResult))
3422 
3423  rewriter.replaceOp(op, maybeTilingResult->replacements);
3424 
3425  tiled.append(maybeTilingResult->tiledOps);
3426  for (const auto &en2 : llvm::enumerate(maybeTilingResult->loops))
3427  loops[en2.index()].push_back(en2.value());
3428  }
3429 
3430  transformResults.set(cast<OpResult>(getTiledLinalgOp()), tiled);
3431  for (const auto &en : llvm::enumerate(loops))
3432  transformResults.set(cast<OpResult>(getLoops()[en.index()]), en.value());
3433 
3435 }
3436 
3438  ValueRange dynamic = getDynamicSizes();
3439  ArrayRef<int64_t> tileSizes = getStaticSizes();
3440  SmallVector<OpFoldResult> results;
3441  results.reserve(tileSizes.size());
3442  unsigned dynamicPos = 0;
3443  Builder builder(getContext());
3444  for (int64_t size : tileSizes) {
3445  if (size == ShapedType::kDynamic) {
3446  results.push_back(dynamic[dynamicPos++]);
3447  } else {
3448  results.push_back(builder.getIndexAttr(size));
3449  }
3450  }
3451  return results;
3452 }
3453 
3454 void transform::TileUsingForOp::getEffects(
3456  consumesHandle(getTargetMutable(), effects);
3457  onlyReadsHandle(getDynamicSizesMutable(), effects);
3458  producesHandle(getOperation()->getOpResults(), effects);
3459  modifiesPayload(effects);
3460 }
3461 
3462 //===----------------------------------------------------------------------===//
3463 // TileUsingForallOp
3464 //===----------------------------------------------------------------------===//
3465 
3466 void transform::TileUsingForallOp::build(OpBuilder &builder,
3467  OperationState &result, Value target,
3468  ArrayRef<int64_t> staticTileSizes,
3470  ArrayAttr mapping) {
3471  return build(builder, result,
3472  /*target=*/target,
3473  /*mixedTileSizes=*/
3474  getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
3475  /*_=*/TileSizesSpec(),
3476  /*mapping=*/mapping);
3477 }
3478 
3479 void transform::TileUsingForallOp::build(OpBuilder &builder,
3480  OperationState &result, Value target,
3481  ArrayRef<OpFoldResult> mixedTileSizes,
3483  ArrayAttr mapping) {
3484  SmallVector<int64_t> staticTileSizes;
3485  SmallVector<Value> dynamicTileSizes;
3486  dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
3487  // Call the default builder which sets up the proper operands segment sizes
3488  // attributes for multiple variadic operands. In the absence of this,
3489  // horrible bugs ensue.
3490  MLIRContext *ctx = builder.getContext();
3491  auto operationType = transform::AnyOpType::get(ctx);
3492  auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
3493  build(builder, result,
3494  /*resultTypes=*/TypeRange{operationType, operationType},
3495  /*target=*/target,
3496  /*num_threads=*/ValueRange{},
3497  /*tile_sizes=*/dynamicTileSizes,
3498  /*packed_num_threads=*/Value(),
3499  /*packed_tile_sizes=*/Value(),
3500  /*static_num_threads=*/builder.getDenseI64ArrayAttr({}),
3501  /*static_tile_sizes=*/staticTileSizesAttr,
3502  /*mapping=*/mapping);
3503 }
3504 
3505 void transform::TileUsingForallOp::build(OpBuilder &builder,
3506  OperationState &result, Value target,
3507  ArrayRef<int64_t> staticNumThreads,
3509  ArrayAttr mapping) {
3510  return build(builder, result, target,
3511  getAsOpFoldResult(builder.getI64ArrayAttr(staticNumThreads)),
3512  NumThreadsSpec(), mapping);
3513 }
3514 
3515 void transform::TileUsingForallOp::build(OpBuilder &builder,
3516  OperationState &result, Value target,
3517  ArrayRef<OpFoldResult> mixedNumThreads,
3519  ArrayAttr mapping) {
3520  SmallVector<int64_t> staticNumThreads;
3521  SmallVector<Value> dynamicNumThreads;
3522  dispatchIndexOpFoldResults(mixedNumThreads, dynamicNumThreads,
3523  staticNumThreads);
3524  // Call the default builder which sets up the proper operands segment sizes
3525  // attributes for multiple variadic operands. In the absence of this,
3526  // horrible bugs ensue.
3527  MLIRContext *ctx = builder.getContext();
3528  auto operationType = transform::AnyOpType::get(ctx);
3529  auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads);
3530  build(builder, result,
3531  /*resultTypes=*/TypeRange{operationType, operationType},
3532  /*target=*/target,
3533  /*num_threads=*/dynamicNumThreads,
3534  /*tile_sizes=*/ValueRange{},
3535  /*packed_num_threads=*/Value(),
3536  /*packed_tile_sizes=*/Value(),
3537  /*static_num_threads=*/staticNumThreadsAttr,
3538  /*static_tile_sizes=*/builder.getDenseI64ArrayAttr({}),
3539  /*mapping=*/mapping);
3540 }
3541 
3542 /// Given `lbs`, `ubs` and `steps` of loops, return (for each loop), the
3543 /// normalized upper bound.
3547  ArrayRef<OpFoldResult> steps) {
3548  AffineExpr s0, s1, s2;
3549  bindSymbols(rewriter.getContext(), s0, s1, s2);
3550  AffineExpr normalizedUbExpr = (s1 - s0).ceilDiv(s2);
3551  SmallVector<OpFoldResult> normalizedUbs;
3552  for (auto [lb, ub, step] : llvm::zip_equal(lbs, ubs, steps)) {
3554  rewriter, loc, normalizedUbExpr, {lb, ub, step});
3555  normalizedUbs.push_back(normalizedUb);
3556  }
3557  return normalizedUbs;
3558 }
3559 
3560 /// When a loop is normalized, the uses of the induction variable within the
3561 /// loop need to replaced with `original_lb + old_iv * original_step`.
3563  Location loc, ValueRange ivs,
3565  ArrayRef<OpFoldResult> steps) {
3566  AffineExpr s0, s1;
3567  AffineExpr d0;
3568  bindSymbols(rewriter.getContext(), s0, s1);
3569  bindDims(rewriter.getContext(), d0);
3570  AffineExpr denormExpr = s0 + d0 * s1;
3571  SmallVector<Value> denormalizedIvs;
3572 
3573  for (auto [iv, lb, step] : llvm::zip_equal(ivs, lbs, steps)) {
3575  rewriter, loc, denormExpr, ArrayRef<OpFoldResult>{iv, lb, step});
3576  denormalizedIvs.push_back(
3577  getValueOrCreateConstantIndexOp(rewriter, loc, denormValue));
3578  }
3579  return denormalizedIvs;
3580 }
3581 
3582 /// Given a `scf.forall` loop return a loop op with the loop bounds
3583 /// normalized.
3584 /// TODO: Replace this with a general utility to normalize `scf.forall`.
3585 /// At the time of writing, this wasnt done since adding this to `scf`
3586 /// dialect would disallow using of `affine.apply` operations due
3587 /// to cyclic dependencies. To avoid churn in lit tests
3588 /// with the change this was added with, defer that to a follow up.
3589 static scf::ForallOp normalizeForallLoopOp(RewriterBase &rewriter,
3590  scf::ForallOp loop) {
3591  SmallVector<OpFoldResult> lbs = loop.getMixedLowerBound();
3592  SmallVector<OpFoldResult> ubs = loop.getMixedUpperBound();
3593  SmallVector<OpFoldResult> steps = loop.getMixedStep();
3594 
3595  if (llvm::all_of(lbs, isZeroInteger) && llvm::all_of(steps, isOneInteger)) {
3596  return loop;
3597  }
3598 
3599  Location loc = loop.getLoc();
3600  SmallVector<OpFoldResult> normalizedUbs =
3601  normalizeUpperBounds(rewriter, loc, lbs, ubs, steps);
3602  SmallVector<OpFoldResult> normalizedLbs(normalizedUbs.size(),
3603  rewriter.getIndexAttr(0));
3604  SmallVector<OpFoldResult> normalizedSteps(normalizedUbs.size(),
3605  rewriter.getIndexAttr(1));
3606 
3607  auto normalizedForallOp = rewriter.create<scf::ForallOp>(
3608  loc, normalizedLbs, normalizedUbs, normalizedSteps, loop.getOutputs(),
3609  loop.getMapping(), [](OpBuilder &, Location, ValueRange) {});
3610 
3611  auto normalizedLoopIvs = normalizedForallOp.getInductionVars();
3612  OpBuilder::InsertionGuard g(rewriter);
3613  Block *normalizedLoopBlock = normalizedForallOp.getBody();
3614  rewriter.setInsertionPointToStart(normalizedLoopBlock);
3615 
3616  SmallVector<Value> argValues =
3617  denormalizeIndVar(rewriter, loc, normalizedLoopIvs, lbs, steps);
3618  argValues.append(normalizedForallOp.getRegionIterArgs().begin(),
3619  normalizedForallOp.getRegionIterArgs().end());
3620  Block *origLoopBlock = loop.getBody();
3621  rewriter.mergeBlocks(origLoopBlock, normalizedLoopBlock, argValues);
3622 
3623  rewriter.replaceOp(loop, normalizedForallOp);
3624  return normalizedForallOp;
3625 }
3626 
3628  RewriterBase &rewriter, transform::TransformState &state,
3629  TransformOpInterface transformOp, Operation *target,
3630  ArrayRef<OpFoldResult> mixedNumThreads,
3631  ArrayRef<OpFoldResult> mixedTileSizes, std::optional<ArrayAttr> mapping,
3632  scf::SCFTilingResult &tilingResult) {
3633  // Transform all targets one by one.
3634  auto tileableOp = dyn_cast<TilingInterface>(target);
3635  if (!tileableOp) {
3637  transformOp.emitSilenceableError()
3638  << "only TilingInterface ops are supported";
3639  diag.attachNote(target->getLoc()) << "target op";
3640  return diag;
3641  }
3642  rewriter.setInsertionPoint(tileableOp);
3645  if (!mixedNumThreads.empty()) {
3646  options.setNumThreads(mixedNumThreads);
3647  } else {
3648  options.setTileSizes(mixedTileSizes);
3649  }
3650  if (mapping) {
3651  options.setMapping(mapping.value().getValue());
3652  }
3653  FailureOr<scf::SCFTilingResult> maybeTilingResult =
3654  scf::tileUsingSCF(rewriter, tileableOp, options);
3655 
3656  if (failed(maybeTilingResult))
3657  return transformOp.emitDefaultSilenceableFailure(tileableOp);
3658 
3659  rewriter.replaceOp(tileableOp, maybeTilingResult->replacements);
3660 
3661  tilingResult = *maybeTilingResult;
3662 
3663  if (mixedNumThreads.empty()) {
3664  auto generatedForallOp = cast<scf::ForallOp>(tilingResult.loops.front());
3665  OpBuilder::InsertionGuard g(rewriter);
3666  rewriter.setInsertionPoint(generatedForallOp);
3667  scf::ForallOp normalizedForallOp =
3668  normalizeForallLoopOp(rewriter, generatedForallOp);
3669  tilingResult.loops.front() = normalizedForallOp;
3670  }
3671 
3673 }
3674 
3675 DiagnosedSilenceableFailure transform::TileUsingForallOp::apply(
3676  transform::TransformRewriter &rewriter,
3677  transform::TransformResults &transformResults,
3678  transform::TransformState &state) {
3679  auto transformOp = cast<TransformOpInterface>(getOperation());
3680 
3681  // Result payload ops.
3682  SmallVector<Operation *> tileOps;
3683  SmallVector<Operation *> tiledOps;
3684 
3685  // Unpack handles.
3686  SmallVector<OpFoldResult> mixedNumThreads;
3688  getPackedNumThreads()
3690  state, transformOp, mixedNumThreads, getPackedNumThreads())
3692  state, transformOp, mixedNumThreads, getMixedNumThreads());
3693  if (!status.succeeded())
3694  return status;
3695  SmallVector<OpFoldResult> mixedTileSizes;
3696  status = getPackedTileSizes()
3698  state, transformOp, mixedTileSizes, getPackedTileSizes())
3700  state, transformOp, mixedTileSizes, getMixedTileSizes());
3701  if (!status.succeeded())
3702  return status;
3703 
3704  for (Operation *target : state.getPayloadOps(getTarget())) {
3705  scf::SCFTilingResult tilingResult;
3707  rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes,
3708  getMapping(), tilingResult);
3709  if (!diag.succeeded())
3710  return diag;
3711  tileOps.push_back(tilingResult.loops.front());
3712  tiledOps.append(tilingResult.tiledOps);
3713  }
3714 
3715  transformResults.set(cast<OpResult>(getForallOp()), tileOps);
3716  transformResults.set(cast<OpResult>(getTiledOp()), tiledOps);
3717 
3719 }
3720 
3721 void transform::TileUsingForallOp::getEffects(
3723  consumesHandle(getTargetMutable(), effects);
3724  onlyReadsHandle(getTileSizesMutable(), effects);
3725  onlyReadsHandle(getNumThreadsMutable(), effects);
3726  onlyReadsHandle(getPackedNumThreadsMutable(), effects);
3727  onlyReadsHandle(getPackedTileSizesMutable(), effects);
3728  producesHandle(getOperation()->getOpResults(), effects);
3729  modifiesPayload(effects);
3730 }
3731 
3732 SmallVector<OpFoldResult> TileUsingForallOp::getMixedNumThreads() {
3733  Builder b(getContext());
3734  return getMixedValues(getStaticNumThreads(), getNumThreads(), b);
3735 }
3736 
3737 SmallVector<OpFoldResult> TileUsingForallOp::getMixedTileSizes() {
3738  Builder b(getContext());
3739  return getMixedValues(getStaticTileSizes(), getTileSizes(), b);
3740 }
3741 
3742 LogicalResult TileUsingForallOp::verify() {
3743  int numThreadsSpec = static_cast<int>(!getMixedNumThreads().empty()) +
3744  static_cast<int>(getPackedNumThreads() != Value());
3745  if (numThreadsSpec > 1)
3746  return emitOpError(
3747  "num_threads and packed_num_threads are mutually exclusive");
3748  int tileSizesSpec = static_cast<int>(!getMixedTileSizes().empty()) +
3749  static_cast<int>(getPackedTileSizes() != Value());
3750  if (tileSizesSpec > 1)
3751  return emitOpError(
3752  "tile_sizes and packed_tile_sizes are mutually exclusive");
3753  if (numThreadsSpec == 0 && tileSizesSpec == 0)
3754  return emitOpError("either (packed_)num_threads or (packed_)tile_sizes "
3755  "must be specified");
3756  return success();
3757 }
3758 
3759 //===----------------------------------------------------------------------===//
3760 // VectorizeChildrenAndApplyPatternsOp
3761 //===----------------------------------------------------------------------===//
3762 
3763 void transform::VectorizeChildrenAndApplyPatternsOp::build(
3764  OpBuilder &builder, OperationState &result, Value target,
3765  bool vectorizePadding, bool vectorizeExtract, bool flatten1DDepthwiseConv) {
3766  result.addOperands(target);
3767  if (vectorizePadding) {
3768  result.addAttribute(
3769  VectorizeChildrenAndApplyPatternsOp::getVectorizePaddingAttrName(
3770  result.name),
3771  builder.getUnitAttr());
3772  }
3773  if (vectorizeExtract) {
3774  result.addAttribute(
3775  VectorizeChildrenAndApplyPatternsOp::getVectorizeNdExtractAttrName(
3776  result.name),
3777  builder.getUnitAttr());
3778  }
3779  if (flatten1DDepthwiseConv) {
3780  result.addAttribute(
3781  VectorizeChildrenAndApplyPatternsOp::getFlatten_1dDepthwiseConvAttrName(
3782  result.name),
3783  builder.getUnitAttr());
3784  }
3785  result.addTypes(transform::AnyOpType::get(builder.getContext()));
3786 }
3787 
3788 namespace {
3789 /// This is an helper only to call vectorize via a pattern inside of
3790 /// VectorizeChildrenAndApplyPatternsOp::applyToOne.
3791 struct VectorizationPattern : public RewritePattern {
3792  explicit VectorizationPattern(MLIRContext *context,
3793  bool vectorizeExtract = false,
3794  bool flattenConv = false)
3795  : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context),
3796  vectorizeNDExtract(vectorizeExtract),
3797  flatten1DDepthwiseConv(flattenConv) {}
3798  LogicalResult matchAndRewrite(Operation *op,
3799  PatternRewriter &rewriter) const override {
3801  return rewriter.notifyMatchFailure(op,
3802  "Unsupported Op, cannot vectorize");
3803  return vectorize(rewriter, op, /*inputVectorSizes=*/{},
3804  /*inputScalableVecDims=*/{}, vectorizeNDExtract,
3805  flatten1DDepthwiseConv);
3806  }
3807 
3808 private:
3809  /// Controls whether to vectorize `tensor.extract` when the input tensor is
3810  /// rank >= 2.
3811  bool vectorizeNDExtract = false;
3812  /// Controls whether to "flatten" the channel dimension when vectorising 1D
3813  /// depthwise convolutions. This should lead to bette vectorization for
3814  /// tensors with a low number of channel dimensions.
3815  bool flatten1DDepthwiseConv = false;
3816 };
3817 } // namespace
3818 
3820 transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
3821  transform::TransformRewriter &rewriter, Operation *target,
3823  transform::TransformState &state) {
3824  if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
3825  auto diag = this->emitOpError("requires isolated-from-above targets");
3826  diag.attachNote(target->getLoc()) << "non-isolated target";
3828  }
3829 
3830  MLIRContext *ctx = getContext();
3832  patterns.add<VectorizationPattern>(ctx, getVectorizeNdExtract(),
3833  getFlatten_1dDepthwiseConv());
3834 
3835  if (!getDisableTransferPermutationMapLoweringPatterns())
3837 
3838  if (!getDisableMultiReductionToContractPatterns())
3840 
3842 
3845  /*benefit=*/2);
3846  vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx);
3847  vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx);
3849 
3851 
3852  if (getVectorizePadding()) {
3854  // This creates an alternative path for lowering tensor.pad - by
3855  // decomposing it into e.g. linalg.fill.
3857  }
3859 
3860  TrackingListener listener(state, *this);
3861  if (failed(
3862  applyPatternsGreedily(target, std::move(patterns),
3863  GreedyRewriteConfig().setListener(&listener))))
3864  return emitDefaultDefiniteFailure(target);
3865 
3866  results.push_back(target);
3868 }
3869 
3870 //===----------------------------------------------------------------------===//
3871 // VectorizeOp
3872 //===----------------------------------------------------------------------===//
3873 
3874 DiagnosedSilenceableFailure transform::VectorizeOp::apply(
3875  transform::TransformRewriter &rewriter,
3876  mlir::transform::TransformResults &transformResults,
3878  auto targets = state.getPayloadOps(getTarget());
3879  if (std::empty(targets))
3881  auto transformOp = cast<TransformOpInterface>(getOperation());
3882  SmallVector<int64_t> vectorSizes;
3884  state, transformOp, getMixedVectorSizes(), vectorSizes);
3885  if (!status.succeeded())
3886  return status;
3887 
3888  // TODO: Check that the correct number of vectorSizes was provided.
3889  for (Operation *target : targets) {
3890  if (!linalg::hasVectorizationImpl(target)) {
3891  return mlir::emitSilenceableFailure(target->getLoc())
3892  << "Unsupported Op, cannot vectorize";
3893  }
3894 
3895  if (failed(linalg::vectorize(rewriter, target, vectorSizes,
3896  getScalableSizes(),
3897  getVectorizeNdExtract().value_or(false)))) {
3898  return mlir::emitSilenceableFailure(target->getLoc())
3899  << "Attempted to vectorize, but failed";
3900  }
3901  }
3902 
3904 }
3905 
3906 void transform::VectorizeOp::getEffects(
3908  consumesHandle(getTargetMutable(), effects);
3909  onlyReadsHandle(getVectorSizesMutable(), effects);
3910  modifiesPayload(effects);
3911 }
3912 
3913 SmallVector<OpFoldResult> VectorizeOp::getMixedVectorSizes() {
3914  OpBuilder b(getContext());
3915  return getMixedValues(getStaticVectorSizes(), getVectorSizes(), b);
3916 }
3917 
3918 LogicalResult transform::VectorizeOp::verify() {
3919  if (getStaticVectorSizes().size() != getScalableSizes().size())
3920  return emitOpError("expected same number of vector sizes (")
3921  << getStaticVectorSizes().size() << ") and scalable sizes ("
3922  << getScalableSizes().size() << ")";
3923  return success();
3924 }
3925 
3926 //===----------------------------------------------------------------------===//
3927 // HoistRedundantVectorTransfersOp
3928 //===----------------------------------------------------------------------===//
3929 
3931 transform::HoistRedundantVectorTransfersOp::applyToOne(
3932  transform::TransformRewriter &rewriter, func::FuncOp target,
3934  transform::TransformState &state) {
3935  // WARNING: This hoisting does not model parallelism and is generally
3936  // incorrect when used on distributed loops with memref semantics!
3937  // TODO: obsolete and should be retired.
3938  linalg::hoistRedundantVectorTransfers(target, getVerifyNonZeroTrip());
3939  results.push_back(target);
3941 }
3942 
3943 //===----------------------------------------------------------------------===//
3944 // HoistRedundantVectorBroadcastsOp
3945 //===----------------------------------------------------------------------===//
3946 
3948 transform::HoistRedundantVectorBroadcastsOp::applyToOne(
3949  transform::TransformRewriter &rewriter, mlir::Operation *target,
3951  transform::TransformState &state) {
3952  rewriter.setInsertionPoint(target);
3953  linalg::hoistRedundantVectorBroadcasts(rewriter, target);
3954  results.push_back(target);
3956 }
3957 
3958 //===----------------------------------------------------------------------===//
3959 // ConvertConv2DToImg2ColOp.
3960 //===----------------------------------------------------------------------===//
3961 
3962 DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne(
3963  transform::TransformRewriter &rewriter, linalg::LinalgOp target,
3965  transform::TransformState &state) {
3966  rewriter.setInsertionPoint(target);
3967  auto maybeTransformed =
3969  target)
3970  .Case([&](linalg::Conv2DNhwcHwcfOp op) {
3971  return rewriteInIm2Col(rewriter, op);
3972  })
3973  .Case([&](linalg::Conv2DNhwcFhwcOp op) {
3974  return rewriteInIm2Col(rewriter, op);
3975  })
3976  .Case([&](linalg::DepthwiseConv2DNhwcHwcOp op) {
3977  return rewriteInIm2Col(rewriter, op);
3978  })
3979  .Case([&](linalg::Conv2DNchwFchwOp op) {
3980  return rewriteInIm2Col(rewriter, op);
3981  })
3982  .Default([&](Operation *op) {
3983  return rewriter.notifyMatchFailure(op, "not supported");
3984  });
3985  if (failed(maybeTransformed))
3986  return emitDefaultSilenceableFailure(target);
3987  // Handle to the operation producing the img2col tensor.
3988  results.push_back(maybeTransformed->first);
3989  // Handle to the operation that replaces the original convolution.
3990  results.push_back(maybeTransformed->second);
3992 }
3993 
3994 //===----------------------------------------------------------------------===//
3995 // FlattenElementwiseLinalgOp.
3996 //===----------------------------------------------------------------------===//
3997 
3998 DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne(
3999  transform::TransformRewriter &rewriter, linalg::LinalgOp target,
4001  transform::TransformState &state) {
4002  rewriter.setInsertionPoint(target);
4003  if (!isElementwise(target))
4004  return mlir::emitSilenceableFailure(target->getLoc())
4005  << "only elementwise flattening is supported";
4006 
4007  // If rank <= 1, do nothing
4008  if (target.getNumLoops() <= 1) {
4009  results.push_back(target);
4011  }
4012 
4013  // Attempt to flatten all dims to one.
4014  ReassociationIndices reassociation(target.getNumLoops());
4015  std::iota(reassociation.begin(), reassociation.end(), 0);
4016  auto maybeFlattened =
4017  collapseOpIterationDims(target, reassociation, rewriter);
4018  if (failed(maybeFlattened))
4019  return mlir::emitSilenceableFailure(target->getLoc())
4020  << "attempted to flatten, but failed";
4021  results.push_back(maybeFlattened->collapsedOp);
4022  rewriter.replaceOp(target, maybeFlattened->results);
4024 }
4025 
4026 //===----------------------------------------------------------------------===//
4027 // TransposeConv2DOp
4028 //===----------------------------------------------------------------------===//
4029 
4030 DiagnosedSilenceableFailure transform::TransposeConv2DOp::applyToOne(
4031  transform::TransformRewriter &rewriter, linalg::LinalgOp target,
4033  transform::TransformState &state) {
4034  rewriter.setInsertionPoint(target);
4035  auto maybeTransformed =
4037  .Case([&](linalg::Conv2DNhwcFhwcOp op) {
4038  return transposeConv2D(rewriter, op);
4039  })
4040  .Case([&](linalg::Conv2DNhwcFhwcQOp op) {
4041  return transposeConv2D(rewriter, op);
4042  })
4043  .Default([&](Operation *op) {
4044  return rewriter.notifyMatchFailure(op, "not supported");
4045  });
4046  if (failed(maybeTransformed))
4047  return emitDefaultSilenceableFailure(target);
4048  // Handle to the new Conv2D operation with transposed filters
4049  results.push_back(*maybeTransformed);
4051 }
4052 
4053 //===----------------------------------------------------------------------===//
4054 // TransposeMatmulOp
4055 //===----------------------------------------------------------------------===//
4056 
4057 DiagnosedSilenceableFailure transform::TransposeMatmulOp::applyToOne(
4058  transform::TransformRewriter &rewriter, linalg::LinalgOp target,
4060  transform::TransformState &state) {
4061  rewriter.setInsertionPoint(target);
4062  bool transposeLHS = getInputToTranspose() == TransposeMatmulInput::lhs;
4063  auto maybeTransformed =
4065  .Case([&](linalg::MatmulOp op) {
4066  return transposeMatmul(rewriter, op, transposeLHS);
4067  })
4068  .Case([&](linalg::BatchMatmulOp op) {
4069  return transposeBatchMatmul(rewriter, op, transposeLHS);
4070  })
4071  .Default([&](Operation *op) { return failure(); });
4072  if (failed(maybeTransformed))
4073  return emitSilenceableFailure(target->getLoc()) << "not supported";
4074  // Handle to the new Matmul operation with transposed filters
4075  results.push_back(*maybeTransformed);
4077 }
4078 
4079 //===----------------------------------------------------------------------===//
4080 // InsertSliceToCopyOp
4081 //===----------------------------------------------------------------------===//
4082 template <typename OpTy>
4085  transform::TransformState &state) {
4086  static_assert(llvm::is_one_of<OpTy, tensor::InsertSliceOp,
4087  tensor::ParallelInsertSliceOp>() &&
4088  "wrong op type");
4089 
4090  if (auto copySource =
4091  target.getSource().template getDefiningOp<linalg::CopyOp>()) {
4092  results.push_back(copySource);
4094  }
4095 
4096  // If we are inside an InParallel region, temporarily set the insertion point
4097  // outside: only tensor.parallel_insert_slice ops are allowed in there.
4098  if constexpr (std::is_same_v<OpTy, tensor::ParallelInsertSliceOp>) {
4099  rewriter.setInsertionPoint(
4100  target->template getParentOfType<scf::InParallelOp>());
4101  }
4102 
4103  Value extracted = rewriter.create<tensor::ExtractSliceOp>(
4104  target.getLoc(), target.getDest(), target.getMixedOffsets(),
4105  target.getMixedSizes(), target.getMixedStrides());
4106  Value copied = rewriter
4107  .create<linalg::CopyOp>(target.getLoc(),
4108  target.getSource(), extracted)
4109  .getResult(0);
4110  // Reset the insertion point.
4111  rewriter.setInsertionPoint(target);
4112  rewriter.replaceOpWithNewOp<OpTy>(
4113  target, copied, target.getDest(), target.getMixedOffsets(),
4114  target.getMixedSizes(), target.getMixedStrides());
4115 
4116  results.push_back(copied.getDefiningOp());
4118 }
4119 
4120 DiagnosedSilenceableFailure transform::InsertSliceToCopyOp::applyToOne(
4121  transform::TransformRewriter &rewriter, Operation *targetOp,
4123  transform::TransformState &state) {
4124 
4125  rewriter.setInsertionPoint(targetOp);
4126  if (auto target = dyn_cast<tensor::InsertSliceOp>(targetOp))
4127  return doit(rewriter, target, results, state);
4128  if (auto target = dyn_cast<tensor::ParallelInsertSliceOp>(targetOp))
4129  return doit(rewriter, target, results, state);
4130 
4132  emitSilenceableError()
4133  << "only InsertSliceOp and ParallelInsertSliceOp ops are supported";
4134  diag.attachNote(targetOp->getLoc()) << "target op";
4135  return diag;
4136 }
4137 
4138 //===----------------------------------------------------------------------===//
4139 // MapCopyToThreadsOp
4140 //===----------------------------------------------------------------------===//
4141 
4142 DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
4143  transform::TransformRewriter &rewriter, Operation *target,
4145  transform::TransformState &state) {
4146  // Check if the op is supported.
4147  if (!isa<linalg::CopyOp, tensor::PadOp>(target)) {
4149  emitSilenceableError()
4150  << "only linalg.copy and tensor.pad target ops are supported";
4151  diag.attachNote(target->getLoc()) << "target op";
4152  return diag;
4153  }
4154  assert(target->getNumResults() == 1 && "expected single result");
4155  auto resultShapedType = cast<ShapedType>(target->getResult(0).getType());
4156  if (!resultShapedType.hasStaticShape()) {
4158  emitSilenceableError()
4159  << "only statically sized ops of rank <= 3 are supported";
4160  diag.attachNote(target->getLoc()) << "target op";
4161  return diag;
4162  }
4163 
4164  // Conservatively set the minimum viable desired bitwidth alignment.
4165  int64_t desiredBitAlignment = getDesiredBitAlignment();
4166  int64_t eltBitwidth =
4167  resultShapedType.getElementType().getIntOrFloatBitWidth();
4168  if (desiredBitAlignment % eltBitwidth != 0) {
4169  desiredBitAlignment = eltBitwidth;
4170  }
4171 
4172  gpu::CopyMappingInfo mapping(
4173  /*ctx=*/getContext(),
4174  /*totalNumThreads=*/getTotalNumThreads(),
4175  /*alignment=*/desiredBitAlignment,
4176  /*sizes=*/resultShapedType.getShape(),
4177  /*favorPredication=*/false,
4178  /*elementalBitwidth=*/
4179  resultShapedType.getElementType().getIntOrFloatBitWidth());
4180  if (mapping.status == gpu::CopyMappingInfo::Status::Invalid) {
4182  emitSilenceableError()
4183  << "too few threads to map copy op to threads on the most minor "
4184  "dimension, given alignment and vector size constraints, try "
4185  "smaller tile size of mapping to more threads";
4186  diag.attachNote(target->getLoc()) << "target op";
4187  return diag;
4188  }
4189 
4190  // OpBuilder only used to compute attributes.
4191  OpBuilder b(getContext());
4192  scf::SCFTilingResult tilingResult;
4194  /*rewriter=*/rewriter,
4195  /*state=*/state,
4196  /*transformOp=*/*this,
4197  /*target=*/target,
4198  /*mixedNumThreads=*/getMixedValues(mapping.numThreads, {}, b),
4199  /*mixedTileSizes=*/ArrayRef<OpFoldResult>{},
4200  /*mapping=*/b.getArrayAttr(mapping.threadMapping),
4201  /*tilingResult=*/tilingResult);
4202  if (!diag.succeeded())
4203  return diag;
4204 
4205  results.push_back(tilingResult.loops.front());
4206  for (auto op : tilingResult.tiledOps)
4207  results.push_back(op);
4209 }
4210 
4211 //===----------------------------------------------------------------------===//
4212 // WinogradConv2DOp
4213 //===----------------------------------------------------------------------===//
4214 
4215 DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne(
4216  transform::TransformRewriter &rewriter, linalg::LinalgOp target,
4218  transform::TransformState &state) {
4219  rewriter.setInsertionPoint(target);
4220  FailureOr<Operation *> maybeTransformed = failure();
4221  bool supported = TypeSwitch<Operation *, bool>(target)
4222  .Case([&](linalg::Conv2DNhwcFhwcOp op) {
4223  maybeTransformed =
4224  winogradConv2D(rewriter, op, getM(), getR());
4225  return true;
4226  })
4227  .Default([&](Operation *op) { return false; });
4228 
4229  if (!supported) {
4230  return emitSilenceableError()
4231  << "this operation is not supported to convert to Winograd Conv2D";
4232  }
4233 
4234  if (failed(maybeTransformed)) {
4235  return emitSilenceableError() << "apply Winograd Conv2D failed";
4236  }
4237 
4238  results.push_back(*maybeTransformed);
4240 }
4241 
4242 DiagnosedSilenceableFailure transform::DecomposeWinogradOp::applyToOne(
4243  transform::TransformRewriter &rewriter, Operation *target,
4245  transform::TransformState &state) {
4246  rewriter.setInsertionPoint(target);
4247  FailureOr<Operation *> maybeTransformed = failure();
4248  bool supported =
4250  .Case([&](linalg::WinogradFilterTransformOp op) {
4251  maybeTransformed = decomposeWinogradFilterTransformOp(rewriter, op);
4252  return true;
4253  })
4254  .Case([&](linalg::WinogradInputTransformOp op) {
4255  maybeTransformed = decomposeWinogradInputTransformOp(rewriter, op);
4256  return true;
4257  })
4258  .Case([&](linalg::WinogradOutputTransformOp op) {
4259  maybeTransformed = decomposeWinogradOutputTransformOp(rewriter, op);
4260  return true;
4261  })
4262  .Default([&](Operation *op) { return false; });
4263 
4264  if (!supported) {
4266  emitSilenceableError()
4267  << "this operation is not supported to decompose into other operations";
4268  diag.attachNote(target->getLoc()) << "target op";
4269  return diag;
4270  }
4271 
4272  if (failed(maybeTransformed)) {
4274  emitSilenceableError() << "decompose Winograd operations failed";
4275  diag.attachNote(target->getLoc()) << "target op";
4276  return diag;
4277  }
4278 
4279  results.push_back(*maybeTransformed);
4281 }
4282 
4283 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
4284 
4285 #define GET_OP_CLASSES
4286 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
static SmallVector< Value > getTileSizes(Location loc, amx::TileType tType, RewriterBase &rewriter)
Maps the 2-dim vector shape to the two 16-bit tile sizes.
Definition: AMXDialect.cpp:70
static MLIRContext * getContext(OpFoldResult val)
DiagnosedSilenceableFailure doit(RewriterBase &rewriter, OpTy target, transform::ApplyToEachResultList &results, transform::TransformState &state)
static Operation * cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
#define DOWNSCALE(trans)
bool isValidPackingPermutation(RelayoutOpTy op, ArrayRef< int64_t > permutation, OuterOrInnerPerm outerOrInnerPerm=OuterOrInnerPerm::Outer)
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static DiagnosedSilenceableFailure reifyMixedParamAndHandleResults(TransformState &state, TransformOpInterface &transformOp, ArrayRef< OpFoldResult > mixedResults, SmallVectorImpl< int64_t > &reified)
When possible, converts each OpFoldResult in mixedResult to an integer if the value can be statically...
static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(transform::TransformState &state, TransformOpInterface transformOp, SmallVector< OpFoldResult > &result, ArrayRef< OpFoldResult > ofrs)
Assuming that ofr is an index attr or a param of index type or a transform dialect handle mapped to e...
static void printContinuousTileSizeTypes(OpAsmPrinter &printer, Operation *op, Type targetType, Type tile_sizes, Type)
static scf::ForallOp normalizeForallLoopOp(RewriterBase &rewriter, scf::ForallOp loop)
Given a scf.forall loop return a loop op with the loop bounds normalized.
static SmallVector< Value > denormalizeIndVar(RewriterBase &rewriter, Location loc, ValueRange ivs, ArrayRef< OpFoldResult > lbs, ArrayRef< OpFoldResult > steps)
When a loop is normalized, the uses of the induction variable within the loop need to replaced with o...
#define DOWNSCALE_NORMAL(a, b)
static FailureOr< LinalgOp > tryApply(Operation *operation, Args &&...args)
Attempts to apply the pattern specified as template argument to the given operation.
static void printMultitileSizesTypes(OpAsmPrinter &printer, Operation *op, Type targetType, Type lowSizeType, Type, Type)
static bool sameOrEquivalentIterArg(Value src, Value dst)
Given two operands coming from a loop iter arg, 'src' and 'dst', return true if the operand 'src' is ...
static Operation * replaceForAllWithNewSignature(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp, TilingResult &tileAndFuseResult, int64_t resultNumber, SmallVector< OpFoldResult > &offsets, SmallVector< OpFoldResult > &sizes)
Add new operands to the forall op for users of the producerOp that are dominated by the containing sc...
static ParseResult parseContinuousTileSizeTypes(OpAsmParser &parser, Type &targetType, Type &tileSizesType, Type &chunkSizesType)
static SmallVector< Operation * > tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
First, find the first "scf::ForallOp" user of producerOp and ensure it is exactly the containingOp,...
static ParseResult parseMultitileSizesTypes(OpAsmParser &parser, Type &targetType, Type &lowSizeType, Type &highSizeType, Type &splitPointType)
static SmallVector< OpFoldResult > normalizeUpperBounds(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > lbs, ArrayRef< OpFoldResult > ubs, ArrayRef< OpFoldResult > steps)
Given lbs, ubs and steps of loops, return (for each loop), the normalized upper bound.
#define DBGS()
static LogicalResult applyTilingToAll(RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps, unsigned numLoops, transform::TransformResults &transformResults, function_ref< FailureOr< scf::SCFTileAndFuseResult >(TilingInterface)> applyFn)
Apply a tiling transformation to all payload ops and store both the tiled operation as well as the cr...
static std::tuple< SmallVector< Operation * >, Operation * > tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
Find the first "extract" user of producerOp and tile it right before its use.
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:309
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:33
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:106
UnitAttr getUnitAttr()
Definition: Builders.cpp:96
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:226
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:165
AffineExpr getAffineSymbolExpr(unsigned position)
Definition: Builders.cpp:366
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:110
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:89
MLIRContext * getContext() const
Definition: Builders.h:55
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:264
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:279
IndexType getIndexType()
Definition: Builders.cpp:53
ArrayAttr getStrArrayAttr(ArrayRef< StringRef > values)
Definition: Builders.cpp:304
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
A class for computing basic dominance information.
Definition: Dominance.h:140
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition: Dominance.h:158
This class allows control over how the GreedyPatternRewriteDriver works.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
Definition: AsmPrinter.cpp:95
This class represents a saved insertion point.
Definition: Builders.h:325
bool isSet() const
Returns true if this insert point is set.
Definition: Builders.h:335
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:551
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
Definition: Builders.h:314
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:318
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:455
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
This class represents an operand of an operation.
Definition: Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:228
This is a value defined by a result of an operation.
Definition: Value.h:447
This class provides the API for ops that are known to be isolated from above.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
OpResult getOpResult(unsigned idx)
Definition: Operation.h:421
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:534
void setOperand(unsigned idx, Value value)
Definition: Operation.h:351
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition: Operation.h:560
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:346
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition: Operation.h:263
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:873
result_range getOpResults()
Definition: Operation.h:420
result_range getResults()
Definition: Operation.h:415
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:219
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:673
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:748
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:238
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:681
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:601
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:593
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:54
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
Type front()
Return first type in the range.
Definition: TypeRange.h:152
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
user_range getUsers() const
Definition: Value.h:218
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
State for analysis-enabled bufferization.
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
void assign(unsigned size, std::nullptr_t)
Sets the list of results to size null pointers.
void reserve(unsigned size)
Reserves space for size elements in the list.
size_t size() const
Returns the number of elements in the list.
void push_back(Operation *op)
Appends an element to the list.
A listener that updates a TransformState based on IR modifications.
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void setValues(OpResult handle, Range &&values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
void setParams(OpResult value, ArrayRef< TransformState::Param > params)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
LogicalResult notifyPayloadOperationReplaced(Operation *op, Operation *replacement)
Notify the transform dialect interpreter that the given op has been replaced with another op and that...
The state maintained across applications of various ops implementing the TransformOpInterface.
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
Definition: AffineOps.cpp:1271
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1175
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1225
LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze op and its nested ops.
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition: Visitors.h:136
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
FailureOr< PackingResult > buildPackingLoopNest(RewriterBase &rewriter, tensor::PadOp opToHoist, scf::ForOp outermostEnclosingForOp, ArrayRef< int64_t > transposeVector)
Build the packing loop nest required to hoist opToHoist above outermostEnclosingForOp.
LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, const LinalgPaddingOptions &options, LinalgOp &paddedOp, SmallVector< Value > &replacements, SmallVector< tensor::PadOp > &padOps)
Pad the iterator dimensions options.paddingDimensions of all opToPad operands to a static bounding bo...
Definition: Padding.cpp:244
FailureOr< std::pair< Operation *, Operation * > > rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp)
Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) and linalg....
bool hasVectorizationImpl(Operation *)
Return true if there's dedicated logic in the Linalg Vectorizer to vectorize this Op,...
FailureOr< Operation * > decomposeWinogradFilterTransformOp(RewriterBase &rewriter, linalg::WinogradFilterTransformOp op)
Rewrite linalg.winograd_filter_transform.
std::optional< Value > allocateWorkgroupMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU workgroup memory.
Definition: Promotion.cpp:469
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, linalg::PackOp packOp, linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
Definition: Transforms.cpp:678
Value bufferizeToAllocation(RewriterBase &rewriter, const BufferizeToAllocationOptions &options, tensor::PadOp padOp, Attribute memorySpace={}, Operation *insertionPoint=nullptr)
Materialize a buffer allocation for the given tensor.pad op and lower the op to linalg....
FailureOr< Value > hoistPaddingOnTensors(RewriterBase &rewriter, tensor::PadOp opToHoist, int64_t numLoops, ArrayRef< int64_t > transposeVector, tensor::PadOp &hoistedOp, SmallVectorImpl< TransposeOp > &transposeOps)
Mechanically hoist padding operations on tensors by numLoops into a new, generally larger tensor.
FailureOr< LinalgOp > specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp)
Create a namedOp from the given GenericOp and replace the GenericOp.
Definition: Specialize.cpp:260
FailureOr< LowerUnPackOpResult > lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp, bool lowerUnpadLikeWithExtractSlice=true)
Rewrite pack as empty + transpose + reshape + extract_slice.
Definition: Transforms.cpp:359
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns)
Definition: Tiling.cpp:861
LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value)
In case of GPU private memory there is no need to deallocate since the memory is freed when going out...
Definition: Promotion.cpp:510
FailureOr< Operation * > decomposeWinogradOutputTransformOp(RewriterBase &rewriter, linalg::WinogradOutputTransformOp op)
Rewrite linalg.winograd_output_transform.
std::optional< Value > allocateGPUPrivateMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU private memory.
Definition: Promotion.cpp:494
FailureOr< Operation * > rewriteInDestinationPassingStyle(RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp)
Rewrite tensor.from_elements to linalg.generic.
FailureOr< Operation * > winogradConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op, int64_t m, int64_t r)
Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm F(m x m, r x r).
FailureOr< Operation * > transposeConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op)
Convert linalg.conv_2d_nhwc_fhwc(_q) to linalg.conv_2d_nhwc_hwcf(_q) by materializing transpose.
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
Patterns to fold unit-extent dimensions in operands/results of linalg ops on tensors via reassociativ...
LogicalResult copyToWorkgroupMemory(OpBuilder &b, Value src, Value dst)
Create Memref copy operations and add gpu barrier guards before and after the copy operation to ensur...
Definition: Promotion.cpp:485
LogicalResult linalgOpAnchoredEmptyTensorEliminationStep(RewriterBase &rewriter, Operation *op, bufferization::OneShotAnalysisState &state)
Try to eliminate tensor::EmptyOps inside op that are anchored on a LinalgOp.
LogicalResult vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Emit a suitable vector form for an operation.
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp linalgOp)
Create a GenericOp from the given named operation linalgOp and replace the given linalgOp.
FailureOr< Operation * > transposeBatchMatmul(RewriterBase &rewriter, linalg::BatchMatmulOp op, bool transposeLHS=true)
Pattern to replace.
LogicalResult promoteSubviewsPrecondition(Operation *op, LinalgPromotionOptions options)
Promote memref.subviews feeding linalg-on-buffers operations.
Definition: Promotion.cpp:398
LogicalResult copyToGPUPrivateMemory(OpBuilder &b, Value src, Value dst)
Normal copy to between src and dst.
Definition: Promotion.cpp:502
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
Definition: Utils.cpp:223
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.
Definition: Interchange.cpp:50
FailureOr< StaticMultiSizeSpecification > computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, int64_t targetSize, int64_t divisor)
Definition: Tiling.cpp:243
void populateDecomposePackUnpackPatterns(RewritePatternSet &patterns)
Populates patterns to decompose linalg.pack and linalg.unpack Ops into e.g.
FailureOr< ContinuousTileSizeSpecification > computeContinuousTileSizes(OpBuilder &builder, TilingInterface op, unsigned dimension, OpFoldResult targetSize, bool emitAssertions)
Definition: Tiling.cpp:163
FailureOr< StaticContinuousTileSizeSpecification > computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension, unsigned targetSize)
Definition: Tiling.cpp:112
FailureOr< SplitReductionResult > splitReduction(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
void populateFoldPackUnpackIntoTensorEmptyPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold operations like linalg.pack and linalg....
void hoistRedundantVectorBroadcasts(RewriterBase &rewriter, Operation *root)
Hoist vector.extract/vector.broadcast pairs out of immediately enclosing scf::ForOp iteratively,...
Definition: Hoisting.cpp:97
FailureOr< ForallReductionTilingResult > tileReductionUsingForall(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > tileSizes={}, std::optional< ArrayAttr > mapping=std::nullopt)
Method to tile a reduction to parallel iterations computing partial reductions.
Definition: Tiling.cpp:595
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
Definition: Transforms.cpp:769
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
Definition: Transforms.cpp:481
void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns)
Patterns to promote inputs to outputs and remove unused inputs of linalg.generic ops.
FailureOr< LinalgOp > promoteSubViews(OpBuilder &b, LinalgOp op, const LinalgPromotionOptions &options)
Promote the subViews into a new buffer allocated at the insertion point b.
Definition: Promotion.cpp:420
std::function< SplitReductionOptions(LinalgOp op)> ControlSplitReductionFn
Function signature to control reduction splitting.
Definition: Transforms.h:489
LogicalResult deallocateWorkgroupMemory(OpBuilder &, Value)
In case of GPU group memory there is no need to deallocate.
Definition: Promotion.cpp:478
FailureOr< Operation * > transposeMatmul(RewriterBase &rewriter, linalg::MatmulOp op, bool transposeLHS=true)
Convert Linalg matmul ops to transposed variants.
FailureOr< CollapseResult > collapseOpIterationDims(LinalgOp op, ArrayRef< ReassociationIndices > foldedIterationDims, RewriterBase &rewriter)
Collapses dimensions of linalg.generic/linalg.copy operation.
void hoistRedundantVectorTransfers(Operation *root, bool verifyNonZeroTrip=false)
Hoist vector.transfer_read/vector.transfer_write on buffers pairs out of immediately enclosing scf::F...
Definition: Hoisting.cpp:202
FailureOr< Operation * > decomposeWinogradInputTransformOp(RewriterBase &rewriter, linalg::WinogradInputTransformOp op)
Rewrite linalg.winograd_input_transform.
void populateDecomposePadPatterns(RewritePatternSet &patterns)
Populates patterns to decompose tensor.pad into e.g.
void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns)
Pattern to replace linalg.add when destination passing on a contraction op suffices for achieving the...
std::pair< TilingInterface, TilingInterface > splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension, OpFoldResult splitPoint)
Split the given op into two parts along the given iteration space dimension at the specified splitPoi...
Definition: Split.cpp:67
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold operations like tensor.pad and tensor.extract_slice into t...
FailureOr< SplitReductionResult > splitReductionByScaling(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
Scaling-based implementation of the split reduction transformation.
FailureOr< MultiSizeSpecification > computeMultiTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension, OpFoldResult targetSize, OpFoldResult divisor, bool emitAssertions=true)
Emits the IR computing the multi-sized tiling specification with two tile sizes not exceeding targetS...
Definition: Tiling.cpp:269
FailureOr< LowerPackResult > lowerPack(RewriterBase &rewriter, linalg::PackOp packOp, bool lowerPadLikeWithInsertSlice=true)
Rewrite pack as pad + reshape + transpose.
Definition: Transforms.cpp:224
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:77
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:22
FailureOr< scf::SCFTilingResult > tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > tileSizes)
Method to tile a reduction and generate a parallel op within a serial loop.
FailureOr< SCFTilingResult > tileUsingSCF(RewriterBase &rewriter, TilingInterface op, const SCFTilingOptions &options)
Method to tile an op that implements the TilingInterface using scf.for for iterating over the tiles.
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
Definition: SCF.cpp:604
FailureOr< SmallVector< scf::ForOp > > lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op)
Method to lower an op that implements the TilingInterface to loops/scalars.
uint64_t getM(LevelType lt)
Definition: Enums.h:443
void populateMergeConsecutiveInsertExtractSlicePatterns(RewritePatternSet &patterns)
Collects patterns to merge consecutive tensor.insert_slice/extract_slice into one.
void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns)
Appends patterns that are used to bubble up tensor.extract slice op above its producer.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition: TensorOps.cpp:64
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:117
void populateFoldTensorSubsetIntoVectorTransferPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into vector transfer ops.
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
DiagnosedSilenceableFailure tileToForallOpImpl(RewriterBase &rewriter, transform::TransformState &state, TransformOpInterface transformOp, Operation *target, ArrayRef< OpFoldResult > mixedNumThreads, ArrayRef< OpFoldResult > mixedTileSizes, std::optional< ArrayAttr > mapping, scf::SCFTilingResult &tilingResult)
Implementation of tiling operations using scf.forall.
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
void populateVectorStepLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant Vector Ops by re-ordering them with e.g.
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Definition: XeGPUOps.cpp:23
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:325
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
This is the representation of an operand reference.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Definition: Builders.h:283
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
NamedAttrList attributes
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
A listener that forwards all notifications to another listener.
Definition: PatternMatch.h:421
ForwardingListener(OpBuilder::Listener *listener)
Definition: PatternMatch.h:422
Container for result values of tiling.
SmallVector< Value > tiledValues
Options for analysis-enabled bufferization.
Transformation to drop unit-extent dimensions from linalg.generic operations.
Definition: Transforms.h:520
Vectorization pattern for memref::CopyOp.
Definition: Transforms.h:1595
Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) dimensions into 1-D depthwise c...
Definition: Transforms.h:1527
Match and rewrite for the pattern:
Definition: Transforms.h:1717
Match and rewrite for the pattern:
Definition: Transforms.h:1745
LinalgPromotionOptions & setUseFullTileBuffersByDefault(bool use)
Definition: Transforms.h:426
LinalgPromotionOptions & setAlignment(unsigned align)
Definition: Transforms.h:432
LinalgPromotionOptions & setUseAlloca(bool use)
Definition: Transforms.h:445
LinalgPromotionOptions & setCopyInOutFns(CopyCallbackFn const &copyIn, CopyCallbackFn const &copyOut)
Definition: Transforms.h:465
LinalgPromotionOptions & setUseFullTileBuffers(ArrayRef< bool > useFullTiles)
Definition: Transforms.h:415
LinalgPromotionOptions & setMemorySpace(Attribute memorySpc)
Definition: Transforms.h:439
LinalgPromotionOptions & setAllocationDeallocationFns(AllocBufferCallbackFn const &allocFn, DeallocBufferCallbackFn const &deallocFn)
Definition: Transforms.h:455
LinalgPromotionOptions & setOperandsToPromote(ArrayRef< int64_t > operands)
Definition: Transforms.h:404
Split Reduction options.
Definition: Transforms.h:474
Options used to control tile + fuse.
SCFTilingOptions tilingOptions
The tiling options used to control the tiling of the consumer.
std::optional< FrozenRewritePatternSet > cleanupPatterns
An optional set of rewrite patterns to apply to the results of tiling before fusion.
Options to use to control tiling.
SCFTilingOptions & setTileSizeComputationFunction(SCFTileSizeComputationFunction fun)
SCFTilingOptions & setInterchange(ArrayRef< int64_t > interchange)
SCFTilingOptions & setTileSizes(ArrayRef< OpFoldResult > tileSizes)
Convenience function to set the tileSizeComputationFunction to a function that computes tile sizes at...
SmallVector< int64_t > interchangeVector
The interchange vector to reorder the tiled loops.
Transformation information returned after tiling.
SmallVector< Operation * > tiledOps
Tiled operations that are generated during tiling.
SmallVector< LoopLikeOpInterface > loops
The scf.for operations that iterate over the tiles.