MLIR  21.0.0git
LinalgTransformOps.cpp
Go to the documentation of this file.
1 //===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
12 
38 #include "mlir/IR/PatternMatch.h"
39 #include "mlir/IR/TypeUtilities.h"
41 #include "mlir/Support/LLVM.h"
42 #include "mlir/Support/TypeID.h"
44 #include "llvm/ADT/STLExtras.h"
45 #include "llvm/ADT/ScopeExit.h"
46 #include "llvm/ADT/TypeSwitch.h"
47 #include "llvm/Support/Debug.h"
48 #include <type_traits>
49 
50 using namespace mlir;
51 using namespace mlir::linalg;
52 using namespace mlir::transform;
53 
54 #define DEBUG_TYPE "linalg-transforms"
55 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
56 #define DBGSNL() (llvm::dbgs() << "\n")
57 #define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n")
58 
59 /// Attempts to apply the pattern specified as template argument to the given
60 /// operation. The pattern is expected to have a `returningMatchAndRewrite`
61 /// function that returns the "main" result or failure. Returns failure if the
62 /// pattern failed to apply. Extra arguments are forwarded to the pattern
63 /// constructor.
64 template <typename PatternTy, typename... Args>
65 static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
66  // Check if the given operation has the type expected by the pattern.
67  using OpTy = typename llvm::function_traits<
68  decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
69  auto op = dyn_cast<OpTy>(operation);
70  if (!op)
71  return failure();
72 
73  // Apply the pattern directly to the op.
74  PatternTy pattern(operation->getContext(), std::forward<Args>(args)...);
75  // We want to discourage direct use of PatternRewriter in APIs but In this
76  // very specific case, an IRRewriter is not enough.
77  struct TrivialPatternRewriter : public PatternRewriter {
78  public:
79  explicit TrivialPatternRewriter(MLIRContext *context)
80  : PatternRewriter(context) {}
81  };
82  TrivialPatternRewriter rewriter(operation->getContext());
83  rewriter.setInsertionPoint(operation);
84  auto result = pattern.returningMatchAndRewrite(op, rewriter);
85  if (failed(result))
86  return failure();
87  return cast<LinalgOp>(result->getOperation());
88 }
89 
90 /// Assuming that `ofr` is an index attr or a param of index type
91 /// or a transform dialect handle mapped to exactly one op
92 /// with one index result, return that value.
94  transform::TransformState &state, TransformOpInterface transformOp,
96  for (OpFoldResult ofr : ofrs) {
97  if (auto attr = dyn_cast<Attribute>(ofr)) {
98  if (!isa<IntegerAttr>(attr))
99  return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
100  result.push_back(ofr);
101  continue;
102  }
103 
104  Value transformValue = cast<Value>(ofr);
105  if (isa<TransformParamTypeInterface>(transformValue.getType())) {
106  ArrayRef<Attribute> params = state.getParams(transformValue);
107  if (params.size() != 1)
108  return transformOp.emitDefiniteFailure()
109  << "requires exactly one parameter associated";
110  result.push_back(params[0]);
111  continue;
112  }
113 
114  auto payloadOps = state.getPayloadOps(transformValue);
115  if (!llvm::hasSingleElement(payloadOps)) {
117  transformOp.emitSilenceableError()
118  << "handle must be mapped to exactly one payload op";
119  diag.attachNote(transformValue.getLoc())
120  << "mapped to " << llvm::range_size(payloadOps) << " payload ops";
121  return diag;
122  }
123 
124  Operation *op = *payloadOps.begin();
125  if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
127  transformOp.emitSilenceableError()
128  << "payload op must have exactly 1 index result";
129  diag.attachNote(op->getLoc())
130  << "has " << op->getNumResults() << " results";
131  return diag;
132  }
133  result.push_back(op->getResult(0));
134  }
135 
137 }
138 
139 // Given a list of params that are index attrs or a list of OpFoldResults
140 // that are either index attrs or op handles, return a list of OpFoldResults
141 // of index attrs or a list of OpFoldResults where all op handles are
142 // replaced with the first (and only) OpResult of that payload op.
143 // (There must be exactly one parameter associated with the AnyParamType or
144 // one mapped payload op which must have exactly one index result.)
146  transform::TransformState &state, TransformOpInterface transformOp,
147  SmallVector<OpFoldResult> &result, Value packedHandle) {
148  if (isa<TransformParamTypeInterface>(packedHandle.getType())) {
149  ArrayRef<Attribute> params = state.getParams(packedHandle);
150  for (auto param : params) {
151  if (!isa<IntegerAttr>(param))
152  return transformOp.emitDefiniteFailure()
153  << "expected the parameter to be associated with an integer "
154  "attribute";
155  result.push_back(param);
156  }
158  }
159 
160  for (Operation *op : state.getPayloadOps(packedHandle)) {
161  if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
163  transformOp.emitSilenceableError()
164  << "payload op must have exactly 1 index result";
165  diag.attachNote(op->getLoc())
166  << "has " << op->getNumResults() << " results";
167  return diag;
168  }
169  result.push_back(op->getResult(0));
170  }
171 
173 }
174 
175 /// When possible, converts each `OpFoldResult` in `mixedResult` to
176 /// an integer if the value can be statically inferred. If a result
177 /// is a `Value` then it must be either a `ParamType` or a handle
178 /// to an a constant like op.
180  TransformState &state, TransformOpInterface &transformOp,
181  ArrayRef<OpFoldResult> mixedResults, SmallVectorImpl<int64_t> &reified) {
182  for (OpFoldResult paramOrHandle : mixedResults) {
183  if (auto attr = dyn_cast<Attribute>(paramOrHandle)) {
184  reified.push_back(cast<IntegerAttr>(attr).getInt());
185  continue;
186  } else if (isa<ParamType>(cast<Value>(paramOrHandle).getType())) {
187  ArrayRef<Attribute> params = state.getParams(cast<Value>(paramOrHandle));
188  if (params.size() != 1)
189  return transformOp.emitSilenceableError() << "expected a single param";
190  reified.push_back(
191  cast<IntegerAttr>(params.front()).getValue().getSExtValue());
192  continue;
193  }
194 
195  Value handle = cast<Value>(paramOrHandle);
196  if (!isa<TransformHandleTypeInterface>(handle.getType()))
197  return transformOp.emitSilenceableError() << "unexpected value handle";
198  auto payload = state.getPayloadOps(handle);
199  if (!llvm::hasSingleElement(payload))
200  return transformOp.emitSilenceableError()
201  << "requires param or handle that is mapped to 1 payload op";
202 
203  Operation *paramOrHandlePayloadOp = *payload.begin();
204  if (paramOrHandlePayloadOp->getNumResults() != 1 ||
205  !paramOrHandlePayloadOp->getResult(0).getType().isIndex()) {
206  return transformOp.emitSilenceableError()
207  << "requires param or handle to be result of op with 1 index "
208  "result";
209  }
210 
211  IntegerAttr attr;
212  if (!matchPattern(paramOrHandlePayloadOp->getResult(0), m_Constant(&attr)))
213  return transformOp.emitSilenceableError()
214  << "requires param or handle to be the result of a constant like "
215  "op";
216 
217  reified.push_back(attr.getInt());
218  }
220 }
221 
222 //===----------------------------------------------------------------------===//
223 // Apply...PatternsOp
224 //===----------------------------------------------------------------------===//
225 
226 void transform::ApplyEraseUnnecessaryInputsPatternsOp::populatePatterns(
229 }
230 
231 void transform::ApplyDecomposeTensorPackUnpackPatternsOp::populatePatterns(
234 }
235 
236 void transform::ApplyDecomposeTensorPadPatternsOp::populatePatterns(
239 }
240 
241 void transform::ApplyFoldUnitExtentDimsViaReshapesPatternsOp::populatePatterns(
245 }
246 
247 void transform::ApplyFoldUnitExtentDimsViaSlicesPatternsOp::populatePatterns(
250  options.rankReductionStrategy =
253 }
254 
255 void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns(
258 }
259 
260 void transform::ApplyFoldAddIntoDestPatternsOp::populatePatterns(
263 }
264 
265 void transform::ApplyPadVectorizationPatternsOp::populatePatterns(
268 }
269 
270 void transform::ApplyFoldIntoPackAndUnpackPatternsOp::populatePatterns(
273 }
274 
275 void transform::ApplyFoldPackUnpackIntoEmptyPatternsOp::populatePatterns(
278 }
279 
280 //===----------------------------------------------------------------------===//
281 // BufferizeToAllocationOp
282 //===----------------------------------------------------------------------===//
283 
284 void transform::BufferizeToAllocationOp::build(OpBuilder &b,
285  OperationState &result,
286  Value target,
287  Attribute memorySpace) {
288  SmallVector<Type> resultTypes;
289  resultTypes.push_back(b.getType<transform::AnyValueType>());
290  resultTypes.push_back(b.getType<transform::AnyOpType>());
291  return build(b, result,
292  /*resultTypes=*/resultTypes,
293  /*target=*/target,
294  /*memorySpace=*/memorySpace);
295 }
296 
297 void transform::BufferizeToAllocationOp::build(OpBuilder &b,
298  OperationState &result,
299  Value target,
300  int64_t memorySpace) {
301  SmallVector<Type> resultTypes;
302  resultTypes.push_back(b.getType<transform::AnyValueType>());
303  resultTypes.push_back(b.getType<transform::AnyOpType>());
304  return build(b, result,
305  /*resultTypes=*/resultTypes,
306  /*target=*/target,
307  /*memorySpace=*/b.getI64IntegerAttr(memorySpace));
308 }
309 
310 namespace {
311 class NewOpsListener : public RewriterBase::ForwardingListener {
312 public:
314 
315  SmallVector<Operation *> getNewOps() const {
316  return SmallVector<Operation *>(newOps.begin(), newOps.end());
317  }
318 
319 private:
320  void notifyOperationInserted(Operation *op,
321  OpBuilder::InsertPoint previous) override {
322  ForwardingListener::notifyOperationInserted(op, previous);
323  // We only care about newly created ops.
324  if (previous.isSet())
325  return;
326  auto inserted = newOps.insert(op);
327  (void)inserted;
328  assert(inserted.second && "expected newly created op");
329  }
330 
331  void notifyOperationErased(Operation *op) override {
332  ForwardingListener::notifyOperationErased(op);
333  op->walk([&](Operation *op) { newOps.erase(op); });
334  }
335 
336  DenseSet<Operation *> newOps;
337 };
338 } // namespace
339 
340 DiagnosedSilenceableFailure transform::BufferizeToAllocationOp::apply(
343  // Attach listener to keep track of newly created ops.
344  OpBuilder::Listener *previousListener = rewriter.getListener();
345  auto resetListener =
346  llvm::make_scope_exit([&]() { rewriter.setListener(previousListener); });
347  NewOpsListener newOpsListener(previousListener);
348  rewriter.setListener(&newOpsListener);
349 
351  if (getMemcpyOp() == "bufferization.materialize_in_destination") {
354  } else if (getMemcpyOp() == "memref.copy") {
355  options.memcpyOp =
357  } else if (getMemcpyOp() == "linalg.copy") {
358  options.memcpyOp =
360  } else {
361  llvm_unreachable("invalid memcpy op");
362  }
363  if (getAllocOp() == "memref.alloc") {
364  options.allocOp =
366  } else if (getAllocOp() == "memref.alloca") {
367  options.allocOp =
369  } else {
370  llvm_unreachable("invalid alloc op");
371  }
372  options.bufferizeDestinationOnly = getBufferizeDestinationOnly();
373  options.emitDealloc = getEmitDealloc();
374 
375  // Bufferize ops.
376  Attribute memorySpace =
377  getMemorySpace().has_value() ? getMemorySpace().value() : Attribute();
378  SmallVector<Value> allocatedBuffers;
379  for (Operation *op : state.getPayloadOps(getTarget())) {
380  Value buffer =
381  linalg::bufferizeToAllocation(rewriter, options, op, memorySpace);
382  if (!buffer) {
383  DiagnosedSilenceableFailure diag = emitSilenceableError()
384  << "failed to bufferize operation";
385  diag.attachNote(op->getLoc()) << "target payload op";
386  return diag;
387  }
388  allocatedBuffers.push_back(buffer);
389  }
390 
391  // Set results.
392  results.setValues(cast<OpResult>(getAllocatedBuffer()), allocatedBuffers);
393  results.set(cast<OpResult>(getNewOps()), newOpsListener.getNewOps());
395 }
396 
397 void transform::BufferizeToAllocationOp::getEffects(
399  if (getBufferizeDestinationOnly()) {
400  // The destination is replaced with a newly allocated buffer, but the op
401  // itself remains in place.
402  onlyReadsHandle(getTargetMutable(), effects);
403  } else {
404  consumesHandle(getTargetMutable(), effects);
405  }
406  producesHandle(getOperation()->getOpResults(), effects);
407  modifiesPayload(effects);
408 }
409 
411  if (getMemcpyOp() != "bufferization.materialize_in_destination" &&
412  getMemcpyOp() != "memref.copy" && getMemcpyOp() != "linalg.copy")
413  return emitOpError() << "unsupported memcpy op";
414  if (getAllocOp() != "memref.alloc" && getAllocOp() != "memref.alloca")
415  return emitOpError() << "unsupported alloc op";
416  return success();
417 }
418 
419 //===----------------------------------------------------------------------===//
420 // DecomposeOp
421 //===----------------------------------------------------------------------===//
422 
424 transform::DecomposeOp::applyToOne(transform::TransformRewriter &rewriter,
425  LinalgOp target,
427  transform::TransformState &state) {
428 #define DOWNSCALE(trans) \
429  { \
430  FailureOr<LinalgOp> res = tryApply<trans>(target); \
431  if (succeeded(res)) { \
432  results.push_back(*res); \
433  return DiagnosedSilenceableFailure::success(); \
434  } \
435  }
436 
437 #define DOWNSCALE_CALL(a, b) DownscaleSizeOneWindowed2DConvolution<a, b>
438 #define DOWNSCALE_NORMAL(a, b) DOWNSCALE(DOWNSCALE_CALL(a, b))
439 
440  DOWNSCALE_NORMAL(Conv2DNhwcHwcfOp, Conv1DNwcWcfOp)
441  DOWNSCALE_NORMAL(Conv2DNchwFchwOp, Conv1DNcwFcwOp)
442  DOWNSCALE_NORMAL(PoolingNhwcSumOp, PoolingNwcSumOp)
443  DOWNSCALE_NORMAL(PoolingNchwSumOp, PoolingNcwSumOp)
444  DOWNSCALE_NORMAL(PoolingNhwcMaxOp, PoolingNwcMaxOp)
445  DOWNSCALE_NORMAL(PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp)
446  DOWNSCALE_NORMAL(PoolingNhwcMinOp, PoolingNwcMinOp)
447  DOWNSCALE_NORMAL(PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp)
448  DOWNSCALE_NORMAL(PoolingNchwMaxOp, PoolingNcwMaxOp)
451 #undef DOWNSCALE_NORMAL
452 #undef DOWNSCALE_CALL
453 #undef DOWNSCALE
454  return emitDefaultSilenceableFailure(target);
455 }
456 
457 //===----------------------------------------------------------------------===//
458 // DecomposeInterfaceOp
459 //===----------------------------------------------------------------------===//
460 
461 // Decompose the target operation if it implements the AggregatedOpInterface.
462 // Push the decomposed operations (the ones that replaces the values produced by
463 // \p target) in the `results`.
464 DiagnosedSilenceableFailure transform::DecomposeInterfaceOp::applyToOne(
465  transform::TransformRewriter &rewriter, Operation *target,
467  transform::TransformState &state) {
468  auto decomposableOp = dyn_cast<AggregatedOpInterface>(target);
469  if (!decomposableOp) {
470  failed(rewriter.notifyMatchFailure(target,
471  "payload is not a decomposable op"));
472  return emitDefaultSilenceableFailure(target);
473  }
474 
475  FailureOr<SmallVector<Value>> maybeNewResults =
476  decomposableOp.decomposeOperation(rewriter);
477  if (failed(maybeNewResults))
478  return emitDefaultSilenceableFailure(target);
479 
480  rewriter.replaceOp(decomposableOp, *maybeNewResults);
481  for (Value val : *maybeNewResults) {
482  Operation *definition = val.getDefiningOp();
483  if (definition)
484  results.push_back(definition);
485  }
487 }
488 
489 //===----------------------------------------------------------------------===//
490 // EliminateLinalgOpAnchoredEmptyTensorsOp
491 //===----------------------------------------------------------------------===//
492 
493 void transform::EliminateLinalgOpAnchoredEmptyTensorsOp::getEffects(
495  onlyReadsHandle(getTargetMutable(), effects);
496  modifiesPayload(effects);
497 }
498 
500 transform::EliminateLinalgOpAnchoredEmptyTensorsOp::apply(
501  transform::TransformRewriter &rewriter, TransformResults &transformResults,
502  TransformState &state) {
504  options.allowReturnAllocsFromLoops = true;
505 
506  for (Operation *target : state.getPayloadOps(getTarget())) {
508  if (failed(analyzeOp(target, state)))
509  return mlir::emitSilenceableFailure(target->getLoc())
510  << "failed to analyze op";
512  rewriter, target, state)))
513  return mlir::emitSilenceableFailure(target->getLoc())
514  << "failed to eliminate LinalgOp anchored tensor.empty ops";
515  }
517 }
518 
519 //===----------------------------------------------------------------------===//
520 // FuseOp
521 //===----------------------------------------------------------------------===//
522 
523 /// Apply a tiling transformation to all payload ops and store both the
524 /// tiled operation as well as the created tile loops.
525 template <typename Range>
526 static LogicalResult applyTilingToAll(
527  RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps,
528  unsigned numLoops, transform::TransformResults &transformResults,
529  function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)>
530  applyFn) {
531  SmallVector<Operation *> tiledLinalgOps;
532  SmallVector<SmallVector<Operation *>> loopOps(numLoops);
533 
534  for (Operation *target : payloadOps) {
535  auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
536  if (!tilingInterfaceOp)
537  return transformOp->emitError("only TilingInterface ops are supported");
538 
539  rewriter.setInsertionPoint(target);
540  FailureOr<scf::SCFTileAndFuseResult> tiledResults =
541  applyFn(tilingInterfaceOp);
542  if (failed(tiledResults))
543  return failure();
544 
545  // Perform the replacement of tiled and fused values.
546  SmallVector<Operation *> opsToReplace{target};
547  llvm::append_range(opsToReplace, tiledResults->fusedProducers);
548  for (Operation *toReplace : opsToReplace) {
549  for (OpResult res : toReplace->getResults())
550  if (auto replacement = tiledResults->replacements.lookup(res))
551  rewriter.replaceAllUsesWith(res, replacement);
552  if (toReplace->use_empty()) {
553  rewriter.eraseOp(toReplace);
554  }
555  }
556 
557  // Report back the relevant handles to the transform op.
558  tiledLinalgOps.push_back(tiledResults->tiledAndFusedOps.front());
559  assert(tiledResults->loops.size() == numLoops &&
560  "Mismatched number of loops, tile and fuse transform should have "
561  "failed");
562  for (unsigned int i = 0; i < numLoops; ++i)
563  loopOps[i].push_back(tiledResults->loops[i]);
564  }
565 
566  transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
567  for (unsigned int i = 0; i < numLoops; ++i)
568  transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
569 
570  return success();
571 }
572 
574 transform::FuseOp::apply(transform::TransformRewriter &rewriter,
575  mlir::transform::TransformResults &transformResults,
577  SmallVector<int64_t> tileSizes =
578  extractFromIntegerArrayAttr<int64_t>(getTileSizes());
579  SmallVector<int64_t> tileInterchange =
580  extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
581 
582  scf::SCFTilingOptions tilingOptions;
583  tilingOptions.interchangeVector = tileInterchange;
584  SmallVector<OpFoldResult> tileSizesOfr =
585  getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
586  tilingOptions = tilingOptions.setTileSizes(tileSizesOfr);
587  scf::SCFTileAndFuseOptions tileAndFuseOptions;
588  tileAndFuseOptions.tilingOptions = tilingOptions;
589 
590  if (getApplyCleanup()) {
591  MLIRContext *context = rewriter.getContext();
592  RewritePatternSet patterns(context);
593  tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, context);
596  tileAndFuseOptions.cleanupPatterns = std::move(patterns);
597  }
598 
599  LogicalResult result = applyTilingToAll(
600  rewriter, getOperation(), state.getPayloadOps(getTarget()),
601  tileSizes.size() - llvm::count(tileSizes, 0), transformResults,
602  [&](TilingInterface tilingInterfaceOp)
603  -> FailureOr<scf::SCFTileAndFuseResult> {
604  return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
605  tileAndFuseOptions);
606  });
607  return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
608  : DiagnosedSilenceableFailure::success();
609 }
610 
611 LogicalResult transform::FuseOp::verify() {
612  SmallVector<int64_t> permutation =
613  extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
614  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
615  if (!std::is_permutation(sequence.begin(), sequence.end(),
616  permutation.begin(), permutation.end())) {
617  return emitOpError() << "expects interchange to be a permutation, found "
618  << getTileInterchange();
619  }
620 
621  SmallVector<int64_t> sizes =
622  extractFromIntegerArrayAttr<int64_t>(getTileSizes());
623  size_t numExpectedLoops = sizes.size() - llvm::count(sizes, 0);
624  if (numExpectedLoops != getNumResults() - 1)
625  return emitOpError() << "expects " << numExpectedLoops << " loop results";
626 
627  return success();
628 }
629 
630 //===----------------------------------------------------------------------===//
631 // FuseIntoContainingOp
632 //===----------------------------------------------------------------------===//
633 
634 void transform::FuseIntoContainingOp::build(OpBuilder &builder,
635  OperationState &result,
636  Value producerOp,
637  Value containingOp) {
638  result.addOperands({producerOp, containingOp});
639  auto resultType = transform::AnyOpType::get(builder.getContext());
640  result.addTypes({resultType, resultType});
641 }
642 
643 /// Add new operands to the forall op for users of the producerOp
644 /// that are dominated by the containing scf.forall op.
646  RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
647  Operation *containingOp, TilingResult &tileAndFuseResult,
648  int64_t resultNumber, SmallVector<OpFoldResult> &offsets,
649  SmallVector<OpFoldResult> &sizes) {
650 
651  // Count number of users not including the containing op
652  SetVector<Operation *> dominatedUsers;
653  DominanceInfo domInfo(containingOp);
654  for (Operation *user : producerOp->getResult(resultNumber).getUsers()) {
655  if (!containingOp->isAncestor(user) &&
656  (domInfo.dominates(containingOp, user))) {
657  dominatedUsers.insert(user);
658  }
659  }
660  if (dominatedUsers.empty())
661  return nullptr;
662 
663  // Create new scf.forall op
664  auto forallOp = cast<scf::ForallOp>(containingOp);
665  OpBuilder::InsertionGuard g(rewriter);
666  rewriter.setInsertionPoint(forallOp);
667 
668  // Get new output
669  Location loc = forallOp.getLoc();
670  auto genericOp = dyn_cast<linalg::GenericOp>(producerOp);
671  if (!genericOp)
672  return nullptr;
673  SmallVector<Value> outputs = genericOp.getOutputs();
674  SmallVector<Value> newOuts(forallOp.getOutputs());
675  newOuts.push_back(outputs[resultNumber]);
676 
677  // Create new scf.forall op
678  auto newforallOp = rewriter.create<scf::ForallOp>(
679  loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
680  forallOp.getMixedStep(), newOuts, forallOp.getMapping());
681  rewriter.eraseBlock(newforallOp.getBody());
682  newforallOp.getRegion().takeBody(forallOp.getRegion());
683 
684  // Add additional block argument for new value being returned
685  // and replaces all uses of the new output with corresponding bbArg
686  // inside the scf.forall to enable fusion into this new scf.forall.
687  newforallOp.getBody()->addArgument(newOuts.back().getType(),
688  newOuts.back().getLoc());
689  auto bbArgs = newforallOp.getBody()->getArguments();
690  rewriter.replaceUsesWithIf(newOuts.back(), bbArgs.back(),
691  [&](OpOperand &use) {
692  Operation *op = use.getOwner();
693  return newforallOp->isProperAncestor(op);
694  });
695 
696  // Fix terminator
697  scf::InParallelOp terminatorOp = newforallOp.getTerminator();
698  SmallVector<Operation *> yieldingOps = llvm::to_vector<4>(llvm::map_range(
699  terminatorOp.getYieldingOps(), [](Operation &op) { return &op; }));
700  Operation *firstYieldOp = yieldingOps.front();
701  rewriter.setInsertionPoint(firstYieldOp);
702  Value src = tileAndFuseResult.tiledValues[0];
703  Value dst = newforallOp.getRegionIterArgs().back();
704  SmallVector<OpFoldResult> strides(offsets.size(), rewriter.getIndexAttr(1));
705  rewriter.create<tensor::ParallelInsertSliceOp>(firstYieldOp->getLoc(), src,
706  dst, offsets, sizes, strides);
707 
708  for (auto result : llvm::enumerate(forallOp.getResults())) {
709  rewriter.replaceAllUsesWith(result.value(),
710  newforallOp->getResult(result.index()));
711  }
712  rewriter.replaceUsesWithIf(producerOp->getResult(resultNumber),
713  newforallOp->getResults().back(),
714  [&](OpOperand &use) {
715  Operation *user = use.getOwner();
716  return dominatedUsers.contains(user);
717  });
718  return newforallOp;
719 }
720 
721 /// Given two operands coming from a loop iter arg, 'src' and 'dst', return true
722 /// if the operand 'src' is equal to 'dst' or equal to a iter arg present in a
723 /// outer loop. To determine the second condition, this function iterates
724 /// using a worklist over the enclosing loops, trying to find 'src' in any of
725 /// the parent loop's iter args.
726 static bool sameOrEquivalentIterArg(Value src, Value dst) {
727  // Stack like vector containing possible iterArgs candidates. The first one
728  // is dst, and we will transverse the IR from there.
729  SmallVector<Value> destWorklist;
730  destWorklist.push_back(dst);
731 
732  while (!destWorklist.empty()) {
733  Value currentDst = destWorklist.pop_back_val();
734 
735  // We have found the same operand in some iter arg in the loop structure,
736  // so src and dst are equivalent.
737  if (src == currentDst)
738  return true;
739 
740  // The operands are not equivalent, look for enclosing loops over
741  // currentDst.
742  auto bbArg = dyn_cast<BlockArgument>(currentDst);
743  if (!bbArg)
744  continue;
745 
746  Block *parentBlock = bbArg.getOwner();
747  assert(parentBlock && "unlinked block argument");
748 
749  Operation *parentOp = parentBlock->getParentOp();
750  assert(parentOp && "expected block argument with parent operation");
751 
752  // Check if parent is loop-like. If it's not, do not add it to the worklist.
753  auto parentLoop = dyn_cast<LoopLikeOpInterface>(parentOp);
754  if (!parentLoop)
755  continue;
756 
757  for (auto innerIterArg : parentLoop.getRegionIterArgs()) {
758  // No need to check for null as innerIterArg is tied to parentLoop.
759  OpOperand *operand = parentLoop.getTiedLoopInit(innerIterArg);
760  Value loopBlockArgument =
761  parentLoop->getOperand(operand->getOperandNumber());
762  destWorklist.push_back(loopBlockArgument);
763  }
764  }
765 
766  return false;
767 }
768 
769 /// Find the first "extract" user of `producerOp` and tile it right before its
770 /// use. The tiled op is fused under the `containingOp`.
771 /// Return this fused op on success or nullptr if anything fails.
772 /// If tiled op has uses that are dominated by `containingOp`, return
773 /// a new `containingOp` with results of the fused op appended to
774 /// results of the `containingOp` or nullptr if there are no dominated uses.
775 static std::tuple<SmallVector<Operation *>, Operation *>
777  Operation *producerOp, Operation *containingOp) {
778  LLVM_DEBUG(DBGS() << "Try to fuse a direct extract use\n");
779  auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
780  if (!tileableProducer) {
781  diag.attachNote(producerOp->getLoc())
782  << "producer is not a TileableInterface: " << *producerOp;
783  return {};
784  }
785 
786  // Search the producer slices accessed within the containing operation.
787  // TODO: Generalize to more extract/insert/parallel_insert triples, maybe
788  // evolve into an interface.
789  auto it = llvm::find_if(tileableProducer->getUsers(), [&](Operation *user) {
790  auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
791  return sliceOp && containingOp->isProperAncestor(sliceOp);
792  });
793 
794  // Find a fusion opportunity.
795  if (it == tileableProducer->getUsers().end()) {
796  diag.attachNote(tileableProducer->getLoc())
797  << "could not find fusion opportunity for: " << *tileableProducer;
798  return {};
799  }
800  auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*it);
801 
802  // Try to fuse the producer in-place.
803  OpBuilder::InsertionGuard guard(rewriter);
804  rewriter.setInsertionPoint(sliceOpToTile);
805 
806  // Clone the producer inside the consumer and try to update the producer init
807  // operands using the loop bbArgs if applicable. More precisely, if the bbArg
808  // of the container loop points to a value that it is used by the consumer op,
809  // then, instead of using such value on the consumer, use the value coming
810  // from the bbArg instead. This allows to reuse the output tensor (instead of
811  // creating a new one) of the container when both producer and container write
812  // to the same output.
813  if (LoopLikeOpInterface containerLoop =
814  dyn_cast<LoopLikeOpInterface>(sliceOpToTile->getParentOp())) {
815  Operation *clone = rewriter.clone(*producerOp);
816  rewriter.modifyOpInPlace(clone, [&]() {
817  // Iterate over the outputs of the producer and over the loop bbArgs and
818  // check if any bbArg points to the same value as the producer output. In
819  // such case, make the producer output point to the bbArg directly.
820  for (OpOperand &initOperandPtr :
821  cast<DestinationStyleOpInterface>(clone).getDpsInitsMutable()) {
822  Value producerOperand =
823  clone->getOperand(initOperandPtr.getOperandNumber());
824  for (BlockArgument containerIterArg :
825  containerLoop.getRegionIterArgs()) {
826  OpOperand *bbArg = containerLoop.getTiedLoopInit(containerIterArg);
827  Value consumerOperand =
828  containerLoop->getOperand(bbArg->getOperandNumber());
829  // The producer has the same init as the loop bbArg, use it.
830  if (sameOrEquivalentIterArg(producerOperand, consumerOperand)) {
831  initOperandPtr.set(containerIterArg);
832  }
833  }
834  }
835  });
836 
837  tileableProducer = dyn_cast<TilingInterface>(clone);
838  }
839 
840  // Tile the producer.
841  int64_t resultNumber =
842  cast<OpResult>(sliceOpToTile.getSource()).getResultNumber();
843  LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
844 
845  SmallVector<OpFoldResult> offsets = sliceOpToTile.getMixedOffsets();
846  SmallVector<OpFoldResult> sizes = sliceOpToTile.getMixedSizes();
847 
848  FailureOr<TilingResult> tileAndFuseResult =
849  tileableProducer.generateResultTileValue(rewriter, resultNumber, offsets,
850  sizes);
851 
852  if (failed(tileAndFuseResult)) {
853  diag.attachNote(tileableProducer->getLoc())
854  << "failed to tile producer op: " << *tileableProducer;
855  return {};
856  }
857 
858 #ifndef NDEBUG
859  for (auto *tiledOp : tileAndFuseResult->tiledOps) {
860  LLVM_DEBUG(DBGS() << "tiledProducer: " << *tiledOp << "\n");
861  }
862 #endif
863 
864  // Replace the extract op.
865  auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
866  rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
867  cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
868  if (failed(maybeRankReduced)) {
869  diag.attachNote(producerOp->getLoc())
870  << "shape types don't match (missing canonicalization?):\nTiledOp: "
871  << tileAndFuseResult->tiledValues[0]
872  << "\nSliceOp: " << sliceOpToTile.getOperation() << '\n';
873  return {};
874  }
875  rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
876 
877  // Add new outputs to containing op, if required
878  Operation *newContainingOp = replaceForAllWithNewSignature(
879  rewriter, diag, producerOp, containingOp, *tileAndFuseResult,
880  resultNumber, offsets, sizes);
881 
882  // Cleanup clone.
883  if (dyn_cast<LoopLikeOpInterface>(containingOp))
884  rewriter.eraseOp(tileableProducer);
885 
886  return std::make_tuple(tileAndFuseResult->tiledOps, newContainingOp);
887 }
888 
889 /// First, find the first "scf::ForallOp" user of `producerOp` and ensure
890 /// it is exactly the `containingOp`, otherwise bail.
891 /// Then, find the first "extract" user of the tied block argument and tile it
892 /// right before its "extract" use. The tiled op is fused under the
893 /// `containingOp`.
894 /// Return this fused op on success or nullptr if anything fails.
897  RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
898  Operation *containingOp) {
899  LLVM_DEBUG(DBGS() << "Try to fuse an extract use through block argument\n");
900 
901  auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
902  if (!tileableProducer) {
903  diag.attachNote(producerOp->getLoc())
904  << "producer is not a TileableInterface: " << *producerOp;
905  return {};
906  }
907 
908  // Search the first use by a "scf::ForallOp" user.
909  scf::ForallOp forallOp;
910  auto itProducerUses =
911  llvm::find_if(tileableProducer->getUses(), [&](OpOperand &use) {
912  forallOp = dyn_cast<scf::ForallOp>(use.getOwner());
913  return forallOp;
914  });
915  // If it's not from the containing op, return.
916  if (!forallOp || forallOp != containingOp) {
917  diag.attachNote(tileableProducer->getLoc())
918  << "could not find a use by the containing op: " << *tileableProducer;
919  return {};
920  }
921 
922  // Search the producer slices accessed within the containing
923  // operation.
924  // TODO: Generalize to more extract/insert/parallel_insert triples.
925  // Maybe evolve into an interface.
926  OpOperand *pUse = &(*itProducerUses);
927  BlockArgument bbArg = forallOp.getTiedBlockArgument(pUse);
928 
929  // Search the producer slices accessed within the containing operation.
930  // TODO: Generalize to more extract/insert/parallel_insert triples, maybe
931  // evolve into an interface.
932  auto itBBArgUsers = llvm::find_if(bbArg.getUsers(), [&](Operation *user) {
933  auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
934  return sliceOp && containingOp->isProperAncestor(sliceOp);
935  });
936 
937  // Find a fusion opportunity.
938  if (itBBArgUsers == bbArg.getUsers().end()) {
939  diag.attachNote(containingOp->getLoc())
940  << "could not find fusion opportunity for bbArg: " << bbArg;
941  return {};
942  }
943  auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*itBBArgUsers);
944 
945  // Try to fuse the producer in-place.
946  OpBuilder::InsertionGuard guard(rewriter);
947  rewriter.setInsertionPoint(sliceOpToTile);
948 
949  // Replace the use in the tileableProducer before tiling: clone, replace and
950  // then tile.
951  int64_t resultNumber = cast<OpResult>(pUse->get()).getResultNumber();
952  LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
953 
954  // Gather destination tensors.
955  SmallVector<Value> destinationTensors;
957  rewriter, tileableProducer->getLoc(), tileableProducer,
958  destinationTensors))) {
959  diag.attachNote(tileableProducer->getLoc())
960  << "failed to get destination tensors for: " << *tileableProducer;
961  return {};
962  }
963 
964  IRMapping bvm;
965  bvm.map(destinationTensors[resultNumber], bbArg);
966  auto tileableProducerClone =
967  cast<TilingInterface>(rewriter.clone(*tileableProducer, bvm));
968  auto scopeGuard =
969  llvm::make_scope_exit([&]() { rewriter.eraseOp(tileableProducerClone); });
970 
971  // Tile the producer.
972  FailureOr<TilingResult> tileAndFuseResult =
973  tileableProducerClone.generateResultTileValue(
974  rewriter, resultNumber, sliceOpToTile.getMixedOffsets(),
975  sliceOpToTile.getMixedSizes());
976  if (failed(tileAndFuseResult)) {
977  diag.attachNote(tileableProducer->getLoc())
978  << "failed to tile producer op: " << *tileableProducer;
979  return {};
980  }
981 
982  // Replace the extract op.
983  auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
984  rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
985  cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
986  assert(succeeded(maybeRankReduced) && "unexpected shape");
987  rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
988 
989  // Replace the use in containingOp.
990  rewriter.modifyOpInPlace(containingOp, [&]() {
991  containingOp->setOperand(pUse->getOperandNumber(),
992  destinationTensors.front());
993  });
994 
995  return tileAndFuseResult->tiledOps;
996 }
997 
999  Operation *producerOp,
1000  Operation *containingOp) {
1001  LLVM_DEBUG(DBGS() << "Try to fuse an use by cloning\n");
1002 
1003  // Gather all uses inside the containing op.
1005  for (OpResult result : producerOp->getOpResults()) {
1006  for (OpOperand &use : result.getUses()) {
1007  if (containingOp->isProperAncestor(use.getOwner())) {
1008  uses.push_back(&use);
1009  continue;
1010  }
1011  // Cannot clone and fuse if the use is by the containing op itself: fail
1012  // immediately.
1013  if (containingOp == use.getOwner()) {
1014  diag.attachNote(producerOp->getLoc())
1015  << "producer op use by containing op cannot be fused by cloning";
1016  return nullptr;
1017  }
1018  }
1019  }
1020 
1021  // Check for a non-empty list of fusion opportunities.
1022  if (uses.empty()) {
1023  diag.attachNote(producerOp->getLoc()) << "no fusion opportunity by cloning";
1024  return nullptr;
1025  }
1026 
1027  // Clone and fuse inside the containing op.
1028  Operation *fusedOp = nullptr;
1029  OpOperand *use = uses.front();
1030  // Parallel insert slice is not a valid clone destination.
1031  // TODO: Generalize to other type of ops.
1032  assert(!isa<tensor::ParallelInsertSliceOp>(use->getOwner()) &&
1033  "Parallel insert slice is not a valid clone destination");
1034  unsigned resultNumber = cast<OpResult>(use->get()).getResultNumber();
1035  LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
1036 
1037  OpBuilder::InsertionGuard guard(rewriter);
1038  rewriter.setInsertionPoint(use->getOwner());
1039  fusedOp = rewriter.clone(*producerOp);
1040  rewriter.modifyOpInPlace(
1041  use->getOwner(), [&] { use->set(fusedOp->getOpResult(resultNumber)); });
1042 
1043  return fusedOp;
1044 }
1045 
1046 bool transform::FuseIntoContainingOp::allowsRepeatedHandleOperands() {
1047  // Allow repeated handles since we are fusing everything anyway.
1048  return true;
1049 }
1050 
1052 transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter,
1053  transform::TransformResults &results,
1054  transform::TransformState &state) {
1055  SmallVector<Operation *> fusedOps;
1056  auto producerOps = state.getPayloadOps(getProducerOp());
1057  auto containingOps = state.getPayloadOps(getContainingOp());
1058  if (!llvm::hasSingleElement(containingOps)) {
1059  return emitDefiniteFailure()
1060  << "requires exactly one containing_op handle (got "
1061  << llvm::range_size(containingOps) << ")";
1062  }
1063  Operation *containingOp = *containingOps.begin();
1064 
1065  // If nothing to fuse, propagate success.
1066  if (std::empty(producerOps)) {
1067  results.set(cast<OpResult>(getFusedOp()), SmallVector<mlir::Operation *>{});
1068  results.set(cast<OpResult>(getNewContainingOp()), {containingOp});
1070  }
1071 
1072  // Helper function to find the next producer that should be fused. Take any
1073  // producer that has a use inside the containing op.
1074  SetVector<Operation *> remainingProducers(llvm::from_range, producerOps);
1075  auto getNextProducer = [&]() -> FailureOr<Operation *> {
1076  for (const auto &it : enumerate(remainingProducers)) {
1077  Operation *producerOp = it.value();
1078  // The containing op may be a user of producerOp: use isAncestor.
1079  int64_t numUsesInContainingOp =
1080  llvm::count_if(producerOp->getUsers(), [&](Operation *op) {
1081  return containingOp->isAncestor(op);
1082  });
1083  // TODO: When resolving the TODO below (no duplicate ops), take an op
1084  // that has no use among the remaining producers. This is a topological
1085  // sorting.
1086  if (numUsesInContainingOp > 0) {
1087  if (numUsesInContainingOp == 1)
1088  remainingProducers.erase(remainingProducers.begin() + it.index());
1089  return producerOp;
1090  }
1091  }
1092  return failure();
1093  };
1094 
1095  while (!remainingProducers.empty()) {
1096  auto nextProducer = getNextProducer();
1097  if (failed(nextProducer)) {
1098  auto diag = mlir::emitSilenceableFailure(getLoc())
1099  << "could not find next producer to fuse into container";
1100  diag.attachNote(containingOp->getLoc()) << "containing op";
1101  return diag;
1102  }
1103 
1104  Operation *producerOp = *nextProducer;
1105 
1106  // Default diagnostic, to be complemented with more failure information.
1108  diag << "could not fuse " << *producerOp << " into " << *containingOp;
1109 
1110  // TODO: If there are multiple uses of the producer in the containing op,
1111  // we currently tile/clone the op multiple times (once per use). In some
1112  // cases, we can tile/clone once and reuse the value for each use.
1113  // Futhermore, producers should then be traversed according to a
1114  // topological sorting.
1115  auto [tiledOps, newContainingOp] =
1116  tileAndFuseFirstExtractUse(rewriter, diag, producerOp, containingOp);
1117  if (!tiledOps.empty()) {
1118  LLVM_DEBUG(DBGS() << "\nFused a direct extract use\n" << *containingOp);
1119  fusedOps.append(tiledOps);
1120  if (newContainingOp) {
1121  // Update handles associated with the containing op so we don't need to
1122  // invalidate them. This is a hack to support better composability
1123  // between tiling and fusion while a proper mechanism is being
1124  // investigated.
1125  //
1126  // DO NOT replicate this elsewhere unless you understand what you are
1127  // doing.
1128  LogicalResult replacementStatus =
1129  rewriter.notifyPayloadOperationReplaced(containingOp,
1130  newContainingOp);
1131  (void)replacementStatus;
1132  assert(succeeded(replacementStatus) &&
1133  "unable to update transform state mapping");
1134  rewriter.eraseOp(containingOp);
1135  containingOp = newContainingOp;
1136  }
1137  continue;
1138  }
1139 
1140  SmallVector<Operation *> tiledContainingOpOperand =
1142  rewriter, diag, producerOp, containingOp);
1143  if (!tiledContainingOpOperand.empty()) {
1144  LLVM_DEBUG(DBGS() << "\nFused an extract use through block argument\n"
1145  << *containingOp);
1146  fusedOps.append(tiledContainingOpOperand);
1147  continue;
1148  }
1149 
1150  Operation *cloned =
1151  cloneAndFuseFirstUse(rewriter, diag, producerOp, containingOp);
1152  if (cloned) {
1153  LLVM_DEBUG(DBGS() << "\nFused an use by cloning\n" << *containingOp);
1154  fusedOps.push_back(cloned);
1155  continue;
1156  }
1158  }
1159 
1160  results.set(cast<OpResult>(getFusedOp()), fusedOps);
1161  results.set(cast<OpResult>(getNewContainingOp()), {containingOp});
1163 }
1164 
1165 void transform::FuseIntoContainingOp::getEffects(
1167  consumesHandle(getProducerOpMutable(), effects);
1168  onlyReadsHandle(getContainingOpMutable(), effects);
1169  producesHandle(getOperation()->getOpResults(), effects);
1170  modifiesPayload(effects);
1171 }
1172 
1173 //===----------------------------------------------------------------------===//
1174 // GeneralizeOp
1175 //===----------------------------------------------------------------------===//
1176 
1178 transform::GeneralizeOp::applyToOne(transform::TransformRewriter &rewriter,
1179  LinalgOp target,
1181  transform::TransformState &state) {
1182  // Exit early if no transformation is needed.
1183  if (isa<GenericOp>(target)) {
1184  results.push_back(target);
1186  }
1187  rewriter.setInsertionPoint(target);
1188  FailureOr<LinalgOp> generic = generalizeNamedOp(rewriter, target);
1189  if (succeeded(generic)) {
1190  results.push_back(generic->getOperation());
1192  }
1193  return emitDefaultSilenceableFailure(target);
1194 }
1195 
1196 //===----------------------------------------------------------------------===//
1197 // SpecializeOp
1198 //===----------------------------------------------------------------------===/
1199 
1201 transform::SpecializeOp::applyToOne(transform::TransformRewriter &rewriter,
1202  LinalgOp target,
1204  transform::TransformState &state) {
1205  // Exit early if the operation is not a generic.
1206  if (!isa<GenericOp>(target)) {
1207  results.push_back(target);
1209  }
1210  rewriter.setInsertionPoint(target);
1211  FailureOr<LinalgOp> named =
1212  specializeGenericOp(rewriter, cast<GenericOp>(target));
1213  if (succeeded(named)) {
1214  results.push_back(named->getOperation());
1216  }
1217  return emitDefaultSilenceableFailure(target);
1218 }
1219 
1220 //===----------------------------------------------------------------------===//
1221 // InterchangeOp
1222 //===----------------------------------------------------------------------===//
1223 
1225 transform::InterchangeOp::applyToOne(transform::TransformRewriter &rewriter,
1226  GenericOp target,
1228  transform::TransformState &state) {
1229  ArrayRef<int64_t> interchangeVector = getIteratorInterchange();
1230  // Exit early if no transformation is needed.
1231  if (interchangeVector.empty()) {
1232  results.push_back(target);
1234  }
1235 
1236  unsigned numLoops = cast<LinalgOp>(target.getOperation()).getNumLoops();
1237  if (interchangeVector.size() != numLoops) {
1238  return emitSilenceableError()
1239  << getIteratorInterchangeAttrName() << " has length ("
1240  << interchangeVector.size()
1241  << ") different from the number of loops in the target operation ("
1242  << numLoops << ")";
1243  }
1244  FailureOr<GenericOp> res = interchangeGenericOp(
1245  rewriter, target, SmallVector<unsigned>(interchangeVector));
1246  if (failed(res))
1247  return emitDefiniteFailure() << "failed to apply";
1248  results.push_back(res->getOperation());
1250 }
1251 
1252 LogicalResult transform::InterchangeOp::verify() {
1253  ArrayRef<int64_t> permutation = getIteratorInterchange();
1254  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
1255  if (!std::is_permutation(sequence.begin(), sequence.end(),
1256  permutation.begin(), permutation.end())) {
1257  return emitOpError()
1258  << "expects iterator_interchange to be a permutation, found "
1259  << getIteratorInterchange();
1260  }
1261  return success();
1262 }
1263 
1264 //===----------------------------------------------------------------------===//
1265 // LinalgCopyToMemrefOp
1266 //===----------------------------------------------------------------------===//
1267 
1268 DiagnosedSilenceableFailure transform::LinalgCopyToMemrefOp::applyToOne(
1269  transform::TransformRewriter &rewriter, Operation *targetOp,
1271  transform::TransformState &state) {
1272 
1273  // Check if the target can be converted.
1274  if (!isa<linalg::CopyOp>(targetOp)) {
1276  emitSilenceableError() << "only linalg.copy target ops are supported";
1277  diag.attachNote(targetOp->getLoc()) << "target op";
1278  return diag;
1279  }
1280 
1281  auto copyOp = dyn_cast<linalg::CopyOp>(targetOp);
1282  if (!copyOp.hasPureBufferSemantics()) {
1284  emitSilenceableError()
1285  << "cannot transform a linalg.copy on tensors into a memref.copy";
1286  diag.attachNote(targetOp->getLoc()) << "target op";
1287  return diag;
1288  }
1289 
1290  SmallVector<Value> inputs = copyOp.getInputs();
1291  SmallVector<Value> outputs = copyOp.getOutputs();
1292  assert(inputs.size() == 1 && "expected linalg copy op with one input");
1293  assert(outputs.size() == 1 && "expected memref copy op with one output");
1294  Value input = inputs.front();
1295  Value output = outputs.front();
1296 
1297  // linalg.copy supports different element types on source/dest whereas
1298  // memref.copy does not, so we must check that the source and dest types can
1299  // be handled by memref.copy and otherwise reject the transformation.
1300  if (!isa<ShapedType>(input.getType())) {
1302  emitSilenceableError()
1303  << "cannot transform a linalg.copy which input has no shape";
1304  diag.attachNote(targetOp->getLoc()) << "target op";
1305  return diag;
1306  }
1307 
1308  // linalg.copy destination must be a shaped type.
1309  assert(isa<ShapedType>(output.getType()));
1310 
1311  if (cast<ShapedType>(input.getType()).getElementType() !=
1312  cast<ShapedType>(output.getType()).getElementType()) {
1314  emitSilenceableError()
1315  << "cannot transform a linalg.copy with different source and "
1316  "destination element types ";
1317  diag.attachNote(targetOp->getLoc()) << "target op";
1318  return diag;
1319  }
1320 
1321  // Target can be converted, do it.
1322  auto memrefCopyOp =
1323  rewriter.replaceOpWithNewOp<memref::CopyOp>(targetOp, input, output);
1324 
1325  results.push_back(memrefCopyOp);
1327 }
1328 
1329 //===----------------------------------------------------------------------===//
1330 // LowerPackOp
1331 //===----------------------------------------------------------------------===//
1332 
1333 DiagnosedSilenceableFailure transform::LowerPackOp::applyToOne(
1334  transform::TransformRewriter &rewriter, linalg::PackOp target,
1335  transform::ApplyToEachResultList &transformResults,
1336  transform::TransformState &state) {
1337  rewriter.setInsertionPoint(target);
1338  bool lowerPadLikeWithInsertSlice = getLowerPadLikeWithInsertSlice();
1339  FailureOr<LowerPackResult> res =
1340  lowerPack(rewriter, target, lowerPadLikeWithInsertSlice);
1341  if (failed(res)) {
1342  return mlir::emitSilenceableFailure(target->getLoc())
1343  << "cannot lower to pad + expand + transpose";
1344  }
1345  transformResults.push_back(res->padOp);
1346  transformResults.push_back(res->expandShapeOp);
1347  transformResults.push_back(res->transposeOp);
1349 }
1350 
1351 //===----------------------------------------------------------------------===//
1352 // LowerUnPackOp
1353 //===----------------------------------------------------------------------===//
1354 
1355 DiagnosedSilenceableFailure transform::LowerUnPackOp::applyToOne(
1356  transform::TransformRewriter &rewriter, linalg::UnPackOp target,
1357  transform::ApplyToEachResultList &transformResults,
1358  transform::TransformState &state) {
1359  rewriter.setInsertionPoint(target);
1360  bool lowerUnpadLikeWithExtractSlice = getLowerUnpadLikeWithExtractSlice();
1361  FailureOr<LowerUnPackOpResult> res =
1362  lowerUnPack(rewriter, target, lowerUnpadLikeWithExtractSlice);
1363  if (failed(res)) {
1365  emitSilenceableError()
1366  << "cannot lower to transpose + collapse + extract";
1367  diag.attachNote(target->getLoc()) << "target payload op";
1368  return diag;
1369  }
1370  transformResults.push_back(res->emptyOp);
1371  transformResults.push_back(res->transposeOp);
1372  transformResults.push_back(res->collapseShapeOp);
1373  transformResults.push_back(res->extractSliceOp);
1375 }
1376 
1377 //===---------------------------------------------------------------------===//
1378 // MatchOp
1379 //===---------------------------------------------------------------------===//
1380 
1381 void transform::MatchOp::build(OpBuilder &builder, OperationState &result,
1382  Value target, ArrayRef<StringRef> opNames) {
1383  result.addOperands(target);
1384  result.addAttribute(MatchOp::getOpsAttrName(result.name),
1385  builder.getStrArrayAttr(opNames));
1386  result.addTypes(transform::AnyOpType::get(builder.getContext()));
1387 }
1388 
1389 void transform::MatchOp::build(OpBuilder &builder, OperationState &result,
1390  TypeRange resultTypes, Value target,
1391  ArrayRef<StringRef> opNames) {
1392  result.addOperands(target);
1393  result.addAttribute(MatchOp::getOpsAttrName(result.name),
1394  builder.getStrArrayAttr(opNames));
1395  result.addTypes(resultTypes);
1396 }
1397 
1399 transform::MatchOp::apply(transform::TransformRewriter &rewriter,
1400  transform::TransformResults &results,
1401  transform::TransformState &state) {
1402  llvm::StringSet<> strs;
1403  if (getOps().has_value())
1404  strs.insert_range(getOps()->getAsValueRange<StringAttr>());
1405 
1406  auto payloadOps = state.getPayloadOps(getTarget());
1407  if (!llvm::hasSingleElement(payloadOps)) {
1408  return emitDefiniteFailure("requires exactly one target handle");
1409  }
1410 
1412  bool incorrectNumOperandTypes = false;
1413  auto matchFun = [&](Operation *op) {
1414  if (getOps().has_value() && !strs.contains(op->getName().getStringRef()))
1415  return;
1416 
1417  // Interfaces cannot be matched by name, just by ID.
1418  // So we specifically encode the interfaces we care about for this op.
1419  if (getInterface().has_value()) {
1420  auto iface = getInterface().value();
1421  if (iface == transform::MatchInterfaceEnum::LinalgOp &&
1422  !isa<LinalgOp>(op))
1423  return;
1424  if (iface == transform::MatchInterfaceEnum::TilingInterface &&
1425  !isa<TilingInterface>(op))
1426  return;
1427  if (iface == transform::MatchInterfaceEnum::LoopLikeInterface &&
1428  !isa<LoopLikeOpInterface>(op))
1429  return;
1430  }
1431 
1432  // Check if all specified attributes match.
1433  if (getOpAttrs().has_value()) {
1434  DictionaryAttr opAttrs = getOpAttrs().value();
1435  for (NamedAttribute attr : opAttrs) {
1436  if (attr.getName() == getInterfaceAttrName() ||
1437  attr.getName() == getOpsAttrName())
1438  continue;
1439  if (!op->hasAttr(attr.getName()))
1440  return;
1441  if (op->getAttr(attr.getName()) != attr.getValue())
1442  return;
1443  }
1444  }
1445 
1446  if (getFilterResultType().has_value()) {
1447  Type t = getFilterResultType().value();
1448  if (op->getNumResults() != 1 || op->getResultTypes().front() != t)
1449  return;
1450  }
1451 
1452  if (getFilterOperandTypes().has_value()) {
1453  mlir::ArrayAttr types = getFilterOperandTypes().value();
1454  auto operandTypes = op->getOperandTypes();
1455 
1456  if (types.size() == 1) {
1457  // All the operands must must be equal to the specified type
1458  auto typeattr =
1459  dyn_cast<mlir::TypeAttr>(getFilterOperandTypes().value()[0]);
1460  Type t = cast<::mlir::Type>(typeattr.getValue());
1461  if (!llvm::all_of(op->getOperandTypes(),
1462  [&](Type operandType) { return operandType == t; }))
1463  return;
1464  } else {
1465  // The operand types must match all the types in the list (in the same
1466  // order in with they are specified)
1467  if (types.size() != operandTypes.size()) {
1468  incorrectNumOperandTypes = true;
1469  return;
1470  }
1471 
1472  for (auto [attr, operandType] :
1473  llvm::zip_equal(getFilterOperandTypes().value(), operandTypes)) {
1474  auto typeattr = cast<mlir::TypeAttr>(attr);
1475  Type type = cast<::mlir::Type>(typeattr.getValue());
1476 
1477  if (type != operandType)
1478  return;
1479  }
1480  }
1481  }
1482 
1483  // All constraints are satisfied.
1484  res.push_back(op);
1485  return;
1486  };
1487 
1488  (*payloadOps.begin())->walk(matchFun);
1489  if (incorrectNumOperandTypes)
1490  return emitDefiniteFailure("If filter_operand_types contains more than a "
1491  "type, then it must contain as much types as "
1492  "the number of operands in the target ops");
1493  results.set(cast<OpResult>(getResult()), res);
1495 }
1496 
1497 //===---------------------------------------------------------------------===//
1498 // MultiTileSizesOp
1499 //===---------------------------------------------------------------------===//
1500 
1502  Type targetType, Type lowSizeType, Type,
1503  Type) {
1504  printer.printFunctionalType(TypeRange{targetType}, TypeRange{lowSizeType});
1505 }
1506 
1507 static ParseResult parseMultitileSizesTypes(OpAsmParser &parser,
1508  Type &targetType, Type &lowSizeType,
1509  Type &highSizeType,
1510  Type &splitPointType) {
1511  FunctionType funcType;
1512  llvm::SMLoc typeLoc = parser.getCurrentLocation();
1513  if (failed(parser.parseType<FunctionType>(funcType)))
1514  return failure();
1515 
1516  if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
1517  parser.emitError(typeLoc) << "expects a trailing functional type with one "
1518  "argument and one result";
1519  }
1520  targetType = funcType.getInput(0);
1521  lowSizeType = highSizeType = splitPointType = funcType.getResult(0);
1522 
1523  return success();
1524 }
1525 
1526 DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne(
1527  transform::TransformRewriter &rewriter, LinalgOp target,
1529  if (isa<TransformParamTypeInterface>(getLowSize().getType())) {
1530  if (target.hasDynamicShape()) {
1531  auto diag = emitSilenceableError()
1532  << "cannot compute parametric tile sizes for dynamically "
1533  "shaped payload op";
1534  diag.attachNote(target->getLoc()) << "payload op";
1535  return diag;
1536  }
1537 
1538  FailureOr<StaticMultiSizeSpecification> spec = computeStaticMultiTileSizes(
1539  target, getDimension(), getTargetSize(), getDivisor());
1540  if (failed(spec)) {
1541  return emitSilenceableError()
1542  << "failed to compute multi-size tiling sizes";
1543  }
1544 
1545  Builder builder(target.getContext());
1546  results.assign(llvm::map_range(
1547  ArrayRef<int64_t>({spec->lowTileSize, spec->highTileSize,
1548  spec->lowTileSize * spec->lowTripCount}),
1549  [&builder, this](int64_t value) {
1550  return builder.getIntegerAttr(
1551  cast<ParamType>(getLowSize().getType()).getType(), value);
1552  }));
1554  }
1555 
1556  OpBuilder builder(target.getContext());
1557  builder.setInsertionPoint(target);
1558  OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
1559  OpFoldResult divisor = builder.getIndexAttr(getDivisor());
1560  FailureOr<MultiSizeSpecification> spec = computeMultiTileSizes(
1561  builder, target, getDimension(), targetSize, divisor);
1562  if (failed(spec)) {
1563  return emitSilenceableError() << "could not generate tile size computation";
1564  }
1565 
1566  AffineExpr s0 = builder.getAffineSymbolExpr(0);
1567  AffineExpr s1 = builder.getAffineSymbolExpr(1);
1568  Operation *splitPoint =
1569  affine::makeComposedAffineApply(builder, target.getLoc(), s0 * s1,
1570  {spec->lowTileSize, spec->lowTripCount});
1571  Operation *lowTileSize = spec->lowTileSize.getDefiningOp();
1572  Operation *highTileSize = spec->highTileSize.getDefiningOp();
1573  assert(lowTileSize && highTileSize && splitPoint &&
1574  "tile sizes are not produced by operations");
1575  results.reserve(results.size() + 3);
1576  results.push_back(lowTileSize);
1577  results.push_back(highTileSize);
1578  results.push_back(splitPoint);
1580 }
1581 
1582 void transform::MultiTileSizesOp::getEffects(
1584  onlyReadsHandle(getTargetMutable(), effects);
1585  producesHandle(getOperation()->getOpResults(), effects);
1586  if (isa<TransformParamTypeInterface>(getLowSize().getType()))
1587  onlyReadsPayload(effects);
1588  else
1589  modifiesPayload(effects);
1590 }
1591 
1592 LogicalResult transform::MultiTileSizesOp::verify() {
1593  if (getLowSize().getType() != getHighSize().getType() ||
1594  getLowSize().getType() != getSplitPoint().getType()) {
1595  return emitOpError() << "expects all results type to be the same";
1596  }
1597  return success();
1598 }
1599 
1600 //===---------------------------------------------------------------------===//
1601 // PackOp
1602 //===---------------------------------------------------------------------===//
1603 
1604 void transform::PackOp::build(OpBuilder &builder, OperationState &result,
1605  Value target,
1606  ArrayRef<OpFoldResult> mixedPackedSizes) {
1607  SmallVector<int64_t> staticPackedSizes;
1608  SmallVector<Value> dynamicPackedSizes;
1609  dispatchIndexOpFoldResults(mixedPackedSizes, dynamicPackedSizes,
1610  staticPackedSizes);
1611  // Call the default builder which sets up the proper operands segment sizes
1612  // attributes for multiple variadic operands. In the absence of this, horrible
1613  // bugs ensue.
1614  Type linalgOpHType = transform::OperationType::get(
1615  builder.getContext(), GenericOp::getOperationName());
1616  build(builder, result,
1617  /*resultType=*/linalgOpHType,
1618  /*target=*/target,
1619  /*dynamic_sizes=*/dynamicPackedSizes,
1620  /*static_sizes=*/builder.getDenseI64ArrayAttr(staticPackedSizes));
1621 }
1622 
1623 SmallVector<OpFoldResult> transform::PackOp::getMixedPackedSizes() {
1624  Builder b(getContext());
1625  return getMixedValues(getStaticPackedSizes(), getPackedSizes(), b);
1626 }
1627 
1629 transform::PackOp::apply(transform::TransformRewriter &rewriter,
1630  transform::TransformResults &transformResults,
1631  transform::TransformState &state) {
1632  auto targetOps = state.getPayloadOps(getTarget());
1633  // If nothing to pack, propagate success.
1634  if (std::empty(targetOps)) {
1635  transformResults.set(cast<OpResult>(getPackedOp()),
1636  ArrayRef<Operation *>({}));
1638  }
1639  // Fail on multi-op handles.
1640  auto linalgOp = dyn_cast<LinalgOp>(*targetOps.begin());
1641  if (!llvm::hasSingleElement(targetOps) || !linalgOp) {
1642  return emitSilenceableError()
1643  << "requires target to map to exactly 1 LinalgOp (got "
1644  << llvm::range_size(targetOps) << ")";
1645  }
1646  // Fail on mismatched number of pack sizes.
1647  if (getMixedPackedSizes().size() != linalgOp.getNumLoops()) {
1648  return emitSilenceableError()
1649  << "requires number of packed sizes match the number of loops ("
1650  << getMixedPackedSizes().size() << " vs " << linalgOp.getNumLoops()
1651  << ")";
1652  }
1653 
1654  // Unpack handles to constants or actual SSA index values.
1655  SmallVector<OpFoldResult> packedSizes;
1657  state, *this, packedSizes, getMixedPackedSizes());
1658 
1659  rewriter.setInsertionPoint(linalgOp);
1660  FailureOr<PackResult> maybeResult = pack(rewriter, linalgOp, packedSizes);
1661  if (failed(maybeResult))
1662  return emitDefiniteFailure("data tiling failed");
1663 
1664  transformResults.set(cast<OpResult>(getPackedOp()),
1665  {maybeResult->packedLinalgOp.getOperation()});
1667 }
1668 
1669 void transform::PackOp::getEffects(
1671  transform::consumesHandle(getTargetMutable(), effects);
1672  transform::onlyReadsHandle(getPackedSizesMutable(), effects);
1673  transform::producesHandle(getOperation()->getOpResults(), effects);
1674  transform::modifiesPayload(effects);
1675 }
1676 
1677 //===---------------------------------------------------------------------===//
1678 // PackGreedilyOp.
1679 //===---------------------------------------------------------------------===//
1680 
1681 LogicalResult transform::PackGreedilyOp::verify() {
1682  if (!isPermutationVector(getMatmulInnerDimsOrder())) {
1683  return emitOpError() << getMatmulInnerDimsOrderAttrName()
1684  << " is not a valid permutation";
1685  }
1686  // TODO: relax to allow empty once we have another strategy than just matmul.
1687  if (!getMatmulPaddedSizesNextMultipleOf().empty()) {
1688  for (auto [s, nmo] :
1689  llvm::zip_equal(getMixedMatmulPackedSizes(),
1690  getMatmulPaddedSizesNextMultipleOf())) {
1691  std::optional<int64_t> maybeStaticPackedSize = getConstantIntValue(s);
1692  if (nmo != 0 &&
1693  (!maybeStaticPackedSize.has_value() || *maybeStaticPackedSize != 0)) {
1694  return emitOpError() << "at most one of the packed_size and the "
1695  "padded_sizes_next_multiple_of can be nonzero "
1696  "for the matmul strategy";
1697  }
1698  }
1699  }
1700  return success();
1701 }
1702 
1704 PackGreedilyOp::apply(transform::TransformRewriter &rewriter,
1705  transform::TransformResults &transformResults,
1706  transform::TransformState &state) {
1707  SmallVector<Operation *> results;
1708  for (Operation *op : state.getPayloadOps(getTarget())) {
1709  auto linalgOp = dyn_cast<LinalgOp>(op);
1710  if (!linalgOp)
1711  continue;
1712  // linalgOp will be replaced and the insertion point may be invalidated if
1713  // we set it before -> set it after.
1714  rewriter.setInsertionPointAfter(linalgOp);
1715  // Failing to pack greedily is perfectly fine.
1716  // In the future we will want to order packings according to some metric.
1717  FailureOr<PackResult> packResult = packMatmulGreedily(
1718  /*rewriter=*/rewriter,
1719  /*linalgOp=*/linalgOp,
1720  /*mnkPackedSizes=*/getMixedMatmulPackedSizes(),
1721  /*mnkPaddedSizesNextMultipleOf=*/
1722  getMatmulPaddedSizesNextMultipleOf(),
1723  /*mnkOrder=*/getMatmulInnerDimsOrder());
1724  if (succeeded(packResult)) {
1725  results.push_back(packResult->packedLinalgOp);
1726  continue;
1727  }
1728  results.push_back(linalgOp);
1729  }
1730  transformResults.set(cast<OpResult>(getPackedOp()), results);
1732 }
1733 
1734 SmallVector<OpFoldResult> PackGreedilyOp::getMixedMatmulPackedSizes() {
1735  Builder b(getContext());
1736  return getMixedValues(getStaticMatmulPackedSizes(), getMatmulPackedSizes(),
1737  b);
1738 }
1739 
1740 void transform::PackGreedilyOp::getEffects(
1742  transform::consumesHandle(getTargetMutable(), effects);
1743  transform::onlyReadsHandle(getMatmulPackedSizesMutable(), effects);
1744  transform::producesHandle(getOperation()->getOpResults(), effects);
1745  transform::modifiesPayload(effects);
1746 }
1747 
1748 //===---------------------------------------------------------------------===//
1749 // PackTransposeOp
1750 //===---------------------------------------------------------------------===//
1751 
1752 LogicalResult transform::PackTransposeOp::verify() {
1753  if (!isPermutationVector(getInnerPerm())) {
1754  return emitOpError() << getInnerPermAttrName()
1755  << " is not a valid permutation";
1756  }
1757  if (!isPermutationVector(getOuterPerm())) {
1758  return emitOpError() << getOuterPermAttrName()
1759  << " is not a valid permutation";
1760  }
1761  if (getInnerPerm().empty() && getOuterPerm().empty()) {
1762  return emitOpError() << " at least one of " << getInnerPermAttrName()
1763  << " or " << getOuterPermAttrName()
1764  << " must be specified";
1765  }
1766  return success();
1767 }
1768 
1769 namespace {
1770 enum class OuterOrInnerPerm { Outer = 0, Inner = 1 };
1771 } // namespace
1772 
1773 /// Return true if `permutation` is a valid permutation of the
1774 /// `outer_dims_perm` (case OuterOrInnerPerm::Outer) or `inner_dims_pos`
1775 /// (OuterOrInnerPerm::Inner) of the `tensor.pack` or `tensor.unpack` `op.
1776 /// This is the case when the `permutation` rank matches the rank expected by
1777 /// `op` and `permutation` is itself a permutation vector.
1778 /// Return true if either `op` or `permutation` are empty to allow a simpler
1779 /// polymorphic implementation.
1780 template <typename RelayoutOpTy>
1782  RelayoutOpTy op, ArrayRef<int64_t> permutation,
1783  OuterOrInnerPerm outerOrInnerPerm = OuterOrInnerPerm::Outer) {
1784  static_assert(
1785  llvm::is_one_of<RelayoutOpTy, linalg::PackOp, linalg::UnPackOp>::value,
1786  "applies to only pack or unpack operations");
1787  if (!op || permutation.empty())
1788  return true;
1789  size_t innerRank = op.getInnerDimsPos().size();
1790  if (outerOrInnerPerm == OuterOrInnerPerm::Inner)
1791  return permutation.size() == innerRank && isPermutationVector(permutation);
1792  // op.getOuterDimsPerm() may be empty, in which case it is identity.
1793  // Don't rely on it.
1794  if (std::is_same<RelayoutOpTy, linalg::PackOp>::value) {
1795  return permutation.size() == op.getSourceRank() &&
1796  isPermutationVector(permutation);
1797  }
1798  return permutation.size() == op.getDestRank() &&
1799  isPermutationVector(permutation);
1800 }
1801 
1803 transform::PackTransposeOp::apply(transform::TransformRewriter &rewriter,
1804  transform::TransformResults &transformResults,
1805  transform::TransformState &state) {
1806  auto packOrUnpackOps = state.getPayloadOps(getTargetPackOrUnPackOp());
1807  auto linalgOps = state.getPayloadOps(getTargetLinalgOp());
1808  // Step 1. If nothing to pack, propagate success.
1809  if (std::empty(packOrUnpackOps)) {
1810  transformResults.set(cast<OpResult>(getPackedOp()), {});
1811  transformResults.set(cast<OpResult>(getPackOp()), {});
1812  transformResults.set(cast<OpResult>(getUnPackOp()), {});
1814  }
1815 
1816  // Step 2. Bunch of runtime sanity check and error messages.
1817  // Step 2.1. Fail on multi-op handles.
1818  if (!llvm::hasSingleElement(packOrUnpackOps) ||
1819  !llvm::hasSingleElement(linalgOps)) {
1820  return emitSilenceableError()
1821  << "requires target to map to exactly 1 "
1822  "packing op and 1 packed op ("
1823  << "got " << llvm::range_size(packOrUnpackOps) << " and "
1824  << llvm::range_size(linalgOps) << ")";
1825  }
1826 
1827  // Step 2.2. Fail on wrong type.
1828  auto packOp = dyn_cast<linalg::PackOp>(*packOrUnpackOps.begin());
1829  auto unPackOp = dyn_cast<linalg::UnPackOp>(*packOrUnpackOps.begin());
1830  if ((!packOp && !unPackOp)) {
1831  return emitSilenceableError() << "requires target to map to a "
1832  "linalg.pack or linalg.unpack";
1833  }
1834  LinalgOp linalgOpTarget = dyn_cast<LinalgOp>(*linalgOps.begin());
1835  if (!linalgOpTarget)
1836  return emitSilenceableError() << "requires a LinalgOp target";
1837 
1838  // Step 2.3. Fail if we can't get the producer / consumer Linalg op.
1839  LinalgOp linalgOp;
1840  if (packOp && packOp.getResult().hasOneUse())
1841  linalgOp = dyn_cast<LinalgOp>(*(packOp.getResult().getUsers().begin()));
1842  else if (unPackOp)
1843  linalgOp = unPackOp.getSource().getDefiningOp<LinalgOp>();
1844  if (linalgOp != linalgOpTarget) {
1845  auto errorMsg =
1846  packOp ? StringLiteral{"not a single use by the LinalgOp target"}
1847  : StringLiteral{"not produced by the LinalgOp target"};
1848  return emitSilenceableError() << errorMsg;
1849  }
1850 
1851  // Step 2.4. If we have an UnPackOp, we need to fetch the symmetrical
1852  // PackOp.
1853  if (unPackOp) {
1854  assert(!packOp && "packOp must be null on entry when unPackOp is not null");
1855  OpOperand *packUse = linalgOp.getDpsInitOperand(
1856  cast<OpResult>(unPackOp.getSource()).getResultNumber());
1857  packOp = dyn_cast_or_null<linalg::PackOp>(packUse->get().getDefiningOp());
1858  if (!packOp || !packOp.getResult().hasOneUse())
1859  return emitSilenceableError() << "could not find matching pack op";
1860  }
1861 
1862  // Step 2.5. Fail if any permutation does not validate.
1863  for (auto permType : {OuterOrInnerPerm::Outer, OuterOrInnerPerm::Inner}) {
1864  ArrayRef<int64_t> perm =
1865  (permType == OuterOrInnerPerm::Outer) ? getOuterPerm() : getInnerPerm();
1866  auto errorMsg = (permType == OuterOrInnerPerm::Outer)
1867  ? StringLiteral{"invalid outer_perm"}
1868  : StringLiteral{"invalid inner_perm"};
1869  if (!isValidPackingPermutation(packOp, perm, permType) ||
1870  !isValidPackingPermutation(unPackOp, perm, permType)) {
1871  Operation *packOrUnpackOp =
1872  unPackOp ? unPackOp.getOperation() : packOp.getOperation();
1873  return emitSilenceableError() << errorMsg << ": " << *packOrUnpackOp;
1874  }
1875  }
1876 
1877  // From here on, packOp and linalgOp are always present, unPackOp may or may
1878  // not be present.
1879  assert(packOp && linalgOp && "unexpected null op");
1880 
1881  // Step 3. Actually transpose the ops.
1882  FailureOr<PackTransposeResult> res = packTranspose(
1883  rewriter, packOp, linalgOp, unPackOp, getOuterPerm(), getInnerPerm());
1884  // Preconditions have been checked, it is an error to fail here.
1885  assert(succeeded(res) && "unexpected packTranspose failure");
1886 
1887  // Step 4. Return results.
1888  transformResults.set(cast<OpResult>(getPackOp()), {res->transposedPackOp});
1889  transformResults.set(cast<OpResult>(getPackedOp()),
1890  {res->transposedLinalgOp});
1891  if (unPackOp) {
1892  transformResults.set(cast<OpResult>(getUnPackOp()),
1893  {res->transposedUnPackOp});
1894  } else {
1895  transformResults.set(cast<OpResult>(getUnPackOp()), {});
1896  }
1897 
1899 }
1900 
1901 //===---------------------------------------------------------------------===//
1902 // PadOp
1903 //===---------------------------------------------------------------------===//
1904 
1905 void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
1906  ArrayRef<int64_t> paddingDimensions,
1907  ArrayRef<int64_t> padToMultipleOf,
1908  ArrayRef<int64_t> nofoldFlags,
1909  ArrayRef<Attribute> transposePaddings,
1910  StringRef copyBackOp) {
1911  auto resultType = transform::AnyOpType::get(b.getContext());
1912  return build(/*builder=*/b,
1913  /*result=*/result,
1914  /*types=*/TypeRange{resultType, resultType},
1915  /*target=*/target,
1916  /*paddingValues=*/ArrayAttr(), // let inference handle this
1917  /*paddingDimensions=*/b.getI64ArrayAttr(paddingDimensions),
1918  /*padToMultipleOf=*/ValueRange{},
1919  /*padToMultipleOf=*/
1920  (padToMultipleOf.empty()
1921  ? DenseI64ArrayAttr()
1922  : b.getDenseI64ArrayAttr(padToMultipleOf)),
1923  /*nofoldFlags=*/b.getI64ArrayAttr(nofoldFlags),
1924  /*transposePaddings=*/b.getArrayAttr(transposePaddings),
1925  /*copyBackOp=*/b.getStringAttr(copyBackOp));
1926 }
1927 
1928 void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
1929  ArrayRef<int64_t> paddingDimensions,
1930  ArrayRef<OpFoldResult> mixedPadToMultipleOf,
1931  ArrayRef<int64_t> nofoldFlags,
1932  ArrayRef<Attribute> transposePaddings,
1933  StringRef copyBackOp) {
1934  auto resultType = transform::AnyOpType::get(b.getContext());
1935  SmallVector<int64_t> staticPadToMultipleOf;
1936  SmallVector<Value> dynamicPadToMultipleOf;
1937  dispatchIndexOpFoldResults(mixedPadToMultipleOf, dynamicPadToMultipleOf,
1938  staticPadToMultipleOf);
1939  return build(/*builder=*/b,
1940  /*result=*/result,
1941  /*types=*/TypeRange{resultType, resultType},
1942  /*target=*/target,
1943  /*paddingValues=*/ArrayAttr(), // let inference handle this
1944  /*paddingDimensions=*/b.getI64ArrayAttr(paddingDimensions),
1945  /*padToMultipleOf=*/dynamicPadToMultipleOf,
1946  /*padToMultipleOf=*/staticPadToMultipleOf,
1947  /*nofoldFlags=*/b.getI64ArrayAttr(nofoldFlags),
1948  /*transposePaddings=*/b.getArrayAttr(transposePaddings),
1949  /*copyBackOp=*/b.getStringAttr(copyBackOp));
1950 }
1951 
1952 void PadOp::getEffects(
1954  consumesHandle(getTargetMutable(), effects);
1955  onlyReadsHandle(getPadToMultipleOfMutable(), effects);
1956  producesHandle(getOperation()->getOpResults(), effects);
1957  modifiesPayload(effects);
1958 }
1959 
1960 SmallVector<OpFoldResult> PadOp::getMixedPadToMultipleOf() {
1961  Builder b(getContext());
1962  return getMixedValues(getStaticPadToMultipleOf(), getPadToMultipleOf(), b);
1963 }
1964 
1966 transform::PadOp::apply(transform::TransformRewriter &rewriter,
1967  transform::TransformResults &results,
1968  transform::TransformState &state) {
1969  auto transformOp = cast<TransformOpInterface>(getOperation());
1970  SmallVector<Operation *> paddedOps, padOps, copyBackOps;
1971 
1972  for (Operation *target : state.getPayloadOps(getTarget())) {
1973  auto linalgTarget = dyn_cast<LinalgOp>(target);
1974  if (!linalgTarget) {
1975  auto diag = emitSilenceableError() << "expected LinalgOp target";
1976  diag.attachNote(target->getLoc()) << "target op";
1977  return diag;
1978  }
1979 
1980  // Convert the integer packing flags to booleans.
1981  SmallVector<bool> nofoldFlags;
1982  for (int64_t packPadding :
1983  extractFromIntegerArrayAttr<int64_t>(getNofoldFlags()))
1984  nofoldFlags.push_back(static_cast<bool>(packPadding));
1985 
1986  // Convert the padding values to attributes.
1987  SmallVector<Attribute> paddingValues;
1988  for (auto const &it :
1989  llvm::zip(getPaddingValues(), linalgTarget->getOperandTypes())) {
1990  auto attr = dyn_cast<TypedAttr>(std::get<0>(it));
1991  if (!attr) {
1992  emitOpError("expects padding values to be typed attributes");
1994  }
1995  Type elementType = getElementTypeOrSelf(std::get<1>(it));
1996  // Try to parse string attributes to obtain an attribute of element type.
1997  if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
1998  auto parsedAttr = dyn_cast_if_present<TypedAttr>(parseAttribute(
1999  stringAttr, getContext(), elementType,
2000  /*numRead=*/nullptr, /*isKnownNullTerminated=*/true));
2001  if (!parsedAttr || parsedAttr.getType() != elementType) {
2002  auto diag = this->emitOpError("expects a padding that parses to ")
2003  << elementType << ", got " << std::get<0>(it);
2004  diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";
2006  }
2007  paddingValues.push_back(parsedAttr);
2008  continue;
2009  }
2010  // Otherwise, add the attribute directly.
2011  if (attr.getType() != elementType) {
2012  auto diag = this->emitOpError("expects a padding value of type ")
2013  << elementType << ", got " << attr;
2014  diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";
2016  }
2017  paddingValues.push_back(attr);
2018  }
2019 
2020  // Extract the transpose vectors.
2021  SmallVector<SmallVector<int64_t>> transposePaddings;
2022  for (Attribute transposeVector : cast<ArrayAttr>(getTransposePaddings()))
2023  transposePaddings.push_back(extractFromIntegerArrayAttr<int64_t>(
2024  cast<ArrayAttr>(transposeVector)));
2025 
2026  LinalgOp paddedOp;
2028  options.paddingDimensions =
2029  extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
2030 
2031  SmallVector<int64_t> padToMultipleOf;
2033  state, transformOp, getMixedPadToMultipleOf(), padToMultipleOf);
2034  if (!status.succeeded())
2035  return status;
2036  if (padToMultipleOf.empty())
2037  padToMultipleOf =
2038  SmallVector<int64_t>(options.paddingDimensions.size(), 1);
2039 
2040  options.padToMultipleOf = padToMultipleOf;
2041  options.paddingValues = paddingValues;
2042  options.nofoldFlags = nofoldFlags;
2043  if (getCopyBackOp() ==
2044  bufferization::MaterializeInDestinationOp::getOperationName()) {
2047  } else if (getCopyBackOp() == linalg::CopyOp::getOperationName()) {
2049  } else if (getCopyBackOp() == kCopyOpNone) {
2051  } else {
2052  llvm_unreachable("unsupported copy_back op");
2053  }
2054 
2055  SmallVector<Value> replacements;
2056  SmallVector<tensor::PadOp> newPadOps;
2057  if (failed(rewriteAsPaddedOp(rewriter, linalgTarget, options, paddedOp,
2058  replacements, newPadOps))) {
2059  auto diag = emitSilenceableError() << "failed to pad op";
2060  diag.attachNote(target->getLoc()) << "target op";
2061  return diag;
2062  }
2063 
2064  // We need to perform our own replacement here because this API is still
2065  // used in patterns that "pad and hoist", for which the replacement values
2066  // need to be different.
2067  // TODO: clean this up and stop "pad and hoist" behavior more globally now
2068  // that we have more composable abstractions.
2069  rewriter.replaceOp(linalgTarget, replacements);
2070  paddedOps.push_back(paddedOp);
2071  padOps.append(newPadOps.begin(), newPadOps.end());
2072  if (options.copyBackOp != LinalgPaddingOptions::CopyBackOp::None) {
2073  for (Value v : replacements) {
2074  Operation *copyBackOp = v.getDefiningOp();
2075  if (!llvm::is_contained(copyBackOps, copyBackOp))
2076  copyBackOps.push_back(copyBackOp);
2077  }
2078  }
2079  }
2080 
2081  results.set(cast<OpResult>(getPadded()), paddedOps);
2082  results.set(cast<OpResult>(getPad()), padOps);
2083  results.set(cast<OpResult>(getCopy()), copyBackOps);
2085 }
2086 
2087 LogicalResult transform::PadOp::verify() {
2088  SmallVector<int64_t> nofoldFlags =
2089  extractFromIntegerArrayAttr<int64_t>(getNofoldFlags());
2090  if (any_of(nofoldFlags, [](int64_t packPadding) {
2091  return packPadding != 0 && packPadding != 1;
2092  })) {
2093  return emitOpError()
2094  << "expects nofold_flags to contain booleans (0/1), found "
2095  << getNofoldFlags();
2096  }
2097 
2098  SmallVector<int64_t> paddingDimensions =
2099  extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
2100  if (any_of(paddingDimensions,
2101  [](int64_t paddingDimension) { return paddingDimension < 0; })) {
2102  return emitOpError() << "expects padding_dimensions to contain positive "
2103  "integers, found "
2104  << getPaddingDimensions();
2105  }
2106  if (!getMixedPadToMultipleOf().empty()) {
2107  if (getMixedPadToMultipleOf().size() != paddingDimensions.size()) {
2108  return emitOpError() << "expects as many multiples as padding_dimensions";
2109  }
2110  }
2111  ArrayAttr transposes = getTransposePaddings();
2112  for (Attribute attr : transposes) {
2113  SmallVector<int64_t> transpose = extractFromIntegerArrayAttr<int64_t>(attr);
2114  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2115  if (!std::is_permutation(sequence.begin(), sequence.end(),
2116  transpose.begin(), transpose.end())) {
2117  return emitOpError()
2118  << "expects transpose_paddings to be a permutation, found "
2119  << attr;
2120  }
2121  }
2122  if (getCopyBackOp() !=
2123  bufferization::MaterializeInDestinationOp::getOperationName() &&
2124  getCopyBackOp() != linalg::CopyOp::getOperationName() &&
2125  getCopyBackOp() != kCopyOpNone)
2126  return emitOpError() << "invalid copy_back_op";
2127  return success();
2128 }
2129 
2130 //===---------------------------------------------------------------------===//
2131 // HoistPadOp
2132 //===---------------------------------------------------------------------===//
2133 
2134 DiagnosedSilenceableFailure transform::HoistPadBuildPackingLoopNestOp::apply(
2135  transform::TransformRewriter &rewriter,
2136  transform::TransformResults &transformResults,
2137  transform::TransformState &state) {
2138  auto targetOps = state.getPayloadOps(getTarget());
2139  auto loopOps = state.getPayloadOps(getLoop());
2140  if (!llvm::hasSingleElement(targetOps) || !llvm::hasSingleElement(loopOps)) {
2141  return emitDefiniteFailure()
2142  << "requires exactly one target and one loop handle (got "
2143  << llvm::range_size(targetOps) << " and "
2144  << llvm::range_size(loopOps) << ")";
2145  }
2146 
2147  auto padOp = dyn_cast_or_null<tensor::PadOp>(*targetOps.begin());
2148  auto loopOp = dyn_cast_or_null<scf::ForOp>(*loopOps.begin());
2149  if (!padOp || !loopOp)
2150  return emitDefiniteFailure() << "requires exactly 2 non-null handles";
2151 
2152  FailureOr<linalg::detail::PackingResult> result =
2153  linalg::detail::buildPackingLoopNest(rewriter, padOp, loopOp,
2154  getTranspose());
2155  if (failed(result))
2156  return emitDefiniteFailure() << "could not build packing loop nest";
2157 
2158  if (result->clonedLoopIvs.empty()) {
2159  transformResults.set(cast<OpResult>(getPackingLoop()),
2160  {result->hoistedPadOp.getOperation()});
2162  }
2163  auto outerPackedLoop =
2164  scf::getForInductionVarOwner(result->clonedLoopIvs.front());
2165  transformResults.set(cast<OpResult>(getPackingLoop()),
2166  {outerPackedLoop.getOperation()});
2168 }
2169 
2171  ArrayRef<int64_t> transpose = getTranspose();
2172  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2173  if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),
2174  transpose.end())) {
2175  return emitOpError() << "expects transpose to be a permutation, found "
2176  << getTranspose();
2177  }
2178  return success();
2179 }
2180 
2181 void transform::HoistPadBuildPackingLoopNestOp::getEffects(
2183  transform::onlyReadsHandle(getTargetMutable(), effects);
2184  transform::onlyReadsHandle(getLoopMutable(), effects);
2185  transform::producesHandle(getOperation()->getOpResults(), effects);
2186  transform::modifiesPayload(effects);
2187 }
2188 
2190 transform::HoistPadOp::applyToOne(transform::TransformRewriter &rewriter,
2191  tensor::PadOp target,
2193  transform::TransformState &state) {
2194  tensor::PadOp hoistedPadOp;
2195  SmallVector<TransposeOp> transposeOps;
2196  FailureOr<Value> result =
2197  hoistPaddingOnTensors(rewriter, target, getNumLoops(), getTranspose(),
2198  hoistedPadOp, transposeOps);
2199  if (succeeded(result)) {
2200  // We need to perform our own replacement here because this API is still
2201  // used in patterns that "pad and hoist", for which the replacement values
2202  // need to be different.
2203  // TODO: clean this up and stop "pad and hoist" behavior more globally now
2204  // that we have more composable abstractions.
2205  rewriter.replaceOp(target, *result);
2206  results.push_back(hoistedPadOp);
2208  }
2209  return emitDefaultSilenceableFailure(target);
2210 }
2211 
2212 LogicalResult transform::HoistPadOp::verify() {
2213  ArrayRef<int64_t> transpose = getTranspose();
2214  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2215  if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),
2216  transpose.end())) {
2217  return emitOpError() << "expects transpose to be a permutation, found "
2218  << getTranspose();
2219  }
2220  return success();
2221 }
2222 
2223 //===----------------------------------------------------------------------===//
2224 // PromoteOp
2225 //===----------------------------------------------------------------------===//
2226 
2228 transform::PromoteOp::applyToOne(transform::TransformRewriter &rewriter,
2229  LinalgOp target,
2231  transform::TransformState &state) {
2232  LinalgPromotionOptions promotionOptions;
2233  if (!getOperandsToPromote().empty())
2234  promotionOptions = promotionOptions.setOperandsToPromote(
2235  extractFromIntegerArrayAttr<int64_t>(getOperandsToPromote()));
2236  if (getUseFullTilesByDefault())
2237  promotionOptions = promotionOptions.setUseFullTileBuffersByDefault(
2238  getUseFullTilesByDefault());
2239  if (getUseAlloca())
2240  promotionOptions = promotionOptions.setUseAlloca(getUseAlloca());
2241  if (!getUseFullTileBuffers().empty())
2242  promotionOptions = promotionOptions.setUseFullTileBuffers(
2243  llvm::to_vector(getUseFullTileBuffers().getAsValueRange<BoolAttr>()));
2244  if (getAlignment().has_value())
2245  promotionOptions = promotionOptions.setAlignment(*getAlignment());
2246  if (getMemorySpace().has_value())
2247  promotionOptions = promotionOptions.setMemorySpace(*getMemorySpace());
2248 
2249  if (getMapping().has_value()) {
2250  // The mapping should only contain an element
2251  auto mapping = *getMapping();
2252  if (mapping.size() > 1)
2253  return emitDefaultDefiniteFailure(target);
2254 
2255  auto addressSpace = cast<mlir::gpu::GPUMemorySpaceMappingAttr>(mapping[0]);
2256 
2257  if (addressSpace.getAddressSpace() ==
2258  mlir::gpu::GPUDialect::getWorkgroupAddressSpace()) {
2259  promotionOptions =
2260  promotionOptions
2264  .setUseFullTileBuffers({false, false});
2265  } else if (addressSpace.getAddressSpace() ==
2266  mlir::gpu::GPUDialect::getPrivateAddressSpace()) {
2267  promotionOptions =
2268  promotionOptions
2272  .setUseFullTileBuffers({false, false});
2273  } else {
2274  return emitDefaultDefiniteFailure(target);
2275  }
2276  }
2277 
2278  if (failed(promoteSubviewsPrecondition(target, promotionOptions)))
2279  return emitDefaultDefiniteFailure(target);
2280 
2281  rewriter.setInsertionPoint(target);
2282  FailureOr<LinalgOp> res = promoteSubViews(rewriter, target, promotionOptions);
2283  if (failed(res))
2284  return emitDefaultDefiniteFailure(target);
2285  results.push_back(target);
2287 }
2288 
2289 //===----------------------------------------------------------------------===//
2290 // ReplaceOp
2291 //===----------------------------------------------------------------------===//
2292 
2294 transform::ReplaceOp::apply(transform::TransformRewriter &rewriter,
2295  TransformResults &transformResults,
2296  TransformState &state) {
2297  auto payload = state.getPayloadOps(getTarget());
2298 
2299  // Check for invalid targets.
2300  for (Operation *target : payload) {
2301  if (target->getNumOperands() > 0)
2302  return emitDefiniteFailure() << "expected target without operands";
2303  if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
2304  target->getNumRegions() > 0)
2305  return emitDefiniteFailure()
2306  << "expected target that is isolated from above";
2307  }
2308 
2309  // Clone and replace.
2310  Operation *pattern = &getBodyRegion().front().front();
2311  SmallVector<Operation *> replacements;
2312  for (Operation *target : payload) {
2313  if (getOperation()->isAncestor(target))
2314  continue;
2315  rewriter.setInsertionPoint(target);
2316  Operation *replacement = rewriter.clone(*pattern);
2317  rewriter.replaceOp(target, replacement->getResults());
2318  replacements.push_back(replacement);
2319  }
2320  transformResults.set(cast<OpResult>(getReplacement()), replacements);
2322 }
2323 
2324 void transform::ReplaceOp::getEffects(
2326  consumesHandle(getTargetMutable(), effects);
2327  producesHandle(getOperation()->getOpResults(), effects);
2328  modifiesPayload(effects);
2329 }
2330 
2331 LogicalResult transform::ReplaceOp::verify() {
2332  if (!getBodyRegion().hasOneBlock())
2333  return emitOpError() << "expected one block";
2334  if (std::distance(getBodyRegion().front().begin(),
2335  getBodyRegion().front().end()) != 1)
2336  return emitOpError() << "expected one operation in block";
2337  Operation *replacement = &getBodyRegion().front().front();
2338  if (replacement->getNumOperands() > 0)
2339  return replacement->emitOpError()
2340  << "expected replacement without operands";
2341  if (!replacement->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
2342  replacement->getNumRegions() > 0)
2343  return replacement->emitOpError()
2344  << "expect op that is isolated from above";
2345  return success();
2346 }
2347 
2348 //===----------------------------------------------------------------------===//
2349 // ScalarizeOp
2350 //===----------------------------------------------------------------------===//
2351 
2353 transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
2354  LinalgOp target,
2356  transform::TransformState &state) {
2357  scf::SCFTilingOptions tilingOptions;
2358  tilingOptions.setTileSizeComputationFunction([&](OpBuilder &b, Operation *) {
2359  SmallVector<OpFoldResult> tileSizes;
2360  Location loc = target.getLoc();
2361  SmallVector<OpFoldResult> allShapeSizes =
2362  target.createFlatListOfOperandDims(b, loc);
2363  AffineMap map = target.getShapesToLoopsMap();
2364  if (!map)
2365  return tileSizes;
2366  SmallVector<OpFoldResult> shapeSizes =
2368  allShapeSizes);
2369  // If the shape size is dynamic, tile by 1.
2370  // Otherwise, do not tile (i.e. tile size 0).
2371  for (OpFoldResult shapeSize : shapeSizes) {
2372  tileSizes.push_back(getConstantIntValue(shapeSize) ? b.getIndexAttr(0)
2373  : b.getIndexAttr(1));
2374  }
2375  return tileSizes;
2376  });
2377  rewriter.setInsertionPoint(target);
2378  FailureOr<scf::SCFTilingResult> maybeTilingResult = tileUsingSCF(
2379  rewriter, cast<TilingInterface>(target.getOperation()), tilingOptions);
2380  if (failed(maybeTilingResult))
2381  return emitDefaultDefiniteFailure(target);
2382 
2383  if (target->getNumResults())
2384  rewriter.replaceOp(target, maybeTilingResult->mergeResult.replacements);
2385  else
2386  rewriter.eraseOp(target);
2387 
2388  results.reserve(maybeTilingResult->tiledOps.size());
2389  for (Operation *tiled : maybeTilingResult->tiledOps)
2390  results.push_back(tiled);
2392 }
2393 
2394 //===----------------------------------------------------------------------===//
2395 // ConvertToLoopsOp
2396 //===----------------------------------------------------------------------===//
2397 
2399 transform::ConvertToLoopsOp::apply(transform::TransformRewriter &rewriter,
2400  transform::TransformResults &results,
2401  transform::TransformState &state) {
2403  for (Operation *target : state.getPayloadOps(getTarget())) {
2404  auto tilingOp = dyn_cast<TilingInterface>(*target);
2405  if (!tilingOp) {
2407  emitSilenceableError()
2408  << "expected the payload to implement TilingInterface";
2409  diag.attachNote(target->getLoc()) << "payload op";
2410  return diag;
2411  }
2412  rewriter.setInsertionPoint(target);
2413  FailureOr<SmallVector<scf::ForOp>> generatedLoops =
2414  scf::lowerToLoopsUsingSCFForOp(rewriter, tilingOp);
2415  if (failed(generatedLoops))
2416  return emitDefaultDefiniteFailure(target);
2417  for (scf::ForOp &loop : *generatedLoops) {
2418  loops.push_back(loop.getOperation());
2419  }
2420  rewriter.eraseOp(target);
2421  }
2422  results.set(cast<OpResult>(getResult()), loops);
2424 }
2425 
2426 //===----------------------------------------------------------------------===//
2427 // RewriteInDestinationPassingStyleOp
2428 //===----------------------------------------------------------------------===//
2429 
2431 transform::RewriteInDestinationPassingStyleOp::applyToOne(
2432  transform::TransformRewriter &rewriter, Operation *target,
2434  transform::TransformState &state) {
2435  rewriter.setInsertionPoint(target);
2436  FailureOr<Operation *> maybeResult =
2438  .Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp>(
2439  [&rewriter](auto op) {
2440  return rewriteInDestinationPassingStyle(rewriter, op);
2441  });
2442  if (failed(maybeResult))
2443  return emitDefaultSilenceableFailure(target);
2444  results.push_back(*maybeResult);
2446 }
2447 
2448 //===----------------------------------------------------------------------===//
2449 // SplitOp
2450 //===----------------------------------------------------------------------===//
2451 
2453 SplitOp::apply(transform::TransformRewriter &rewriter,
2454  TransformResults &results, TransformState &state) {
2455  // Collect the dynamic split points if provided.
2456  SmallVector<Operation *> payload =
2457  llvm::to_vector(state.getPayloadOps(getTarget()));
2458 
2459  bool isMultiwaySplit = getMultiway();
2460 
2461  if (isMultiwaySplit && !llvm::hasSingleElement(payload)) {
2462  return mlir::emitSilenceableFailure(getLoc())
2463  << "requires exactly one target when "
2464  "multiway split is enabled (got "
2465  << llvm::range_size(payload) << ")";
2466  }
2467 
2468  SmallVector<OpFoldResult> chunkSizes;
2469 
2470  if (!isMultiwaySplit)
2471  chunkSizes.reserve(payload.size());
2472 
2473  if (getDynamicChunkSizes()) {
2475  if (isa<TransformHandleTypeInterface>(getDynamicChunkSizes().getType())) {
2476  chunkSizes = llvm::to_vector(llvm::map_range(
2477  state.getPayloadOps(getDynamicChunkSizes()), [&](Operation *op) {
2478  if (op->getNumResults() != 1 ||
2479  !op->getResult(0).getType().isIndex()) {
2480  diag = emitSilenceableError()
2481  << "expected dynamic split point handle to point to a "
2482  "single-result index-typed op";
2483  diag.attachNote(op->getLoc()) << "dynamic split point";
2484  }
2485  return OpFoldResult(op->getResult(0));
2486  }));
2487  } else {
2488  chunkSizes = llvm::to_vector(
2489  llvm::map_range(state.getParams(getDynamicChunkSizes()),
2490  [](Attribute attr) { return OpFoldResult(attr); }));
2491  }
2492  if (diag.isSilenceableFailure())
2493  return diag;
2494 
2495  // For multiway split, a single payload is expected to have multiple
2496  // split points.
2497  if (!isMultiwaySplit && chunkSizes.size() != payload.size()) {
2498  return emitDefiniteFailure()
2499  << "expected the dynamic split point handle to point to as "
2500  "many operations ("
2501  << chunkSizes.size() << ") as the target handle ("
2502  << payload.size() << ")";
2503  }
2504  } else {
2505  chunkSizes.resize(payload.size(),
2506  rewriter.getIndexAttr(getStaticChunkSizes()));
2507  }
2508 
2509  auto checkStructuredOpAndDimensions =
2510  [&](LinalgOp linalgOp, Location loc) -> DiagnosedSilenceableFailure {
2511  if (!linalgOp) {
2512  auto diag = emitSilenceableError() << "only applies to structured ops";
2513  diag.attachNote(loc) << "target op";
2514  return diag;
2515  }
2516 
2517  if (getDimension() >= linalgOp.getNumLoops()) {
2518  auto diag = emitSilenceableError() << "dimension " << getDimension()
2519  << " does not exist in target op";
2520  diag.attachNote(loc) << "target op";
2521  return diag;
2522  }
2524  };
2525 
2526  auto checkFailureInSplitting =
2527  [&](bool hasFailed, Location loc) -> DiagnosedSilenceableFailure {
2528  if (hasFailed) {
2529  auto diag = emitDefiniteFailure() << "internal failure in splitting";
2530  diag.attachNote(loc) << "target op";
2531  return diag;
2532  }
2534  };
2535 
2536  SmallVector<Operation *> opList;
2537  if (isMultiwaySplit) {
2538 
2539  // Split a single target operation at multiple points.
2540  TilingInterface head, tail;
2541  Operation *target = payload.front();
2542 
2543  LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
2544 
2545  // Check that the target is a valid LinalgOp with correct dimensions.
2547  checkStructuredOpAndDimensions(linalgOp, target->getLoc());
2548  if (diag.isSilenceableFailure())
2549  return diag;
2550 
2551  for (auto &&[idx, chunkSize] : llvm::enumerate(chunkSizes)) {
2552 
2553  if (idx > 0)
2554  target = tail.getOperation();
2555 
2556  if (!target)
2557  break;
2558 
2559  linalgOp = cast<LinalgOp>(target);
2560  Location loc = target->getLoc();
2561 
2562  rewriter.setInsertionPoint(linalgOp);
2563  std::tie(head, tail) = linalg::splitOp(
2564  rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2565  getDimension(), chunkSize);
2566 
2567  // Propagate errors.
2569  checkFailureInSplitting(!head && !tail, loc);
2570  if (diag.isDefiniteFailure())
2571  return diag;
2572 
2573  opList.push_back(head.getOperation());
2574  }
2575 
2576  // Append any leftover parts to the end of the result list.
2577  if (tail)
2578  opList.push_back(tail.getOperation());
2579 
2580  } else {
2581  // Split each target operation.
2582  SmallVector<Operation *> first, second;
2583  Operation *noSecondPart = nullptr;
2584  for (const auto &pair : llvm::zip(payload, chunkSizes)) {
2585  Operation *target = std::get<0>(pair);
2586  Location loc = target->getLoc();
2587  LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
2589  checkStructuredOpAndDimensions(linalgOp, target->getLoc());
2590 
2591  if (diag.isSilenceableFailure())
2592  return diag;
2593 
2594  rewriter.setInsertionPoint(linalgOp);
2595  std::tie(first.emplace_back(), second.emplace_back()) = linalg::splitOp(
2596  rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2597  getDimension(), std::get<1>(pair));
2598 
2599  // Propagate errors.
2600  DiagnosedSilenceableFailure diagSplit =
2601  checkFailureInSplitting(!first.back() && !second.back(), loc);
2602  if (diagSplit.isDefiniteFailure())
2603  return diag;
2604 
2605  // Do not add null second parts.
2606  if (!second.back()) {
2607  noSecondPart = target;
2608  second.pop_back();
2609  }
2610  }
2611 
2612  if (second.size() != first.size() && !second.empty()) {
2613  auto diag = emitSilenceableError()
2614  << "splitting does not produce the second part for a subset "
2615  "of targets";
2616  diag.attachNote()
2617  << "expected splitting to produce the second part of all "
2618  "or none of the targets";
2619  diag.attachNote(noSecondPart->getLoc())
2620  << "first target with no second part";
2621  return diag;
2622  }
2623 
2624  opList.append(first);
2625  if (second.size())
2626  opList.append(second);
2627  }
2628  results.set(cast<OpResult>(getSplitList()), opList);
2630 }
2631 
2632 void SplitOp::getEffects(
2634  consumesHandle(getTargetMutable(), effects);
2635  if (getDynamicChunkSizes())
2636  onlyReadsHandle(getDynamicChunkSizesMutable(), effects);
2637  producesHandle(getOperation()->getOpResults(), effects);
2638  modifiesPayload(effects);
2639 }
2640 
2641 ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) {
2642  OpAsmParser::UnresolvedOperand target, dynamicChunkSizes;
2643  IntegerAttr staticChunkSizes;
2644  if (parser.parseOperand(target) || parser.parseKeyword("after"))
2645  return failure();
2646 
2647  OptionalParseResult dynamicPointParseResult =
2648  parser.parseOptionalOperand(dynamicChunkSizes);
2649  if (!dynamicPointParseResult.has_value()) {
2650  int64_t staticChunkSizesValue;
2651  if (failed(parser.parseInteger(staticChunkSizesValue)))
2652  return failure();
2653 
2654  staticChunkSizes =
2655  parser.getBuilder().getI64IntegerAttr(staticChunkSizesValue);
2656  }
2657 
2658  Type targetType;
2659  if (parser.parseOptionalAttrDict(result.attributes) ||
2660  parser.parseColonType(targetType) ||
2661  parser.resolveOperand(target, targetType, result.operands)) {
2662  return failure();
2663  }
2664  if (dynamicPointParseResult.has_value()) {
2665  Type ChunkSizesType;
2666  if (failed(*dynamicPointParseResult) || parser.parseComma() ||
2667  parser.parseType(ChunkSizesType) ||
2668  parser.resolveOperand(dynamicChunkSizes, ChunkSizesType,
2669  result.operands)) {
2670  return failure();
2671  }
2672 
2673  staticChunkSizes =
2674  parser.getBuilder().getI64IntegerAttr(ShapedType::kDynamic);
2675  }
2676 
2677  result.addAttribute(
2678  SplitOp::getStaticChunkSizesAttrName(result.name).getValue(),
2679  staticChunkSizes);
2680  result.addTypes(targetType);
2681  return success();
2682 }
2683 
2684 void SplitOp::print(OpAsmPrinter &printer) {
2685  printer << " " << getTarget() << " after ";
2686  int64_t staticChunkSize = static_cast<int64_t>(getStaticChunkSizes());
2687  if (staticChunkSize != ShapedType::kDynamic)
2688  printer << staticChunkSize;
2689  else
2690  printer << getDynamicChunkSizes();
2691  printer << " ";
2692  printer.printOptionalAttrDict(getOperation()->getAttrs(),
2693  {getStaticChunkSizesAttrName()});
2694  printer << " : " << getTarget().getType();
2695  if (staticChunkSize == ShapedType::kDynamic)
2696  printer << ", " << getDynamicChunkSizes().getType();
2697 }
2698 
2699 LogicalResult SplitOp::verify() {
2700  if ((static_cast<int64_t>(getStaticChunkSizes()) != ShapedType::kDynamic) ^
2701  (getDynamicChunkSizes() == nullptr)) {
2702  return emitOpError() << "expects either a dynamic or a static split "
2703  "point to be provided";
2704  }
2705  return success();
2706 }
2707 
2708 //===----------------------------------------------------------------------===//
2709 // SplitReductionOp
2710 //===----------------------------------------------------------------------===//
2711 
2712 void transform::SplitReductionOp::build(
2713  OpBuilder &builder, OperationState &result, Value target,
2714  int64_t splitFactor, int64_t insertSplitDimension, bool innerParallel,
2715  bool useScalingAlgorithm, bool useAlloc) {
2716  MLIRContext *ctx = builder.getContext();
2717  result.addOperands(target);
2718  result.addAttribute(SplitReductionOp::getSplitFactorAttrName(result.name),
2719  builder.getI64IntegerAttr(splitFactor));
2720  result.addAttribute(
2721  SplitReductionOp::getInsertSplitDimensionAttrName(result.name),
2722  builder.getI64IntegerAttr(insertSplitDimension));
2723  if (innerParallel) {
2724  result.addAttribute(SplitReductionOp::getInnerParallelAttrName(result.name),
2725  builder.getUnitAttr());
2726  }
2727  if (useScalingAlgorithm) {
2728  result.addAttribute(
2729  SplitReductionOp::getUseScalingAlgorithmAttrName(result.name),
2730  builder.getUnitAttr());
2731  }
2732  if (useAlloc) {
2733  result.addAttribute(SplitReductionOp::getUseAllocAttrName(result.name),
2734  builder.getUnitAttr());
2735  }
2736  auto resultType = transform::AnyOpType::get(ctx);
2737  result.addTypes({resultType, resultType, resultType, resultType});
2738 }
2739 
2740 DiagnosedSilenceableFailure transform::SplitReductionOp::applyToOne(
2741  transform::TransformRewriter &rewriter, LinalgOp target,
2743  transform::TransformState &state) {
2744  ControlSplitReductionFn splitFn = [&](LinalgOp) {
2745  return linalg::SplitReductionOptions{int64_t(getSplitFactor()),
2746  unsigned(getInsertSplitDimension()),
2747  bool(getInnerParallel())};
2748  };
2749  rewriter.setInsertionPoint(target);
2750  FailureOr<SplitReductionResult> splitResult =
2751  (getUseScalingAlgorithm())
2752  ? splitReductionByScaling(rewriter, target, splitFn, getUseAlloc())
2753  : splitReduction(rewriter, target, splitFn, getUseAlloc());
2754  if (failed(splitResult))
2755  return emitDefaultDefiniteFailure(target);
2756 
2757  results.push_back(splitResult->initOrAlloc);
2758  results.push_back(splitResult->fillOp);
2759  results.push_back(splitResult->splitLinalgOp);
2760  results.push_back(splitResult->resultCombiningLinalgOp);
2762 }
2763 
2764 //===----------------------------------------------------------------------===//
2765 // TileReductionUsingForOp
2766 //===----------------------------------------------------------------------===//
2767 
2768 void transform::TileReductionUsingForOp::build(
2769  OpBuilder &builder, OperationState &result, Value target,
2770  ArrayRef<int64_t> staticTileSizes) {
2771  // Call the default builder.
2772  // This is future-proof re mixed static-dynamic and setting up the proper
2773  // operands segment sizes attributes for multiple variadic operands.
2774  // In the absence of this, horrible bugs ensue.
2775  // TODO: support mixed static-dynamic (see TileUsingForallOp).
2776  MLIRContext *ctx = builder.getContext();
2777  auto opTy = transform::AnyOpType::get(ctx);
2778  auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
2779  build(builder, result,
2780  /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
2781  /*target=*/target,
2782  /*tile_sizes=*/staticTileSizesAttr);
2783 }
2784 
2785 DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
2786  transform::TransformRewriter &rewriter, Operation *target,
2788  transform::TransformState &state) {
2789  rewriter.setInsertionPoint(target);
2790 
2791  auto partialReductionOp = dyn_cast<PartialReductionOpInterface>(target);
2792  if (!partialReductionOp) {
2793  return emitSilenceableFailure(
2794  target->getLoc(),
2795  "Operation should implement PartialReductionOpInterface");
2796  }
2797  FailureOr<scf::SCFTilingResult> result = scf::tileReductionUsingScf(
2798  rewriter, partialReductionOp,
2799  getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes())));
2800 
2801  if (failed(result))
2802  return emitDefaultSilenceableFailure(target);
2803  rewriter.replaceOp(target, result->mergeResult.replacements);
2804  for (Value initValue : result->initialValues)
2805  results.push_back(initValue.getDefiningOp());
2806  for (auto parallelTiledOp : result->tiledOps)
2807  results.push_back(parallelTiledOp);
2808  for (auto mergeOp : result->mergeResult.mergeOps)
2809  results.push_back(mergeOp);
2810  results.push_back(result->loops.front());
2812 }
2813 
2814 //===----------------------------------------------------------------------===//
2815 // TileReductionUsingForallOp
2816 //===----------------------------------------------------------------------===//
2817 
2818 void transform::TileReductionUsingForallOp::build(
2819  OpBuilder &builder, OperationState &result, Value target,
2820  ArrayRef<int64_t> staticNumThreads, ArrayRef<int64_t> staticTileSizes,
2821  ArrayAttr mapping) {
2822  // Call the default builder.
2823  // This is future-proof re mixed static-dynamic and setting up the proper
2824  // operands segment sizes attributes for multiple variadic operands.
2825  // In the absence of this, horrible bugs ensue.
2826  // TODO: support mixed static-dynamic (see TileUsingForallOp).
2827  MLIRContext *ctx = builder.getContext();
2828  auto opTy = transform::AnyOpType::get(ctx);
2829  auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads);
2830  auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
2831  build(builder, result,
2832  /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
2833  /*target=*/target,
2834  /*num_threads=*/staticNumThreadsAttr,
2835  /*tile_sizes=*/staticTileSizesAttr,
2836  /*mapping=*/mapping);
2837 }
2838 
2839 DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
2840  transform::TransformRewriter &rewriter, LinalgOp target,
2842  transform::TransformState &state) {
2843  rewriter.setInsertionPoint(target);
2844  SmallVector<OpFoldResult> numThreads =
2845  getAsOpFoldResult(rewriter.getI64ArrayAttr(getNumThreads()));
2846  SmallVector<OpFoldResult> tileSizes =
2847  getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes()));
2848  FailureOr<linalg::ForallReductionTilingResult> result =
2850  rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
2851  numThreads, tileSizes, getMapping());
2852 
2853  if (failed(result)) {
2854  auto diag = emitSilenceableError() << "could not tile reduction";
2855  diag.attachNote(target.getLoc()) << "target operation";
2856  return diag;
2857  }
2858  for (Value initValue : result->initialValues)
2859  results.push_back(initValue.getDefiningOp());
2860  for (auto parallelTiledOp : result->parallelTiledOps)
2861  results.push_back(parallelTiledOp);
2862  for (auto mergeOp : result->mergeOps)
2863  results.push_back(mergeOp);
2864  results.push_back(result->loops);
2866 }
2867 
2868 //===----------------------------------------------------------------------===//
2869 // ContinuousTileSizesOp
2870 //===----------------------------------------------------------------------===//
2871 
2873 transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
2874  TransformResults &transformResults,
2875  TransformState &state) {
2876 
2877  SmallVector<Operation *> targetOps =
2878  llvm::to_vector(state.getPayloadOps(getTarget()));
2879 
2880  if (!llvm::hasSingleElement(targetOps)) {
2881  return mlir::emitSilenceableFailure(getLoc())
2882  << "requires exactly one target (got " << llvm::range_size(targetOps)
2883  << ")";
2884  }
2885 
2886  Operation *target = *targetOps.begin();
2887  auto linalgOp = dyn_cast<LinalgOp>(target);
2888  auto tileableOp = dyn_cast<TilingInterface>(target);
2889 
2890  if (!linalgOp)
2891  return emitDefiniteFailure() << "expected Linalg Op";
2892 
2893  OpBuilder builder(linalgOp.getContext());
2894 
2895  if (isa<TransformParamTypeInterface>(getChunkSizes().getType())) {
2896  if (linalgOp.hasDynamicShape()) {
2897  auto diag = emitSilenceableError()
2898  << "cannot compute parametric tile sizes for dynamically "
2899  "shaped payload op";
2900  diag.attachNote(linalgOp->getLoc()) << "payload op";
2901  return diag;
2902  }
2903 
2904  FailureOr<StaticContinuousTileSizeSpecification> spec =
2905  computeStaticContinuousTileSizes(linalgOp, getDimension(),
2906  getTargetSize());
2907  if (failed(spec)) {
2908  return emitSilenceableError()
2909  << "failed to compute multi-size tiling sizes";
2910  }
2911 
2912  SmallVector<int64_t> chunkSizes;
2913 
2914  for (auto &&[tileSize, tripCount] :
2915  llvm::zip_equal(spec->tileSizes, spec->tripCounts))
2916  chunkSizes.push_back(tileSize * tripCount);
2917 
2918  auto getI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
2919  return llvm::map_to_vector(values, [&](int64_t value) -> Attribute {
2920  return builder.getI64IntegerAttr(value);
2921  });
2922  };
2923  transformResults.setParams(cast<OpResult>(getTileSizes()),
2924  getI64AttrsFromI64(spec->tileSizes));
2925  transformResults.setParams(cast<OpResult>(getChunkSizes()),
2926  getI64AttrsFromI64(chunkSizes));
2927 
2929  }
2930 
2931  builder.setInsertionPoint(linalgOp);
2932 
2933  OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
2934  unsigned dimension = getDimension();
2935 
2936  FailureOr<ContinuousTileSizeSpecification> spec = computeContinuousTileSizes(
2937  builder, tileableOp, dimension, targetSize, true);
2938  if (failed(spec)) {
2939  return emitSilenceableError() << "could not generate tile size computation";
2940  }
2941 
2942  AffineExpr s0 = builder.getAffineSymbolExpr(0);
2943  AffineExpr s1 = builder.getAffineSymbolExpr(1);
2944  auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value {
2945  return affine::makeComposedAffineApply(builder, linalgOp->getLoc(), expr,
2946  ofrs);
2947  };
2948 
2949  SmallVector<Value> chunkSizes;
2950  Value splitPoint;
2951  for (auto &&[tileSize, tripCount] :
2952  llvm::zip_equal(spec->tileSizes, spec->tripCounts)) {
2953  splitPoint = apply(s0 * s1, {tileSize, tripCount});
2954  chunkSizes.push_back(splitPoint);
2955  }
2956 
2957  auto getDefiningOps = [&](ArrayRef<Value> values) {
2958  return llvm::map_to_vector(values, [&](Value value) -> Operation * {
2959  return value.getDefiningOp();
2960  });
2961  };
2962 
2963  transformResults.set(cast<OpResult>(getTileSizes()),
2964  getDefiningOps(spec->tileSizes));
2965  transformResults.set(cast<OpResult>(getChunkSizes()),
2966  getDefiningOps(chunkSizes));
2967 
2969 }
2970 
2972 
2973  if (getTileSizes().getType() != getChunkSizes().getType()) {
2974  return emitOpError() << "expects all results type to be the same";
2975  }
2976 
2977  return success();
2978 }
2979 
2980 void transform::ContinuousTileSizesOp::getEffects(
2982  if (isa<TransformParamTypeInterface>(getTileSizes().getType()))
2983  onlyReadsPayload(effects);
2984  else
2985  modifiesPayload(effects);
2986  onlyReadsHandle(getTargetMutable(), effects);
2987  producesHandle(getOperation()->getOpResults(), effects);
2988 }
2989 
2991  Type targetType, Type tile_sizes,
2992  Type) {
2993  printer.printFunctionalType(TypeRange{targetType}, TypeRange{tile_sizes});
2994 }
2995 
2996 static ParseResult parseContinuousTileSizeTypes(OpAsmParser &parser,
2997  Type &targetType,
2998  Type &tileSizesType,
2999  Type &chunkSizesType) {
3000  FunctionType funcType;
3001  llvm::SMLoc typeLoc = parser.getCurrentLocation();
3002  if (failed(parser.parseType<FunctionType>(funcType)))
3003  return failure();
3004 
3005  if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
3006  parser.emitError(typeLoc) << "expects a trailing functional type with one "
3007  "argument and one result";
3008  }
3009  targetType = funcType.getInput(0);
3010  tileSizesType = chunkSizesType = funcType.getResult(0);
3011 
3012  return success();
3013 }
3014 
3015 //===----------------------------------------------------------------------===//
3016 // TileUsingForOp
3017 //===----------------------------------------------------------------------===//
3018 
3019 void transform::TileUsingForOp::build(
3020  OpBuilder &builder, OperationState &result, TypeRange loopTypes,
3021  Value target, ArrayRef<int64_t> staticTileSizes,
3022  ArrayRef<int64_t> interchange,
3023  std::optional<ArrayRef<bool>> scalableSizes) {
3024  return build(builder, result, loopTypes,
3025  /*target=*/target,
3026  /*mixedTileSizes=*/
3027  getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
3028  interchange, scalableSizes);
3029 }
3030 
3031 void transform::TileUsingForOp::build(
3032  OpBuilder &builder, OperationState &result, Value target,
3033  ArrayRef<int64_t> staticTileSizes, ArrayRef<int64_t> interchange,
3034  std::optional<ArrayRef<bool>> scalableSizes) {
3035  build(builder, result, target,
3036  getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
3037  interchange, scalableSizes);
3038 }
3039 
3040 void transform::TileUsingForOp::build(
3041  OpBuilder &builder, OperationState &result, Value target,
3042  ArrayRef<OpFoldResult> mixedTileSizes, ArrayRef<int64_t> interchange,
3043  std::optional<ArrayRef<bool>> scalableSizes) {
3044  // Loop types are automaticaly splat by the callee, setting up one is
3045  // enough.
3046  SmallVector<Type> loopTypes(1, builder.getType<transform::AnyOpType>());
3047  build(builder, result, loopTypes, target, mixedTileSizes, interchange,
3048  scalableSizes);
3049 }
3050 
3051 void transform::TileUsingForOp::build(
3052  OpBuilder &builder, OperationState &result, TypeRange loopTypes,
3053  Value target, ArrayRef<OpFoldResult> mixedTileSizes,
3054  ArrayRef<int64_t> interchange,
3055  std::optional<ArrayRef<bool>> scalableSizes) {
3056  SmallVector<int64_t> staticTileSizes;
3057  SmallVector<Value> dynamicTileSizes;
3058  dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
3059  // Call the default builder which sets up the proper operands segment sizes
3060  // attributes for multiple variadic operands. In the absence of this,
3061  // horrible bugs ensue.
3062  auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
3063  unsigned numExpectedLoops =
3064  staticTileSizes.size() - llvm::count(staticTileSizes, 0);
3065  SmallVector<Type> resultTypes;
3066  resultTypes.reserve(numExpectedLoops);
3067  assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) &&
3068  "expected one loop type or as many as loops");
3069  if (loopTypes.size() == 1)
3070  resultTypes.append(numExpectedLoops, loopTypes[0]);
3071  else
3072  llvm::append_range(resultTypes, loopTypes);
3073  SmallVector<bool> expandedScalableSizes(mixedTileSizes.size(), false);
3074  if (scalableSizes.has_value())
3075  expandedScalableSizes.assign(scalableSizes->begin(), scalableSizes->end());
3076  build(builder, result, /*tiled_linalg_op=*/target.getType(),
3077  /*loops=*/resultTypes,
3078  /*target=*/target,
3079  /*dynamic_sizes=*/dynamicTileSizes,
3080  /*static_sizes=*/staticTileSizesAttr,
3081  /*interchange=*/builder.getDenseI64ArrayAttr(interchange),
3082  /*scalable_sizes=*/expandedScalableSizes);
3083 }
3084 
3085 LogicalResult transform::TileUsingForOp::verify() {
3086  if (getMixedSizes().size() != getScalableSizes().size())
3087  return emitOpError("expected same number of sizes (")
3088  << getMixedSizes().size() << ") and scalable sizes ("
3089  << getScalableSizes().size() << ")";
3090  ArrayRef<int64_t> staticSizes = getStaticSizes();
3091  unsigned numExpectedLoops = staticSizes.size() - llvm::count(staticSizes, 0);
3092  if (getLoops().size() != numExpectedLoops)
3093  return emitOpError("expected number of loops to tile (")
3094  << numExpectedLoops << ") to match number of `loops` results ("
3095  << getLoops().size() << ")";
3096  return success();
3097 }
3098 
3100 transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter,
3101  TransformResults &transformResults,
3102  TransformState &state) {
3103  ArrayRef<int64_t> tileSizes = getStaticSizes();
3104 
3105  SmallVector<Operation *> targets =
3106  llvm::to_vector(state.getPayloadOps(getTarget()));
3107  SmallVector<SmallVector<Operation *>> dynamicSizeProducers;
3109  dynamicSizeProducers.reserve(getDynamicSizes().size());
3110  paramSizes.reserve(getDynamicSizes().size());
3111  for (Value transformValue : getDynamicSizes()) {
3112  if (isa<ParamType>(transformValue.getType())) {
3113  dynamicSizeProducers.push_back({});
3114  ArrayRef<Attribute> params = state.getParams(transformValue);
3115  paramSizes.push_back(
3116  llvm::to_vector(llvm::map_range(params, [](Attribute attr) {
3117  return cast<IntegerAttr>(attr).getValue().getSExtValue();
3118  })));
3119 
3120  if (paramSizes.back().size() != targets.size()) {
3122  emitSilenceableError()
3123  << "expected as many parameter values ("
3124  << dynamicSizeProducers.back().size() << ") as target ops ("
3125  << targets.size() << ")";
3126  diag.attachNote(transformValue.getLoc()) << "for this parameter";
3127  return diag;
3128  }
3129 
3130  continue;
3131  }
3132  paramSizes.push_back({});
3133  dynamicSizeProducers.push_back(
3134  llvm::to_vector(state.getPayloadOps(transformValue)));
3135 
3136  if (dynamicSizeProducers.back().size() != targets.size()) {
3138  emitSilenceableError()
3139  << "expected as many dynamic size-producing operations ("
3140  << dynamicSizeProducers.back().size() << ") as target ops ("
3141  << targets.size() << ")";
3142  diag.attachNote(transformValue.getLoc()) << "for this handle";
3143  return diag;
3144  }
3145 
3146  for (Operation *op : dynamicSizeProducers.back()) {
3147  if (op->getNumResults() == 1 &&
3148  isa<IndexType>(op->getResult(0).getType())) {
3149  continue;
3150  }
3151 
3153  emitSilenceableError() << "expected sizes to be produced by ops "
3154  "with a single index-type result";
3155  diag.attachNote(op->getLoc()) << "size producer op";
3156  diag.attachNote(transformValue.getLoc()) << "for this handle";
3157  return diag;
3158  }
3159  }
3160 
3163  loops.resize(getLoops().size());
3164  auto scalableSizes = getScalableSizes();
3165  for (auto [i, op] : llvm::enumerate(targets)) {
3166  auto tilingInterface = dyn_cast<TilingInterface>(op);
3167  if (!tilingInterface) {
3169  emitSilenceableError()
3170  << "only ops implementing TilingInterface are supported";
3171  diag.attachNote(op->getLoc()) << "target op";
3172  return diag;
3173  }
3174  if (tileSizes.size() > tilingInterface.getLoopIteratorTypes().size()) {
3176  emitSilenceableError()
3177  << "too many tiles provided, expected at most "
3178  << tilingInterface.getLoopIteratorTypes().size() << " found "
3179  << tileSizes.size();
3180  diag.attachNote(op->getLoc()) << "target op";
3181  return diag;
3182  }
3183 
3184  scf::SCFTilingOptions tilingOptions;
3185  if (tileSizes.empty()) {
3186  tilingOptions.setTileSizeComputationFunction(
3188  return {};
3189  });
3190  } else {
3191  tilingOptions.setTileSizeComputationFunction([&, index = i](OpBuilder &b,
3192  Operation *) {
3194  sizes.reserve(tileSizes.size());
3195  unsigned dynamicIdx = 0;
3196 
3197  for (auto [ofrIdx, ofr] : llvm::enumerate(getMixedSizes())) {
3198  if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) {
3199  if (scalableSizes[ofrIdx]) {
3200  auto val = b.create<arith::ConstantIndexOp>(
3201  getLoc(), cast<IntegerAttr>(attr).getInt());
3202  Value vscale =
3203  b.create<vector::VectorScaleOp>(getLoc(), b.getIndexType());
3204  sizes.push_back(
3205  b.create<arith::MulIOp>(getLoc(), val, vscale).getResult());
3206  } else {
3207  sizes.push_back(attr);
3208  }
3209  continue;
3210  }
3211  ArrayRef<Operation *> dynamicSizes = dynamicSizeProducers[dynamicIdx];
3212  ArrayRef<int64_t> params = paramSizes[dynamicIdx];
3213  ++dynamicIdx;
3214  assert((dynamicSizes.empty() ^ params.empty()) &&
3215  "expected either dynamic sizes or parameters");
3216  if (!params.empty()) {
3217  sizes.push_back(b.getIndexAttr(params[index]));
3218  } else {
3219  sizes.push_back(dynamicSizes[index]->getResult(0));
3220  }
3221  }
3222  return sizes;
3223  });
3224  }
3225 
3226  tilingOptions.setInterchange(getInterchange());
3227  FailureOr<scf::SCFTilingResult> maybeTilingResult =
3228  tileUsingSCF(rewriter, tilingInterface, tilingOptions);
3229  if (failed(maybeTilingResult))
3231 
3232  rewriter.replaceOp(op, maybeTilingResult->mergeResult.replacements);
3233 
3234  tiled.append(maybeTilingResult->tiledOps);
3235  for (const auto &en2 : llvm::enumerate(maybeTilingResult->loops))
3236  loops[en2.index()].push_back(en2.value());
3237  }
3238 
3239  transformResults.set(cast<OpResult>(getTiledLinalgOp()), tiled);
3240  for (const auto &en : llvm::enumerate(loops))
3241  transformResults.set(cast<OpResult>(getLoops()[en.index()]), en.value());
3242 
3244 }
3245 
3247  ValueRange dynamic = getDynamicSizes();
3248  ArrayRef<int64_t> tileSizes = getStaticSizes();
3249  SmallVector<OpFoldResult> results;
3250  results.reserve(tileSizes.size());
3251  unsigned dynamicPos = 0;
3252  Builder builder(getContext());
3253  for (int64_t size : tileSizes) {
3254  if (size == ShapedType::kDynamic) {
3255  results.push_back(dynamic[dynamicPos++]);
3256  } else {
3257  results.push_back(builder.getIndexAttr(size));
3258  }
3259  }
3260  return results;
3261 }
3262 
3263 void transform::TileUsingForOp::getEffects(
3265  consumesHandle(getTargetMutable(), effects);
3266  onlyReadsHandle(getDynamicSizesMutable(), effects);
3267  producesHandle(getOperation()->getOpResults(), effects);
3268  modifiesPayload(effects);
3269 }
3270 
3271 //===----------------------------------------------------------------------===//
3272 // TileUsingForallOp
3273 //===----------------------------------------------------------------------===//
3274 
3275 void transform::TileUsingForallOp::build(OpBuilder &builder,
3276  OperationState &result, Value target,
3277  ArrayRef<int64_t> staticTileSizes,
3279  ArrayAttr mapping) {
3280  return build(builder, result,
3281  /*target=*/target,
3282  /*mixedTileSizes=*/
3283  getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
3284  /*_=*/TileSizesSpec(),
3285  /*mapping=*/mapping);
3286 }
3287 
3288 void transform::TileUsingForallOp::build(OpBuilder &builder,
3289  OperationState &result, Value target,
3290  ArrayRef<OpFoldResult> mixedTileSizes,
3292  ArrayAttr mapping) {
3293  SmallVector<int64_t> staticTileSizes;
3294  SmallVector<Value> dynamicTileSizes;
3295  dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
3296  // Call the default builder which sets up the proper operands segment sizes
3297  // attributes for multiple variadic operands. In the absence of this,
3298  // horrible bugs ensue.
3299  MLIRContext *ctx = builder.getContext();
3300  auto operationType = transform::AnyOpType::get(ctx);
3301  auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
3302  build(builder, result,
3303  /*resultTypes=*/TypeRange{operationType, operationType},
3304  /*target=*/target,
3305  /*num_threads=*/ValueRange{},
3306  /*tile_sizes=*/dynamicTileSizes,
3307  /*packed_num_threads=*/Value(),
3308  /*packed_tile_sizes=*/Value(),
3309  /*static_num_threads=*/builder.getDenseI64ArrayAttr({}),
3310  /*static_tile_sizes=*/staticTileSizesAttr,
3311  /*mapping=*/mapping);
3312 }
3313 
3314 void transform::TileUsingForallOp::build(OpBuilder &builder,
3315  OperationState &result, Value target,
3316  ArrayRef<int64_t> staticNumThreads,
3318  ArrayAttr mapping) {
3319  return build(builder, result, target,
3320  getAsOpFoldResult(builder.getI64ArrayAttr(staticNumThreads)),
3321  NumThreadsSpec(), mapping);
3322 }
3323 
3324 void transform::TileUsingForallOp::build(OpBuilder &builder,
3325  OperationState &result, Value target,
3326  ArrayRef<OpFoldResult> mixedNumThreads,
3328  ArrayAttr mapping) {
3329  SmallVector<int64_t> staticNumThreads;
3330  SmallVector<Value> dynamicNumThreads;
3331  dispatchIndexOpFoldResults(mixedNumThreads, dynamicNumThreads,
3332  staticNumThreads);
3333  // Call the default builder which sets up the proper operands segment sizes
3334  // attributes for multiple variadic operands. In the absence of this,
3335  // horrible bugs ensue.
3336  MLIRContext *ctx = builder.getContext();
3337  auto operationType = transform::AnyOpType::get(ctx);
3338  auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads);
3339  build(builder, result,
3340  /*resultTypes=*/TypeRange{operationType, operationType},
3341  /*target=*/target,
3342  /*num_threads=*/dynamicNumThreads,
3343  /*tile_sizes=*/ValueRange{},
3344  /*packed_num_threads=*/Value(),
3345  /*packed_tile_sizes=*/Value(),
3346  /*static_num_threads=*/staticNumThreadsAttr,
3347  /*static_tile_sizes=*/builder.getDenseI64ArrayAttr({}),
3348  /*mapping=*/mapping);
3349 }
3350 
3351 /// Given `lbs`, `ubs` and `steps` of loops, return (for each loop), the
3352 /// normalized upper bound.
3356  ArrayRef<OpFoldResult> steps) {
3357  AffineExpr s0, s1, s2;
3358  bindSymbols(rewriter.getContext(), s0, s1, s2);
3359  AffineExpr normalizedUbExpr = (s1 - s0).ceilDiv(s2);
3360  SmallVector<OpFoldResult> normalizedUbs;
3361  for (auto [lb, ub, step] : llvm::zip_equal(lbs, ubs, steps)) {
3363  rewriter, loc, normalizedUbExpr, {lb, ub, step});
3364  normalizedUbs.push_back(normalizedUb);
3365  }
3366  return normalizedUbs;
3367 }
3368 
3369 /// When a loop is normalized, the uses of the induction variable within the
3370 /// loop need to replaced with `original_lb + old_iv * original_step`.
3372  Location loc, ValueRange ivs,
3374  ArrayRef<OpFoldResult> steps) {
3375  AffineExpr s0, s1;
3376  AffineExpr d0;
3377  bindSymbols(rewriter.getContext(), s0, s1);
3378  bindDims(rewriter.getContext(), d0);
3379  AffineExpr denormExpr = s0 + d0 * s1;
3380  SmallVector<Value> denormalizedIvs;
3381 
3382  for (auto [iv, lb, step] : llvm::zip_equal(ivs, lbs, steps)) {
3384  rewriter, loc, denormExpr, ArrayRef<OpFoldResult>{iv, lb, step});
3385  denormalizedIvs.push_back(
3386  getValueOrCreateConstantIndexOp(rewriter, loc, denormValue));
3387  }
3388  return denormalizedIvs;
3389 }
3390 
3391 /// Given a `scf.forall` loop return a loop op with the loop bounds
3392 /// normalized.
3393 /// TODO: Replace this with a general utility to normalize `scf.forall`.
3394 /// At the time of writing, this wasnt done since adding this to `scf`
3395 /// dialect would disallow using of `affine.apply` operations due
3396 /// to cyclic dependencies. To avoid churn in lit tests
3397 /// with the change this was added with, defer that to a follow up.
3398 static scf::ForallOp normalizeForallLoopOp(RewriterBase &rewriter,
3399  scf::ForallOp loop) {
3400  SmallVector<OpFoldResult> lbs = loop.getMixedLowerBound();
3401  SmallVector<OpFoldResult> ubs = loop.getMixedUpperBound();
3402  SmallVector<OpFoldResult> steps = loop.getMixedStep();
3403 
3404  if (llvm::all_of(
3405  lbs, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }) &&
3406  llvm::all_of(
3407  steps, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); })) {
3408  return loop;
3409  }
3410 
3411  Location loc = loop.getLoc();
3412  SmallVector<OpFoldResult> normalizedUbs =
3413  normalizeUpperBounds(rewriter, loc, lbs, ubs, steps);
3414  SmallVector<OpFoldResult> normalizedLbs(normalizedUbs.size(),
3415  rewriter.getIndexAttr(0));
3416  SmallVector<OpFoldResult> normalizedSteps(normalizedUbs.size(),
3417  rewriter.getIndexAttr(1));
3418 
3419  auto normalizedForallOp = rewriter.create<scf::ForallOp>(
3420  loc, normalizedLbs, normalizedUbs, normalizedSteps, loop.getOutputs(),
3421  loop.getMapping(), [](OpBuilder &, Location, ValueRange) {});
3422 
3423  auto normalizedLoopIvs = normalizedForallOp.getInductionVars();
3424  OpBuilder::InsertionGuard g(rewriter);
3425  Block *normalizedLoopBlock = normalizedForallOp.getBody();
3426  rewriter.setInsertionPointToStart(normalizedLoopBlock);
3427 
3428  SmallVector<Value> argValues =
3429  denormalizeIndVar(rewriter, loc, normalizedLoopIvs, lbs, steps);
3430  argValues.append(normalizedForallOp.getRegionIterArgs().begin(),
3431  normalizedForallOp.getRegionIterArgs().end());
3432  Block *origLoopBlock = loop.getBody();
3433  rewriter.mergeBlocks(origLoopBlock, normalizedLoopBlock, argValues);
3434 
3435  rewriter.replaceOp(loop, normalizedForallOp);
3436  return normalizedForallOp;
3437 }
3438 
3440  RewriterBase &rewriter, transform::TransformState &state,
3441  TransformOpInterface transformOp, Operation *target,
3442  ArrayRef<OpFoldResult> mixedNumThreads,
3443  ArrayRef<OpFoldResult> mixedTileSizes, std::optional<ArrayAttr> mapping,
3444  scf::SCFTilingResult &tilingResult) {
3445  // Transform all targets one by one.
3446  auto tileableOp = dyn_cast<TilingInterface>(target);
3447  if (!tileableOp) {
3449  transformOp.emitSilenceableError()
3450  << "only TilingInterface ops are supported";
3451  diag.attachNote(target->getLoc()) << "target op";
3452  return diag;
3453  }
3454  rewriter.setInsertionPoint(tileableOp);
3457  if (!mixedNumThreads.empty()) {
3458  options.setNumThreads(mixedNumThreads);
3459  } else {
3460  options.setTileSizes(mixedTileSizes);
3461  }
3462  if (mapping) {
3463  options.setMapping(mapping.value().getValue());
3464  }
3465  FailureOr<scf::SCFTilingResult> maybeTilingResult =
3466  scf::tileUsingSCF(rewriter, tileableOp, options);
3467 
3468  if (failed(maybeTilingResult))
3469  return transformOp.emitDefaultSilenceableFailure(tileableOp);
3470 
3471  rewriter.replaceOp(tileableOp, maybeTilingResult->mergeResult.replacements);
3472 
3473  tilingResult = *maybeTilingResult;
3474 
3475  if (mixedNumThreads.empty()) {
3476  auto generatedForallOp = cast<scf::ForallOp>(tilingResult.loops.front());
3477  OpBuilder::InsertionGuard g(rewriter);
3478  rewriter.setInsertionPoint(generatedForallOp);
3479  scf::ForallOp normalizedForallOp =
3480  normalizeForallLoopOp(rewriter, generatedForallOp);
3481  tilingResult.loops.front() = normalizedForallOp;
3482  }
3483 
3485 }
3486 
3487 DiagnosedSilenceableFailure transform::TileUsingForallOp::apply(
3488  transform::TransformRewriter &rewriter,
3489  transform::TransformResults &transformResults,
3490  transform::TransformState &state) {
3491  auto transformOp = cast<TransformOpInterface>(getOperation());
3492 
3493  // Result payload ops.
3494  SmallVector<Operation *> tileOps;
3495  SmallVector<Operation *> tiledOps;
3496 
3497  // Unpack handles.
3498  SmallVector<OpFoldResult> mixedNumThreads;
3500  getPackedNumThreads()
3502  state, transformOp, mixedNumThreads, getPackedNumThreads())
3504  state, transformOp, mixedNumThreads, getMixedNumThreads());
3505  if (!status.succeeded())
3506  return status;
3507  SmallVector<OpFoldResult> mixedTileSizes;
3508  status = getPackedTileSizes()
3510  state, transformOp, mixedTileSizes, getPackedTileSizes())
3512  state, transformOp, mixedTileSizes, getMixedTileSizes());
3513  if (!status.succeeded())
3514  return status;
3515 
3516  for (Operation *target : state.getPayloadOps(getTarget())) {
3517  scf::SCFTilingResult tilingResult;
3519  rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes,
3520  getMapping(), tilingResult);
3521  if (!diag.succeeded())
3522  return diag;
3523  tileOps.push_back(tilingResult.loops.front());
3524  tiledOps.append(tilingResult.tiledOps);
3525  }
3526 
3527  transformResults.set(cast<OpResult>(getForallOp()), tileOps);
3528  transformResults.set(cast<OpResult>(getTiledOp()), tiledOps);
3529 
3531 }
3532 
3533 void transform::TileUsingForallOp::getEffects(
3535  consumesHandle(getTargetMutable(), effects);
3536  onlyReadsHandle(getTileSizesMutable(), effects);
3537  onlyReadsHandle(getNumThreadsMutable(), effects);
3538  onlyReadsHandle(getPackedNumThreadsMutable(), effects);
3539  onlyReadsHandle(getPackedTileSizesMutable(), effects);
3540  producesHandle(getOperation()->getOpResults(), effects);
3541  modifiesPayload(effects);
3542 }
3543 
3544 SmallVector<OpFoldResult> TileUsingForallOp::getMixedNumThreads() {
3545  Builder b(getContext());
3546  return getMixedValues(getStaticNumThreads(), getNumThreads(), b);
3547 }
3548 
3549 SmallVector<OpFoldResult> TileUsingForallOp::getMixedTileSizes() {
3550  Builder b(getContext());
3551  return getMixedValues(getStaticTileSizes(), getTileSizes(), b);
3552 }
3553 
3554 LogicalResult TileUsingForallOp::verify() {
3555  int numThreadsSpec = static_cast<int>(!getMixedNumThreads().empty()) +
3556  static_cast<int>(getPackedNumThreads() != Value());
3557  if (numThreadsSpec > 1)
3558  return emitOpError(
3559  "num_threads and packed_num_threads are mutually exclusive");
3560  int tileSizesSpec = static_cast<int>(!getMixedTileSizes().empty()) +
3561  static_cast<int>(getPackedTileSizes() != Value());
3562  if (tileSizesSpec > 1)
3563  return emitOpError(
3564  "tile_sizes and packed_tile_sizes are mutually exclusive");
3565  if (numThreadsSpec == 0 && tileSizesSpec == 0)
3566  return emitOpError("either (packed_)num_threads or (packed_)tile_sizes "
3567  "must be specified");
3568  return success();
3569 }
3570 
3571 //===----------------------------------------------------------------------===//
3572 // VectorizeChildrenAndApplyPatternsOp
3573 //===----------------------------------------------------------------------===//
3574 
3575 void transform::VectorizeChildrenAndApplyPatternsOp::build(
3576  OpBuilder &builder, OperationState &result, Value target,
3577  bool vectorizePadding, bool vectorizeExtract, bool flatten1DDepthwiseConv) {
3578  result.addOperands(target);
3579  if (vectorizePadding) {
3580  result.addAttribute(
3581  VectorizeChildrenAndApplyPatternsOp::getVectorizePaddingAttrName(
3582  result.name),
3583  builder.getUnitAttr());
3584  }
3585  if (vectorizeExtract) {
3586  result.addAttribute(
3587  VectorizeChildrenAndApplyPatternsOp::getVectorizeNdExtractAttrName(
3588  result.name),
3589  builder.getUnitAttr());
3590  }
3591  if (flatten1DDepthwiseConv) {
3592  result.addAttribute(
3593  VectorizeChildrenAndApplyPatternsOp::getFlatten_1dDepthwiseConvAttrName(
3594  result.name),
3595  builder.getUnitAttr());
3596  }
3597  result.addTypes(transform::AnyOpType::get(builder.getContext()));
3598 }
3599 
3600 namespace {
3601 /// This is an helper only to call vectorize via a pattern inside of
3602 /// VectorizeChildrenAndApplyPatternsOp::applyToOne.
3603 struct VectorizationPattern : public RewritePattern {
3604  explicit VectorizationPattern(MLIRContext *context,
3605  bool vectorizeExtract = false,
3606  bool flattenConv = false)
3607  : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context),
3608  vectorizeNDExtract(vectorizeExtract),
3609  flatten1DDepthwiseConv(flattenConv) {}
3610  LogicalResult matchAndRewrite(Operation *op,
3611  PatternRewriter &rewriter) const override {
3613  return rewriter.notifyMatchFailure(op,
3614  "Unsupported Op, cannot vectorize");
3615  return vectorize(rewriter, op, /*inputVectorSizes=*/{},
3616  /*inputScalableVecDims=*/{}, vectorizeNDExtract,
3617  flatten1DDepthwiseConv);
3618  }
3619 
3620 private:
3621  /// Controls whether to vectorize `tensor.extract` when the input tensor is
3622  /// rank >= 2.
3623  bool vectorizeNDExtract = false;
3624  /// Controls whether to "flatten" the channel dimension when vectorising 1D
3625  /// depthwise convolutions. This should lead to bette vectorization for
3626  /// tensors with a low number of channel dimensions.
3627  bool flatten1DDepthwiseConv = false;
3628 };
3629 } // namespace
3630 
3632 transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
3633  transform::TransformRewriter &rewriter, Operation *target,
3635  transform::TransformState &state) {
3636  if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
3637  auto diag = this->emitOpError("requires isolated-from-above targets");
3638  diag.attachNote(target->getLoc()) << "non-isolated target";
3640  }
3641 
3642  MLIRContext *ctx = getContext();
3644  patterns.add<VectorizationPattern>(ctx, getVectorizeNdExtract(),
3645  getFlatten_1dDepthwiseConv());
3646 
3647  if (!getDisableTransferPermutationMapLoweringPatterns())
3649 
3650  if (!getDisableMultiReductionToContractPatterns())
3652 
3654 
3657  /*benefit=*/2);
3658  vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx);
3659  vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx);
3661 
3663 
3664  if (getVectorizePadding()) {
3666  // This creates an alternative path for lowering tensor.pad - by
3667  // decomposing it into e.g. linalg.fill.
3669  }
3671 
3672  TrackingListener listener(state, *this);
3673  if (failed(
3674  applyPatternsGreedily(target, std::move(patterns),
3675  GreedyRewriteConfig().setListener(&listener))))
3676  return emitDefaultDefiniteFailure(target);
3677 
3678  results.push_back(target);
3680 }
3681 
3682 //===----------------------------------------------------------------------===//
3683 // VectorizeOp
3684 //===----------------------------------------------------------------------===//
3685 
3686 DiagnosedSilenceableFailure transform::VectorizeOp::apply(
3687  transform::TransformRewriter &rewriter,
3688  mlir::transform::TransformResults &transformResults,
3690  auto targets = state.getPayloadOps(getTarget());
3691  if (std::empty(targets))
3693  auto transformOp = cast<TransformOpInterface>(getOperation());
3694  SmallVector<int64_t> vectorSizes;
3696  state, transformOp, getMixedVectorSizes(), vectorSizes);
3697  if (!status.succeeded())
3698  return status;
3699 
3700  // TODO: Check that the correct number of vectorSizes was provided.
3701  for (Operation *target : targets) {
3702  if (!linalg::hasVectorizationImpl(target)) {
3703  return mlir::emitSilenceableFailure(target->getLoc())
3704  << "Unsupported Op, cannot vectorize";
3705  }
3706 
3707  if (failed(linalg::vectorize(rewriter, target, vectorSizes,
3708  getScalableSizes(),
3709  getVectorizeNdExtract().value_or(false)))) {
3710  return mlir::emitSilenceableFailure(target->getLoc())
3711  << "Attempted to vectorize, but failed";
3712  }
3713  }
3714 
3716 }
3717 
3718 void transform::VectorizeOp::getEffects(
3720  consumesHandle(getTargetMutable(), effects);
3721  onlyReadsHandle(getVectorSizesMutable(), effects);
3722  modifiesPayload(effects);
3723 }
3724 
3725 SmallVector<OpFoldResult> VectorizeOp::getMixedVectorSizes() {
3726  OpBuilder b(getContext());
3727  return getMixedValues(getStaticVectorSizes(), getVectorSizes(), b);
3728 }
3729 
3730 LogicalResult transform::VectorizeOp::verify() {
3731  if (getStaticVectorSizes().size() != getScalableSizes().size())
3732  return emitOpError("expected same number of vector sizes (")
3733  << getStaticVectorSizes().size() << ") and scalable sizes ("
3734  << getScalableSizes().size() << ")";
3735  return success();
3736 }
3737 
3738 //===----------------------------------------------------------------------===//
3739 // HoistRedundantVectorTransfersOp
3740 //===----------------------------------------------------------------------===//
3741 
3743 transform::HoistRedundantVectorTransfersOp::applyToOne(
3744  transform::TransformRewriter &rewriter, func::FuncOp target,
3746  transform::TransformState &state) {
3747  // WARNING: This hoisting does not model parallelism and is generally
3748  // incorrect when used on distributed loops with memref semantics!
3749  // TODO: obsolete and should be retired.
3750  linalg::hoistRedundantVectorTransfers(target, getVerifyNonZeroTrip());
3751  results.push_back(target);
3753 }
3754 
3755 //===----------------------------------------------------------------------===//
3756 // HoistRedundantVectorBroadcastsOp
3757 //===----------------------------------------------------------------------===//
3758 
3760 transform::HoistRedundantVectorBroadcastsOp::applyToOne(
3761  transform::TransformRewriter &rewriter, mlir::Operation *target,
3763  transform::TransformState &state) {
3764  rewriter.setInsertionPoint(target);
3765  linalg::hoistRedundantVectorBroadcasts(rewriter, target);
3766  results.push_back(target);
3768 }
3769 
3770 //===----------------------------------------------------------------------===//
3771 // ConvertConv2DToImg2ColOp.
3772 //===----------------------------------------------------------------------===//
3773 
3774 DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne(
3775  transform::TransformRewriter &rewriter, linalg::LinalgOp target,
3777  transform::TransformState &state) {
3778  rewriter.setInsertionPoint(target);
3779  auto maybeTransformed =
3781  target)
3782  .Case([&](linalg::Conv2DNhwcHwcfOp op) {
3783  return rewriteInIm2Col(rewriter, op);
3784  })
3785  .Case([&](linalg::Conv2DNhwcFhwcOp op) {
3786  return rewriteInIm2Col(rewriter, op);
3787  })
3788  .Case([&](linalg::DepthwiseConv2DNhwcHwcOp op) {
3789  return rewriteInIm2Col(rewriter, op);
3790  })
3791  .Case([&](linalg::Conv2DNchwFchwOp op) {
3792  return rewriteInIm2Col(rewriter, op);
3793  })
3794  .Default([&](Operation *op) {
3795  return rewriter.notifyMatchFailure(op, "not supported");
3796  });
3797  if (failed(maybeTransformed))
3798  return emitDefaultSilenceableFailure(target);
3799  // Handle to the operation producing the img2col tensor.
3800  results.push_back(maybeTransformed->first);
3801  // Handle to the operation that replaces the original convolution.
3802  results.push_back(maybeTransformed->second);
3804 }
3805 
3806 //===----------------------------------------------------------------------===//
3807 // FlattenElementwiseLinalgOp.
3808 //===----------------------------------------------------------------------===//
3809 
3810 DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne(
3811  transform::TransformRewriter &rewriter, linalg::LinalgOp target,
3813  transform::TransformState &state) {
3814  rewriter.setInsertionPoint(target);
3815  if (!isElementwise(target))
3816  return mlir::emitSilenceableFailure(target->getLoc())
3817  << "only elementwise flattening is supported";
3818 
3819  // If rank <= 1, do nothing
3820  if (target.getNumLoops() <= 1) {
3821  results.push_back(target);
3823  }
3824 
3825  // Attempt to flatten all dims to one.
3826  ReassociationIndices reassociation(target.getNumLoops());
3827  std::iota(reassociation.begin(), reassociation.end(), 0);
3828  auto maybeFlattened =
3829  collapseOpIterationDims(target, reassociation, rewriter);
3830  if (failed(maybeFlattened))
3831  return mlir::emitSilenceableFailure(target->getLoc())
3832  << "attempted to flatten, but failed";
3833  results.push_back(maybeFlattened->collapsedOp);
3834  rewriter.replaceOp(target, maybeFlattened->results);
3836 }
3837 
3838 //===----------------------------------------------------------------------===//
3839 // TransposeConv2DOp
3840 //===----------------------------------------------------------------------===//
3841 
3842 DiagnosedSilenceableFailure transform::TransposeConv2DOp::applyToOne(
3843  transform::TransformRewriter &rewriter, linalg::LinalgOp target,
3845  transform::TransformState &state) {
3846  rewriter.setInsertionPoint(target);
3847  auto maybeTransformed =
3849  .Case([&](linalg::Conv2DNhwcFhwcOp op) {
3850  return transposeConv2D(rewriter, op);
3851  })
3852  .Case([&](linalg::Conv2DNhwcFhwcQOp op) {
3853  return transposeConv2D(rewriter, op);
3854  })
3855  .Default([&](Operation *op) {
3856  return rewriter.notifyMatchFailure(op, "not supported");
3857  });
3858  if (failed(maybeTransformed))
3859  return emitDefaultSilenceableFailure(target);
3860  // Handle to the new Conv2D operation with transposed filters
3861  results.push_back(*maybeTransformed);
3863 }
3864 
3865 //===----------------------------------------------------------------------===//
3866 // TransposeMatmulOp
3867 //===----------------------------------------------------------------------===//
3868 
3869 DiagnosedSilenceableFailure transform::TransposeMatmulOp::applyToOne(
3870  transform::TransformRewriter &rewriter, linalg::LinalgOp target,
3872  transform::TransformState &state) {
3873  rewriter.setInsertionPoint(target);
3874  bool transposeLHS = getInputToTranspose() == TransposeMatmulInput::lhs;
3875  auto maybeTransformed =
3877  .Case([&](linalg::MatmulOp op) {
3878  return transposeMatmul(rewriter, op, transposeLHS);
3879  })
3880  .Case([&](linalg::BatchMatmulOp op) {
3881  return transposeBatchMatmul(rewriter, op, transposeLHS);
3882  })
3883  .Default([&](Operation *op) { return failure(); });
3884  if (failed(maybeTransformed))
3885  return emitSilenceableFailure(target->getLoc()) << "not supported";
3886  // Handle to the new Matmul operation with transposed filters
3887  results.push_back(*maybeTransformed);
3889 }
3890 
3891 //===----------------------------------------------------------------------===//
3892 // InsertSliceToCopyOp
3893 //===----------------------------------------------------------------------===//
3894 template <typename OpTy>
3897  transform::TransformState &state) {
3898  static_assert(llvm::is_one_of<OpTy, tensor::InsertSliceOp,
3899  tensor::ParallelInsertSliceOp>() &&
3900  "wrong op type");
3901 
3902  if (auto copySource =
3903  target.getSource().template getDefiningOp<linalg::CopyOp>()) {
3904  results.push_back(copySource);
3906  }
3907 
3908  // If we are inside an InParallel region, temporarily set the insertion point
3909  // outside: only tensor.parallel_insert_slice ops are allowed in there.
3910  if constexpr (std::is_same_v<OpTy, tensor::ParallelInsertSliceOp>) {
3911  rewriter.setInsertionPoint(
3912  target->template getParentOfType<scf::InParallelOp>());
3913  }
3914 
3915  Value extracted = rewriter.create<tensor::ExtractSliceOp>(
3916  target.getLoc(), target.getDest(), target.getMixedOffsets(),
3917  target.getMixedSizes(), target.getMixedStrides());
3918  Value copied = rewriter
3919  .create<linalg::CopyOp>(target.getLoc(),
3920  target.getSource(), extracted)
3921  .getResult(0);
3922  // Reset the insertion point.
3923  rewriter.setInsertionPoint(target);
3924  rewriter.replaceOpWithNewOp<OpTy>(
3925  target, copied, target.getDest(), target.getMixedOffsets(),
3926  target.getMixedSizes(), target.getMixedStrides());
3927 
3928  results.push_back(copied.getDefiningOp());
3930 }
3931 
3932 DiagnosedSilenceableFailure transform::InsertSliceToCopyOp::applyToOne(
3933  transform::TransformRewriter &rewriter, Operation *targetOp,
3935  transform::TransformState &state) {
3936 
3937  rewriter.setInsertionPoint(targetOp);
3938  if (auto target = dyn_cast<tensor::InsertSliceOp>(targetOp))
3939  return doit(rewriter, target, results, state);
3940  if (auto target = dyn_cast<tensor::ParallelInsertSliceOp>(targetOp))
3941  return doit(rewriter, target, results, state);
3942 
3944  emitSilenceableError()
3945  << "only InsertSliceOp and ParallelInsertSliceOp ops are supported";
3946  diag.attachNote(targetOp->getLoc()) << "target op";
3947  return diag;
3948 }
3949 
3950 //===----------------------------------------------------------------------===//
3951 // MapCopyToThreadsOp
3952 //===----------------------------------------------------------------------===//
3953 
3954 DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
3955  transform::TransformRewriter &rewriter, Operation *target,
3957  transform::TransformState &state) {
3958  // Check if the op is supported.
3959  if (!isa<linalg::CopyOp, tensor::PadOp>(target)) {
3961  emitSilenceableError()
3962  << "only linalg.copy and tensor.pad target ops are supported";
3963  diag.attachNote(target->getLoc()) << "target op";
3964  return diag;
3965  }
3966  assert(target->getNumResults() == 1 && "expected single result");
3967  auto resultShapedType = cast<ShapedType>(target->getResult(0).getType());
3968  if (!resultShapedType.hasStaticShape()) {
3970  emitSilenceableError()
3971  << "only statically sized ops of rank <= 3 are supported";
3972  diag.attachNote(target->getLoc()) << "target op";
3973  return diag;
3974  }
3975 
3976  // Conservatively set the minimum viable desired bitwidth alignment.
3977  int64_t desiredBitAlignment = getDesiredBitAlignment();
3978  int64_t eltBitwidth =
3979  resultShapedType.getElementType().getIntOrFloatBitWidth();
3980  if (desiredBitAlignment % eltBitwidth != 0) {
3981  desiredBitAlignment = eltBitwidth;
3982  }
3983 
3984  gpu::CopyMappingInfo mapping(
3985  /*ctx=*/getContext(),
3986  /*totalNumThreads=*/getTotalNumThreads(),
3987  /*alignment=*/desiredBitAlignment,
3988  /*sizes=*/resultShapedType.getShape(),
3989  /*favorPredication=*/false,
3990  /*elementalBitwidth=*/
3991  resultShapedType.getElementType().getIntOrFloatBitWidth());
3992  if (mapping.status == gpu::CopyMappingInfo::Status::Invalid) {
3994  emitSilenceableError()
3995  << "too few threads to map copy op to threads on the most minor "
3996  "dimension, given alignment and vector size constraints, try "
3997  "smaller tile size of mapping to more threads";
3998  diag.attachNote(target->getLoc()) << "target op";
3999  return diag;
4000  }
4001 
4002  // OpBuilder only used to compute attributes.
4003  OpBuilder b(getContext());
4004  scf::SCFTilingResult tilingResult;
4006  /*rewriter=*/rewriter,
4007  /*state=*/state,
4008  /*transformOp=*/*this,
4009  /*target=*/target,
4010  /*mixedNumThreads=*/getMixedValues(mapping.numThreads, {}, b),
4011  /*mixedTileSizes=*/ArrayRef<OpFoldResult>{},
4012  /*mapping=*/b.getArrayAttr(mapping.threadMapping),
4013  /*tilingResult=*/tilingResult);
4014  if (!diag.succeeded())
4015  return diag;
4016 
4017  results.push_back(tilingResult.loops.front());
4018  for (auto op : tilingResult.tiledOps)
4019  results.push_back(op);
4021 }
4022 
4023 //===----------------------------------------------------------------------===//
4024 // WinogradConv2DOp
4025 //===----------------------------------------------------------------------===//
4026 
4027 DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne(
4028  transform::TransformRewriter &rewriter, linalg::LinalgOp target,
4030  transform::TransformState &state) {
4031  rewriter.setInsertionPoint(target);
4032  FailureOr<Operation *> maybeTransformed = failure();
4033  bool supported = TypeSwitch<Operation *, bool>(target)
4034  .Case([&](linalg::Conv2DNhwcFhwcOp op) {
4035  maybeTransformed =
4036  winogradConv2D(rewriter, op, getM(), getR());
4037  return true;
4038  })
4039  .Default([&](Operation *op) { return false; });
4040 
4041  if (!supported) {
4042  return emitSilenceableError()
4043  << "this operation is not supported to convert to Winograd Conv2D";
4044  }
4045 
4046  if (failed(maybeTransformed)) {
4047  return emitSilenceableError() << "apply Winograd Conv2D failed";
4048  }
4049 
4050  results.push_back(*maybeTransformed);
4052 }
4053 
4054 DiagnosedSilenceableFailure transform::DecomposeWinogradOp::applyToOne(
4055  transform::TransformRewriter &rewriter, Operation *target,
4057  transform::TransformState &state) {
4058  rewriter.setInsertionPoint(target);
4059  FailureOr<Operation *> maybeTransformed = failure();
4060  bool supported =
4062  .Case([&](linalg::WinogradFilterTransformOp op) {
4063  maybeTransformed = decomposeWinogradFilterTransformOp(rewriter, op);
4064  return true;
4065  })
4066  .Case([&](linalg::WinogradInputTransformOp op) {
4067  maybeTransformed = decomposeWinogradInputTransformOp(rewriter, op);
4068  return true;
4069  })
4070  .Case([&](linalg::WinogradOutputTransformOp op) {
4071  maybeTransformed = decomposeWinogradOutputTransformOp(rewriter, op);
4072  return true;
4073  })
4074  .Default([&](Operation *op) { return false; });
4075 
4076  if (!supported) {
4078  emitSilenceableError()
4079  << "this operation is not supported to decompose into other operations";
4080  diag.attachNote(target->getLoc()) << "target op";
4081  return diag;
4082  }
4083 
4084  if (failed(maybeTransformed)) {
4086  emitSilenceableError() << "decompose Winograd operations failed";
4087  diag.attachNote(target->getLoc()) << "target op";
4088  return diag;
4089  }
4090 
4091  results.push_back(*maybeTransformed);
4093 }
4094 
4095 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
4096 
4097 #define GET_OP_CLASSES
4098 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
static MLIRContext * getContext(OpFoldResult val)
DiagnosedSilenceableFailure doit(RewriterBase &rewriter, OpTy target, transform::ApplyToEachResultList &results, transform::TransformState &state)
static Operation * cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
#define DOWNSCALE(trans)
bool isValidPackingPermutation(RelayoutOpTy op, ArrayRef< int64_t > permutation, OuterOrInnerPerm outerOrInnerPerm=OuterOrInnerPerm::Outer)
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static DiagnosedSilenceableFailure reifyMixedParamAndHandleResults(TransformState &state, TransformOpInterface &transformOp, ArrayRef< OpFoldResult > mixedResults, SmallVectorImpl< int64_t > &reified)
When possible, converts each OpFoldResult in mixedResult to an integer if the value can be statically...
static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(transform::TransformState &state, TransformOpInterface transformOp, SmallVector< OpFoldResult > &result, ArrayRef< OpFoldResult > ofrs)
Assuming that ofr is an index attr or a param of index type or a transform dialect handle mapped to e...
static void printContinuousTileSizeTypes(OpAsmPrinter &printer, Operation *op, Type targetType, Type tile_sizes, Type)
static scf::ForallOp normalizeForallLoopOp(RewriterBase &rewriter, scf::ForallOp loop)
Given a scf.forall loop return a loop op with the loop bounds normalized.
static SmallVector< Value > denormalizeIndVar(RewriterBase &rewriter, Location loc, ValueRange ivs, ArrayRef< OpFoldResult > lbs, ArrayRef< OpFoldResult > steps)
When a loop is normalized, the uses of the induction variable within the loop need to replaced with o...
#define DOWNSCALE_NORMAL(a, b)
static FailureOr< LinalgOp > tryApply(Operation *operation, Args &&...args)
Attempts to apply the pattern specified as template argument to the given operation.
static void printMultitileSizesTypes(OpAsmPrinter &printer, Operation *op, Type targetType, Type lowSizeType, Type, Type)
static bool sameOrEquivalentIterArg(Value src, Value dst)
Given two operands coming from a loop iter arg, 'src' and 'dst', return true if the operand 'src' is ...
static Operation * replaceForAllWithNewSignature(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp, TilingResult &tileAndFuseResult, int64_t resultNumber, SmallVector< OpFoldResult > &offsets, SmallVector< OpFoldResult > &sizes)
Add new operands to the forall op for users of the producerOp that are dominated by the containing sc...
static ParseResult parseContinuousTileSizeTypes(OpAsmParser &parser, Type &targetType, Type &tileSizesType, Type &chunkSizesType)
static SmallVector< Operation * > tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
First, find the first "scf::ForallOp" user of producerOp and ensure it is exactly the containingOp,...
static ParseResult parseMultitileSizesTypes(OpAsmParser &parser, Type &targetType, Type &lowSizeType, Type &highSizeType, Type &splitPointType)
static SmallVector< OpFoldResult > normalizeUpperBounds(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > lbs, ArrayRef< OpFoldResult > ubs, ArrayRef< OpFoldResult > steps)
Given lbs, ubs and steps of loops, return (for each loop), the normalized upper bound.
#define DBGS()
static LogicalResult applyTilingToAll(RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps, unsigned numLoops, transform::TransformResults &transformResults, function_ref< FailureOr< scf::SCFTileAndFuseResult >(TilingInterface)> applyFn)
Apply a tiling transformation to all payload ops and store both the tiled operation as well as the cr...
static std::tuple< SmallVector< Operation * >, Operation * > tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
Find the first "extract" user of producerOp and tile it right before its use.
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:295
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:33
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
UnitAttr getUnitAttr()
Definition: Builders.cpp:94
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:224
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:163
AffineExpr getAffineSymbolExpr(unsigned position)
Definition: Builders.cpp:364
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:108
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:88
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:258
MLIRContext * getContext() const
Definition: Builders.h:55
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:262
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:277
IndexType getIndexType()
Definition: Builders.cpp:51
ArrayAttr getStrArrayAttr(ArrayRef< StringRef > values)
Definition: Builders.cpp:302
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
A class for computing basic dominance information.
Definition: Dominance.h:140
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition: Dominance.h:158
This class allows control over how the GreedyPatternRewriteDriver works.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
Definition: AsmPrinter.cpp:95
This class represents a saved insertion point.
Definition: Builders.h:324
bool isSet() const
Returns true if this insert point is set.
Definition: Builders.h:334
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:345
This class helps build Operations.
Definition: Builders.h:204
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:549
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
Definition: Builders.h:313
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:428
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:395
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:317
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:409
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
This class represents an operand of an operation.
Definition: Value.h:243
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:433
This class provides the API for ops that are known to be isolated from above.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
OpResult getOpResult(unsigned idx)
Definition: Operation.h:421
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:534
void setOperand(unsigned idx, Value value)
Definition: Operation.h:351
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition: Operation.h:560
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:346
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition: Operation.h:263
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:873
result_range getOpResults()
Definition: Operation.h:420
result_range getResults()
Definition: Operation.h:415
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:219
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:673
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:238
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:602
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:594
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:54
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
Type front()
Return first type in the range.
Definition: TypeRange.h:152
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
user_range getUsers() const
Definition: Value.h:204
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
State for analysis-enabled bufferization.
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
void assign(unsigned size, std::nullptr_t)
Sets the list of results to size null pointers.
void reserve(unsigned size)
Reserves space for size elements in the list.
size_t size() const
Returns the number of elements in the list.
void push_back(Operation *op)
Appends an element to the list.
A listener that updates a TransformState based on IR modifications.
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void setValues(OpResult handle, Range &&values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
void setParams(OpResult value, ArrayRef< TransformState::Param > params)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
LogicalResult notifyPayloadOperationReplaced(Operation *op, Operation *replacement)
Notify the transform dialect interpreter that the given op has been replaced with another op and that...
The state maintained across applications of various ops implementing the TransformOpInterface.
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
Definition: AffineOps.cpp:1270
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1174
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1224
LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze op and its nested ops.
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition: Visitors.h:136
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
FailureOr< PackingResult > buildPackingLoopNest(RewriterBase &rewriter, tensor::PadOp opToHoist, scf::ForOp outermostEnclosingForOp, ArrayRef< int64_t > transposeVector)
Build the packing loop nest required to hoist opToHoist above outermostEnclosingForOp.
LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, const LinalgPaddingOptions &options, LinalgOp &paddedOp, SmallVector< Value > &replacements, SmallVector< tensor::PadOp > &padOps)
Pad the iterator dimensions paddingDimensions of all opToPad operands to a static bounding box.
Definition: Padding.cpp:153
FailureOr< std::pair< Operation *, Operation * > > rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp)
Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) and linalg....
bool hasVectorizationImpl(Operation *)
Return true if there's dedicated logic in the Linalg Vectorizer to vectorize this Op,...
FailureOr< Operation * > decomposeWinogradFilterTransformOp(RewriterBase &rewriter, linalg::WinogradFilterTransformOp op)
Rewrite linalg.winograd_filter_transform.
std::optional< Value > allocateWorkgroupMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU workgroup memory.
Definition: Promotion.cpp:469
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, linalg::PackOp packOp, linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
Definition: Transforms.cpp:679
Value bufferizeToAllocation(RewriterBase &rewriter, const BufferizeToAllocationOptions &options, tensor::PadOp padOp, Attribute memorySpace={}, Operation *insertionPoint=nullptr)
Materialize a buffer allocation for the given tensor.pad op and lower the op to linalg....
FailureOr< Value > hoistPaddingOnTensors(RewriterBase &rewriter, tensor::PadOp opToHoist, int64_t numLoops, ArrayRef< int64_t > transposeVector, tensor::PadOp &hoistedOp, SmallVectorImpl< TransposeOp > &transposeOps)
Mechanically hoist padding operations on tensors by numLoops into a new, generally larger tensor.
FailureOr< LinalgOp > specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp)
Create a namedOp from the given GenericOp and replace the GenericOp.
Definition: Specialize.cpp:260
FailureOr< LowerUnPackOpResult > lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp, bool lowerUnpadLikeWithExtractSlice=true)
Rewrite pack as empty + transpose + reshape + extract_slice.
Definition: Transforms.cpp:360
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns)
Definition: Tiling.cpp:860
LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value)
In case of GPU private memory there is no need to deallocate since the memory is freed when going out...
Definition: Promotion.cpp:510
FailureOr< Operation * > decomposeWinogradOutputTransformOp(RewriterBase &rewriter, linalg::WinogradOutputTransformOp op)
Rewrite linalg.winograd_output_transform.
std::optional< Value > allocateGPUPrivateMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU private memory.
Definition: Promotion.cpp:494
FailureOr< Operation * > rewriteInDestinationPassingStyle(RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp)
Rewrite tensor.from_elements to linalg.generic.
FailureOr< Operation * > winogradConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op, int64_t m, int64_t r)
Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm F(m x m, r x r).
FailureOr< Operation * > transposeConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op)
Convert linalg.conv_2d_nhwc_fhwc(_q) to linalg.conv_2d_nhwc_hwcf(_q) by materializing transpose.
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
Patterns to fold unit-extent dimensions in operands/results of linalg ops on tensors via reassociativ...
LogicalResult copyToWorkgroupMemory(OpBuilder &b, Value src, Value dst)
Create Memref copy operations and add gpu barrier guards before and after the copy operation to ensur...
Definition: Promotion.cpp:485
LogicalResult linalgOpAnchoredEmptyTensorEliminationStep(RewriterBase &rewriter, Operation *op, bufferization::OneShotAnalysisState &state)
Try to eliminate tensor::EmptyOps inside op that are anchored on a LinalgOp.
LogicalResult vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Emit a suitable vector form for an operation.
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp linalgOp)
Create a GenericOp from the given named operation linalgOp and replace the given linalgOp.
FailureOr< Operation * > transposeBatchMatmul(RewriterBase &rewriter, linalg::BatchMatmulOp op, bool transposeLHS=true)
Pattern to replace.
LogicalResult promoteSubviewsPrecondition(Operation *op, LinalgPromotionOptions options)
Promote memref.subviews feeding linalg-on-buffers operations.
Definition: Promotion.cpp:398
LogicalResult copyToGPUPrivateMemory(OpBuilder &b, Value src, Value dst)
Normal copy to between src and dst.
Definition: Promotion.cpp:502
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
Definition: Utils.cpp:223
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.
Definition: Interchange.cpp:50
FailureOr< StaticMultiSizeSpecification > computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, int64_t targetSize, int64_t divisor)
Definition: Tiling.cpp:242
void populateDecomposePackUnpackPatterns(RewritePatternSet &patterns)
Populates patterns to decompose linalg.pack and linalg.unpack Ops into e.g.
FailureOr< ContinuousTileSizeSpecification > computeContinuousTileSizes(OpBuilder &builder, TilingInterface op, unsigned dimension, OpFoldResult targetSize, bool emitAssertions)
Definition: Tiling.cpp:162
FailureOr< StaticContinuousTileSizeSpecification > computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension, unsigned targetSize)
Definition: Tiling.cpp:111
FailureOr< SplitReductionResult > splitReduction(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
void populateFoldPackUnpackIntoTensorEmptyPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold operations like linalg.pack and linalg....
void hoistRedundantVectorBroadcasts(RewriterBase &rewriter, Operation *root)
Hoist vector.extract/vector.broadcast pairs out of immediately enclosing scf::ForOp iteratively,...
Definition: Hoisting.cpp:97
FailureOr< ForallReductionTilingResult > tileReductionUsingForall(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > tileSizes={}, std::optional< ArrayAttr > mapping=std::nullopt)
Method to tile a reduction to parallel iterations computing partial reductions.
Definition: Tiling.cpp:594
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
Definition: Transforms.cpp:770
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
Definition: Transforms.cpp:482
void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns)
Patterns to promote inputs to outputs and remove unused inputs of linalg.generic ops.
FailureOr< LinalgOp > promoteSubViews(OpBuilder &b, LinalgOp op, const LinalgPromotionOptions &options)
Promote the subViews into a new buffer allocated at the insertion point b.
Definition: Promotion.cpp:420
std::function< SplitReductionOptions(LinalgOp op)> ControlSplitReductionFn
Function signature to control reduction splitting.
Definition: Transforms.h:442
LogicalResult deallocateWorkgroupMemory(OpBuilder &, Value)
In case of GPU group memory there is no need to deallocate.
Definition: Promotion.cpp:478
FailureOr< Operation * > transposeMatmul(RewriterBase &rewriter, linalg::MatmulOp op, bool transposeLHS=true)
Convert Linalg matmul ops to transposed variants.
FailureOr< CollapseResult > collapseOpIterationDims(LinalgOp op, ArrayRef< ReassociationIndices > foldedIterationDims, RewriterBase &rewriter)
Collapses dimensions of linalg.generic/linalg.copy operation.
void hoistRedundantVectorTransfers(Operation *root, bool verifyNonZeroTrip=false)
Hoist vector.transfer_read/vector.transfer_write on buffers pairs out of immediately enclosing scf::F...
Definition: Hoisting.cpp:202
FailureOr< Operation * > decomposeWinogradInputTransformOp(RewriterBase &rewriter, linalg::WinogradInputTransformOp op)
Rewrite linalg.winograd_input_transform.
void populateDecomposePadPatterns(RewritePatternSet &patterns)
Populates patterns to decompose tensor.pad into e.g.
void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns)
Pattern to replace linalg.add when destination passing on a contraction op suffices for achieving the...
std::pair< TilingInterface, TilingInterface > splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension, OpFoldResult splitPoint)
Split the given op into two parts along the given iteration space dimension at the specified splitPoi...
Definition: Split.cpp:67
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold operations like tensor.pad and tensor.extract_slice into t...
FailureOr< SplitReductionResult > splitReductionByScaling(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
Scaling-based implementation of the split reduction transformation.
FailureOr< MultiSizeSpecification > computeMultiTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension, OpFoldResult targetSize, OpFoldResult divisor, bool emitAssertions=true)
Emits the IR computing the multi-sized tiling specification with two tile sizes not exceeding targetS...
Definition: Tiling.cpp:268
FailureOr< LowerPackResult > lowerPack(RewriterBase &rewriter, linalg::PackOp packOp, bool lowerPadLikeWithInsertSlice=true)
Rewrite pack as pad + reshape + transpose.
Definition: Transforms.cpp:224
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:77
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
FailureOr< SCFTilingResult > tileUsingSCF(RewriterBase &rewriter, TilingInterface op, const SCFTilingOptions &options)
Method to tile an op that implements the TilingInterface using scf.for for iterating over the tiles.
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
Definition: SCF.cpp:604
FailureOr< SmallVector< scf::ForOp > > lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op)
Method to lower an op that implements the TilingInterface to loops/scalars.
FailureOr< scf::SCFTilingResult > tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > tileSize)
Method to tile a reduction and generate a parallel op within a serial loop.
uint64_t getM(LevelType lt)
Definition: Enums.h:443
void populateMergeConsecutiveInsertExtractSlicePatterns(RewritePatternSet &patterns)
Collects patterns to merge consecutive tensor.insert_slice/extract_slice into one.
void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns)
Appends patterns that are used to bubble up tensor.extract slice op above its producer.
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:114
void populateFoldTensorSubsetIntoVectorTransferPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into vector transfer ops.
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
DiagnosedSilenceableFailure tileToForallOpImpl(RewriterBase &rewriter, transform::TransformState &state, TransformOpInterface transformOp, Operation *target, ArrayRef< OpFoldResult > mixedNumThreads, ArrayRef< OpFoldResult > mixedTileSizes, std::optional< ArrayAttr > mapping, scf::SCFTilingResult &tilingResult)
Implementation of tiling operations using scf.forall.
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
void populateVectorStepLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant Vector Ops by re-ordering them with e.g.
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Definition: XeGPUOps.cpp:23
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:325
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
This is the representation of an operand reference.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Definition: Builders.h:282
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
NamedAttrList attributes
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
A listener that forwards all notifications to another listener.
Definition: PatternMatch.h:421
ForwardingListener(OpBuilder::Listener *listener)
Definition: PatternMatch.h:422
Container for result values of tiling.
SmallVector< Value > tiledValues
Options for analysis-enabled bufferization.
Transformation to drop unit-extent dimensions from linalg.generic operations.
Definition: Transforms.h:473
Vectorization pattern for memref::CopyOp.
Definition: Transforms.h:1505
Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) dimensions into 1-D depthwise c...
Definition: Transforms.h:1437
Match and rewrite for the pattern:
Definition: Transforms.h:1627
Match and rewrite for the pattern:
Definition: Transforms.h:1655
LinalgPromotionOptions & setUseFullTileBuffersByDefault(bool use)
Definition: Transforms.h:379
LinalgPromotionOptions & setAlignment(unsigned align)
Definition: Transforms.h:385
LinalgPromotionOptions & setUseAlloca(bool use)
Definition: Transforms.h:398
LinalgPromotionOptions & setCopyInOutFns(CopyCallbackFn const &copyIn, CopyCallbackFn const &copyOut)
Definition: Transforms.h:418
LinalgPromotionOptions & setUseFullTileBuffers(ArrayRef< bool > useFullTiles)
Definition: Transforms.h:368
LinalgPromotionOptions & setMemorySpace(Attribute memorySpc)
Definition: Transforms.h:392
LinalgPromotionOptions & setAllocationDeallocationFns(AllocBufferCallbackFn const &allocFn, DeallocBufferCallbackFn const &deallocFn)
Definition: Transforms.h:408
LinalgPromotionOptions & setOperandsToPromote(ArrayRef< int64_t > operands)
Definition: Transforms.h:357
Split Reduction options.
Definition: Transforms.h:427
Options used to control tile + fuse.
SCFTilingOptions tilingOptions
The tiling options used to control the tiling of the consumer.
std::optional< FrozenRewritePatternSet > cleanupPatterns
An optional set of rewrite patterns to apply to the results of tiling before fusion.
Options to use to control tiling.
SCFTilingOptions & setTileSizeComputationFunction(SCFTileSizeComputationFunction fun)
SCFTilingOptions & setInterchange(ArrayRef< int64_t > interchange)
SCFTilingOptions & setTileSizes(ArrayRef< OpFoldResult > tileSizes)
Convenience function to set the tileSizeComputationFunction to a function that computes tile sizes at...
SmallVector< int64_t > interchangeVector
The interchange vector to reorder the tiled loops.
Transformation information returned after tiling.
SmallVector< Operation * > tiledOps
Tiled operations that are generated during tiling.
SmallVector< LoopLikeOpInterface > loops
The scf.for operations that iterate over the tiles.