MLIR  21.0.0git
LinalgTransformOps.cpp
Go to the documentation of this file.
1 //===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
12 
38 #include "mlir/IR/PatternMatch.h"
39 #include "mlir/IR/TypeUtilities.h"
41 #include "mlir/Support/LLVM.h"
42 #include "mlir/Support/TypeID.h"
44 #include "llvm/ADT/STLExtras.h"
45 #include "llvm/ADT/ScopeExit.h"
46 #include "llvm/ADT/TypeSwitch.h"
47 #include "llvm/Support/Debug.h"
48 #include <type_traits>
49 
50 using namespace mlir;
51 using namespace mlir::linalg;
52 using namespace mlir::transform;
53 
54 #define DEBUG_TYPE "linalg-transforms"
55 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
56 #define DBGSNL() (llvm::dbgs() << "\n")
57 #define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n")
58 
59 /// Attempts to apply the pattern specified as template argument to the given
60 /// operation. The pattern is expected to have a `returningMatchAndRewrite`
61 /// function that returns the "main" result or failure. Returns failure if the
62 /// pattern failed to apply. Extra arguments are forwarded to the pattern
63 /// constructor.
64 template <typename PatternTy, typename... Args>
65 static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
66  // Check if the given operation has the type expected by the pattern.
67  using OpTy = typename llvm::function_traits<
68  decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
69  auto op = dyn_cast<OpTy>(operation);
70  if (!op)
71  return failure();
72 
73  // Apply the pattern directly to the op.
74  PatternTy pattern(operation->getContext(), std::forward<Args>(args)...);
75  // We want to discourage direct use of PatternRewriter in APIs but In this
76  // very specific case, an IRRewriter is not enough.
77  struct TrivialPatternRewriter : public PatternRewriter {
78  public:
79  explicit TrivialPatternRewriter(MLIRContext *context)
80  : PatternRewriter(context) {}
81  };
82  TrivialPatternRewriter rewriter(operation->getContext());
83  rewriter.setInsertionPoint(operation);
84  auto result = pattern.returningMatchAndRewrite(op, rewriter);
85  if (failed(result))
86  return failure();
87  return cast<LinalgOp>(result->getOperation());
88 }
89 
90 /// Assuming that `ofr` is an index attr or a param of index type
91 /// or a transform dialect handle mapped to exactly one op
92 /// with one index result, return that value.
94  transform::TransformState &state, TransformOpInterface transformOp,
96  for (OpFoldResult ofr : ofrs) {
97  if (auto attr = dyn_cast<Attribute>(ofr)) {
98  if (!isa<IntegerAttr>(attr))
99  return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
100  result.push_back(ofr);
101  continue;
102  }
103 
104  Value transformValue = cast<Value>(ofr);
105  if (isa<TransformParamTypeInterface>(transformValue.getType())) {
106  ArrayRef<Attribute> params = state.getParams(transformValue);
107  if (params.size() != 1)
108  return transformOp.emitDefiniteFailure()
109  << "requires exactly one parameter associated";
110  result.push_back(params[0]);
111  continue;
112  }
113 
114  auto payloadOps = state.getPayloadOps(transformValue);
115  if (!llvm::hasSingleElement(payloadOps)) {
117  transformOp.emitSilenceableError()
118  << "handle must be mapped to exactly one payload op";
119  diag.attachNote(transformValue.getLoc())
120  << "mapped to " << llvm::range_size(payloadOps) << " payload ops";
121  return diag;
122  }
123 
124  Operation *op = *payloadOps.begin();
125  if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
127  transformOp.emitSilenceableError()
128  << "payload op must have exactly 1 index result";
129  diag.attachNote(op->getLoc())
130  << "has " << op->getNumResults() << " results";
131  return diag;
132  }
133  result.push_back(op->getResult(0));
134  }
135 
137 }
138 
139 // Given a list of params that are index attrs or a list of OpFoldResults
140 // that are either index attrs or op handles, return a list of OpFoldResults
141 // of index attrs or a list of OpFoldResults where all op handles are
142 // replaced with the first (and only) OpResult of that payload op.
143 // (There must be exactly one parameter associated with the AnyParamType or
144 // one mapped payload op which must have exactly one index result.)
146  transform::TransformState &state, TransformOpInterface transformOp,
147  SmallVector<OpFoldResult> &result, Value packedHandle) {
148  if (isa<TransformParamTypeInterface>(packedHandle.getType())) {
149  ArrayRef<Attribute> params = state.getParams(packedHandle);
150  for (auto param : params) {
151  if (!isa<IntegerAttr>(param))
152  return transformOp.emitDefiniteFailure()
153  << "expected the parameter to be associated with an integer "
154  "attribute";
155  result.push_back(param);
156  }
158  }
159 
160  for (Operation *op : state.getPayloadOps(packedHandle)) {
161  if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
163  transformOp.emitSilenceableError()
164  << "payload op must have exactly 1 index result";
165  diag.attachNote(op->getLoc())
166  << "has " << op->getNumResults() << " results";
167  return diag;
168  }
169  result.push_back(op->getResult(0));
170  }
171 
173 }
174 
175 /// When possible, converts each `OpFoldResult` in `mixedResult` to
176 /// an integer if the value can be statically inferred. If a result
177 /// is a `Value` then it must be either a `ParamType` or a handle
178 /// to an a constant like op.
180  TransformState &state, TransformOpInterface &transformOp,
181  ArrayRef<OpFoldResult> mixedResults, SmallVectorImpl<int64_t> &reified) {
182  for (OpFoldResult paramOrHandle : mixedResults) {
183  if (auto attr = dyn_cast<Attribute>(paramOrHandle)) {
184  reified.push_back(cast<IntegerAttr>(attr).getInt());
185  continue;
186  } else if (isa<ParamType>(cast<Value>(paramOrHandle).getType())) {
187  ArrayRef<Attribute> params = state.getParams(cast<Value>(paramOrHandle));
188  if (params.size() != 1)
189  return transformOp.emitSilenceableError() << "expected a single param";
190  reified.push_back(
191  cast<IntegerAttr>(params.front()).getValue().getSExtValue());
192  continue;
193  }
194 
195  Value handle = cast<Value>(paramOrHandle);
196  if (!isa<TransformHandleTypeInterface>(handle.getType()))
197  return transformOp.emitSilenceableError() << "unexpected value handle";
198  auto payload = state.getPayloadOps(handle);
199  if (!llvm::hasSingleElement(payload))
200  return transformOp.emitSilenceableError()
201  << "requires param or handle that is mapped to 1 payload op";
202 
203  Operation *paramOrHandlePayloadOp = *payload.begin();
204  if (paramOrHandlePayloadOp->getNumResults() != 1 ||
205  !paramOrHandlePayloadOp->getResult(0).getType().isIndex()) {
206  return transformOp.emitSilenceableError()
207  << "requires param or handle to be result of op with 1 index "
208  "result";
209  }
210 
211  IntegerAttr attr;
212  if (!matchPattern(paramOrHandlePayloadOp->getResult(0), m_Constant(&attr)))
213  return transformOp.emitSilenceableError()
214  << "requires param or handle to be the result of a constant like "
215  "op";
216 
217  reified.push_back(attr.getInt());
218  }
220 }
221 
222 //===----------------------------------------------------------------------===//
223 // Apply...PatternsOp
224 //===----------------------------------------------------------------------===//
225 
226 void transform::ApplyEraseUnnecessaryInputsPatternsOp::populatePatterns(
229 }
230 
231 void transform::ApplyDecomposeTensorPackUnpackPatternsOp::populatePatterns(
234 }
235 
236 void transform::ApplyDecomposeTensorPadPatternsOp::populatePatterns(
239 }
240 
241 void transform::ApplyFoldUnitExtentDimsViaReshapesPatternsOp::populatePatterns(
245 }
246 
247 void transform::ApplyFoldUnitExtentDimsViaSlicesPatternsOp::populatePatterns(
250  options.rankReductionStrategy =
253 }
254 
255 void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns(
258 }
259 
260 void transform::ApplyFoldAddIntoDestPatternsOp::populatePatterns(
263 }
264 
265 void transform::ApplyPadVectorizationPatternsOp::populatePatterns(
268 }
269 
270 void transform::ApplyFoldIntoPackAndUnpackPatternsOp::populatePatterns(
273 }
274 
275 void transform::ApplyFoldPackUnpackIntoEmptyPatternsOp::populatePatterns(
278 }
279 
280 //===----------------------------------------------------------------------===//
281 // BufferizeToAllocationOp
282 //===----------------------------------------------------------------------===//
283 
284 void transform::BufferizeToAllocationOp::build(OpBuilder &b,
285  OperationState &result,
286  Value target,
287  Attribute memorySpace) {
288  SmallVector<Type> resultTypes;
289  resultTypes.push_back(b.getType<transform::AnyValueType>());
290  resultTypes.push_back(b.getType<transform::AnyOpType>());
291  return build(b, result,
292  /*resultTypes=*/resultTypes,
293  /*target=*/target,
294  /*memorySpace=*/memorySpace);
295 }
296 
297 void transform::BufferizeToAllocationOp::build(OpBuilder &b,
298  OperationState &result,
299  Value target,
300  int64_t memorySpace) {
301  SmallVector<Type> resultTypes;
302  resultTypes.push_back(b.getType<transform::AnyValueType>());
303  resultTypes.push_back(b.getType<transform::AnyOpType>());
304  return build(b, result,
305  /*resultTypes=*/resultTypes,
306  /*target=*/target,
307  /*memorySpace=*/b.getI64IntegerAttr(memorySpace));
308 }
309 
310 namespace {
311 class NewOpsListener : public RewriterBase::ForwardingListener {
312 public:
314 
315  SmallVector<Operation *> getNewOps() const {
316  return SmallVector<Operation *>(newOps.begin(), newOps.end());
317  }
318 
319 private:
320  void notifyOperationInserted(Operation *op,
321  OpBuilder::InsertPoint previous) override {
322  ForwardingListener::notifyOperationInserted(op, previous);
323  // We only care about newly created ops.
324  if (previous.isSet())
325  return;
326  auto inserted = newOps.insert(op);
327  (void)inserted;
328  assert(inserted.second && "expected newly created op");
329  }
330 
331  void notifyOperationErased(Operation *op) override {
332  ForwardingListener::notifyOperationErased(op);
333  op->walk([&](Operation *op) { newOps.erase(op); });
334  }
335 
336  DenseSet<Operation *> newOps;
337 };
338 } // namespace
339 
340 DiagnosedSilenceableFailure transform::BufferizeToAllocationOp::apply(
343  // Attach listener to keep track of newly created ops.
344  OpBuilder::Listener *previousListener = rewriter.getListener();
345  auto resetListener =
346  llvm::make_scope_exit([&]() { rewriter.setListener(previousListener); });
347  NewOpsListener newOpsListener(previousListener);
348  rewriter.setListener(&newOpsListener);
349 
351  if (getMemcpyOp() == "bufferization.materialize_in_destination") {
354  } else if (getMemcpyOp() == "memref.copy") {
355  options.memcpyOp =
357  } else if (getMemcpyOp() == "linalg.copy") {
358  options.memcpyOp =
360  } else {
361  llvm_unreachable("invalid memcpy op");
362  }
363  if (getAllocOp() == "memref.alloc") {
364  options.allocOp =
366  } else if (getAllocOp() == "memref.alloca") {
367  options.allocOp =
369  } else {
370  llvm_unreachable("invalid alloc op");
371  }
372  options.bufferizeDestinationOnly = getBufferizeDestinationOnly();
373  options.emitDealloc = getEmitDealloc();
374 
375  // Bufferize ops.
376  Attribute memorySpace =
377  getMemorySpace().has_value() ? getMemorySpace().value() : Attribute();
378  SmallVector<Value> allocatedBuffers;
379  for (Operation *op : state.getPayloadOps(getTarget())) {
380  Value buffer =
381  linalg::bufferizeToAllocation(rewriter, options, op, memorySpace);
382  if (!buffer) {
383  DiagnosedSilenceableFailure diag = emitSilenceableError()
384  << "failed to bufferize operation";
385  diag.attachNote(op->getLoc()) << "target payload op";
386  return diag;
387  }
388  allocatedBuffers.push_back(buffer);
389  }
390 
391  // Set results.
392  results.setValues(cast<OpResult>(getAllocatedBuffer()), allocatedBuffers);
393  results.set(cast<OpResult>(getNewOps()), newOpsListener.getNewOps());
395 }
396 
397 void transform::BufferizeToAllocationOp::getEffects(
399  if (getBufferizeDestinationOnly()) {
400  // The destination is replaced with a newly allocated buffer, but the op
401  // itself remains in place.
402  onlyReadsHandle(getTargetMutable(), effects);
403  } else {
404  consumesHandle(getTargetMutable(), effects);
405  }
406  producesHandle(getOperation()->getOpResults(), effects);
407  modifiesPayload(effects);
408 }
409 
411  if (getMemcpyOp() != "bufferization.materialize_in_destination" &&
412  getMemcpyOp() != "memref.copy" && getMemcpyOp() != "linalg.copy")
413  return emitOpError() << "unsupported memcpy op";
414  if (getAllocOp() != "memref.alloc" && getAllocOp() != "memref.alloca")
415  return emitOpError() << "unsupported alloc op";
416  return success();
417 }
418 
419 //===----------------------------------------------------------------------===//
420 // DecomposeOp
421 //===----------------------------------------------------------------------===//
422 
424 transform::DecomposeOp::applyToOne(transform::TransformRewriter &rewriter,
425  LinalgOp target,
427  transform::TransformState &state) {
428 #define DOWNSCALE(trans) \
429  { \
430  FailureOr<LinalgOp> res = tryApply<trans>(target); \
431  if (succeeded(res)) { \
432  results.push_back(*res); \
433  return DiagnosedSilenceableFailure::success(); \
434  } \
435  }
436 
437 #define DOWNSCALE_CALL(a, b) DownscaleSizeOneWindowed2DConvolution<a, b>
438 #define DOWNSCALE_NORMAL(a, b) DOWNSCALE(DOWNSCALE_CALL(a, b))
439 
440  DOWNSCALE_NORMAL(Conv2DNhwcHwcfOp, Conv1DNwcWcfOp)
441  DOWNSCALE_NORMAL(Conv2DNchwFchwOp, Conv1DNcwFcwOp)
442  DOWNSCALE_NORMAL(PoolingNhwcSumOp, PoolingNwcSumOp)
443  DOWNSCALE_NORMAL(PoolingNchwSumOp, PoolingNcwSumOp)
444  DOWNSCALE_NORMAL(PoolingNhwcMaxOp, PoolingNwcMaxOp)
445  DOWNSCALE_NORMAL(PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp)
446  DOWNSCALE_NORMAL(PoolingNhwcMinOp, PoolingNwcMinOp)
447  DOWNSCALE_NORMAL(PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp)
448  DOWNSCALE_NORMAL(PoolingNchwMaxOp, PoolingNcwMaxOp)
451 #undef DOWNSCALE_NORMAL
452 #undef DOWNSCALE_CALL
453 #undef DOWNSCALE
454  return emitDefaultSilenceableFailure(target);
455 }
456 
457 //===----------------------------------------------------------------------===//
458 // DecomposeInterfaceOp
459 //===----------------------------------------------------------------------===//
460 
461 // Decompose the target operation if it implements the AggregatedOpInterface.
462 // Push the decomposed operations (the ones that replaces the values produced by
463 // \p target) in the `results`.
464 DiagnosedSilenceableFailure transform::DecomposeInterfaceOp::applyToOne(
465  transform::TransformRewriter &rewriter, Operation *target,
467  transform::TransformState &state) {
468  auto decomposableOp = dyn_cast<AggregatedOpInterface>(target);
469  if (!decomposableOp) {
470  failed(rewriter.notifyMatchFailure(target,
471  "payload is not a decomposable op"));
472  return emitDefaultSilenceableFailure(target);
473  }
474 
475  FailureOr<SmallVector<Value>> maybeNewResults =
476  decomposableOp.decomposeOperation(rewriter);
477  if (failed(maybeNewResults))
478  return emitDefaultSilenceableFailure(target);
479 
480  rewriter.replaceOp(decomposableOp, *maybeNewResults);
481  for (Value val : *maybeNewResults) {
482  Operation *definition = val.getDefiningOp();
483  if (definition)
484  results.push_back(definition);
485  }
487 }
488 
489 //===----------------------------------------------------------------------===//
490 // EliminateLinalgOpAnchoredEmptyTensorsOp
491 //===----------------------------------------------------------------------===//
492 
493 void transform::EliminateLinalgOpAnchoredEmptyTensorsOp::getEffects(
495  onlyReadsHandle(getTargetMutable(), effects);
496  modifiesPayload(effects);
497 }
498 
500 transform::EliminateLinalgOpAnchoredEmptyTensorsOp::apply(
501  transform::TransformRewriter &rewriter, TransformResults &transformResults,
502  TransformState &state) {
504  options.allowReturnAllocsFromLoops = true;
505 
506  for (Operation *target : state.getPayloadOps(getTarget())) {
508  if (failed(analyzeOp(target, state)))
509  return mlir::emitSilenceableFailure(target->getLoc())
510  << "failed to analyze op";
512  rewriter, target, state)))
513  return mlir::emitSilenceableFailure(target->getLoc())
514  << "failed to eliminate LinalgOp anchored tensor.empty ops";
515  }
517 }
518 
519 //===----------------------------------------------------------------------===//
520 // FuseOp
521 //===----------------------------------------------------------------------===//
522 
523 /// Apply a tiling transformation to all payload ops and store both the
524 /// tiled operation as well as the created tile loops.
525 template <typename Range>
526 static LogicalResult applyTilingToAll(
527  RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps,
528  unsigned numLoops, transform::TransformResults &transformResults,
529  function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)>
530  applyFn) {
531  SmallVector<Operation *> tiledLinalgOps;
532  SmallVector<SmallVector<Operation *>> loopOps(numLoops);
533 
534  for (Operation *target : payloadOps) {
535  auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
536  if (!tilingInterfaceOp)
537  return transformOp->emitError("only TilingInterface ops are supported");
538 
539  rewriter.setInsertionPoint(target);
540  FailureOr<scf::SCFTileAndFuseResult> tiledResults =
541  applyFn(tilingInterfaceOp);
542  if (failed(tiledResults))
543  return failure();
544 
545  // Perform the replacement of tiled and fused values.
546  SmallVector<Operation *> opsToReplace{target};
547  llvm::append_range(opsToReplace, tiledResults->fusedProducers);
548  for (Operation *toReplace : opsToReplace) {
549  for (OpResult res : toReplace->getResults())
550  if (auto replacement = tiledResults->replacements.lookup(res))
551  rewriter.replaceAllUsesWith(res, replacement);
552  if (toReplace->use_empty()) {
553  rewriter.eraseOp(toReplace);
554  }
555  }
556 
557  // Report back the relevant handles to the transform op.
558  tiledLinalgOps.push_back(tiledResults->tiledAndFusedOps.front());
559  assert(tiledResults->loops.size() == numLoops &&
560  "Mismatched number of loops, tile and fuse transform should have "
561  "failed");
562  for (unsigned int i = 0; i < numLoops; ++i)
563  loopOps[i].push_back(tiledResults->loops[i]);
564  }
565 
566  transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
567  for (unsigned int i = 0; i < numLoops; ++i)
568  transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
569 
570  return success();
571 }
572 
574 transform::FuseOp::apply(transform::TransformRewriter &rewriter,
575  mlir::transform::TransformResults &transformResults,
577  SmallVector<int64_t> tileSizes =
578  extractFromIntegerArrayAttr<int64_t>(getTileSizes());
579  SmallVector<int64_t> tileInterchange =
580  extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
581 
582  scf::SCFTilingOptions tilingOptions;
583  tilingOptions.interchangeVector = tileInterchange;
584  SmallVector<OpFoldResult> tileSizesOfr =
585  getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
586  tilingOptions = tilingOptions.setTileSizes(tileSizesOfr);
587  scf::SCFTileAndFuseOptions tileAndFuseOptions;
588  tileAndFuseOptions.tilingOptions = tilingOptions;
589 
590  if (getApplyCleanup()) {
591  MLIRContext *context = rewriter.getContext();
592  RewritePatternSet patterns(context);
593  tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, context);
596  tileAndFuseOptions.cleanupPatterns = std::move(patterns);
597  }
598 
599  LogicalResult result = applyTilingToAll(
600  rewriter, getOperation(), state.getPayloadOps(getTarget()),
601  tileSizes.size() - llvm::count(tileSizes, 0), transformResults,
602  [&](TilingInterface tilingInterfaceOp)
603  -> FailureOr<scf::SCFTileAndFuseResult> {
604  return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
605  tileAndFuseOptions);
606  });
607  return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
608  : DiagnosedSilenceableFailure::success();
609 }
610 
611 LogicalResult transform::FuseOp::verify() {
612  SmallVector<int64_t> permutation =
613  extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
614  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
615  if (!std::is_permutation(sequence.begin(), sequence.end(),
616  permutation.begin(), permutation.end())) {
617  return emitOpError() << "expects interchange to be a permutation, found "
618  << getTileInterchange();
619  }
620 
621  SmallVector<int64_t> sizes =
622  extractFromIntegerArrayAttr<int64_t>(getTileSizes());
623  size_t numExpectedLoops = sizes.size() - llvm::count(sizes, 0);
624  if (numExpectedLoops != getNumResults() - 1)
625  return emitOpError() << "expects " << numExpectedLoops << " loop results";
626 
627  return success();
628 }
629 
630 //===----------------------------------------------------------------------===//
631 // FuseIntoContainingOp
632 //===----------------------------------------------------------------------===//
633 
634 void transform::FuseIntoContainingOp::build(OpBuilder &builder,
635  OperationState &result,
636  Value producerOp,
637  Value containingOp) {
638  result.addOperands({producerOp, containingOp});
639  auto resultType = transform::AnyOpType::get(builder.getContext());
640  result.addTypes({resultType, resultType});
641 }
642 
643 /// Add new operands to the forall op for users of the producerOp
644 /// that are dominated by the containing scf.forall op.
646  RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
647  Operation *containingOp, TilingResult &tileAndFuseResult,
648  int64_t resultNumber, SmallVector<OpFoldResult> &offsets,
649  SmallVector<OpFoldResult> &sizes) {
650 
651  // Count number of users not including the containing op
652  SetVector<Operation *> dominatedUsers;
653  DominanceInfo domInfo(containingOp);
654  for (Operation *user : producerOp->getResult(resultNumber).getUsers()) {
655  if (!containingOp->isAncestor(user) &&
656  (domInfo.dominates(containingOp, user))) {
657  dominatedUsers.insert(user);
658  }
659  }
660  if (dominatedUsers.empty())
661  return nullptr;
662 
663  // Create new scf.forall op
664  auto forallOp = cast<scf::ForallOp>(containingOp);
665  OpBuilder::InsertionGuard g(rewriter);
666  rewriter.setInsertionPoint(forallOp);
667 
668  // Get new output
669  Location loc = forallOp.getLoc();
670  auto genericOp = dyn_cast<linalg::GenericOp>(producerOp);
671  if (!genericOp)
672  return nullptr;
673  SmallVector<Value> outputs = genericOp.getOutputs();
674  SmallVector<Value> newOuts(forallOp.getOutputs());
675  newOuts.push_back(outputs[resultNumber]);
676 
677  // Create new scf.forall op
678  auto newforallOp = rewriter.create<scf::ForallOp>(
679  loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
680  forallOp.getMixedStep(), newOuts, forallOp.getMapping());
681  rewriter.eraseBlock(newforallOp.getBody());
682  newforallOp.getRegion().takeBody(forallOp.getRegion());
683 
684  // Add additional block argument for new value being returned
685  // and replaces all uses of the new output with corresponding bbArg
686  // inside the scf.forall to enable fusion into this new scf.forall.
687  newforallOp.getBody()->addArgument(newOuts.back().getType(),
688  newOuts.back().getLoc());
689  auto bbArgs = newforallOp.getBody()->getArguments();
690  rewriter.replaceUsesWithIf(newOuts.back(), bbArgs.back(),
691  [&](OpOperand &use) {
692  Operation *op = use.getOwner();
693  return newforallOp->isProperAncestor(op);
694  });
695 
696  // Fix terminator
697  scf::InParallelOp terminatorOp = newforallOp.getTerminator();
698  SmallVector<Operation *> yieldingOps = llvm::to_vector<4>(llvm::map_range(
699  terminatorOp.getYieldingOps(), [](Operation &op) { return &op; }));
700  Operation *firstYieldOp = yieldingOps.front();
701  rewriter.setInsertionPoint(firstYieldOp);
702  Value src = tileAndFuseResult.tiledValues[0];
703  Value dst = newforallOp.getRegionIterArgs().back();
704  SmallVector<OpFoldResult> strides(offsets.size(), rewriter.getIndexAttr(1));
705  rewriter.create<tensor::ParallelInsertSliceOp>(firstYieldOp->getLoc(), src,
706  dst, offsets, sizes, strides);
707 
708  for (auto result : llvm::enumerate(forallOp.getResults())) {
709  rewriter.replaceAllUsesWith(result.value(),
710  newforallOp->getResult(result.index()));
711  }
712  rewriter.replaceUsesWithIf(producerOp->getResult(resultNumber),
713  newforallOp->getResults().back(),
714  [&](OpOperand &use) {
715  Operation *user = use.getOwner();
716  return dominatedUsers.contains(user);
717  });
718  return newforallOp;
719 }
720 
721 /// Find the first "extract" user of `producerOp` and tile it right before its
722 /// use. The tiled op is fused under the `containingOp`.
723 /// Return this fused op on success or nullptr if anything fails.
724 /// If tiled op has uses that are dominated by `containingOp`, return
725 /// a new `containingOp` with results of the fused op appended to
726 /// results of the `containingOp` or nullptr if there are no dominated uses.
727 static std::tuple<SmallVector<Operation *>, Operation *>
729  Operation *producerOp, Operation *containingOp) {
730  LLVM_DEBUG(DBGS() << "Try to fuse a direct extract use\n");
731  auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
732  if (!tileableProducer) {
733  diag.attachNote(producerOp->getLoc())
734  << "producer is not a TileableInterface: " << *producerOp;
735  return {};
736  }
737 
738  // Search the producer slices accessed within the containing operation.
739  // TODO: Generalize to more extract/insert/parallel_insert triples, maybe
740  // evolve into an interface.
741  auto it = llvm::find_if(tileableProducer->getUsers(), [&](Operation *user) {
742  auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
743  return sliceOp && containingOp->isProperAncestor(sliceOp);
744  });
745 
746  // Find a fusion opportunity.
747  if (it == tileableProducer->getUsers().end()) {
748  diag.attachNote(tileableProducer->getLoc())
749  << "could not find fusion opportunity for: " << *tileableProducer;
750  return {};
751  }
752  auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*it);
753 
754  // Try to fuse the producer in-place.
755  OpBuilder::InsertionGuard guard(rewriter);
756  rewriter.setInsertionPoint(sliceOpToTile);
757 
758  // Tile the producer.
759  int64_t resultNumber =
760  cast<OpResult>(sliceOpToTile.getSource()).getResultNumber();
761  LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
762 
763  SmallVector<OpFoldResult> offsets = sliceOpToTile.getMixedOffsets();
764  SmallVector<OpFoldResult> sizes = sliceOpToTile.getMixedSizes();
765 
766  FailureOr<TilingResult> tileAndFuseResult =
767  tileableProducer.generateResultTileValue(rewriter, resultNumber, offsets,
768  sizes);
769 
770  if (failed(tileAndFuseResult)) {
771  diag.attachNote(tileableProducer->getLoc())
772  << "failed to tile producer op: " << *tileableProducer;
773  return {};
774  }
775 
776 #ifndef NDEBUG
777  for (auto *tiledOp : tileAndFuseResult->tiledOps) {
778  LLVM_DEBUG(DBGS() << "tiledProducer: " << *tiledOp << "\n");
779  }
780 #endif
781 
782  // Replace the extract op.
783  auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
784  rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
785  cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
786  if (failed(maybeRankReduced)) {
787  diag.attachNote(producerOp->getLoc())
788  << "shape types don't match (missing canonicalization?):\nTiledOp: "
789  << tileAndFuseResult->tiledValues[0]
790  << "\nSliceOp: " << sliceOpToTile.getOperation() << '\n';
791  return {};
792  }
793  rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
794 
795  // Add new outputs to containing op, if required
796  Operation *newContainingOp = replaceForAllWithNewSignature(
797  rewriter, diag, producerOp, containingOp, *tileAndFuseResult,
798  resultNumber, offsets, sizes);
799 
800  return std::make_tuple(tileAndFuseResult->tiledOps, newContainingOp);
801 }
802 
803 /// First, find the first "scf::ForallOp" user of `producerOp` and ensure
804 /// it is exactly the `containingOp`, otherwise bail.
805 /// Then, find the first "extract" user of the tied block argument and tile it
806 /// right before its "extract" use. The tiled op is fused under the
807 /// `containingOp`.
808 /// Return this fused op on success or nullptr if anything fails.
811  RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
812  Operation *containingOp) {
813  LLVM_DEBUG(DBGS() << "Try to fuse an extract use through block argument\n");
814 
815  auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
816  if (!tileableProducer) {
817  diag.attachNote(producerOp->getLoc())
818  << "producer is not a TileableInterface: " << *producerOp;
819  return {};
820  }
821 
822  // Search the first use by a "scf::ForallOp" user.
823  scf::ForallOp forallOp;
824  auto itProducerUses =
825  llvm::find_if(tileableProducer->getUses(), [&](OpOperand &use) {
826  forallOp = dyn_cast<scf::ForallOp>(use.getOwner());
827  return forallOp;
828  });
829  // If it's not from the containing op, return.
830  if (!forallOp || forallOp != containingOp) {
831  diag.attachNote(tileableProducer->getLoc())
832  << "could not find a use by the containing op: " << *tileableProducer;
833  return {};
834  }
835 
836  // Search the producer slices accessed within the containing
837  // operation.
838  // TODO: Generalize to more extract/insert/parallel_insert triples.
839  // Maybe evolve into an interface.
840  OpOperand *pUse = &(*itProducerUses);
841  BlockArgument bbArg = forallOp.getTiedBlockArgument(pUse);
842 
843  // Search the producer slices accessed within the containing operation.
844  // TODO: Generalize to more extract/insert/parallel_insert triples, maybe
845  // evolve into an interface.
846  auto itBBArgUsers = llvm::find_if(bbArg.getUsers(), [&](Operation *user) {
847  auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
848  return sliceOp && containingOp->isProperAncestor(sliceOp);
849  });
850 
851  // Find a fusion opportunity.
852  if (itBBArgUsers == bbArg.getUsers().end()) {
853  diag.attachNote(containingOp->getLoc())
854  << "could not find fusion opportunity for bbArg: " << bbArg;
855  return {};
856  }
857  auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*itBBArgUsers);
858 
859  // Try to fuse the producer in-place.
860  OpBuilder::InsertionGuard guard(rewriter);
861  rewriter.setInsertionPoint(sliceOpToTile);
862 
863  // Replace the use in the tileableProducer before tiling: clone, replace and
864  // then tile.
865  int64_t resultNumber = cast<OpResult>(pUse->get()).getResultNumber();
866  LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
867 
868  // Gather destination tensors.
869  SmallVector<Value> destinationTensors;
871  rewriter, tileableProducer->getLoc(), tileableProducer,
872  destinationTensors))) {
873  diag.attachNote(tileableProducer->getLoc())
874  << "failed to get destination tensors for: " << *tileableProducer;
875  return {};
876  }
877 
878  IRMapping bvm;
879  bvm.map(destinationTensors[resultNumber], bbArg);
880  auto tileableProducerClone =
881  cast<TilingInterface>(rewriter.clone(*tileableProducer, bvm));
882  auto scopeGuard =
883  llvm::make_scope_exit([&]() { rewriter.eraseOp(tileableProducerClone); });
884 
885  // Tile the producer.
886  FailureOr<TilingResult> tileAndFuseResult =
887  tileableProducerClone.generateResultTileValue(
888  rewriter, resultNumber, sliceOpToTile.getMixedOffsets(),
889  sliceOpToTile.getMixedSizes());
890  if (failed(tileAndFuseResult)) {
891  diag.attachNote(tileableProducer->getLoc())
892  << "failed to tile producer op: " << *tileableProducer;
893  return {};
894  }
895 
896  // Replace the extract op.
897  auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
898  rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
899  cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
900  assert(succeeded(maybeRankReduced) && "unexpected shape");
901  rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
902 
903  // Replace the use in containingOp.
904  rewriter.modifyOpInPlace(containingOp, [&]() {
905  containingOp->setOperand(pUse->getOperandNumber(),
906  destinationTensors.front());
907  });
908 
909  return tileAndFuseResult->tiledOps;
910 }
911 
913  Operation *producerOp,
914  Operation *containingOp) {
915  LLVM_DEBUG(DBGS() << "Try to fuse an use by cloning\n");
916 
917  // Gather all uses inside the containing op.
919  for (OpResult result : producerOp->getOpResults()) {
920  for (OpOperand &use : result.getUses()) {
921  if (containingOp->isProperAncestor(use.getOwner())) {
922  uses.push_back(&use);
923  continue;
924  }
925  // Cannot clone and fuse if the use is by the containing op itself: fail
926  // immediately.
927  if (containingOp == use.getOwner()) {
928  diag.attachNote(producerOp->getLoc())
929  << "producer op use by containing op cannot be fused by cloning";
930  return nullptr;
931  }
932  }
933  }
934 
935  // Check for a non-empty list of fusion opportunities.
936  if (uses.empty()) {
937  diag.attachNote(producerOp->getLoc()) << "no fusion opportunity by cloning";
938  return nullptr;
939  }
940 
941  // Clone and fuse inside the containing op.
942  Operation *fusedOp = nullptr;
943  OpOperand *use = uses.front();
944  // Parallel insert slice is not a valid clone destination.
945  // TODO: Generalize to other type of ops.
946  assert(!isa<tensor::ParallelInsertSliceOp>(use->getOwner()) &&
947  "Parallel insert slice is not a valid clone destination");
948  unsigned resultNumber = cast<OpResult>(use->get()).getResultNumber();
949  LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
950 
951  OpBuilder::InsertionGuard guard(rewriter);
952  rewriter.setInsertionPoint(use->getOwner());
953  fusedOp = rewriter.clone(*producerOp);
954  rewriter.modifyOpInPlace(
955  use->getOwner(), [&] { use->set(fusedOp->getOpResult(resultNumber)); });
956 
957  return fusedOp;
958 }
959 
960 bool transform::FuseIntoContainingOp::allowsRepeatedHandleOperands() {
961  // Allow repeated handles since we are fusing everything anyway.
962  return true;
963 }
964 
966 transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter,
968  transform::TransformState &state) {
969  SmallVector<Operation *> fusedOps;
970  auto producerOps = state.getPayloadOps(getProducerOp());
971  auto containingOps = state.getPayloadOps(getContainingOp());
972  if (!llvm::hasSingleElement(containingOps)) {
973  return emitDefiniteFailure()
974  << "requires exactly one containing_op handle (got "
975  << llvm::range_size(containingOps) << ")";
976  }
977  Operation *containingOp = *containingOps.begin();
978 
979  // If nothing to fuse, propagate success.
980  if (std::empty(producerOps)) {
981  results.set(cast<OpResult>(getFusedOp()), SmallVector<mlir::Operation *>{});
982  results.set(cast<OpResult>(getNewContainingOp()), {containingOp});
984  }
985 
986  // Helper function to find the next producer that should be fused. Take any
987  // producer that has a use inside the containing op.
988  SetVector<Operation *> remainingProducers(producerOps.begin(),
989  producerOps.end());
990  auto getNextProducer = [&]() -> FailureOr<Operation *> {
991  for (const auto &it : enumerate(remainingProducers)) {
992  Operation *producerOp = it.value();
993  // The containing op may be a user of producerOp: use isAncestor.
994  int64_t numUsesInContainingOp =
995  llvm::count_if(producerOp->getUsers(), [&](Operation *op) {
996  return containingOp->isAncestor(op);
997  });
998  // TODO: When resolving the TODO below (no duplicate ops), take an op
999  // that has no use among the remaining producers. This is a topological
1000  // sorting.
1001  if (numUsesInContainingOp > 0) {
1002  if (numUsesInContainingOp == 1)
1003  remainingProducers.erase(remainingProducers.begin() + it.index());
1004  return producerOp;
1005  }
1006  }
1007  return failure();
1008  };
1009 
1010  while (!remainingProducers.empty()) {
1011  auto nextProducer = getNextProducer();
1012  if (failed(nextProducer)) {
1013  auto diag = mlir::emitSilenceableFailure(getLoc())
1014  << "could not find next producer to fuse into container";
1015  diag.attachNote(containingOp->getLoc()) << "containing op";
1016  return diag;
1017  }
1018 
1019  Operation *producerOp = *nextProducer;
1020 
1021  // Default diagnostic, to be complemented with more failure information.
1023  diag << "could not fuse " << *producerOp << " into " << *containingOp;
1024 
1025  // TODO: If there are multiple uses of the producer in the containing op,
1026  // we currently tile/clone the op multiple times (once per use). In some
1027  // cases, we can tile/clone once and reuse the value for each use.
1028  // Futhermore, producers should then be traversed according to a
1029  // topological sorting.
1030  auto [tiledOps, newContainingOp] =
1031  tileAndFuseFirstExtractUse(rewriter, diag, producerOp, containingOp);
1032  if (!tiledOps.empty()) {
1033  LLVM_DEBUG(DBGS() << "\nFused a direct extract use\n" << *containingOp);
1034  fusedOps.append(tiledOps);
1035  if (newContainingOp) {
1036  // Update handles associated with the containing op so we don't need to
1037  // invalidate them. This is a hack to support better composability
1038  // between tiling and fusion while a proper mechanism is being
1039  // investigated.
1040  //
1041  // DO NOT replicate this elsewhere unless you understand what you are
1042  // doing.
1043  LogicalResult replacementStatus =
1044  rewriter.notifyPayloadOperationReplaced(containingOp,
1045  newContainingOp);
1046  (void)replacementStatus;
1047  assert(succeeded(replacementStatus) &&
1048  "unable to update transform state mapping");
1049  rewriter.eraseOp(containingOp);
1050  containingOp = newContainingOp;
1051  }
1052  continue;
1053  }
1054 
1055  SmallVector<Operation *> tiledContainingOpOperand =
1057  rewriter, diag, producerOp, containingOp);
1058  if (!tiledContainingOpOperand.empty()) {
1059  LLVM_DEBUG(DBGS() << "\nFused an extract use through block argument\n"
1060  << *containingOp);
1061  fusedOps.append(tiledContainingOpOperand);
1062  continue;
1063  }
1064 
1065  Operation *cloned =
1066  cloneAndFuseFirstUse(rewriter, diag, producerOp, containingOp);
1067  if (cloned) {
1068  LLVM_DEBUG(DBGS() << "\nFused an use by cloning\n" << *containingOp);
1069  fusedOps.push_back(cloned);
1070  continue;
1071  }
1073  }
1074 
1075  results.set(cast<OpResult>(getFusedOp()), fusedOps);
1076  results.set(cast<OpResult>(getNewContainingOp()), {containingOp});
1078 }
1079 
1080 void transform::FuseIntoContainingOp::getEffects(
1082  consumesHandle(getProducerOpMutable(), effects);
1083  onlyReadsHandle(getContainingOpMutable(), effects);
1084  producesHandle(getOperation()->getOpResults(), effects);
1085  modifiesPayload(effects);
1086 }
1087 
1088 //===----------------------------------------------------------------------===//
1089 // GeneralizeOp
1090 //===----------------------------------------------------------------------===//
1091 
1093 transform::GeneralizeOp::applyToOne(transform::TransformRewriter &rewriter,
1094  LinalgOp target,
1096  transform::TransformState &state) {
1097  // Exit early if no transformation is needed.
1098  if (isa<GenericOp>(target)) {
1099  results.push_back(target);
1101  }
1102  rewriter.setInsertionPoint(target);
1103  FailureOr<LinalgOp> generic = generalizeNamedOp(rewriter, target);
1104  if (succeeded(generic)) {
1105  results.push_back(generic->getOperation());
1107  }
1108  return emitDefaultSilenceableFailure(target);
1109 }
1110 
1111 //===----------------------------------------------------------------------===//
1112 // SpecializeOp
1113 //===----------------------------------------------------------------------===/
1114 
1116 transform::SpecializeOp::applyToOne(transform::TransformRewriter &rewriter,
1117  LinalgOp target,
1119  transform::TransformState &state) {
1120  // Exit early if the operation is not a generic.
1121  if (!isa<GenericOp>(target)) {
1122  results.push_back(target);
1124  }
1125  rewriter.setInsertionPoint(target);
1126  FailureOr<LinalgOp> named =
1127  specializeGenericOp(rewriter, cast<GenericOp>(target));
1128  if (succeeded(named)) {
1129  results.push_back(named->getOperation());
1131  }
1132  return emitDefaultSilenceableFailure(target);
1133 }
1134 
1135 //===----------------------------------------------------------------------===//
1136 // InterchangeOp
1137 //===----------------------------------------------------------------------===//
1138 
1140 transform::InterchangeOp::applyToOne(transform::TransformRewriter &rewriter,
1141  GenericOp target,
1143  transform::TransformState &state) {
1144  ArrayRef<int64_t> interchangeVector = getIteratorInterchange();
1145  // Exit early if no transformation is needed.
1146  if (interchangeVector.empty()) {
1147  results.push_back(target);
1149  }
1150 
1151  unsigned numLoops = cast<LinalgOp>(target.getOperation()).getNumLoops();
1152  if (interchangeVector.size() != numLoops) {
1153  return emitSilenceableError()
1154  << getIteratorInterchangeAttrName() << " has length ("
1155  << interchangeVector.size()
1156  << ") different from the number of loops in the target operation ("
1157  << numLoops << ")";
1158  }
1159  FailureOr<GenericOp> res = interchangeGenericOp(
1160  rewriter, target, SmallVector<unsigned>(interchangeVector));
1161  if (failed(res))
1162  return emitDefiniteFailure() << "failed to apply";
1163  results.push_back(res->getOperation());
1165 }
1166 
1167 LogicalResult transform::InterchangeOp::verify() {
1168  ArrayRef<int64_t> permutation = getIteratorInterchange();
1169  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
1170  if (!std::is_permutation(sequence.begin(), sequence.end(),
1171  permutation.begin(), permutation.end())) {
1172  return emitOpError()
1173  << "expects iterator_interchange to be a permutation, found "
1174  << getIteratorInterchange();
1175  }
1176  return success();
1177 }
1178 
1179 //===----------------------------------------------------------------------===//
1180 // LowerPackOp
1181 //===----------------------------------------------------------------------===//
1182 
1183 DiagnosedSilenceableFailure transform::LowerPackOp::applyToOne(
1184  transform::TransformRewriter &rewriter, linalg::PackOp target,
1185  transform::ApplyToEachResultList &transformResults,
1186  transform::TransformState &state) {
1187  rewriter.setInsertionPoint(target);
1188  bool lowerPadLikeWithInsertSlice = getLowerPadLikeWithInsertSlice();
1189  FailureOr<LowerPackResult> res =
1190  lowerPack(rewriter, target, lowerPadLikeWithInsertSlice);
1191  if (failed(res)) {
1192  return mlir::emitSilenceableFailure(target->getLoc())
1193  << "cannot lower to pad + expand + transpose";
1194  }
1195  transformResults.push_back(res->padOp);
1196  transformResults.push_back(res->expandShapeOp);
1197  transformResults.push_back(res->transposeOp);
1199 }
1200 
1201 //===----------------------------------------------------------------------===//
1202 // LowerUnPackOp
1203 //===----------------------------------------------------------------------===//
1204 
1205 DiagnosedSilenceableFailure transform::LowerUnPackOp::applyToOne(
1206  transform::TransformRewriter &rewriter, linalg::UnPackOp target,
1207  transform::ApplyToEachResultList &transformResults,
1208  transform::TransformState &state) {
1209  rewriter.setInsertionPoint(target);
1210  bool lowerUnpadLikeWithExtractSlice = getLowerUnpadLikeWithExtractSlice();
1211  FailureOr<LowerUnPackOpResult> res =
1212  lowerUnPack(rewriter, target, lowerUnpadLikeWithExtractSlice);
1213  if (failed(res)) {
1215  emitSilenceableError()
1216  << "cannot lower to transpose + collapse + extract";
1217  diag.attachNote(target->getLoc()) << "target payload op";
1218  return diag;
1219  }
1220  transformResults.push_back(res->emptyOp);
1221  transformResults.push_back(res->transposeOp);
1222  transformResults.push_back(res->collapseShapeOp);
1223  transformResults.push_back(res->extractSliceOp);
1225 }
1226 
1227 //===---------------------------------------------------------------------===//
1228 // MatchOp
1229 //===---------------------------------------------------------------------===//
1230 
1231 void transform::MatchOp::build(OpBuilder &builder, OperationState &result,
1232  Value target, ArrayRef<StringRef> opNames) {
1233  result.addOperands(target);
1234  result.addAttribute(MatchOp::getOpsAttrName(result.name),
1235  builder.getStrArrayAttr(opNames));
1236  result.addTypes(transform::AnyOpType::get(builder.getContext()));
1237 }
1238 
1239 void transform::MatchOp::build(OpBuilder &builder, OperationState &result,
1240  TypeRange resultTypes, Value target,
1241  ArrayRef<StringRef> opNames) {
1242  result.addOperands(target);
1243  result.addAttribute(MatchOp::getOpsAttrName(result.name),
1244  builder.getStrArrayAttr(opNames));
1245  result.addTypes(resultTypes);
1246 }
1247 
1249 transform::MatchOp::apply(transform::TransformRewriter &rewriter,
1250  transform::TransformResults &results,
1251  transform::TransformState &state) {
1252  llvm::StringSet<> strs;
1253  if (getOps().has_value())
1254  strs.insert(getOps()->getAsValueRange<StringAttr>().begin(),
1255  getOps()->getAsValueRange<StringAttr>().end());
1256 
1257  auto payloadOps = state.getPayloadOps(getTarget());
1258  if (!llvm::hasSingleElement(payloadOps)) {
1259  return emitDefiniteFailure("requires exactly one target handle");
1260  }
1261 
1263  bool incorrectNumOperandTypes = false;
1264  auto matchFun = [&](Operation *op) {
1265  if (getOps().has_value() && !strs.contains(op->getName().getStringRef()))
1266  return;
1267 
1268  // Interfaces cannot be matched by name, just by ID.
1269  // So we specifically encode the interfaces we care about for this op.
1270  if (getInterface().has_value()) {
1271  auto iface = getInterface().value();
1272  if (iface == transform::MatchInterfaceEnum::LinalgOp &&
1273  !isa<LinalgOp>(op))
1274  return;
1275  if (iface == transform::MatchInterfaceEnum::TilingInterface &&
1276  !isa<TilingInterface>(op))
1277  return;
1278  if (iface == transform::MatchInterfaceEnum::LoopLikeInterface &&
1279  !isa<LoopLikeOpInterface>(op))
1280  return;
1281  }
1282 
1283  // Check if all specified attributes match.
1284  if (getOpAttrs().has_value()) {
1285  DictionaryAttr opAttrs = getOpAttrs().value();
1286  for (NamedAttribute attr : opAttrs) {
1287  if (attr.getName() == getInterfaceAttrName() ||
1288  attr.getName() == getOpsAttrName())
1289  continue;
1290  if (!op->hasAttr(attr.getName()))
1291  return;
1292  if (op->getAttr(attr.getName()) != attr.getValue())
1293  return;
1294  }
1295  }
1296 
1297  if (getFilterResultType().has_value()) {
1298  Type t = getFilterResultType().value();
1299  if (op->getNumResults() != 1 || op->getResultTypes().front() != t)
1300  return;
1301  }
1302 
1303  if (getFilterOperandTypes().has_value()) {
1304  mlir::ArrayAttr types = getFilterOperandTypes().value();
1305  auto operandTypes = op->getOperandTypes();
1306 
1307  if (types.size() == 1) {
1308  // All the operands must must be equal to the specified type
1309  auto typeattr =
1310  dyn_cast<mlir::TypeAttr>(getFilterOperandTypes().value()[0]);
1311  Type t = cast<::mlir::Type>(typeattr.getValue());
1312  if (!llvm::all_of(op->getOperandTypes(),
1313  [&](Type operandType) { return operandType == t; }))
1314  return;
1315  } else {
1316  // The operand types must match all the types in the list (in the same
1317  // order in with they are specified)
1318  if (types.size() != operandTypes.size()) {
1319  incorrectNumOperandTypes = true;
1320  return;
1321  }
1322 
1323  for (auto [attr, operandType] :
1324  llvm::zip_equal(getFilterOperandTypes().value(), operandTypes)) {
1325  auto typeattr = cast<mlir::TypeAttr>(attr);
1326  Type type = cast<::mlir::Type>(typeattr.getValue());
1327 
1328  if (type != operandType)
1329  return;
1330  }
1331  }
1332  }
1333 
1334  // All constraints are satisfied.
1335  res.push_back(op);
1336  return;
1337  };
1338 
1339  (*payloadOps.begin())->walk(matchFun);
1340  if (incorrectNumOperandTypes)
1341  return emitDefiniteFailure("If filter_operand_types contains more than a "
1342  "type, then it must contain as much types as "
1343  "the number of operands in the target ops");
1344  results.set(cast<OpResult>(getResult()), res);
1346 }
1347 
1348 //===---------------------------------------------------------------------===//
1349 // MultiTileSizesOp
1350 //===---------------------------------------------------------------------===//
1351 
1353  Type targetType, Type lowSizeType, Type,
1354  Type) {
1355  printer.printFunctionalType(TypeRange{targetType}, TypeRange{lowSizeType});
1356 }
1357 
1358 static ParseResult parseMultitileSizesTypes(OpAsmParser &parser,
1359  Type &targetType, Type &lowSizeType,
1360  Type &highSizeType,
1361  Type &splitPointType) {
1362  FunctionType funcType;
1363  llvm::SMLoc typeLoc = parser.getCurrentLocation();
1364  if (failed(parser.parseType<FunctionType>(funcType)))
1365  return failure();
1366 
1367  if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
1368  parser.emitError(typeLoc) << "expects a trailing functional type with one "
1369  "argument and one result";
1370  }
1371  targetType = funcType.getInput(0);
1372  lowSizeType = highSizeType = splitPointType = funcType.getResult(0);
1373 
1374  return success();
1375 }
1376 
1377 DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne(
1378  transform::TransformRewriter &rewriter, LinalgOp target,
1380  if (isa<TransformParamTypeInterface>(getLowSize().getType())) {
1381  if (target.hasDynamicShape()) {
1382  auto diag = emitSilenceableError()
1383  << "cannot compute parametric tile sizes for dynamically "
1384  "shaped payload op";
1385  diag.attachNote(target->getLoc()) << "payload op";
1386  return diag;
1387  }
1388 
1389  FailureOr<StaticMultiSizeSpecification> spec = computeStaticMultiTileSizes(
1390  target, getDimension(), getTargetSize(), getDivisor());
1391  if (failed(spec)) {
1392  return emitSilenceableError()
1393  << "failed to compute multi-size tiling sizes";
1394  }
1395 
1396  Builder builder(target.getContext());
1397  results.assign(llvm::map_range(
1398  ArrayRef<int64_t>({spec->lowTileSize, spec->highTileSize,
1399  spec->lowTileSize * spec->lowTripCount}),
1400  [&builder, this](int64_t value) {
1401  return builder.getIntegerAttr(
1402  cast<ParamType>(getLowSize().getType()).getType(), value);
1403  }));
1405  }
1406 
1407  OpBuilder builder(target.getContext());
1408  builder.setInsertionPoint(target);
1409  OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
1410  OpFoldResult divisor = builder.getIndexAttr(getDivisor());
1411  FailureOr<MultiSizeSpecification> spec = computeMultiTileSizes(
1412  builder, target, getDimension(), targetSize, divisor);
1413  if (failed(spec)) {
1414  return emitSilenceableError() << "could not generate tile size computation";
1415  }
1416 
1417  AffineExpr s0 = builder.getAffineSymbolExpr(0);
1418  AffineExpr s1 = builder.getAffineSymbolExpr(1);
1419  Operation *splitPoint =
1420  affine::makeComposedAffineApply(builder, target.getLoc(), s0 * s1,
1421  {spec->lowTileSize, spec->lowTripCount});
1422  Operation *lowTileSize = spec->lowTileSize.getDefiningOp();
1423  Operation *highTileSize = spec->highTileSize.getDefiningOp();
1424  assert(lowTileSize && highTileSize && splitPoint &&
1425  "tile sizes are not produced by operations");
1426  results.reserve(results.size() + 3);
1427  results.push_back(lowTileSize);
1428  results.push_back(highTileSize);
1429  results.push_back(splitPoint);
1431 }
1432 
1433 void transform::MultiTileSizesOp::getEffects(
1435  onlyReadsHandle(getTargetMutable(), effects);
1436  producesHandle(getOperation()->getOpResults(), effects);
1437  if (isa<TransformParamTypeInterface>(getLowSize().getType()))
1438  onlyReadsPayload(effects);
1439  else
1440  modifiesPayload(effects);
1441 }
1442 
1443 LogicalResult transform::MultiTileSizesOp::verify() {
1444  if (getLowSize().getType() != getHighSize().getType() ||
1445  getLowSize().getType() != getSplitPoint().getType()) {
1446  return emitOpError() << "expects all results type to be the same";
1447  }
1448  return success();
1449 }
1450 
1451 //===---------------------------------------------------------------------===//
1452 // PackOp
1453 //===---------------------------------------------------------------------===//
1454 
1455 void transform::PackOp::build(OpBuilder &builder, OperationState &result,
1456  Value target,
1457  ArrayRef<OpFoldResult> mixedPackedSizes) {
1458  SmallVector<int64_t> staticPackedSizes;
1459  SmallVector<Value> dynamicPackedSizes;
1460  dispatchIndexOpFoldResults(mixedPackedSizes, dynamicPackedSizes,
1461  staticPackedSizes);
1462  // Call the default builder which sets up the proper operands segment sizes
1463  // attributes for multiple variadic operands. In the absence of this, horrible
1464  // bugs ensue.
1465  Type linalgOpHType = transform::OperationType::get(
1466  builder.getContext(), GenericOp::getOperationName());
1467  build(builder, result,
1468  /*resultType=*/linalgOpHType,
1469  /*target=*/target,
1470  /*dynamic_sizes=*/dynamicPackedSizes,
1471  /*static_sizes=*/builder.getDenseI64ArrayAttr(staticPackedSizes));
1472 }
1473 
1474 SmallVector<OpFoldResult> transform::PackOp::getMixedPackedSizes() {
1475  Builder b(getContext());
1476  return getMixedValues(getStaticPackedSizes(), getPackedSizes(), b);
1477 }
1478 
1480 transform::PackOp::apply(transform::TransformRewriter &rewriter,
1481  transform::TransformResults &transformResults,
1482  transform::TransformState &state) {
1483  auto targetOps = state.getPayloadOps(getTarget());
1484  // If nothing to pack, propagate success.
1485  if (std::empty(targetOps)) {
1486  transformResults.set(cast<OpResult>(getPackedOp()),
1487  ArrayRef<Operation *>({}));
1489  }
1490  // Fail on multi-op handles.
1491  auto linalgOp = dyn_cast<LinalgOp>(*targetOps.begin());
1492  if (!llvm::hasSingleElement(targetOps) || !linalgOp) {
1493  return emitSilenceableError()
1494  << "requires target to map to exactly 1 LinalgOp (got "
1495  << llvm::range_size(targetOps) << ")";
1496  }
1497  // Fail on mismatched number of pack sizes.
1498  if (getMixedPackedSizes().size() != linalgOp.getNumLoops()) {
1499  return emitSilenceableError()
1500  << "requires number of packed sizes match the number of loops ("
1501  << getMixedPackedSizes().size() << " vs " << linalgOp.getNumLoops()
1502  << ")";
1503  }
1504 
1505  // Unpack handles to constants or actual SSA index values.
1506  SmallVector<OpFoldResult> packedSizes;
1508  state, *this, packedSizes, getMixedPackedSizes());
1509 
1510  rewriter.setInsertionPoint(linalgOp);
1511  FailureOr<PackResult> maybeResult = pack(rewriter, linalgOp, packedSizes);
1512  if (failed(maybeResult))
1513  return emitDefiniteFailure("data tiling failed");
1514 
1515  transformResults.set(cast<OpResult>(getPackedOp()),
1516  {maybeResult->packedLinalgOp.getOperation()});
1518 }
1519 
1520 void transform::PackOp::getEffects(
1522  transform::consumesHandle(getTargetMutable(), effects);
1523  transform::onlyReadsHandle(getPackedSizesMutable(), effects);
1524  transform::producesHandle(getOperation()->getOpResults(), effects);
1525  transform::modifiesPayload(effects);
1526 }
1527 
1528 //===---------------------------------------------------------------------===//
1529 // PackGreedilyOp.
1530 //===---------------------------------------------------------------------===//
1531 
1532 LogicalResult transform::PackGreedilyOp::verify() {
1533  if (!isPermutationVector(getMatmulInnerDimsOrder())) {
1534  return emitOpError() << getMatmulInnerDimsOrderAttrName()
1535  << " is not a valid permutation";
1536  }
1537  // TODO: relax to allow empty once we have another strategy than just matmul.
1538  if (!getMatmulPaddedSizesNextMultipleOf().empty()) {
1539  for (auto [s, nmo] :
1540  llvm::zip_equal(getMixedMatmulPackedSizes(),
1541  getMatmulPaddedSizesNextMultipleOf())) {
1542  std::optional<int64_t> maybeStaticPackedSize = getConstantIntValue(s);
1543  if (nmo != 0 &&
1544  (!maybeStaticPackedSize.has_value() || *maybeStaticPackedSize != 0)) {
1545  return emitOpError() << "at most one of the packed_size and the "
1546  "padded_sizes_next_multiple_of can be nonzero "
1547  "for the matmul strategy";
1548  }
1549  }
1550  }
1551  return success();
1552 }
1553 
1555 PackGreedilyOp::apply(transform::TransformRewriter &rewriter,
1556  transform::TransformResults &transformResults,
1557  transform::TransformState &state) {
1558  SmallVector<Operation *> results;
1559  for (Operation *op : state.getPayloadOps(getTarget())) {
1560  auto linalgOp = dyn_cast<LinalgOp>(op);
1561  if (!linalgOp)
1562  continue;
1563  // linalgOp will be replaced and the insertion point may be invalidated if
1564  // we set it before -> set it after.
1565  rewriter.setInsertionPointAfter(linalgOp);
1566  // Failing to pack greedily is perfectly fine.
1567  // In the future we will want to order packings according to some metric.
1568  FailureOr<PackResult> packResult = packMatmulGreedily(
1569  /*rewriter=*/rewriter,
1570  /*linalgOp=*/linalgOp,
1571  /*mnkPackedSizes=*/getMixedMatmulPackedSizes(),
1572  /*mnkPaddedSizesNextMultipleOf=*/
1573  getMatmulPaddedSizesNextMultipleOf(),
1574  /*mnkOrder=*/getMatmulInnerDimsOrder());
1575  if (succeeded(packResult)) {
1576  results.push_back(packResult->packedLinalgOp);
1577  continue;
1578  }
1579  results.push_back(linalgOp);
1580  }
1581  transformResults.set(cast<OpResult>(getPackedOp()), results);
1583 }
1584 
1585 SmallVector<OpFoldResult> PackGreedilyOp::getMixedMatmulPackedSizes() {
1586  Builder b(getContext());
1587  return getMixedValues(getStaticMatmulPackedSizes(), getMatmulPackedSizes(),
1588  b);
1589 }
1590 
1591 void transform::PackGreedilyOp::getEffects(
1593  transform::consumesHandle(getTargetMutable(), effects);
1594  transform::onlyReadsHandle(getMatmulPackedSizesMutable(), effects);
1595  transform::producesHandle(getOperation()->getOpResults(), effects);
1596  transform::modifiesPayload(effects);
1597 }
1598 
1599 //===---------------------------------------------------------------------===//
1600 // PackTransposeOp
1601 //===---------------------------------------------------------------------===//
1602 
1603 LogicalResult transform::PackTransposeOp::verify() {
1604  if (!isPermutationVector(getInnerPerm())) {
1605  return emitOpError() << getInnerPermAttrName()
1606  << " is not a valid permutation";
1607  }
1608  if (!isPermutationVector(getOuterPerm())) {
1609  return emitOpError() << getOuterPermAttrName()
1610  << " is not a valid permutation";
1611  }
1612  if (getInnerPerm().empty() && getOuterPerm().empty()) {
1613  return emitOpError() << " at least one of " << getInnerPermAttrName()
1614  << " or " << getOuterPermAttrName()
1615  << " must be specified";
1616  }
1617  return success();
1618 }
1619 
1620 namespace {
1621 enum class OuterOrInnerPerm { Outer = 0, Inner = 1 };
1622 } // namespace
1623 
1624 /// Return true if `permutation` is a valid permutation of the
1625 /// `outer_dims_perm` (case OuterOrInnerPerm::Outer) or `inner_dims_pos`
1626 /// (OuterOrInnerPerm::Inner) of the `tensor.pack` or `tensor.unpack` `op.
1627 /// This is the case when the `permutation` rank matches the rank expected by
1628 /// `op` and `permutation` is itself a permutation vector.
1629 /// Return true if either `op` or `permutation` are empty to allow a simpler
1630 /// polymorphic implementation.
1631 template <typename RelayoutOpTy>
1633  RelayoutOpTy op, ArrayRef<int64_t> permutation,
1634  OuterOrInnerPerm outerOrInnerPerm = OuterOrInnerPerm::Outer) {
1635  static_assert(
1636  llvm::is_one_of<RelayoutOpTy, linalg::PackOp, linalg::UnPackOp>::value,
1637  "applies to only pack or unpack operations");
1638  if (!op || permutation.empty())
1639  return true;
1640  size_t innerRank = op.getInnerDimsPos().size();
1641  if (outerOrInnerPerm == OuterOrInnerPerm::Inner)
1642  return permutation.size() == innerRank && isPermutationVector(permutation);
1643  // op.getOuterDimsPerm() may be empty, in which case it is identity.
1644  // Don't rely on it.
1645  if (std::is_same<RelayoutOpTy, linalg::PackOp>::value) {
1646  return permutation.size() == op.getSourceRank() &&
1647  isPermutationVector(permutation);
1648  }
1649  return permutation.size() == op.getDestRank() &&
1650  isPermutationVector(permutation);
1651 }
1652 
1654 transform::PackTransposeOp::apply(transform::TransformRewriter &rewriter,
1655  transform::TransformResults &transformResults,
1656  transform::TransformState &state) {
1657  auto packOrUnpackOps = state.getPayloadOps(getTargetPackOrUnPackOp());
1658  auto linalgOps = state.getPayloadOps(getTargetLinalgOp());
1659  // Step 1. If nothing to pack, propagate success.
1660  if (std::empty(packOrUnpackOps)) {
1661  transformResults.set(cast<OpResult>(getPackedOp()), {});
1662  transformResults.set(cast<OpResult>(getPackOp()), {});
1663  transformResults.set(cast<OpResult>(getUnPackOp()), {});
1665  }
1666 
1667  // Step 2. Bunch of runtime sanity check and error messages.
1668  // Step 2.1. Fail on multi-op handles.
1669  if (!llvm::hasSingleElement(packOrUnpackOps) ||
1670  !llvm::hasSingleElement(linalgOps)) {
1671  return emitSilenceableError()
1672  << "requires target to map to exactly 1 "
1673  "packing op and 1 packed op ("
1674  << "got " << llvm::range_size(packOrUnpackOps) << " and "
1675  << llvm::range_size(linalgOps) << ")";
1676  }
1677 
1678  // Step 2.2. Fail on wrong type.
1679  auto packOp = dyn_cast<linalg::PackOp>(*packOrUnpackOps.begin());
1680  auto unPackOp = dyn_cast<linalg::UnPackOp>(*packOrUnpackOps.begin());
1681  if ((!packOp && !unPackOp)) {
1682  return emitSilenceableError() << "requires target to map to a "
1683  "linalg.pack or linalg.unpack";
1684  }
1685  LinalgOp linalgOpTarget = dyn_cast<LinalgOp>(*linalgOps.begin());
1686  if (!linalgOpTarget)
1687  return emitSilenceableError() << "requires a LinalgOp target";
1688 
1689  // Step 2.3. Fail if we can't get the producer / consumer Linalg op.
1690  LinalgOp linalgOp;
1691  if (packOp && packOp.getResult().hasOneUse())
1692  linalgOp = dyn_cast<LinalgOp>(*(packOp.getResult().getUsers().begin()));
1693  else if (unPackOp)
1694  linalgOp = unPackOp.getSource().getDefiningOp<LinalgOp>();
1695  if (linalgOp != linalgOpTarget) {
1696  auto errorMsg =
1697  packOp ? StringLiteral{"not a single use by the LinalgOp target"}
1698  : StringLiteral{"not produced by the LinalgOp target"};
1699  return emitSilenceableError() << errorMsg;
1700  }
1701 
1702  // Step 2.4. If we have an UnPackOp, we need to fetch the symmetrical
1703  // PackOp.
1704  if (unPackOp) {
1705  assert(!packOp && "packOp must be null on entry when unPackOp is not null");
1706  OpOperand *packUse = linalgOp.getDpsInitOperand(
1707  cast<OpResult>(unPackOp.getSource()).getResultNumber());
1708  packOp = dyn_cast_or_null<linalg::PackOp>(packUse->get().getDefiningOp());
1709  if (!packOp || !packOp.getResult().hasOneUse())
1710  return emitSilenceableError() << "could not find matching pack op";
1711  }
1712 
1713  // Step 2.5. Fail if any permutation does not validate.
1714  for (auto permType : {OuterOrInnerPerm::Outer, OuterOrInnerPerm::Inner}) {
1715  ArrayRef<int64_t> perm =
1716  (permType == OuterOrInnerPerm::Outer) ? getOuterPerm() : getInnerPerm();
1717  auto errorMsg = (permType == OuterOrInnerPerm::Outer)
1718  ? StringLiteral{"invalid outer_perm"}
1719  : StringLiteral{"invalid inner_perm"};
1720  if (!isValidPackingPermutation(packOp, perm, permType) ||
1721  !isValidPackingPermutation(unPackOp, perm, permType)) {
1722  Operation *packOrUnpackOp =
1723  unPackOp ? unPackOp.getOperation() : packOp.getOperation();
1724  return emitSilenceableError() << errorMsg << ": " << *packOrUnpackOp;
1725  }
1726  }
1727 
1728  // From here on, packOp and linalgOp are always present, unPackOp may or may
1729  // not be present.
1730  assert(packOp && linalgOp && "unexpected null op");
1731 
1732  // Step 3. Actually transpose the ops.
1733  FailureOr<PackTransposeResult> res = packTranspose(
1734  rewriter, packOp, linalgOp, unPackOp, getOuterPerm(), getInnerPerm());
1735  // Preconditions have been checked, it is an error to fail here.
1736  assert(succeeded(res) && "unexpected packTranspose failure");
1737 
1738  // Step 4. Return results.
1739  transformResults.set(cast<OpResult>(getPackOp()), {res->transposedPackOp});
1740  transformResults.set(cast<OpResult>(getPackedOp()),
1741  {res->transposedLinalgOp});
1742  if (unPackOp) {
1743  transformResults.set(cast<OpResult>(getUnPackOp()),
1744  {res->transposedUnPackOp});
1745  } else {
1746  transformResults.set(cast<OpResult>(getUnPackOp()), {});
1747  }
1748 
1750 }
1751 
1752 //===---------------------------------------------------------------------===//
1753 // PadOp
1754 //===---------------------------------------------------------------------===//
1755 
1756 void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
1757  ArrayRef<int64_t> paddingDimensions,
1758  ArrayRef<int64_t> padToMultipleOf,
1759  ArrayRef<int64_t> nofoldFlags,
1760  ArrayRef<Attribute> transposePaddings,
1761  StringRef copyBackOp) {
1762  auto resultType = transform::AnyOpType::get(b.getContext());
1763  return build(/*builder=*/b,
1764  /*result=*/result,
1765  /*types=*/TypeRange{resultType, resultType},
1766  /*target=*/target,
1767  /*paddingValues=*/ArrayAttr(), // let inference handle this
1768  /*paddingDimensions=*/b.getI64ArrayAttr(paddingDimensions),
1769  /*padToMultipleOf=*/ValueRange{},
1770  /*padToMultipleOf=*/
1771  (padToMultipleOf.empty()
1772  ? DenseI64ArrayAttr()
1773  : b.getDenseI64ArrayAttr(padToMultipleOf)),
1774  /*nofoldFlags=*/b.getI64ArrayAttr(nofoldFlags),
1775  /*transposePaddings=*/b.getArrayAttr(transposePaddings),
1776  /*copyBackOp=*/b.getStringAttr(copyBackOp));
1777 }
1778 
1779 void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
1780  ArrayRef<int64_t> paddingDimensions,
1781  ArrayRef<OpFoldResult> mixedPadToMultipleOf,
1782  ArrayRef<int64_t> nofoldFlags,
1783  ArrayRef<Attribute> transposePaddings,
1784  StringRef copyBackOp) {
1785  auto resultType = transform::AnyOpType::get(b.getContext());
1786  SmallVector<int64_t> staticPadToMultipleOf;
1787  SmallVector<Value> dynamicPadToMultipleOf;
1788  dispatchIndexOpFoldResults(mixedPadToMultipleOf, dynamicPadToMultipleOf,
1789  staticPadToMultipleOf);
1790  return build(/*builder=*/b,
1791  /*result=*/result,
1792  /*types=*/TypeRange{resultType, resultType},
1793  /*target=*/target,
1794  /*paddingValues=*/ArrayAttr(), // let inference handle this
1795  /*paddingDimensions=*/b.getI64ArrayAttr(paddingDimensions),
1796  /*padToMultipleOf=*/dynamicPadToMultipleOf,
1797  /*padToMultipleOf=*/staticPadToMultipleOf,
1798  /*nofoldFlags=*/b.getI64ArrayAttr(nofoldFlags),
1799  /*transposePaddings=*/b.getArrayAttr(transposePaddings),
1800  /*copyBackOp=*/b.getStringAttr(copyBackOp));
1801 }
1802 
1803 void PadOp::getEffects(
1805  consumesHandle(getTargetMutable(), effects);
1806  onlyReadsHandle(getPadToMultipleOfMutable(), effects);
1807  producesHandle(getOperation()->getOpResults(), effects);
1808  modifiesPayload(effects);
1809 }
1810 
1811 SmallVector<OpFoldResult> PadOp::getMixedPadToMultipleOf() {
1812  Builder b(getContext());
1813  return getMixedValues(getStaticPadToMultipleOf(), getPadToMultipleOf(), b);
1814 }
1815 
1817 transform::PadOp::apply(transform::TransformRewriter &rewriter,
1818  transform::TransformResults &results,
1819  transform::TransformState &state) {
1820  auto transformOp = cast<TransformOpInterface>(getOperation());
1821  SmallVector<Operation *> paddedOps, padOps, copyBackOps;
1822 
1823  for (Operation *target : state.getPayloadOps(getTarget())) {
1824  auto linalgTarget = dyn_cast<LinalgOp>(target);
1825  if (!linalgTarget) {
1826  auto diag = emitSilenceableError() << "expected LinalgOp target";
1827  diag.attachNote(target->getLoc()) << "target op";
1828  return diag;
1829  }
1830 
1831  // Convert the integer packing flags to booleans.
1832  SmallVector<bool> nofoldFlags;
1833  for (int64_t packPadding :
1834  extractFromIntegerArrayAttr<int64_t>(getNofoldFlags()))
1835  nofoldFlags.push_back(static_cast<bool>(packPadding));
1836 
1837  // Convert the padding values to attributes.
1838  SmallVector<Attribute> paddingValues;
1839  for (auto const &it :
1840  llvm::zip(getPaddingValues(), linalgTarget->getOperandTypes())) {
1841  auto attr = dyn_cast<TypedAttr>(std::get<0>(it));
1842  if (!attr) {
1843  emitOpError("expects padding values to be typed attributes");
1845  }
1846  Type elementType = getElementTypeOrSelf(std::get<1>(it));
1847  // Try to parse string attributes to obtain an attribute of element type.
1848  if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
1849  auto parsedAttr = dyn_cast_if_present<TypedAttr>(parseAttribute(
1850  stringAttr, getContext(), elementType,
1851  /*numRead=*/nullptr, /*isKnownNullTerminated=*/true));
1852  if (!parsedAttr || parsedAttr.getType() != elementType) {
1853  auto diag = this->emitOpError("expects a padding that parses to ")
1854  << elementType << ", got " << std::get<0>(it);
1855  diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";
1857  }
1858  paddingValues.push_back(parsedAttr);
1859  continue;
1860  }
1861  // Otherwise, add the attribute directly.
1862  if (attr.getType() != elementType) {
1863  auto diag = this->emitOpError("expects a padding value of type ")
1864  << elementType << ", got " << attr;
1865  diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";
1867  }
1868  paddingValues.push_back(attr);
1869  }
1870 
1871  // Extract the transpose vectors.
1872  SmallVector<SmallVector<int64_t>> transposePaddings;
1873  for (Attribute transposeVector : cast<ArrayAttr>(getTransposePaddings()))
1874  transposePaddings.push_back(extractFromIntegerArrayAttr<int64_t>(
1875  cast<ArrayAttr>(transposeVector)));
1876 
1877  LinalgOp paddedOp;
1879  options.paddingDimensions =
1880  extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
1881 
1882  SmallVector<int64_t> padToMultipleOf;
1884  state, transformOp, getMixedPadToMultipleOf(), padToMultipleOf);
1885  if (!status.succeeded())
1886  return status;
1887  if (padToMultipleOf.empty())
1888  padToMultipleOf =
1889  SmallVector<int64_t>(options.paddingDimensions.size(), 1);
1890 
1891  options.padToMultipleOf = padToMultipleOf;
1892  options.paddingValues = paddingValues;
1893  options.nofoldFlags = nofoldFlags;
1894  if (getCopyBackOp() ==
1895  bufferization::MaterializeInDestinationOp::getOperationName()) {
1898  } else if (getCopyBackOp() == linalg::CopyOp::getOperationName()) {
1900  } else if (getCopyBackOp() == kCopyOpNone) {
1902  } else {
1903  llvm_unreachable("unsupported copy_back op");
1904  }
1905 
1906  SmallVector<Value> replacements;
1907  SmallVector<tensor::PadOp> newPadOps;
1908  if (failed(rewriteAsPaddedOp(rewriter, linalgTarget, options, paddedOp,
1909  replacements, newPadOps))) {
1910  auto diag = emitSilenceableError() << "failed to pad op";
1911  diag.attachNote(target->getLoc()) << "target op";
1912  return diag;
1913  }
1914 
1915  // We need to perform our own replacement here because this API is still
1916  // used in patterns that "pad and hoist", for which the replacement values
1917  // need to be different.
1918  // TODO: clean this up and stop "pad and hoist" behavior more globally now
1919  // that we have more composable abstractions.
1920  rewriter.replaceOp(linalgTarget, replacements);
1921  paddedOps.push_back(paddedOp);
1922  padOps.append(newPadOps.begin(), newPadOps.end());
1923  if (options.copyBackOp != LinalgPaddingOptions::CopyBackOp::None) {
1924  for (Value v : replacements) {
1925  Operation *copyBackOp = v.getDefiningOp();
1926  if (!llvm::is_contained(copyBackOps, copyBackOp))
1927  copyBackOps.push_back(copyBackOp);
1928  }
1929  }
1930  }
1931 
1932  results.set(cast<OpResult>(getPadded()), paddedOps);
1933  results.set(cast<OpResult>(getPad()), padOps);
1934  results.set(cast<OpResult>(getCopy()), copyBackOps);
1936 }
1937 
1938 LogicalResult transform::PadOp::verify() {
1939  SmallVector<int64_t> nofoldFlags =
1940  extractFromIntegerArrayAttr<int64_t>(getNofoldFlags());
1941  if (any_of(nofoldFlags, [](int64_t packPadding) {
1942  return packPadding != 0 && packPadding != 1;
1943  })) {
1944  return emitOpError()
1945  << "expects nofold_flags to contain booleans (0/1), found "
1946  << getNofoldFlags();
1947  }
1948 
1949  SmallVector<int64_t> paddingDimensions =
1950  extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
1951  if (any_of(paddingDimensions,
1952  [](int64_t paddingDimension) { return paddingDimension < 0; })) {
1953  return emitOpError() << "expects padding_dimensions to contain positive "
1954  "integers, found "
1955  << getPaddingDimensions();
1956  }
1957  if (!getMixedPadToMultipleOf().empty()) {
1958  if (getMixedPadToMultipleOf().size() != paddingDimensions.size()) {
1959  return emitOpError() << "expects as many multiples as padding_dimensions";
1960  }
1961  }
1962  ArrayAttr transposes = getTransposePaddings();
1963  for (Attribute attr : transposes) {
1964  SmallVector<int64_t> transpose = extractFromIntegerArrayAttr<int64_t>(attr);
1965  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
1966  if (!std::is_permutation(sequence.begin(), sequence.end(),
1967  transpose.begin(), transpose.end())) {
1968  return emitOpError()
1969  << "expects transpose_paddings to be a permutation, found "
1970  << attr;
1971  }
1972  }
1973  if (getCopyBackOp() !=
1974  bufferization::MaterializeInDestinationOp::getOperationName() &&
1975  getCopyBackOp() != linalg::CopyOp::getOperationName() &&
1976  getCopyBackOp() != kCopyOpNone)
1977  return emitOpError() << "invalid copy_back_op";
1978  return success();
1979 }
1980 
1981 //===---------------------------------------------------------------------===//
1982 // HoistPadOp
1983 //===---------------------------------------------------------------------===//
1984 
1985 DiagnosedSilenceableFailure transform::HoistPadBuildPackingLoopNestOp::apply(
1986  transform::TransformRewriter &rewriter,
1987  transform::TransformResults &transformResults,
1988  transform::TransformState &state) {
1989  auto targetOps = state.getPayloadOps(getTarget());
1990  auto loopOps = state.getPayloadOps(getLoop());
1991  if (!llvm::hasSingleElement(targetOps) || !llvm::hasSingleElement(loopOps)) {
1992  return emitDefiniteFailure()
1993  << "requires exactly one target and one loop handle (got "
1994  << llvm::range_size(targetOps) << " and "
1995  << llvm::range_size(loopOps) << ")";
1996  }
1997 
1998  auto padOp = dyn_cast_or_null<tensor::PadOp>(*targetOps.begin());
1999  auto loopOp = dyn_cast_or_null<scf::ForOp>(*loopOps.begin());
2000  if (!padOp || !loopOp)
2001  return emitDefiniteFailure() << "requires exactly 2 non-null handles";
2002 
2003  FailureOr<linalg::detail::PackingResult> result =
2004  linalg::detail::buildPackingLoopNest(rewriter, padOp, loopOp,
2005  getTranspose());
2006  if (failed(result))
2007  return emitDefiniteFailure() << "could not build packing loop nest";
2008 
2009  if (result->clonedLoopIvs.empty()) {
2010  transformResults.set(cast<OpResult>(getPackingLoop()),
2011  {result->hoistedPadOp.getOperation()});
2013  }
2014  auto outerPackedLoop =
2015  scf::getForInductionVarOwner(result->clonedLoopIvs.front());
2016  transformResults.set(cast<OpResult>(getPackingLoop()),
2017  {outerPackedLoop.getOperation()});
2019 }
2020 
2022  ArrayRef<int64_t> transpose = getTranspose();
2023  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2024  if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),
2025  transpose.end())) {
2026  return emitOpError() << "expects transpose to be a permutation, found "
2027  << getTranspose();
2028  }
2029  return success();
2030 }
2031 
2032 void transform::HoistPadBuildPackingLoopNestOp::getEffects(
2034  transform::onlyReadsHandle(getTargetMutable(), effects);
2035  transform::onlyReadsHandle(getLoopMutable(), effects);
2036  transform::producesHandle(getOperation()->getOpResults(), effects);
2037  transform::modifiesPayload(effects);
2038 }
2039 
2041 transform::HoistPadOp::applyToOne(transform::TransformRewriter &rewriter,
2042  tensor::PadOp target,
2044  transform::TransformState &state) {
2045  tensor::PadOp hoistedPadOp;
2046  SmallVector<TransposeOp> transposeOps;
2047  FailureOr<Value> result =
2048  hoistPaddingOnTensors(rewriter, target, getNumLoops(), getTranspose(),
2049  hoistedPadOp, transposeOps);
2050  if (succeeded(result)) {
2051  // We need to perform our own replacement here because this API is still
2052  // used in patterns that "pad and hoist", for which the replacement values
2053  // need to be different.
2054  // TODO: clean this up and stop "pad and hoist" behavior more globally now
2055  // that we have more composable abstractions.
2056  rewriter.replaceOp(target, *result);
2057  results.push_back(hoistedPadOp);
2059  }
2060  return emitDefaultSilenceableFailure(target);
2061 }
2062 
2063 LogicalResult transform::HoistPadOp::verify() {
2064  ArrayRef<int64_t> transpose = getTranspose();
2065  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2066  if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),
2067  transpose.end())) {
2068  return emitOpError() << "expects transpose to be a permutation, found "
2069  << getTranspose();
2070  }
2071  return success();
2072 }
2073 
2074 //===----------------------------------------------------------------------===//
2075 // PromoteOp
2076 //===----------------------------------------------------------------------===//
2077 
2079 transform::PromoteOp::applyToOne(transform::TransformRewriter &rewriter,
2080  LinalgOp target,
2082  transform::TransformState &state) {
2083  LinalgPromotionOptions promotionOptions;
2084  if (!getOperandsToPromote().empty())
2085  promotionOptions = promotionOptions.setOperandsToPromote(
2086  extractFromIntegerArrayAttr<int64_t>(getOperandsToPromote()));
2087  if (getUseFullTilesByDefault())
2088  promotionOptions = promotionOptions.setUseFullTileBuffersByDefault(
2089  getUseFullTilesByDefault());
2090  if (getUseAlloca())
2091  promotionOptions = promotionOptions.setUseAlloca(getUseAlloca());
2092  if (!getUseFullTileBuffers().empty())
2093  promotionOptions = promotionOptions.setUseFullTileBuffers(
2094  llvm::to_vector(getUseFullTileBuffers().getAsValueRange<BoolAttr>()));
2095  if (getAlignment().has_value())
2096  promotionOptions = promotionOptions.setAlignment(*getAlignment());
2097  if (getMemorySpace().has_value())
2098  promotionOptions = promotionOptions.setMemorySpace(*getMemorySpace());
2099 
2100  if (getMapping().has_value()) {
2101  // The mapping should only contain an element
2102  auto mapping = *getMapping();
2103  if (mapping.size() > 1)
2104  return emitDefaultDefiniteFailure(target);
2105 
2106  auto addressSpace = cast<mlir::gpu::GPUMemorySpaceMappingAttr>(mapping[0]);
2107 
2108  if (addressSpace.getAddressSpace() ==
2109  mlir::gpu::GPUDialect::getWorkgroupAddressSpace()) {
2110  promotionOptions =
2111  promotionOptions
2115  .setUseFullTileBuffers({false, false});
2116  } else if (addressSpace.getAddressSpace() ==
2117  mlir::gpu::GPUDialect::getPrivateAddressSpace()) {
2118  promotionOptions =
2119  promotionOptions
2123  .setUseFullTileBuffers({false, false});
2124  } else {
2125  return emitDefaultDefiniteFailure(target);
2126  }
2127  }
2128 
2129  if (failed(promoteSubviewsPrecondition(target, promotionOptions)))
2130  return emitDefaultDefiniteFailure(target);
2131 
2132  rewriter.setInsertionPoint(target);
2133  FailureOr<LinalgOp> res = promoteSubViews(rewriter, target, promotionOptions);
2134  if (failed(res))
2135  return emitDefaultDefiniteFailure(target);
2136  results.push_back(target);
2138 }
2139 
2140 //===----------------------------------------------------------------------===//
2141 // ReplaceOp
2142 //===----------------------------------------------------------------------===//
2143 
2145 transform::ReplaceOp::apply(transform::TransformRewriter &rewriter,
2146  TransformResults &transformResults,
2147  TransformState &state) {
2148  auto payload = state.getPayloadOps(getTarget());
2149 
2150  // Check for invalid targets.
2151  for (Operation *target : payload) {
2152  if (target->getNumOperands() > 0)
2153  return emitDefiniteFailure() << "expected target without operands";
2154  if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
2155  target->getNumRegions() > 0)
2156  return emitDefiniteFailure()
2157  << "expected target that is isolated from above";
2158  }
2159 
2160  // Clone and replace.
2161  Operation *pattern = &getBodyRegion().front().front();
2162  SmallVector<Operation *> replacements;
2163  for (Operation *target : payload) {
2164  if (getOperation()->isAncestor(target))
2165  continue;
2166  rewriter.setInsertionPoint(target);
2167  Operation *replacement = rewriter.clone(*pattern);
2168  rewriter.replaceOp(target, replacement->getResults());
2169  replacements.push_back(replacement);
2170  }
2171  transformResults.set(cast<OpResult>(getReplacement()), replacements);
2173 }
2174 
2175 void transform::ReplaceOp::getEffects(
2177  consumesHandle(getTargetMutable(), effects);
2178  producesHandle(getOperation()->getOpResults(), effects);
2179  modifiesPayload(effects);
2180 }
2181 
2182 LogicalResult transform::ReplaceOp::verify() {
2183  if (!getBodyRegion().hasOneBlock())
2184  return emitOpError() << "expected one block";
2185  if (std::distance(getBodyRegion().front().begin(),
2186  getBodyRegion().front().end()) != 1)
2187  return emitOpError() << "expected one operation in block";
2188  Operation *replacement = &getBodyRegion().front().front();
2189  if (replacement->getNumOperands() > 0)
2190  return replacement->emitOpError()
2191  << "expected replacement without operands";
2192  if (!replacement->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
2193  replacement->getNumRegions() > 0)
2194  return replacement->emitOpError()
2195  << "expect op that is isolated from above";
2196  return success();
2197 }
2198 
2199 //===----------------------------------------------------------------------===//
2200 // ScalarizeOp
2201 //===----------------------------------------------------------------------===//
2202 
2204 transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
2205  LinalgOp target,
2207  transform::TransformState &state) {
2208  scf::SCFTilingOptions tilingOptions;
2209  tilingOptions.setTileSizeComputationFunction([&](OpBuilder &b, Operation *) {
2210  SmallVector<OpFoldResult> tileSizes;
2211  Location loc = target.getLoc();
2212  SmallVector<OpFoldResult> allShapeSizes =
2213  target.createFlatListOfOperandDims(b, loc);
2214  AffineMap map = target.getShapesToLoopsMap();
2215  if (!map)
2216  return tileSizes;
2217  SmallVector<OpFoldResult> shapeSizes =
2219  allShapeSizes);
2220  // If the shape size is dynamic, tile by 1.
2221  // Otherwise, do not tile (i.e. tile size 0).
2222  for (OpFoldResult shapeSize : shapeSizes) {
2223  tileSizes.push_back(getConstantIntValue(shapeSize) ? b.getIndexAttr(0)
2224  : b.getIndexAttr(1));
2225  }
2226  return tileSizes;
2227  });
2228  SmallVector<int64_t> emptyTileSizes;
2229  rewriter.setInsertionPoint(target);
2230  FailureOr<scf::SCFTilingResult> maybeTilingResult = tileUsingSCF(
2231  rewriter, cast<TilingInterface>(target.getOperation()), tilingOptions);
2232  if (failed(maybeTilingResult))
2233  return emitDefaultDefiniteFailure(target);
2234 
2235  if (target->getNumResults())
2236  rewriter.replaceOp(target, maybeTilingResult->mergeResult.replacements);
2237  else
2238  rewriter.eraseOp(target);
2239 
2240  results.reserve(maybeTilingResult->tiledOps.size());
2241  for (Operation *tiled : maybeTilingResult->tiledOps)
2242  results.push_back(tiled);
2244 }
2245 
2246 //===----------------------------------------------------------------------===//
2247 // ConvertToLoopsOp
2248 //===----------------------------------------------------------------------===//
2249 
2251 transform::ConvertToLoopsOp::apply(transform::TransformRewriter &rewriter,
2252  transform::TransformResults &results,
2253  transform::TransformState &state) {
2255  for (Operation *target : state.getPayloadOps(getTarget())) {
2256  auto tilingOp = dyn_cast<TilingInterface>(*target);
2257  if (!tilingOp) {
2259  emitSilenceableError()
2260  << "expected the payload to implement TilingInterface";
2261  diag.attachNote(target->getLoc()) << "payload op";
2262  return diag;
2263  }
2264  rewriter.setInsertionPoint(target);
2265  FailureOr<SmallVector<scf::ForOp>> generatedLoops =
2266  scf::lowerToLoopsUsingSCFForOp(rewriter, tilingOp);
2267  if (failed(generatedLoops))
2268  return emitDefaultDefiniteFailure(target);
2269  for (scf::ForOp &loop : *generatedLoops) {
2270  loops.push_back(loop.getOperation());
2271  }
2272  rewriter.eraseOp(target);
2273  }
2274  results.set(cast<OpResult>(getResult()), loops);
2276 }
2277 
2278 //===----------------------------------------------------------------------===//
2279 // RewriteInDestinationPassingStyleOp
2280 //===----------------------------------------------------------------------===//
2281 
2283 transform::RewriteInDestinationPassingStyleOp::applyToOne(
2284  transform::TransformRewriter &rewriter, Operation *target,
2286  transform::TransformState &state) {
2288  rewriter.setInsertionPoint(target);
2289  FailureOr<Operation *> maybeResult =
2291  .Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp>(
2292  [&rewriter](auto op) {
2293  return rewriteInDestinationPassingStyle(rewriter, op);
2294  });
2295  if (failed(maybeResult))
2296  return emitDefaultSilenceableFailure(target);
2297  results.push_back(*maybeResult);
2299 }
2300 
2301 //===----------------------------------------------------------------------===//
2302 // SplitOp
2303 //===----------------------------------------------------------------------===//
2304 
2306 SplitOp::apply(transform::TransformRewriter &rewriter,
2307  TransformResults &results, TransformState &state) {
2308  // Collect the dynamic split points if provided.
2309  SmallVector<Operation *> payload =
2310  llvm::to_vector(state.getPayloadOps(getTarget()));
2311 
2312  bool isMultiwaySplit = getMultiway();
2313 
2314  if (isMultiwaySplit && !llvm::hasSingleElement(payload)) {
2315  return mlir::emitSilenceableFailure(getLoc())
2316  << "requires exactly one target when "
2317  "multiway split is enabled (got "
2318  << llvm::range_size(payload) << ")";
2319  }
2320 
2321  SmallVector<OpFoldResult> chunkSizes;
2322 
2323  if (!isMultiwaySplit)
2324  chunkSizes.reserve(payload.size());
2325 
2326  if (getDynamicChunkSizes()) {
2328  if (isa<TransformHandleTypeInterface>(getDynamicChunkSizes().getType())) {
2329  chunkSizes = llvm::to_vector(llvm::map_range(
2330  state.getPayloadOps(getDynamicChunkSizes()), [&](Operation *op) {
2331  if (op->getNumResults() != 1 ||
2332  !op->getResult(0).getType().isIndex()) {
2333  diag = emitSilenceableError()
2334  << "expected dynamic split point handle to point to a "
2335  "single-result index-typed op";
2336  diag.attachNote(op->getLoc()) << "dynamic split point";
2337  }
2338  return OpFoldResult(op->getResult(0));
2339  }));
2340  } else {
2341  chunkSizes = llvm::to_vector(
2342  llvm::map_range(state.getParams(getDynamicChunkSizes()),
2343  [](Attribute attr) { return OpFoldResult(attr); }));
2344  }
2345  if (diag.isSilenceableFailure())
2346  return diag;
2347 
2348  // For multiway split, a single payload is expected to have multiple
2349  // split points.
2350  if (!isMultiwaySplit && chunkSizes.size() != payload.size()) {
2351  return emitDefiniteFailure()
2352  << "expected the dynamic split point handle to point to as "
2353  "many operations ("
2354  << chunkSizes.size() << ") as the target handle ("
2355  << payload.size() << ")";
2356  }
2357  } else {
2358  chunkSizes.resize(payload.size(),
2359  rewriter.getIndexAttr(getStaticChunkSizes()));
2360  }
2361 
2362  auto checkStructuredOpAndDimensions =
2363  [&](LinalgOp linalgOp, Location loc) -> DiagnosedSilenceableFailure {
2364  if (!linalgOp) {
2365  auto diag = emitSilenceableError() << "only applies to structured ops";
2366  diag.attachNote(loc) << "target op";
2367  return diag;
2368  }
2369 
2370  if (getDimension() >= linalgOp.getNumLoops()) {
2371  auto diag = emitSilenceableError() << "dimension " << getDimension()
2372  << " does not exist in target op";
2373  diag.attachNote(loc) << "target op";
2374  return diag;
2375  }
2377  };
2378 
2379  auto checkFailureInSplitting =
2380  [&](bool hasFailed, Location loc) -> DiagnosedSilenceableFailure {
2381  if (hasFailed) {
2382  auto diag = emitDefiniteFailure() << "internal failure in splitting";
2383  diag.attachNote(loc) << "target op";
2384  return diag;
2385  }
2387  };
2388 
2389  SmallVector<Operation *> opList;
2390  if (isMultiwaySplit) {
2391 
2392  // Split a single target operation at multiple points.
2393  TilingInterface head, tail;
2394  Operation *target = payload.front();
2395 
2396  LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
2397 
2398  // Check that the target is a valid LinalgOp with correct dimensions.
2400  checkStructuredOpAndDimensions(linalgOp, target->getLoc());
2401  if (diag.isSilenceableFailure())
2402  return diag;
2403 
2404  for (auto &&[idx, chunkSize] : llvm::enumerate(chunkSizes)) {
2405 
2406  if (idx > 0)
2407  target = tail.getOperation();
2408 
2409  if (!target)
2410  break;
2411 
2412  linalgOp = cast<LinalgOp>(target);
2413  Location loc = target->getLoc();
2414 
2415  rewriter.setInsertionPoint(linalgOp);
2416  std::tie(head, tail) = linalg::splitOp(
2417  rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2418  getDimension(), chunkSize);
2419 
2420  // Propagate errors.
2422  checkFailureInSplitting(!head && !tail, loc);
2423  if (diag.isDefiniteFailure())
2424  return diag;
2425 
2426  opList.push_back(head.getOperation());
2427  }
2428 
2429  // Append any leftover parts to the end of the result list.
2430  if (tail)
2431  opList.push_back(tail.getOperation());
2432 
2433  } else {
2434  // Split each target operation.
2435  SmallVector<Operation *> first, second;
2436  Operation *noSecondPart = nullptr;
2437  for (const auto &pair : llvm::zip(payload, chunkSizes)) {
2438  Operation *target = std::get<0>(pair);
2439  Location loc = target->getLoc();
2440  LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
2442  checkStructuredOpAndDimensions(linalgOp, target->getLoc());
2443 
2444  if (diag.isSilenceableFailure())
2445  return diag;
2446 
2447  rewriter.setInsertionPoint(linalgOp);
2448  std::tie(first.emplace_back(), second.emplace_back()) = linalg::splitOp(
2449  rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2450  getDimension(), std::get<1>(pair));
2451 
2452  // Propagate errors.
2453  DiagnosedSilenceableFailure diagSplit =
2454  checkFailureInSplitting(!first.back() && !second.back(), loc);
2455  if (diagSplit.isDefiniteFailure())
2456  return diag;
2457 
2458  // Do not add null second parts.
2459  if (!second.back()) {
2460  noSecondPart = target;
2461  second.pop_back();
2462  }
2463  }
2464 
2465  if (second.size() != first.size() && !second.empty()) {
2466  auto diag = emitSilenceableError()
2467  << "splitting does not produce the second part for a subset "
2468  "of targets";
2469  diag.attachNote()
2470  << "expected splitting to produce the second part of all "
2471  "or none of the targets";
2472  diag.attachNote(noSecondPart->getLoc())
2473  << "first target with no second part";
2474  return diag;
2475  }
2476 
2477  opList.append(first);
2478  if (second.size())
2479  opList.append(second);
2480  }
2481  results.set(cast<OpResult>(getSplitList()), opList);
2483 }
2484 
2485 void SplitOp::getEffects(
2487  consumesHandle(getTargetMutable(), effects);
2488  if (getDynamicChunkSizes())
2489  onlyReadsHandle(getDynamicChunkSizesMutable(), effects);
2490  producesHandle(getOperation()->getOpResults(), effects);
2491  modifiesPayload(effects);
2492 }
2493 
2494 ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) {
2495  OpAsmParser::UnresolvedOperand target, dynamicChunkSizes;
2496  IntegerAttr staticChunkSizes;
2497  if (parser.parseOperand(target) || parser.parseKeyword("after"))
2498  return failure();
2499 
2500  OptionalParseResult dynamicPointParseResult =
2501  parser.parseOptionalOperand(dynamicChunkSizes);
2502  if (!dynamicPointParseResult.has_value()) {
2503  int64_t staticChunkSizesValue;
2504  if (failed(parser.parseInteger(staticChunkSizesValue)))
2505  return failure();
2506 
2507  staticChunkSizes =
2508  parser.getBuilder().getI64IntegerAttr(staticChunkSizesValue);
2509  }
2510 
2511  Type targetType;
2512  if (parser.parseOptionalAttrDict(result.attributes) ||
2513  parser.parseColonType(targetType) ||
2514  parser.resolveOperand(target, targetType, result.operands)) {
2515  return failure();
2516  }
2517  if (dynamicPointParseResult.has_value()) {
2518  Type ChunkSizesType;
2519  if (failed(*dynamicPointParseResult) || parser.parseComma() ||
2520  parser.parseType(ChunkSizesType) ||
2521  parser.resolveOperand(dynamicChunkSizes, ChunkSizesType,
2522  result.operands)) {
2523  return failure();
2524  }
2525 
2526  staticChunkSizes =
2527  parser.getBuilder().getI64IntegerAttr(ShapedType::kDynamic);
2528  }
2529 
2530  result.addAttribute(
2531  SplitOp::getStaticChunkSizesAttrName(result.name).getValue(),
2532  staticChunkSizes);
2533  result.addTypes(targetType);
2534  return success();
2535 }
2536 
2537 void SplitOp::print(OpAsmPrinter &printer) {
2538  printer << " " << getTarget() << " after ";
2539  int64_t staticChunkSize = static_cast<int64_t>(getStaticChunkSizes());
2540  if (staticChunkSize != ShapedType::kDynamic)
2541  printer << staticChunkSize;
2542  else
2543  printer << getDynamicChunkSizes();
2544  printer << " ";
2545  printer.printOptionalAttrDict(getOperation()->getAttrs(),
2546  {getStaticChunkSizesAttrName()});
2547  printer << " : " << getTarget().getType();
2548  if (staticChunkSize == ShapedType::kDynamic)
2549  printer << ", " << getDynamicChunkSizes().getType();
2550 }
2551 
2552 LogicalResult SplitOp::verify() {
2553  if ((static_cast<int64_t>(getStaticChunkSizes()) != ShapedType::kDynamic) ^
2554  (getDynamicChunkSizes() == nullptr)) {
2555  return emitOpError() << "expects either a dynamic or a static split "
2556  "point to be provided";
2557  }
2558  return success();
2559 }
2560 
2561 //===----------------------------------------------------------------------===//
2562 // SplitReductionOp
2563 //===----------------------------------------------------------------------===//
2564 
2565 void transform::SplitReductionOp::build(
2566  OpBuilder &builder, OperationState &result, Value target,
2567  int64_t splitFactor, int64_t insertSplitDimension, bool innerParallel,
2568  bool useScalingAlgorithm, bool useAlloc) {
2569  MLIRContext *ctx = builder.getContext();
2570  result.addOperands(target);
2571  result.addAttribute(SplitReductionOp::getSplitFactorAttrName(result.name),
2572  builder.getI64IntegerAttr(splitFactor));
2573  result.addAttribute(
2574  SplitReductionOp::getInsertSplitDimensionAttrName(result.name),
2575  builder.getI64IntegerAttr(insertSplitDimension));
2576  if (innerParallel) {
2577  result.addAttribute(SplitReductionOp::getInnerParallelAttrName(result.name),
2578  builder.getUnitAttr());
2579  }
2580  if (useScalingAlgorithm) {
2581  result.addAttribute(
2582  SplitReductionOp::getUseScalingAlgorithmAttrName(result.name),
2583  builder.getUnitAttr());
2584  }
2585  if (useAlloc) {
2586  result.addAttribute(SplitReductionOp::getUseAllocAttrName(result.name),
2587  builder.getUnitAttr());
2588  }
2589  auto resultType = transform::AnyOpType::get(ctx);
2590  result.addTypes({resultType, resultType, resultType, resultType});
2591 }
2592 
2593 DiagnosedSilenceableFailure transform::SplitReductionOp::applyToOne(
2594  transform::TransformRewriter &rewriter, LinalgOp target,
2596  transform::TransformState &state) {
2597  ControlSplitReductionFn splitFn = [&](LinalgOp) {
2598  return linalg::SplitReductionOptions{int64_t(getSplitFactor()),
2599  unsigned(getInsertSplitDimension()),
2600  bool(getInnerParallel())};
2601  };
2602  rewriter.setInsertionPoint(target);
2603  FailureOr<SplitReductionResult> splitResult =
2604  (getUseScalingAlgorithm())
2605  ? splitReductionByScaling(rewriter, target, splitFn, getUseAlloc())
2606  : splitReduction(rewriter, target, splitFn, getUseAlloc());
2607  if (failed(splitResult))
2608  return emitDefaultDefiniteFailure(target);
2609 
2610  results.push_back(splitResult->initOrAlloc);
2611  results.push_back(splitResult->fillOp);
2612  results.push_back(splitResult->splitLinalgOp);
2613  results.push_back(splitResult->resultCombiningLinalgOp);
2615 }
2616 
2617 //===----------------------------------------------------------------------===//
2618 // TileReductionUsingForOp
2619 //===----------------------------------------------------------------------===//
2620 
2621 void transform::TileReductionUsingForOp::build(
2622  OpBuilder &builder, OperationState &result, Value target,
2623  ArrayRef<int64_t> staticTileSizes) {
2624  // Call the default builder.
2625  // This is future-proof re mixed static-dynamic and setting up the proper
2626  // operands segment sizes attributes for multiple variadic operands.
2627  // In the absence of this, horrible bugs ensue.
2628  // TODO: support mixed static-dynamic (see TileUsingForallOp).
2629  MLIRContext *ctx = builder.getContext();
2630  auto opTy = transform::AnyOpType::get(ctx);
2631  auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
2632  build(builder, result,
2633  /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
2634  /*target=*/target,
2635  /*tile_sizes=*/staticTileSizesAttr);
2636 }
2637 
2638 DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
2639  transform::TransformRewriter &rewriter, Operation *target,
2641  transform::TransformState &state) {
2642  rewriter.setInsertionPoint(target);
2643 
2644  auto partialReductionOp = dyn_cast<PartialReductionOpInterface>(target);
2645  if (!partialReductionOp) {
2646  return emitSilenceableFailure(
2647  target->getLoc(),
2648  "Operation should implement PartialReductionOpInterface");
2649  }
2650  FailureOr<scf::SCFTilingResult> result = scf::tileReductionUsingScf(
2651  rewriter, partialReductionOp,
2652  getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes())));
2653 
2654  if (failed(result))
2655  return emitDefaultSilenceableFailure(target);
2656  rewriter.replaceOp(target, result->mergeResult.replacements);
2657  for (Value initValue : result->initialValues)
2658  results.push_back(initValue.getDefiningOp());
2659  for (auto parallelTiledOp : result->tiledOps)
2660  results.push_back(parallelTiledOp);
2661  for (auto mergeOp : result->mergeResult.mergeOps)
2662  results.push_back(mergeOp);
2663  results.push_back(result->loops.front());
2665 }
2666 
2667 //===----------------------------------------------------------------------===//
2668 // TileReductionUsingForallOp
2669 //===----------------------------------------------------------------------===//
2670 
2671 void transform::TileReductionUsingForallOp::build(
2672  OpBuilder &builder, OperationState &result, Value target,
2673  ArrayRef<int64_t> staticNumThreads, ArrayRef<int64_t> staticTileSizes,
2674  ArrayAttr mapping) {
2675  // Call the default builder.
2676  // This is future-proof re mixed static-dynamic and setting up the proper
2677  // operands segment sizes attributes for multiple variadic operands.
2678  // In the absence of this, horrible bugs ensue.
2679  // TODO: support mixed static-dynamic (see TileUsingForallOp).
2680  MLIRContext *ctx = builder.getContext();
2681  auto opTy = transform::AnyOpType::get(ctx);
2682  auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads);
2683  auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
2684  build(builder, result,
2685  /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
2686  /*target=*/target,
2687  /*num_threads=*/staticNumThreadsAttr,
2688  /*tile_sizes=*/staticTileSizesAttr,
2689  /*mapping=*/mapping);
2690 }
2691 
2692 DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
2693  transform::TransformRewriter &rewriter, LinalgOp target,
2695  transform::TransformState &state) {
2696  rewriter.setInsertionPoint(target);
2697  SmallVector<OpFoldResult> numThreads =
2698  getAsOpFoldResult(rewriter.getI64ArrayAttr(getNumThreads()));
2699  SmallVector<OpFoldResult> tileSizes =
2700  getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes()));
2701  FailureOr<linalg::ForallReductionTilingResult> result =
2703  rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
2704  numThreads, tileSizes, getMapping());
2705 
2706  if (failed(result)) {
2707  auto diag = emitSilenceableError() << "could not tile reduction";
2708  diag.attachNote(target.getLoc()) << "target operation";
2709  return diag;
2710  }
2711  for (Value initValue : result->initialValues)
2712  results.push_back(initValue.getDefiningOp());
2713  for (auto parallelTiledOp : result->parallelTiledOps)
2714  results.push_back(parallelTiledOp);
2715  for (auto mergeOp : result->mergeOps)
2716  results.push_back(mergeOp);
2717  results.push_back(result->loops);
2719 }
2720 
2721 //===----------------------------------------------------------------------===//
2722 // ContinuousTileSizesOp
2723 //===----------------------------------------------------------------------===//
2724 
2726 transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
2727  TransformResults &transformResults,
2728  TransformState &state) {
2729 
2730  SmallVector<Operation *> targetOps =
2731  llvm::to_vector(state.getPayloadOps(getTarget()));
2732 
2733  if (!llvm::hasSingleElement(targetOps)) {
2734  return mlir::emitSilenceableFailure(getLoc())
2735  << "requires exactly one target (got " << llvm::range_size(targetOps)
2736  << ")";
2737  }
2738 
2739  Operation *target = *targetOps.begin();
2740  auto linalgOp = dyn_cast<LinalgOp>(target);
2741  auto tileableOp = dyn_cast<TilingInterface>(target);
2742 
2743  if (!linalgOp)
2744  return emitDefiniteFailure() << "expected Linalg Op";
2745 
2746  OpBuilder builder(linalgOp.getContext());
2747 
2748  if (isa<TransformParamTypeInterface>(getChunkSizes().getType())) {
2749  if (linalgOp.hasDynamicShape()) {
2750  auto diag = emitSilenceableError()
2751  << "cannot compute parametric tile sizes for dynamically "
2752  "shaped payload op";
2753  diag.attachNote(linalgOp->getLoc()) << "payload op";
2754  return diag;
2755  }
2756 
2757  FailureOr<StaticContinuousTileSizeSpecification> spec =
2758  computeStaticContinuousTileSizes(linalgOp, getDimension(),
2759  getTargetSize());
2760  if (failed(spec)) {
2761  return emitSilenceableError()
2762  << "failed to compute multi-size tiling sizes";
2763  }
2764 
2765  SmallVector<int64_t> chunkSizes;
2766 
2767  for (auto &&[tileSize, tripCount] :
2768  llvm::zip_equal(spec->tileSizes, spec->tripCounts))
2769  chunkSizes.push_back(tileSize * tripCount);
2770 
2771  auto getI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
2772  return llvm::map_to_vector(values, [&](int64_t value) -> Attribute {
2773  return builder.getI64IntegerAttr(value);
2774  });
2775  };
2776  transformResults.setParams(cast<OpResult>(getTileSizes()),
2777  getI64AttrsFromI64(spec->tileSizes));
2778  transformResults.setParams(cast<OpResult>(getChunkSizes()),
2779  getI64AttrsFromI64(chunkSizes));
2780 
2782  }
2783 
2784  builder.setInsertionPoint(linalgOp);
2785 
2786  OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
2787  unsigned dimension = getDimension();
2788 
2789  FailureOr<ContinuousTileSizeSpecification> spec = computeContinuousTileSizes(
2790  builder, tileableOp, dimension, targetSize, true);
2791  if (failed(spec)) {
2792  return emitSilenceableError() << "could not generate tile size computation";
2793  }
2794 
2795  AffineExpr s0 = builder.getAffineSymbolExpr(0);
2796  AffineExpr s1 = builder.getAffineSymbolExpr(1);
2797  auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value {
2798  return affine::makeComposedAffineApply(builder, linalgOp->getLoc(), expr,
2799  ofrs);
2800  };
2801 
2802  SmallVector<Value> chunkSizes;
2803  Value splitPoint;
2804  for (auto &&[tileSize, tripCount] :
2805  llvm::zip_equal(spec->tileSizes, spec->tripCounts)) {
2806  splitPoint = apply(s0 * s1, {tileSize, tripCount});
2807  chunkSizes.push_back(splitPoint);
2808  }
2809 
2810  auto getDefiningOps = [&](ArrayRef<Value> values) {
2811  return llvm::map_to_vector(values, [&](Value value) -> Operation * {
2812  return value.getDefiningOp();
2813  });
2814  };
2815 
2816  transformResults.set(cast<OpResult>(getTileSizes()),
2817  getDefiningOps(spec->tileSizes));
2818  transformResults.set(cast<OpResult>(getChunkSizes()),
2819  getDefiningOps(chunkSizes));
2820 
2822 }
2823 
2825 
2826  if (getTileSizes().getType() != getChunkSizes().getType()) {
2827  return emitOpError() << "expects all results type to be the same";
2828  }
2829 
2830  return success();
2831 }
2832 
2833 void transform::ContinuousTileSizesOp::getEffects(
2835  if (isa<TransformParamTypeInterface>(getTileSizes().getType()))
2836  onlyReadsPayload(effects);
2837  else
2838  modifiesPayload(effects);
2839  onlyReadsHandle(getTargetMutable(), effects);
2840  producesHandle(getOperation()->getOpResults(), effects);
2841 }
2842 
2844  Type targetType, Type tile_sizes,
2845  Type) {
2846  printer.printFunctionalType(TypeRange{targetType}, TypeRange{tile_sizes});
2847 }
2848 
2849 static ParseResult parseContinuousTileSizeTypes(OpAsmParser &parser,
2850  Type &targetType,
2851  Type &tileSizesType,
2852  Type &chunkSizesType) {
2853  FunctionType funcType;
2854  llvm::SMLoc typeLoc = parser.getCurrentLocation();
2855  if (failed(parser.parseType<FunctionType>(funcType)))
2856  return failure();
2857 
2858  if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
2859  parser.emitError(typeLoc) << "expects a trailing functional type with one "
2860  "argument and one result";
2861  }
2862  targetType = funcType.getInput(0);
2863  tileSizesType = chunkSizesType = funcType.getResult(0);
2864 
2865  return success();
2866 }
2867 
2868 //===----------------------------------------------------------------------===//
2869 // TileUsingForOp
2870 //===----------------------------------------------------------------------===//
2871 
2872 void transform::TileUsingForOp::build(
2873  OpBuilder &builder, OperationState &result, TypeRange loopTypes,
2874  Value target, ArrayRef<int64_t> staticTileSizes,
2875  ArrayRef<int64_t> interchange,
2876  std::optional<ArrayRef<bool>> scalableSizes) {
2877  return build(builder, result, loopTypes,
2878  /*target=*/target,
2879  /*mixedTileSizes=*/
2880  getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
2881  interchange, scalableSizes);
2882 }
2883 
2884 void transform::TileUsingForOp::build(
2885  OpBuilder &builder, OperationState &result, Value target,
2886  ArrayRef<int64_t> staticTileSizes, ArrayRef<int64_t> interchange,
2887  std::optional<ArrayRef<bool>> scalableSizes) {
2888  build(builder, result, target,
2889  getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
2890  interchange, scalableSizes);
2891 }
2892 
2893 void transform::TileUsingForOp::build(
2894  OpBuilder &builder, OperationState &result, Value target,
2895  ArrayRef<OpFoldResult> mixedTileSizes, ArrayRef<int64_t> interchange,
2896  std::optional<ArrayRef<bool>> scalableSizes) {
2897  // Loop types are automaticaly splat by the callee, setting up one is
2898  // enough.
2899  SmallVector<Type> loopTypes(1, builder.getType<transform::AnyOpType>());
2900  build(builder, result, loopTypes, target, mixedTileSizes, interchange,
2901  scalableSizes);
2902 }
2903 
2904 void transform::TileUsingForOp::build(
2905  OpBuilder &builder, OperationState &result, TypeRange loopTypes,
2906  Value target, ArrayRef<OpFoldResult> mixedTileSizes,
2907  ArrayRef<int64_t> interchange,
2908  std::optional<ArrayRef<bool>> scalableSizes) {
2909  SmallVector<int64_t> staticTileSizes;
2910  SmallVector<Value> dynamicTileSizes;
2911  dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
2912  // Call the default builder which sets up the proper operands segment sizes
2913  // attributes for multiple variadic operands. In the absence of this,
2914  // horrible bugs ensue.
2915  auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
2916  unsigned numExpectedLoops =
2917  staticTileSizes.size() - llvm::count(staticTileSizes, 0);
2918  SmallVector<Type> resultTypes;
2919  resultTypes.reserve(numExpectedLoops);
2920  assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) &&
2921  "expected one loop type or as many as loops");
2922  if (loopTypes.size() == 1)
2923  resultTypes.append(numExpectedLoops, loopTypes[0]);
2924  else
2925  llvm::append_range(resultTypes, loopTypes);
2926  SmallVector<bool> expandedScalableSizes(mixedTileSizes.size(), false);
2927  if (scalableSizes.has_value())
2928  expandedScalableSizes.assign(scalableSizes->begin(), scalableSizes->end());
2929  build(builder, result, /*tiled_linalg_op=*/target.getType(),
2930  /*loops=*/resultTypes,
2931  /*target=*/target,
2932  /*dynamic_sizes=*/dynamicTileSizes,
2933  /*static_sizes=*/staticTileSizesAttr,
2934  /*interchange=*/builder.getDenseI64ArrayAttr(interchange),
2935  /*scalable_sizes=*/expandedScalableSizes);
2936 }
2937 
2938 LogicalResult transform::TileUsingForOp::verify() {
2939  if (getMixedSizes().size() != getScalableSizes().size())
2940  return emitOpError("expected same number of sizes (")
2941  << getMixedSizes().size() << ") and scalable sizes ("
2942  << getScalableSizes().size() << ")";
2943  ArrayRef<int64_t> staticSizes = getStaticSizes();
2944  unsigned numExpectedLoops = staticSizes.size() - llvm::count(staticSizes, 0);
2945  if (getLoops().size() != numExpectedLoops)
2946  return emitOpError("expected number of loops to tile (")
2947  << numExpectedLoops << ") to match number of `loops` results ("
2948  << getLoops().size() << ")";
2949  return success();
2950 }
2951 
2953 transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter,
2954  TransformResults &transformResults,
2955  TransformState &state) {
2956  ArrayRef<int64_t> tileSizes = getStaticSizes();
2957 
2958  SmallVector<Operation *> targets =
2959  llvm::to_vector(state.getPayloadOps(getTarget()));
2960  SmallVector<SmallVector<Operation *>> dynamicSizeProducers;
2962  dynamicSizeProducers.reserve(getDynamicSizes().size());
2963  paramSizes.reserve(getDynamicSizes().size());
2964  for (Value transformValue : getDynamicSizes()) {
2965  if (isa<ParamType>(transformValue.getType())) {
2966  dynamicSizeProducers.push_back({});
2967  ArrayRef<Attribute> params = state.getParams(transformValue);
2968  paramSizes.push_back(
2969  llvm::to_vector(llvm::map_range(params, [](Attribute attr) {
2970  return cast<IntegerAttr>(attr).getValue().getSExtValue();
2971  })));
2972 
2973  if (paramSizes.back().size() != targets.size()) {
2975  emitSilenceableError()
2976  << "expected as many parameter values ("
2977  << dynamicSizeProducers.back().size() << ") as target ops ("
2978  << targets.size() << ")";
2979  diag.attachNote(transformValue.getLoc()) << "for this parameter";
2980  return diag;
2981  }
2982 
2983  continue;
2984  }
2985  paramSizes.push_back({});
2986  dynamicSizeProducers.push_back(
2987  llvm::to_vector(state.getPayloadOps(transformValue)));
2988 
2989  if (dynamicSizeProducers.back().size() != targets.size()) {
2991  emitSilenceableError()
2992  << "expected as many dynamic size-producing operations ("
2993  << dynamicSizeProducers.back().size() << ") as target ops ("
2994  << targets.size() << ")";
2995  diag.attachNote(transformValue.getLoc()) << "for this handle";
2996  return diag;
2997  }
2998 
2999  for (Operation *op : dynamicSizeProducers.back()) {
3000  if (op->getNumResults() == 1 &&
3001  isa<IndexType>(op->getResult(0).getType())) {
3002  continue;
3003  }
3004 
3006  emitSilenceableError() << "expected sizes to be produced by ops "
3007  "with a single index-type result";
3008  diag.attachNote(op->getLoc()) << "size producer op";
3009  diag.attachNote(transformValue.getLoc()) << "for this handle";
3010  return diag;
3011  }
3012  }
3013 
3016  loops.resize(getLoops().size());
3017  auto scalableSizes = getScalableSizes();
3018  for (auto [i, op] : llvm::enumerate(targets)) {
3019  auto tilingInterface = dyn_cast<TilingInterface>(op);
3020  if (!tilingInterface) {
3022  emitSilenceableError()
3023  << "only ops implementing TilingInterface are supported";
3024  diag.attachNote(op->getLoc()) << "target op";
3025  return diag;
3026  }
3027  if (tileSizes.size() > tilingInterface.getLoopIteratorTypes().size()) {
3029  emitSilenceableError()
3030  << "too many tiles provided, expected at most "
3031  << tilingInterface.getLoopIteratorTypes().size() << " found "
3032  << tileSizes.size();
3033  diag.attachNote(op->getLoc()) << "target op";
3034  return diag;
3035  }
3036 
3037  scf::SCFTilingOptions tilingOptions;
3038  if (tileSizes.empty()) {
3039  tilingOptions.setTileSizeComputationFunction(
3041  return {};
3042  });
3043  } else {
3044  tilingOptions.setTileSizeComputationFunction([&, index = i](OpBuilder &b,
3045  Operation *) {
3047  sizes.reserve(tileSizes.size());
3048  unsigned dynamicIdx = 0;
3049 
3050  for (auto [ofrIdx, ofr] : llvm::enumerate(getMixedSizes())) {
3051  if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) {
3052  if (scalableSizes[ofrIdx]) {
3053  auto val = b.create<arith::ConstantIndexOp>(
3054  getLoc(), cast<IntegerAttr>(attr).getInt());
3055  Value vscale =
3056  b.create<vector::VectorScaleOp>(getLoc(), b.getIndexType());
3057  sizes.push_back(
3058  b.create<arith::MulIOp>(getLoc(), val, vscale).getResult());
3059  } else {
3060  sizes.push_back(attr);
3061  }
3062  continue;
3063  }
3064  ArrayRef<Operation *> dynamicSizes = dynamicSizeProducers[dynamicIdx];
3065  ArrayRef<int64_t> params = paramSizes[dynamicIdx];
3066  ++dynamicIdx;
3067  assert((dynamicSizes.empty() ^ params.empty()) &&
3068  "expected either dynamic sizes or parameters");
3069  if (!params.empty()) {
3070  sizes.push_back(b.getIndexAttr(params[index]));
3071  } else {
3072  sizes.push_back(dynamicSizes[index]->getResult(0));
3073  }
3074  }
3075  return sizes;
3076  });
3077  }
3078 
3079  tilingOptions.setInterchange(getInterchange());
3080  FailureOr<scf::SCFTilingResult> maybeTilingResult =
3081  tileUsingSCF(rewriter, tilingInterface, tilingOptions);
3082  if (failed(maybeTilingResult))
3084 
3085  rewriter.replaceOp(op, maybeTilingResult->mergeResult.replacements);
3086 
3087  tiled.append(maybeTilingResult->tiledOps);
3088  for (const auto &en2 : llvm::enumerate(maybeTilingResult->loops))
3089  loops[en2.index()].push_back(en2.value());
3090  }
3091 
3092  transformResults.set(cast<OpResult>(getTiledLinalgOp()), tiled);
3093  for (const auto &en : llvm::enumerate(loops))
3094  transformResults.set(cast<OpResult>(getLoops()[en.index()]), en.value());
3095 
3097 }
3098 
3100  ValueRange dynamic = getDynamicSizes();
3101  ArrayRef<int64_t> tileSizes = getStaticSizes();
3102  SmallVector<OpFoldResult> results;
3103  results.reserve(tileSizes.size());
3104  unsigned dynamicPos = 0;
3105  Builder builder(getContext());
3106  for (int64_t size : tileSizes) {
3107  if (size == ShapedType::kDynamic) {
3108  results.push_back(dynamic[dynamicPos++]);
3109  } else {
3110  results.push_back(builder.getIndexAttr(size));
3111  }
3112  }
3113  return results;
3114 }
3115 
3116 void transform::TileUsingForOp::getEffects(
3118  consumesHandle(getTargetMutable(), effects);
3119  onlyReadsHandle(getDynamicSizesMutable(), effects);
3120  producesHandle(getOperation()->getOpResults(), effects);
3121  modifiesPayload(effects);
3122 }
3123 
3124 //===----------------------------------------------------------------------===//
3125 // TileUsingForallOp
3126 //===----------------------------------------------------------------------===//
3127 
3128 void transform::TileUsingForallOp::build(OpBuilder &builder,
3129  OperationState &result, Value target,
3130  ArrayRef<int64_t> staticTileSizes,
3132  ArrayAttr mapping) {
3133  return build(builder, result,
3134  /*target=*/target,
3135  /*mixedTileSizes=*/
3136  getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
3137  /*_=*/TileSizesSpec(),
3138  /*mapping=*/mapping);
3139 }
3140 
3141 void transform::TileUsingForallOp::build(OpBuilder &builder,
3142  OperationState &result, Value target,
3143  ArrayRef<OpFoldResult> mixedTileSizes,
3145  ArrayAttr mapping) {
3146  SmallVector<int64_t> staticTileSizes;
3147  SmallVector<Value> dynamicTileSizes;
3148  dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
3149  // Call the default builder which sets up the proper operands segment sizes
3150  // attributes for multiple variadic operands. In the absence of this,
3151  // horrible bugs ensue.
3152  MLIRContext *ctx = builder.getContext();
3153  auto operationType = transform::AnyOpType::get(ctx);
3154  auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
3155  build(builder, result,
3156  /*resultTypes=*/TypeRange{operationType, operationType},
3157  /*target=*/target,
3158  /*num_threads=*/ValueRange{},
3159  /*tile_sizes=*/dynamicTileSizes,
3160  /*packed_num_threads=*/Value(),
3161  /*packed_tile_sizes=*/Value(),
3162  /*static_num_threads=*/builder.getDenseI64ArrayAttr({}),
3163  /*static_tile_sizes=*/staticTileSizesAttr,
3164  /*mapping=*/mapping);
3165 }
3166 
3167 void transform::TileUsingForallOp::build(OpBuilder &builder,
3168  OperationState &result, Value target,
3169  ArrayRef<int64_t> staticNumThreads,
3171  ArrayAttr mapping) {
3172  return build(builder, result, target,
3173  getAsOpFoldResult(builder.getI64ArrayAttr(staticNumThreads)),
3174  NumThreadsSpec(), mapping);
3175 }
3176 
3177 void transform::TileUsingForallOp::build(OpBuilder &builder,
3178  OperationState &result, Value target,
3179  ArrayRef<OpFoldResult> mixedNumThreads,
3181  ArrayAttr mapping) {
3182  SmallVector<int64_t> staticNumThreads;
3183  SmallVector<Value> dynamicNumThreads;
3184  dispatchIndexOpFoldResults(mixedNumThreads, dynamicNumThreads,
3185  staticNumThreads);
3186  // Call the default builder which sets up the proper operands segment sizes
3187  // attributes for multiple variadic operands. In the absence of this,
3188  // horrible bugs ensue.
3189  MLIRContext *ctx = builder.getContext();
3190  auto operationType = transform::AnyOpType::get(ctx);
3191  auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads);
3192  build(builder, result,
3193  /*resultTypes=*/TypeRange{operationType, operationType},
3194  /*target=*/target,
3195  /*num_threads=*/dynamicNumThreads,
3196  /*tile_sizes=*/ValueRange{},
3197  /*packed_num_threads=*/Value(),
3198  /*packed_tile_sizes=*/Value(),
3199  /*static_num_threads=*/staticNumThreadsAttr,
3200  /*static_tile_sizes=*/builder.getDenseI64ArrayAttr({}),
3201  /*mapping=*/mapping);
3202 }
3203 
3204 /// Given `lbs`, `ubs` and `steps` of loops, return (for each loop), the
3205 /// normalized upper bound.
3209  ArrayRef<OpFoldResult> steps) {
3210  AffineExpr s0, s1, s2;
3211  bindSymbols(rewriter.getContext(), s0, s1, s2);
3212  AffineExpr normalizedUbExpr = (s1 - s0).ceilDiv(s2);
3213  SmallVector<OpFoldResult> normalizedUbs;
3214  for (auto [lb, ub, step] : llvm::zip_equal(lbs, ubs, steps)) {
3216  rewriter, loc, normalizedUbExpr, {lb, ub, step});
3217  normalizedUbs.push_back(normalizedUb);
3218  }
3219  return normalizedUbs;
3220 }
3221 
3222 /// When a loop is normalized, the uses of the induction variable within the
3223 /// loop need to replaced with `original_lb + old_iv * original_step`.
3225  Location loc, ValueRange ivs,
3227  ArrayRef<OpFoldResult> steps) {
3228  AffineExpr s0, s1;
3229  AffineExpr d0;
3230  bindSymbols(rewriter.getContext(), s0, s1);
3231  bindDims(rewriter.getContext(), d0);
3232  AffineExpr denormExpr = s0 + d0 * s1;
3233  SmallVector<Value> denormalizedIvs;
3234 
3235  for (auto [iv, lb, step] : llvm::zip_equal(ivs, lbs, steps)) {
3237  rewriter, loc, denormExpr, ArrayRef<OpFoldResult>{iv, lb, step});
3238  denormalizedIvs.push_back(
3239  getValueOrCreateConstantIndexOp(rewriter, loc, denormValue));
3240  }
3241  return denormalizedIvs;
3242 }
3243 
3244 /// Given a `scf.forall` loop return a loop op with the loop bounds
3245 /// normalized.
3246 /// TODO: Replace this with a general utility to normalize `scf.forall`.
3247 /// At the time of writing, this wasnt done since adding this to `scf`
3248 /// dialect would disallow using of `affine.apply` operations due
3249 /// to cyclic dependencies. To avoid churn in lit tests
3250 /// with the change this was added with, defer that to a follow up.
3251 static scf::ForallOp normalizeForallLoopOp(RewriterBase &rewriter,
3252  scf::ForallOp loop) {
3253  SmallVector<OpFoldResult> lbs = loop.getMixedLowerBound();
3254  SmallVector<OpFoldResult> ubs = loop.getMixedUpperBound();
3255  SmallVector<OpFoldResult> steps = loop.getMixedStep();
3256 
3257  if (llvm::all_of(
3258  lbs, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }) &&
3259  llvm::all_of(
3260  steps, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); })) {
3261  return loop;
3262  }
3263 
3264  Location loc = loop.getLoc();
3265  SmallVector<OpFoldResult> normalizedUbs =
3266  normalizeUpperBounds(rewriter, loc, lbs, ubs, steps);
3267  SmallVector<OpFoldResult> normalizedLbs(normalizedUbs.size(),
3268  rewriter.getIndexAttr(0));
3269  SmallVector<OpFoldResult> normalizedSteps(normalizedUbs.size(),
3270  rewriter.getIndexAttr(1));
3271 
3272  auto normalizedForallOp = rewriter.create<scf::ForallOp>(
3273  loc, normalizedLbs, normalizedUbs, normalizedSteps, loop.getOutputs(),
3274  loop.getMapping(), [](OpBuilder &, Location, ValueRange) {});
3275 
3276  auto normalizedLoopIvs = normalizedForallOp.getInductionVars();
3277  OpBuilder::InsertionGuard g(rewriter);
3278  Block *normalizedLoopBlock = normalizedForallOp.getBody();
3279  rewriter.setInsertionPointToStart(normalizedLoopBlock);
3280 
3281  SmallVector<Value> argValues =
3282  denormalizeIndVar(rewriter, loc, normalizedLoopIvs, lbs, steps);
3283  argValues.append(normalizedForallOp.getRegionIterArgs().begin(),
3284  normalizedForallOp.getRegionIterArgs().end());
3285  Block *origLoopBlock = loop.getBody();
3286  rewriter.mergeBlocks(origLoopBlock, normalizedLoopBlock, argValues);
3287 
3288  rewriter.replaceOp(loop, normalizedForallOp);
3289  return normalizedForallOp;
3290 }
3291 
3293  RewriterBase &rewriter, transform::TransformState &state,
3294  TransformOpInterface transformOp, Operation *target,
3295  ArrayRef<OpFoldResult> mixedNumThreads,
3296  ArrayRef<OpFoldResult> mixedTileSizes, std::optional<ArrayAttr> mapping,
3297  scf::SCFTilingResult &tilingResult) {
3298  // Transform all targets one by one.
3299  auto tileableOp = dyn_cast<TilingInterface>(target);
3300  if (!tileableOp) {
3302  transformOp.emitSilenceableError()
3303  << "only TilingInterface ops are supported";
3304  diag.attachNote(target->getLoc()) << "target op";
3305  return diag;
3306  }
3307  rewriter.setInsertionPoint(tileableOp);
3310  if (!mixedNumThreads.empty()) {
3311  options.setNumThreads(mixedNumThreads);
3312  } else {
3313  options.setTileSizes(mixedTileSizes);
3314  }
3315  if (mapping) {
3316  options.setMapping(mapping.value().getValue());
3317  }
3318  FailureOr<scf::SCFTilingResult> maybeTilingResult =
3319  scf::tileUsingSCF(rewriter, tileableOp, options);
3320 
3321  if (failed(maybeTilingResult))
3322  return transformOp.emitDefaultSilenceableFailure(tileableOp);
3323 
3324  rewriter.replaceOp(tileableOp, maybeTilingResult->mergeResult.replacements);
3325 
3326  tilingResult = *maybeTilingResult;
3327 
3328  if (mixedNumThreads.empty()) {
3329  auto generatedForallOp = cast<scf::ForallOp>(tilingResult.loops.front());
3330  OpBuilder::InsertionGuard g(rewriter);
3331  rewriter.setInsertionPoint(generatedForallOp);
3332  scf::ForallOp normalizedForallOp =
3333  normalizeForallLoopOp(rewriter, generatedForallOp);
3334  tilingResult.loops.front() = normalizedForallOp;
3335  }
3336 
3338 }
3339 
3340 DiagnosedSilenceableFailure transform::TileUsingForallOp::apply(
3341  transform::TransformRewriter &rewriter,
3342  transform::TransformResults &transformResults,
3343  transform::TransformState &state) {
3344  auto transformOp = cast<TransformOpInterface>(getOperation());
3345 
3346  // Result payload ops.
3347  SmallVector<Operation *> tileOps;
3348  SmallVector<Operation *> tiledOps;
3349 
3350  // Unpack handles.
3351  SmallVector<OpFoldResult> mixedNumThreads;
3353  getPackedNumThreads()
3355  state, transformOp, mixedNumThreads, getPackedNumThreads())
3357  state, transformOp, mixedNumThreads, getMixedNumThreads());
3358  if (!status.succeeded())
3359  return status;
3360  SmallVector<OpFoldResult> mixedTileSizes;
3361  status = getPackedTileSizes()
3363  state, transformOp, mixedTileSizes, getPackedTileSizes())
3365  state, transformOp, mixedTileSizes, getMixedTileSizes());
3366  if (!status.succeeded())
3367  return status;
3368 
3369  for (Operation *target : state.getPayloadOps(getTarget())) {
3370  scf::SCFTilingResult tilingResult;
3372  rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes,
3373  getMapping(), tilingResult);
3374  if (!diag.succeeded())
3375  return diag;
3376  tileOps.push_back(tilingResult.loops.front());
3377  tiledOps.append(tilingResult.tiledOps);
3378  }
3379 
3380  transformResults.set(cast<OpResult>(getForallOp()), tileOps);
3381  transformResults.set(cast<OpResult>(getTiledOp()), tiledOps);
3382 
3384 }
3385 
3386 void transform::TileUsingForallOp::getEffects(
3388  consumesHandle(getTargetMutable(), effects);
3389  onlyReadsHandle(getTileSizesMutable(), effects);
3390  onlyReadsHandle(getNumThreadsMutable(), effects);
3391  onlyReadsHandle(getPackedNumThreadsMutable(), effects);
3392  onlyReadsHandle(getPackedTileSizesMutable(), effects);
3393  producesHandle(getOperation()->getOpResults(), effects);
3394  modifiesPayload(effects);
3395 }
3396 
3397 SmallVector<OpFoldResult> TileUsingForallOp::getMixedNumThreads() {
3398  Builder b(getContext());
3399  return getMixedValues(getStaticNumThreads(), getNumThreads(), b);
3400 }
3401 
3402 SmallVector<OpFoldResult> TileUsingForallOp::getMixedTileSizes() {
3403  Builder b(getContext());
3404  return getMixedValues(getStaticTileSizes(), getTileSizes(), b);
3405 }
3406 
3407 LogicalResult TileUsingForallOp::verify() {
3408  int numThreadsSpec = static_cast<int>(!getMixedNumThreads().empty()) +
3409  static_cast<int>(getPackedNumThreads() != Value());
3410  if (numThreadsSpec > 1)
3411  return emitOpError(
3412  "num_threads and packed_num_threads are mutually exclusive");
3413  int tileSizesSpec = static_cast<int>(!getMixedTileSizes().empty()) +
3414  static_cast<int>(getPackedTileSizes() != Value());
3415  if (tileSizesSpec > 1)
3416  return emitOpError(
3417  "tile_sizes and packed_tile_sizes are mutually exclusive");
3418  if (numThreadsSpec == 0 && tileSizesSpec == 0)
3419  return emitOpError("either (packed_)num_threads or (packed_)tile_sizes "
3420  "must be specified");
3421  return success();
3422 }
3423 
3424 //===----------------------------------------------------------------------===//
3425 // VectorizeChildrenAndApplyPatternsOp
3426 //===----------------------------------------------------------------------===//
3427 
3428 void transform::VectorizeChildrenAndApplyPatternsOp::build(
3429  OpBuilder &builder, OperationState &result, Value target,
3430  bool vectorizePadding, bool vectorizeExtract, bool flatten1DDepthwiseConv) {
3431  result.addOperands(target);
3432  if (vectorizePadding) {
3433  result.addAttribute(
3434  VectorizeChildrenAndApplyPatternsOp::getVectorizePaddingAttrName(
3435  result.name),
3436  builder.getUnitAttr());
3437  }
3438  if (vectorizeExtract) {
3439  result.addAttribute(
3440  VectorizeChildrenAndApplyPatternsOp::getVectorizeNdExtractAttrName(
3441  result.name),
3442  builder.getUnitAttr());
3443  }
3444  if (flatten1DDepthwiseConv) {
3445  result.addAttribute(
3446  VectorizeChildrenAndApplyPatternsOp::getFlatten_1dDepthwiseConvAttrName(
3447  result.name),
3448  builder.getUnitAttr());
3449  }
3450  result.addTypes(transform::AnyOpType::get(builder.getContext()));
3451 }
3452 
3453 namespace {
3454 /// This is an helper only to call vectorize via a pattern inside of
3455 /// VectorizeChildrenAndApplyPatternsOp::applyToOne.
3456 struct VectorizationPattern : public RewritePattern {
3457  explicit VectorizationPattern(MLIRContext *context,
3458  bool vectorizeExtract = false,
3459  bool flattenConv = false)
3460  : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context),
3461  vectorizeNDExtract(vectorizeExtract),
3462  flatten1DDepthwiseConv(flattenConv) {}
3463  LogicalResult matchAndRewrite(Operation *op,
3464  PatternRewriter &rewriter) const override {
3466  return rewriter.notifyMatchFailure(op,
3467  "Unsupported Op, cannot vectorize");
3468  return vectorize(rewriter, op, /*inputVectorSizes=*/{},
3469  /*inputScalableVecDims=*/{}, vectorizeNDExtract,
3470  flatten1DDepthwiseConv);
3471  }
3472 
3473 private:
3474  /// Controls whether to vectorize `tensor.extract` when the input tensor is
3475  /// rank >= 2.
3476  bool vectorizeNDExtract = false;
3477  /// Controls whether to "flatten" the channel dimension when vectorising 1D
3478  /// depthwise convolutions. This should lead to bette vectorization for
3479  /// tensors with a low number of channel dimensions.
3480  bool flatten1DDepthwiseConv = false;
3481 };
3482 } // namespace
3483 
3485 transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
3486  transform::TransformRewriter &rewriter, Operation *target,
3488  transform::TransformState &state) {
3489  if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
3490  auto diag = this->emitOpError("requires isolated-from-above targets");
3491  diag.attachNote(target->getLoc()) << "non-isolated target";
3493  }
3494 
3495  MLIRContext *ctx = getContext();
3497  patterns.add<VectorizationPattern>(ctx, getVectorizeNdExtract(),
3498  getFlatten_1dDepthwiseConv());
3499 
3500  if (!getDisableTransferPermutationMapLoweringPatterns())
3502 
3503  if (!getDisableMultiReductionToContractPatterns())
3505 
3507 
3510  /*benefit=*/2);
3511  vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx);
3512  vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx);
3514 
3516 
3517  if (getVectorizePadding()) {
3519  // This creates an alternative path for lowering tensor.pad - by
3520  // decomposing it into e.g. linalg.fill.
3522  }
3524 
3525  TrackingListener listener(state, *this);
3527  config.listener = &listener;
3528  if (failed(applyPatternsGreedily(target, std::move(patterns), config)))
3529  return emitDefaultDefiniteFailure(target);
3530 
3531  results.push_back(target);
3533 }
3534 
3535 //===----------------------------------------------------------------------===//
3536 // VectorizeOp
3537 //===----------------------------------------------------------------------===//
3538 
3539 DiagnosedSilenceableFailure transform::VectorizeOp::apply(
3540  transform::TransformRewriter &rewriter,
3541  mlir::transform::TransformResults &transformResults,
3543  auto targets = state.getPayloadOps(getTarget());
3544  if (std::empty(targets))
3546  auto transformOp = cast<TransformOpInterface>(getOperation());
3547  SmallVector<int64_t> vectorSizes;
3549  state, transformOp, getMixedVectorSizes(), vectorSizes);
3550  if (!status.succeeded())
3551  return status;
3552 
3553  // TODO: Check that the correct number of vectorSizes was provided.
3554  for (Operation *target : targets) {
3555  if (!linalg::hasVectorizationImpl(target)) {
3556  return mlir::emitSilenceableFailure(target->getLoc())
3557  << "Unsupported Op, cannot vectorize";
3558  }
3559 
3560  if (failed(linalg::vectorize(rewriter, target, vectorSizes,
3561  getScalableSizes(),
3562  getVectorizeNdExtract().value_or(false)))) {
3563  return mlir::emitSilenceableFailure(target->getLoc())
3564  << "Attempted to vectorize, but failed";
3565  }
3566  }
3567 
3569 }
3570 
3571 void transform::VectorizeOp::getEffects(
3573  consumesHandle(getTargetMutable(), effects);
3574  onlyReadsHandle(getVectorSizesMutable(), effects);
3575  modifiesPayload(effects);
3576 }
3577 
3578 SmallVector<OpFoldResult> VectorizeOp::getMixedVectorSizes() {
3579  OpBuilder b(getContext());
3580  return getMixedValues(getStaticVectorSizes(), getVectorSizes(), b);
3581 }
3582 
3583 LogicalResult transform::VectorizeOp::verify() {
3584  if (getStaticVectorSizes().size() != getScalableSizes().size())
3585  return emitOpError("expected same number of vector sizes (")
3586  << getStaticVectorSizes().size() << ") and scalable sizes ("
3587  << getScalableSizes().size() << ")";
3588  return success();
3589 }
3590 
3591 //===----------------------------------------------------------------------===//
3592 // HoistRedundantVectorTransfersOp
3593 //===----------------------------------------------------------------------===//
3594 
3596 transform::HoistRedundantVectorTransfersOp::applyToOne(
3597  transform::TransformRewriter &rewriter, func::FuncOp target,
3599  transform::TransformState &state) {
3600  // WARNING: This hoisting does not model parallelism and is generally
3601  // incorrect when used on distributed loops with memref semantics!
3602  // TODO: obsolete and should be retired.
3603  linalg::hoistRedundantVectorTransfers(target, getVerifyNonZeroTrip());
3604  results.push_back(target);
3606 }
3607 
3608 //===----------------------------------------------------------------------===//
3609 // HoistRedundantVectorBroadcastsOp
3610 //===----------------------------------------------------------------------===//
3611 
3613 transform::HoistRedundantVectorBroadcastsOp::applyToOne(
3614  transform::TransformRewriter &rewriter, mlir::Operation *target,
3616  transform::TransformState &state) {
3617  rewriter.setInsertionPoint(target);
3618  linalg::hoistRedundantVectorBroadcasts(rewriter, target);
3619  results.push_back(target);
3621 }
3622 
3623 //===----------------------------------------------------------------------===//
3624 // ConvertConv2DToImg2ColOp.
3625 //===----------------------------------------------------------------------===//
3626 
3627 DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne(
3628  transform::TransformRewriter &rewriter, linalg::LinalgOp target,
3630  transform::TransformState &state) {
3631  rewriter.setInsertionPoint(target);
3632  auto maybeTransformed =
3634  target)
3635  .Case([&](linalg::Conv2DNhwcHwcfOp op) {
3636  return rewriteInIm2Col(rewriter, op);
3637  })
3638  .Case([&](linalg::Conv2DNhwcFhwcOp op) {
3639  return rewriteInIm2Col(rewriter, op);
3640  })
3641  .Case([&](linalg::DepthwiseConv2DNhwcHwcOp op) {
3642  return rewriteInIm2Col(rewriter, op);
3643  })
3644  .Case([&](linalg::Conv2DNchwFchwOp op) {
3645  return rewriteInIm2Col(rewriter, op);
3646  })
3647  .Default([&](Operation *op) {
3648  return rewriter.notifyMatchFailure(op, "not supported");
3649  });
3650  if (failed(maybeTransformed))
3651  return emitDefaultSilenceableFailure(target);
3652  // Handle to the operation producing the img2col tensor.
3653  results.push_back(maybeTransformed->first);
3654  // Handle to the operation that replaces the original convolution.
3655  results.push_back(maybeTransformed->second);
3657 }
3658 
3659 //===----------------------------------------------------------------------===//
3660 // FlattenElementwiseLinalgOp.
3661 //===----------------------------------------------------------------------===//
3662 
3663 DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne(
3664  transform::TransformRewriter &rewriter, linalg::LinalgOp target,
3666  transform::TransformState &state) {
3667  rewriter.setInsertionPoint(target);
3668  if (!isElementwise(target))
3669  return mlir::emitSilenceableFailure(target->getLoc())
3670  << "only elementwise flattening is supported";
3671 
3672  // If rank <= 1, do nothing
3673  if (target.getNumLoops() <= 1) {
3674  results.push_back(target);
3676  }
3677 
3678  // Attempt to flatten all dims to one.
3679  ReassociationIndices reassociation(target.getNumLoops());
3680  std::iota(reassociation.begin(), reassociation.end(), 0);
3681  auto maybeFlattened =
3682  collapseOpIterationDims(target, reassociation, rewriter);
3683  if (failed(maybeFlattened))
3684  return mlir::emitSilenceableFailure(target->getLoc())
3685  << "attempted to flatten, but failed";
3686  results.push_back(maybeFlattened->collapsedOp);
3687  rewriter.replaceOp(target, maybeFlattened->results);
3689 }
3690 
3691 //===----------------------------------------------------------------------===//
3692 // TransposeConv2DOp
3693 //===----------------------------------------------------------------------===//
3694 
3695 DiagnosedSilenceableFailure transform::TransposeConv2DOp::applyToOne(
3696  transform::TransformRewriter &rewriter, linalg::LinalgOp target,
3698  transform::TransformState &state) {
3699  rewriter.setInsertionPoint(target);
3700  auto maybeTransformed =
3702  .Case([&](linalg::Conv2DNhwcFhwcOp op) {
3703  return transposeConv2D(rewriter, op);
3704  })
3705  .Case([&](linalg::Conv2DNhwcFhwcQOp op) {
3706  return transposeConv2D(rewriter, op);
3707  })
3708  .Default([&](Operation *op) {
3709  return rewriter.notifyMatchFailure(op, "not supported");
3710  });
3711  if (failed(maybeTransformed))
3712  return emitDefaultSilenceableFailure(target);
3713  // Handle to the new Conv2D operation with transposed filters
3714  results.push_back(*maybeTransformed);
3716 }
3717 
3718 //===----------------------------------------------------------------------===//
3719 // TransposeMatmulOp
3720 //===----------------------------------------------------------------------===//
3721 
3722 DiagnosedSilenceableFailure transform::TransposeMatmulOp::applyToOne(
3723  transform::TransformRewriter &rewriter, linalg::LinalgOp target,
3725  transform::TransformState &state) {
3726  rewriter.setInsertionPoint(target);
3727  bool transposeLHS = getInputToTranspose() == TransposeMatmulInput::lhs;
3728  auto maybeTransformed =
3730  .Case([&](linalg::MatmulOp op) {
3731  return transposeMatmul(rewriter, op, transposeLHS);
3732  })
3733  .Case([&](linalg::BatchMatmulOp op) {
3734  return transposeBatchMatmul(rewriter, op, transposeLHS);
3735  })
3736  .Default([&](Operation *op) { return failure(); });
3737  if (failed(maybeTransformed))
3738  return emitSilenceableFailure(target->getLoc()) << "not supported";
3739  // Handle to the new Matmul operation with transposed filters
3740  results.push_back(*maybeTransformed);
3742 }
3743 
3744 //===----------------------------------------------------------------------===//
3745 // InsertSliceToCopyOp
3746 //===----------------------------------------------------------------------===//
3747 template <typename OpTy>
3750  transform::TransformState &state) {
3751  static_assert(llvm::is_one_of<OpTy, tensor::InsertSliceOp,
3752  tensor::ParallelInsertSliceOp>() &&
3753  "wrong op type");
3754 
3755  if (auto copySource =
3756  target.getSource().template getDefiningOp<linalg::CopyOp>()) {
3757  results.push_back(copySource);
3759  }
3760 
3761  // If we are inside an InParallel region, temporarily set the insertion point
3762  // outside: only tensor.parallel_insert_slice ops are allowed in there.
3763  if constexpr (std::is_same_v<OpTy, tensor::ParallelInsertSliceOp>) {
3764  rewriter.setInsertionPoint(
3765  target->template getParentOfType<scf::InParallelOp>());
3766  }
3767 
3768  Value extracted = rewriter.create<tensor::ExtractSliceOp>(
3769  target.getLoc(), target.getDest(), target.getMixedOffsets(),
3770  target.getMixedSizes(), target.getMixedStrides());
3771  Value copied = rewriter
3772  .create<linalg::CopyOp>(target.getLoc(),
3773  target.getSource(), extracted)
3774  .getResult(0);
3775  // Reset the insertion point.
3776  rewriter.setInsertionPoint(target);
3777  rewriter.replaceOpWithNewOp<OpTy>(
3778  target, copied, target.getDest(), target.getMixedOffsets(),
3779  target.getMixedSizes(), target.getMixedStrides());
3780 
3781  results.push_back(copied.getDefiningOp());
3783 }
3784 
3785 DiagnosedSilenceableFailure transform::InsertSliceToCopyOp::applyToOne(
3786  transform::TransformRewriter &rewriter, Operation *targetOp,
3788  transform::TransformState &state) {
3789 
3790  rewriter.setInsertionPoint(targetOp);
3791  if (auto target = dyn_cast<tensor::InsertSliceOp>(targetOp))
3792  return doit(rewriter, target, results, state);
3793  if (auto target = dyn_cast<tensor::ParallelInsertSliceOp>(targetOp))
3794  return doit(rewriter, target, results, state);
3795 
3797  emitSilenceableError()
3798  << "only InsertSliceOp and ParallelInsertSliceOp ops are supported";
3799  diag.attachNote(targetOp->getLoc()) << "target op";
3800  return diag;
3801 }
3802 
3803 //===----------------------------------------------------------------------===//
3804 // MapCopyToThreadsOp
3805 //===----------------------------------------------------------------------===//
3806 
3807 DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
3808  transform::TransformRewriter &rewriter, Operation *target,
3810  transform::TransformState &state) {
3811  // Check if the op is supported.
3812  if (!isa<linalg::CopyOp, tensor::PadOp>(target)) {
3814  emitSilenceableError()
3815  << "only linalg.copy and tensor.pad target ops are supported";
3816  diag.attachNote(target->getLoc()) << "target op";
3817  return diag;
3818  }
3819  assert(target->getNumResults() == 1 && "expected single result");
3820  auto resultShapedType = cast<ShapedType>(target->getResult(0).getType());
3821  if (!resultShapedType.hasStaticShape()) {
3823  emitSilenceableError()
3824  << "only statically sized ops of rank <= 3 are supported";
3825  diag.attachNote(target->getLoc()) << "target op";
3826  return diag;
3827  }
3828 
3829  // Conservatively set the minimum viable desired bitwidth alignment.
3830  int64_t desiredBitAlignment = getDesiredBitAlignment();
3831  int64_t eltBitwidth =
3832  resultShapedType.getElementType().getIntOrFloatBitWidth();
3833  if (desiredBitAlignment % eltBitwidth != 0) {
3834  desiredBitAlignment = eltBitwidth;
3835  }
3836 
3837  gpu::CopyMappingInfo mapping(
3838  /*ctx=*/getContext(),
3839  /*totalNumThreads=*/getTotalNumThreads(),
3840  /*alignment=*/desiredBitAlignment,
3841  /*sizes=*/resultShapedType.getShape(),
3842  /*favorPredication=*/false,
3843  /*elementalBitwidth=*/
3844  resultShapedType.getElementType().getIntOrFloatBitWidth());
3845  if (mapping.status == gpu::CopyMappingInfo::Status::Invalid) {
3847  emitSilenceableError()
3848  << "too few threads to map copy op to threads on the most minor "
3849  "dimension, given alignment and vector size constraints, try "
3850  "smaller tile size of mapping to more threads";
3851  diag.attachNote(target->getLoc()) << "target op";
3852  return diag;
3853  }
3854 
3855  // OpBuilder only used to compute attributes.
3856  OpBuilder b(getContext());
3857  scf::SCFTilingResult tilingResult;
3859  /*rewriter=*/rewriter,
3860  /*state=*/state,
3861  /*transformOp=*/*this,
3862  /*target=*/target,
3863  /*mixedNumThreads=*/getMixedValues(mapping.numThreads, {}, b),
3864  /*mixedTileSizes=*/ArrayRef<OpFoldResult>{},
3865  /*mapping=*/b.getArrayAttr(mapping.threadMapping),
3866  /*tilingResult=*/tilingResult);
3867  if (!diag.succeeded())
3868  return diag;
3869 
3870  results.push_back(tilingResult.loops.front());
3871  for (auto op : tilingResult.tiledOps)
3872  results.push_back(op);
3874 }
3875 
3876 //===----------------------------------------------------------------------===//
3877 // WinogradConv2DOp
3878 //===----------------------------------------------------------------------===//
3879 
3880 DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne(
3881  transform::TransformRewriter &rewriter, linalg::LinalgOp target,
3883  transform::TransformState &state) {
3884  rewriter.setInsertionPoint(target);
3885  FailureOr<Operation *> maybeTransformed = failure();
3886  bool supported = TypeSwitch<Operation *, bool>(target)
3887  .Case([&](linalg::Conv2DNhwcFhwcOp op) {
3888  maybeTransformed =
3889  winogradConv2D(rewriter, op, getM(), getR());
3890  return true;
3891  })
3892  .Default([&](Operation *op) { return false; });
3893 
3894  if (!supported) {
3895  return emitSilenceableError()
3896  << "this operation is not supported to convert to Winograd Conv2D";
3897  }
3898 
3899  if (failed(maybeTransformed)) {
3900  return emitSilenceableError() << "apply Winograd Conv2D failed";
3901  }
3902 
3903  results.push_back(*maybeTransformed);
3905 }
3906 
3907 DiagnosedSilenceableFailure transform::DecomposeWinogradOp::applyToOne(
3908  transform::TransformRewriter &rewriter, Operation *target,
3910  transform::TransformState &state) {
3911  rewriter.setInsertionPoint(target);
3912  FailureOr<Operation *> maybeTransformed = failure();
3913  bool supported =
3915  .Case([&](linalg::WinogradFilterTransformOp op) {
3916  maybeTransformed = decomposeWinogradFilterTransformOp(rewriter, op);
3917  return true;
3918  })
3919  .Case([&](linalg::WinogradInputTransformOp op) {
3920  maybeTransformed = decomposeWinogradInputTransformOp(rewriter, op);
3921  return true;
3922  })
3923  .Case([&](linalg::WinogradOutputTransformOp op) {
3924  maybeTransformed = decomposeWinogradOutputTransformOp(rewriter, op);
3925  return true;
3926  })
3927  .Default([&](Operation *op) { return false; });
3928 
3929  if (!supported) {
3931  emitSilenceableError()
3932  << "this operation is not supported to decompose into other operations";
3933  diag.attachNote(target->getLoc()) << "target op";
3934  return diag;
3935  }
3936 
3937  if (failed(maybeTransformed)) {
3939  emitSilenceableError() << "decompose Winograd operations failed";
3940  diag.attachNote(target->getLoc()) << "target op";
3941  return diag;
3942  }
3943 
3944  results.push_back(*maybeTransformed);
3946 }
3947 
3948 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
3949 
3950 #define GET_OP_CLASSES
3951 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
static MLIRContext * getContext(OpFoldResult val)
DiagnosedSilenceableFailure doit(RewriterBase &rewriter, OpTy target, transform::ApplyToEachResultList &results, transform::TransformState &state)
static Operation * cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
#define DOWNSCALE(trans)
bool isValidPackingPermutation(RelayoutOpTy op, ArrayRef< int64_t > permutation, OuterOrInnerPerm outerOrInnerPerm=OuterOrInnerPerm::Outer)
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static DiagnosedSilenceableFailure reifyMixedParamAndHandleResults(TransformState &state, TransformOpInterface &transformOp, ArrayRef< OpFoldResult > mixedResults, SmallVectorImpl< int64_t > &reified)
When possible, converts each OpFoldResult in mixedResult to an integer if the value can be statically...
static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(transform::TransformState &state, TransformOpInterface transformOp, SmallVector< OpFoldResult > &result, ArrayRef< OpFoldResult > ofrs)
Assuming that ofr is an index attr or a param of index type or a transform dialect handle mapped to e...
static void printContinuousTileSizeTypes(OpAsmPrinter &printer, Operation *op, Type targetType, Type tile_sizes, Type)
static scf::ForallOp normalizeForallLoopOp(RewriterBase &rewriter, scf::ForallOp loop)
Given a scf.forall loop return a loop op with the loop bounds normalized.
static SmallVector< Value > denormalizeIndVar(RewriterBase &rewriter, Location loc, ValueRange ivs, ArrayRef< OpFoldResult > lbs, ArrayRef< OpFoldResult > steps)
When a loop is normalized, the uses of the induction variable within the loop need to replaced with o...
#define DOWNSCALE_NORMAL(a, b)
static FailureOr< LinalgOp > tryApply(Operation *operation, Args &&...args)
Attempts to apply the pattern specified as template argument to the given operation.
static void printMultitileSizesTypes(OpAsmPrinter &printer, Operation *op, Type targetType, Type lowSizeType, Type, Type)
static Operation * replaceForAllWithNewSignature(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp, TilingResult &tileAndFuseResult, int64_t resultNumber, SmallVector< OpFoldResult > &offsets, SmallVector< OpFoldResult > &sizes)
Add new operands to the forall op for users of the producerOp that are dominated by the containing sc...
static ParseResult parseContinuousTileSizeTypes(OpAsmParser &parser, Type &targetType, Type &tileSizesType, Type &chunkSizesType)
static SmallVector< Operation * > tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
First, find the first "scf::ForallOp" user of producerOp and ensure it is exactly the containingOp,...
static ParseResult parseMultitileSizesTypes(OpAsmParser &parser, Type &targetType, Type &lowSizeType, Type &highSizeType, Type &splitPointType)
static SmallVector< OpFoldResult > normalizeUpperBounds(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > lbs, ArrayRef< OpFoldResult > ubs, ArrayRef< OpFoldResult > steps)
Given lbs, ubs and steps of loops, return (for each loop), the normalized upper bound.
#define DBGS()
static LogicalResult applyTilingToAll(RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps, unsigned numLoops, transform::TransformResults &transformResults, function_ref< FailureOr< scf::SCFTileAndFuseResult >(TilingInterface)> applyFn)
Apply a tiling transformation to all payload ops and store both the tiled operation as well as the cr...
static std::tuple< SmallVector< Operation * >, Operation * > tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
Find the first "extract" user of producerOp and tile it right before its use.
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
Block represents an ordered list of Operations.
Definition: Block.h:33
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
UnitAttr getUnitAttr()
Definition: Builders.cpp:94
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:224
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:163
AffineExpr getAffineSymbolExpr(unsigned position)
Definition: Builders.cpp:364
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:108
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:89
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:258
MLIRContext * getContext() const
Definition: Builders.h:56
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:262
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:277
IndexType getIndexType()
Definition: Builders.cpp:51
ArrayAttr getStrArrayAttr(ArrayRef< StringRef > values)
Definition: Builders.cpp:302
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
A class for computing basic dominance information.
Definition: Dominance.h:140
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition: Dominance.h:158
This class allows control over how the GreedyPatternRewriteDriver works.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:207
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
Definition: AsmPrinter.cpp:95
This class represents a saved insertion point.
Definition: Builders.h:325
bool isSet() const
Returns true if this insert point is set.
Definition: Builders.h:335
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:544
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
Definition: Builders.h:314
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:318
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:457
This class provides the API for ops that are known to be isolated from above.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getOpResult(unsigned idx)
Definition: Operation.h:421
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:750
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:534
void setOperand(unsigned idx, Value value)
Definition: Operation.h:351
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition: Operation.h:560
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:798
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:346
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition: Operation.h:263
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:874
result_range getOpResults()
Definition: Operation.h:420
result_range getResults()
Definition: Operation.h:415
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:219
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:803
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:271
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:412
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:736
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:656
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:648
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:554
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:54
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
Type front()
Return first type in the range.
Definition: TypeRange.h:149
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
user_range getUsers() const
Definition: Value.h:228
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
State for analysis-enabled bufferization.
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
void assign(unsigned size, std::nullptr_t)
Sets the list of results to size null pointers.
void reserve(unsigned size)
Reserves space for size elements in the list.
size_t size() const
Returns the number of elements in the list.
void push_back(Operation *op)
Appends an element to the list.
A listener that updates a TransformState based on IR modifications.
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void setValues(OpResult handle, Range &&values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
void setParams(OpResult value, ArrayRef< TransformState::Param > params)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
LogicalResult notifyPayloadOperationReplaced(Operation *op, Operation *replacement)
Notify the transform dialect interpreter that the given op has been replaced with another op and that...
The state maintained across applications of various ops implementing the TransformOpInterface.
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
Definition: AffineOps.cpp:1255
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1158
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1208
LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze op and its nested ops.
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition: Visitors.h:136
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
FailureOr< PackingResult > buildPackingLoopNest(RewriterBase &rewriter, tensor::PadOp opToHoist, scf::ForOp outermostEnclosingForOp, ArrayRef< int64_t > transposeVector)
Build the packing loop nest required to hoist opToHoist above outermostEnclosingForOp.
LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, const LinalgPaddingOptions &options, LinalgOp &paddedOp, SmallVector< Value > &replacements, SmallVector< tensor::PadOp > &padOps)
Pad the iterator dimensions paddingDimensions of all opToPad operands to a static bounding box.
Definition: Padding.cpp:153
FailureOr< std::pair< Operation *, Operation * > > rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp)
Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) and linalg....
bool hasVectorizationImpl(Operation *)
Return true if there's dedicated logic in the Linalg Vectorizer to vectorize this Op,...
FailureOr< Operation * > decomposeWinogradFilterTransformOp(RewriterBase &rewriter, linalg::WinogradFilterTransformOp op)
Rewrite linalg.winograd_filter_transform.
std::optional< Value > allocateWorkgroupMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU workgroup memory.
Definition: Promotion.cpp:470
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, linalg::PackOp packOp, linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
Definition: Transforms.cpp:677
Value bufferizeToAllocation(RewriterBase &rewriter, const BufferizeToAllocationOptions &options, tensor::PadOp padOp, Attribute memorySpace={}, Operation *insertionPoint=nullptr)
Materialize a buffer allocation for the given tensor.pad op and lower the op to linalg....
FailureOr< Value > hoistPaddingOnTensors(RewriterBase &rewriter, tensor::PadOp opToHoist, int64_t numLoops, ArrayRef< int64_t > transposeVector, tensor::PadOp &hoistedOp, SmallVectorImpl< TransposeOp > &transposeOps)
Mechanically hoist padding operations on tensors by numLoops into a new, generally larger tensor.
FailureOr< LinalgOp > specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp)
Create a namedOp from the given GenericOp and replace the GenericOp.
Definition: Specialize.cpp:260
FailureOr< LowerUnPackOpResult > lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp, bool lowerUnpadLikeWithExtractSlice=true)
Rewrite pack as empty + transpose + reshape + extract_slice.
Definition: Transforms.cpp:356
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns)
Definition: Tiling.cpp:860
LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value)
In case of GPU private memory there is no need to deallocate since the memory is freed when going out...
Definition: Promotion.cpp:511
FailureOr< Operation * > decomposeWinogradOutputTransformOp(RewriterBase &rewriter, linalg::WinogradOutputTransformOp op)
Rewrite linalg.winograd_output_transform.
std::optional< Value > allocateGPUPrivateMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU private memory.
Definition: Promotion.cpp:495
FailureOr< Operation * > rewriteInDestinationPassingStyle(RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp)
Rewrite tensor.from_elements to linalg.generic.
FailureOr< Operation * > winogradConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op, int64_t m, int64_t r)
Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm F(m x m, r x r).
FailureOr< Operation * > transposeConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op)
Convert linalg.conv_2d_nhwc_fhwc(_q) to linalg.conv_2d_nhwc_hwcf(_q) by materializing transpose.
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
Patterns to fold unit-extent dimensions in operands/results of linalg ops on tensors via reassociativ...
LogicalResult copyToWorkgroupMemory(OpBuilder &b, Value src, Value dst)
Create Memref copy operations and add gpu barrier guards before and after the copy operation to ensur...
Definition: Promotion.cpp:486
LogicalResult linalgOpAnchoredEmptyTensorEliminationStep(RewriterBase &rewriter, Operation *op, bufferization::OneShotAnalysisState &state)
Try to eliminate tensor::EmptyOps inside op that are anchored on a LinalgOp.
LogicalResult vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Emit a suitable vector form for an operation.
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp linalgOp)
Create a GenericOp from the given named operation linalgOp and replace the given linalgOp.
FailureOr< Operation * > transposeBatchMatmul(RewriterBase &rewriter, linalg::BatchMatmulOp op, bool transposeLHS=true)
Pattern to replace.
LogicalResult promoteSubviewsPrecondition(Operation *op, LinalgPromotionOptions options)
Promote memref.subviews feeding linalg-on-buffers operations.
Definition: Promotion.cpp:399
LogicalResult copyToGPUPrivateMemory(OpBuilder &b, Value src, Value dst)
Normal copy to between src and dst.
Definition: Promotion.cpp:503
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
Definition: Utils.cpp:223
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.
Definition: Interchange.cpp:50
FailureOr< StaticMultiSizeSpecification > computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, int64_t targetSize, int64_t divisor)
Definition: Tiling.cpp:242
void populateDecomposePackUnpackPatterns(RewritePatternSet &patterns)
Populates patterns to decompose linalg.pack and linalg.unpack Ops into e.g.
FailureOr< ContinuousTileSizeSpecification > computeContinuousTileSizes(OpBuilder &builder, TilingInterface op, unsigned dimension, OpFoldResult targetSize, bool emitAssertions)
Definition: Tiling.cpp:162
FailureOr< StaticContinuousTileSizeSpecification > computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension, unsigned targetSize)
Definition: Tiling.cpp:111
FailureOr< SplitReductionResult > splitReduction(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
void populateFoldPackUnpackIntoTensorEmptyPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold operations like linalg.pack and linalg....
void hoistRedundantVectorBroadcasts(RewriterBase &rewriter, Operation *root)
Hoist vector.extract/vector.broadcast pairs out of immediately enclosing scf::ForOp iteratively,...
Definition: Hoisting.cpp:97
FailureOr< ForallReductionTilingResult > tileReductionUsingForall(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > tileSizes={}, std::optional< ArrayAttr > mapping=std::nullopt)
Method to tile a reduction to parallel iterations computing partial reductions.
Definition: Tiling.cpp:594
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
Definition: Transforms.cpp:768
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
Definition: Transforms.cpp:479
void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns)
Patterns to promote inputs to outputs and remove unused inputs of linalg.generic ops.
FailureOr< LinalgOp > promoteSubViews(OpBuilder &b, LinalgOp op, const LinalgPromotionOptions &options)
Promote the subViews into a new buffer allocated at the insertion point b.
Definition: Promotion.cpp:421
std::function< SplitReductionOptions(LinalgOp op)> ControlSplitReductionFn
Function signature to control reduction splitting.
Definition: Transforms.h:442
LogicalResult deallocateWorkgroupMemory(OpBuilder &, Value)
In case of GPU group memory there is no need to deallocate.
Definition: Promotion.cpp:479
FailureOr< Operation * > transposeMatmul(RewriterBase &rewriter, linalg::MatmulOp op, bool transposeLHS=true)
Convert Linalg matmul ops to transposed variants.
FailureOr< CollapseResult > collapseOpIterationDims(LinalgOp op, ArrayRef< ReassociationIndices > foldedIterationDims, RewriterBase &rewriter)
Collapses dimensions of linalg.generic/linalg.copy operation.
void hoistRedundantVectorTransfers(Operation *root, bool verifyNonZeroTrip=false)
Hoist vector.transfer_read/vector.transfer_write on buffers pairs out of immediately enclosing scf::F...
Definition: Hoisting.cpp:202
FailureOr< Operation * > decomposeWinogradInputTransformOp(RewriterBase &rewriter, linalg::WinogradInputTransformOp op)
Rewrite linalg.winograd_input_transform.
void populateDecomposePadPatterns(RewritePatternSet &patterns)
Populates patterns to decompose tensor.pad into e.g.
void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns)
Pattern to replace linalg.add when destination passing on a contraction op suffices for achieving the...
std::pair< TilingInterface, TilingInterface > splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension, OpFoldResult splitPoint)
Split the given op into two parts along the given iteration space dimension at the specified splitPoi...
Definition: Split.cpp:67
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold operations like tensor.pad and tensor.extract_slice into t...
FailureOr< SplitReductionResult > splitReductionByScaling(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
Scaling-based implementation of the split reduction transformation.
FailureOr< MultiSizeSpecification > computeMultiTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension, OpFoldResult targetSize, OpFoldResult divisor, bool emitAssertions=true)
Emits the IR computing the multi-sized tiling specification with two tile sizes not exceeding targetS...
Definition: Tiling.cpp:268
FailureOr< LowerPackResult > lowerPack(RewriterBase &rewriter, linalg::PackOp packOp, bool lowerPadLikeWithInsertSlice=true)
Rewrite pack as pad + reshape + transpose.
Definition: Transforms.cpp:219
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:78
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
FailureOr< SCFTilingResult > tileUsingSCF(RewriterBase &rewriter, TilingInterface op, const SCFTilingOptions &options)
Method to tile an op that implements the TilingInterface using scf.for for iterating over the tiles.
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
Definition: SCF.cpp:602
FailureOr< SmallVector< scf::ForOp > > lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op)
Method to lower an op that implements the TilingInterface to loops/scalars.
FailureOr< scf::SCFTilingResult > tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > tileSize)
Method to tile a reduction and generate a parallel op within a serial loop.
uint64_t getM(LevelType lt)
Definition: Enums.h:443
void populateMergeConsecutiveInsertExtractSlicePatterns(RewritePatternSet &patterns)
Collects patterns to merge consecutive tensor.insert_slice/extract_slice into one.
void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns)
Appends patterns that are used to bubble up tensor.extract slice op above its producer.
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:113
void populateFoldTensorSubsetIntoVectorTransferPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into vector transfer ops.
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
DiagnosedSilenceableFailure tileToForallOpImpl(RewriterBase &rewriter, transform::TransformState &state, TransformOpInterface transformOp, Operation *target, ArrayRef< OpFoldResult > mixedNumThreads, ArrayRef< OpFoldResult > mixedTileSizes, std::optional< ArrayAttr > mapping, scf::SCFTilingResult &tilingResult)
Implementation of tiling operations using scf.forall.
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
void populateVectorStepLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant Vector Ops by re-ordering them with e.g.
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Definition: XeGPUOps.cpp:22
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:348
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:362
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:425
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
This is the representation of an operand reference.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Definition: Builders.h:283
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
NamedAttrList attributes
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
A listener that forwards all notifications to another listener.
Definition: PatternMatch.h:475
ForwardingListener(OpBuilder::Listener *listener)
Definition: PatternMatch.h:476
Container for result values of tiling.
SmallVector< Value > tiledValues
Options for analysis-enabled bufferization.
Transformation to drop unit-extent dimensions from linalg.generic operations.
Definition: Transforms.h:473
Vectorization pattern for memref::CopyOp.
Definition: Transforms.h:1505
Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) dimensions into 1-D depthwise c...
Definition: Transforms.h:1437
Match and rewrite for the pattern:
Definition: Transforms.h:1627
Match and rewrite for the pattern:
Definition: Transforms.h:1655
LinalgPromotionOptions & setUseFullTileBuffersByDefault(bool use)
Definition: Transforms.h:379
LinalgPromotionOptions & setAlignment(unsigned align)
Definition: Transforms.h:385
LinalgPromotionOptions & setUseAlloca(bool use)
Definition: Transforms.h:398
LinalgPromotionOptions & setCopyInOutFns(CopyCallbackFn const &copyIn, CopyCallbackFn const &copyOut)
Definition: Transforms.h:418
LinalgPromotionOptions & setUseFullTileBuffers(ArrayRef< bool > useFullTiles)
Definition: Transforms.h:368
LinalgPromotionOptions & setMemorySpace(Attribute memorySpc)
Definition: Transforms.h:392
LinalgPromotionOptions & setAllocationDeallocationFns(AllocBufferCallbackFn const &allocFn, DeallocBufferCallbackFn const &deallocFn)
Definition: Transforms.h:408
LinalgPromotionOptions & setOperandsToPromote(ArrayRef< int64_t > operands)
Definition: Transforms.h:357
Split Reduction options.
Definition: Transforms.h:427
Options used to control tile + fuse.
SCFTilingOptions tilingOptions
The tiling options used to control the tiling of the consumer.
std::optional< FrozenRewritePatternSet > cleanupPatterns
An optional set of rewrite patterns to apply to the results of tiling before fusion.
Options to use to control tiling.
SCFTilingOptions & setTileSizeComputationFunction(SCFTileSizeComputationFunction fun)
SCFTilingOptions & setInterchange(ArrayRef< int64_t > interchange)
SCFTilingOptions & setTileSizes(ArrayRef< OpFoldResult > tileSizes)
Convenience function to set the tileSizeComputationFunction to a function that computes tile sizes at...
SmallVector< int64_t > interchangeVector
The interchange vector to reorder the tiled loops.
Transformation information returned after tiling.
SmallVector< Operation * > tiledOps
Tiled operations that are generated during tiling.
SmallVector< LoopLikeOpInterface > loops
The scf.for operations that iterate over the tiles.