MLIR  20.0.0git
LinalgTransformOps.cpp
Go to the documentation of this file.
1 //===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
12 
37 #include "mlir/IR/PatternMatch.h"
38 #include "mlir/IR/TypeUtilities.h"
40 #include "mlir/Support/LLVM.h"
41 #include "mlir/Support/TypeID.h"
43 #include "llvm/ADT/STLExtras.h"
44 #include "llvm/ADT/ScopeExit.h"
45 #include "llvm/ADT/TypeSwitch.h"
46 #include "llvm/Support/Debug.h"
47 #include <type_traits>
48 
49 using namespace mlir;
50 using namespace mlir::linalg;
51 using namespace mlir::transform;
52 
53 #define DEBUG_TYPE "linalg-transforms"
54 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
55 #define DBGSNL() (llvm::dbgs() << "\n")
56 #define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n")
57 
58 /// Attempts to apply the pattern specified as template argument to the given
59 /// operation. The pattern is expected to have a `returningMatchAndRewrite`
60 /// function that returns the "main" result or failure. Returns failure if the
61 /// pattern failed to apply. Extra arguments are forwarded to the pattern
62 /// constructor.
63 template <typename PatternTy, typename... Args>
64 static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
65  // Check if the given operation has the type expected by the pattern.
66  using OpTy = typename llvm::function_traits<
67  decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
68  auto op = dyn_cast<OpTy>(operation);
69  if (!op)
70  return failure();
71 
72  // Apply the pattern directly to the op.
73  PatternTy pattern(operation->getContext(), std::forward<Args>(args)...);
74  // We want to discourage direct use of PatternRewriter in APIs but In this
75  // very specific case, an IRRewriter is not enough.
76  struct TrivialPatternRewriter : public PatternRewriter {
77  public:
78  explicit TrivialPatternRewriter(MLIRContext *context)
79  : PatternRewriter(context) {}
80  };
81  TrivialPatternRewriter rewriter(operation->getContext());
82  rewriter.setInsertionPoint(operation);
83  auto result = pattern.returningMatchAndRewrite(op, rewriter);
84  if (failed(result))
85  return failure();
86  return cast<LinalgOp>(result->getOperation());
87 }
88 
89 /// Assuming that `ofr` is an index attr or a param of index type
90 /// or a transform dialect handle mapped to exactly one op
91 /// with one index result, return that value.
93  transform::TransformState &state, TransformOpInterface transformOp,
95  for (OpFoldResult ofr : ofrs) {
96  if (ofr.is<Attribute>()) {
97  if (!isa<IntegerAttr>(ofr.get<Attribute>()))
98  return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
99  result.push_back(ofr);
100  continue;
101  }
102 
103  Value transformValue = ofr.get<Value>();
104  if (isa<TransformParamTypeInterface>(transformValue.getType())) {
105  ArrayRef<Attribute> params = state.getParams(transformValue);
106  if (params.size() != 1)
107  return transformOp.emitDefiniteFailure()
108  << "requires exactly one parameter associated";
109  result.push_back(params[0]);
110  continue;
111  }
112 
113  auto payloadOps = state.getPayloadOps(transformValue);
114  if (!llvm::hasSingleElement(payloadOps)) {
116  transformOp.emitSilenceableError()
117  << "handle must be mapped to exactly one payload op";
118  diag.attachNote(transformValue.getLoc())
119  << "mapped to " << llvm::range_size(payloadOps) << " payload ops";
120  return diag;
121  }
122 
123  Operation *op = *payloadOps.begin();
124  if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
126  transformOp.emitSilenceableError()
127  << "payload op must have exactly 1 index result";
128  diag.attachNote(op->getLoc())
129  << "has " << op->getNumResults() << " results";
130  return diag;
131  }
132  result.push_back(op->getResult(0));
133  }
134 
136 }
137 
138 // Given a list of params that are index attrs or a list of OpFoldResults
139 // that are either index attrs or op handles, return a list of OpFoldResults
140 // of index attrs or a list of OpFoldResults where all op handles are
141 // replaced with the first (and only) OpResult of that payload op.
142 // (There must be exactly one parameter associated with the AnyParamType or
143 // one mapped payload op which must have exactly one index result.)
145  transform::TransformState &state, TransformOpInterface transformOp,
146  SmallVector<OpFoldResult> &result, Value packedHandle) {
147  if (isa<TransformParamTypeInterface>(packedHandle.getType())) {
148  ArrayRef<Attribute> params = state.getParams(packedHandle);
149  for (auto param : params) {
150  if (!isa<IntegerAttr>(param))
151  return transformOp.emitDefiniteFailure()
152  << "expected the parameter to be associated with an integer "
153  "attribute";
154  result.push_back(param);
155  }
157  }
158 
159  for (Operation *op : state.getPayloadOps(packedHandle)) {
160  if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
162  transformOp.emitSilenceableError()
163  << "payload op must have exactly 1 index result";
164  diag.attachNote(op->getLoc())
165  << "has " << op->getNumResults() << " results";
166  return diag;
167  }
168  result.push_back(op->getResult(0));
169  }
170 
172 }
173 
174 /// When possible, converts each `OpFoldResult` in `mixedResult` to
175 /// an integer if the value can be statically inferred. If a result
176 /// is a `Value` then it must be either a `ParamType` or a handle
177 /// to an a constant like op.
179  TransformState &state, TransformOpInterface &transformOp,
180  ArrayRef<OpFoldResult> mixedResults, SmallVectorImpl<int64_t> &reified) {
181  for (OpFoldResult paramOrHandle : mixedResults) {
182  if (isa<Attribute>(paramOrHandle)) {
183  reified.push_back(
184  cast<IntegerAttr>(paramOrHandle.get<Attribute>()).getInt());
185  continue;
186  } else if (isa<ParamType>(paramOrHandle.get<Value>().getType())) {
187  ArrayRef<Attribute> params = state.getParams(paramOrHandle.get<Value>());
188  if (params.size() != 1)
189  return transformOp.emitSilenceableError() << "expected a single param";
190  reified.push_back(
191  cast<IntegerAttr>(params.front()).getValue().getSExtValue());
192  continue;
193  }
194 
195  Value handle = paramOrHandle.get<Value>();
196  if (!isa<TransformHandleTypeInterface>(handle.getType()))
197  return transformOp.emitSilenceableError() << "unexpected value handle";
198  auto payload = state.getPayloadOps(handle);
199  if (!llvm::hasSingleElement(payload))
200  return transformOp.emitSilenceableError()
201  << "requires param or handle that is mapped to 1 payload op";
202 
203  Operation *paramOrHandlePayloadOp = *payload.begin();
204  if (paramOrHandlePayloadOp->getNumResults() != 1 ||
205  !paramOrHandlePayloadOp->getResult(0).getType().isIndex()) {
206  return transformOp.emitSilenceableError()
207  << "requires param or handle to be result of op with 1 index "
208  "result";
209  }
210 
211  IntegerAttr attr;
212  if (!matchPattern(paramOrHandlePayloadOp->getResult(0), m_Constant(&attr)))
213  return transformOp.emitSilenceableError()
214  << "requires param or handle to be the result of a constant like "
215  "op";
216 
217  reified.push_back(attr.getInt());
218  }
220 }
221 
222 //===----------------------------------------------------------------------===//
223 // Apply...PatternsOp
224 //===----------------------------------------------------------------------===//
225 
226 void transform::ApplyEraseUnnecessaryInputsPatternsOp::populatePatterns(
227  RewritePatternSet &patterns) {
229 }
230 
231 void transform::ApplyFoldUnitExtentDimsViaReshapesPatternsOp::populatePatterns(
232  RewritePatternSet &patterns) {
235 }
236 
237 void transform::ApplyFoldUnitExtentDimsViaSlicesPatternsOp::populatePatterns(
238  RewritePatternSet &patterns) {
240  options.rankReductionStrategy =
243 }
244 
245 void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns(
246  RewritePatternSet &patterns) {
248 }
249 
250 //===----------------------------------------------------------------------===//
251 // BufferizeToAllocationOp
252 //===----------------------------------------------------------------------===//
253 
254 void transform::BufferizeToAllocationOp::build(OpBuilder &b,
255  OperationState &result,
256  Value target,
257  Attribute memorySpace) {
258  SmallVector<Type> resultTypes;
259  resultTypes.push_back(b.getType<transform::AnyValueType>());
260  resultTypes.push_back(b.getType<transform::AnyOpType>());
261  return build(b, result,
262  /*resultTypes=*/resultTypes,
263  /*target=*/target,
264  /*memorySpace=*/memorySpace);
265 }
266 
267 void transform::BufferizeToAllocationOp::build(OpBuilder &b,
268  OperationState &result,
269  Value target,
270  int64_t memorySpace) {
271  SmallVector<Type> resultTypes;
272  resultTypes.push_back(b.getType<transform::AnyValueType>());
273  resultTypes.push_back(b.getType<transform::AnyOpType>());
274  return build(b, result,
275  /*resultTypes=*/resultTypes,
276  /*target=*/target,
277  /*memorySpace=*/b.getI64IntegerAttr(memorySpace));
278 }
279 
280 namespace {
281 class NewOpsListener : public RewriterBase::ForwardingListener {
282 public:
284 
285  SmallVector<Operation *> getNewOps() const {
286  return SmallVector<Operation *>(newOps.begin(), newOps.end());
287  }
288 
289 private:
290  void notifyOperationInserted(Operation *op,
291  OpBuilder::InsertPoint previous) override {
292  ForwardingListener::notifyOperationInserted(op, previous);
293  // We only care about newly created ops.
294  if (previous.isSet())
295  return;
296  auto inserted = newOps.insert(op);
297  (void)inserted;
298  assert(inserted.second && "expected newly created op");
299  }
300 
301  void notifyOperationErased(Operation *op) override {
302  ForwardingListener::notifyOperationErased(op);
303  op->walk([&](Operation *op) { newOps.erase(op); });
304  }
305 
306  DenseSet<Operation *> newOps;
307 };
308 } // namespace
309 
310 DiagnosedSilenceableFailure transform::BufferizeToAllocationOp::apply(
313  // Attach listener to keep track of newly created ops.
314  OpBuilder::Listener *previousListener = rewriter.getListener();
315  auto resetListener =
316  llvm::make_scope_exit([&]() { rewriter.setListener(previousListener); });
317  NewOpsListener newOpsListener(previousListener);
318  rewriter.setListener(&newOpsListener);
319 
321  if (getMemcpyOp() == "bufferization.materialize_in_destination") {
324  } else if (getMemcpyOp() == "memref.copy") {
325  options.memcpyOp =
327  } else if (getMemcpyOp() == "linalg.copy") {
328  options.memcpyOp =
330  } else {
331  llvm_unreachable("invalid memcpy op");
332  }
333  if (getAllocOp() == "memref.alloc") {
334  options.allocOp =
336  } else if (getAllocOp() == "memref.alloca") {
337  options.allocOp =
339  } else {
340  llvm_unreachable("invalid alloc op");
341  }
342  options.bufferizeDestinationOnly = getBufferizeDestinationOnly();
343  options.emitDealloc = getEmitDealloc();
344 
345  // Bufferize ops.
346  Attribute memorySpace =
347  getMemorySpace().has_value() ? getMemorySpace().value() : Attribute();
348  SmallVector<Value> allocatedBuffers;
349  for (Operation *op : state.getPayloadOps(getTarget())) {
350  Value buffer =
351  linalg::bufferizeToAllocation(rewriter, options, op, memorySpace);
352  if (!buffer) {
353  DiagnosedSilenceableFailure diag = emitSilenceableError()
354  << "failed to bufferize operation";
355  diag.attachNote(op->getLoc()) << "target payload op";
356  return diag;
357  }
358  allocatedBuffers.push_back(buffer);
359  }
360 
361  // Set results.
362  results.setValues(cast<OpResult>(getAllocatedBuffer()), allocatedBuffers);
363  results.set(cast<OpResult>(getNewOps()), newOpsListener.getNewOps());
365 }
366 
367 void transform::BufferizeToAllocationOp::getEffects(
369  if (getBufferizeDestinationOnly()) {
370  // The destination is replaced with a newly allocated buffer, but the op
371  // itself remains in place.
372  onlyReadsHandle(getTargetMutable(), effects);
373  } else {
374  consumesHandle(getTargetMutable(), effects);
375  }
376  producesHandle(getOperation()->getOpResults(), effects);
377  modifiesPayload(effects);
378 }
379 
381  if (getMemcpyOp() != "bufferization.materialize_in_destination" &&
382  getMemcpyOp() != "memref.copy" && getMemcpyOp() != "linalg.copy")
383  return emitOpError() << "unsupported memcpy op";
384  if (getAllocOp() != "memref.alloc" && getAllocOp() != "memref.alloca")
385  return emitOpError() << "unsupported alloc op";
386  return success();
387 }
388 
389 //===----------------------------------------------------------------------===//
390 // DecomposeOp
391 //===----------------------------------------------------------------------===//
392 
394 transform::DecomposeOp::applyToOne(transform::TransformRewriter &rewriter,
395  LinalgOp target,
397  transform::TransformState &state) {
398 #define DOWNSCALE(trans) \
399  { \
400  FailureOr<LinalgOp> res = tryApply<trans>(target); \
401  if (succeeded(res)) { \
402  results.push_back(*res); \
403  return DiagnosedSilenceableFailure::success(); \
404  } \
405  }
406 
407 #define DOWNSCALE_CALL(a, b) DownscaleSizeOneWindowed2DConvolution<a, b>
408 #define DOWNSCALE_NORMAL(a, b) DOWNSCALE(DOWNSCALE_CALL(a, b))
409 
410  DOWNSCALE_NORMAL(Conv2DNhwcHwcfOp, Conv1DNwcWcfOp)
411  DOWNSCALE_NORMAL(Conv2DNchwFchwOp, Conv1DNcwFcwOp)
412  DOWNSCALE_NORMAL(PoolingNhwcSumOp, PoolingNwcSumOp)
413  DOWNSCALE_NORMAL(PoolingNchwSumOp, PoolingNcwSumOp)
414  DOWNSCALE_NORMAL(PoolingNhwcMaxOp, PoolingNwcMaxOp)
415  DOWNSCALE_NORMAL(PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp)
416  DOWNSCALE_NORMAL(PoolingNhwcMinOp, PoolingNwcMinOp)
417  DOWNSCALE_NORMAL(PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp)
418  DOWNSCALE_NORMAL(PoolingNchwMaxOp, PoolingNcwMaxOp)
421 #undef DOWNSCALE_NORMAL
422 #undef DOWNSCALE_CALL
423 #undef DOWNSCALE
424  return emitDefaultSilenceableFailure(target);
425 }
426 
427 //===----------------------------------------------------------------------===//
428 // DecomposeInterfaceOp
429 //===----------------------------------------------------------------------===//
430 
431 // Decompose the target operation if it implements the AggregatedOpInterface.
432 // Push the decomposed operations (the ones that replaces the values produced by
433 // \p target) in the `results`.
434 DiagnosedSilenceableFailure transform::DecomposeInterfaceOp::applyToOne(
435  transform::TransformRewriter &rewriter, Operation *target,
437  transform::TransformState &state) {
438  auto decomposableOp = dyn_cast<AggregatedOpInterface>(target);
439  if (!decomposableOp) {
440  failed(rewriter.notifyMatchFailure(target,
441  "payload is not a decomposable op"));
442  return emitDefaultSilenceableFailure(target);
443  }
444 
445  FailureOr<SmallVector<Value>> maybeNewResults =
446  decomposableOp.decomposeOperation(rewriter);
447  if (failed(maybeNewResults))
448  return emitDefaultSilenceableFailure(target);
449 
450  rewriter.replaceOp(decomposableOp, *maybeNewResults);
451  for (Value val : *maybeNewResults) {
452  Operation *definition = val.getDefiningOp();
453  if (definition)
454  results.push_back(definition);
455  }
457 }
458 
459 //===----------------------------------------------------------------------===//
460 // EliminateLinalgOpAnchoredEmptyTensorsOp
461 //===----------------------------------------------------------------------===//
462 
463 void transform::EliminateLinalgOpAnchoredEmptyTensorsOp::getEffects(
465  onlyReadsHandle(getTargetMutable(), effects);
466  modifiesPayload(effects);
467 }
468 
470 transform::EliminateLinalgOpAnchoredEmptyTensorsOp::apply(
471  transform::TransformRewriter &rewriter, TransformResults &transformResults,
472  TransformState &state) {
474  options.allowReturnAllocsFromLoops = true;
475 
476  for (Operation *target : state.getPayloadOps(getTarget())) {
478  if (failed(analyzeOp(target, state)))
479  return mlir::emitSilenceableFailure(target->getLoc())
480  << "failed to analyze op";
482  rewriter, target, state)))
483  return mlir::emitSilenceableFailure(target->getLoc())
484  << "failed to eliminate LinalgOp anchored tensor.empty ops";
485  }
487 }
488 
489 //===----------------------------------------------------------------------===//
490 // FuseOp
491 //===----------------------------------------------------------------------===//
492 
493 /// Apply a tiling transformation to all payload ops and store both the
494 /// tiled operation as well as the created tile loops.
495 template <typename Range>
496 static LogicalResult applyTilingToAll(
497  RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps,
498  unsigned numLoops, transform::TransformResults &transformResults,
499  function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)>
500  applyFn) {
501  SmallVector<Operation *> tiledLinalgOps;
502  SmallVector<SmallVector<Operation *>> loopOps(numLoops);
503 
504  for (Operation *target : payloadOps) {
505  auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
506  if (!tilingInterfaceOp)
507  return transformOp->emitError("only TilingInterface ops are supported");
508 
509  rewriter.setInsertionPoint(target);
510  FailureOr<scf::SCFTileAndFuseResult> tiledResults =
511  applyFn(tilingInterfaceOp);
512  if (failed(tiledResults))
513  return failure();
514 
515  // Perform the replacement of tiled and fused values.
516  SmallVector<Operation *> opsToReplace{target};
517  llvm::append_range(opsToReplace, tiledResults->fusedProducers);
518  for (Operation *toReplace : opsToReplace) {
519  for (OpResult res : toReplace->getResults())
520  if (auto replacement = tiledResults->replacements.lookup(res))
521  rewriter.replaceAllUsesWith(res, replacement);
522  if (toReplace->use_empty()) {
523  rewriter.eraseOp(toReplace);
524  }
525  }
526 
527  // Report back the relevant handles to the transform op.
528  tiledLinalgOps.push_back(tiledResults->tiledAndFusedOps.front());
529  assert(tiledResults->loops.size() == numLoops &&
530  "Mismatched number of loops, tile and fuse transform should have "
531  "failed");
532  for (unsigned int i = 0; i < numLoops; ++i)
533  loopOps[i].push_back(tiledResults->loops[i]);
534  }
535 
536  transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
537  for (unsigned int i = 0; i < numLoops; ++i)
538  transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
539 
540  return success();
541 }
542 
544 transform::FuseOp::apply(transform::TransformRewriter &rewriter,
545  mlir::transform::TransformResults &transformResults,
547  SmallVector<int64_t> tileSizes =
548  extractFromIntegerArrayAttr<int64_t>(getTileSizes());
549  SmallVector<int64_t> tileInterchange =
550  extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
551 
552  scf::SCFTilingOptions tilingOptions;
553  tilingOptions.interchangeVector = tileInterchange;
554  SmallVector<OpFoldResult> tileSizesOfr =
555  getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
556  tilingOptions = tilingOptions.setTileSizes(tileSizesOfr);
557  scf::SCFTileAndFuseOptions tileAndFuseOptions;
558  tileAndFuseOptions.tilingOptions = tilingOptions;
559  LogicalResult result = applyTilingToAll(
560  rewriter, getOperation(), state.getPayloadOps(getTarget()),
561  tileSizes.size() - llvm::count(tileSizes, 0), transformResults,
562  [&](TilingInterface tilingInterfaceOp)
563  -> FailureOr<scf::SCFTileAndFuseResult> {
564  return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
565  tileAndFuseOptions);
566  });
567  return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
568  : DiagnosedSilenceableFailure::success();
569 }
570 
571 LogicalResult transform::FuseOp::verify() {
572  SmallVector<int64_t> permutation =
573  extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
574  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
575  if (!std::is_permutation(sequence.begin(), sequence.end(),
576  permutation.begin(), permutation.end())) {
577  return emitOpError() << "expects interchange to be a permutation, found "
578  << getTileInterchange();
579  }
580 
581  SmallVector<int64_t> sizes =
582  extractFromIntegerArrayAttr<int64_t>(getTileSizes());
583  size_t numExpectedLoops = sizes.size() - llvm::count(sizes, 0);
584  if (numExpectedLoops != getNumResults() - 1)
585  return emitOpError() << "expects " << numExpectedLoops << " loop results";
586 
587  return success();
588 }
589 
590 //===----------------------------------------------------------------------===//
591 // FuseIntoContainingOp
592 //===----------------------------------------------------------------------===//
593 
594 void transform::FuseIntoContainingOp::build(OpBuilder &builder,
595  OperationState &result,
596  Value producerOp,
597  Value containingOp) {
598  result.addOperands({producerOp, containingOp});
599  auto resultType = transform::AnyOpType::get(builder.getContext());
600  result.addTypes({resultType, resultType});
601 }
602 
603 /// Add new operands to the forall op for users of the producerOp
604 /// that are dominated by the containing scf.forall op.
606  RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
607  Operation *containingOp, TilingResult &tileAndFuseResult,
608  int64_t resultNumber, SmallVector<OpFoldResult> &offsets,
609  SmallVector<OpFoldResult> &sizes) {
610 
611  // Count number of users not including the containing op
612  SetVector<Operation *> dominatedUsers;
613  DominanceInfo domInfo(containingOp);
614  for (Operation *user : producerOp->getResult(resultNumber).getUsers()) {
615  if (!containingOp->isAncestor(user) &&
616  (domInfo.dominates(containingOp, user))) {
617  dominatedUsers.insert(user);
618  }
619  }
620  if (dominatedUsers.empty())
621  return nullptr;
622 
623  // Create new scf.forall op
624  auto forallOp = cast<scf::ForallOp>(containingOp);
625  OpBuilder::InsertionGuard g(rewriter);
626  rewriter.setInsertionPoint(forallOp);
627 
628  // Get new output
629  Location loc = forallOp.getLoc();
630  auto genericOp = dyn_cast<linalg::GenericOp>(producerOp);
631  if (!genericOp)
632  return nullptr;
633  SmallVector<Value> outputs = genericOp.getOutputs();
634  SmallVector<Value> newOuts(forallOp.getOutputs());
635  newOuts.push_back(outputs[resultNumber]);
636 
637  // Create new scf.forall op
638  auto newforallOp = rewriter.create<scf::ForallOp>(
639  loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
640  forallOp.getMixedStep(), newOuts, forallOp.getMapping());
641  rewriter.eraseBlock(newforallOp.getBody());
642  newforallOp.getRegion().takeBody(forallOp.getRegion());
643 
644  // Add additional block argument for new value being returned
645  // and replaces all uses of the new output with corresponding bbArg
646  // inside the scf.forall to enable fusion into this new scf.forall.
647  newforallOp.getBody()->addArgument(newOuts.back().getType(),
648  newOuts.back().getLoc());
649  auto bbArgs = newforallOp.getBody()->getArguments();
650  rewriter.replaceUsesWithIf(newOuts.back(), bbArgs.back(),
651  [&](OpOperand &use) {
652  Operation *op = use.getOwner();
653  return newforallOp->isProperAncestor(op);
654  });
655 
656  // Fix terminator
657  scf::InParallelOp terminatorOp = newforallOp.getTerminator();
658  SmallVector<Operation *> yieldingOps = llvm::to_vector<4>(llvm::map_range(
659  terminatorOp.getYieldingOps(), [](Operation &op) { return &op; }));
660  Operation *firstYieldOp = yieldingOps.front();
661  rewriter.setInsertionPoint(firstYieldOp);
662  Value src = tileAndFuseResult.tiledValues[0];
663  Value dst = newforallOp.getRegionIterArgs().back();
664  SmallVector<OpFoldResult> strides(offsets.size(), rewriter.getIndexAttr(1));
665  rewriter.create<tensor::ParallelInsertSliceOp>(firstYieldOp->getLoc(), src,
666  dst, offsets, sizes, strides);
667 
668  for (auto result : llvm::enumerate(forallOp.getResults())) {
669  rewriter.replaceAllUsesWith(result.value(),
670  newforallOp->getResult(result.index()));
671  }
672  rewriter.replaceUsesWithIf(producerOp->getResult(resultNumber),
673  newforallOp->getResults().back(),
674  [&](OpOperand &use) {
675  Operation *user = use.getOwner();
676  return dominatedUsers.contains(user);
677  });
678  return newforallOp;
679 }
680 
681 /// Find the first "extract" user of `producerOp` and tile it right before its
682 /// use. The tiled op is fused under the `containingOp`.
683 /// Return this fused op on success or nullptr if anything fails.
684 /// If tiled op has uses that are dominated by `containingOp`, return
685 /// a new `containingOp` with results of the fused op appended to
686 /// results of the `containingOp` or nullptr if there are no dominated uses.
687 static std::tuple<SmallVector<Operation *>, Operation *>
689  Operation *producerOp, Operation *containingOp) {
690  LLVM_DEBUG(DBGS() << "Try to fuse a direct extract use\n");
691  auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
692  if (!tileableProducer) {
693  diag.attachNote(producerOp->getLoc())
694  << "producer is not a TileableInterface: " << *producerOp;
695  return {};
696  }
697 
698  // Search the producer slices accessed within the containing operation.
699  // TODO: Generalize to more extract/insert/parallel_insert triples, maybe
700  // evolve into an interface.
701  auto it = llvm::find_if(tileableProducer->getUsers(), [&](Operation *user) {
702  auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
703  return sliceOp && containingOp->isProperAncestor(sliceOp);
704  });
705 
706  // Find a fusion opportunity.
707  if (it == tileableProducer->getUsers().end()) {
708  diag.attachNote(tileableProducer->getLoc())
709  << "could not find fusion opportunity for: " << *tileableProducer;
710  return {};
711  }
712  auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*it);
713 
714  // Try to fuse the producer in-place.
715  OpBuilder::InsertionGuard guard(rewriter);
716  rewriter.setInsertionPoint(sliceOpToTile);
717 
718  // Tile the producer.
719  int64_t resultNumber =
720  cast<OpResult>(sliceOpToTile.getSource()).getResultNumber();
721  LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
722 
723  SmallVector<OpFoldResult> offsets = sliceOpToTile.getMixedOffsets();
724  SmallVector<OpFoldResult> sizes = sliceOpToTile.getMixedSizes();
725 
726  FailureOr<TilingResult> tileAndFuseResult =
727  tileableProducer.generateResultTileValue(rewriter, resultNumber, offsets,
728  sizes);
729 
730  if (failed(tileAndFuseResult)) {
731  diag.attachNote(tileableProducer->getLoc())
732  << "failed to tile producer op: " << *tileableProducer;
733  return {};
734  }
735 
736 #ifndef NDEBUG
737  for (auto *tiledOp : tileAndFuseResult->tiledOps) {
738  LLVM_DEBUG(DBGS() << "tiledProducer: " << *tiledOp << "\n");
739  }
740 #endif
741 
742  // Replace the extract op.
743  auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
744  rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
745  cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
746  if (failed(maybeRankReduced)) {
747  diag.attachNote(producerOp->getLoc())
748  << "shape types don't match (missing canonicalization?):\nTiledOp: "
749  << tileAndFuseResult->tiledValues[0]
750  << "\nSliceOp: " << sliceOpToTile.getOperation() << '\n';
751  return {};
752  }
753  rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
754 
755  // Add new outputs to containing op, if required
756  Operation *newContainingOp = replaceForAllWithNewSignature(
757  rewriter, diag, producerOp, containingOp, *tileAndFuseResult,
758  resultNumber, offsets, sizes);
759 
760  return std::make_tuple(tileAndFuseResult->tiledOps, newContainingOp);
761 }
762 
763 /// First, find the first "scf::ForallOp" user of `producerOp` and ensure
764 /// it is exactly the `containingOp`, otherwise bail.
765 /// Then, find the first "extract" user of the tied block argument and tile it
766 /// right before its "extract" use. The tiled op is fused under the
767 /// `containingOp`.
768 /// Return this fused op on success or nullptr if anything fails.
771  RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
772  Operation *containingOp) {
773  LLVM_DEBUG(DBGS() << "Try to fuse an extract use through block argument\n");
774 
775  auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
776  if (!tileableProducer) {
777  diag.attachNote(producerOp->getLoc())
778  << "producer is not a TileableInterface: " << *producerOp;
779  return {};
780  }
781 
782  // Search the first use by a "scf::ForallOp" user.
783  scf::ForallOp forallOp;
784  auto itProducerUses =
785  llvm::find_if(tileableProducer->getUses(), [&](OpOperand &use) {
786  forallOp = dyn_cast<scf::ForallOp>(use.getOwner());
787  return forallOp;
788  });
789  // If it's not from the containing op, return.
790  if (!forallOp || forallOp != containingOp) {
791  diag.attachNote(tileableProducer->getLoc())
792  << "could not find a use by the containing op: " << *tileableProducer;
793  return {};
794  }
795 
796  // Search the producer slices accessed within the containing
797  // operation.
798  // TODO: Generalize to more extract/insert/parallel_insert triples.
799  // Maybe evolve into an interface.
800  OpOperand *pUse = &(*itProducerUses);
801  BlockArgument bbArg = forallOp.getTiedBlockArgument(pUse);
802 
803  // Search the producer slices accessed within the containing operation.
804  // TODO: Generalize to more extract/insert/parallel_insert triples, maybe
805  // evolve into an interface.
806  auto itBBArgUsers = llvm::find_if(bbArg.getUsers(), [&](Operation *user) {
807  auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
808  return sliceOp && containingOp->isProperAncestor(sliceOp);
809  });
810 
811  // Find a fusion opportunity.
812  if (itBBArgUsers == bbArg.getUsers().end()) {
813  diag.attachNote(containingOp->getLoc())
814  << "could not find fusion opportunity for bbArg: " << bbArg;
815  return {};
816  }
817  auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*itBBArgUsers);
818 
819  // Try to fuse the producer in-place.
820  OpBuilder::InsertionGuard guard(rewriter);
821  rewriter.setInsertionPoint(sliceOpToTile);
822 
823  // Replace the use in the tileableProducer before tiling: clone, replace and
824  // then tile.
825  int64_t resultNumber = cast<OpResult>(pUse->get()).getResultNumber();
826  LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
827 
828  // Gather destination tensors.
829  SmallVector<Value> destinationTensors;
831  rewriter, tileableProducer->getLoc(), tileableProducer,
832  destinationTensors))) {
833  diag.attachNote(tileableProducer->getLoc())
834  << "failed to get destination tensors for: " << *tileableProducer;
835  return {};
836  }
837 
838  IRMapping bvm;
839  bvm.map(destinationTensors[resultNumber], bbArg);
840  auto tileableProducerClone =
841  cast<TilingInterface>(rewriter.clone(*tileableProducer, bvm));
842  auto scopeGuard =
843  llvm::make_scope_exit([&]() { rewriter.eraseOp(tileableProducerClone); });
844 
845  // Tile the producer.
846  FailureOr<TilingResult> tileAndFuseResult =
847  tileableProducerClone.generateResultTileValue(
848  rewriter, resultNumber, sliceOpToTile.getMixedOffsets(),
849  sliceOpToTile.getMixedSizes());
850  if (failed(tileAndFuseResult)) {
851  diag.attachNote(tileableProducer->getLoc())
852  << "failed to tile producer op: " << *tileableProducer;
853  return {};
854  }
855 
856  // Replace the extract op.
857  auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
858  rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
859  cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
860  assert(succeeded(maybeRankReduced) && "unexpected shape");
861  rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
862 
863  // Replace the use in containingOp.
864  rewriter.modifyOpInPlace(containingOp, [&]() {
865  containingOp->setOperand(pUse->getOperandNumber(),
866  destinationTensors.front());
867  });
868 
869  return tileAndFuseResult->tiledOps;
870 }
871 
873  Operation *producerOp,
874  Operation *containingOp) {
875  LLVM_DEBUG(DBGS() << "Try to fuse an use by cloning\n");
876 
877  // Gather all uses inside the containing op.
879  for (OpResult result : producerOp->getOpResults()) {
880  for (OpOperand &use : result.getUses()) {
881  if (containingOp->isProperAncestor(use.getOwner())) {
882  uses.push_back(&use);
883  continue;
884  }
885  // Cannot clone and fuse if the use is by the containing op itself: fail
886  // immediately.
887  if (containingOp == use.getOwner()) {
888  diag.attachNote(producerOp->getLoc())
889  << "producer op use by containing op cannot be fused by cloning";
890  return nullptr;
891  }
892  }
893  }
894 
895  // Check for a non-empty list of fusion opportunities.
896  if (uses.empty()) {
897  diag.attachNote(producerOp->getLoc()) << "no fusion opportunity by cloning";
898  return nullptr;
899  }
900 
901  // Clone and fuse inside the containing op.
902  Operation *fusedOp = nullptr;
903  OpOperand *use = uses.front();
904  // Parallel insert slice is not a valid clone destination.
905  // TODO: Generalize to other type of ops.
906  assert(!isa<tensor::ParallelInsertSliceOp>(use->getOwner()) &&
907  "Parallel insert slice is not a valid clone destination");
908  unsigned resultNumber = cast<OpResult>(use->get()).getResultNumber();
909  LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
910 
911  OpBuilder::InsertionGuard guard(rewriter);
912  rewriter.setInsertionPoint(use->getOwner());
913  fusedOp = rewriter.clone(*producerOp);
914  rewriter.modifyOpInPlace(
915  use->getOwner(), [&] { use->set(fusedOp->getOpResult(resultNumber)); });
916 
917  return fusedOp;
918 }
919 
920 bool transform::FuseIntoContainingOp::allowsRepeatedHandleOperands() {
921  // Allow repeated handles since we are fusing everything anyway.
922  return true;
923 }
924 
926 transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter,
928  transform::TransformState &state) {
929  SmallVector<Operation *> fusedOps;
930  auto producerOps = state.getPayloadOps(getProducerOp());
931  auto containingOps = state.getPayloadOps(getContainingOp());
932  if (!llvm::hasSingleElement(containingOps)) {
933  return emitDefiniteFailure()
934  << "requires exactly one containing_op handle (got "
935  << llvm::range_size(containingOps) << ")";
936  }
937  Operation *containingOp = *containingOps.begin();
938 
939  // If nothing to fuse, propagate success.
940  if (std::empty(producerOps)) {
941  results.set(cast<OpResult>(getFusedOp()), SmallVector<mlir::Operation *>{});
942  results.set(cast<OpResult>(getNewContainingOp()), {containingOp});
944  }
945 
946  // Helper function to find the next producer that should be fused. Take any
947  // producer that has a use inside the containing op.
948  SetVector<Operation *> remainingProducers(producerOps.begin(),
949  producerOps.end());
950  auto getNextProducer = [&]() -> FailureOr<Operation *> {
951  for (const auto &it : enumerate(remainingProducers)) {
952  Operation *producerOp = it.value();
953  // The containing op may be a user of producerOp: use isAncestor.
954  int64_t numUsesInContainingOp =
955  llvm::count_if(producerOp->getUsers(), [&](Operation *op) {
956  return containingOp->isAncestor(op);
957  });
958  // TODO: When resolving the TODO below (no duplicate ops), take an op
959  // that has no use among the remaining producers. This is a topological
960  // sorting.
961  if (numUsesInContainingOp > 0) {
962  if (numUsesInContainingOp == 1)
963  remainingProducers.erase(remainingProducers.begin() + it.index());
964  return producerOp;
965  }
966  }
967  return failure();
968  };
969 
970  while (!remainingProducers.empty()) {
971  auto nextProducer = getNextProducer();
972  if (failed(nextProducer)) {
973  auto diag = mlir::emitSilenceableFailure(getLoc())
974  << "could not find next producer to fuse into container";
975  diag.attachNote(containingOp->getLoc()) << "containing op";
976  return diag;
977  }
978 
979  Operation *producerOp = *nextProducer;
980 
981  // Default diagnostic, to be complemented with more failure information.
983  diag << "could not fuse " << *producerOp << " into " << *containingOp;
984 
985  // TODO: If there are multiple uses of the producer in the containing op,
986  // we currently tile/clone the op multiple times (once per use). In some
987  // cases, we can tile/clone once and reuse the value for each use.
988  // Futhermore, producers should then be traversed according to a
989  // topological sorting.
990  auto [tiledOps, newContainingOp] =
991  tileAndFuseFirstExtractUse(rewriter, diag, producerOp, containingOp);
992  if (!tiledOps.empty()) {
993  LLVM_DEBUG(DBGS() << "\nFused a direct extract use\n" << *containingOp);
994  fusedOps.append(tiledOps);
995  if (newContainingOp) {
996  // Update handles associated with the containing op so we don't need to
997  // invalidate them. This is a hack to support better composability
998  // between tiling and fusion while a proper mechanism is being
999  // investigated.
1000  //
1001  // DO NOT replicate this elsewhere unless you understand what you are
1002  // doing.
1003  LogicalResult replacementStatus =
1004  rewriter.notifyPayloadOperationReplaced(containingOp,
1005  newContainingOp);
1006  (void)replacementStatus;
1007  assert(succeeded(replacementStatus) &&
1008  "unable to update transform state mapping");
1009  rewriter.eraseOp(containingOp);
1010  containingOp = newContainingOp;
1011  }
1012  continue;
1013  }
1014 
1015  SmallVector<Operation *> tiledContainingOpOperand =
1017  rewriter, diag, producerOp, containingOp);
1018  if (!tiledContainingOpOperand.empty()) {
1019  LLVM_DEBUG(DBGS() << "\nFused an extract use through block argument\n"
1020  << *containingOp);
1021  fusedOps.append(tiledContainingOpOperand);
1022  continue;
1023  }
1024 
1025  Operation *cloned =
1026  cloneAndFuseFirstUse(rewriter, diag, producerOp, containingOp);
1027  if (cloned) {
1028  LLVM_DEBUG(DBGS() << "\nFused an use by cloning\n" << *containingOp);
1029  fusedOps.push_back(cloned);
1030  continue;
1031  }
1033  }
1034 
1035  results.set(cast<OpResult>(getFusedOp()), fusedOps);
1036  results.set(cast<OpResult>(getNewContainingOp()), {containingOp});
1038 }
1039 
1040 void transform::FuseIntoContainingOp::getEffects(
1042  consumesHandle(getProducerOpMutable(), effects);
1043  onlyReadsHandle(getContainingOpMutable(), effects);
1044  producesHandle(getOperation()->getOpResults(), effects);
1045  modifiesPayload(effects);
1046 }
1047 
1048 //===----------------------------------------------------------------------===//
1049 // GeneralizeOp
1050 //===----------------------------------------------------------------------===//
1051 
1053 transform::GeneralizeOp::applyToOne(transform::TransformRewriter &rewriter,
1054  LinalgOp target,
1056  transform::TransformState &state) {
1057  // Exit early if no transformation is needed.
1058  if (isa<GenericOp>(target)) {
1059  results.push_back(target);
1061  }
1062  rewriter.setInsertionPoint(target);
1063  FailureOr<LinalgOp> generic = generalizeNamedOp(rewriter, target);
1064  if (succeeded(generic)) {
1065  results.push_back(generic->getOperation());
1067  }
1068  return emitDefaultSilenceableFailure(target);
1069 }
1070 
1071 //===----------------------------------------------------------------------===//
1072 // SpecializeOp
1073 //===----------------------------------------------------------------------===/
1074 
1076 transform::SpecializeOp::applyToOne(transform::TransformRewriter &rewriter,
1077  LinalgOp target,
1079  transform::TransformState &state) {
1080  // Exit early if the operation is not a generic.
1081  if (!isa<GenericOp>(target)) {
1082  results.push_back(target);
1084  }
1085  rewriter.setInsertionPoint(target);
1086  FailureOr<LinalgOp> named =
1087  specializeGenericOp(rewriter, cast<GenericOp>(target));
1088  if (succeeded(named)) {
1089  results.push_back(named->getOperation());
1091  }
1092  return emitDefaultSilenceableFailure(target);
1093 }
1094 
1095 //===----------------------------------------------------------------------===//
1096 // InterchangeOp
1097 //===----------------------------------------------------------------------===//
1098 
1100 transform::InterchangeOp::applyToOne(transform::TransformRewriter &rewriter,
1101  GenericOp target,
1103  transform::TransformState &state) {
1104  ArrayRef<int64_t> interchangeVector = getIteratorInterchange();
1105  // Exit early if no transformation is needed.
1106  if (interchangeVector.empty()) {
1107  results.push_back(target);
1109  }
1110 
1111  unsigned numLoops = cast<LinalgOp>(target.getOperation()).getNumLoops();
1112  if (interchangeVector.size() != numLoops) {
1113  return emitSilenceableError()
1114  << getIteratorInterchangeAttrName() << " has length ("
1115  << interchangeVector.size()
1116  << ") different from the number of loops in the target operation ("
1117  << numLoops << ")";
1118  }
1119  FailureOr<GenericOp> res =
1120  interchangeGenericOp(rewriter, target,
1121  SmallVector<unsigned>(interchangeVector.begin(),
1122  interchangeVector.end()));
1123  if (failed(res))
1124  return emitDefiniteFailure() << "failed to apply";
1125  results.push_back(res->getOperation());
1127 }
1128 
1129 LogicalResult transform::InterchangeOp::verify() {
1130  ArrayRef<int64_t> permutation = getIteratorInterchange();
1131  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
1132  if (!std::is_permutation(sequence.begin(), sequence.end(),
1133  permutation.begin(), permutation.end())) {
1134  return emitOpError()
1135  << "expects iterator_interchange to be a permutation, found "
1136  << getIteratorInterchange();
1137  }
1138  return success();
1139 }
1140 
1141 //===----------------------------------------------------------------------===//
1142 // LowerPackOp
1143 //===----------------------------------------------------------------------===//
1144 
1145 DiagnosedSilenceableFailure transform::LowerPackOp::applyToOne(
1146  transform::TransformRewriter &rewriter, tensor::PackOp target,
1147  transform::ApplyToEachResultList &transformResults,
1148  transform::TransformState &state) {
1149  rewriter.setInsertionPoint(target);
1150  FailureOr<LowerPackResult> res = lowerPack(rewriter, target);
1151  if (failed(res)) {
1152  return mlir::emitSilenceableFailure(target->getLoc())
1153  << "cannot lower to pad + expand + transpose";
1154  }
1155  transformResults.push_back(res->padOp);
1156  transformResults.push_back(res->expandShapeOp);
1157  transformResults.push_back(res->transposeOp);
1159 }
1160 
1161 //===----------------------------------------------------------------------===//
1162 // LowerUnPackOp
1163 //===----------------------------------------------------------------------===//
1164 
1165 DiagnosedSilenceableFailure transform::LowerUnPackOp::applyToOne(
1166  transform::TransformRewriter &rewriter, tensor::UnPackOp target,
1167  transform::ApplyToEachResultList &transformResults,
1168  transform::TransformState &state) {
1169  rewriter.setInsertionPoint(target);
1170  FailureOr<LowerUnPackOpResult> res = lowerUnPack(rewriter, target);
1171  if (failed(res)) {
1173  emitSilenceableError()
1174  << "cannot lower to transpose + collapse + extract";
1175  diag.attachNote(target->getLoc()) << "target payload op";
1176  return diag;
1177  }
1178  transformResults.push_back(res->emptyOp);
1179  transformResults.push_back(res->transposeOp);
1180  transformResults.push_back(res->collapseShapeOp);
1181  transformResults.push_back(res->extractSliceOp);
1183 }
1184 
1185 //===---------------------------------------------------------------------===//
1186 // MatchOp
1187 //===---------------------------------------------------------------------===//
1188 
1189 void transform::MatchOp::build(OpBuilder &builder, OperationState &result,
1190  Value target, ArrayRef<StringRef> opNames) {
1191  result.addOperands(target);
1192  result.addAttribute(MatchOp::getOpsAttrName(result.name),
1193  builder.getStrArrayAttr(opNames));
1194  result.addTypes(transform::AnyOpType::get(builder.getContext()));
1195 }
1196 
1197 void transform::MatchOp::build(OpBuilder &builder, OperationState &result,
1198  TypeRange resultTypes, Value target,
1199  ArrayRef<StringRef> opNames) {
1200  result.addOperands(target);
1201  result.addAttribute(MatchOp::getOpsAttrName(result.name),
1202  builder.getStrArrayAttr(opNames));
1203  result.addTypes(resultTypes);
1204 }
1205 
1207 transform::MatchOp::apply(transform::TransformRewriter &rewriter,
1208  transform::TransformResults &results,
1209  transform::TransformState &state) {
1210  llvm::StringSet<> strs;
1211  if (getOps().has_value())
1212  strs.insert(getOps()->getAsValueRange<StringAttr>().begin(),
1213  getOps()->getAsValueRange<StringAttr>().end());
1214 
1215  auto payloadOps = state.getPayloadOps(getTarget());
1216  if (!llvm::hasSingleElement(payloadOps)) {
1217  return emitDefiniteFailure("requires exactly one target handle");
1218  }
1219 
1221  bool incorrectNumOperandTypes = false;
1222  auto matchFun = [&](Operation *op) {
1223  if (getOps().has_value() && !strs.contains(op->getName().getStringRef()))
1224  return;
1225 
1226  // Interfaces cannot be matched by name, just by ID.
1227  // So we specifically encode the interfaces we care about for this op.
1228  if (getInterface().has_value()) {
1229  auto iface = getInterface().value();
1230  if (iface == transform::MatchInterfaceEnum::LinalgOp &&
1231  !isa<LinalgOp>(op))
1232  return;
1233  if (iface == transform::MatchInterfaceEnum::TilingInterface &&
1234  !isa<TilingInterface>(op))
1235  return;
1236  if (iface == transform::MatchInterfaceEnum::LoopLikeInterface &&
1237  !isa<LoopLikeOpInterface>(op))
1238  return;
1239  }
1240 
1241  // Check if all specified attributes match.
1242  if (getOpAttrs().has_value()) {
1243  DictionaryAttr opAttrs = getOpAttrs().value();
1244  for (NamedAttribute attr : opAttrs) {
1245  if (attr.getName() == getInterfaceAttrName() ||
1246  attr.getName() == getOpsAttrName())
1247  continue;
1248  if (!op->hasAttr(attr.getName()))
1249  return;
1250  if (op->getAttr(attr.getName()) != attr.getValue())
1251  return;
1252  }
1253  }
1254 
1255  if (getFilterResultType().has_value()) {
1256  Type t = getFilterResultType().value();
1257  if (op->getNumResults() != 1 || op->getResultTypes().front() != t)
1258  return;
1259  }
1260 
1261  if (getFilterOperandTypes().has_value()) {
1262  mlir::ArrayAttr types = getFilterOperandTypes().value();
1263  auto operandTypes = op->getOperandTypes();
1264 
1265  if (types.size() == 1) {
1266  // All the operands must must be equal to the specified type
1267  auto typeattr =
1268  dyn_cast<mlir::TypeAttr>(getFilterOperandTypes().value()[0]);
1269  Type t = cast<::mlir::Type>(typeattr.getValue());
1270  if (!llvm::all_of(op->getOperandTypes(),
1271  [&](Type operandType) { return operandType == t; }))
1272  return;
1273  } else {
1274  // The operand types must match all the types in the list (in the same
1275  // order in with they are specified)
1276  if (types.size() != operandTypes.size()) {
1277  incorrectNumOperandTypes = true;
1278  return;
1279  }
1280 
1281  for (auto [attr, operandType] :
1282  llvm::zip_equal(getFilterOperandTypes().value(), operandTypes)) {
1283  auto typeattr = cast<mlir::TypeAttr>(attr);
1284  Type type = cast<::mlir::Type>(typeattr.getValue());
1285 
1286  if (type != operandType)
1287  return;
1288  }
1289  }
1290  }
1291 
1292  // All constraints are satisfied.
1293  res.push_back(op);
1294  return;
1295  };
1296 
1297  (*payloadOps.begin())->walk(matchFun);
1298  if (incorrectNumOperandTypes)
1299  return emitDefiniteFailure("If filter_operand_types contains more than a "
1300  "type, then it must contain as much types as "
1301  "the number of operands in the target ops");
1302  results.set(cast<OpResult>(getResult()), res);
1304 }
1305 
1306 //===---------------------------------------------------------------------===//
1307 // MultiTileSizesOp
1308 //===---------------------------------------------------------------------===//
1309 
1311  Type targetType, Type lowSizeType, Type,
1312  Type) {
1313  printer.printFunctionalType(TypeRange{targetType}, TypeRange{lowSizeType});
1314 }
1315 
1316 static ParseResult parseMultitileSizesTypes(OpAsmParser &parser,
1317  Type &targetType, Type &lowSizeType,
1318  Type &highSizeType,
1319  Type &splitPointType) {
1320  FunctionType funcType;
1321  llvm::SMLoc typeLoc = parser.getCurrentLocation();
1322  if (failed(parser.parseType<FunctionType>(funcType)))
1323  return failure();
1324 
1325  if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
1326  parser.emitError(typeLoc) << "expects a trailing functional type with one "
1327  "argument and one result";
1328  }
1329  targetType = funcType.getInput(0);
1330  lowSizeType = highSizeType = splitPointType = funcType.getResult(0);
1331 
1332  return success();
1333 }
1334 
1335 DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne(
1336  transform::TransformRewriter &rewriter, LinalgOp target,
1338  if (isa<TransformParamTypeInterface>(getLowSize().getType())) {
1339  if (target.hasDynamicShape()) {
1340  auto diag = emitSilenceableError()
1341  << "cannot compute parametric tile sizes for dynamically "
1342  "shaped payload op";
1343  diag.attachNote(target->getLoc()) << "payload op";
1344  return diag;
1345  }
1346 
1347  FailureOr<StaticMultiSizeSpecification> spec = computeStaticMultiTileSizes(
1348  target, getDimension(), getTargetSize(), getDivisor());
1349  if (failed(spec)) {
1350  return emitSilenceableError()
1351  << "failed to compute multi-size tiling sizes";
1352  }
1353 
1354  Builder builder(target.getContext());
1355  results.assign(llvm::map_range(
1356  ArrayRef<int64_t>({spec->lowTileSize, spec->highTileSize,
1357  spec->lowTileSize * spec->lowTripCount}),
1358  [&builder, this](int64_t value) {
1359  return builder.getIntegerAttr(
1360  cast<ParamType>(getLowSize().getType()).getType(), value);
1361  }));
1363  }
1364 
1365  OpBuilder builder(target.getContext());
1366  builder.setInsertionPoint(target);
1367  OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
1368  OpFoldResult divisor = builder.getIndexAttr(getDivisor());
1369  FailureOr<MultiSizeSpecification> spec = computeMultiTileSizes(
1370  builder, target, getDimension(), targetSize, divisor);
1371  if (failed(spec)) {
1372  return emitSilenceableError() << "could not generate tile size computation";
1373  }
1374 
1375  AffineExpr s0 = builder.getAffineSymbolExpr(0);
1376  AffineExpr s1 = builder.getAffineSymbolExpr(1);
1377  Operation *splitPoint =
1378  affine::makeComposedAffineApply(builder, target.getLoc(), s0 * s1,
1379  {spec->lowTileSize, spec->lowTripCount});
1380  Operation *lowTileSize = spec->lowTileSize.getDefiningOp();
1381  Operation *highTileSize = spec->highTileSize.getDefiningOp();
1382  assert(lowTileSize && highTileSize && splitPoint &&
1383  "tile sizes are not produced by operations");
1384  results.reserve(results.size() + 3);
1385  results.push_back(lowTileSize);
1386  results.push_back(highTileSize);
1387  results.push_back(splitPoint);
1389 }
1390 
1391 void transform::MultiTileSizesOp::getEffects(
1393  onlyReadsHandle(getTargetMutable(), effects);
1394  producesHandle(getOperation()->getOpResults(), effects);
1395  if (isa<TransformParamTypeInterface>(getLowSize().getType()))
1396  onlyReadsPayload(effects);
1397  else
1398  modifiesPayload(effects);
1399 }
1400 
1401 LogicalResult transform::MultiTileSizesOp::verify() {
1402  if (getLowSize().getType() != getHighSize().getType() ||
1403  getLowSize().getType() != getSplitPoint().getType()) {
1404  return emitOpError() << "expects all results type to be the same";
1405  }
1406  return success();
1407 }
1408 
1409 //===---------------------------------------------------------------------===//
1410 // PackOp
1411 //===---------------------------------------------------------------------===//
1412 
1413 void transform::PackOp::build(OpBuilder &builder, OperationState &result,
1414  Value target,
1415  ArrayRef<OpFoldResult> mixedPackedSizes) {
1416  SmallVector<int64_t> staticPackedSizes;
1417  SmallVector<Value> dynamicPackedSizes;
1418  dispatchIndexOpFoldResults(mixedPackedSizes, dynamicPackedSizes,
1419  staticPackedSizes);
1420  // Call the default builder which sets up the proper operands segment sizes
1421  // attributes for multiple variadic operands. In the absence of this, horrible
1422  // bugs ensue.
1423  Type linalgOpHType = transform::OperationType::get(
1424  builder.getContext(), GenericOp::getOperationName());
1425  build(builder, result,
1426  /*resultType=*/linalgOpHType,
1427  /*target=*/target,
1428  /*dynamic_sizes=*/dynamicPackedSizes,
1429  /*static_sizes=*/builder.getDenseI64ArrayAttr(staticPackedSizes));
1430 }
1431 
1432 SmallVector<OpFoldResult> transform::PackOp::getMixedPackedSizes() {
1433  Builder b(getContext());
1434  return getMixedValues(getStaticPackedSizes(), getPackedSizes(), b);
1435 }
1436 
1438 transform::PackOp::apply(transform::TransformRewriter &rewriter,
1439  transform::TransformResults &transformResults,
1440  transform::TransformState &state) {
1441  auto targetOps = state.getPayloadOps(getTarget());
1442  // If nothing to pack, propagate success.
1443  if (std::empty(targetOps)) {
1444  transformResults.set(cast<OpResult>(getPackedOp()),
1445  ArrayRef<Operation *>({}));
1447  }
1448  // Fail on multi-op handles.
1449  auto linalgOp = dyn_cast<LinalgOp>(*targetOps.begin());
1450  if (!llvm::hasSingleElement(targetOps) || !linalgOp) {
1451  return emitSilenceableError()
1452  << "requires target to map to exactly 1 LinalgOp (got "
1453  << llvm::range_size(targetOps) << ")";
1454  }
1455  // Fail on mismatched number of pack sizes.
1456  if (getMixedPackedSizes().size() != linalgOp.getNumLoops()) {
1457  return emitSilenceableError()
1458  << "requires number of packed sizes match the number of loops ("
1459  << getMixedPackedSizes().size() << " vs " << linalgOp.getNumLoops()
1460  << ")";
1461  }
1462 
1463  // Unpack handles to constants or actual SSA index values.
1464  SmallVector<OpFoldResult> packedSizes;
1466  state, *this, packedSizes, getMixedPackedSizes());
1467 
1468  rewriter.setInsertionPoint(linalgOp);
1469  FailureOr<PackResult> maybeResult = pack(rewriter, linalgOp, packedSizes);
1470  if (failed(maybeResult))
1471  return emitDefiniteFailure("data tiling failed");
1472 
1473  transformResults.set(cast<OpResult>(getPackedOp()),
1474  {maybeResult->packedLinalgOp.getOperation()});
1476 }
1477 
1478 void transform::PackOp::getEffects(
1480  transform::consumesHandle(getTargetMutable(), effects);
1481  transform::onlyReadsHandle(getPackedSizesMutable(), effects);
1482  transform::producesHandle(getOperation()->getOpResults(), effects);
1483  transform::modifiesPayload(effects);
1484 }
1485 
1486 //===---------------------------------------------------------------------===//
1487 // PackGreedilyOp.
1488 //===---------------------------------------------------------------------===//
1489 
1490 LogicalResult transform::PackGreedilyOp::verify() {
1491  if (!isPermutationVector(getMatmulInnerDimsOrder())) {
1492  return emitOpError() << getMatmulInnerDimsOrderAttrName()
1493  << " is not a valid permutation";
1494  }
1495  // TODO: relax to allow empty once we have another strategy than just matmul.
1496  if (!getMatmulPaddedSizesNextMultipleOf().empty()) {
1497  for (auto [s, nmo] :
1498  llvm::zip_equal(getMixedMatmulPackedSizes(),
1499  getMatmulPaddedSizesNextMultipleOf())) {
1500  std::optional<int64_t> maybeStaticPackedSize = getConstantIntValue(s);
1501  if (nmo != 0 &&
1502  (!maybeStaticPackedSize.has_value() || *maybeStaticPackedSize != 0)) {
1503  return emitOpError() << "at most one of the packed_size and the "
1504  "padded_sizes_next_multiple_of can be nonzero "
1505  "for the matmul strategy";
1506  }
1507  }
1508  }
1509  return success();
1510 }
1511 
1513 PackGreedilyOp::apply(transform::TransformRewriter &rewriter,
1514  transform::TransformResults &transformResults,
1515  transform::TransformState &state) {
1516  SmallVector<Operation *> results;
1517  for (Operation *op : state.getPayloadOps(getTarget())) {
1518  auto linalgOp = dyn_cast<LinalgOp>(op);
1519  if (!linalgOp)
1520  continue;
1521  // linalgOp will be replaced and the insertion point may be invalidated if
1522  // we set it before -> set it after.
1523  rewriter.setInsertionPointAfter(linalgOp);
1524  // Failing to pack greedily is perfectly fine.
1525  // In the future we will want to order packings according to some metric.
1526  FailureOr<PackResult> packResult = packMatmulGreedily(
1527  /*rewriter=*/rewriter,
1528  /*linalgOp=*/linalgOp,
1529  /*mnkPackedSizes=*/getMixedMatmulPackedSizes(),
1530  /*mnkPaddedSizesNextMultipleOf=*/
1531  getMatmulPaddedSizesNextMultipleOf(),
1532  /*mnkOrder=*/getMatmulInnerDimsOrder());
1533  if (succeeded(packResult)) {
1534  results.push_back(packResult->packedLinalgOp);
1535  continue;
1536  }
1537  results.push_back(linalgOp);
1538  }
1539  transformResults.set(cast<OpResult>(getPackedOp()), results);
1541 }
1542 
1543 SmallVector<OpFoldResult> PackGreedilyOp::getMixedMatmulPackedSizes() {
1544  Builder b(getContext());
1545  return getMixedValues(getStaticMatmulPackedSizes(), getMatmulPackedSizes(),
1546  b);
1547 }
1548 
1549 void transform::PackGreedilyOp::getEffects(
1551  transform::consumesHandle(getTargetMutable(), effects);
1552  transform::onlyReadsHandle(getMatmulPackedSizesMutable(), effects);
1553  transform::producesHandle(getOperation()->getOpResults(), effects);
1554  transform::modifiesPayload(effects);
1555 }
1556 
1557 //===---------------------------------------------------------------------===//
1558 // PackTransposeOp
1559 //===---------------------------------------------------------------------===//
1560 
1561 LogicalResult transform::PackTransposeOp::verify() {
1562  if (!isPermutationVector(getInnerPerm())) {
1563  return emitOpError() << getInnerPermAttrName()
1564  << " is not a valid permutation";
1565  }
1566  if (!isPermutationVector(getOuterPerm())) {
1567  return emitOpError() << getOuterPermAttrName()
1568  << " is not a valid permutation";
1569  }
1570  if (getInnerPerm().empty() && getOuterPerm().empty()) {
1571  return emitOpError() << " at least one of " << getInnerPermAttrName()
1572  << " or " << getOuterPermAttrName()
1573  << " must be specified";
1574  }
1575  return success();
1576 }
1577 
1578 namespace {
1579 enum class OuterOrInnerPerm { Outer = 0, Inner = 1 };
1580 } // namespace
1581 
1582 /// Return true if `permutation` is a valid permutation of the
1583 /// `outer_dims_perm` (case OuterOrInnerPerm::Outer) or `inner_dims_pos`
1584 /// (OuterOrInnerPerm::Inner) of the `tensor.pack` or `tensor.unpack` `op.
1585 /// This is the case when the `permutation` rank matches the rank expected by
1586 /// `op` and `permutation` is itself a permutation vector.
1587 /// Return true if either `op` or `permutation` are empty to allow a simpler
1588 /// polymorphic implementation.
1589 template <typename RelayoutOpTy>
1591  RelayoutOpTy op, ArrayRef<int64_t> permutation,
1592  OuterOrInnerPerm outerOrInnerPerm = OuterOrInnerPerm::Outer) {
1593  static_assert(
1594  llvm::is_one_of<RelayoutOpTy, tensor::PackOp, tensor::UnPackOp>::value,
1595  "applies to only pack or unpack operations");
1596  if (!op || permutation.empty())
1597  return true;
1598  size_t innerRank = op.getInnerDimsPos().size();
1599  if (outerOrInnerPerm == OuterOrInnerPerm::Inner)
1600  return permutation.size() == innerRank && isPermutationVector(permutation);
1601  // op.getOuterDimsPerm() may be empty, in which case it is identity.
1602  // Don't rely on it.
1603  if (std::is_same<RelayoutOpTy, tensor::PackOp>::value) {
1604  return permutation.size() == op.getSourceRank() &&
1605  isPermutationVector(permutation);
1606  }
1607  return permutation.size() == op.getDestRank() &&
1608  isPermutationVector(permutation);
1609 }
1610 
1612 transform::PackTransposeOp::apply(transform::TransformRewriter &rewriter,
1613  transform::TransformResults &transformResults,
1614  transform::TransformState &state) {
1615  auto packOrUnpackOps = state.getPayloadOps(getTargetPackOrUnPackOp());
1616  auto linalgOps = state.getPayloadOps(getTargetLinalgOp());
1617  // Step 1. If nothing to pack, propagate success.
1618  if (std::empty(packOrUnpackOps)) {
1619  transformResults.set(cast<OpResult>(getPackedOp()), {});
1620  transformResults.set(cast<OpResult>(getPackOp()), {});
1621  transformResults.set(cast<OpResult>(getUnPackOp()), {});
1623  }
1624 
1625  // Step 2. Bunch of runtime sanity check and error messages.
1626  // Step 2.1. Fail on multi-op handles.
1627  if (!llvm::hasSingleElement(packOrUnpackOps) ||
1628  !llvm::hasSingleElement(linalgOps)) {
1629  return emitSilenceableError()
1630  << "requires target to map to exactly 1 "
1631  "packing op and 1 packed op ("
1632  << "got " << llvm::range_size(packOrUnpackOps) << " and "
1633  << llvm::range_size(linalgOps) << ")";
1634  }
1635 
1636  // Step 2.2. Fail on wrong type.
1637  auto packOp = dyn_cast<tensor::PackOp>(*packOrUnpackOps.begin());
1638  auto unPackOp = dyn_cast<tensor::UnPackOp>(*packOrUnpackOps.begin());
1639  if ((!packOp && !unPackOp)) {
1640  return emitSilenceableError() << "requires target to map to a "
1641  "tensor.pack or tensor.unpack";
1642  }
1643  LinalgOp linalgOpTarget = dyn_cast<LinalgOp>(*linalgOps.begin());
1644  if (!linalgOpTarget)
1645  return emitSilenceableError() << "requires a LinalgOp target";
1646 
1647  // Step 2.3. Fail if we can't get the producer / consumer Linalg op.
1648  LinalgOp linalgOp;
1649  if (packOp && packOp.getResult().hasOneUse())
1650  linalgOp = dyn_cast<LinalgOp>(*(packOp.getResult().getUsers().begin()));
1651  else if (unPackOp)
1652  linalgOp = unPackOp.getSource().getDefiningOp<LinalgOp>();
1653  if (linalgOp != linalgOpTarget) {
1654  auto errorMsg =
1655  packOp ? StringLiteral{"not a single use by the LinalgOp target"}
1656  : StringLiteral{"not produced by the LinalgOp target"};
1657  return emitSilenceableError() << errorMsg;
1658  }
1659 
1660  // Step 2.4. If we have an UnPackOp, we need to fetch the symmetrical
1661  // PackOp.
1662  if (unPackOp) {
1663  assert(!packOp && "packOp must be null on entry when unPackOp is not null");
1664  OpOperand *packUse = linalgOp.getDpsInitOperand(
1665  cast<OpResult>(unPackOp.getSource()).getResultNumber());
1666  packOp = dyn_cast_or_null<tensor::PackOp>(packUse->get().getDefiningOp());
1667  if (!packOp || !packOp.getResult().hasOneUse())
1668  return emitSilenceableError() << "could not find matching pack op";
1669  }
1670 
1671  // Step 2.5. Fail if any permutation does not validate.
1672  for (auto permType : {OuterOrInnerPerm::Outer, OuterOrInnerPerm::Inner}) {
1673  ArrayRef<int64_t> perm =
1674  (permType == OuterOrInnerPerm::Outer) ? getOuterPerm() : getInnerPerm();
1675  auto errorMsg = (permType == OuterOrInnerPerm::Outer)
1676  ? StringLiteral{"invalid outer_perm"}
1677  : StringLiteral{"invalid inner_perm"};
1678  if (!isValidPackingPermutation(packOp, perm, permType) ||
1679  !isValidPackingPermutation(unPackOp, perm, permType)) {
1680  Operation *packOrUnpackOp =
1681  unPackOp ? unPackOp.getOperation() : packOp.getOperation();
1682  return emitSilenceableError() << errorMsg << ": " << *packOrUnpackOp;
1683  }
1684  }
1685 
1686  // From here on, packOp and linalgOp are always present, unPackOp may or may
1687  // not be present.
1688  assert(packOp && linalgOp && "unexpected null op");
1689 
1690  // Step 3. Actually transpose the ops.
1691  FailureOr<PackTransposeResult> res = packTranspose(
1692  rewriter, packOp, linalgOp, unPackOp, getOuterPerm(), getInnerPerm());
1693  // Preconditions have been checked, it is an error to fail here.
1694  assert(succeeded(res) && "unexpected packTranspose failure");
1695 
1696  // Step 4. Return results.
1697  transformResults.set(cast<OpResult>(getPackOp()), {res->transposedPackOp});
1698  transformResults.set(cast<OpResult>(getPackedOp()),
1699  {res->transposedLinalgOp});
1700  if (unPackOp) {
1701  transformResults.set(cast<OpResult>(getUnPackOp()),
1702  {res->transposedUnPackOp});
1703  } else {
1704  transformResults.set(cast<OpResult>(getUnPackOp()), {});
1705  }
1706 
1708 }
1709 
1710 //===---------------------------------------------------------------------===//
1711 // PadOp
1712 //===---------------------------------------------------------------------===//
1713 
1714 void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
1715  ArrayRef<int64_t> paddingDimensions,
1716  ArrayRef<int64_t> padToMultipleOf,
1717  ArrayRef<int64_t> packPaddings,
1718  ArrayRef<Attribute> transposePaddings,
1719  StringRef copyBackOp) {
1720  auto resultType = transform::AnyOpType::get(b.getContext());
1721  return build(/*builder=*/b,
1722  /*result=*/result,
1723  /*types=*/TypeRange{resultType, resultType},
1724  /*target=*/target,
1725  /*paddingValues=*/ArrayAttr(), // let inference handle this
1726  /*paddingDimensions=*/b.getI64ArrayAttr(paddingDimensions),
1727  /*padToMultipleOf=*/ValueRange{},
1728  /*padToMultipleOf=*/
1729  (padToMultipleOf.empty()
1730  ? DenseI64ArrayAttr()
1731  : b.getDenseI64ArrayAttr(padToMultipleOf)),
1732  /*packPaddings=*/b.getI64ArrayAttr(packPaddings),
1733  /*transposePaddings=*/b.getArrayAttr(transposePaddings),
1734  /*copyBackOp=*/b.getStringAttr(copyBackOp));
1735 }
1736 
1737 void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
1738  ArrayRef<int64_t> paddingDimensions,
1739  ArrayRef<OpFoldResult> mixedPadToMultipleOf,
1740  ArrayRef<int64_t> packPaddings,
1741  ArrayRef<Attribute> transposePaddings,
1742  StringRef copyBackOp) {
1743  auto resultType = transform::AnyOpType::get(b.getContext());
1744  SmallVector<int64_t> staticPadToMultipleOf;
1745  SmallVector<Value> dynamicPadToMultipleOf;
1746  dispatchIndexOpFoldResults(mixedPadToMultipleOf, dynamicPadToMultipleOf,
1747  staticPadToMultipleOf);
1748  return build(/*builder=*/b,
1749  /*result=*/result,
1750  /*types=*/TypeRange{resultType, resultType},
1751  /*target=*/target,
1752  /*paddingValues=*/ArrayAttr(), // let inference handle this
1753  /*paddingDimensions=*/b.getI64ArrayAttr(paddingDimensions),
1754  /*padToMultipleOf=*/dynamicPadToMultipleOf,
1755  /*padToMultipleOf=*/staticPadToMultipleOf,
1756  /*packPaddings=*/b.getI64ArrayAttr(packPaddings),
1757  /*transposePaddings=*/b.getArrayAttr(transposePaddings),
1758  /*copyBackOp=*/b.getStringAttr(copyBackOp));
1759 }
1760 
1761 void PadOp::getEffects(
1763  consumesHandle(getTargetMutable(), effects);
1764  onlyReadsHandle(getPadToMultipleOfMutable(), effects);
1765  producesHandle(getOperation()->getOpResults(), effects);
1766  modifiesPayload(effects);
1767 }
1768 
1769 SmallVector<OpFoldResult> PadOp::getMixedPadToMultipleOf() {
1770  Builder b(getContext());
1771  return getMixedValues(getStaticPadToMultipleOf(), getPadToMultipleOf(), b);
1772 }
1773 
1775 transform::PadOp::apply(transform::TransformRewriter &rewriter,
1776  transform::TransformResults &results,
1777  transform::TransformState &state) {
1778  auto transformOp = cast<TransformOpInterface>(getOperation());
1779  SmallVector<Operation *> paddedOps, padOps, copyBackOps;
1780 
1781  for (Operation *target : state.getPayloadOps(getTarget())) {
1782  auto linalgTarget = dyn_cast<LinalgOp>(target);
1783  if (!linalgTarget) {
1784  auto diag = emitSilenceableError() << "expected LinalgOp target";
1785  diag.attachNote(target->getLoc()) << "target op";
1786  return diag;
1787  }
1788 
1789  // Convert the integer packing flags to booleans.
1790  SmallVector<bool> packPaddings;
1791  for (int64_t packPadding :
1792  extractFromIntegerArrayAttr<int64_t>(getPackPaddings()))
1793  packPaddings.push_back(static_cast<bool>(packPadding));
1794 
1795  // Convert the padding values to attributes.
1796  SmallVector<Attribute> paddingValues;
1797  for (auto const &it :
1798  llvm::zip(getPaddingValues(), linalgTarget->getOperandTypes())) {
1799  auto attr = dyn_cast<TypedAttr>(std::get<0>(it));
1800  if (!attr) {
1801  emitOpError("expects padding values to be typed attributes");
1803  }
1804  Type elementType = getElementTypeOrSelf(std::get<1>(it));
1805  // Try to parse string attributes to obtain an attribute of element type.
1806  if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
1807  auto parsedAttr = dyn_cast_if_present<TypedAttr>(parseAttribute(
1808  stringAttr, getContext(), elementType,
1809  /*numRead=*/nullptr, /*isKnownNullTerminated=*/true));
1810  if (!parsedAttr || parsedAttr.getType() != elementType) {
1811  auto diag = this->emitOpError("expects a padding that parses to ")
1812  << elementType << ", got " << std::get<0>(it);
1813  diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";
1815  }
1816  paddingValues.push_back(parsedAttr);
1817  continue;
1818  }
1819  // Otherwise, add the attribute directly.
1820  if (attr.getType() != elementType) {
1821  auto diag = this->emitOpError("expects a padding value of type ")
1822  << elementType << ", got " << attr;
1823  diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";
1825  }
1826  paddingValues.push_back(attr);
1827  }
1828 
1829  // Extract the transpose vectors.
1830  SmallVector<SmallVector<int64_t>> transposePaddings;
1831  for (Attribute transposeVector : cast<ArrayAttr>(getTransposePaddings()))
1832  transposePaddings.push_back(extractFromIntegerArrayAttr<int64_t>(
1833  cast<ArrayAttr>(transposeVector)));
1834 
1835  LinalgOp paddedOp;
1837  options.paddingDimensions =
1838  extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
1839 
1840  SmallVector<int64_t> padToMultipleOf;
1842  state, transformOp, getMixedPadToMultipleOf(), padToMultipleOf);
1843  if (!status.succeeded())
1844  return status;
1845  if (padToMultipleOf.empty())
1846  padToMultipleOf =
1847  SmallVector<int64_t>(options.paddingDimensions.size(), 1);
1848 
1849  options.padToMultipleOf = padToMultipleOf;
1850  options.paddingValues = paddingValues;
1851  options.packPaddings = packPaddings;
1852  if (getCopyBackOp() ==
1853  bufferization::MaterializeInDestinationOp::getOperationName()) {
1856  } else if (getCopyBackOp() == linalg::CopyOp::getOperationName()) {
1858  } else if (getCopyBackOp() == kCopyOpNone) {
1860  } else {
1861  llvm_unreachable("unsupported copy_back op");
1862  }
1863 
1864  SmallVector<Value> replacements;
1865  SmallVector<tensor::PadOp> newPadOps;
1866  if (failed(rewriteAsPaddedOp(rewriter, linalgTarget, options, paddedOp,
1867  replacements, newPadOps))) {
1868  auto diag = emitSilenceableError() << "failed to pad op";
1869  diag.attachNote(target->getLoc()) << "target op";
1870  return diag;
1871  }
1872 
1873  // We need to perform our own replacement here because this API is still
1874  // used in patterns that "pad and hoist", for which the replacement values
1875  // need to be different.
1876  // TODO: clean this up and stop "pad and hoist" behavior more globally now
1877  // that we have more composable abstractions.
1878  rewriter.replaceOp(linalgTarget, replacements);
1879  paddedOps.push_back(paddedOp);
1880  padOps.append(newPadOps.begin(), newPadOps.end());
1881  if (options.copyBackOp != LinalgPaddingOptions::CopyBackOp::None) {
1882  for (Value v : replacements) {
1883  Operation *copyBackOp = v.getDefiningOp();
1884  if (!llvm::is_contained(copyBackOps, copyBackOp))
1885  copyBackOps.push_back(copyBackOp);
1886  }
1887  }
1888  }
1889 
1890  results.set(cast<OpResult>(getPadded()), paddedOps);
1891  results.set(cast<OpResult>(getPad()), padOps);
1892  results.set(cast<OpResult>(getCopy()), copyBackOps);
1894 }
1895 
1896 LogicalResult transform::PadOp::verify() {
1897  SmallVector<int64_t> packPaddings =
1898  extractFromIntegerArrayAttr<int64_t>(getPackPaddings());
1899  if (any_of(packPaddings, [](int64_t packPadding) {
1900  return packPadding != 0 && packPadding != 1;
1901  })) {
1902  return emitOpError()
1903  << "expects pack_paddings to contain booleans (0/1), found "
1904  << getPackPaddings();
1905  }
1906 
1907  SmallVector<int64_t> paddingDimensions =
1908  extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
1909  if (any_of(paddingDimensions,
1910  [](int64_t paddingDimension) { return paddingDimension < 0; })) {
1911  return emitOpError() << "expects padding_dimensions to contain positive "
1912  "integers, found "
1913  << getPaddingDimensions();
1914  }
1915  if (!getMixedPadToMultipleOf().empty()) {
1916  if (getMixedPadToMultipleOf().size() != paddingDimensions.size()) {
1917  return emitOpError() << "expects as many multiples as padding_dimensions";
1918  }
1919  }
1920  ArrayAttr transposes = getTransposePaddings();
1921  for (Attribute attr : transposes) {
1922  SmallVector<int64_t> transpose = extractFromIntegerArrayAttr<int64_t>(attr);
1923  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
1924  if (!std::is_permutation(sequence.begin(), sequence.end(),
1925  transpose.begin(), transpose.end())) {
1926  return emitOpError()
1927  << "expects transpose_paddings to be a permutation, found "
1928  << attr;
1929  }
1930  }
1931  if (getCopyBackOp() !=
1932  bufferization::MaterializeInDestinationOp::getOperationName() &&
1933  getCopyBackOp() != linalg::CopyOp::getOperationName() &&
1934  getCopyBackOp() != kCopyOpNone)
1935  return emitOpError() << "invalid copy_back_op";
1936  return success();
1937 }
1938 
1939 //===---------------------------------------------------------------------===//
1940 // HoistPadOp
1941 //===---------------------------------------------------------------------===//
1942 
1943 DiagnosedSilenceableFailure transform::HoistPadBuildPackingLoopNestOp::apply(
1944  transform::TransformRewriter &rewriter,
1945  transform::TransformResults &transformResults,
1946  transform::TransformState &state) {
1947  auto targetOps = state.getPayloadOps(getTarget());
1948  auto loopOps = state.getPayloadOps(getLoop());
1949  if (!llvm::hasSingleElement(targetOps) || !llvm::hasSingleElement(loopOps)) {
1950  return emitDefiniteFailure()
1951  << "requires exactly one target and one loop handle (got "
1952  << llvm::range_size(targetOps) << " and "
1953  << llvm::range_size(loopOps) << ")";
1954  }
1955 
1956  auto padOp = dyn_cast_or_null<tensor::PadOp>(*targetOps.begin());
1957  auto loopOp = dyn_cast_or_null<scf::ForOp>(*loopOps.begin());
1958  if (!padOp || !loopOp)
1959  return emitDefiniteFailure() << "requires exactly 2 non-null handles";
1960 
1961  FailureOr<linalg::detail::PackingResult> result =
1962  linalg::detail::buildPackingLoopNest(rewriter, padOp, loopOp,
1963  getTranspose());
1964  if (failed(result))
1965  return emitDefiniteFailure() << "could not build packing loop nest";
1966 
1967  if (result->clonedLoopIvs.empty()) {
1968  transformResults.set(cast<OpResult>(getPackingLoop()),
1969  {result->hoistedPadOp.getOperation()});
1971  }
1972  auto outerPackedLoop =
1973  scf::getForInductionVarOwner(result->clonedLoopIvs.front());
1974  transformResults.set(cast<OpResult>(getPackingLoop()),
1975  {outerPackedLoop.getOperation()});
1977 }
1978 
1980  ArrayRef<int64_t> transpose = getTranspose();
1981  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
1982  if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),
1983  transpose.end())) {
1984  return emitOpError() << "expects transpose to be a permutation, found "
1985  << getTranspose();
1986  }
1987  return success();
1988 }
1989 
1990 void transform::HoistPadBuildPackingLoopNestOp::getEffects(
1992  transform::onlyReadsHandle(getTargetMutable(), effects);
1993  transform::onlyReadsHandle(getLoopMutable(), effects);
1994  transform::producesHandle(getOperation()->getOpResults(), effects);
1995  transform::modifiesPayload(effects);
1996 }
1997 
1999 transform::HoistPadOp::applyToOne(transform::TransformRewriter &rewriter,
2000  tensor::PadOp target,
2002  transform::TransformState &state) {
2003  tensor::PadOp hoistedPadOp;
2004  SmallVector<GenericOp> transposeOps;
2005  FailureOr<Value> result =
2006  hoistPaddingOnTensors(rewriter, target, getNumLoops(), getTranspose(),
2007  hoistedPadOp, transposeOps);
2008  if (succeeded(result)) {
2009  // We need to perform our own replacement here because this API is still
2010  // used in patterns that "pad and hoist", for which the replacement values
2011  // need to be different.
2012  // TODO: clean this up and stop "pad and hoist" behavior more globally now
2013  // that we have more composable abstractions.
2014  rewriter.replaceOp(target, *result);
2015  results.push_back(hoistedPadOp);
2017  }
2018  return emitDefaultSilenceableFailure(target);
2019 }
2020 
2021 LogicalResult transform::HoistPadOp::verify() {
2022  ArrayRef<int64_t> transpose = getTranspose();
2023  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2024  if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),
2025  transpose.end())) {
2026  return emitOpError() << "expects transpose to be a permutation, found "
2027  << getTranspose();
2028  }
2029  return success();
2030 }
2031 
2032 //===----------------------------------------------------------------------===//
2033 // PromoteOp
2034 //===----------------------------------------------------------------------===//
2035 
2037 transform::PromoteOp::applyToOne(transform::TransformRewriter &rewriter,
2038  LinalgOp target,
2040  transform::TransformState &state) {
2041  LinalgPromotionOptions promotionOptions;
2042  if (!getOperandsToPromote().empty())
2043  promotionOptions = promotionOptions.setOperandsToPromote(
2044  extractFromIntegerArrayAttr<int64_t>(getOperandsToPromote()));
2045  if (getUseFullTilesByDefault())
2046  promotionOptions = promotionOptions.setUseFullTileBuffersByDefault(
2047  getUseFullTilesByDefault());
2048  if (getUseAlloca())
2049  promotionOptions = promotionOptions.setUseAlloca(getUseAlloca());
2050  if (!getUseFullTileBuffers().empty())
2051  promotionOptions = promotionOptions.setUseFullTileBuffers(
2052  llvm::to_vector(getUseFullTileBuffers().getAsValueRange<BoolAttr>()));
2053  if (getAlignment().has_value())
2054  promotionOptions = promotionOptions.setAlignment(*getAlignment());
2055  if (getMemorySpace().has_value())
2056  promotionOptions = promotionOptions.setMemorySpace(*getMemorySpace());
2057 
2058  if (getMapping().has_value()) {
2059  // The mapping should only contain an element
2060  auto mapping = *getMapping();
2061  if (mapping.size() > 1)
2062  return emitDefaultDefiniteFailure(target);
2063 
2064  auto addressSpace = cast<mlir::gpu::GPUMemorySpaceMappingAttr>(mapping[0]);
2065 
2066  if (addressSpace.getAddressSpace() ==
2067  mlir::gpu::GPUDialect::getWorkgroupAddressSpace()) {
2068  promotionOptions =
2069  promotionOptions
2073  .setUseFullTileBuffers({false, false});
2074  } else if (addressSpace.getAddressSpace() ==
2075  mlir::gpu::GPUDialect::getPrivateAddressSpace()) {
2076  promotionOptions =
2077  promotionOptions
2081  .setUseFullTileBuffers({false, false});
2082  } else {
2083  return emitDefaultDefiniteFailure(target);
2084  }
2085  }
2086 
2087  if (failed(promoteSubviewsPrecondition(target, promotionOptions)))
2088  return emitDefaultDefiniteFailure(target);
2089 
2090  rewriter.setInsertionPoint(target);
2091  FailureOr<LinalgOp> res = promoteSubViews(rewriter, target, promotionOptions);
2092  if (failed(res))
2093  return emitDefaultDefiniteFailure(target);
2094  results.push_back(target);
2096 }
2097 
2098 //===----------------------------------------------------------------------===//
2099 // ReplaceOp
2100 //===----------------------------------------------------------------------===//
2101 
2103 transform::ReplaceOp::apply(transform::TransformRewriter &rewriter,
2104  TransformResults &transformResults,
2105  TransformState &state) {
2106  auto payload = state.getPayloadOps(getTarget());
2107 
2108  // Check for invalid targets.
2109  for (Operation *target : payload) {
2110  if (target->getNumOperands() > 0)
2111  return emitDefiniteFailure() << "expected target without operands";
2112  if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
2113  target->getNumRegions() > 0)
2114  return emitDefiniteFailure()
2115  << "expected target that is isolated from above";
2116  }
2117 
2118  // Clone and replace.
2119  Operation *pattern = &getBodyRegion().front().front();
2120  SmallVector<Operation *> replacements;
2121  for (Operation *target : payload) {
2122  if (getOperation()->isAncestor(target))
2123  continue;
2124  rewriter.setInsertionPoint(target);
2125  Operation *replacement = rewriter.clone(*pattern);
2126  rewriter.replaceOp(target, replacement->getResults());
2127  replacements.push_back(replacement);
2128  }
2129  transformResults.set(cast<OpResult>(getReplacement()), replacements);
2131 }
2132 
2133 void transform::ReplaceOp::getEffects(
2135  consumesHandle(getTargetMutable(), effects);
2136  producesHandle(getOperation()->getOpResults(), effects);
2137  modifiesPayload(effects);
2138 }
2139 
2140 LogicalResult transform::ReplaceOp::verify() {
2141  if (!getBodyRegion().hasOneBlock())
2142  return emitOpError() << "expected one block";
2143  if (std::distance(getBodyRegion().front().begin(),
2144  getBodyRegion().front().end()) != 1)
2145  return emitOpError() << "expected one operation in block";
2146  Operation *replacement = &getBodyRegion().front().front();
2147  if (replacement->getNumOperands() > 0)
2148  return replacement->emitOpError()
2149  << "expected replacement without operands";
2150  if (!replacement->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
2151  replacement->getNumRegions() > 0)
2152  return replacement->emitOpError()
2153  << "expect op that is isolated from above";
2154  return success();
2155 }
2156 
2157 //===----------------------------------------------------------------------===//
2158 // ScalarizeOp
2159 //===----------------------------------------------------------------------===//
2160 
2162 transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
2163  LinalgOp target,
2165  transform::TransformState &state) {
2166  scf::SCFTilingOptions tilingOptions;
2167  tilingOptions.setTileSizeComputationFunction([&](OpBuilder &b, Operation *) {
2168  SmallVector<OpFoldResult> tileSizes;
2169  Location loc = target.getLoc();
2170  SmallVector<OpFoldResult> allShapeSizes =
2171  target.createFlatListOfOperandDims(b, loc);
2172  AffineMap map = target.getShapesToLoopsMap();
2173  if (!map)
2174  return tileSizes;
2175  SmallVector<OpFoldResult> shapeSizes =
2177  allShapeSizes);
2178  // If the shape size is dynamic, tile by 1.
2179  // Otherwise, do not tile (i.e. tile size 0).
2180  for (OpFoldResult shapeSize : shapeSizes) {
2181  tileSizes.push_back(getConstantIntValue(shapeSize) ? b.getIndexAttr(0)
2182  : b.getIndexAttr(1));
2183  }
2184  return tileSizes;
2185  });
2186  SmallVector<int64_t> emptyTileSizes;
2187  rewriter.setInsertionPoint(target);
2188  FailureOr<scf::SCFTilingResult> maybeTilingResult = tileUsingSCF(
2189  rewriter, cast<TilingInterface>(target.getOperation()), tilingOptions);
2190  if (failed(maybeTilingResult))
2191  return emitDefaultDefiniteFailure(target);
2192 
2193  if (target->getNumResults())
2194  rewriter.replaceOp(target, maybeTilingResult->replacements);
2195  else
2196  rewriter.eraseOp(target);
2197 
2198  results.reserve(maybeTilingResult->tiledOps.size());
2199  for (Operation *tiled : maybeTilingResult->tiledOps)
2200  results.push_back(tiled);
2202 }
2203 
2204 //===----------------------------------------------------------------------===//
2205 // ConvertToLoopsOp
2206 //===----------------------------------------------------------------------===//
2207 
2209 transform::ConvertToLoopsOp::apply(transform::TransformRewriter &rewriter,
2210  transform::TransformResults &results,
2211  transform::TransformState &state) {
2213  for (Operation *target : state.getPayloadOps(getTarget())) {
2214  auto tilingOp = dyn_cast<TilingInterface>(*target);
2215  if (!target) {
2217  emitSilenceableError()
2218  << "expected the payload to implement TilingInterface";
2219  diag.attachNote(target->getLoc()) << "payload op";
2220  return diag;
2221  }
2222  rewriter.setInsertionPoint(target);
2223  FailureOr<SmallVector<scf::ForOp>> generatedLoops =
2224  scf::lowerToLoopsUsingSCFForOp(rewriter, tilingOp);
2225  if (failed(generatedLoops))
2226  return emitDefaultDefiniteFailure(target);
2227  for (scf::ForOp &loop : *generatedLoops) {
2228  loops.push_back(loop.getOperation());
2229  }
2230  rewriter.eraseOp(target);
2231  }
2232  results.set(cast<OpResult>(getResult()), loops);
2234 }
2235 
2236 //===----------------------------------------------------------------------===//
2237 // RewriteInDestinationPassingStyleOp
2238 //===----------------------------------------------------------------------===//
2239 
2241 transform::RewriteInDestinationPassingStyleOp::applyToOne(
2242  transform::TransformRewriter &rewriter, Operation *target,
2244  transform::TransformState &state) {
2246  rewriter.setInsertionPoint(target);
2247  FailureOr<Operation *> maybeResult =
2249  .Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp>(
2250  [&rewriter](auto op) {
2251  return rewriteInDestinationPassingStyle(rewriter, op);
2252  });
2253  if (failed(maybeResult))
2254  return emitDefaultSilenceableFailure(target);
2255  results.push_back(*maybeResult);
2257 }
2258 
2259 //===----------------------------------------------------------------------===//
2260 // SplitOp
2261 //===----------------------------------------------------------------------===//
2262 
2264 SplitOp::apply(transform::TransformRewriter &rewriter,
2265  TransformResults &results, TransformState &state) {
2266  // Collect the dynamic split points if provided.
2267  SmallVector<Operation *> payload =
2268  llvm::to_vector(state.getPayloadOps(getTarget()));
2269 
2270  bool isMultiwaySplit = getMultiway();
2271 
2272  if (isMultiwaySplit && !llvm::hasSingleElement(payload)) {
2273  return mlir::emitSilenceableFailure(getLoc())
2274  << "requires exactly one target when "
2275  "multiway split is enabled (got "
2276  << llvm::range_size(payload) << ")";
2277  }
2278 
2279  SmallVector<OpFoldResult> chunkSizes;
2280 
2281  if (!isMultiwaySplit)
2282  chunkSizes.reserve(payload.size());
2283 
2284  if (getDynamicChunkSizes()) {
2286  if (isa<TransformHandleTypeInterface>(getDynamicChunkSizes().getType())) {
2287  chunkSizes = llvm::to_vector(llvm::map_range(
2288  state.getPayloadOps(getDynamicChunkSizes()), [&](Operation *op) {
2289  if (op->getNumResults() != 1 ||
2290  !op->getResult(0).getType().isIndex()) {
2291  diag = emitSilenceableError()
2292  << "expected dynamic split point handle to point to a "
2293  "single-result index-typed op";
2294  diag.attachNote(op->getLoc()) << "dynamic split point";
2295  }
2296  return OpFoldResult(op->getResult(0));
2297  }));
2298  } else {
2299  chunkSizes = llvm::to_vector(
2300  llvm::map_range(state.getParams(getDynamicChunkSizes()),
2301  [](Attribute attr) { return OpFoldResult(attr); }));
2302  }
2303  if (diag.isSilenceableFailure())
2304  return diag;
2305 
2306  // For multiway split, a single payload is expected to have multiple
2307  // split points.
2308  if (!isMultiwaySplit && chunkSizes.size() != payload.size()) {
2309  return emitDefiniteFailure()
2310  << "expected the dynamic split point handle to point to as "
2311  "many operations ("
2312  << chunkSizes.size() << ") as the target handle ("
2313  << payload.size() << ")";
2314  }
2315  } else {
2316  chunkSizes.resize(payload.size(),
2317  rewriter.getIndexAttr(getStaticChunkSizes()));
2318  }
2319 
2320  auto checkStructuredOpAndDimensions =
2321  [&](LinalgOp linalgOp, Location loc) -> DiagnosedSilenceableFailure {
2322  if (!linalgOp) {
2323  auto diag = emitSilenceableError() << "only applies to structured ops";
2324  diag.attachNote(loc) << "target op";
2325  return diag;
2326  }
2327 
2328  if (getDimension() >= linalgOp.getNumLoops()) {
2329  auto diag = emitSilenceableError() << "dimension " << getDimension()
2330  << " does not exist in target op";
2331  diag.attachNote(loc) << "target op";
2332  return diag;
2333  }
2335  };
2336 
2337  auto checkFailureInSplitting =
2338  [&](bool hasFailed, Location loc) -> DiagnosedSilenceableFailure {
2339  if (hasFailed) {
2340  auto diag = emitDefiniteFailure() << "internal failure in splitting";
2341  diag.attachNote(loc) << "target op";
2342  return diag;
2343  }
2345  };
2346 
2347  if (isMultiwaySplit) {
2348 
2349  // Split a single target operation at multiple points.
2350  SmallVector<Operation *> opList;
2351  TilingInterface head, tail;
2352  Operation *target = payload.front();
2353 
2354  LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
2355 
2356  // Check that the target is a valid LinalgOp with correct dimensions.
2358  checkStructuredOpAndDimensions(linalgOp, target->getLoc());
2359  if (diag.isSilenceableFailure())
2360  return diag;
2361 
2362  for (auto &&[idx, chunkSize] : llvm::enumerate(chunkSizes)) {
2363 
2364  if (idx > 0)
2365  target = tail.getOperation();
2366 
2367  if (!target)
2368  break;
2369 
2370  linalgOp = cast<LinalgOp>(target);
2371  Location loc = target->getLoc();
2372 
2373  rewriter.setInsertionPoint(linalgOp);
2374  std::tie(head, tail) = linalg::splitOp(
2375  rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2376  getDimension(), chunkSize);
2377 
2378  // Propagate errors.
2380  checkFailureInSplitting(!head && !tail, loc);
2381  if (diag.isDefiniteFailure())
2382  return diag;
2383 
2384  opList.push_back(head.getOperation());
2385  }
2386 
2387  // Append any leftover parts to the end of the result list.
2388  if (tail)
2389  opList.push_back(tail.getOperation());
2390  results.set(cast<OpResult>(getFirst()), opList);
2391  results.set(cast<OpResult>(getSecond()), {});
2392 
2393  } else {
2394  // Split each target operation.
2395  SmallVector<Operation *> first, second;
2396  Operation *noSecondPart = nullptr;
2397  for (const auto &pair : llvm::zip(payload, chunkSizes)) {
2398  Operation *target = std::get<0>(pair);
2399  Location loc = target->getLoc();
2400  LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
2402  checkStructuredOpAndDimensions(linalgOp, target->getLoc());
2403 
2404  if (diag.isSilenceableFailure())
2405  return diag;
2406 
2407  rewriter.setInsertionPoint(linalgOp);
2408  std::tie(first.emplace_back(), second.emplace_back()) = linalg::splitOp(
2409  rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2410  getDimension(), std::get<1>(pair));
2411 
2412  // Propagate errors.
2413  DiagnosedSilenceableFailure diagSplit =
2414  checkFailureInSplitting(!first.back() && !second.back(), loc);
2415  if (diagSplit.isDefiniteFailure())
2416  return diag;
2417 
2418  // Do not add null second parts.
2419  if (!second.back()) {
2420  noSecondPart = target;
2421  second.pop_back();
2422  }
2423  }
2424 
2425  if (second.size() != first.size() && !second.empty()) {
2426  auto diag = emitSilenceableError()
2427  << "splitting does not produce the second part for a subset "
2428  "of targets";
2429  diag.attachNote()
2430  << "expected splitting to produce the second part of all "
2431  "or none of the targets";
2432  diag.attachNote(noSecondPart->getLoc())
2433  << "first target with no second part";
2434  return diag;
2435  }
2436 
2437  results.set(cast<OpResult>(getFirst()), first);
2438  results.set(cast<OpResult>(getSecond()), second);
2439  }
2441 }
2442 
2443 void SplitOp::getEffects(
2445  consumesHandle(getTargetMutable(), effects);
2446  if (getDynamicChunkSizes())
2447  onlyReadsHandle(getDynamicChunkSizesMutable(), effects);
2448  producesHandle(getOperation()->getOpResults(), effects);
2449  modifiesPayload(effects);
2450 }
2451 
2452 ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) {
2453  OpAsmParser::UnresolvedOperand target, dynamicChunkSizes;
2454  IntegerAttr staticChunkSizes;
2455  if (parser.parseOperand(target) || parser.parseKeyword("after"))
2456  return failure();
2457 
2458  OptionalParseResult dynamicPointParseResult =
2459  parser.parseOptionalOperand(dynamicChunkSizes);
2460  if (!dynamicPointParseResult.has_value()) {
2461  int64_t staticChunkSizesValue;
2462  if (failed(parser.parseInteger(staticChunkSizesValue)))
2463  return failure();
2464 
2465  staticChunkSizes =
2466  parser.getBuilder().getI64IntegerAttr(staticChunkSizesValue);
2467  }
2468 
2469  Type targetType;
2470  if (parser.parseOptionalAttrDict(result.attributes) ||
2471  parser.parseColonType(targetType) ||
2472  parser.resolveOperand(target, targetType, result.operands)) {
2473  return failure();
2474  }
2475  if (dynamicPointParseResult.has_value()) {
2476  Type ChunkSizesType;
2477  if (failed(*dynamicPointParseResult) || parser.parseComma() ||
2478  parser.parseType(ChunkSizesType) ||
2479  parser.resolveOperand(dynamicChunkSizes, ChunkSizesType,
2480  result.operands)) {
2481  return failure();
2482  }
2483 
2484  staticChunkSizes =
2485  parser.getBuilder().getI64IntegerAttr(ShapedType::kDynamic);
2486  }
2487 
2488  result.addAttribute(
2489  SplitOp::getStaticChunkSizesAttrName(result.name).getValue(),
2490  staticChunkSizes);
2491  result.addTypes({targetType, targetType});
2492  return success();
2493 }
2494 
2495 void SplitOp::print(OpAsmPrinter &printer) {
2496  printer << " " << getTarget() << " after ";
2497  int64_t staticChunkSize = static_cast<int64_t>(getStaticChunkSizes());
2498  if (staticChunkSize != ShapedType::kDynamic)
2499  printer << staticChunkSize;
2500  else
2501  printer << getDynamicChunkSizes();
2502  printer << " ";
2503  printer.printOptionalAttrDict(getOperation()->getAttrs(),
2504  {getStaticChunkSizesAttrName()});
2505  printer << " : " << getTarget().getType();
2506  if (staticChunkSize == ShapedType::kDynamic)
2507  printer << ", " << getDynamicChunkSizes().getType();
2508 }
2509 
2510 LogicalResult SplitOp::verify() {
2511  if ((static_cast<int64_t>(getStaticChunkSizes()) != ShapedType::kDynamic) ^
2512  (getDynamicChunkSizes() == nullptr)) {
2513  return emitOpError() << "expects either a dynamic or a static split "
2514  "point to be provided";
2515  }
2516  return success();
2517 }
2518 
2519 //===----------------------------------------------------------------------===//
2520 // SplitReductionOp
2521 //===----------------------------------------------------------------------===//
2522 
2523 void transform::SplitReductionOp::build(
2524  OpBuilder &builder, OperationState &result, Value target,
2525  int64_t splitFactor, int64_t insertSplitDimension, bool innerParallel,
2526  bool useScalingAlgorithm, bool useAlloc) {
2527  MLIRContext *ctx = builder.getContext();
2528  result.addOperands(target);
2529  result.addAttribute(SplitReductionOp::getSplitFactorAttrName(result.name),
2530  builder.getI64IntegerAttr(splitFactor));
2531  result.addAttribute(
2532  SplitReductionOp::getInsertSplitDimensionAttrName(result.name),
2533  builder.getI64IntegerAttr(insertSplitDimension));
2534  if (innerParallel) {
2535  result.addAttribute(SplitReductionOp::getInnerParallelAttrName(result.name),
2536  builder.getUnitAttr());
2537  }
2538  if (useScalingAlgorithm) {
2539  result.addAttribute(
2540  SplitReductionOp::getUseScalingAlgorithmAttrName(result.name),
2541  builder.getUnitAttr());
2542  }
2543  if (useAlloc) {
2544  result.addAttribute(SplitReductionOp::getUseAllocAttrName(result.name),
2545  builder.getUnitAttr());
2546  }
2547  auto resultType = transform::AnyOpType::get(ctx);
2548  result.addTypes({resultType, resultType, resultType, resultType});
2549 }
2550 
2551 DiagnosedSilenceableFailure transform::SplitReductionOp::applyToOne(
2552  transform::TransformRewriter &rewriter, LinalgOp target,
2554  transform::TransformState &state) {
2555  ControlSplitReductionFn splitFn = [&](LinalgOp) {
2556  return linalg::SplitReductionOptions{int64_t(getSplitFactor()),
2557  unsigned(getInsertSplitDimension()),
2558  bool(getInnerParallel())};
2559  };
2560  rewriter.setInsertionPoint(target);
2561  FailureOr<SplitReductionResult> splitResult =
2562  (getUseScalingAlgorithm())
2563  ? splitReductionByScaling(rewriter, target, splitFn, getUseAlloc())
2564  : splitReduction(rewriter, target, splitFn, getUseAlloc());
2565  if (failed(splitResult))
2566  return emitDefaultDefiniteFailure(target);
2567 
2568  results.push_back(splitResult->initOrAlloc);
2569  results.push_back(splitResult->fillOp);
2570  results.push_back(splitResult->splitLinalgOp);
2571  results.push_back(splitResult->resultCombiningLinalgOp);
2573 }
2574 
2575 //===----------------------------------------------------------------------===//
2576 // TileReductionUsingForOp
2577 //===----------------------------------------------------------------------===//
2578 
2579 void transform::TileReductionUsingForOp::build(
2580  OpBuilder &builder, OperationState &result, Value target,
2581  ArrayRef<int64_t> staticTileSizes) {
2582  // Call the default builder.
2583  // This is future-proof re mixed static-dynamic and setting up the proper
2584  // operands segment sizes attributes for multiple variadic operands.
2585  // In the absence of this, horrible bugs ensue.
2586  // TODO: support mixed static-dynamic (see TileUsingForallOp).
2587  MLIRContext *ctx = builder.getContext();
2588  auto opTy = transform::AnyOpType::get(ctx);
2589  auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
2590  build(builder, result,
2591  /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
2592  /*target=*/target,
2593  /*tile_sizes=*/staticTileSizesAttr);
2594 }
2595 
2596 DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
2597  transform::TransformRewriter &rewriter, LinalgOp target,
2599  transform::TransformState &state) {
2600  rewriter.setInsertionPoint(target);
2601  FailureOr<scf::SCFReductionTilingResult> result = scf::tileReductionUsingScf(
2602  rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
2603  getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes())));
2604 
2605  if (failed(result))
2606  return emitDefaultSilenceableFailure(target);
2607  for (Value initValue : result->initialValues)
2608  results.push_back(initValue.getDefiningOp());
2609  for (auto parallelTiledOp : result->parallelTiledOps)
2610  results.push_back(parallelTiledOp);
2611  for (auto mergeOp : result->mergeOps)
2612  results.push_back(mergeOp);
2613  results.push_back(result->loops.front());
2615 }
2616 
2617 //===----------------------------------------------------------------------===//
2618 // TileReductionUsingForallOp
2619 //===----------------------------------------------------------------------===//
2620 
2621 void transform::TileReductionUsingForallOp::build(
2622  OpBuilder &builder, OperationState &result, Value target,
2623  ArrayRef<int64_t> staticNumThreads, ArrayRef<int64_t> staticTileSizes,
2624  ArrayAttr mapping) {
2625  // Call the default builder.
2626  // This is future-proof re mixed static-dynamic and setting up the proper
2627  // operands segment sizes attributes for multiple variadic operands.
2628  // In the absence of this, horrible bugs ensue.
2629  // TODO: support mixed static-dynamic (see TileUsingForallOp).
2630  MLIRContext *ctx = builder.getContext();
2631  auto opTy = transform::AnyOpType::get(ctx);
2632  auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads);
2633  auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
2634  build(builder, result,
2635  /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
2636  /*target=*/target,
2637  /*num_threads=*/staticNumThreadsAttr,
2638  /*tile_sizes=*/staticTileSizesAttr,
2639  /*mapping=*/mapping);
2640 }
2641 
2642 DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
2643  transform::TransformRewriter &rewriter, LinalgOp target,
2645  transform::TransformState &state) {
2646  rewriter.setInsertionPoint(target);
2647  SmallVector<OpFoldResult> numThreads =
2648  getAsOpFoldResult(rewriter.getI64ArrayAttr(getNumThreads()));
2649  SmallVector<OpFoldResult> tileSizes =
2650  getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes()));
2651  FailureOr<linalg::ForallReductionTilingResult> result =
2653  rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
2654  numThreads, tileSizes, getMapping());
2655 
2656  if (failed(result)) {
2657  auto diag = emitSilenceableError() << "could not tile reduction";
2658  diag.attachNote(target.getLoc()) << "target operation";
2659  return diag;
2660  }
2661  for (Value initValue : result->initialValues)
2662  results.push_back(initValue.getDefiningOp());
2663  for (auto parallelTiledOp : result->parallelTiledOps)
2664  results.push_back(parallelTiledOp);
2665  for (auto mergeOp : result->mergeOps)
2666  results.push_back(mergeOp);
2667  results.push_back(result->loops);
2669 }
2670 
2671 //===----------------------------------------------------------------------===//
2672 // ContinuousTileSizesOp
2673 //===----------------------------------------------------------------------===//
2674 
2676 transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
2677  TransformResults &transformResults,
2678  TransformState &state) {
2679 
2680  SmallVector<Operation *> targetOps =
2681  llvm::to_vector(state.getPayloadOps(getTarget()));
2682 
2683  if (!llvm::hasSingleElement(targetOps)) {
2684  return mlir::emitSilenceableFailure(getLoc())
2685  << "requires exactly one target (got " << llvm::range_size(targetOps)
2686  << ")";
2687  }
2688 
2689  Operation *target = *targetOps.begin();
2690  auto linalgOp = dyn_cast<LinalgOp>(target);
2691  auto tileableOp = dyn_cast<TilingInterface>(target);
2692 
2693  if (!linalgOp)
2694  return emitDefiniteFailure() << "expected Linalg Op";
2695 
2696  OpBuilder builder(linalgOp.getContext());
2697 
2698  if (isa<TransformParamTypeInterface>(getChunkSizes().getType())) {
2699  if (linalgOp.hasDynamicShape()) {
2700  auto diag = emitSilenceableError()
2701  << "cannot compute parametric tile sizes for dynamically "
2702  "shaped payload op";
2703  diag.attachNote(linalgOp->getLoc()) << "payload op";
2704  return diag;
2705  }
2706 
2707  FailureOr<StaticContinuousTileSizeSpecification> spec =
2708  computeStaticContinuousTileSizes(linalgOp, getDimension(),
2709  getTargetSize());
2710  if (failed(spec)) {
2711  return emitSilenceableError()
2712  << "failed to compute multi-size tiling sizes";
2713  }
2714 
2715  SmallVector<int64_t> chunkSizes;
2716 
2717  for (auto &&[tileSize, tripCount] :
2718  llvm::zip_equal(spec->tileSizes, spec->tripCounts))
2719  chunkSizes.push_back(tileSize * tripCount);
2720 
2721  auto getI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
2722  return llvm::map_to_vector(values, [&](int64_t value) -> Attribute {
2723  return builder.getI64IntegerAttr(value);
2724  });
2725  };
2726  transformResults.setParams(cast<OpResult>(getTileSizes()),
2727  getI64AttrsFromI64(spec->tileSizes));
2728  transformResults.setParams(cast<OpResult>(getChunkSizes()),
2729  getI64AttrsFromI64(chunkSizes));
2730 
2732  }
2733 
2734  builder.setInsertionPoint(linalgOp);
2735 
2736  OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
2737  unsigned dimension = getDimension();
2738 
2739  FailureOr<ContinuousTileSizeSpecification> spec = computeContinuousTileSizes(
2740  builder, tileableOp, dimension, targetSize, true);
2741  if (failed(spec)) {
2742  return emitSilenceableError() << "could not generate tile size computation";
2743  }
2744 
2745  AffineExpr s0 = builder.getAffineSymbolExpr(0);
2746  AffineExpr s1 = builder.getAffineSymbolExpr(1);
2747  auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value {
2748  return affine::makeComposedAffineApply(builder, linalgOp->getLoc(), expr,
2749  ofrs);
2750  };
2751 
2752  SmallVector<Value> chunkSizes;
2753  Value splitPoint;
2754  for (auto &&[tileSize, tripCount] :
2755  llvm::zip_equal(spec->tileSizes, spec->tripCounts)) {
2756  splitPoint = apply(s0 * s1, {tileSize, tripCount});
2757  chunkSizes.push_back(splitPoint);
2758  }
2759 
2760  auto getDefiningOps = [&](ArrayRef<Value> values) {
2761  return llvm::map_to_vector(values, [&](Value value) -> Operation * {
2762  return value.getDefiningOp();
2763  });
2764  };
2765 
2766  transformResults.set(cast<OpResult>(getTileSizes()),
2767  getDefiningOps(spec->tileSizes));
2768  transformResults.set(cast<OpResult>(getChunkSizes()),
2769  getDefiningOps(chunkSizes));
2770 
2772 }
2773 
2775 
2776  if (getTileSizes().getType() != getChunkSizes().getType()) {
2777  return emitOpError() << "expects all results type to be the same";
2778  }
2779 
2780  return success();
2781 }
2782 
2783 void transform::ContinuousTileSizesOp::getEffects(
2785  if (isa<TransformParamTypeInterface>(getTileSizes().getType()))
2786  onlyReadsPayload(effects);
2787  else
2788  modifiesPayload(effects);
2789  onlyReadsHandle(getTargetMutable(), effects);
2790  producesHandle(getOperation()->getOpResults(), effects);
2791 }
2792 
2794  Type targetType, Type tile_sizes,
2795  Type) {
2796  printer.printFunctionalType(TypeRange{targetType}, TypeRange{tile_sizes});
2797 }
2798 
2799 static ParseResult parseContinuousTileSizeTypes(OpAsmParser &parser,
2800  Type &targetType,
2801  Type &tileSizesType,
2802  Type &chunkSizesType) {
2803  FunctionType funcType;
2804  llvm::SMLoc typeLoc = parser.getCurrentLocation();
2805  if (failed(parser.parseType<FunctionType>(funcType)))
2806  return failure();
2807 
2808  if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
2809  parser.emitError(typeLoc) << "expects a trailing functional type with one "
2810  "argument and one result";
2811  }
2812  targetType = funcType.getInput(0);
2813  tileSizesType = chunkSizesType = funcType.getResult(0);
2814 
2815  return success();
2816 }
2817 
2818 //===----------------------------------------------------------------------===//
2819 // TileUsingForOp
2820 //===----------------------------------------------------------------------===//
2821 
2822 void transform::TileUsingForOp::build(
2823  OpBuilder &builder, OperationState &result, TypeRange loopTypes,
2824  Value target, ArrayRef<int64_t> staticTileSizes,
2825  ArrayRef<int64_t> interchange,
2826  std::optional<ArrayRef<bool>> scalableSizes) {
2827  return build(builder, result, loopTypes,
2828  /*target=*/target,
2829  /*mixedTileSizes=*/
2830  getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
2831  interchange, scalableSizes);
2832 }
2833 
2834 void transform::TileUsingForOp::build(
2835  OpBuilder &builder, OperationState &result, Value target,
2836  ArrayRef<int64_t> staticTileSizes, ArrayRef<int64_t> interchange,
2837  std::optional<ArrayRef<bool>> scalableSizes) {
2838  build(builder, result, target,
2839  getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
2840  interchange, scalableSizes);
2841 }
2842 
2843 void transform::TileUsingForOp::build(
2844  OpBuilder &builder, OperationState &result, Value target,
2845  ArrayRef<OpFoldResult> mixedTileSizes, ArrayRef<int64_t> interchange,
2846  std::optional<ArrayRef<bool>> scalableSizes) {
2847  // Loop types are automaticaly splat by the callee, setting up one is
2848  // enough.
2849  SmallVector<Type> loopTypes(1, builder.getType<transform::AnyOpType>());
2850  build(builder, result, loopTypes, target, mixedTileSizes, interchange,
2851  scalableSizes);
2852 }
2853 
2854 void transform::TileUsingForOp::build(
2855  OpBuilder &builder, OperationState &result, TypeRange loopTypes,
2856  Value target, ArrayRef<OpFoldResult> mixedTileSizes,
2857  ArrayRef<int64_t> interchange,
2858  std::optional<ArrayRef<bool>> scalableSizes) {
2859  SmallVector<int64_t> staticTileSizes;
2860  SmallVector<Value> dynamicTileSizes;
2861  dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
2862  // Call the default builder which sets up the proper operands segment sizes
2863  // attributes for multiple variadic operands. In the absence of this,
2864  // horrible bugs ensue.
2865  auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
2866  unsigned numExpectedLoops =
2867  staticTileSizes.size() - llvm::count(staticTileSizes, 0);
2868  SmallVector<Type> resultTypes;
2869  resultTypes.reserve(numExpectedLoops);
2870  assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) &&
2871  "expected one loop type or as many as loops");
2872  if (loopTypes.size() == 1)
2873  resultTypes.append(numExpectedLoops, loopTypes[0]);
2874  else
2875  llvm::append_range(resultTypes, loopTypes);
2876  SmallVector<bool> expandedScalableSizes(mixedTileSizes.size(), false);
2877  if (scalableSizes.has_value())
2878  expandedScalableSizes.assign(scalableSizes->begin(), scalableSizes->end());
2879  build(builder, result, /*tiled_linalg_op=*/target.getType(),
2880  /*loops=*/resultTypes,
2881  /*target=*/target,
2882  /*dynamic_sizes=*/dynamicTileSizes,
2883  /*static_sizes=*/staticTileSizesAttr,
2884  /*interchange=*/builder.getDenseI64ArrayAttr(interchange),
2885  /*scalable_sizes=*/expandedScalableSizes);
2886 }
2887 
2888 LogicalResult transform::TileUsingForOp::verify() {
2889  if (getMixedSizes().size() != getScalableSizes().size())
2890  return emitOpError("expected same number of sizes (")
2891  << getMixedSizes().size() << ") and scalable sizes ("
2892  << getScalableSizes().size() << ")";
2893  ArrayRef<int64_t> staticSizes = getStaticSizes();
2894  unsigned numExpectedLoops = staticSizes.size() - llvm::count(staticSizes, 0);
2895  if (getLoops().size() != numExpectedLoops)
2896  return emitOpError("expected number of loops to tile (")
2897  << numExpectedLoops << ") to match number of `loops` results ("
2898  << getLoops().size() << ")";
2899  return success();
2900 }
2901 
2903 transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter,
2904  TransformResults &transformResults,
2905  TransformState &state) {
2906  ArrayRef<int64_t> tileSizes = getStaticSizes();
2907 
2908  SmallVector<Operation *> targets =
2909  llvm::to_vector(state.getPayloadOps(getTarget()));
2910  SmallVector<SmallVector<Operation *>> dynamicSizeProducers;
2912  dynamicSizeProducers.reserve(getDynamicSizes().size());
2913  paramSizes.reserve(getDynamicSizes().size());
2914  for (Value transformValue : getDynamicSizes()) {
2915  if (isa<ParamType>(transformValue.getType())) {
2916  dynamicSizeProducers.push_back({});
2917  ArrayRef<Attribute> params = state.getParams(transformValue);
2918  paramSizes.push_back(
2919  llvm::to_vector(llvm::map_range(params, [](Attribute attr) {
2920  return cast<IntegerAttr>(attr).getValue().getSExtValue();
2921  })));
2922 
2923  if (paramSizes.back().size() != targets.size()) {
2925  emitSilenceableError()
2926  << "expected as many parameter values ("
2927  << dynamicSizeProducers.back().size() << ") as target ops ("
2928  << targets.size() << ")";
2929  diag.attachNote(transformValue.getLoc()) << "for this parameter";
2930  return diag;
2931  }
2932 
2933  continue;
2934  }
2935  paramSizes.push_back({});
2936  dynamicSizeProducers.push_back(
2937  llvm::to_vector(state.getPayloadOps(transformValue)));
2938 
2939  if (dynamicSizeProducers.back().size() != targets.size()) {
2941  emitSilenceableError()
2942  << "expected as many dynamic size-producing operations ("
2943  << dynamicSizeProducers.back().size() << ") as target ops ("
2944  << targets.size() << ")";
2945  diag.attachNote(transformValue.getLoc()) << "for this handle";
2946  return diag;
2947  }
2948 
2949  for (Operation *op : dynamicSizeProducers.back()) {
2950  if (op->getNumResults() == 1 &&
2951  isa<IndexType>(op->getResult(0).getType())) {
2952  continue;
2953  }
2954 
2956  emitSilenceableError() << "expected sizes to be produced by ops "
2957  "with a single index-type result";
2958  diag.attachNote(op->getLoc()) << "size producer op";
2959  diag.attachNote(transformValue.getLoc()) << "for this handle";
2960  return diag;
2961  }
2962  }
2963 
2966  loops.resize(getLoops().size());
2967  auto scalableSizes = getScalableSizes();
2968  for (auto [i, op] : llvm::enumerate(targets)) {
2969  auto tilingInterface = dyn_cast<TilingInterface>(op);
2970  if (!tilingInterface) {
2972  emitSilenceableError()
2973  << "only ops implementing TilingInterface are supported";
2974  diag.attachNote(op->getLoc()) << "target op";
2975  return diag;
2976  }
2977  if (tileSizes.size() > tilingInterface.getLoopIteratorTypes().size()) {
2979  emitSilenceableError()
2980  << "too many tiles provided, expected at most "
2981  << tilingInterface.getLoopIteratorTypes().size() << " found "
2982  << tileSizes.size();
2983  diag.attachNote(op->getLoc()) << "target op";
2984  return diag;
2985  }
2986 
2987  scf::SCFTilingOptions tilingOptions;
2988  if (tileSizes.empty()) {
2989  tilingOptions.setTileSizeComputationFunction(
2991  return {};
2992  });
2993  } else {
2994  tilingOptions.setTileSizeComputationFunction([&, index = i](OpBuilder &b,
2995  Operation *) {
2997  sizes.reserve(tileSizes.size());
2998  unsigned dynamicIdx = 0;
2999 
3000  for (auto [ofrIdx, ofr] : llvm::enumerate(getMixedSizes())) {
3001  if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) {
3002  if (scalableSizes[ofrIdx]) {
3003  auto val = b.create<arith::ConstantIndexOp>(
3004  getLoc(), cast<IntegerAttr>(attr).getInt());
3005  Value vscale =
3006  b.create<vector::VectorScaleOp>(getLoc(), b.getIndexType());
3007  sizes.push_back(
3008  b.create<arith::MulIOp>(getLoc(), val, vscale).getResult());
3009  } else {
3010  sizes.push_back(attr);
3011  }
3012  continue;
3013  }
3014  ArrayRef<Operation *> dynamicSizes = dynamicSizeProducers[dynamicIdx];
3015  ArrayRef<int64_t> params = paramSizes[dynamicIdx];
3016  ++dynamicIdx;
3017  assert((dynamicSizes.empty() ^ params.empty()) &&
3018  "expected either dynamic sizes or parameters");
3019  if (!params.empty()) {
3020  sizes.push_back(b.getIndexAttr(params[index]));
3021  } else {
3022  sizes.push_back(dynamicSizes[index]->getResult(0));
3023  }
3024  }
3025  return sizes;
3026  });
3027  }
3028 
3029  tilingOptions.setInterchange(getInterchange());
3030  FailureOr<scf::SCFTilingResult> maybeTilingResult =
3031  tileUsingSCF(rewriter, tilingInterface, tilingOptions);
3032  if (failed(maybeTilingResult))
3034 
3035  rewriter.replaceOp(op, maybeTilingResult->replacements);
3036 
3037  tiled.append(maybeTilingResult->tiledOps);
3038  for (const auto &en2 : llvm::enumerate(maybeTilingResult->loops))
3039  loops[en2.index()].push_back(en2.value());
3040  }
3041 
3042  transformResults.set(cast<OpResult>(getTiledLinalgOp()), tiled);
3043  for (const auto &en : llvm::enumerate(loops))
3044  transformResults.set(cast<OpResult>(getLoops()[en.index()]), en.value());
3045 
3047 }
3048 
3050  ValueRange dynamic = getDynamicSizes();
3051  ArrayRef<int64_t> tileSizes = getStaticSizes();
3052  SmallVector<OpFoldResult> results;
3053  results.reserve(tileSizes.size());
3054  unsigned dynamicPos = 0;
3055  Builder builder(getContext());
3056  for (int64_t size : tileSizes) {
3057  if (size == ShapedType::kDynamic) {
3058  results.push_back(dynamic[dynamicPos++]);
3059  } else {
3060  results.push_back(builder.getIndexAttr(size));
3061  }
3062  }
3063  return results;
3064 }
3065 
3066 void transform::TileUsingForOp::getEffects(
3068  consumesHandle(getTargetMutable(), effects);
3069  onlyReadsHandle(getDynamicSizesMutable(), effects);
3070  producesHandle(getOperation()->getOpResults(), effects);
3071  modifiesPayload(effects);
3072 }
3073 
3074 //===----------------------------------------------------------------------===//
3075 // TileUsingForallOp
3076 //===----------------------------------------------------------------------===//
3077 
3078 void transform::TileUsingForallOp::build(OpBuilder &builder,
3079  OperationState &result, Value target,
3080  ArrayRef<int64_t> staticTileSizes,
3082  ArrayAttr mapping) {
3083  return build(builder, result,
3084  /*target=*/target,
3085  /*mixedTileSizes=*/
3086  getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
3087  /*_=*/TileSizesSpec(),
3088  /*mapping=*/mapping);
3089 }
3090 
3091 void transform::TileUsingForallOp::build(OpBuilder &builder,
3092  OperationState &result, Value target,
3093  ArrayRef<OpFoldResult> mixedTileSizes,
3095  ArrayAttr mapping) {
3096  SmallVector<int64_t> staticTileSizes;
3097  SmallVector<Value> dynamicTileSizes;
3098  dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
3099  // Call the default builder which sets up the proper operands segment sizes
3100  // attributes for multiple variadic operands. In the absence of this,
3101  // horrible bugs ensue.
3102  MLIRContext *ctx = builder.getContext();
3103  auto operationType = transform::AnyOpType::get(ctx);
3104  auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
3105  build(builder, result,
3106  /*resultTypes=*/TypeRange{operationType, operationType},
3107  /*target=*/target,
3108  /*num_threads=*/ValueRange{},
3109  /*tile_sizes=*/dynamicTileSizes,
3110  /*packed_num_threads=*/Value(),
3111  /*packed_tile_sizes=*/Value(),
3112  /*static_num_threads=*/builder.getDenseI64ArrayAttr({}),
3113  /*static_tile_sizes=*/staticTileSizesAttr,
3114  /*mapping=*/mapping);
3115 }
3116 
3117 void transform::TileUsingForallOp::build(OpBuilder &builder,
3118  OperationState &result, Value target,
3119  ArrayRef<int64_t> staticNumThreads,
3121  ArrayAttr mapping) {
3122  return build(builder, result, target,
3123  getAsOpFoldResult(builder.getI64ArrayAttr(staticNumThreads)),
3124  NumThreadsSpec(), mapping);
3125 }
3126 
3127 void transform::TileUsingForallOp::build(OpBuilder &builder,
3128  OperationState &result, Value target,
3129  ArrayRef<OpFoldResult> mixedNumThreads,
3131  ArrayAttr mapping) {
3132  SmallVector<int64_t> staticNumThreads;
3133  SmallVector<Value> dynamicNumThreads;
3134  dispatchIndexOpFoldResults(mixedNumThreads, dynamicNumThreads,
3135  staticNumThreads);
3136  // Call the default builder which sets up the proper operands segment sizes
3137  // attributes for multiple variadic operands. In the absence of this,
3138  // horrible bugs ensue.
3139  MLIRContext *ctx = builder.getContext();
3140  auto operationType = transform::AnyOpType::get(ctx);
3141  auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads);
3142  build(builder, result,
3143  /*resultTypes=*/TypeRange{operationType, operationType},
3144  /*target=*/target,
3145  /*num_threads=*/dynamicNumThreads,
3146  /*tile_sizes=*/ValueRange{},
3147  /*packed_num_threads=*/Value(),
3148  /*packed_tile_sizes=*/Value(),
3149  /*static_num_threads=*/staticNumThreadsAttr,
3150  /*static_tile_sizes=*/builder.getDenseI64ArrayAttr({}),
3151  /*mapping=*/mapping);
3152 }
3153 
3155  RewriterBase &rewriter, transform::TransformState &state,
3156  TransformOpInterface transformOp, Operation *target,
3157  ArrayRef<OpFoldResult> mixedNumThreads,
3158  ArrayRef<OpFoldResult> mixedTileSizes, std::optional<ArrayAttr> mapping,
3159  linalg::ForallTilingResult &tilingResult) {
3160  // Transform all targets one by one.
3161  auto tileableOp = dyn_cast<TilingInterface>(target);
3162  if (!tileableOp) {
3164  transformOp.emitSilenceableError()
3165  << "only TilingInterface ops are supported";
3166  diag.attachNote(target->getLoc()) << "target op";
3167  return diag;
3168  }
3169  rewriter.setInsertionPoint(tileableOp);
3170  FailureOr<linalg::ForallTilingResult> maybeTilingResult = failure();
3171  if (!mixedNumThreads.empty()) {
3172  maybeTilingResult =
3173  linalg::tileToForallOp(rewriter, tileableOp, mixedNumThreads, mapping);
3174  } else {
3175  maybeTilingResult = linalg::tileToForallOpUsingTileSizes(
3176  rewriter, tileableOp, mixedTileSizes, mapping);
3177  }
3178 
3179  if (failed(maybeTilingResult))
3180  return transformOp.emitDefaultSilenceableFailure(tileableOp);
3181  rewriter.replaceOp(tileableOp, maybeTilingResult->tileOp->getResults());
3182 
3183  tilingResult = *maybeTilingResult;
3185 }
3186 
3187 DiagnosedSilenceableFailure transform::TileUsingForallOp::apply(
3188  transform::TransformRewriter &rewriter,
3189  transform::TransformResults &transformResults,
3190  transform::TransformState &state) {
3191  auto transformOp = cast<TransformOpInterface>(getOperation());
3192 
3193  // Result payload ops.
3194  SmallVector<Operation *> tileOps;
3195  SmallVector<Operation *> tiledOps;
3196 
3197  // Unpack handles.
3198  SmallVector<OpFoldResult> mixedNumThreads;
3200  getPackedNumThreads()
3202  state, transformOp, mixedNumThreads, getPackedNumThreads())
3204  state, transformOp, mixedNumThreads, getMixedNumThreads());
3205  if (!status.succeeded())
3206  return status;
3207  SmallVector<OpFoldResult> mixedTileSizes;
3208  status = getPackedTileSizes()
3210  state, transformOp, mixedTileSizes, getPackedTileSizes())
3212  state, transformOp, mixedTileSizes, getMixedTileSizes());
3213  if (!status.succeeded())
3214  return status;
3215 
3216  for (Operation *target : state.getPayloadOps(getTarget())) {
3217  linalg::ForallTilingResult tilingResult;
3219  rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes,
3220  getMapping(), tilingResult);
3221  if (!diag.succeeded())
3222  return diag;
3223  tileOps.push_back(tilingResult.tileOp);
3224  tiledOps.push_back(tilingResult.tiledOp);
3225  }
3226 
3227  transformResults.set(cast<OpResult>(getForallOp()), tileOps);
3228  transformResults.set(cast<OpResult>(getTiledOp()), tiledOps);
3229 
3231 }
3232 
3233 void transform::TileUsingForallOp::getEffects(
3235  consumesHandle(getTargetMutable(), effects);
3236  onlyReadsHandle(getTileSizesMutable(), effects);
3237  onlyReadsHandle(getNumThreadsMutable(), effects);
3238  onlyReadsHandle(getPackedNumThreadsMutable(), effects);
3239  onlyReadsHandle(getPackedTileSizesMutable(), effects);
3240  producesHandle(getOperation()->getOpResults(), effects);
3241  modifiesPayload(effects);
3242 }
3243 
3244 SmallVector<OpFoldResult> TileUsingForallOp::getMixedNumThreads() {
3245  Builder b(getContext());
3246  return getMixedValues(getStaticNumThreads(), getNumThreads(), b);
3247 }
3248 
3249 SmallVector<OpFoldResult> TileUsingForallOp::getMixedTileSizes() {
3250  Builder b(getContext());
3251  return getMixedValues(getStaticTileSizes(), getTileSizes(), b);
3252 }
3253 
3254 LogicalResult TileUsingForallOp::verify() {
3255  int numThreadsSpec = static_cast<int>(!getMixedNumThreads().empty()) +
3256  static_cast<int>(getPackedNumThreads() != Value());
3257  if (numThreadsSpec > 1)
3258  return emitOpError(
3259  "num_threads and packed_num_threads are mutually exclusive");
3260  int tileSizesSpec = static_cast<int>(!getMixedTileSizes().empty()) +
3261  static_cast<int>(getPackedTileSizes() != Value());
3262  if (tileSizesSpec > 1)
3263  return emitOpError(
3264  "tile_sizes and packed_tile_sizes are mutually exclusive");
3265  if (numThreadsSpec == 0 && tileSizesSpec == 0)
3266  return emitOpError("either (packed_)num_threads or (packed_)tile_sizes "
3267  "must be specified");
3268  return success();
3269 }
3270 
3271 //===----------------------------------------------------------------------===//
3272 // VectorizeChildrenAndApplyPatternsOp
3273 //===----------------------------------------------------------------------===//
3274 
3275 void transform::VectorizeChildrenAndApplyPatternsOp::build(
3276  OpBuilder &builder, OperationState &result, Value target,
3277  bool vectorizePadding, bool vectorizeExtract, bool flatten1DDepthwiseConv) {
3278  result.addOperands(target);
3279  if (vectorizePadding) {
3280  result.addAttribute(
3281  VectorizeChildrenAndApplyPatternsOp::getVectorizePaddingAttrName(
3282  result.name),
3283  builder.getUnitAttr());
3284  }
3285  if (vectorizeExtract) {
3286  result.addAttribute(
3287  VectorizeChildrenAndApplyPatternsOp::getVectorizeNdExtractAttrName(
3288  result.name),
3289  builder.getUnitAttr());
3290  }
3291  if (flatten1DDepthwiseConv) {
3292  result.addAttribute(
3293  VectorizeChildrenAndApplyPatternsOp::getFlatten_1dDepthwiseConvAttrName(
3294  result.name),
3295  builder.getUnitAttr());
3296  }
3297  result.addTypes(transform::AnyOpType::get(builder.getContext()));
3298 }
3299 
3300 namespace {
3301 /// This is an helper only to call vectorize via a pattern inside of
3302 /// VectorizeChildrenAndApplyPatternsOp::applyToOne.
3303 struct VectorizationPattern : public RewritePattern {
3304  explicit VectorizationPattern(MLIRContext *context,
3305  bool vectorizeExtract = false,
3306  bool flattenConv = false)
3307  : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context),
3308  vectorizeNDExtract(vectorizeExtract),
3309  flatten1DDepthwiseConv(flattenConv) {}
3310  LogicalResult matchAndRewrite(Operation *op,
3311  PatternRewriter &rewriter) const override {
3312  LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
3313  if (!linalgOp)
3314  return rewriter.notifyMatchFailure(op, "expected Linalg Op");
3315  return vectorize(rewriter, linalgOp, /*inputVectorSizes=*/{},
3316  /*scalableVecDims=*/{}, vectorizeNDExtract,
3317  flatten1DDepthwiseConv);
3318  }
3319 
3320 private:
3321  /// Controls whether to vectorize `tensor.extract` when the input tensor is
3322  /// rank >= 2.
3323  bool vectorizeNDExtract = false;
3324  /// Controls whether to "flatten" the channel dimension when vectorising 1D
3325  /// depthwise convolutions. This should lead to bette vectorization for
3326  /// tensors with a low number of channel dimensions.
3327  bool flatten1DDepthwiseConv = false;
3328 };
3329 } // namespace
3330 
3332 transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
3333  transform::TransformRewriter &rewriter, Operation *target,
3335  transform::TransformState &state) {
3336  if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
3337  auto diag = this->emitOpError("requires isolated-from-above targets");
3338  diag.attachNote(target->getLoc()) << "non-isolated target";
3340  }
3341 
3342  MLIRContext *ctx = getContext();
3343  RewritePatternSet patterns(ctx);
3344  patterns.add<VectorizationPattern>(ctx, getVectorizeNdExtract(),
3345  getFlatten_1dDepthwiseConv());
3346 
3347  if (!getDisableTransferPermutationMapLoweringPatterns())
3349 
3350  if (!getDisableMultiReductionToContractPatterns())
3352 
3354 
3357  /*benefit=*/2);
3358  vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx);
3359  vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx);
3361 
3362  patterns.add<CopyVectorizationPattern>(ctx);
3363 
3364  if (getVectorizePadding())
3366 
3367  TrackingListener listener(state, *this);
3368  GreedyRewriteConfig config;
3369  config.listener = &listener;
3370  if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns), config)))
3371  return emitDefaultDefiniteFailure(target);
3372 
3373  results.push_back(target);
3375 }
3376 
3377 //===----------------------------------------------------------------------===//
3378 // VectorizeOp
3379 //===----------------------------------------------------------------------===//
3380 
3381 DiagnosedSilenceableFailure transform::VectorizeOp::apply(
3382  transform::TransformRewriter &rewriter,
3383  mlir::transform::TransformResults &transformResults,
3385  auto targets = state.getPayloadOps(getTarget());
3386  if (std::empty(targets))
3388  auto transformOp = cast<TransformOpInterface>(getOperation());
3389  SmallVector<int64_t> vectorSizes;
3391  state, transformOp, getMixedVectorSizes(), vectorSizes);
3392  if (!status.succeeded())
3393  return status;
3394 
3395  // TODO: Check that the correct number of vectorSizes was provided.
3396  for (Operation *target : targets) {
3397  if (!isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp, tensor::UnPackOp>(
3398  target)) {
3399  return mlir::emitSilenceableFailure(target->getLoc())
3400  << "Unsupported Op, cannot vectorize";
3401  }
3402 
3403  if (failed(linalg::vectorize(rewriter, target, vectorSizes,
3404  getScalableSizes(),
3405  getVectorizeNdExtract().has_value()
3406  ? getVectorizeNdExtract().value()
3407  : false))) {
3408  return mlir::emitSilenceableFailure(target->getLoc())
3409  << "Attempted to vectorize, but failed";
3410  }
3411  }
3412 
3414 }
3415 
3416 void transform::VectorizeOp::getEffects(
3418  consumesHandle(getTargetMutable(), effects);
3419  onlyReadsHandle(getVectorSizesMutable(), effects);
3420  modifiesPayload(effects);
3421 }
3422 
3423 SmallVector<OpFoldResult> VectorizeOp::getMixedVectorSizes() {
3424  OpBuilder b(getContext());
3425  return getMixedValues(getStaticVectorSizes(), getVectorSizes(), b);
3426 }
3427 
3428 LogicalResult transform::VectorizeOp::verify() {
3429  if (getStaticVectorSizes().size() != getScalableSizes().size())
3430  return emitOpError("expected same number of vector sizes (")
3431  << getStaticVectorSizes().size() << ") and scalable sizes ("
3432  << getScalableSizes().size() << ")";
3433  return success();
3434 }
3435 
3436 //===----------------------------------------------------------------------===//
3437 // HoistRedundantVectorTransfersOp
3438 //===----------------------------------------------------------------------===//
3439 
3441 transform::HoistRedundantVectorTransfersOp::applyToOne(
3442  transform::TransformRewriter &rewriter, func::FuncOp target,
3444  transform::TransformState &state) {
3445  // WARNING: This hoisting does not model parallelism and is generally
3446  // incorrect when used on distributed loops with memref semantics!
3447  // TODO: obsolete and should be retired.
3449  results.push_back(target);
3451 }
3452 
3453 //===----------------------------------------------------------------------===//
3454 // HoistRedundantVectorBroadcastsOp
3455 //===----------------------------------------------------------------------===//
3456 
3458 transform::HoistRedundantVectorBroadcastsOp::applyToOne(
3459  transform::TransformRewriter &rewriter, mlir::Operation *target,
3461  transform::TransformState &state) {
3462  rewriter.setInsertionPoint(target);
3463  linalg::hoistRedundantVectorBroadcasts(rewriter, target);
3464  results.push_back(target);
3466 }
3467 
3468 //===----------------------------------------------------------------------===//
3469 // ConvertConv2DToImg2ColOp.
3470 //===----------------------------------------------------------------------===//
3471 
3472 DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne(
3473  transform::TransformRewriter &rewriter, linalg::LinalgOp target,
3475  transform::TransformState &state) {
3476  rewriter.setInsertionPoint(target);
3477  auto maybeTransformed =
3479  target)
3480  .Case([&](linalg::Conv2DNhwcHwcfOp op) {
3481  return rewriteInIm2Col(rewriter, op);
3482  })
3483  .Case([&](linalg::Conv2DNhwcFhwcOp op) {
3484  return rewriteInIm2Col(rewriter, op);
3485  })
3486  .Case([&](linalg::DepthwiseConv2DNhwcHwcOp op) {
3487  return rewriteInIm2Col(rewriter, op);
3488  })
3489  .Case([&](linalg::Conv2DNchwFchwOp op) {
3490  return rewriteInIm2Col(rewriter, op);
3491  })
3492  .Default([&](Operation *op) {
3493  return rewriter.notifyMatchFailure(op, "not supported");
3494  });
3495  if (failed(maybeTransformed))
3496  return emitDefaultSilenceableFailure(target);
3497  // Handle to the operation producing the img2col tensor.
3498  results.push_back(maybeTransformed->first);
3499  // Handle to the operation that replaces the original convolution.
3500  results.push_back(maybeTransformed->second);
3502 }
3503 
3504 //===----------------------------------------------------------------------===//
3505 // FlattenElementwiseLinalgOp.
3506 //===----------------------------------------------------------------------===//
3507 
3508 DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne(
3509  transform::TransformRewriter &rewriter, linalg::LinalgOp target,
3511  transform::TransformState &state) {
3512  rewriter.setInsertionPoint(target);
3513  if (!isElementwise(target))
3514  return mlir::emitSilenceableFailure(target->getLoc())
3515  << "only elementwise flattening is supported";
3516 
3517  // If rank <= 1, do nothing
3518  if (target.getNumLoops() <= 1) {
3519  results.push_back(target);
3521  }
3522 
3523  // Attempt to flatten all dims to one.
3524  ReassociationIndices reassociation(target.getNumLoops());
3525  std::iota(reassociation.begin(), reassociation.end(), 0);
3526  auto maybeFlattened =
3527  collapseOpIterationDims(target, reassociation, rewriter);
3528  if (failed(maybeFlattened))
3529  return mlir::emitSilenceableFailure(target->getLoc())
3530  << "attempted to flatten, but failed";
3531  results.push_back(maybeFlattened->collapsedOp);
3532  rewriter.replaceOp(target, maybeFlattened->results);
3534 }
3535 
3536 //===----------------------------------------------------------------------===//
3537 // TransposeConv2DOp
3538 //===----------------------------------------------------------------------===//
3539 
3540 DiagnosedSilenceableFailure transform::TransposeConv2DOp::applyToOne(
3541  transform::TransformRewriter &rewriter, linalg::LinalgOp target,
3543  transform::TransformState &state) {
3544  rewriter.setInsertionPoint(target);
3545  auto maybeTransformed =
3547  .Case([&](linalg::Conv2DNhwcFhwcOp op) {
3548  return transposeConv2D(rewriter, op);
3549  })
3550  .Case([&](linalg::Conv2DNhwcFhwcQOp op) {
3551  return transposeConv2D(rewriter, op);
3552  })
3553  .Default([&](Operation *op) {
3554  return rewriter.notifyMatchFailure(op, "not supported");
3555  });
3556  if (failed(maybeTransformed))
3557  return emitDefaultSilenceableFailure(target);
3558  // Handle to the new Conv2D operation with transposed filters
3559  results.push_back(*maybeTransformed);
3561 }
3562 
3563 //===----------------------------------------------------------------------===//
3564 // TransposeMatmulOp
3565 //===----------------------------------------------------------------------===//
3566 
3567 DiagnosedSilenceableFailure transform::TransposeMatmulOp::applyToOne(
3568  transform::TransformRewriter &rewriter, linalg::LinalgOp target,
3570  transform::TransformState &state) {
3571  rewriter.setInsertionPoint(target);
3572  bool transposeLHS = getInputToTranspose() == TransposeMatmulInput::lhs;
3573  auto maybeTransformed =
3575  .Case([&](linalg::MatmulOp op) {
3576  return transposeMatmul(rewriter, op, transposeLHS);
3577  })
3578  .Case([&](linalg::BatchMatmulOp op) {
3579  return transposeBatchMatmul(rewriter, op, transposeLHS);
3580  })
3581  .Default([&](Operation *op) { return failure(); });
3582  if (failed(maybeTransformed))
3583  return emitSilenceableFailure(target->getLoc()) << "not supported";
3584  // Handle to the new Matmul operation with transposed filters
3585  results.push_back(*maybeTransformed);
3587 }
3588 
3589 //===----------------------------------------------------------------------===//
3590 // InsertSliceToCopyOp
3591 //===----------------------------------------------------------------------===//
3592 template <typename OpTy>
3595  transform::TransformState &state) {
3596  static_assert(llvm::is_one_of<OpTy, tensor::InsertSliceOp,
3597  tensor::ParallelInsertSliceOp>() &&
3598  "wrong op type");
3599 
3600  if (auto copySource =
3601  target.getSource().template getDefiningOp<linalg::CopyOp>()) {
3602  results.push_back(copySource);
3604  }
3605 
3606  // If we are inside an InParallel region, temporarily set the insertion point
3607  // outside: only tensor.parallel_insert_slice ops are allowed in there.
3608  if constexpr (std::is_same_v<OpTy, tensor::ParallelInsertSliceOp>) {
3609  rewriter.setInsertionPoint(
3610  target->template getParentOfType<scf::InParallelOp>());
3611  }
3612 
3613  Value extracted = rewriter.create<tensor::ExtractSliceOp>(
3614  target.getLoc(), target.getDest(), target.getMixedOffsets(),
3615  target.getMixedSizes(), target.getMixedStrides());
3616  Value copied = rewriter
3617  .create<linalg::CopyOp>(target.getLoc(),
3618  target.getSource(), extracted)
3619  .getResult(0);
3620  // Reset the insertion point.
3621  rewriter.setInsertionPoint(target);
3622  rewriter.replaceOpWithNewOp<OpTy>(
3623  target, copied, target.getDest(), target.getMixedOffsets(),
3624  target.getMixedSizes(), target.getMixedStrides());
3625 
3626  results.push_back(copied.getDefiningOp());
3628 }
3629 
3630 DiagnosedSilenceableFailure transform::InsertSliceToCopyOp::applyToOne(
3631  transform::TransformRewriter &rewriter, Operation *targetOp,
3633  transform::TransformState &state) {
3634 
3635  rewriter.setInsertionPoint(targetOp);
3636  if (auto target = dyn_cast<tensor::InsertSliceOp>(targetOp))
3637  return doit(rewriter, target, results, state);
3638  if (auto target = dyn_cast<tensor::ParallelInsertSliceOp>(targetOp))
3639  return doit(rewriter, target, results, state);
3640 
3642  emitSilenceableError()
3643  << "only InsertSliceOp and ParallelInsertSliceOp ops are supported";
3644  diag.attachNote(targetOp->getLoc()) << "target op";
3645  return diag;
3646 }
3647 
3648 //===----------------------------------------------------------------------===//
3649 // MapCopyToThreadsOp
3650 //===----------------------------------------------------------------------===//
3651 
3652 DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
3653  transform::TransformRewriter &rewriter, Operation *target,
3655  transform::TransformState &state) {
3656  // Check if the op is supported.
3657  if (!isa<linalg::CopyOp, tensor::PadOp>(target)) {
3659  emitSilenceableError()
3660  << "only linalg.copy and tensor.pad target ops are supported";
3661  diag.attachNote(target->getLoc()) << "target op";
3662  return diag;
3663  }
3664  assert(target->getNumResults() == 1 && "expected single result");
3665  auto resultShapedType = cast<ShapedType>(target->getResult(0).getType());
3666  if (!resultShapedType.hasStaticShape()) {
3668  emitSilenceableError()
3669  << "only statically sized ops of rank <= 3 are supported";
3670  diag.attachNote(target->getLoc()) << "target op";
3671  return diag;
3672  }
3673 
3674  // Conservatively set the minimum viable desired bitwidth alignment.
3675  int64_t desiredBitAlignment = getDesiredBitAlignment();
3676  int64_t eltBitwidth =
3677  resultShapedType.getElementType().getIntOrFloatBitWidth();
3678  if (desiredBitAlignment % eltBitwidth != 0) {
3679  desiredBitAlignment = eltBitwidth;
3680  }
3681 
3682  gpu::CopyMappingInfo mapping(
3683  /*ctx=*/getContext(),
3684  /*totalNumThreads=*/getTotalNumThreads(),
3685  /*alignment=*/desiredBitAlignment,
3686  /*sizes=*/resultShapedType.getShape(),
3687  /*favorPredication=*/false,
3688  /*elementalBitwidth=*/
3689  resultShapedType.getElementType().getIntOrFloatBitWidth());
3690  if (mapping.status == gpu::CopyMappingInfo::Status::Invalid) {
3692  emitSilenceableError()
3693  << "too few threads to map copy op to threads on the most minor "
3694  "dimension, given alignment and vector size constraints, try "
3695  "smaller tile size of mapping to more threads";
3696  diag.attachNote(target->getLoc()) << "target op";
3697  return diag;
3698  }
3699 
3700  // OpBuilder only used to compute attributes.
3701  OpBuilder b(getContext());
3702  linalg::ForallTilingResult tilingResult;
3704  /*rewriter=*/rewriter,
3705  /*state=*/state,
3706  /*transformOp=*/*this,
3707  /*target=*/target,
3708  /*mixedNumThreads=*/getMixedValues(mapping.numThreads, {}, b),
3709  /*mixedTileSizes=*/ArrayRef<OpFoldResult>{},
3710  /*mapping=*/b.getArrayAttr(mapping.threadMapping),
3711  /*tilingResult=*/tilingResult);
3712  if (!diag.succeeded())
3713  return diag;
3714 
3715  results.push_back(tilingResult.tileOp);
3716  results.push_back(tilingResult.tiledOp);
3718 }
3719 
3720 //===----------------------------------------------------------------------===//
3721 // WinogradConv2DOp
3722 //===----------------------------------------------------------------------===//
3723 
3724 DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne(
3725  transform::TransformRewriter &rewriter, linalg::LinalgOp target,
3727  transform::TransformState &state) {
3728  rewriter.setInsertionPoint(target);
3729  FailureOr<Operation *> maybeTransformed = failure();
3730  bool supported = TypeSwitch<Operation *, bool>(target)
3731  .Case([&](linalg::Conv2DNhwcFhwcOp op) {
3732  maybeTransformed =
3733  winogradConv2D(rewriter, op, getM(), getR());
3734  return true;
3735  })
3736  .Default([&](Operation *op) { return false; });
3737 
3738  if (!supported) {
3739  return emitSilenceableError()
3740  << "this operation is not supported to convert to Winograd Conv2D";
3741  }
3742 
3743  if (supported && failed(maybeTransformed)) {
3744  return emitSilenceableError() << "apply Winograd Conv2D failed";
3745  }
3746 
3747  results.push_back(*maybeTransformed);
3749 }
3750 
3751 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
3752 
3753 #define GET_OP_CLASSES
3754 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
static MLIRContext * getContext(OpFoldResult val)
DiagnosedSilenceableFailure doit(RewriterBase &rewriter, OpTy target, transform::ApplyToEachResultList &results, transform::TransformState &state)
static Operation * cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
#define DOWNSCALE(trans)
bool isValidPackingPermutation(RelayoutOpTy op, ArrayRef< int64_t > permutation, OuterOrInnerPerm outerOrInnerPerm=OuterOrInnerPerm::Outer)
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static DiagnosedSilenceableFailure reifyMixedParamAndHandleResults(TransformState &state, TransformOpInterface &transformOp, ArrayRef< OpFoldResult > mixedResults, SmallVectorImpl< int64_t > &reified)
When possible, converts each OpFoldResult in mixedResult to an integer if the value can be statically...
static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(transform::TransformState &state, TransformOpInterface transformOp, SmallVector< OpFoldResult > &result, ArrayRef< OpFoldResult > ofrs)
Assuming that ofr is an index attr or a param of index type or a transform dialect handle mapped to e...
static void printContinuousTileSizeTypes(OpAsmPrinter &printer, Operation *op, Type targetType, Type tile_sizes, Type)
#define DOWNSCALE_NORMAL(a, b)
static FailureOr< LinalgOp > tryApply(Operation *operation, Args &&...args)
Attempts to apply the pattern specified as template argument to the given operation.
static void printMultitileSizesTypes(OpAsmPrinter &printer, Operation *op, Type targetType, Type lowSizeType, Type, Type)
static Operation * replaceForAllWithNewSignature(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp, TilingResult &tileAndFuseResult, int64_t resultNumber, SmallVector< OpFoldResult > &offsets, SmallVector< OpFoldResult > &sizes)
Add new operands to the forall op for users of the producerOp that are dominated by the containing sc...
static ParseResult parseContinuousTileSizeTypes(OpAsmParser &parser, Type &targetType, Type &tileSizesType, Type &chunkSizesType)
static SmallVector< Operation * > tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
First, find the first "scf::ForallOp" user of producerOp and ensure it is exactly the containingOp,...
static ParseResult parseMultitileSizesTypes(OpAsmParser &parser, Type &targetType, Type &lowSizeType, Type &highSizeType, Type &splitPointType)
#define DBGS()
static LogicalResult applyTilingToAll(RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps, unsigned numLoops, transform::TransformResults &transformResults, function_ref< FailureOr< scf::SCFTileAndFuseResult >(TilingInterface)> applyFn)
Apply a tiling transformation to all payload ops and store both the tiled operation as well as the cr...
static std::tuple< SmallVector< Operation * >, Operation * > tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
Find the first "extract" user of producerOp and tile it right before its use.
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static FailureOr< ForallTilingResult > tileToForallOpImpl(RewriterBase &b, TilingInterface op, ArrayRef< OpFoldResult > numThreads, std::optional< ArrayRef< OpFoldResult >> nominalTileSizes, std::optional< ArrayAttr > mapping, bool omitTileOffsetBoundsCheck)
Rewrite a TilingInterface op to a tiled scf.forall.
Definition: Tiling.cpp:474
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:128
UnitAttr getUnitAttr()
Definition: Builders.cpp:118
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:242
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:187
AffineExpr getAffineSymbolExpr(unsigned position)
Definition: Builders.cpp:379
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:132
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:94
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:273
MLIRContext * getContext() const
Definition: Builders.h:55
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:277
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:292
IndexType getIndexType()
Definition: Builders.cpp:75
ArrayAttr getStrArrayAttr(ArrayRef< StringRef > values)
Definition: Builders.cpp:317
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
A class for computing basic dominance information.
Definition: Dominance.h:136
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition: Dominance.h:156
This class allows control over how the GreedyPatternRewriteDriver works.
RewriterBase::Listener * listener
An optional listener that should be notified about IR modifications.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:207
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
Definition: AsmPrinter.cpp:94
This class represents a saved insertion point.
Definition: Builders.h:330
bool isSet() const
Returns true if this insert point is set.
Definition: Builders.h:340
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:351
This class helps build Operations.
Definition: Builders.h:210
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:559
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
Definition: Builders.h:319
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:401
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:323
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:468
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:415
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:457
This class provides the API for ops that are known to be isolated from above.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getOpResult(unsigned idx)
Definition: Operation.h:416
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:745
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:529
void setOperand(unsigned idx, Value value)
Definition: Operation.h:346
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition: Operation.h:555
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:669
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:341
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_type_range getOperandTypes()
Definition: Operation.h:392
result_type_range getResultTypes()
Definition: Operation.h:423
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition: Operation.h:263
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:869
result_range getOpResults()
Definition: Operation.h:415
result_range getResults()
Definition: Operation.h:410
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:219
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:246
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:638
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:57
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
Type front()
Return first type in the range.
Definition: TypeRange.h:148
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
user_range getUsers() const
Definition: Value.h:228
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
State for analysis-enabled bufferization.
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
void assign(unsigned size, std::nullptr_t)
Sets the list of results to size null pointers.
void reserve(unsigned size)
Reserves space for size elements in the list.
size_t size() const
Returns the number of elements in the list.
void push_back(Operation *op)
Appends an element to the list.
A listener that updates a TransformState based on IR modifications.
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void setValues(OpResult handle, Range &&values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
void setParams(OpResult value, ArrayRef< TransformState::Param > params)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
LogicalResult notifyPayloadOperationReplaced(Operation *op, Operation *replacement)
Notify the transform dialect interpreter that the given op has been replaced with another op and that...
The state maintained across applications of various ops implementing the TransformOpInterface.
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
Definition: AffineOps.cpp:1239
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1142
LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze op and its nested ops.
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition: Visitors.h:136
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
FailureOr< PackingResult > buildPackingLoopNest(RewriterBase &rewriter, tensor::PadOp opToHoist, scf::ForOp outermostEnclosingForOp, ArrayRef< int64_t > transposeVector)
Build the packing loop nest required to hoist opToHoist above outermostEnclosingForOp.
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp namedOp)
Create a GenericOp from the given named operation namedOp and replace namedOp.
LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, const LinalgPaddingOptions &options, LinalgOp &paddedOp, SmallVector< Value > &replacements, SmallVector< tensor::PadOp > &padOps)
Pad the iterator dimensions paddingDimensions of all opToPad operands to a static bounding box.
Definition: Padding.cpp:153
FailureOr< std::pair< Operation *, Operation * > > rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp)
Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) and linalg....
FailureOr< ForallTilingResult > tileToForallOpUsingTileSizes(RewriterBase &builder, TilingInterface op, ArrayRef< OpFoldResult > tileSizes, std::optional< ArrayAttr > mapping)
Same as tileToForallOp, but calculate the number of threads required using the given tileSizes.
Definition: Tiling.cpp:598
FailureOr< LowerUnPackOpResult > lowerUnPack(RewriterBase &rewriter, tensor::UnPackOp unPackOp)
Rewrite pack as empty + transpose + reshape + extract_slice.
Definition: Transforms.cpp:357
std::optional< Value > allocateWorkgroupMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU workgroup memory.
Definition: Promotion.cpp:470
Value bufferizeToAllocation(RewriterBase &rewriter, const BufferizeToAllocationOptions &options, tensor::PadOp padOp, Attribute memorySpace={}, Operation *insertionPoint=nullptr)
Materialize a buffer allocation for the given tensor.pad op and lower the op to linalg....
FailureOr< LinalgOp > specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp)
Create a namedOp from the given GenericOp and replace the GenericOp.
Definition: Specialize.cpp:262
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns)
Definition: Tiling.cpp:1046
LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value)
In case of GPU private memory there is no need to deallocate since the memory is freed when going out...
Definition: Promotion.cpp:511
std::optional< Value > allocateGPUPrivateMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU private memory.
Definition: Promotion.cpp:495
FailureOr< Operation * > rewriteInDestinationPassingStyle(RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp)
Rewrite tensor.from_elements to linalg.generic.
FailureOr< Operation * > winogradConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op, int64_t m, int64_t r)
Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm F(m x m, r x r).
FailureOr< ForallTilingResult > tileToForallOp(RewriterBase &builder, TilingInterface op, ArrayRef< OpFoldResult > numThreads, std::optional< ArrayAttr > mapping)
Definition: Tiling.cpp:589
void hoistRedundantVectorTransfers(Operation *root)
Hoist vector.transfer_read/vector.transfer_write on buffers pairs out of immediately enclosing scf::F...
Definition: Hoisting.cpp:202
FailureOr< Operation * > transposeConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op)
Convert linalg.conv_2d_nhwc_fhwc(_q) to linalg.conv_2d_nhwc_hwcf(_q) by materializing transpose.
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
Patterns to fold unit-extent dimensions in operands/results of linalg ops on tensors via reassociativ...
LogicalResult copyToWorkgroupMemory(OpBuilder &b, Value src, Value dst)
Create Memref copy operations and add gpu barrier guards before and after the copy operation to ensur...
Definition: Promotion.cpp:486
FailureOr< Value > hoistPaddingOnTensors(RewriterBase &rewriter, tensor::PadOp opToHoist, int64_t numLoops, ArrayRef< int64_t > transposeVector, tensor::PadOp &hoistedOp, SmallVectorImpl< GenericOp > &transposeOps)
Mechanically hoist padding operations on tensors by numLoops into a new, generally larger tensor.
LogicalResult linalgOpAnchoredEmptyTensorEliminationStep(RewriterBase &rewriter, Operation *op, bufferization::OneShotAnalysisState &state)
Try to eliminate tensor::EmptyOps inside op that are anchored on a LinalgOp.
LogicalResult vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Emit a suitable vector form for an operation.
FailureOr< Operation * > transposeBatchMatmul(RewriterBase &rewriter, linalg::BatchMatmulOp op, bool transposeLHS=true)
Pattern to replace.
LogicalResult promoteSubviewsPrecondition(Operation *op, LinalgPromotionOptions options)
Promote memref.subviews feeding linalg-on-buffers operations.
Definition: Promotion.cpp:399
LogicalResult copyToGPUPrivateMemory(OpBuilder &b, Value src, Value dst)
Normal copy to between src and dst.
Definition: Promotion.cpp:503
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
Definition: Utils.cpp:169
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.
Definition: Interchange.cpp:50
FailureOr< StaticMultiSizeSpecification > computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, int64_t targetSize, int64_t divisor)
Definition: Tiling.cpp:242
FailureOr< ContinuousTileSizeSpecification > computeContinuousTileSizes(OpBuilder &builder, TilingInterface op, unsigned dimension, OpFoldResult targetSize, bool emitAssertions)
Definition: Tiling.cpp:162
FailureOr< StaticContinuousTileSizeSpecification > computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension, unsigned targetSize)
Definition: Tiling.cpp:111
FailureOr< SplitReductionResult > splitReduction(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
void hoistRedundantVectorBroadcasts(RewriterBase &rewriter, Operation *root)
Hoist vector.extract/vector.broadcast pairs out of immediately enclosing scf::ForOp iteratively,...
Definition: Hoisting.cpp:97
FailureOr< ForallReductionTilingResult > tileReductionUsingForall(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > tileSizes={}, std::optional< ArrayAttr > mapping=std::nullopt)
Method to tile a reduction to parallel iterations computing partial reductions.
Definition: Tiling.cpp:778
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
Definition: Transforms.cpp:769
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
Definition: Transforms.cpp:480
void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns)
Patterns to promote inputs to outputs and remove unused inputs of linalg.generic ops.
FailureOr< LowerPackResult > lowerPack(RewriterBase &rewriter, tensor::PackOp packOp)
Rewrite pack as pad + reshape + transpose.
Definition: Transforms.cpp:219
FailureOr< LinalgOp > promoteSubViews(OpBuilder &b, LinalgOp op, const LinalgPromotionOptions &options)
Promote the subViews into a new buffer allocated at the insertion point b.
Definition: Promotion.cpp:421
std::function< SplitReductionOptions(LinalgOp op)> ControlSplitReductionFn
Function signature to control reduction splitting.
Definition: Transforms.h:442
LogicalResult deallocateWorkgroupMemory(OpBuilder &, Value)
In case of GPU group memory there is no need to deallocate.
Definition: Promotion.cpp:479
FailureOr< Operation * > transposeMatmul(RewriterBase &rewriter, linalg::MatmulOp op, bool transposeLHS=true)
Convert Linalg matmul ops to transposed variants.
FailureOr< CollapseResult > collapseOpIterationDims(LinalgOp op, ArrayRef< ReassociationIndices > foldedIterationDims, RewriterBase &rewriter)
Collapses dimensions of linalg.generic/linalg.copy operation.
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, tensor::PackOp packOp, linalg::LinalgOp linalgOp, tensor::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
Definition: Transforms.cpp:678
std::pair< TilingInterface, TilingInterface > splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension, OpFoldResult splitPoint)
Split the given op into two parts along the given iteration space dimension at the specified splitPoi...
Definition: Split.cpp:67
FailureOr< SplitReductionResult > splitReductionByScaling(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
Scaling-based implementation of the split reduction transformation.
FailureOr< MultiSizeSpecification > computeMultiTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension, OpFoldResult targetSize, OpFoldResult divisor, bool emitAssertions=true)
Emits the IR computing the multi-sized tiling specification with two tile sizes not exceeding targetS...
Definition: Tiling.cpp:268
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:77
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
FailureOr< scf::SCFReductionTilingResult > tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > tileSize)
Method to tile a reduction and generate a parallel op within a serial loop.
FailureOr< SCFTilingResult > tileUsingSCF(RewriterBase &rewriter, TilingInterface op, const SCFTilingOptions &options)
Method to tile an op that implements the TilingInterface using scf.for for iterating over the tiles.
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
Definition: SCF.cpp:597
FailureOr< SmallVector< scf::ForOp > > lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op)
Method to lower an op that implements the TilingInterface to loops/scalars.
uint64_t getM(LevelType lt)
Definition: Enums.h:443
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:109
void populateFoldTensorSubsetIntoVectorTransferPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into vector transfer ops.
DiagnosedSilenceableFailure tileToForallOpImpl(RewriterBase &rewriter, transform::TransformState &state, TransformOpInterface transformOp, Operation *target, ArrayRef< OpFoldResult > mixedNumThreads, ArrayRef< OpFoldResult > mixedTileSizes, std::optional< ArrayAttr > mapping, linalg::ForallTilingResult &tilingResult)
Implementation of tiling operations using scf.forall.
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
void populateSinkVectorBroadcastPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant vector broadcasts.
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Definition: XeGPUOps.cpp:21
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
This is the representation of an operand reference.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Definition: Builders.h:288
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
NamedAttrList attributes
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
A listener that forwards all notifications to another listener.
Definition: PatternMatch.h:463
ForwardingListener(OpBuilder::Listener *listener)
Definition: PatternMatch.h:464
Container for result values of tiling.
SmallVector< Value > tiledValues
Options for analysis-enabled bufferization.
Transformation to drop unit-extent dimensions from linalg.generic operations.
Definition: Transforms.h:473
Vectorization pattern for memref::CopyOp.
Definition: Transforms.h:1440
Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) dimensions into 1-D depthwise c...
Definition: Transforms.h:1372
Rewrite a TilingInterface op to a tiled scf.forall, applying tiling by numThreads.
Definition: Transforms.h:877
Match and rewrite for the pattern:
Definition: Transforms.h:1513
Match and rewrite for the pattern:
Definition: Transforms.h:1541
LinalgPromotionOptions & setUseFullTileBuffersByDefault(bool use)
Definition: Transforms.h:379
LinalgPromotionOptions & setAlignment(unsigned align)
Definition: Transforms.h:385
LinalgPromotionOptions & setUseAlloca(bool use)
Definition: Transforms.h:398
LinalgPromotionOptions & setCopyInOutFns(CopyCallbackFn const &copyIn, CopyCallbackFn const &copyOut)
Definition: Transforms.h:418
LinalgPromotionOptions & setUseFullTileBuffers(ArrayRef< bool > useFullTiles)
Definition: Transforms.h:368
LinalgPromotionOptions & setMemorySpace(Attribute memorySpc)
Definition: Transforms.h:392
LinalgPromotionOptions & setAllocationDeallocationFns(AllocBufferCallbackFn const &allocFn, DeallocBufferCallbackFn const &deallocFn)
Definition: Transforms.h:408
LinalgPromotionOptions & setOperandsToPromote(ArrayRef< int64_t > operands)
Definition: Transforms.h:357
Split Reduction options.
Definition: Transforms.h:427
Options used to control tile + fuse.
SCFTilingOptions tilingOptions
The tiling options used to control the tiling of the consumer.
Options to use to control tiling.
SCFTilingOptions & setTileSizeComputationFunction(SCFTileSizeComputationFunction fun)
SCFTilingOptions & setInterchange(ArrayRef< int64_t > interchange)
SCFTilingOptions & setTileSizes(ArrayRef< OpFoldResult > ts)
Convenience function to set the tileSizeComputationFunction to a function that computes tile sizes at...
SmallVector< int64_t > interchangeVector
The interchange vector to reorder the tiled loops.