MLIR  19.0.0git
LinalgTransformOps.cpp
Go to the documentation of this file.
1 //===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
12 
37 #include "mlir/IR/PatternMatch.h"
38 #include "mlir/IR/TypeUtilities.h"
40 #include "mlir/Support/LLVM.h"
41 #include "mlir/Support/TypeID.h"
43 #include "llvm/ADT/STLExtras.h"
44 #include "llvm/ADT/ScopeExit.h"
45 #include "llvm/ADT/TypeSwitch.h"
46 #include "llvm/Support/Debug.h"
47 #include <type_traits>
48 
49 using namespace mlir;
50 using namespace mlir::linalg;
51 using namespace mlir::transform;
52 
53 #define DEBUG_TYPE "linalg-transforms"
54 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
55 #define DBGSNL() (llvm::dbgs() << "\n")
56 #define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n")
57 
58 /// Attempts to apply the pattern specified as template argument to the given
59 /// operation. The pattern is expected to have a `returningMatchAndRewrite`
60 /// function that returns the "main" result or failure. Returns failure if the
61 /// pattern failed to apply. Extra arguments are forwarded to the pattern
62 /// constructor.
63 template <typename PatternTy, typename... Args>
64 static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
65  // Check if the given operation has the type expected by the pattern.
66  using OpTy = typename llvm::function_traits<
67  decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
68  auto op = dyn_cast<OpTy>(operation);
69  if (!op)
70  return failure();
71 
72  // Apply the pattern directly to the op.
73  PatternTy pattern(operation->getContext(), std::forward<Args>(args)...);
74  // We want to discourage direct use of PatternRewriter in APIs but In this
75  // very specific case, an IRRewriter is not enough.
76  struct TrivialPatternRewriter : public PatternRewriter {
77  public:
78  explicit TrivialPatternRewriter(MLIRContext *context)
79  : PatternRewriter(context) {}
80  };
81  TrivialPatternRewriter rewriter(operation->getContext());
82  rewriter.setInsertionPoint(operation);
83  auto result = pattern.returningMatchAndRewrite(op, rewriter);
84  if (failed(result))
85  return failure();
86  return cast<LinalgOp>(result->getOperation());
87 }
88 
89 /// Assuming that `ofr` is an index attr or a param of index type
90 /// or a transform dialect handle mapped to exactly one op
91 /// with one index result, return that value.
93  transform::TransformState &state, TransformOpInterface transformOp,
95  for (OpFoldResult ofr : ofrs) {
96  if (ofr.is<Attribute>()) {
97  if (!isa<IntegerAttr>(ofr.get<Attribute>()))
98  return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
99  result.push_back(ofr);
100  continue;
101  }
102 
103  Value transformValue = ofr.get<Value>();
104  if (isa<TransformParamTypeInterface>(transformValue.getType())) {
105  ArrayRef<Attribute> params = state.getParams(transformValue);
106  if (params.size() != 1)
107  return transformOp.emitDefiniteFailure()
108  << "requires exactly one parameter associated";
109  result.push_back(params[0]);
110  continue;
111  }
112 
113  auto payloadOps = state.getPayloadOps(transformValue);
114  if (!llvm::hasSingleElement(payloadOps)) {
116  transformOp.emitSilenceableError()
117  << "handle must be mapped to exactly one payload op";
118  diag.attachNote(transformValue.getLoc())
119  << "mapped to " << llvm::range_size(payloadOps) << " payload ops";
120  return diag;
121  }
122 
123  Operation *op = *payloadOps.begin();
124  if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
126  transformOp.emitSilenceableError()
127  << "payload op must have exactly 1 index result";
128  diag.attachNote(op->getLoc())
129  << "has " << op->getNumResults() << " results";
130  return diag;
131  }
132  result.push_back(op->getResult(0));
133  }
134 
136 }
137 
138 // Given a list of params that are index attrs or a list of OpFoldResults
139 // that are either index attrs or op handles, return a list of OpFoldResults
140 // of index attrs or a list of OpFoldResults where all op handles are
141 // replaced with the first (and only) OpResult of that payload op.
142 // (There must be exactly one parameter associated with the AnyParamType or
143 // one mapped payload op which must have exactly one index result.)
145  transform::TransformState &state, TransformOpInterface transformOp,
146  SmallVector<OpFoldResult> &result, Value packedHandle) {
147  if (isa<TransformParamTypeInterface>(packedHandle.getType())) {
148  ArrayRef<Attribute> params = state.getParams(packedHandle);
149  for (auto param : params) {
150  if (!isa<IntegerAttr>(param))
151  return transformOp.emitDefiniteFailure()
152  << "expected the parameter to be associated with an integer "
153  "attribute";
154  result.push_back(param);
155  }
157  }
158 
159  for (Operation *op : state.getPayloadOps(packedHandle)) {
160  if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
162  transformOp.emitSilenceableError()
163  << "payload op must have exactly 1 index result";
164  diag.attachNote(op->getLoc())
165  << "has " << op->getNumResults() << " results";
166  return diag;
167  }
168  result.push_back(op->getResult(0));
169  }
170 
172 }
173 
174 /// When possible, converts each `OpFoldResult` in `mixedResult` to
175 /// an integer if the value can be statically inferred. If a result
176 /// is a `Value` then it must be either a `ParamType` or a handle
177 /// to an a constant like op.
179  TransformState &state, TransformOpInterface &transformOp,
180  ArrayRef<OpFoldResult> mixedResults, SmallVectorImpl<int64_t> &reified) {
181  for (OpFoldResult paramOrHandle : mixedResults) {
182  if (isa<Attribute>(paramOrHandle)) {
183  reified.push_back(
184  cast<IntegerAttr>(paramOrHandle.get<Attribute>()).getInt());
185  continue;
186  } else if (isa<ParamType>(paramOrHandle.get<Value>().getType())) {
187  ArrayRef<Attribute> params = state.getParams(paramOrHandle.get<Value>());
188  if (params.size() != 1)
189  return transformOp.emitSilenceableError() << "expected a single param";
190  reified.push_back(
191  cast<IntegerAttr>(params.front()).getValue().getSExtValue());
192  continue;
193  }
194 
195  Value handle = paramOrHandle.get<Value>();
196  if (!isa<TransformHandleTypeInterface>(handle.getType()))
197  return transformOp.emitSilenceableError() << "unexpected value handle";
198  auto payload = state.getPayloadOps(handle);
199  if (!llvm::hasSingleElement(payload))
200  return transformOp.emitSilenceableError()
201  << "requires param or handle that is mapped to 1 payload op";
202 
203  Operation *paramOrHandlePayloadOp = *payload.begin();
204  if (paramOrHandlePayloadOp->getNumResults() != 1 ||
205  !paramOrHandlePayloadOp->getResult(0).getType().isIndex()) {
206  return transformOp.emitSilenceableError()
207  << "requires param or handle to be result of op with 1 index "
208  "result";
209  }
210 
211  IntegerAttr attr;
212  if (!matchPattern(paramOrHandlePayloadOp->getResult(0), m_Constant(&attr)))
213  return transformOp.emitSilenceableError()
214  << "requires param or handle to be the result of a constant like "
215  "op";
216 
217  reified.push_back(attr.getInt());
218  }
220 }
221 
222 //===----------------------------------------------------------------------===//
223 // Apply...PatternsOp
224 //===----------------------------------------------------------------------===//
225 
226 void transform::ApplyEraseUnnecessaryInputsPatternsOp::populatePatterns(
227  RewritePatternSet &patterns) {
229 }
230 
231 void transform::ApplyFoldUnitExtentDimsViaReshapesPatternsOp::populatePatterns(
232  RewritePatternSet &patterns) {
235 }
236 
237 void transform::ApplyFoldUnitExtentDimsViaSlicesPatternsOp::populatePatterns(
238  RewritePatternSet &patterns) {
240  options.rankReductionStrategy =
243 }
244 
245 void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns(
246  RewritePatternSet &patterns) {
248 }
249 
250 //===----------------------------------------------------------------------===//
251 // BufferizeToAllocationOp
252 //===----------------------------------------------------------------------===//
253 
254 void transform::BufferizeToAllocationOp::build(OpBuilder &b,
255  OperationState &result,
256  Value target,
257  Attribute memorySpace) {
258  SmallVector<Type> resultTypes;
259  resultTypes.push_back(b.getType<transform::AnyValueType>());
260  resultTypes.push_back(b.getType<transform::AnyOpType>());
261  return build(b, result,
262  /*resultTypes=*/resultTypes,
263  /*target=*/target,
264  /*memorySpace=*/memorySpace);
265 }
266 
267 void transform::BufferizeToAllocationOp::build(OpBuilder &b,
268  OperationState &result,
269  Value target,
270  int64_t memorySpace) {
271  SmallVector<Type> resultTypes;
272  resultTypes.push_back(b.getType<transform::AnyValueType>());
273  resultTypes.push_back(b.getType<transform::AnyOpType>());
274  return build(b, result,
275  /*resultTypes=*/resultTypes,
276  /*target=*/target,
277  /*memorySpace=*/b.getI64IntegerAttr(memorySpace));
278 }
279 
280 namespace {
281 class NewOpsListener : public RewriterBase::ForwardingListener {
282 public:
284 
285  SmallVector<Operation *> getNewOps() const {
286  return SmallVector<Operation *>(newOps.begin(), newOps.end());
287  }
288 
289 private:
290  void notifyOperationInserted(Operation *op,
291  OpBuilder::InsertPoint previous) override {
292  ForwardingListener::notifyOperationInserted(op, previous);
293  // We only care about newly created ops.
294  if (previous.isSet())
295  return;
296  auto inserted = newOps.insert(op);
297  (void)inserted;
298  assert(inserted.second && "expected newly created op");
299  }
300 
301  void notifyOperationErased(Operation *op) override {
302  ForwardingListener::notifyOperationErased(op);
303  op->walk([&](Operation *op) { newOps.erase(op); });
304  }
305 
306  DenseSet<Operation *> newOps;
307 };
308 } // namespace
309 
310 DiagnosedSilenceableFailure transform::BufferizeToAllocationOp::apply(
313  // Attach listener to keep track of newly created ops.
314  OpBuilder::Listener *previousListener = rewriter.getListener();
315  auto resetListener =
316  llvm::make_scope_exit([&]() { rewriter.setListener(previousListener); });
317  NewOpsListener newOpsListener(previousListener);
318  rewriter.setListener(&newOpsListener);
319 
321  if (getMemcpyOp() == "bufferization.materialize_in_destination") {
324  } else if (getMemcpyOp() == "memref.copy") {
325  options.memcpyOp =
327  } else if (getMemcpyOp() == "linalg.copy") {
328  options.memcpyOp =
330  } else {
331  llvm_unreachable("invalid memcpy op");
332  }
333  if (getAllocOp() == "memref.alloc") {
334  options.allocOp =
336  } else if (getAllocOp() == "memref.alloca") {
337  options.allocOp =
339  } else {
340  llvm_unreachable("invalid alloc op");
341  }
342  options.bufferizeDestinationOnly = getBufferizeDestinationOnly();
343  options.emitDealloc = getEmitDealloc();
344 
345  // Bufferize ops.
346  Attribute memorySpace =
347  getMemorySpace().has_value() ? getMemorySpace().value() : Attribute();
348  SmallVector<Value> allocatedBuffers;
349  for (Operation *op : state.getPayloadOps(getTarget())) {
350  Value buffer =
351  linalg::bufferizeToAllocation(rewriter, options, op, memorySpace);
352  if (!buffer) {
353  DiagnosedSilenceableFailure diag = emitSilenceableError()
354  << "failed to bufferize operation";
355  diag.attachNote(op->getLoc()) << "target payload op";
356  return diag;
357  }
358  allocatedBuffers.push_back(buffer);
359  }
360 
361  // Set results.
362  results.setValues(cast<OpResult>(getAllocatedBuffer()), allocatedBuffers);
363  results.set(cast<OpResult>(getNewOps()), newOpsListener.getNewOps());
365 }
366 
367 void transform::BufferizeToAllocationOp::getEffects(
369  if (getBufferizeDestinationOnly()) {
370  // The destination is replaced with a newly allocated buffer, but the op
371  // itself remains in place.
372  onlyReadsHandle(getTarget(), effects);
373  } else {
374  consumesHandle(getTarget(), effects);
375  }
376  producesHandle(getAllocatedBuffer(), effects);
377  producesHandle(getNewOps(), effects);
378  modifiesPayload(effects);
379 }
380 
382  if (getMemcpyOp() != "bufferization.materialize_in_destination" &&
383  getMemcpyOp() != "memref.copy" && getMemcpyOp() != "linalg.copy")
384  return emitOpError() << "unsupported memcpy op";
385  if (getAllocOp() != "memref.alloc" && getAllocOp() != "memref.alloca")
386  return emitOpError() << "unsupported alloc op";
387  return success();
388 }
389 
390 //===----------------------------------------------------------------------===//
391 // DecomposeOp
392 //===----------------------------------------------------------------------===//
393 
395 transform::DecomposeOp::applyToOne(transform::TransformRewriter &rewriter,
396  LinalgOp target,
398  transform::TransformState &state) {
399 #define DOWNSCALE(trans) \
400  { \
401  FailureOr<LinalgOp> res = tryApply<trans>(target); \
402  if (succeeded(res)) { \
403  results.push_back(*res); \
404  return DiagnosedSilenceableFailure::success(); \
405  } \
406  }
407 
408 #define DOWNSCALE_CALL(a, b) DownscaleSizeOneWindowed2DConvolution<a, b>
409 #define DOWNSCALE_NORMAL(a, b) DOWNSCALE(DOWNSCALE_CALL(a, b))
410 
411  DOWNSCALE_NORMAL(Conv2DNhwcHwcfOp, Conv1DNwcWcfOp)
412  DOWNSCALE_NORMAL(Conv2DNchwFchwOp, Conv1DNcwFcwOp)
413  DOWNSCALE_NORMAL(PoolingNhwcSumOp, PoolingNwcSumOp)
414  DOWNSCALE_NORMAL(PoolingNchwSumOp, PoolingNcwSumOp)
415  DOWNSCALE_NORMAL(PoolingNhwcMaxOp, PoolingNwcMaxOp)
416  DOWNSCALE_NORMAL(PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp)
417  DOWNSCALE_NORMAL(PoolingNhwcMinOp, PoolingNwcMinOp)
418  DOWNSCALE_NORMAL(PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp)
419  DOWNSCALE_NORMAL(PoolingNchwMaxOp, PoolingNcwMaxOp)
422 #undef DOWNSCALE_NORMAL
423 #undef DOWNSCALE_CALL
424 #undef DOWNSCALE
425  return emitDefaultSilenceableFailure(target);
426 }
427 
428 //===----------------------------------------------------------------------===//
429 // DecomposeInterfaceOp
430 //===----------------------------------------------------------------------===//
431 
432 // Decompose the target operation if it implements the AggregatedOpInterface.
433 // Push the decomposed operations (the ones that replaces the values produced by
434 // \p target) in the `results`.
435 DiagnosedSilenceableFailure transform::DecomposeInterfaceOp::applyToOne(
436  transform::TransformRewriter &rewriter, Operation *target,
438  transform::TransformState &state) {
439  auto decomposableOp = dyn_cast<AggregatedOpInterface>(target);
440  if (!decomposableOp) {
441  failed(rewriter.notifyMatchFailure(target,
442  "payload is not a decomposable op"));
443  return emitDefaultSilenceableFailure(target);
444  }
445 
446  FailureOr<SmallVector<Value>> maybeNewResults =
447  decomposableOp.decomposeOperation(rewriter);
448  if (failed(maybeNewResults))
449  return emitDefaultSilenceableFailure(target);
450 
451  rewriter.replaceOp(decomposableOp, *maybeNewResults);
452  for (Value val : *maybeNewResults) {
453  Operation *definition = val.getDefiningOp();
454  if (definition)
455  results.push_back(definition);
456  }
458 }
459 
460 //===----------------------------------------------------------------------===//
461 // EliminateLinalgOpAnchoredEmptyTensorsOp
462 //===----------------------------------------------------------------------===//
463 
464 void transform::EliminateLinalgOpAnchoredEmptyTensorsOp::getEffects(
466  onlyReadsHandle(getTarget(), effects);
467  modifiesPayload(effects);
468 }
469 
471 transform::EliminateLinalgOpAnchoredEmptyTensorsOp::apply(
472  transform::TransformRewriter &rewriter, TransformResults &transformResults,
473  TransformState &state) {
475  options.allowReturnAllocsFromLoops = true;
476 
477  for (Operation *target : state.getPayloadOps(getTarget())) {
479  if (failed(analyzeOp(target, state)))
480  return mlir::emitSilenceableFailure(target->getLoc())
481  << "failed to analyze op";
483  rewriter, target, state)))
484  return mlir::emitSilenceableFailure(target->getLoc())
485  << "failed to eliminate LinalgOp anchored tensor.empty ops";
486  }
488 }
489 
490 //===----------------------------------------------------------------------===//
491 // FuseOp
492 //===----------------------------------------------------------------------===//
493 
494 /// Apply a tiling transformation to all payload ops and store both the
495 /// tiled operation as well as the created tile loops.
496 template <typename Range>
498  RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps,
499  unsigned numLoops, transform::TransformResults &transformResults,
501  applyFn) {
502  SmallVector<Operation *> tiledLinalgOps;
503  SmallVector<SmallVector<Operation *>> loopOps(numLoops);
504 
505  for (Operation *target : payloadOps) {
506  auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
507  if (!tilingInterfaceOp)
508  return transformOp->emitError("only TilingInterface ops are supported");
509 
510  rewriter.setInsertionPoint(target);
512  applyFn(tilingInterfaceOp);
513  if (failed(tiledResults))
514  return failure();
515 
516  // Perform the replacement of tiled and fused values.
517  SmallVector<Operation *> opsToReplace{target};
518  llvm::append_range(opsToReplace, tiledResults->fusedProducers);
519  for (Operation *toReplace : opsToReplace) {
520  for (OpResult res : toReplace->getResults())
521  if (auto replacement = tiledResults->replacements.lookup(res))
522  rewriter.replaceAllUsesWith(res, replacement);
523  if (toReplace->use_empty()) {
524  rewriter.eraseOp(toReplace);
525  }
526  }
527 
528  // Report back the relevant handles to the transform op.
529  tiledLinalgOps.push_back(tiledResults->tiledAndFusedOps.front());
530  assert(tiledResults->loops.size() == numLoops &&
531  "Mismatched number of loops, tile and fuse transform should have "
532  "failed");
533  for (unsigned int i = 0; i < numLoops; ++i)
534  loopOps[i].push_back(tiledResults->loops[i]);
535  }
536 
537  transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
538  for (unsigned int i = 0; i < numLoops; ++i)
539  transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
540 
541  return success();
542 }
543 
545 transform::FuseOp::apply(transform::TransformRewriter &rewriter,
546  mlir::transform::TransformResults &transformResults,
548  SmallVector<int64_t> tileSizes =
549  extractFromIntegerArrayAttr<int64_t>(getTileSizes());
550  SmallVector<int64_t> tileInterchange =
551  extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
552 
553  scf::SCFTilingOptions tilingOptions;
554  tilingOptions.interchangeVector = tileInterchange;
555  SmallVector<OpFoldResult> tileSizesOfr =
556  getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
557  tilingOptions = tilingOptions.setTileSizes(tileSizesOfr);
558  scf::SCFTileAndFuseOptions tileAndFuseOptions;
559  tileAndFuseOptions.tilingOptions = tilingOptions;
561  rewriter, getOperation(), state.getPayloadOps(getTarget()),
562  tileSizes.size() - llvm::count(tileSizes, 0), transformResults,
563  [&](TilingInterface tilingInterfaceOp)
565  return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
566  tileAndFuseOptions);
567  });
570 }
571 
573  SmallVector<int64_t> permutation =
574  extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
575  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
576  if (!std::is_permutation(sequence.begin(), sequence.end(),
577  permutation.begin(), permutation.end())) {
578  return emitOpError() << "expects interchange to be a permutation, found "
579  << getTileInterchange();
580  }
581 
582  SmallVector<int64_t> sizes =
583  extractFromIntegerArrayAttr<int64_t>(getTileSizes());
584  size_t numExpectedLoops = sizes.size() - llvm::count(sizes, 0);
585  if (numExpectedLoops != getNumResults() - 1)
586  return emitOpError() << "expects " << numExpectedLoops << " loop results";
587 
588  return success();
589 }
590 
591 //===----------------------------------------------------------------------===//
592 // FuseIntoContainingOp
593 //===----------------------------------------------------------------------===//
594 
595 void transform::FuseIntoContainingOp::build(OpBuilder &builder,
596  OperationState &result,
597  Value producerOp,
598  Value containingOp) {
599  result.addOperands({producerOp, containingOp});
600  auto resultType = transform::AnyOpType::get(builder.getContext());
601  result.addTypes({resultType, resultType});
602 }
603 
604 /// Add new operands to the forall op for users of the producerOp
605 /// that are dominated by the containing scf.forall op.
607  RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
608  Operation *containingOp, TilingResult &tileAndFuseResult,
609  int64_t resultNumber, SmallVector<OpFoldResult> &offsets,
610  SmallVector<OpFoldResult> &sizes) {
611 
612  // Count number of users not including the containing op
613  SetVector<Operation *> dominatedUsers;
614  DominanceInfo domInfo(containingOp);
615  for (Operation *user : producerOp->getResult(resultNumber).getUsers()) {
616  if (!containingOp->isAncestor(user) &&
617  (domInfo.dominates(containingOp, user))) {
618  dominatedUsers.insert(user);
619  }
620  }
621  if (dominatedUsers.empty())
622  return nullptr;
623 
624  // Create new scf.forall op
625  auto forallOp = cast<scf::ForallOp>(containingOp);
626  OpBuilder::InsertionGuard g(rewriter);
627  rewriter.setInsertionPoint(forallOp);
628 
629  // Get new output
630  Location loc = forallOp.getLoc();
631  auto genericOp = dyn_cast<linalg::GenericOp>(producerOp);
632  if (!genericOp)
633  return nullptr;
634  SmallVector<Value> outputs = genericOp.getOutputs();
635  SmallVector<Value> newOuts(forallOp.getOutputs());
636  newOuts.push_back(outputs[resultNumber]);
637 
638  // Create new scf.forall op
639  auto newforallOp = rewriter.create<scf::ForallOp>(
640  loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
641  forallOp.getMixedStep(), newOuts, forallOp.getMapping());
642  rewriter.eraseBlock(newforallOp.getBody());
643  newforallOp.getRegion().takeBody(forallOp.getRegion());
644 
645  // Add additional block argument for new value being returned
646  // and replaces all uses of the new output with corresponding bbArg
647  // inside the scf.forall to enable fusion into this new scf.forall.
648  newforallOp.getBody()->addArgument(newOuts.back().getType(),
649  newOuts.back().getLoc());
650  auto bbArgs = newforallOp.getBody()->getArguments();
651  rewriter.replaceUsesWithIf(newOuts.back(), bbArgs.back(),
652  [&](OpOperand &use) {
653  Operation *op = use.getOwner();
654  return newforallOp->isProperAncestor(op);
655  });
656 
657  // Fix terminator
658  scf::InParallelOp terminatorOp = newforallOp.getTerminator();
659  SmallVector<Operation *> yieldingOps = llvm::to_vector<4>(llvm::map_range(
660  terminatorOp.getYieldingOps(), [](Operation &op) { return &op; }));
661  Operation *firstYieldOp = yieldingOps.front();
662  rewriter.setInsertionPoint(firstYieldOp);
663  Value src = tileAndFuseResult.tiledValues[0];
664  Value dst = newforallOp.getRegionIterArgs().back();
665  SmallVector<OpFoldResult> strides(offsets.size(), rewriter.getIndexAttr(1));
666  rewriter.create<tensor::ParallelInsertSliceOp>(firstYieldOp->getLoc(), src,
667  dst, offsets, sizes, strides);
668 
669  for (auto result : llvm::enumerate(forallOp.getResults())) {
670  rewriter.replaceAllUsesWith(result.value(),
671  newforallOp->getResult(result.index()));
672  }
673  rewriter.replaceUsesWithIf(producerOp->getResult(resultNumber),
674  newforallOp->getResults().back(),
675  [&](OpOperand &use) {
676  Operation *user = use.getOwner();
677  return dominatedUsers.contains(user);
678  });
679  return newforallOp;
680 }
681 
682 /// Find the first "extract" user of `producerOp` and tile it right before its
683 /// use. The tiled op is fused under the `containingOp`.
684 /// Return this fused op on success or nullptr if anything fails.
685 /// If tiled op has uses that are dominated by `containingOp`, return
686 /// a new `containingOp` with results of the fused op appended to
687 /// results of the `containingOp` or nullptr if there are no dominated uses.
688 static std::tuple<SmallVector<Operation *>, Operation *>
690  Operation *producerOp, Operation *containingOp) {
691  LLVM_DEBUG(DBGS() << "Try to fuse a direct extract use\n");
692  auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
693  if (!tileableProducer) {
694  diag.attachNote(producerOp->getLoc())
695  << "producer is not a TileableInterface: " << *producerOp;
696  return {};
697  }
698 
699  // Search the producer slices accessed within the containing operation.
700  // TODO: Generalize to more extract/insert/parallel_insert triples, maybe
701  // evolve into an interface.
702  auto it = llvm::find_if(tileableProducer->getUsers(), [&](Operation *user) {
703  auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
704  return sliceOp && containingOp->isProperAncestor(sliceOp);
705  });
706 
707  // Find a fusion opportunity.
708  if (it == tileableProducer->getUsers().end()) {
709  diag.attachNote(tileableProducer->getLoc())
710  << "could not find fusion opportunity for: " << *tileableProducer;
711  return {};
712  }
713  auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*it);
714 
715  // Try to fuse the producer in-place.
716  OpBuilder::InsertionGuard guard(rewriter);
717  rewriter.setInsertionPoint(sliceOpToTile);
718 
719  // Tile the producer.
720  int64_t resultNumber =
721  cast<OpResult>(sliceOpToTile.getSource()).getResultNumber();
722  LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
723 
724  SmallVector<OpFoldResult> offsets = sliceOpToTile.getMixedOffsets();
725  SmallVector<OpFoldResult> sizes = sliceOpToTile.getMixedSizes();
726 
727  FailureOr<TilingResult> tileAndFuseResult =
728  tileableProducer.generateResultTileValue(rewriter, resultNumber, offsets,
729  sizes);
730 
731  if (failed(tileAndFuseResult)) {
732  diag.attachNote(tileableProducer->getLoc())
733  << "failed to tile producer op: " << *tileableProducer;
734  return {};
735  }
736 
737 #ifndef NDEBUG
738  for (auto *tiledOp : tileAndFuseResult->tiledOps) {
739  LLVM_DEBUG(DBGS() << "tiledProducer: " << *tiledOp << "\n");
740  }
741 #endif
742 
743  // Replace the extract op.
744  auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
745  rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
746  cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
747  if (failed(maybeRankReduced)) {
748  diag.attachNote(producerOp->getLoc())
749  << "shape types don't match (missing canonicalization?):\nTiledOp: "
750  << tileAndFuseResult->tiledValues[0]
751  << "\nSliceOp: " << sliceOpToTile.getOperation() << '\n';
752  return {};
753  }
754  rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
755 
756  // Add new outputs to containing op, if required
757  Operation *newContainingOp = replaceForAllWithNewSignature(
758  rewriter, diag, producerOp, containingOp, *tileAndFuseResult,
759  resultNumber, offsets, sizes);
760 
761  return std::make_tuple(tileAndFuseResult->tiledOps, newContainingOp);
762 }
763 
764 /// First, find the first "scf::ForallOp" user of `producerOp` and ensure
765 /// it is exactly the `containingOp`, otherwise bail.
766 /// Then, find the first "extract" user of the tied block argument and tile it
767 /// right before its "extract" use. The tiled op is fused under the
768 /// `containingOp`.
769 /// Return this fused op on success or nullptr if anything fails.
772  RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
773  Operation *containingOp) {
774  LLVM_DEBUG(DBGS() << "Try to fuse an extract use through block argument\n");
775 
776  auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
777  if (!tileableProducer) {
778  diag.attachNote(producerOp->getLoc())
779  << "producer is not a TileableInterface: " << *producerOp;
780  return {};
781  }
782 
783  // Search the first use by a "scf::ForallOp" user.
784  scf::ForallOp forallOp;
785  auto itProducerUses =
786  llvm::find_if(tileableProducer->getUses(), [&](OpOperand &use) {
787  forallOp = dyn_cast<scf::ForallOp>(use.getOwner());
788  return forallOp;
789  });
790  // If it's not from the containing op, return.
791  if (!forallOp || forallOp != containingOp) {
792  diag.attachNote(tileableProducer->getLoc())
793  << "could not find a use by the containing op: " << *tileableProducer;
794  return {};
795  }
796 
797  // Search the producer slices accessed within the containing
798  // operation.
799  // TODO: Generalize to more extract/insert/parallel_insert triples.
800  // Maybe evolve into an interface.
801  OpOperand *pUse = &(*itProducerUses);
802  BlockArgument bbArg = forallOp.getTiedBlockArgument(pUse);
803 
804  // Search the producer slices accessed within the containing operation.
805  // TODO: Generalize to more extract/insert/parallel_insert triples, maybe
806  // evolve into an interface.
807  auto itBBArgUsers = llvm::find_if(bbArg.getUsers(), [&](Operation *user) {
808  auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
809  return sliceOp && containingOp->isProperAncestor(sliceOp);
810  });
811 
812  // Find a fusion opportunity.
813  if (itBBArgUsers == bbArg.getUsers().end()) {
814  diag.attachNote(containingOp->getLoc())
815  << "could not find fusion opportunity for bbArg: " << bbArg;
816  return {};
817  }
818  auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*itBBArgUsers);
819 
820  // Try to fuse the producer in-place.
821  OpBuilder::InsertionGuard guard(rewriter);
822  rewriter.setInsertionPoint(sliceOpToTile);
823 
824  // Replace the use in the tileableProducer before tiling: clone, replace and
825  // then tile.
826  int64_t resultNumber = cast<OpResult>(pUse->get()).getResultNumber();
827  LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
828 
829  // Gather destination tensors.
830  SmallVector<Value> destinationTensors;
832  rewriter, tileableProducer->getLoc(), tileableProducer,
833  destinationTensors))) {
834  diag.attachNote(tileableProducer->getLoc())
835  << "failed to get destination tensors for: " << *tileableProducer;
836  return {};
837  }
838 
839  IRMapping bvm;
840  bvm.map(destinationTensors[resultNumber], bbArg);
841  auto tileableProducerClone =
842  cast<TilingInterface>(rewriter.clone(*tileableProducer, bvm));
843  auto scopeGuard =
844  llvm::make_scope_exit([&]() { rewriter.eraseOp(tileableProducerClone); });
845 
846  // Tile the producer.
847  FailureOr<TilingResult> tileAndFuseResult =
848  tileableProducerClone.generateResultTileValue(
849  rewriter, resultNumber, sliceOpToTile.getMixedOffsets(),
850  sliceOpToTile.getMixedSizes());
851  if (failed(tileAndFuseResult)) {
852  diag.attachNote(tileableProducer->getLoc())
853  << "failed to tile producer op: " << *tileableProducer;
854  return {};
855  }
856 
857  // Replace the extract op.
858  auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
859  rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
860  cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
861  assert(succeeded(maybeRankReduced) && "unexpected shape");
862  rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
863 
864  // Replace the use in containingOp.
865  rewriter.modifyOpInPlace(containingOp, [&]() {
866  containingOp->setOperand(pUse->getOperandNumber(),
867  destinationTensors.front());
868  });
869 
870  return tileAndFuseResult->tiledOps;
871 }
872 
874  Operation *producerOp,
875  Operation *containingOp) {
876  LLVM_DEBUG(DBGS() << "Try to fuse an use by cloning\n");
877 
878  // Gather all uses inside the containing op.
880  for (OpResult result : producerOp->getOpResults()) {
881  for (OpOperand &use : result.getUses()) {
882  if (containingOp->isProperAncestor(use.getOwner())) {
883  uses.push_back(&use);
884  continue;
885  }
886  // Cannot clone and fuse if the use is by the containing op itself: fail
887  // immediately.
888  if (containingOp == use.getOwner()) {
889  diag.attachNote(producerOp->getLoc())
890  << "producer op use by containing op cannot be fused by cloning";
891  return nullptr;
892  }
893  }
894  }
895 
896  // Check for a non-empty list of fusion opportunities.
897  if (uses.empty()) {
898  diag.attachNote(producerOp->getLoc()) << "no fusion opportunity by cloning";
899  return nullptr;
900  }
901 
902  // Clone and fuse inside the containing op.
903  Operation *fusedOp = nullptr;
904  OpOperand *use = uses.front();
905  // Parallel insert slice is not a valid clone destination.
906  // TODO: Generalize to other type of ops.
907  assert(!isa<tensor::ParallelInsertSliceOp>(use->getOwner()) &&
908  "Parallel insert slice is not a valid clone destination");
909  unsigned resultNumber = cast<OpResult>(use->get()).getResultNumber();
910  LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
911 
912  OpBuilder::InsertionGuard guard(rewriter);
913  rewriter.setInsertionPoint(use->getOwner());
914  fusedOp = rewriter.clone(*producerOp);
915  rewriter.modifyOpInPlace(
916  use->getOwner(), [&] { use->set(fusedOp->getOpResult(resultNumber)); });
917 
918  return fusedOp;
919 }
920 
921 bool transform::FuseIntoContainingOp::allowsRepeatedHandleOperands() {
922  // Allow repeated handles since we are fusing everything anyway.
923  return true;
924 }
925 
927 transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter,
929  transform::TransformState &state) {
930  SmallVector<Operation *> fusedOps;
931  auto producerOps = state.getPayloadOps(getProducerOp());
932  auto containingOps = state.getPayloadOps(getContainingOp());
933  if (!llvm::hasSingleElement(containingOps)) {
934  return emitDefiniteFailure()
935  << "requires exactly one containing_op handle (got "
936  << llvm::range_size(containingOps) << ")";
937  }
938  Operation *containingOp = *containingOps.begin();
939 
940  // If nothing to fuse, propagate success.
941  if (std::empty(producerOps)) {
942  results.set(cast<OpResult>(getFusedOp()), SmallVector<mlir::Operation *>{});
943  results.set(cast<OpResult>(getNewContainingOp()), {containingOp});
945  }
946 
947  // Helper function to find the next producer that should be fused. Take any
948  // producer that has a use inside the containing op.
949  SetVector<Operation *> remainingProducers(producerOps.begin(),
950  producerOps.end());
951  auto getNextProducer = [&]() -> FailureOr<Operation *> {
952  for (const auto &it : enumerate(remainingProducers)) {
953  Operation *producerOp = it.value();
954  // The containing op may be a user of producerOp: use isAncestor.
955  int64_t numUsesInContainingOp =
956  llvm::count_if(producerOp->getUsers(), [&](Operation *op) {
957  return containingOp->isAncestor(op);
958  });
959  // TODO: When resolving the TODO below (no duplicate ops), take an op
960  // that has no use among the remaining producers. This is a topological
961  // sorting.
962  if (numUsesInContainingOp > 0) {
963  if (numUsesInContainingOp == 1)
964  remainingProducers.erase(remainingProducers.begin() + it.index());
965  return producerOp;
966  }
967  }
968  return failure();
969  };
970 
971  while (!remainingProducers.empty()) {
972  auto nextProducer = getNextProducer();
973  if (failed(nextProducer)) {
974  auto diag = mlir::emitSilenceableFailure(getLoc())
975  << "could not find next producer to fuse into container";
976  diag.attachNote(containingOp->getLoc()) << "containing op";
977  return diag;
978  }
979 
980  Operation *producerOp = *nextProducer;
981 
982  // Default diagnostic, to be complemented with more failure information.
984  diag << "could not fuse " << *producerOp << " into " << *containingOp;
985 
986  // TODO: If there are multiple uses of the producer in the containing op,
987  // we currently tile/clone the op multiple times (once per use). In some
988  // cases, we can tile/clone once and reuse the value for each use.
989  // Futhermore, producers should then be traversed according to a
990  // topological sorting.
991  auto [tiledOps, newContainingOp] =
992  tileAndFuseFirstExtractUse(rewriter, diag, producerOp, containingOp);
993  if (!tiledOps.empty()) {
994  LLVM_DEBUG(DBGS() << "\nFused a direct extract use\n" << *containingOp);
995  fusedOps.append(tiledOps);
996  if (newContainingOp) {
997  // Update handles associated with the containing op so we don't need to
998  // invalidate them. This is a hack to support better composability
999  // between tiling and fusion while a proper mechanism is being
1000  // investigated.
1001  //
1002  // DO NOT replicate this elsewhere unless you understand what you are
1003  // doing.
1004  LogicalResult replacementStatus =
1005  rewriter.notifyPayloadOperationReplaced(containingOp,
1006  newContainingOp);
1007  (void)replacementStatus;
1008  assert(succeeded(replacementStatus) &&
1009  "unable to update transform state mapping");
1010  rewriter.eraseOp(containingOp);
1011  containingOp = newContainingOp;
1012  }
1013  continue;
1014  }
1015 
1016  SmallVector<Operation *> tiledContainingOpOperand =
1018  rewriter, diag, producerOp, containingOp);
1019  if (!tiledContainingOpOperand.empty()) {
1020  LLVM_DEBUG(DBGS() << "\nFused an extract use through block argument\n"
1021  << *containingOp);
1022  fusedOps.append(tiledContainingOpOperand);
1023  continue;
1024  }
1025 
1026  Operation *cloned =
1027  cloneAndFuseFirstUse(rewriter, diag, producerOp, containingOp);
1028  if (cloned) {
1029  LLVM_DEBUG(DBGS() << "\nFused an use by cloning\n" << *containingOp);
1030  fusedOps.push_back(cloned);
1031  continue;
1032  }
1034  }
1035 
1036  results.set(cast<OpResult>(getFusedOp()), fusedOps);
1037  results.set(cast<OpResult>(getNewContainingOp()), {containingOp});
1039 }
1040 
1041 void transform::FuseIntoContainingOp::getEffects(
1043  consumesHandle(getProducerOp(), effects);
1044  onlyReadsHandle(getContainingOp(), effects);
1045  producesHandle(getResults(), effects);
1046  modifiesPayload(effects);
1047 }
1048 
1049 //===----------------------------------------------------------------------===//
1050 // GeneralizeOp
1051 //===----------------------------------------------------------------------===//
1052 
1054 transform::GeneralizeOp::applyToOne(transform::TransformRewriter &rewriter,
1055  LinalgOp target,
1057  transform::TransformState &state) {
1058  // Exit early if no transformation is needed.
1059  if (isa<GenericOp>(target)) {
1060  results.push_back(target);
1062  }
1063  rewriter.setInsertionPoint(target);
1064  FailureOr<LinalgOp> generic = generalizeNamedOp(rewriter, target);
1065  if (succeeded(generic)) {
1066  results.push_back(generic->getOperation());
1068  }
1069  return emitDefaultSilenceableFailure(target);
1070 }
1071 
1072 //===----------------------------------------------------------------------===//
1073 // SpecializeOp
1074 //===----------------------------------------------------------------------===/
1075 
1077 transform::SpecializeOp::applyToOne(transform::TransformRewriter &rewriter,
1078  LinalgOp target,
1080  transform::TransformState &state) {
1081  // Exit early if the operation is not a generic.
1082  if (!isa<GenericOp>(target)) {
1083  results.push_back(target);
1085  }
1086  rewriter.setInsertionPoint(target);
1087  FailureOr<LinalgOp> named =
1088  specializeGenericOp(rewriter, cast<GenericOp>(target));
1089  if (succeeded(named)) {
1090  results.push_back(named->getOperation());
1092  }
1093  return emitDefaultSilenceableFailure(target);
1094 }
1095 
1096 //===----------------------------------------------------------------------===//
1097 // InterchangeOp
1098 //===----------------------------------------------------------------------===//
1099 
1101 transform::InterchangeOp::applyToOne(transform::TransformRewriter &rewriter,
1102  GenericOp target,
1104  transform::TransformState &state) {
1105  ArrayRef<int64_t> interchangeVector = getIteratorInterchange();
1106  // Exit early if no transformation is needed.
1107  if (interchangeVector.empty()) {
1108  results.push_back(target);
1110  }
1111 
1112  unsigned numLoops = cast<LinalgOp>(target.getOperation()).getNumLoops();
1113  if (interchangeVector.size() != numLoops) {
1114  return emitSilenceableError()
1115  << getIteratorInterchangeAttrName() << " has length ("
1116  << interchangeVector.size()
1117  << ") different from the number of loops in the target operation ("
1118  << numLoops << ")";
1119  }
1120  FailureOr<GenericOp> res =
1121  interchangeGenericOp(rewriter, target,
1122  SmallVector<unsigned>(interchangeVector.begin(),
1123  interchangeVector.end()));
1124  if (failed(res))
1125  return emitDefiniteFailure() << "failed to apply";
1126  results.push_back(res->getOperation());
1128 }
1129 
1131  ArrayRef<int64_t> permutation = getIteratorInterchange();
1132  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
1133  if (!std::is_permutation(sequence.begin(), sequence.end(),
1134  permutation.begin(), permutation.end())) {
1135  return emitOpError()
1136  << "expects iterator_interchange to be a permutation, found "
1137  << getIteratorInterchange();
1138  }
1139  return success();
1140 }
1141 
1142 //===----------------------------------------------------------------------===//
1143 // LowerPackOp
1144 //===----------------------------------------------------------------------===//
1145 
1146 DiagnosedSilenceableFailure transform::LowerPackOp::applyToOne(
1147  transform::TransformRewriter &rewriter, tensor::PackOp target,
1148  transform::ApplyToEachResultList &transformResults,
1149  transform::TransformState &state) {
1150  rewriter.setInsertionPoint(target);
1151  FailureOr<LowerPackResult> res = lowerPack(rewriter, target);
1152  if (failed(res)) {
1153  return mlir::emitSilenceableFailure(target->getLoc())
1154  << "cannot lower to pad + expand + transpose";
1155  }
1156  transformResults.push_back(res->padOp);
1157  transformResults.push_back(res->expandShapeOp);
1158  transformResults.push_back(res->transposeOp);
1160 }
1161 
1162 //===----------------------------------------------------------------------===//
1163 // LowerUnPackOp
1164 //===----------------------------------------------------------------------===//
1165 
1166 DiagnosedSilenceableFailure transform::LowerUnPackOp::applyToOne(
1167  transform::TransformRewriter &rewriter, tensor::UnPackOp target,
1168  transform::ApplyToEachResultList &transformResults,
1169  transform::TransformState &state) {
1170  rewriter.setInsertionPoint(target);
1171  FailureOr<LowerUnPackOpResult> res = lowerUnPack(rewriter, target);
1172  if (failed(res)) {
1174  emitSilenceableError()
1175  << "cannot lower to transpose + collapse + extract";
1176  diag.attachNote(target->getLoc()) << "target payload op";
1177  return diag;
1178  }
1179  transformResults.push_back(res->emptyOp);
1180  transformResults.push_back(res->transposeOp);
1181  transformResults.push_back(res->collapseShapeOp);
1182  transformResults.push_back(res->extractSliceOp);
1184 }
1185 
1186 //===---------------------------------------------------------------------===//
1187 // MatchOp
1188 //===---------------------------------------------------------------------===//
1189 
1190 void transform::MatchOp::build(OpBuilder &builder, OperationState &result,
1191  Value target, ArrayRef<StringRef> opNames) {
1192  result.addOperands(target);
1193  result.addAttribute(MatchOp::getOpsAttrName(result.name),
1194  builder.getStrArrayAttr(opNames));
1195  result.addTypes(transform::AnyOpType::get(builder.getContext()));
1196 }
1197 
1198 void transform::MatchOp::build(OpBuilder &builder, OperationState &result,
1199  TypeRange resultTypes, Value target,
1200  ArrayRef<StringRef> opNames) {
1201  result.addOperands(target);
1202  result.addAttribute(MatchOp::getOpsAttrName(result.name),
1203  builder.getStrArrayAttr(opNames));
1204  result.addTypes(resultTypes);
1205 }
1206 
1208 transform::MatchOp::apply(transform::TransformRewriter &rewriter,
1209  transform::TransformResults &results,
1210  transform::TransformState &state) {
1211  llvm::StringSet<> strs;
1212  if (getOps().has_value())
1213  strs.insert(getOps()->getAsValueRange<StringAttr>().begin(),
1214  getOps()->getAsValueRange<StringAttr>().end());
1215 
1216  auto payloadOps = state.getPayloadOps(getTarget());
1217  if (!llvm::hasSingleElement(payloadOps)) {
1218  return emitDefiniteFailure("requires exactly one target handle");
1219  }
1220 
1222  bool incorrectNumOperandTypes = false;
1223  auto matchFun = [&](Operation *op) {
1224  if (getOps().has_value() && !strs.contains(op->getName().getStringRef()))
1225  return;
1226 
1227  // Interfaces cannot be matched by name, just by ID.
1228  // So we specifically encode the interfaces we care about for this op.
1229  if (getInterface().has_value()) {
1230  auto iface = getInterface().value();
1231  if (iface == transform::MatchInterfaceEnum::LinalgOp &&
1232  !isa<LinalgOp>(op))
1233  return;
1234  if (iface == transform::MatchInterfaceEnum::TilingInterface &&
1235  !isa<TilingInterface>(op))
1236  return;
1237  if (iface == transform::MatchInterfaceEnum::LoopLikeInterface &&
1238  !isa<LoopLikeOpInterface>(op))
1239  return;
1240  }
1241 
1242  // Check if all specified attributes match.
1243  if (getOpAttrs().has_value()) {
1244  DictionaryAttr opAttrs = getOpAttrs().value();
1245  for (NamedAttribute attr : opAttrs) {
1246  if (attr.getName() == getInterfaceAttrName() ||
1247  attr.getName() == getOpsAttrName())
1248  continue;
1249  if (!op->hasAttr(attr.getName()))
1250  return;
1251  if (op->getAttr(attr.getName()) != attr.getValue())
1252  return;
1253  }
1254  }
1255 
1256  if (getFilterResultType().has_value()) {
1257  Type t = getFilterResultType().value();
1258  if (op->getNumResults() != 1 || op->getResultTypes().front() != t)
1259  return;
1260  }
1261 
1262  if (getFilterOperandTypes().has_value()) {
1263  mlir::ArrayAttr types = getFilterOperandTypes().value();
1264  auto operandTypes = op->getOperandTypes();
1265 
1266  if (types.size() == 1) {
1267  // All the operands must must be equal to the specified type
1268  auto typeattr =
1269  dyn_cast<mlir::TypeAttr>(getFilterOperandTypes().value()[0]);
1270  Type t = cast<::mlir::Type>(typeattr.getValue());
1271  if (!llvm::all_of(op->getOperandTypes(),
1272  [&](Type operandType) { return operandType == t; }))
1273  return;
1274  } else {
1275  // The operand types must match all the types in the list (in the same
1276  // order in with they are specified)
1277  if (types.size() != operandTypes.size()) {
1278  incorrectNumOperandTypes = true;
1279  return;
1280  }
1281 
1282  for (auto [attr, operandType] :
1283  llvm::zip_equal(getFilterOperandTypes().value(), operandTypes)) {
1284  auto typeattr = cast<mlir::TypeAttr>(attr);
1285  Type type = cast<::mlir::Type>(typeattr.getValue());
1286 
1287  if (type != operandType)
1288  return;
1289  }
1290  }
1291  }
1292 
1293  // All constraints are satisfied.
1294  res.push_back(op);
1295  return;
1296  };
1297 
1298  (*payloadOps.begin())->walk(matchFun);
1299  if (incorrectNumOperandTypes)
1300  return emitDefiniteFailure("If filter_operand_types contains more than a "
1301  "type, then it must contain as much types as "
1302  "the number of operands in the target ops");
1303  results.set(cast<OpResult>(getResult()), res);
1305 }
1306 
1307 //===---------------------------------------------------------------------===//
1308 // MultiTileSizesOp
1309 //===---------------------------------------------------------------------===//
1310 
1312  Type targetType, Type lowSizeType, Type,
1313  Type) {
1314  printer.printFunctionalType(TypeRange{targetType}, TypeRange{lowSizeType});
1315 }
1316 
1318  Type &targetType, Type &lowSizeType,
1319  Type &highSizeType,
1320  Type &splitPointType) {
1321  FunctionType funcType;
1322  llvm::SMLoc typeLoc = parser.getCurrentLocation();
1323  if (failed(parser.parseType<FunctionType>(funcType)))
1324  return failure();
1325 
1326  if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
1327  parser.emitError(typeLoc) << "expects a trailing functional type with one "
1328  "argument and one result";
1329  }
1330  targetType = funcType.getInput(0);
1331  lowSizeType = highSizeType = splitPointType = funcType.getResult(0);
1332 
1333  return success();
1334 }
1335 
1336 DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne(
1337  transform::TransformRewriter &rewriter, LinalgOp target,
1339  if (isa<TransformParamTypeInterface>(getLowSize().getType())) {
1340  if (target.hasDynamicShape()) {
1341  auto diag = emitSilenceableError()
1342  << "cannot compute parametric tile sizes for dynamically "
1343  "shaped payload op";
1344  diag.attachNote(target->getLoc()) << "payload op";
1345  return diag;
1346  }
1347 
1349  target, getDimension(), getTargetSize(), getDivisor());
1350  if (failed(spec)) {
1351  return emitSilenceableError()
1352  << "failed to compute multi-size tiling sizes";
1353  }
1354 
1355  Builder builder(target.getContext());
1356  results.assign(llvm::map_range(
1357  ArrayRef<int64_t>({spec->lowTileSize, spec->highTileSize,
1358  spec->lowTileSize * spec->lowTripCount}),
1359  [&builder, this](int64_t value) {
1360  return builder.getIntegerAttr(
1361  cast<ParamType>(getLowSize().getType()).getType(), value);
1362  }));
1364  }
1365 
1366  OpBuilder builder(target.getContext());
1367  builder.setInsertionPoint(target);
1368  OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
1369  OpFoldResult divisor = builder.getIndexAttr(getDivisor());
1371  builder, target, getDimension(), targetSize, divisor);
1372  if (failed(spec)) {
1373  return emitSilenceableError() << "could not generate tile size computation";
1374  }
1375 
1376  AffineExpr s0 = builder.getAffineSymbolExpr(0);
1377  AffineExpr s1 = builder.getAffineSymbolExpr(1);
1378  Operation *splitPoint =
1379  affine::makeComposedAffineApply(builder, target.getLoc(), s0 * s1,
1380  {spec->lowTileSize, spec->lowTripCount});
1381  Operation *lowTileSize = spec->lowTileSize.getDefiningOp();
1382  Operation *highTileSize = spec->highTileSize.getDefiningOp();
1383  assert(lowTileSize && highTileSize && splitPoint &&
1384  "tile sizes are not produced by operations");
1385  results.reserve(results.size() + 3);
1386  results.push_back(lowTileSize);
1387  results.push_back(highTileSize);
1388  results.push_back(splitPoint);
1390 }
1391 
1392 void transform::MultiTileSizesOp::getEffects(
1394  onlyReadsHandle(getTarget(), effects);
1395  producesHandle(getResults(), effects);
1396  if (isa<TransformParamTypeInterface>(getLowSize().getType()))
1397  onlyReadsPayload(effects);
1398  else
1399  modifiesPayload(effects);
1400 }
1401 
1403  if (getLowSize().getType() != getHighSize().getType() ||
1404  getLowSize().getType() != getSplitPoint().getType()) {
1405  return emitOpError() << "expects all results type to be the same";
1406  }
1407  return success();
1408 }
1409 
1410 //===---------------------------------------------------------------------===//
1411 // PackOp
1412 //===---------------------------------------------------------------------===//
1413 
1414 void transform::PackOp::build(OpBuilder &builder, OperationState &result,
1415  Value target,
1416  ArrayRef<OpFoldResult> mixedPackedSizes) {
1417  SmallVector<int64_t> staticPackedSizes;
1418  SmallVector<Value> dynamicPackedSizes;
1419  dispatchIndexOpFoldResults(mixedPackedSizes, dynamicPackedSizes,
1420  staticPackedSizes);
1421  // Call the default builder which sets up the proper operands segment sizes
1422  // attributes for multiple variadic operands. In the absence of this, horrible
1423  // bugs ensue.
1424  Type linalgOpHType = transform::OperationType::get(
1425  builder.getContext(), GenericOp::getOperationName());
1426  build(builder, result,
1427  /*resultType=*/linalgOpHType,
1428  /*target=*/target,
1429  /*dynamic_sizes=*/dynamicPackedSizes,
1430  /*static_sizes=*/builder.getDenseI64ArrayAttr(staticPackedSizes));
1431 }
1432 
1433 SmallVector<OpFoldResult> transform::PackOp::getMixedPackedSizes() {
1434  Builder b(getContext());
1435  return getMixedValues(getStaticPackedSizes(), getPackedSizes(), b);
1436 }
1437 
1439 transform::PackOp::apply(transform::TransformRewriter &rewriter,
1440  transform::TransformResults &transformResults,
1441  transform::TransformState &state) {
1442  auto targetOps = state.getPayloadOps(getTarget());
1443  // If nothing to pack, propagate success.
1444  if (std::empty(targetOps)) {
1445  transformResults.set(cast<OpResult>(getPackedOp()),
1446  ArrayRef<Operation *>({}));
1448  }
1449  // Fail on multi-op handles.
1450  auto linalgOp = dyn_cast<LinalgOp>(*targetOps.begin());
1451  if (!llvm::hasSingleElement(targetOps) || !linalgOp) {
1452  return emitSilenceableError()
1453  << "requires target to map to exactly 1 LinalgOp (got "
1454  << llvm::range_size(targetOps) << ")";
1455  }
1456  // Fail on mismatched number of pack sizes.
1457  if (getMixedPackedSizes().size() != linalgOp.getNumLoops()) {
1458  return emitSilenceableError()
1459  << "requires number of packed sizes match the number of loops ("
1460  << getMixedPackedSizes().size() << " vs " << linalgOp.getNumLoops()
1461  << ")";
1462  }
1463 
1464  // Unpack handles to constants or actual SSA index values.
1465  SmallVector<OpFoldResult> packedSizes;
1467  state, *this, packedSizes, getMixedPackedSizes());
1468 
1469  rewriter.setInsertionPoint(linalgOp);
1470  FailureOr<PackResult> maybeResult = pack(rewriter, linalgOp, packedSizes);
1471  if (failed(maybeResult))
1472  return emitDefiniteFailure("data tiling failed");
1473 
1474  transformResults.set(cast<OpResult>(getPackedOp()),
1475  {maybeResult->packedLinalgOp.getOperation()});
1477 }
1478 
1479 void transform::PackOp::getEffects(
1481  transform::consumesHandle(getTarget(), effects);
1482  transform::onlyReadsHandle(getPackedSizes(), effects);
1483  transform::producesHandle(getPackedOp(), effects);
1484  transform::modifiesPayload(effects);
1485 }
1486 
1487 //===---------------------------------------------------------------------===//
1488 // PackGreedilyOp.
1489 //===---------------------------------------------------------------------===//
1490 
1492  if (!isPermutationVector(getMatmulInnerDimsOrder())) {
1493  return emitOpError() << getMatmulInnerDimsOrderAttrName()
1494  << " is not a valid permutation";
1495  }
1496  // TODO: relax to allow empty once we have another strategy than just matmul.
1497  if (!getMatmulPaddedSizesNextMultipleOf().empty()) {
1498  for (auto [s, nmo] :
1499  llvm::zip_equal(getMixedMatmulPackedSizes(),
1500  getMatmulPaddedSizesNextMultipleOf())) {
1501  std::optional<int64_t> maybeStaticPackedSize = getConstantIntValue(s);
1502  if (nmo != 0 &&
1503  (!maybeStaticPackedSize.has_value() || *maybeStaticPackedSize != 0)) {
1504  return emitOpError() << "at most one of the packed_size and the "
1505  "padded_sizes_next_multiple_of can be nonzero "
1506  "for the matmul strategy";
1507  }
1508  }
1509  }
1510  return success();
1511 }
1512 
1514 PackGreedilyOp::apply(transform::TransformRewriter &rewriter,
1515  transform::TransformResults &transformResults,
1516  transform::TransformState &state) {
1517  SmallVector<Operation *> results;
1518  for (Operation *op : state.getPayloadOps(getTarget())) {
1519  auto linalgOp = dyn_cast<LinalgOp>(op);
1520  if (!linalgOp)
1521  continue;
1522  // linalgOp will be replaced and the insertion point may be invalidated if
1523  // we set it before -> set it after.
1524  rewriter.setInsertionPointAfter(linalgOp);
1525  // Failing to pack greedily is perfectly fine.
1526  // In the future we will want to order packings according to some metric.
1528  /*rewriter=*/rewriter,
1529  /*linalgOp=*/linalgOp,
1530  /*mnkPackedSizes=*/getMixedMatmulPackedSizes(),
1531  /*mnkPaddedSizesNextMultipleOf=*/
1532  getMatmulPaddedSizesNextMultipleOf(),
1533  /*mnkOrder=*/getMatmulInnerDimsOrder());
1534  if (succeeded(packResult)) {
1535  results.push_back(packResult->packedLinalgOp);
1536  continue;
1537  }
1538  results.push_back(linalgOp);
1539  }
1540  transformResults.set(cast<OpResult>(getPackedOp()), results);
1542 }
1543 
1544 SmallVector<OpFoldResult> PackGreedilyOp::getMixedMatmulPackedSizes() {
1545  Builder b(getContext());
1546  return getMixedValues(getStaticMatmulPackedSizes(), getMatmulPackedSizes(),
1547  b);
1548 }
1549 
1550 void transform::PackGreedilyOp::getEffects(
1552  transform::consumesHandle(getTarget(), effects);
1553  transform::onlyReadsHandle(getMatmulPackedSizes(), effects);
1554  transform::producesHandle(getPackedOp(), effects);
1555  transform::modifiesPayload(effects);
1556 }
1557 
1558 //===---------------------------------------------------------------------===//
1559 // PackTransposeOp
1560 //===---------------------------------------------------------------------===//
1561 
1563  if (!isPermutationVector(getInnerPerm())) {
1564  return emitOpError() << getInnerPermAttrName()
1565  << " is not a valid permutation";
1566  }
1567  if (!isPermutationVector(getOuterPerm())) {
1568  return emitOpError() << getOuterPermAttrName()
1569  << " is not a valid permutation";
1570  }
1571  if (getInnerPerm().empty() && getOuterPerm().empty()) {
1572  return emitOpError() << " at least one of " << getInnerPermAttrName()
1573  << " or " << getOuterPermAttrName()
1574  << " must be specified";
1575  }
1576  return success();
1577 }
1578 
1579 namespace {
1580 enum class OuterOrInnerPerm { Outer = 0, Inner = 1 };
1581 } // namespace
1582 
1583 /// Return true if `permutation` is a valid permutation of the
1584 /// `outer_dims_perm` (case OuterOrInnerPerm::Outer) or `inner_dims_pos`
1585 /// (OuterOrInnerPerm::Inner) of the `tensor.pack` or `tensor.unpack` `op.
1586 /// This is the case when the `permutation` rank matches the rank expected by
1587 /// `op` and `permutation` is itself a permutation vector.
1588 /// Return true if either `op` or `permutation` are empty to allow a simpler
1589 /// polymorphic implementation.
1590 template <typename RelayoutOpTy>
1592  RelayoutOpTy op, ArrayRef<int64_t> permutation,
1593  OuterOrInnerPerm outerOrInnerPerm = OuterOrInnerPerm::Outer) {
1594  static_assert(
1595  llvm::is_one_of<RelayoutOpTy, tensor::PackOp, tensor::UnPackOp>::value,
1596  "applies to only pack or unpack operations");
1597  if (!op || permutation.empty())
1598  return true;
1599  size_t innerRank = op.getInnerDimsPos().size();
1600  if (outerOrInnerPerm == OuterOrInnerPerm::Inner)
1601  return permutation.size() == innerRank && isPermutationVector(permutation);
1602  // op.getOuterDimsPerm() may be empty, in which case it is identity.
1603  // Don't rely on it.
1604  if (std::is_same<RelayoutOpTy, tensor::PackOp>::value) {
1605  return permutation.size() == op.getSourceRank() &&
1606  isPermutationVector(permutation);
1607  }
1608  return permutation.size() == op.getDestRank() &&
1609  isPermutationVector(permutation);
1610 }
1611 
1613 transform::PackTransposeOp::apply(transform::TransformRewriter &rewriter,
1614  transform::TransformResults &transformResults,
1615  transform::TransformState &state) {
1616  auto packOrUnpackOps = state.getPayloadOps(getTargetPackOrUnPackOp());
1617  auto linalgOps = state.getPayloadOps(getTargetLinalgOp());
1618  // Step 1. If nothing to pack, propagate success.
1619  if (std::empty(packOrUnpackOps)) {
1620  transformResults.set(cast<OpResult>(getPackedOp()), {});
1621  transformResults.set(cast<OpResult>(getPackOp()), {});
1622  transformResults.set(cast<OpResult>(getUnPackOp()), {});
1624  }
1625 
1626  // Step 2. Bunch of runtime sanity check and error messages.
1627  // Step 2.1. Fail on multi-op handles.
1628  if (!llvm::hasSingleElement(packOrUnpackOps) ||
1629  !llvm::hasSingleElement(linalgOps)) {
1630  return emitSilenceableError()
1631  << "requires target to map to exactly 1 "
1632  "packing op and 1 packed op ("
1633  << "got " << llvm::range_size(packOrUnpackOps) << " and "
1634  << llvm::range_size(linalgOps) << ")";
1635  }
1636 
1637  // Step 2.2. Fail on wrong type.
1638  auto packOp = dyn_cast<tensor::PackOp>(*packOrUnpackOps.begin());
1639  auto unPackOp = dyn_cast<tensor::UnPackOp>(*packOrUnpackOps.begin());
1640  if ((!packOp && !unPackOp)) {
1641  return emitSilenceableError() << "requires target to map to a "
1642  "tensor.pack or tensor.unpack";
1643  }
1644  LinalgOp linalgOpTarget = dyn_cast<LinalgOp>(*linalgOps.begin());
1645  if (!linalgOpTarget)
1646  return emitSilenceableError() << "requires a LinalgOp target";
1647 
1648  // Step 2.3. Fail if we can't get the producer / consumer Linalg op.
1649  LinalgOp linalgOp;
1650  if (packOp && packOp.getResult().hasOneUse())
1651  linalgOp = dyn_cast<LinalgOp>(*(packOp.getResult().getUsers().begin()));
1652  else if (unPackOp)
1653  linalgOp = unPackOp.getSource().getDefiningOp<LinalgOp>();
1654  if (linalgOp != linalgOpTarget) {
1655  auto errorMsg =
1656  packOp ? StringLiteral{"not a single use by the LinalgOp target"}
1657  : StringLiteral{"not produced by the LinalgOp target"};
1658  return emitSilenceableError() << errorMsg;
1659  }
1660 
1661  // Step 2.4. If we have an UnPackOp, we need to fetch the symmetrical
1662  // PackOp.
1663  if (unPackOp) {
1664  assert(!packOp && "packOp must be null on entry when unPackOp is not null");
1665  OpOperand *packUse = linalgOp.getDpsInitOperand(
1666  cast<OpResult>(unPackOp.getSource()).getResultNumber());
1667  packOp = dyn_cast_or_null<tensor::PackOp>(packUse->get().getDefiningOp());
1668  if (!packOp || !packOp.getResult().hasOneUse())
1669  return emitSilenceableError() << "could not find matching pack op";
1670  }
1671 
1672  // Step 2.5. Fail if any permutation does not validate.
1673  for (auto permType : {OuterOrInnerPerm::Outer, OuterOrInnerPerm::Inner}) {
1674  ArrayRef<int64_t> perm =
1675  (permType == OuterOrInnerPerm::Outer) ? getOuterPerm() : getInnerPerm();
1676  auto errorMsg = (permType == OuterOrInnerPerm::Outer)
1677  ? StringLiteral{"invalid outer_perm"}
1678  : StringLiteral{"invalid inner_perm"};
1679  if (!isValidPackingPermutation(packOp, perm, permType) ||
1680  !isValidPackingPermutation(unPackOp, perm, permType)) {
1681  Operation *packOrUnpackOp =
1682  unPackOp ? unPackOp.getOperation() : packOp.getOperation();
1683  return emitSilenceableError() << errorMsg << ": " << *packOrUnpackOp;
1684  }
1685  }
1686 
1687  // From here on, packOp and linalgOp are always present, unPackOp may or may
1688  // not be present.
1689  assert(packOp && linalgOp && "unexpected null op");
1690 
1691  // Step 3. Actually transpose the ops.
1693  rewriter, packOp, linalgOp, unPackOp, getOuterPerm(), getInnerPerm());
1694  // Preconditions have been checked, it is an error to fail here.
1695  assert(succeeded(res) && "unexpected packTranspose failure");
1696 
1697  // Step 4. Return results.
1698  transformResults.set(cast<OpResult>(getPackOp()), {res->transposedPackOp});
1699  transformResults.set(cast<OpResult>(getPackedOp()),
1700  {res->transposedLinalgOp});
1701  if (unPackOp) {
1702  transformResults.set(cast<OpResult>(getUnPackOp()),
1703  {res->transposedUnPackOp});
1704  } else {
1705  transformResults.set(cast<OpResult>(getUnPackOp()), {});
1706  }
1707 
1709 }
1710 
1711 //===---------------------------------------------------------------------===//
1712 // PadOp
1713 //===---------------------------------------------------------------------===//
1714 
1715 void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
1716  ArrayRef<int64_t> paddingDimensions,
1717  ArrayRef<int64_t> padToMultipleOf,
1718  ArrayRef<int64_t> packPaddings,
1719  ArrayRef<Attribute> transposePaddings,
1720  StringRef copyBackOp) {
1721  auto resultType = transform::AnyOpType::get(b.getContext());
1722  return build(/*builder=*/b,
1723  /*result=*/result,
1724  /*types=*/TypeRange{resultType, resultType},
1725  /*target=*/target,
1726  /*paddingValues=*/ArrayAttr(), // let inference handle this
1727  /*paddingDimensions=*/b.getI64ArrayAttr(paddingDimensions),
1728  /*padToMultipleOf=*/ValueRange{},
1729  /*padToMultipleOf=*/
1730  (padToMultipleOf.empty()
1731  ? DenseI64ArrayAttr()
1732  : b.getDenseI64ArrayAttr(padToMultipleOf)),
1733  /*packPaddings=*/b.getI64ArrayAttr(packPaddings),
1734  /*transposePaddings=*/b.getArrayAttr(transposePaddings),
1735  /*copyBackOp=*/b.getStringAttr(copyBackOp));
1736 }
1737 
1738 void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
1739  ArrayRef<int64_t> paddingDimensions,
1740  ArrayRef<OpFoldResult> mixedPadToMultipleOf,
1741  ArrayRef<int64_t> packPaddings,
1742  ArrayRef<Attribute> transposePaddings,
1743  StringRef copyBackOp) {
1744  auto resultType = transform::AnyOpType::get(b.getContext());
1745  SmallVector<int64_t> staticPadToMultipleOf;
1746  SmallVector<Value> dynamicPadToMultipleOf;
1747  dispatchIndexOpFoldResults(mixedPadToMultipleOf, dynamicPadToMultipleOf,
1748  staticPadToMultipleOf);
1749  return build(/*builder=*/b,
1750  /*result=*/result,
1751  /*types=*/TypeRange{resultType, resultType},
1752  /*target=*/target,
1753  /*paddingValues=*/ArrayAttr(), // let inference handle this
1754  /*paddingDimensions=*/b.getI64ArrayAttr(paddingDimensions),
1755  /*padToMultipleOf=*/dynamicPadToMultipleOf,
1756  /*padToMultipleOf=*/staticPadToMultipleOf,
1757  /*packPaddings=*/b.getI64ArrayAttr(packPaddings),
1758  /*transposePaddings=*/b.getArrayAttr(transposePaddings),
1759  /*copyBackOp=*/b.getStringAttr(copyBackOp));
1760 }
1761 
1762 void PadOp::getEffects(
1764  consumesHandle(getTarget(), effects);
1765  onlyReadsHandle(getPadToMultipleOf(), effects);
1766  producesHandle(getPadded(), effects);
1767  producesHandle(getPad(), effects);
1768  producesHandle(getCopy(), effects);
1769  modifiesPayload(effects);
1770 }
1771 
1772 SmallVector<OpFoldResult> PadOp::getMixedPadToMultipleOf() {
1773  Builder b(getContext());
1774  return getMixedValues(getStaticPadToMultipleOf(), getPadToMultipleOf(), b);
1775 }
1776 
1778 transform::PadOp::apply(transform::TransformRewriter &rewriter,
1779  transform::TransformResults &results,
1780  transform::TransformState &state) {
1781  auto transformOp = cast<TransformOpInterface>(getOperation());
1782  SmallVector<Operation *> paddedOps, padOps, copyBackOps;
1783 
1784  for (Operation *target : state.getPayloadOps(getTarget())) {
1785  auto linalgTarget = dyn_cast<LinalgOp>(target);
1786  if (!linalgTarget) {
1787  auto diag = emitSilenceableError() << "expected LinalgOp target";
1788  diag.attachNote(target->getLoc()) << "target op";
1789  return diag;
1790  }
1791 
1792  // Convert the integer packing flags to booleans.
1793  SmallVector<bool> packPaddings;
1794  for (int64_t packPadding :
1795  extractFromIntegerArrayAttr<int64_t>(getPackPaddings()))
1796  packPaddings.push_back(static_cast<bool>(packPadding));
1797 
1798  // Convert the padding values to attributes.
1799  SmallVector<Attribute> paddingValues;
1800  for (auto const &it :
1801  llvm::zip(getPaddingValues(), linalgTarget->getOperandTypes())) {
1802  auto attr = dyn_cast<TypedAttr>(std::get<0>(it));
1803  if (!attr) {
1804  emitOpError("expects padding values to be typed attributes");
1806  }
1807  Type elementType = getElementTypeOrSelf(std::get<1>(it));
1808  // Try to parse string attributes to obtain an attribute of element type.
1809  if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
1810  auto parsedAttr = dyn_cast_if_present<TypedAttr>(parseAttribute(
1811  stringAttr, getContext(), elementType,
1812  /*numRead=*/nullptr, /*isKnownNullTerminated=*/true));
1813  if (!parsedAttr || parsedAttr.getType() != elementType) {
1814  auto diag = this->emitOpError("expects a padding that parses to ")
1815  << elementType << ", got " << std::get<0>(it);
1816  diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";
1818  }
1819  paddingValues.push_back(parsedAttr);
1820  continue;
1821  }
1822  // Otherwise, add the attribute directly.
1823  if (attr.getType() != elementType) {
1824  auto diag = this->emitOpError("expects a padding value of type ")
1825  << elementType << ", got " << attr;
1826  diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";
1828  }
1829  paddingValues.push_back(attr);
1830  }
1831 
1832  // Extract the transpose vectors.
1833  SmallVector<SmallVector<int64_t>> transposePaddings;
1834  for (Attribute transposeVector : cast<ArrayAttr>(getTransposePaddings()))
1835  transposePaddings.push_back(extractFromIntegerArrayAttr<int64_t>(
1836  cast<ArrayAttr>(transposeVector)));
1837 
1838  LinalgOp paddedOp;
1840  options.paddingDimensions =
1841  extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
1842 
1843  SmallVector<int64_t> padToMultipleOf;
1845  state, transformOp, getMixedPadToMultipleOf(), padToMultipleOf);
1846  if (!status.succeeded())
1847  return status;
1848  if (padToMultipleOf.empty())
1849  padToMultipleOf =
1850  SmallVector<int64_t>(options.paddingDimensions.size(), 1);
1851 
1852  options.padToMultipleOf = padToMultipleOf;
1853  options.paddingValues = paddingValues;
1854  options.packPaddings = packPaddings;
1855  if (getCopyBackOp() ==
1856  bufferization::MaterializeInDestinationOp::getOperationName()) {
1859  } else if (getCopyBackOp() == linalg::CopyOp::getOperationName()) {
1861  } else if (getCopyBackOp() == kCopyOpNone) {
1863  } else {
1864  llvm_unreachable("unsupported copy_back op");
1865  }
1866 
1867  SmallVector<Value> replacements;
1868  SmallVector<tensor::PadOp> newPadOps;
1869  if (failed(rewriteAsPaddedOp(rewriter, linalgTarget, options, paddedOp,
1870  replacements, newPadOps))) {
1871  auto diag = emitSilenceableError() << "failed to pad op";
1872  diag.attachNote(target->getLoc()) << "target op";
1873  return diag;
1874  }
1875 
1876  // We need to perform our own replacement here because this API is still
1877  // used in patterns that "pad and hoist", for which the replacement values
1878  // need to be different.
1879  // TODO: clean this up and stop "pad and hoist" behavior more globally now
1880  // that we have more composable abstractions.
1881  rewriter.replaceOp(linalgTarget, replacements);
1882  paddedOps.push_back(paddedOp);
1883  padOps.append(newPadOps.begin(), newPadOps.end());
1884  if (options.copyBackOp != LinalgPaddingOptions::CopyBackOp::None) {
1885  for (Value v : replacements) {
1886  Operation *copyBackOp = v.getDefiningOp();
1887  if (!llvm::is_contained(copyBackOps, copyBackOp))
1888  copyBackOps.push_back(copyBackOp);
1889  }
1890  }
1891  }
1892 
1893  results.set(cast<OpResult>(getPadded()), paddedOps);
1894  results.set(cast<OpResult>(getPad()), padOps);
1895  results.set(cast<OpResult>(getCopy()), copyBackOps);
1897 }
1898 
1900  SmallVector<int64_t> packPaddings =
1901  extractFromIntegerArrayAttr<int64_t>(getPackPaddings());
1902  if (any_of(packPaddings, [](int64_t packPadding) {
1903  return packPadding != 0 && packPadding != 1;
1904  })) {
1905  return emitOpError()
1906  << "expects pack_paddings to contain booleans (0/1), found "
1907  << getPackPaddings();
1908  }
1909 
1910  SmallVector<int64_t> paddingDimensions =
1911  extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
1912  if (any_of(paddingDimensions,
1913  [](int64_t paddingDimension) { return paddingDimension < 0; })) {
1914  return emitOpError() << "expects padding_dimensions to contain positive "
1915  "integers, found "
1916  << getPaddingDimensions();
1917  }
1918  if (!getMixedPadToMultipleOf().empty()) {
1919  if (getMixedPadToMultipleOf().size() != paddingDimensions.size()) {
1920  return emitOpError() << "expects as many multiples as padding_dimensions";
1921  }
1922  }
1923  ArrayAttr transposes = getTransposePaddings();
1924  for (Attribute attr : transposes) {
1925  SmallVector<int64_t> transpose = extractFromIntegerArrayAttr<int64_t>(attr);
1926  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
1927  if (!std::is_permutation(sequence.begin(), sequence.end(),
1928  transpose.begin(), transpose.end())) {
1929  return emitOpError()
1930  << "expects transpose_paddings to be a permutation, found "
1931  << attr;
1932  }
1933  }
1934  if (getCopyBackOp() !=
1935  bufferization::MaterializeInDestinationOp::getOperationName() &&
1936  getCopyBackOp() != linalg::CopyOp::getOperationName() &&
1937  getCopyBackOp() != kCopyOpNone)
1938  return emitOpError() << "invalid copy_back_op";
1939  return success();
1940 }
1941 
1942 //===---------------------------------------------------------------------===//
1943 // HoistPadOp
1944 //===---------------------------------------------------------------------===//
1945 
1946 DiagnosedSilenceableFailure transform::HoistPadBuildPackingLoopNestOp::apply(
1947  transform::TransformRewriter &rewriter,
1948  transform::TransformResults &transformResults,
1949  transform::TransformState &state) {
1950  auto targetOps = state.getPayloadOps(getTarget());
1951  auto loopOps = state.getPayloadOps(getLoop());
1952  if (!llvm::hasSingleElement(targetOps) || !llvm::hasSingleElement(loopOps)) {
1953  return emitDefiniteFailure()
1954  << "requires exactly one target and one loop handle (got "
1955  << llvm::range_size(targetOps) << " and "
1956  << llvm::range_size(loopOps) << ")";
1957  }
1958 
1959  auto padOp = dyn_cast_or_null<tensor::PadOp>(*targetOps.begin());
1960  auto loopOp = dyn_cast_or_null<scf::ForOp>(*loopOps.begin());
1961  if (!padOp || !loopOp)
1962  return emitDefiniteFailure() << "requires exactly 2 non-null handles";
1963 
1965  linalg::detail::buildPackingLoopNest(rewriter, padOp, loopOp,
1966  getTranspose());
1967  if (failed(result))
1968  return emitDefiniteFailure() << "could not build packing loop nest";
1969 
1970  if (result->clonedLoopIvs.empty()) {
1971  transformResults.set(cast<OpResult>(getPackingLoop()),
1972  {result->hoistedPadOp.getOperation()});
1974  }
1975  auto outerPackedLoop =
1976  scf::getForInductionVarOwner(result->clonedLoopIvs.front());
1977  transformResults.set(cast<OpResult>(getPackingLoop()),
1978  {outerPackedLoop.getOperation()});
1980 }
1981 
1983  ArrayRef<int64_t> transpose = getTranspose();
1984  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
1985  if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),
1986  transpose.end())) {
1987  return emitOpError() << "expects transpose to be a permutation, found "
1988  << getTranspose();
1989  }
1990  return success();
1991 }
1992 
1993 void transform::HoistPadBuildPackingLoopNestOp::getEffects(
1995  transform::onlyReadsHandle(getTarget(), effects);
1996  transform::onlyReadsHandle(getLoop(), effects);
1997  transform::producesHandle(getPackingLoop(), effects);
1998  transform::modifiesPayload(effects);
1999 }
2000 
2002 transform::HoistPadOp::applyToOne(transform::TransformRewriter &rewriter,
2003  tensor::PadOp target,
2005  transform::TransformState &state) {
2006  tensor::PadOp hoistedPadOp;
2007  SmallVector<GenericOp> transposeOps;
2008  FailureOr<Value> result =
2009  hoistPaddingOnTensors(rewriter, target, getNumLoops(), getTranspose(),
2010  hoistedPadOp, transposeOps);
2011  if (succeeded(result)) {
2012  // We need to perform our own replacement here because this API is still
2013  // used in patterns that "pad and hoist", for which the replacement values
2014  // need to be different.
2015  // TODO: clean this up and stop "pad and hoist" behavior more globally now
2016  // that we have more composable abstractions.
2017  rewriter.replaceOp(target, *result);
2018  results.push_back(hoistedPadOp);
2020  }
2021  return emitDefaultSilenceableFailure(target);
2022 }
2023 
2025  ArrayRef<int64_t> transpose = getTranspose();
2026  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2027  if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),
2028  transpose.end())) {
2029  return emitOpError() << "expects transpose to be a permutation, found "
2030  << getTranspose();
2031  }
2032  return success();
2033 }
2034 
2035 //===----------------------------------------------------------------------===//
2036 // PromoteOp
2037 //===----------------------------------------------------------------------===//
2038 
2040 transform::PromoteOp::applyToOne(transform::TransformRewriter &rewriter,
2041  LinalgOp target,
2043  transform::TransformState &state) {
2044  LinalgPromotionOptions promotionOptions;
2045  if (!getOperandsToPromote().empty())
2046  promotionOptions = promotionOptions.setOperandsToPromote(
2047  extractFromIntegerArrayAttr<int64_t>(getOperandsToPromote()));
2048  if (getUseFullTilesByDefault())
2049  promotionOptions = promotionOptions.setUseFullTileBuffersByDefault(
2050  getUseFullTilesByDefault());
2051  if (getUseAlloca())
2052  promotionOptions = promotionOptions.setUseAlloca(getUseAlloca());
2053  if (!getUseFullTileBuffers().empty())
2054  promotionOptions = promotionOptions.setUseFullTileBuffers(
2055  llvm::to_vector(getUseFullTileBuffers().getAsValueRange<BoolAttr>()));
2056  if (getAlignment().has_value())
2057  promotionOptions = promotionOptions.setAlignment(*getAlignment());
2058  if (getMemorySpace().has_value())
2059  promotionOptions = promotionOptions.setMemorySpace(*getMemorySpace());
2060 
2061  if (getMapping().has_value()) {
2062  // The mapping should only contain an element
2063  auto mapping = *getMapping();
2064  if (mapping.size() > 1)
2065  return emitDefaultDefiniteFailure(target);
2066 
2067  auto addressSpace = cast<mlir::gpu::GPUMemorySpaceMappingAttr>(mapping[0]);
2068 
2069  if (addressSpace.getAddressSpace() ==
2070  mlir::gpu::GPUDialect::getWorkgroupAddressSpace()) {
2071  promotionOptions =
2072  promotionOptions
2076  .setUseFullTileBuffers({false, false});
2077  } else if (addressSpace.getAddressSpace() ==
2078  mlir::gpu::GPUDialect::getPrivateAddressSpace()) {
2079  promotionOptions =
2080  promotionOptions
2084  .setUseFullTileBuffers({false, false});
2085  } else {
2086  return emitDefaultDefiniteFailure(target);
2087  }
2088  }
2089 
2090  if (failed(promoteSubviewsPrecondition(target, promotionOptions)))
2091  return emitDefaultDefiniteFailure(target);
2092 
2093  rewriter.setInsertionPoint(target);
2094  FailureOr<LinalgOp> res = promoteSubViews(rewriter, target, promotionOptions);
2095  if (failed(res))
2096  return emitDefaultDefiniteFailure(target);
2097  results.push_back(target);
2099 }
2100 
2101 //===----------------------------------------------------------------------===//
2102 // ReplaceOp
2103 //===----------------------------------------------------------------------===//
2104 
2106 transform::ReplaceOp::apply(transform::TransformRewriter &rewriter,
2107  TransformResults &transformResults,
2108  TransformState &state) {
2109  auto payload = state.getPayloadOps(getTarget());
2110 
2111  // Check for invalid targets.
2112  for (Operation *target : payload) {
2113  if (target->getNumOperands() > 0)
2114  return emitDefiniteFailure() << "expected target without operands";
2115  if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
2116  target->getNumRegions() > 0)
2117  return emitDefiniteFailure()
2118  << "expected target that is isolated from above";
2119  }
2120 
2121  // Clone and replace.
2122  Operation *pattern = &getBodyRegion().front().front();
2123  SmallVector<Operation *> replacements;
2124  for (Operation *target : payload) {
2125  if (getOperation()->isAncestor(target))
2126  continue;
2127  rewriter.setInsertionPoint(target);
2128  Operation *replacement = rewriter.clone(*pattern);
2129  rewriter.replaceOp(target, replacement->getResults());
2130  replacements.push_back(replacement);
2131  }
2132  transformResults.set(cast<OpResult>(getReplacement()), replacements);
2134 }
2135 
2136 void transform::ReplaceOp::getEffects(
2138  consumesHandle(getTarget(), effects);
2139  producesHandle(getReplacement(), effects);
2140  modifiesPayload(effects);
2141 }
2142 
2144  if (!getBodyRegion().hasOneBlock())
2145  return emitOpError() << "expected one block";
2146  if (std::distance(getBodyRegion().front().begin(),
2147  getBodyRegion().front().end()) != 1)
2148  return emitOpError() << "expected one operation in block";
2149  Operation *replacement = &getBodyRegion().front().front();
2150  if (replacement->getNumOperands() > 0)
2151  return replacement->emitOpError()
2152  << "expected replacement without operands";
2153  if (!replacement->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
2154  replacement->getNumRegions() > 0)
2155  return replacement->emitOpError()
2156  << "expect op that is isolated from above";
2157  return success();
2158 }
2159 
2160 //===----------------------------------------------------------------------===//
2161 // ScalarizeOp
2162 //===----------------------------------------------------------------------===//
2163 
2165 transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
2166  LinalgOp target,
2168  transform::TransformState &state) {
2169  scf::SCFTilingOptions tilingOptions;
2170  tilingOptions.setTileSizeComputationFunction([&](OpBuilder &b, Operation *) {
2171  SmallVector<OpFoldResult> tileSizes;
2172  Location loc = target.getLoc();
2173  SmallVector<OpFoldResult> allShapeSizes =
2174  target.createFlatListOfOperandDims(b, loc);
2175  AffineMap map = target.getShapesToLoopsMap();
2176  if (!map)
2177  return tileSizes;
2178  SmallVector<OpFoldResult> shapeSizes =
2180  allShapeSizes);
2181  // If the shape size is dynamic, tile by 1.
2182  // Otherwise, do not tile (i.e. tile size 0).
2183  for (OpFoldResult shapeSize : shapeSizes) {
2184  tileSizes.push_back(getConstantIntValue(shapeSize) ? b.getIndexAttr(0)
2185  : b.getIndexAttr(1));
2186  }
2187  return tileSizes;
2188  });
2189  SmallVector<int64_t> emptyTileSizes;
2190  rewriter.setInsertionPoint(target);
2191  FailureOr<scf::SCFTilingResult> maybeTilingResult = tileUsingSCF(
2192  rewriter, cast<TilingInterface>(target.getOperation()), tilingOptions);
2193  if (failed(maybeTilingResult))
2194  return emitDefaultDefiniteFailure(target);
2195 
2196  if (target->getNumResults())
2197  rewriter.replaceOp(target, maybeTilingResult->replacements);
2198  else
2199  rewriter.eraseOp(target);
2200 
2201  results.reserve(maybeTilingResult->tiledOps.size());
2202  for (Operation *tiled : maybeTilingResult->tiledOps)
2203  results.push_back(tiled);
2205 }
2206 
2207 //===----------------------------------------------------------------------===//
2208 // ConvertToLoopsOp
2209 //===----------------------------------------------------------------------===//
2210 
2212 transform::ConvertToLoopsOp::apply(transform::TransformRewriter &rewriter,
2213  transform::TransformResults &results,
2214  transform::TransformState &state) {
2216  for (Operation *target : state.getPayloadOps(getTarget())) {
2217  auto tilingOp = dyn_cast<TilingInterface>(*target);
2218  if (!target) {
2220  emitSilenceableError()
2221  << "expected the payload to implement TilingInterface";
2222  diag.attachNote(target->getLoc()) << "payload op";
2223  return diag;
2224  }
2225  rewriter.setInsertionPoint(target);
2226  FailureOr<SmallVector<scf::ForOp>> generatedLoops =
2227  scf::lowerToLoopsUsingSCFForOp(rewriter, tilingOp);
2228  if (failed(generatedLoops))
2229  return emitDefaultDefiniteFailure(target);
2230  for (scf::ForOp &loop : *generatedLoops) {
2231  loops.push_back(loop.getOperation());
2232  }
2233  rewriter.eraseOp(target);
2234  }
2235  results.set(cast<OpResult>(getResult()), loops);
2237 }
2238 
2239 //===----------------------------------------------------------------------===//
2240 // RewriteInDestinationPassingStyleOp
2241 //===----------------------------------------------------------------------===//
2242 
2244 transform::RewriteInDestinationPassingStyleOp::applyToOne(
2245  transform::TransformRewriter &rewriter, Operation *target,
2247  transform::TransformState &state) {
2249  rewriter.setInsertionPoint(target);
2250  FailureOr<Operation *> maybeResult =
2252  .Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp>(
2253  [&rewriter](auto op) {
2254  return rewriteInDestinationPassingStyle(rewriter, op);
2255  });
2256  if (failed(maybeResult))
2257  return emitDefaultSilenceableFailure(target);
2258  results.push_back(*maybeResult);
2260 }
2261 
2262 //===----------------------------------------------------------------------===//
2263 // SplitOp
2264 //===----------------------------------------------------------------------===//
2265 
2267 SplitOp::apply(transform::TransformRewriter &rewriter,
2268  TransformResults &results, TransformState &state) {
2269  // Collect the dynamic split points if provided.
2270  SmallVector<Operation *> payload =
2271  llvm::to_vector(state.getPayloadOps(getTarget()));
2272  SmallVector<OpFoldResult> splitPoints;
2273  splitPoints.reserve(payload.size());
2274  if (getDynamicSplitPoint()) {
2276  if (isa<TransformHandleTypeInterface>(getDynamicSplitPoint().getType())) {
2277  splitPoints = llvm::to_vector(llvm::map_range(
2278  state.getPayloadOps(getDynamicSplitPoint()), [&](Operation *op) {
2279  if (op->getNumResults() != 1 ||
2280  !op->getResult(0).getType().isIndex()) {
2281  diag = emitSilenceableError()
2282  << "expected dynamic split point handle to point to a "
2283  "single-result index-typed op";
2284  diag.attachNote(op->getLoc()) << "dynamic split point";
2285  }
2286  return OpFoldResult(op->getResult(0));
2287  }));
2288  } else {
2289  splitPoints = llvm::to_vector(
2290  llvm::map_range(state.getParams(getDynamicSplitPoint()),
2291  [](Attribute attr) { return OpFoldResult(attr); }));
2292  }
2293  if (diag.isSilenceableFailure())
2294  return diag;
2295 
2296  if (splitPoints.size() != payload.size()) {
2297  return emitDefiniteFailure()
2298  << "expected the dynamic split point handle to point to as "
2299  "many operations ("
2300  << splitPoints.size() << ") as the target handle ("
2301  << payload.size() << ")";
2302  }
2303  } else {
2304  splitPoints.resize(payload.size(),
2305  rewriter.getIndexAttr(getStaticSplitPoint()));
2306  }
2307 
2308  // Split each target operation.
2309  SmallVector<Operation *> first, second;
2310  Operation *noSecondPart = nullptr;
2311  for (const auto &pair : llvm::zip(payload, splitPoints)) {
2312  Operation *target = std::get<0>(pair);
2313  auto linalgOp = dyn_cast<LinalgOp>(target);
2314  if (!linalgOp) {
2315  auto diag = emitSilenceableError() << "only applies to structured ops";
2316  diag.attachNote(target->getLoc()) << "target op";
2317  return diag;
2318  }
2319 
2320  if (getDimension() >= linalgOp.getNumLoops()) {
2321  auto diag = emitSilenceableError() << "dimension " << getDimension()
2322  << " does not exist in target op";
2323  diag.attachNote(target->getLoc()) << "target op";
2324  return diag;
2325  }
2326 
2327  rewriter.setInsertionPoint(linalgOp);
2328  std::tie(first.emplace_back(), second.emplace_back()) = linalg::splitOp(
2329  rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2330  getDimension(), std::get<1>(pair));
2331 
2332  // Propagate errors.
2333  if (!first.back() && !second.back()) {
2334  auto diag = emitDefiniteFailure() << "internal failure in splitting";
2335  diag.attachNote(target->getLoc()) << "target op";
2336  return diag;
2337  }
2338 
2339  // Do not add null second parts.
2340  if (!second.back()) {
2341  noSecondPart = target;
2342  second.pop_back();
2343  }
2344  }
2345 
2346  if (second.size() != first.size() && !second.empty()) {
2347  auto diag = emitSilenceableError()
2348  << "splitting does not produce the second part for a subset "
2349  "of targets";
2350  diag.attachNote() << "expected splitting to produce the second part of all "
2351  "or none of the targets";
2352  diag.attachNote(noSecondPart->getLoc())
2353  << "first target with no second part";
2354  return diag;
2355  }
2356 
2357  results.set(cast<OpResult>(getFirst()), first);
2358  results.set(cast<OpResult>(getSecond()), second);
2360 }
2361 
2362 void SplitOp::getEffects(
2364  consumesHandle(getTarget(), effects);
2365  if (getDynamicSplitPoint())
2366  onlyReadsHandle(getDynamicSplitPoint(), effects);
2367  producesHandle(getResults(), effects);
2368  modifiesPayload(effects);
2369 }
2370 
2372  OpAsmParser::UnresolvedOperand target, dynamicSplitPoint;
2373  IntegerAttr staticSplitPoint;
2374  if (parser.parseOperand(target) || parser.parseKeyword("after"))
2375  return failure();
2376 
2377  OptionalParseResult dynamicPointParseResult =
2378  parser.parseOptionalOperand(dynamicSplitPoint);
2379  if (!dynamicPointParseResult.has_value()) {
2380  int64_t staticSplitPointValue;
2381  if (failed(parser.parseInteger(staticSplitPointValue)))
2382  return failure();
2383 
2384  staticSplitPoint =
2385  parser.getBuilder().getI64IntegerAttr(staticSplitPointValue);
2386  }
2387 
2388  Type targetType;
2389  if (parser.parseOptionalAttrDict(result.attributes) ||
2390  parser.parseColonType(targetType) ||
2391  parser.resolveOperand(target, targetType, result.operands)) {
2392  return failure();
2393  }
2394  if (dynamicPointParseResult.has_value()) {
2395  Type splitPointType;
2396  if (failed(*dynamicPointParseResult) || parser.parseComma() ||
2397  parser.parseType(splitPointType) ||
2398  parser.resolveOperand(dynamicSplitPoint, splitPointType,
2399  result.operands)) {
2400  return failure();
2401  }
2402 
2403  staticSplitPoint =
2404  parser.getBuilder().getI64IntegerAttr(ShapedType::kDynamic);
2405  }
2406 
2407  result.addAttribute(
2408  SplitOp::getStaticSplitPointAttrName(result.name).getValue(),
2409  staticSplitPoint);
2410  result.addTypes({targetType, targetType});
2411  return success();
2412 }
2413 
2414 void SplitOp::print(OpAsmPrinter &printer) {
2415  printer << " " << getTarget() << " after ";
2416  int64_t staticSplitSize = static_cast<int64_t>(getStaticSplitPoint());
2417  if (staticSplitSize != ShapedType::kDynamic)
2418  printer << staticSplitSize;
2419  else
2420  printer << getDynamicSplitPoint();
2421  printer << " ";
2422  printer.printOptionalAttrDict(getOperation()->getAttrs(),
2423  {getStaticSplitPointAttrName()});
2424  printer << " : " << getTarget().getType();
2425  if (staticSplitSize == ShapedType::kDynamic)
2426  printer << ", " << getDynamicSplitPoint().getType();
2427 }
2428 
2430  if ((static_cast<int64_t>(getStaticSplitPoint()) != ShapedType::kDynamic) ^
2431  (getDynamicSplitPoint() == nullptr)) {
2432  return emitOpError() << "expects either a dynamic or a static split "
2433  "point to be provided";
2434  }
2435  return success();
2436 }
2437 
2438 //===----------------------------------------------------------------------===//
2439 // SplitReductionOp
2440 //===----------------------------------------------------------------------===//
2441 
2442 void transform::SplitReductionOp::build(
2443  OpBuilder &builder, OperationState &result, Value target,
2444  int64_t splitFactor, int64_t insertSplitDimension, bool innerParallel,
2445  bool useScalingAlgorithm, bool useAlloc) {
2446  MLIRContext *ctx = builder.getContext();
2447  result.addOperands(target);
2448  result.addAttribute(SplitReductionOp::getSplitFactorAttrName(result.name),
2449  builder.getI64IntegerAttr(splitFactor));
2450  result.addAttribute(
2451  SplitReductionOp::getInsertSplitDimensionAttrName(result.name),
2452  builder.getI64IntegerAttr(insertSplitDimension));
2453  if (innerParallel) {
2454  result.addAttribute(SplitReductionOp::getInnerParallelAttrName(result.name),
2455  builder.getUnitAttr());
2456  }
2457  if (useScalingAlgorithm) {
2458  result.addAttribute(
2459  SplitReductionOp::getUseScalingAlgorithmAttrName(result.name),
2460  builder.getUnitAttr());
2461  }
2462  if (useAlloc) {
2463  result.addAttribute(SplitReductionOp::getUseAllocAttrName(result.name),
2464  builder.getUnitAttr());
2465  }
2466  auto resultType = transform::AnyOpType::get(ctx);
2467  result.addTypes({resultType, resultType, resultType, resultType});
2468 }
2469 
2470 DiagnosedSilenceableFailure transform::SplitReductionOp::applyToOne(
2471  transform::TransformRewriter &rewriter, LinalgOp target,
2473  transform::TransformState &state) {
2474  ControlSplitReductionFn splitFn = [&](LinalgOp) {
2475  return linalg::SplitReductionOptions{int64_t(getSplitFactor()),
2476  unsigned(getInsertSplitDimension()),
2477  bool(getInnerParallel())};
2478  };
2479  rewriter.setInsertionPoint(target);
2480  FailureOr<SplitReductionResult> splitResult =
2481  (getUseScalingAlgorithm())
2482  ? splitReductionByScaling(rewriter, target, splitFn, getUseAlloc())
2483  : splitReduction(rewriter, target, splitFn, getUseAlloc());
2484  if (failed(splitResult))
2485  return emitDefaultDefiniteFailure(target);
2486 
2487  results.push_back(splitResult->initOrAlloc);
2488  results.push_back(splitResult->fillOp);
2489  results.push_back(splitResult->splitLinalgOp);
2490  results.push_back(splitResult->resultCombiningLinalgOp);
2492 }
2493 
2494 //===----------------------------------------------------------------------===//
2495 // TileReductionUsingForOp
2496 //===----------------------------------------------------------------------===//
2497 
2498 void transform::TileReductionUsingForOp::build(
2499  OpBuilder &builder, OperationState &result, Value target,
2500  ArrayRef<int64_t> staticTileSizes) {
2501  // Call the default builder.
2502  // This is future-proof re mixed static-dynamic and setting up the proper
2503  // operands segment sizes attributes for multiple variadic operands.
2504  // In the absence of this, horrible bugs ensue.
2505  // TODO: support mixed static-dynamic (see TileUsingForallOp).
2506  MLIRContext *ctx = builder.getContext();
2507  auto opTy = transform::AnyOpType::get(ctx);
2508  auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
2509  build(builder, result,
2510  /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
2511  /*target=*/target,
2512  /*tile_sizes=*/staticTileSizesAttr);
2513 }
2514 
2515 DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
2516  transform::TransformRewriter &rewriter, LinalgOp target,
2518  transform::TransformState &state) {
2519  rewriter.setInsertionPoint(target);
2521  rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
2522  getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes())));
2523 
2524  if (failed(result))
2525  return emitDefaultSilenceableFailure(target);
2526  for (Value initValue : result->initialValues)
2527  results.push_back(initValue.getDefiningOp());
2528  results.push_back(result->parallelTiledOp);
2529  results.push_back(result->mergeOp);
2530  results.push_back(result->loops.front());
2532 }
2533 
2534 //===----------------------------------------------------------------------===//
2535 // TileReductionUsingForallOp
2536 //===----------------------------------------------------------------------===//
2537 
2538 void transform::TileReductionUsingForallOp::build(
2539  OpBuilder &builder, OperationState &result, Value target,
2540  ArrayRef<int64_t> staticNumThreads, ArrayRef<int64_t> staticTileSizes,
2541  ArrayAttr mapping) {
2542  // Call the default builder.
2543  // This is future-proof re mixed static-dynamic and setting up the proper
2544  // operands segment sizes attributes for multiple variadic operands.
2545  // In the absence of this, horrible bugs ensue.
2546  // TODO: support mixed static-dynamic (see TileUsingForallOp).
2547  MLIRContext *ctx = builder.getContext();
2548  auto opTy = transform::AnyOpType::get(ctx);
2549  auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads);
2550  auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
2551  build(builder, result,
2552  /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
2553  /*target=*/target,
2554  /*num_threads=*/staticNumThreadsAttr,
2555  /*tile_sizes=*/staticTileSizesAttr,
2556  /*mapping=*/mapping);
2557 }
2558 
2559 DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
2560  transform::TransformRewriter &rewriter, LinalgOp target,
2562  transform::TransformState &state) {
2563  rewriter.setInsertionPoint(target);
2564  SmallVector<OpFoldResult> numThreads =
2565  getAsOpFoldResult(rewriter.getI64ArrayAttr(getNumThreads()));
2566  SmallVector<OpFoldResult> tileSizes =
2567  getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes()));
2570  rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
2571  numThreads, tileSizes, getMapping());
2572 
2573  if (failed(result)) {
2574  auto diag = emitSilenceableError() << "could not tile reduction";
2575  diag.attachNote(target.getLoc()) << "target operation";
2576  return diag;
2577  }
2578  for (Value initValue : result->initialValues)
2579  results.push_back(initValue.getDefiningOp());
2580  results.push_back(result->parallelTiledOp);
2581  results.push_back(result->mergeOp);
2582  results.push_back(result->loops);
2584 }
2585 
2586 //===----------------------------------------------------------------------===//
2587 // TileUsingForOp
2588 //===----------------------------------------------------------------------===//
2589 
2590 void transform::TileUsingForOp::build(
2591  OpBuilder &builder, OperationState &result, TypeRange loopTypes,
2592  Value target, ArrayRef<int64_t> staticTileSizes,
2593  ArrayRef<int64_t> interchange,
2594  std::optional<ArrayRef<bool>> scalableSizes) {
2595  return build(builder, result, loopTypes,
2596  /*target=*/target,
2597  /*mixedTileSizes=*/
2598  getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
2599  interchange, scalableSizes);
2600 }
2601 
2602 void transform::TileUsingForOp::build(
2603  OpBuilder &builder, OperationState &result, Value target,
2604  ArrayRef<int64_t> staticTileSizes, ArrayRef<int64_t> interchange,
2605  std::optional<ArrayRef<bool>> scalableSizes) {
2606  build(builder, result, target,
2607  getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
2608  interchange, scalableSizes);
2609 }
2610 
2611 void transform::TileUsingForOp::build(
2612  OpBuilder &builder, OperationState &result, Value target,
2613  ArrayRef<OpFoldResult> mixedTileSizes, ArrayRef<int64_t> interchange,
2614  std::optional<ArrayRef<bool>> scalableSizes) {
2615  // Loop types are automaticaly splat by the callee, setting up one is
2616  // enough.
2617  SmallVector<Type> loopTypes(1, builder.getType<transform::AnyOpType>());
2618  build(builder, result, loopTypes, target, mixedTileSizes, interchange,
2619  scalableSizes);
2620 }
2621 
2622 void transform::TileUsingForOp::build(
2623  OpBuilder &builder, OperationState &result, TypeRange loopTypes,
2624  Value target, ArrayRef<OpFoldResult> mixedTileSizes,
2625  ArrayRef<int64_t> interchange,
2626  std::optional<ArrayRef<bool>> scalableSizes) {
2627  SmallVector<int64_t> staticTileSizes;
2628  SmallVector<Value> dynamicTileSizes;
2629  dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
2630  // Call the default builder which sets up the proper operands segment sizes
2631  // attributes for multiple variadic operands. In the absence of this,
2632  // horrible bugs ensue.
2633  auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
2634  unsigned numExpectedLoops =
2635  staticTileSizes.size() - llvm::count(staticTileSizes, 0);
2636  SmallVector<Type> resultTypes;
2637  resultTypes.reserve(numExpectedLoops);
2638  assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) &&
2639  "expected one loop type or as many as loops");
2640  if (loopTypes.size() == 1)
2641  resultTypes.append(numExpectedLoops, loopTypes[0]);
2642  else
2643  llvm::append_range(resultTypes, loopTypes);
2644  SmallVector<bool> expandedScalableSizes(mixedTileSizes.size(), false);
2645  if (scalableSizes.has_value())
2646  expandedScalableSizes.assign(scalableSizes->begin(), scalableSizes->end());
2647  build(builder, result, /*tiled_linalg_op=*/target.getType(),
2648  /*loops=*/resultTypes,
2649  /*target=*/target,
2650  /*dynamic_sizes=*/dynamicTileSizes,
2651  /*static_sizes=*/staticTileSizesAttr,
2652  /*interchange=*/builder.getDenseI64ArrayAttr(interchange),
2653  /*scalable_sizes=*/expandedScalableSizes);
2654 }
2655 
2657  if (getMixedSizes().size() != getScalableSizes().size())
2658  return emitOpError("expected same number of sizes (")
2659  << getMixedSizes().size() << ") and scalable sizes ()"
2660  << getScalableSizes().size() << ")";
2661  return success();
2662 }
2663 
2665 transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter,
2666  TransformResults &transformResults,
2667  TransformState &state) {
2668  ArrayRef<int64_t> tileSizes = getStaticSizes();
2669 
2670  SmallVector<Operation *> targets =
2671  llvm::to_vector(state.getPayloadOps(getTarget()));
2672  SmallVector<SmallVector<Operation *>> dynamicSizeProducers;
2674  dynamicSizeProducers.reserve(getDynamicSizes().size());
2675  paramSizes.reserve(getDynamicSizes().size());
2676  for (Value transformValue : getDynamicSizes()) {
2677  if (isa<ParamType>(transformValue.getType())) {
2678  dynamicSizeProducers.push_back({});
2679  ArrayRef<Attribute> params = state.getParams(transformValue);
2680  paramSizes.push_back(
2681  llvm::to_vector(llvm::map_range(params, [](Attribute attr) {
2682  return cast<IntegerAttr>(attr).getValue().getSExtValue();
2683  })));
2684 
2685  if (paramSizes.back().size() != targets.size()) {
2687  emitSilenceableError()
2688  << "expected as many parameter values ("
2689  << dynamicSizeProducers.back().size() << ") as target ops ("
2690  << targets.size() << ")";
2691  diag.attachNote(transformValue.getLoc()) << "for this parameter";
2692  return diag;
2693  }
2694 
2695  continue;
2696  }
2697  paramSizes.push_back({});
2698  dynamicSizeProducers.push_back(
2699  llvm::to_vector(state.getPayloadOps(transformValue)));
2700 
2701  if (dynamicSizeProducers.back().size() != targets.size()) {
2703  emitSilenceableError()
2704  << "expected as many dynamic size-producing operations ("
2705  << dynamicSizeProducers.back().size() << ") as target ops ("
2706  << targets.size() << ")";
2707  diag.attachNote(transformValue.getLoc()) << "for this handle";
2708  return diag;
2709  }
2710 
2711  for (Operation *op : dynamicSizeProducers.back()) {
2712  if (op->getNumResults() == 1 &&
2713  isa<IndexType>(op->getResult(0).getType())) {
2714  continue;
2715  }
2716 
2718  emitSilenceableError() << "expected sizes to be produced by ops "
2719  "with a single index-type result";
2720  diag.attachNote(op->getLoc()) << "size producer op";
2721  diag.attachNote(transformValue.getLoc()) << "for this handle";
2722  return diag;
2723  }
2724  }
2725 
2728  loops.resize(getLoops().size());
2729  auto scalableSizes = getScalableSizes();
2730  for (auto [i, op] : llvm::enumerate(targets)) {
2731  auto tilingInterface = dyn_cast<TilingInterface>(op);
2732  if (!tilingInterface) {
2734  emitSilenceableError()
2735  << "only ops implementing TilingInterface are supported";
2736  diag.attachNote(op->getLoc()) << "target op";
2737  return diag;
2738  }
2739  if (tileSizes.size() > tilingInterface.getLoopIteratorTypes().size()) {
2741  emitSilenceableError()
2742  << "too many tiles provided, expected at most "
2743  << tilingInterface.getLoopIteratorTypes().size() << " found "
2744  << tileSizes.size();
2745  diag.attachNote(op->getLoc()) << "target op";
2746  return diag;
2747  }
2748 
2749  scf::SCFTilingOptions tilingOptions;
2750  if (tileSizes.empty()) {
2751  tilingOptions.setTileSizeComputationFunction(
2753  return {};
2754  });
2755  } else {
2756  tilingOptions.setTileSizeComputationFunction([&, index = i](OpBuilder &b,
2757  Operation *) {
2759  sizes.reserve(tileSizes.size());
2760  unsigned dynamicIdx = 0;
2761 
2762  for (auto [ofrIdx, ofr] : llvm::enumerate(getMixedSizes())) {
2763  if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) {
2764  if (scalableSizes[ofrIdx]) {
2765  auto val = b.create<arith::ConstantIndexOp>(
2766  getLoc(), cast<IntegerAttr>(attr).getInt());
2767  Value vscale =
2768  b.create<vector::VectorScaleOp>(getLoc(), b.getIndexType());
2769  sizes.push_back(
2770  b.create<arith::MulIOp>(getLoc(), val, vscale).getResult());
2771  } else {
2772  sizes.push_back(attr);
2773  }
2774  continue;
2775  }
2776  ArrayRef<Operation *> dynamicSizes = dynamicSizeProducers[dynamicIdx];
2777  ArrayRef<int64_t> params = paramSizes[dynamicIdx];
2778  ++dynamicIdx;
2779  assert((dynamicSizes.empty() ^ params.empty()) &&
2780  "expected either dynamic sizes or parameters");
2781  if (!params.empty()) {
2782  sizes.push_back(b.getIndexAttr(params[index]));
2783  } else {
2784  sizes.push_back(dynamicSizes[index]->getResult(0));
2785  }
2786  }
2787  return sizes;
2788  });
2789  }
2790 
2791  tilingOptions.setInterchange(getInterchange());
2792  FailureOr<scf::SCFTilingResult> maybeTilingResult =
2793  tileUsingSCF(rewriter, tilingInterface, tilingOptions);
2794  if (failed(maybeTilingResult))
2796 
2797  rewriter.replaceOp(op, maybeTilingResult->replacements);
2798 
2799  tiled.append(maybeTilingResult->tiledOps);
2800  for (const auto &en2 : llvm::enumerate(maybeTilingResult->loops))
2801  loops[en2.index()].push_back(en2.value());
2802  }
2803 
2804  transformResults.set(cast<OpResult>(getTiledLinalgOp()), tiled);
2805  for (const auto &en : llvm::enumerate(loops))
2806  transformResults.set(cast<OpResult>(getLoops()[en.index()]), en.value());
2807 
2809 }
2810 
2812  ValueRange dynamic = getDynamicSizes();
2813  ArrayRef<int64_t> tileSizes = getStaticSizes();
2814  SmallVector<OpFoldResult> results;
2815  results.reserve(tileSizes.size());
2816  unsigned dynamicPos = 0;
2817  Builder builder(getContext());
2818  for (int64_t size : tileSizes) {
2819  if (size == ShapedType::kDynamic) {
2820  results.push_back(dynamic[dynamicPos++]);
2821  } else {
2822  results.push_back(builder.getIndexAttr(size));
2823  }
2824  }
2825  return results;
2826 }
2827 
2828 void transform::TileUsingForOp::getEffects(
2830  consumesHandle(getTarget(), effects);
2831  onlyReadsHandle(getDynamicSizes(), effects);
2832  producesHandle(getTiledLinalgOp(), effects);
2833  producesHandle(getLoops(), effects);
2834  modifiesPayload(effects);
2835 }
2836 
2837 //===----------------------------------------------------------------------===//
2838 // TileUsingForallOp
2839 //===----------------------------------------------------------------------===//
2840 
2841 void transform::TileUsingForallOp::build(OpBuilder &builder,
2842  OperationState &result, Value target,
2843  ArrayRef<int64_t> staticTileSizes,
2845  ArrayAttr mapping) {
2846  return build(builder, result,
2847  /*target=*/target,
2848  /*mixedTileSizes=*/
2849  getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
2850  /*_=*/TileSizesSpec(),
2851  /*mapping=*/mapping);
2852 }
2853 
2854 void transform::TileUsingForallOp::build(OpBuilder &builder,
2855  OperationState &result, Value target,
2856  ArrayRef<OpFoldResult> mixedTileSizes,
2858  ArrayAttr mapping) {
2859  SmallVector<int64_t> staticTileSizes;
2860  SmallVector<Value> dynamicTileSizes;
2861  dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
2862  // Call the default builder which sets up the proper operands segment sizes
2863  // attributes for multiple variadic operands. In the absence of this,
2864  // horrible bugs ensue.
2865  MLIRContext *ctx = builder.getContext();
2866  auto operationType = transform::AnyOpType::get(ctx);
2867  auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
2868  build(builder, result,
2869  /*resultTypes=*/TypeRange{operationType, operationType},
2870  /*target=*/target,
2871  /*num_threads=*/ValueRange{},
2872  /*tile_sizes=*/dynamicTileSizes,
2873  /*packed_num_threads=*/Value(),
2874  /*packed_tile_sizes=*/Value(),
2875  /*static_num_threads=*/builder.getDenseI64ArrayAttr({}),
2876  /*static_tile_sizes=*/staticTileSizesAttr,
2877  /*mapping=*/mapping);
2878 }
2879 
2880 void transform::TileUsingForallOp::build(OpBuilder &builder,
2881  OperationState &result, Value target,
2882  ArrayRef<int64_t> staticNumThreads,
2884  ArrayAttr mapping) {
2885  return build(builder, result, target,
2886  getAsOpFoldResult(builder.getI64ArrayAttr(staticNumThreads)),
2887  NumThreadsSpec(), mapping);
2888 }
2889 
2890 void transform::TileUsingForallOp::build(OpBuilder &builder,
2891  OperationState &result, Value target,
2892  ArrayRef<OpFoldResult> mixedNumThreads,
2894  ArrayAttr mapping) {
2895  SmallVector<int64_t> staticNumThreads;
2896  SmallVector<Value> dynamicNumThreads;
2897  dispatchIndexOpFoldResults(mixedNumThreads, dynamicNumThreads,
2898  staticNumThreads);
2899  // Call the default builder which sets up the proper operands segment sizes
2900  // attributes for multiple variadic operands. In the absence of this,
2901  // horrible bugs ensue.
2902  MLIRContext *ctx = builder.getContext();
2903  auto operationType = transform::AnyOpType::get(ctx);
2904  auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads);
2905  build(builder, result,
2906  /*resultTypes=*/TypeRange{operationType, operationType},
2907  /*target=*/target,
2908  /*num_threads=*/dynamicNumThreads,
2909  /*tile_sizes=*/ValueRange{},
2910  /*packed_num_threads=*/Value(),
2911  /*packed_tile_sizes=*/Value(),
2912  /*static_num_threads=*/staticNumThreadsAttr,
2913  /*static_tile_sizes=*/builder.getDenseI64ArrayAttr({}),
2914  /*mapping=*/mapping);
2915 }
2916 
2918  RewriterBase &rewriter, transform::TransformState &state,
2919  TransformOpInterface transformOp, Operation *target,
2920  ArrayRef<OpFoldResult> mixedNumThreads,
2921  ArrayRef<OpFoldResult> mixedTileSizes, std::optional<ArrayAttr> mapping,
2922  linalg::ForallTilingResult &tilingResult) {
2923  // Transform all targets one by one.
2924  auto tileableOp = dyn_cast<TilingInterface>(target);
2925  if (!tileableOp) {
2927  transformOp.emitSilenceableError()
2928  << "only TilingInterface ops are supported";
2929  diag.attachNote(target->getLoc()) << "target op";
2930  return diag;
2931  }
2932  rewriter.setInsertionPoint(tileableOp);
2933  FailureOr<linalg::ForallTilingResult> maybeTilingResult = failure();
2934  if (!mixedNumThreads.empty()) {
2935  maybeTilingResult =
2936  linalg::tileToForallOp(rewriter, tileableOp, mixedNumThreads, mapping);
2937  } else {
2938  maybeTilingResult = linalg::tileToForallOpUsingTileSizes(
2939  rewriter, tileableOp, mixedTileSizes, mapping);
2940  }
2941 
2942  if (failed(maybeTilingResult))
2943  return transformOp.emitDefaultSilenceableFailure(tileableOp);
2944  rewriter.replaceOp(tileableOp, maybeTilingResult->tileOp->getResults());
2945 
2946  tilingResult = *maybeTilingResult;
2948 }
2949 
2950 DiagnosedSilenceableFailure transform::TileUsingForallOp::apply(
2951  transform::TransformRewriter &rewriter,
2952  transform::TransformResults &transformResults,
2953  transform::TransformState &state) {
2954  auto transformOp = cast<TransformOpInterface>(getOperation());
2955 
2956  // Result payload ops.
2957  SmallVector<Operation *> tileOps;
2958  SmallVector<Operation *> tiledOps;
2959 
2960  // Unpack handles.
2961  SmallVector<OpFoldResult> mixedNumThreads;
2963  getPackedNumThreads()
2965  state, transformOp, mixedNumThreads, getPackedNumThreads())
2967  state, transformOp, mixedNumThreads, getMixedNumThreads());
2968  if (!status.succeeded())
2969  return status;
2970  SmallVector<OpFoldResult> mixedTileSizes;
2971  status = getPackedTileSizes()
2973  state, transformOp, mixedTileSizes, getPackedTileSizes())
2975  state, transformOp, mixedTileSizes, getMixedTileSizes());
2976  if (!status.succeeded())
2977  return status;
2978 
2979  for (Operation *target : state.getPayloadOps(getTarget())) {
2980  linalg::ForallTilingResult tilingResult;
2982  rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes,
2983  getMapping(), tilingResult);
2984  if (!diag.succeeded())
2985  return diag;
2986  tileOps.push_back(tilingResult.tileOp);
2987  tiledOps.push_back(tilingResult.tiledOp);
2988  }
2989 
2990  transformResults.set(cast<OpResult>(getForallOp()), tileOps);
2991  transformResults.set(cast<OpResult>(getTiledOp()), tiledOps);
2992 
2994 }
2995 
2996 void transform::TileUsingForallOp::getEffects(
2998  consumesHandle(getTarget(), effects);
2999  onlyReadsHandle(getTileSizes(), effects);
3000  onlyReadsHandle(getNumThreads(), effects);
3001  onlyReadsHandle(getPackedNumThreads(), effects);
3002  onlyReadsHandle(getPackedTileSizes(), effects);
3003  producesHandle(getResults(), effects);
3004  modifiesPayload(effects);
3005 }
3006 
3007 SmallVector<OpFoldResult> TileUsingForallOp::getMixedNumThreads() {
3008  Builder b(getContext());
3009  return getMixedValues(getStaticNumThreads(), getNumThreads(), b);
3010 }
3011 
3012 SmallVector<OpFoldResult> TileUsingForallOp::getMixedTileSizes() {
3013  Builder b(getContext());
3014  return getMixedValues(getStaticTileSizes(), getTileSizes(), b);
3015 }
3016 
3018  int numThreadsSpec = static_cast<int>(!getMixedNumThreads().empty()) +
3019  static_cast<int>(getPackedNumThreads() != Value());
3020  if (numThreadsSpec > 1)
3021  return emitOpError(
3022  "num_threads and packed_num_threads are mutually exclusive");
3023  int tileSizesSpec = static_cast<int>(!getMixedTileSizes().empty()) +
3024  static_cast<int>(getPackedTileSizes() != Value());
3025  if (tileSizesSpec > 1)
3026  return emitOpError(
3027  "tile_sizes and packed_tile_sizes are mutually exclusive");
3028  if (numThreadsSpec == 0 && tileSizesSpec == 0)
3029  return emitOpError("either (packed_)num_threads or (packed_)tile_sizes "
3030  "must be specified");
3031  return success();
3032 }
3033 
3034 //===----------------------------------------------------------------------===//
3035 // VectorizeChildrenAndApplyPatternsOp
3036 //===----------------------------------------------------------------------===//
3037 
3038 void transform::VectorizeChildrenAndApplyPatternsOp::build(
3039  OpBuilder &builder, OperationState &result, Value target,
3040  bool vectorizePadding, bool vectorizeExtract, bool flatten1DDepthwiseConv) {
3041  result.addOperands(target);
3042  if (vectorizePadding) {
3043  result.addAttribute(
3044  VectorizeChildrenAndApplyPatternsOp::getVectorizePaddingAttrName(
3045  result.name),
3046  builder.getUnitAttr());
3047  }
3048  if (vectorizeExtract) {
3049  result.addAttribute(
3050  VectorizeChildrenAndApplyPatternsOp::getVectorizeNdExtractAttrName(
3051  result.name),
3052  builder.getUnitAttr());
3053  }
3054  if (flatten1DDepthwiseConv) {
3055  result.addAttribute(
3056  VectorizeChildrenAndApplyPatternsOp::getFlatten_1dDepthwiseConvAttrName(
3057  result.name),
3058  builder.getUnitAttr());
3059  }
3060  result.addTypes(transform::AnyOpType::get(builder.getContext()));
3061 }
3062 
3063 namespace {
3064 /// This is an helper only to call vectorize via a pattern inside of
3065 /// VectorizeChildrenAndApplyPatternsOp::applyToOne.
3066 struct VectorizationPattern : public RewritePattern {
3067  explicit VectorizationPattern(MLIRContext *context,
3068  bool vectorizeExtract = false,
3069  bool flattenConv = false)
3070  : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context),
3071  vectorizeNDExtract(vectorizeExtract),
3072  flatten1DDepthwiseConv(flattenConv) {}
3073  LogicalResult matchAndRewrite(Operation *op,
3074  PatternRewriter &rewriter) const override {
3075  LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
3076  if (!linalgOp)
3077  return rewriter.notifyMatchFailure(op, "expected Linalg Op");
3078  return vectorize(rewriter, linalgOp, /*inputVectorSizes=*/{},
3079  /*scalableVecDims=*/{}, vectorizeNDExtract,
3080  flatten1DDepthwiseConv);
3081  }
3082 
3083 private:
3084  /// Controls whether to vectorize `tensor.extract` when the input tensor is
3085  /// rank >= 2.
3086  bool vectorizeNDExtract = false;
3087  /// Controls whether to "flatten" the channel dimension when vectorising 1D
3088  /// depthwise convolutions. This should lead to bette vectorization for
3089  /// tensors with a low number of channel dimensions.
3090  bool flatten1DDepthwiseConv = false;
3091 };
3092 } // namespace
3093 
3095 transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
3096  transform::TransformRewriter &rewriter, Operation *target,
3098  transform::TransformState &state) {
3099  if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
3100  auto diag = this->emitOpError("requires isolated-from-above targets");
3101  diag.attachNote(target->getLoc()) << "non-isolated target";
3103  }
3104 
3105  MLIRContext *ctx = getContext();
3106  RewritePatternSet patterns(ctx);
3107  patterns.add<VectorizationPattern>(ctx, getVectorizeNdExtract(),
3108  getFlatten_1dDepthwiseConv());
3109 
3110  if (!getDisableTransferPermutationMapLoweringPatterns())
3112 
3113  if (!getDisableMultiReductionToContractPatterns())
3115 
3117 
3120  /*benefit=*/2);
3121  vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx);
3122  vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx);
3124 
3125  patterns.add<CopyVectorizationPattern>(ctx);
3126 
3127  if (getVectorizePadding())
3129 
3130  TrackingListener listener(state, *this);
3131  GreedyRewriteConfig config;
3132  config.listener = &listener;
3133  if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns), config)))
3134  return emitDefaultDefiniteFailure(target);
3135 
3136  results.push_back(target);
3138 }
3139 
3140 //===----------------------------------------------------------------------===//
3141 // VectorizeOp
3142 //===----------------------------------------------------------------------===//
3143 
3144 DiagnosedSilenceableFailure transform::VectorizeOp::apply(
3145  transform::TransformRewriter &rewriter,
3146  mlir::transform::TransformResults &transformResults,
3148  auto targets = state.getPayloadOps(getTarget());
3149  if (std::empty(targets))
3151  auto transformOp = cast<TransformOpInterface>(getOperation());
3152  SmallVector<int64_t> vectorSizes;
3154  state, transformOp, getMixedVectorSizes(), vectorSizes);
3155  if (!status.succeeded())
3156  return status;
3157 
3158  // TODO: Check that the correct number of vectorSizes was provided.
3159  for (Operation *target : targets) {
3160  if (!isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp, tensor::UnPackOp>(
3161  target)) {
3162  return mlir::emitSilenceableFailure(target->getLoc())
3163  << "Unsupported Op, cannot vectorize";
3164  }
3165 
3166  if (failed(linalg::vectorize(rewriter, target, vectorSizes,
3167  getScalableSizes(),
3168  getVectorizeNdExtract().has_value()
3169  ? getVectorizeNdExtract().value()
3170  : false))) {
3171  return mlir::emitSilenceableFailure(target->getLoc())
3172  << "Attempted to vectorize, but failed";
3173  }
3174  }
3175 
3177 }
3178 
3179 void transform::VectorizeOp::getEffects(
3181  consumesHandle(getTarget(), effects);
3182  onlyReadsHandle(getVectorSizes(), effects);
3183  modifiesPayload(effects);
3184 }
3185 
3186 SmallVector<OpFoldResult> VectorizeOp::getMixedVectorSizes() {
3187  OpBuilder b(getContext());
3188  return getMixedValues(getStaticVectorSizes(), getVectorSizes(), b);
3189 }
3190 
3192  if (getStaticVectorSizes().size() != getScalableSizes().size())
3193  return emitOpError("expected same number of vector sizes (")
3194  << getStaticVectorSizes().size() << ") and scalable sizes ("
3195  << getScalableSizes().size() << ")";
3196  return success();
3197 }
3198 
3199 //===----------------------------------------------------------------------===//
3200 // HoistRedundantVectorTransfersOp
3201 //===----------------------------------------------------------------------===//
3202 
3204 transform::HoistRedundantVectorTransfersOp::applyToOne(
3205  transform::TransformRewriter &rewriter, func::FuncOp target,
3207  transform::TransformState &state) {
3208  // WARNING: This hoisting does not model parallelism and is generally
3209  // incorrect when used on distributed loops with memref semantics!
3210  // TODO: obsolete and should be retired.
3212  results.push_back(target);
3214 }
3215 
3216 //===----------------------------------------------------------------------===//
3217 // HoistRedundantVectorBroadcastsOp
3218 //===----------------------------------------------------------------------===//
3219 
3221 transform::HoistRedundantVectorBroadcastsOp::applyToOne(
3222  transform::TransformRewriter &rewriter, mlir::Operation *target,
3224  transform::TransformState &state) {
3225  rewriter.setInsertionPoint(target);
3226  linalg::hoistRedundantVectorBroadcasts(rewriter, target);
3227  results.push_back(target);
3229 }
3230 
3231 //===----------------------------------------------------------------------===//
3232 // ConvertConv2DToImg2ColOp.
3233 //===----------------------------------------------------------------------===//
3234 
3235 DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne(
3236  transform::TransformRewriter &rewriter, linalg::LinalgOp target,
3238  transform::TransformState &state) {
3239  rewriter.setInsertionPoint(target);
3240  auto maybeTransformed =
3242  target)
3243  .Case([&](linalg::Conv2DNhwcHwcfOp op) {
3244  return rewriteInIm2Col(rewriter, op);
3245  })
3246  .Case([&](linalg::Conv2DNhwcFhwcOp op) {
3247  return rewriteInIm2Col(rewriter, op);
3248  })
3249  .Case([&](linalg::DepthwiseConv2DNhwcHwcOp op) {
3250  return rewriteInIm2Col(rewriter, op);
3251  })
3252  .Case([&](linalg::Conv2DNchwFchwOp op) {
3253  return rewriteInIm2Col(rewriter, op);
3254  })
3255  .Default([&](Operation *op) {
3256  return rewriter.notifyMatchFailure(op, "not supported");
3257  });
3258  if (failed(maybeTransformed))
3259  return emitDefaultSilenceableFailure(target);
3260  // Handle to the operation producing the img2col tensor.
3261  results.push_back(maybeTransformed->first);
3262  // Handle to the operation that replaces the original convolution.
3263  results.push_back(maybeTransformed->second);
3265 }
3266 
3267 //===----------------------------------------------------------------------===//
3268 // FlattenElementwiseLinalgOp.
3269 //===----------------------------------------------------------------------===//
3270 
3271 DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne(
3272  transform::TransformRewriter &rewriter, linalg::LinalgOp target,
3274  transform::TransformState &state) {
3275  rewriter.setInsertionPoint(target);
3276  if (!isElementwise(target))
3277  return mlir::emitSilenceableFailure(target->getLoc())
3278  << "only elementwise flattening is supported";
3279 
3280  // If rank <= 1, do nothing
3281  if (target.getNumLoops() <= 1) {
3282  results.push_back(target);
3284  }
3285 
3286  // Attempt to flatten all dims to one.
3287  ReassociationIndices reassociation(target.getNumLoops());
3288  std::iota(reassociation.begin(), reassociation.end(), 0);
3289  auto maybeFlattened =
3290  collapseOpIterationDims(target, reassociation, rewriter);
3291  if (failed(maybeFlattened))
3292  return mlir::emitSilenceableFailure(target->getLoc())
3293  << "attempted to flatten, but failed";
3294  results.push_back(maybeFlattened->collapsedOp);
3295  rewriter.replaceOp(target, maybeFlattened->results);
3297 }
3298 
3299 //===----------------------------------------------------------------------===//
3300 // TransposeConv2DOp
3301 //===----------------------------------------------------------------------===//
3302 
3303 DiagnosedSilenceableFailure transform::TransposeConv2DOp::applyToOne(
3304  transform::TransformRewriter &rewriter, linalg::LinalgOp target,
3306  transform::TransformState &state) {
3307  rewriter.setInsertionPoint(target);
3308  auto maybeTransformed =
3310  .Case([&](linalg::Conv2DNhwcFhwcOp op) {
3311  return transposeConv2D(rewriter, op);
3312  })
3313  .Case([&](linalg::Conv2DNhwcFhwcQOp op) {
3314  return transposeConv2D(rewriter, op);
3315  })
3316  .Default([&](Operation *op) {
3317  return rewriter.notifyMatchFailure(op, "not supported");
3318  });
3319  if (failed(maybeTransformed))
3320  return emitDefaultSilenceableFailure(target);
3321  // Handle to the new Conv2D operation with transposed filters
3322  results.push_back(*maybeTransformed);
3324 }
3325 
3326 //===----------------------------------------------------------------------===//
3327 // TransposeMatmulOp
3328 //===----------------------------------------------------------------------===//
3329 
3330 DiagnosedSilenceableFailure transform::TransposeMatmulOp::applyToOne(
3331  transform::TransformRewriter &rewriter, linalg::LinalgOp target,
3333  transform::TransformState &state) {
3334  rewriter.setInsertionPoint(target);
3335  bool transposeLHS = getInputToTranspose() == TransposeMatmulInput::lhs;
3336  auto maybeTransformed =
3338  .Case([&](linalg::MatmulOp op) {
3339  return transposeMatmul(rewriter, op, transposeLHS);
3340  })
3341  .Case([&](linalg::BatchMatmulOp op) {
3342  return transposeBatchMatmul(rewriter, op, transposeLHS);
3343  })
3344  .Default([&](Operation *op) { return failure(); });
3345  if (failed(maybeTransformed))
3346  return emitSilenceableFailure(target->getLoc()) << "not supported";
3347  // Handle to the new Matmul operation with transposed filters
3348  results.push_back(*maybeTransformed);
3350 }
3351 
3352 //===----------------------------------------------------------------------===//
3353 // InsertSliceToCopyOp
3354 //===----------------------------------------------------------------------===//
3355 template <typename OpTy>
3358  transform::TransformState &state) {
3359  static_assert(llvm::is_one_of<OpTy, tensor::InsertSliceOp,
3360  tensor::ParallelInsertSliceOp>() &&
3361  "wrong op type");
3362 
3363  if (auto copySource =
3364  target.getSource().template getDefiningOp<linalg::CopyOp>()) {
3365  results.push_back(copySource);
3367  }
3368 
3369  // If we are inside an InParallel region, temporarily set the insertion point
3370  // outside: only tensor.parallel_insert_slice ops are allowed in there.
3371  if constexpr (std::is_same_v<OpTy, tensor::ParallelInsertSliceOp>) {
3372  rewriter.setInsertionPoint(
3373  target->template getParentOfType<scf::InParallelOp>());
3374  }
3375 
3376  Value extracted = rewriter.create<tensor::ExtractSliceOp>(
3377  target.getLoc(), target.getDest(), target.getMixedOffsets(),
3378  target.getMixedSizes(), target.getMixedStrides());
3379  Value copied = rewriter
3380  .create<linalg::CopyOp>(target.getLoc(),
3381  target.getSource(), extracted)
3382  .getResult(0);
3383  // Reset the insertion point.
3384  rewriter.setInsertionPoint(target);
3385  rewriter.replaceOpWithNewOp<OpTy>(
3386  target, copied, target.getDest(), target.getMixedOffsets(),
3387  target.getMixedSizes(), target.getMixedStrides());
3388 
3389  results.push_back(copied.getDefiningOp());
3391 }
3392 
3393 DiagnosedSilenceableFailure transform::InsertSliceToCopyOp::applyToOne(
3394  transform::TransformRewriter &rewriter, Operation *targetOp,
3396  transform::TransformState &state) {
3397 
3398  rewriter.setInsertionPoint(targetOp);
3399  if (auto target = dyn_cast<tensor::InsertSliceOp>(targetOp))
3400  return doit(rewriter, target, results, state);
3401  if (auto target = dyn_cast<tensor::ParallelInsertSliceOp>(targetOp))
3402  return doit(rewriter, target, results, state);
3403 
3405  emitSilenceableError()
3406  << "only InsertSliceOp and ParallelInsertSliceOp ops are supported";
3407  diag.attachNote(targetOp->getLoc()) << "target op";
3408  return diag;
3409 }
3410 
3411 //===----------------------------------------------------------------------===//
3412 // MapCopyToThreadsOp
3413 //===----------------------------------------------------------------------===//
3414 
3415 DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
3416  transform::TransformRewriter &rewriter, Operation *target,
3418  transform::TransformState &state) {
3419  // Check if the op is supported.
3420  if (!isa<linalg::CopyOp, tensor::PadOp>(target)) {
3422  emitSilenceableError()
3423  << "only linalg.copy and tensor.pad target ops are supported";
3424  diag.attachNote(target->getLoc()) << "target op";
3425  return diag;
3426  }
3427  assert(target->getNumResults() == 1 && "expected single result");
3428  auto resultShapedType = cast<ShapedType>(target->getResult(0).getType());
3429  if (!resultShapedType.hasStaticShape()) {
3431  emitSilenceableError()
3432  << "only statically sized ops of rank <= 3 are supported";
3433  diag.attachNote(target->getLoc()) << "target op";
3434  return diag;
3435  }
3436 
3437  // Conservatively set the minimum viable desired bitwidth alignment.
3438  int64_t desiredBitAlignment = getDesiredBitAlignment();
3439  int64_t eltBitwidth =
3440  resultShapedType.getElementType().getIntOrFloatBitWidth();
3441  if (desiredBitAlignment % eltBitwidth != 0) {
3442  desiredBitAlignment = eltBitwidth;
3443  }
3444 
3445  gpu::CopyMappingInfo mapping(
3446  /*ctx=*/getContext(),
3447  /*totalNumThreads=*/getTotalNumThreads(),
3448  /*alignment=*/desiredBitAlignment,
3449  /*sizes=*/resultShapedType.getShape(),
3450  /*favorPredication=*/false,
3451  /*elementalBitwidth=*/
3452  resultShapedType.getElementType().getIntOrFloatBitWidth());
3453  if (mapping.status == gpu::CopyMappingInfo::Status::Invalid) {
3455  emitSilenceableError()
3456  << "too few threads to map copy op to threads on the most minor "
3457  "dimension, given alignment and vector size constraints, try "
3458  "smaller tile size of mapping to more threads";
3459  diag.attachNote(target->getLoc()) << "target op";
3460  return diag;
3461  }
3462 
3463  // OpBuilder only used to compute attributes.
3464  OpBuilder b(getContext());
3465  linalg::ForallTilingResult tilingResult;
3467  /*rewriter=*/rewriter,
3468  /*state=*/state,
3469  /*transformOp=*/*this,
3470  /*target=*/target,
3471  /*mixedNumThreads=*/getMixedValues(mapping.numThreads, {}, b),
3472  /*mixedTileSizes=*/ArrayRef<OpFoldResult>{},
3473  /*mapping=*/b.getArrayAttr(mapping.threadMapping),
3474  /*tilingResult=*/tilingResult);
3475  if (!diag.succeeded())
3476  return diag;
3477 
3478  results.push_back(tilingResult.tileOp);
3479  results.push_back(tilingResult.tiledOp);
3481 }
3482 
3483 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
3484 
3485 #define GET_OP_CLASSES
3486 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
static MLIRContext * getContext(OpFoldResult val)
DiagnosedSilenceableFailure doit(RewriterBase &rewriter, OpTy target, transform::ApplyToEachResultList &results, transform::TransformState &state)
static Operation * cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
#define DOWNSCALE(trans)
bool isValidPackingPermutation(RelayoutOpTy op, ArrayRef< int64_t > permutation, OuterOrInnerPerm outerOrInnerPerm=OuterOrInnerPerm::Outer)
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static DiagnosedSilenceableFailure reifyMixedParamAndHandleResults(TransformState &state, TransformOpInterface &transformOp, ArrayRef< OpFoldResult > mixedResults, SmallVectorImpl< int64_t > &reified)
When possible, converts each OpFoldResult in mixedResult to an integer if the value can be statically...
static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(transform::TransformState &state, TransformOpInterface transformOp, SmallVector< OpFoldResult > &result, ArrayRef< OpFoldResult > ofrs)
Assuming that ofr is an index attr or a param of index type or a transform dialect handle mapped to e...
#define DOWNSCALE_NORMAL(a, b)
static FailureOr< LinalgOp > tryApply(Operation *operation, Args &&...args)
Attempts to apply the pattern specified as template argument to the given operation.
static void printMultitileSizesTypes(OpAsmPrinter &printer, Operation *op, Type targetType, Type lowSizeType, Type, Type)
static Operation * replaceForAllWithNewSignature(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp, TilingResult &tileAndFuseResult, int64_t resultNumber, SmallVector< OpFoldResult > &offsets, SmallVector< OpFoldResult > &sizes)
Add new operands to the forall op for users of the producerOp that are dominated by the containing sc...
static SmallVector< Operation * > tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
First, find the first "scf::ForallOp" user of producerOp and ensure it is exactly the containingOp,...
static ParseResult parseMultitileSizesTypes(OpAsmParser &parser, Type &targetType, Type &lowSizeType, Type &highSizeType, Type &splitPointType)
#define DBGS()
static LogicalResult applyTilingToAll(RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps, unsigned numLoops, transform::TransformResults &transformResults, function_ref< FailureOr< scf::SCFTileAndFuseResult >(TilingInterface)> applyFn)
Apply a tiling transformation to all payload ops and store both the tiled operation as well as the cr...
static std::tuple< SmallVector< Operation * >, Operation * > tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
Find the first "extract" user of producerOp and tile it right before its use.
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static FailureOr< ForallTilingResult > tileToForallOpImpl(RewriterBase &b, TilingInterface op, ArrayRef< OpFoldResult > numThreads, std::optional< ArrayRef< OpFoldResult >> nominalTileSizes, std::optional< ArrayAttr > mapping, bool omitTileOffsetBoundsCheck)
Rewrite a TilingInterface op to a tiled scf.forall.
Definition: Tiling.cpp:343
Base type for affine expression.
Definition: AffineExpr.h:69
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
UnitAttr getUnitAttr()
Definition: Builders.cpp:114
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:238
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:183
AffineExpr getAffineSymbolExpr(unsigned position)
Definition: Builders.cpp:375
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:128
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:93
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:269
MLIRContext * getContext() const
Definition: Builders.h:55
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:273
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:288
IndexType getIndexType()
Definition: Builders.cpp:71
ArrayAttr getStrArrayAttr(ArrayRef< StringRef > values)
Definition: Builders.cpp:313
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:156
A class for computing basic dominance information.
Definition: Dominance.h:136
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition: Dominance.h:156
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class allows control over how the GreedyPatternRewriteDriver works.
RewriterBase::Listener * listener
An optional listener that should be notified about IR modifications.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:207
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
Definition: AsmPrinter.cpp:93
This class represents a saved insertion point.
Definition: Builders.h:329
bool isSet() const
Returns true if this insert point is set.
Definition: Builders.h:339
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:555
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
Definition: Builders.h:318
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:322
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:414
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:457
This class provides the API for ops that are known to be isolated from above.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getOpResult(unsigned idx)
Definition: Operation.h:416
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:745
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:529
void setOperand(unsigned idx, Value value)
Definition: Operation.h:346
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition: Operation.h:555
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:669
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:341
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_type_range getOperandTypes()
Definition: Operation.h:392
result_type_range getResultTypes()
Definition: Operation.h:423
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition: Operation.h:263
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:869
result_range getOpResults()
Definition: Operation.h:415
result_range getResults()
Definition: Operation.h:410
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:219
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:246
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:638
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:56
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
Type front()
Return first type in the range.
Definition: TypeRange.h:148
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
user_range getUsers() const
Definition: Value.h:228
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
State for analysis-enabled bufferization.
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
void assign(unsigned size, std::nullptr_t)
Sets the list of results to size null pointers.
void reserve(unsigned size)
Reserves space for size elements in the list.
size_t size() const
Returns the number of elements in the list.
void push_back(Operation *op)
Appends an element to the list.
A listener that updates a TransformState based on IR modifications.
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void setValues(OpResult handle, Range &&values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
LogicalResult notifyPayloadOperationReplaced(Operation *op, Operation *replacement)
Notify the transform dialect interpreter that the given op has been replaced with another op and that...
The state maintained across applications of various ops implementing the TransformOpInterface.
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
Definition: AffineOps.cpp:1235
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1138
LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze op and its nested ops.
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition: Visitors.h:137
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
FailureOr< PackingResult > buildPackingLoopNest(RewriterBase &rewriter, tensor::PadOp opToHoist, scf::ForOp outermostEnclosingForOp, ArrayRef< int64_t > transposeVector)
Build the packing loop nest required to hoist opToHoist above outermostEnclosingForOp.
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp namedOp)
Create a GenericOp from the given named operation namedOp and replace namedOp.
LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, const LinalgPaddingOptions &options, LinalgOp &paddedOp, SmallVector< Value > &replacements, SmallVector< tensor::PadOp > &padOps)
Pad the iterator dimensions paddingDimensions of all opToPad operands to a static bounding box.
Definition: Padding.cpp:153
FailureOr< std::pair< Operation *, Operation * > > rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp)
Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) and linalg....
FailureOr< ForallTilingResult > tileToForallOpUsingTileSizes(RewriterBase &builder, TilingInterface op, ArrayRef< OpFoldResult > tileSizes, std::optional< ArrayAttr > mapping)
Same as tileToForallOp, but calculate the number of threads required using the given tileSizes.
Definition: Tiling.cpp:467
FailureOr< LowerUnPackOpResult > lowerUnPack(RewriterBase &rewriter, tensor::UnPackOp unPackOp)
Rewrite pack as empty + transpose + reshape + extract_slice.
Definition: Transforms.cpp:357
std::optional< Value > allocateWorkgroupMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU workgroup memory.
Definition: Promotion.cpp:470
Value bufferizeToAllocation(RewriterBase &rewriter, const BufferizeToAllocationOptions &options, tensor::PadOp padOp, Attribute memorySpace={}, Operation *insertionPoint=nullptr)
Materialize a buffer allocation for the given tensor.pad op and lower the op to linalg....
FailureOr< LinalgOp > specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp)
Create a namedOp from the given GenericOp and replace the GenericOp.
Definition: Specialize.cpp:61
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns)
Definition: Tiling.cpp:912
LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value)
In case of GPU private memory there is no need to deallocate since the memory is freed when going out...
Definition: Promotion.cpp:511
std::optional< Value > allocateGPUPrivateMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU private memory.
Definition: Promotion.cpp:495
FailureOr< Operation * > rewriteInDestinationPassingStyle(RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp)
Rewrite tensor.from_elements to linalg.generic.
FailureOr< ForallTilingResult > tileToForallOp(RewriterBase &builder, TilingInterface op, ArrayRef< OpFoldResult > numThreads, std::optional< ArrayAttr > mapping)
Definition: Tiling.cpp:458
void hoistRedundantVectorTransfers(Operation *root)
Hoist vector.transfer_read/vector.transfer_write on buffers pairs out of immediately enclosing scf::F...
Definition: Hoisting.cpp:202
FailureOr< Operation * > transposeConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op)
Convert linalg.conv_2d_nhwc_fhwc(_q) to linalg.conv_2d_nhwc_hwcf(_q) by materializing transpose.
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
Patterns to fold unit-extent dimensions in operands/results of linalg ops on tensors via reassociativ...
LogicalResult copyToWorkgroupMemory(OpBuilder &b, Value src, Value dst)
Create Memref copy operations and add gpu barrier guards before and after the copy operation to ensur...
Definition: Promotion.cpp:486
FailureOr< Value > hoistPaddingOnTensors(RewriterBase &rewriter, tensor::PadOp opToHoist, int64_t numLoops, ArrayRef< int64_t > transposeVector, tensor::PadOp &hoistedOp, SmallVectorImpl< GenericOp > &transposeOps)
Mechanically hoist padding operations on tensors by numLoops into a new, generally larger tensor.
LogicalResult linalgOpAnchoredEmptyTensorEliminationStep(RewriterBase &rewriter, Operation *op, bufferization::OneShotAnalysisState &state)
Try to eliminate tensor::EmptyOps inside op that are anchored on a LinalgOp.
LogicalResult vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Emit a suitable vector form for an operation.
FailureOr< Operation * > transposeBatchMatmul(RewriterBase &rewriter, linalg::BatchMatmulOp op, bool transposeLHS=true)
Pattern to replace.
LogicalResult promoteSubviewsPrecondition(Operation *op, LinalgPromotionOptions options)
Promote memref.subviews feeding linalg-on-buffers operations.
Definition: Promotion.cpp:399
LogicalResult copyToGPUPrivateMemory(OpBuilder &b, Value src, Value dst)
Normal copy to between src and dst.
Definition: Promotion.cpp:503
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
Definition: Utils.cpp:169
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.
Definition: Interchange.cpp:50
FailureOr< StaticMultiSizeSpecification > computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, int64_t targetSize, int64_t divisor)
Definition: Tiling.cpp:111
FailureOr< SplitReductionResult > splitReduction(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
void hoistRedundantVectorBroadcasts(RewriterBase &rewriter, Operation *root)
Hoist vector.extract/vector.broadcast pairs out of immediately enclosing scf::ForOp iteratively,...
Definition: Hoisting.cpp:97
FailureOr< ForallReductionTilingResult > tileReductionUsingForall(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > tileSizes={}, std::optional< ArrayAttr > mapping=std::nullopt)
Method to tile a reduction to parallel iterations computing partial reductions.
Definition: Tiling.cpp:647
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
Definition: Transforms.cpp:779
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
Definition: Transforms.cpp:490
void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns)
Patterns to promote inputs to outputs and remove unused inputs of linalg.generic ops.
FailureOr< LowerPackResult > lowerPack(RewriterBase &rewriter, tensor::PackOp packOp)
Rewrite pack as pad + reshape + transpose.
Definition: Transforms.cpp:219
FailureOr< LinalgOp > promoteSubViews(OpBuilder &b, LinalgOp op, const LinalgPromotionOptions &options)
Promote the subViews into a new buffer allocated at the insertion point b.
Definition: Promotion.cpp:421
std::function< SplitReductionOptions(LinalgOp op)> ControlSplitReductionFn
Function signature to control reduction splitting.
Definition: Transforms.h:443
LogicalResult deallocateWorkgroupMemory(OpBuilder &, Value)
In case of GPU group memory there is no need to deallocate.
Definition: Promotion.cpp:479
FailureOr< Operation * > transposeMatmul(RewriterBase &rewriter, linalg::MatmulOp op, bool transposeLHS=true)
Convert Linalg matmul ops to transposed variants.
FailureOr< CollapseResult > collapseOpIterationDims(LinalgOp op, ArrayRef< ReassociationIndices > foldedIterationDims, RewriterBase &rewriter)
Collapses dimensions of linalg.generic/linalg.copy operation.
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, tensor::PackOp packOp, linalg::LinalgOp linalgOp, tensor::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
Definition: Transforms.cpp:688
std::pair< TilingInterface, TilingInterface > splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension, OpFoldResult splitPoint)
Split the given op into two parts along the given iteration space dimension at the specified splitPoi...
Definition: Split.cpp:67
FailureOr< SplitReductionResult > splitReductionByScaling(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
Scaling-based implementation of the split reduction transformation.
FailureOr< MultiSizeSpecification > computeMultiTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension, OpFoldResult targetSize, OpFoldResult divisor, bool emitAssertions=true)
Emits the IR computing the multi-sized tiling specification with two tile sizes not exceeding targetS...
Definition: Tiling.cpp:137
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:77
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
FailureOr< scf::SCFReductionTilingResult > tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > tileSize)
Method to tile a reduction and generate a parallel op within a serial loop.
FailureOr< SCFTilingResult > tileUsingSCF(RewriterBase &rewriter, TilingInterface op, const SCFTilingOptions &options)
Method to tile an op that implements the TilingInterface using scf.for for iterating over the tiles.
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
Definition: SCF.cpp:597
FailureOr< SmallVector< scf::ForOp > > lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op)
Method to lower an op that implements the TilingInterface to loops/scalars.
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:105
void populateFoldTensorSubsetIntoVectorTransferPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into vector transfer ops.
DiagnosedSilenceableFailure tileToForallOpImpl(RewriterBase &rewriter, transform::TransformState &state, TransformOpInterface transformOp, Operation *target, ArrayRef< OpFoldResult > mixedNumThreads, ArrayRef< OpFoldResult > mixedTileSizes, std::optional< ArrayAttr > mapping, linalg::ForallTilingResult &tilingResult)
Implementation of tiling operations using scf.forall.
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void onlyReadsHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void consumesHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void producesHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
void populateSinkVectorBroadcastPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant vector broadcasts.
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Definition: XeGPUOps.cpp:21
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This is the representation of an operand reference.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Definition: Builders.h:287
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
NamedAttrList attributes
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
A listener that forwards all notifications to another listener.
Definition: PatternMatch.h:463
ForwardingListener(OpBuilder::Listener *listener)
Definition: PatternMatch.h:464
Container for result values of tiling.
SmallVector< Value > tiledValues
Options for analysis-enabled bufferization.
Transformation to drop unit-extent dimensions from linalg.generic operations.
Definition: Transforms.h:474
Vectorization pattern for memref::CopyOp.
Definition: Transforms.h:1399
Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) dimensions into 1-D depthwise c...
Definition: Transforms.h:1345
Rewrite a TilingInterface op to a tiled scf.forall, applying tiling by numThreads.
Definition: Transforms.h:857
Match and rewrite for the pattern:
Definition: Transforms.h:1472
Match and rewrite for the pattern:
Definition: Transforms.h:1500
LinalgPromotionOptions & setUseFullTileBuffersByDefault(bool use)
Definition: Transforms.h:380
LinalgPromotionOptions & setAlignment(unsigned align)
Definition: Transforms.h:386
LinalgPromotionOptions & setUseAlloca(bool use)
Definition: Transforms.h:399
LinalgPromotionOptions & setCopyInOutFns(CopyCallbackFn const &copyIn, CopyCallbackFn const &copyOut)
Definition: Transforms.h:419
LinalgPromotionOptions & setUseFullTileBuffers(ArrayRef< bool > useFullTiles)
Definition: Transforms.h:369
LinalgPromotionOptions & setMemorySpace(Attribute memorySpc)
Definition: Transforms.h:393
LinalgPromotionOptions & setAllocationDeallocationFns(AllocBufferCallbackFn const &allocFn, DeallocBufferCallbackFn const &deallocFn)
Definition: Transforms.h:409
LinalgPromotionOptions & setOperandsToPromote(ArrayRef< int64_t > operands)
Definition: Transforms.h:358
Split Reduction options.
Definition: Transforms.h:428
Options used to control tile + fuse.
SCFTilingOptions tilingOptions
The tiling options used to control the tiling of the consumer.
Options to use to control tiling.
SCFTilingOptions & setTileSizeComputationFunction(SCFTileSizeComputationFunction fun)
SCFTilingOptions & setInterchange(ArrayRef< int64_t > interchange)
SCFTilingOptions & setTileSizes(ArrayRef< OpFoldResult > ts)
Convenience function to set the tileSizeComputationFunction to a function that computes tile sizes at...
SmallVector< int64_t > interchangeVector
The interchange vector to reorder the tiled loops.