MLIR  21.0.0git
LinalgTransformOps.cpp
Go to the documentation of this file.
1 //===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
12 
38 #include "mlir/IR/PatternMatch.h"
39 #include "mlir/IR/TypeUtilities.h"
41 #include "mlir/Support/LLVM.h"
42 #include "mlir/Support/TypeID.h"
44 #include "llvm/ADT/STLExtras.h"
45 #include "llvm/ADT/ScopeExit.h"
46 #include "llvm/ADT/TypeSwitch.h"
47 #include "llvm/Support/Debug.h"
48 #include <type_traits>
49 
50 using namespace mlir;
51 using namespace mlir::linalg;
52 using namespace mlir::transform;
53 
54 #define DEBUG_TYPE "linalg-transforms"
55 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
56 #define DBGSNL() (llvm::dbgs() << "\n")
57 #define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n")
58 
59 /// Attempts to apply the pattern specified as template argument to the given
60 /// operation. The pattern is expected to have a `returningMatchAndRewrite`
61 /// function that returns the "main" result or failure. Returns failure if the
62 /// pattern failed to apply. Extra arguments are forwarded to the pattern
63 /// constructor.
64 template <typename PatternTy, typename... Args>
65 static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
66  // Check if the given operation has the type expected by the pattern.
67  using OpTy = typename llvm::function_traits<
68  decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
69  auto op = dyn_cast<OpTy>(operation);
70  if (!op)
71  return failure();
72 
73  // Apply the pattern directly to the op.
74  PatternTy pattern(operation->getContext(), std::forward<Args>(args)...);
75  // We want to discourage direct use of PatternRewriter in APIs but In this
76  // very specific case, an IRRewriter is not enough.
77  struct TrivialPatternRewriter : public PatternRewriter {
78  public:
79  explicit TrivialPatternRewriter(MLIRContext *context)
80  : PatternRewriter(context) {}
81  };
82  TrivialPatternRewriter rewriter(operation->getContext());
83  rewriter.setInsertionPoint(operation);
84  auto result = pattern.returningMatchAndRewrite(op, rewriter);
85  if (failed(result))
86  return failure();
87  return cast<LinalgOp>(result->getOperation());
88 }
89 
90 /// Assuming that `ofr` is an index attr or a param of index type
91 /// or a transform dialect handle mapped to exactly one op
92 /// with one index result, return that value.
94  transform::TransformState &state, TransformOpInterface transformOp,
96  for (OpFoldResult ofr : ofrs) {
97  if (auto attr = dyn_cast<Attribute>(ofr)) {
98  if (!isa<IntegerAttr>(attr))
99  return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
100  result.push_back(ofr);
101  continue;
102  }
103 
104  Value transformValue = cast<Value>(ofr);
105  if (isa<TransformParamTypeInterface>(transformValue.getType())) {
106  ArrayRef<Attribute> params = state.getParams(transformValue);
107  if (params.size() != 1)
108  return transformOp.emitDefiniteFailure()
109  << "requires exactly one parameter associated";
110  result.push_back(params[0]);
111  continue;
112  }
113 
114  auto payloadOps = state.getPayloadOps(transformValue);
115  if (!llvm::hasSingleElement(payloadOps)) {
117  transformOp.emitSilenceableError()
118  << "handle must be mapped to exactly one payload op";
119  diag.attachNote(transformValue.getLoc())
120  << "mapped to " << llvm::range_size(payloadOps) << " payload ops";
121  return diag;
122  }
123 
124  Operation *op = *payloadOps.begin();
125  if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
127  transformOp.emitSilenceableError()
128  << "payload op must have exactly 1 index result";
129  diag.attachNote(op->getLoc())
130  << "has " << op->getNumResults() << " results";
131  return diag;
132  }
133  result.push_back(op->getResult(0));
134  }
135 
137 }
138 
139 // Given a list of params that are index attrs or a list of OpFoldResults
140 // that are either index attrs or op handles, return a list of OpFoldResults
141 // of index attrs or a list of OpFoldResults where all op handles are
142 // replaced with the first (and only) OpResult of that payload op.
143 // (There must be exactly one parameter associated with the AnyParamType or
144 // one mapped payload op which must have exactly one index result.)
146  transform::TransformState &state, TransformOpInterface transformOp,
147  SmallVector<OpFoldResult> &result, Value packedHandle) {
148  if (isa<TransformParamTypeInterface>(packedHandle.getType())) {
149  ArrayRef<Attribute> params = state.getParams(packedHandle);
150  for (auto param : params) {
151  if (!isa<IntegerAttr>(param))
152  return transformOp.emitDefiniteFailure()
153  << "expected the parameter to be associated with an integer "
154  "attribute";
155  result.push_back(param);
156  }
158  }
159 
160  for (Operation *op : state.getPayloadOps(packedHandle)) {
161  if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
163  transformOp.emitSilenceableError()
164  << "payload op must have exactly 1 index result";
165  diag.attachNote(op->getLoc())
166  << "has " << op->getNumResults() << " results";
167  return diag;
168  }
169  result.push_back(op->getResult(0));
170  }
171 
173 }
174 
175 /// When possible, converts each `OpFoldResult` in `mixedResult` to
176 /// an integer if the value can be statically inferred. If a result
177 /// is a `Value` then it must be either a `ParamType` or a handle
178 /// to an a constant like op.
180  TransformState &state, TransformOpInterface &transformOp,
181  ArrayRef<OpFoldResult> mixedResults, SmallVectorImpl<int64_t> &reified) {
182  for (OpFoldResult paramOrHandle : mixedResults) {
183  if (auto attr = dyn_cast<Attribute>(paramOrHandle)) {
184  reified.push_back(cast<IntegerAttr>(attr).getInt());
185  continue;
186  } else if (isa<ParamType>(cast<Value>(paramOrHandle).getType())) {
187  ArrayRef<Attribute> params = state.getParams(cast<Value>(paramOrHandle));
188  if (params.size() != 1)
189  return transformOp.emitSilenceableError() << "expected a single param";
190  reified.push_back(
191  cast<IntegerAttr>(params.front()).getValue().getSExtValue());
192  continue;
193  }
194 
195  Value handle = cast<Value>(paramOrHandle);
196  if (!isa<TransformHandleTypeInterface>(handle.getType()))
197  return transformOp.emitSilenceableError() << "unexpected value handle";
198  auto payload = state.getPayloadOps(handle);
199  if (!llvm::hasSingleElement(payload))
200  return transformOp.emitSilenceableError()
201  << "requires param or handle that is mapped to 1 payload op";
202 
203  Operation *paramOrHandlePayloadOp = *payload.begin();
204  if (paramOrHandlePayloadOp->getNumResults() != 1 ||
205  !paramOrHandlePayloadOp->getResult(0).getType().isIndex()) {
206  return transformOp.emitSilenceableError()
207  << "requires param or handle to be result of op with 1 index "
208  "result";
209  }
210 
211  IntegerAttr attr;
212  if (!matchPattern(paramOrHandlePayloadOp->getResult(0), m_Constant(&attr)))
213  return transformOp.emitSilenceableError()
214  << "requires param or handle to be the result of a constant like "
215  "op";
216 
217  reified.push_back(attr.getInt());
218  }
220 }
221 
222 //===----------------------------------------------------------------------===//
223 // Apply...PatternsOp
224 //===----------------------------------------------------------------------===//
225 
226 void transform::ApplyEraseUnnecessaryInputsPatternsOp::populatePatterns(
229 }
230 
231 void transform::ApplyDecomposeTensorPackUnpackPatternsOp::populatePatterns(
234 }
235 
236 void transform::ApplyDecomposeTensorPadPatternsOp::populatePatterns(
239 }
240 
241 void transform::ApplyFoldUnitExtentDimsViaReshapesPatternsOp::populatePatterns(
245 }
246 
247 void transform::ApplyFoldUnitExtentDimsViaSlicesPatternsOp::populatePatterns(
250  options.rankReductionStrategy =
253 }
254 
255 void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns(
258 }
259 
260 void transform::ApplyFoldAddIntoDestPatternsOp::populatePatterns(
263 }
264 
265 void transform::ApplyPadVectorizationPatternsOp::populatePatterns(
268 }
269 
270 void transform::ApplyFoldIntoPackAndUnpackPatternsOp::populatePatterns(
273 }
274 
275 void transform::ApplyFoldPackUnpackIntoEmptyPatternsOp::populatePatterns(
278 }
279 
280 //===----------------------------------------------------------------------===//
281 // BufferizeToAllocationOp
282 //===----------------------------------------------------------------------===//
283 
284 void transform::BufferizeToAllocationOp::build(OpBuilder &b,
285  OperationState &result,
286  Value target,
287  Attribute memorySpace) {
288  SmallVector<Type> resultTypes;
289  resultTypes.push_back(b.getType<transform::AnyValueType>());
290  resultTypes.push_back(b.getType<transform::AnyOpType>());
291  return build(b, result,
292  /*resultTypes=*/resultTypes,
293  /*target=*/target,
294  /*memorySpace=*/memorySpace);
295 }
296 
297 void transform::BufferizeToAllocationOp::build(OpBuilder &b,
298  OperationState &result,
299  Value target,
300  int64_t memorySpace) {
301  SmallVector<Type> resultTypes;
302  resultTypes.push_back(b.getType<transform::AnyValueType>());
303  resultTypes.push_back(b.getType<transform::AnyOpType>());
304  return build(b, result,
305  /*resultTypes=*/resultTypes,
306  /*target=*/target,
307  /*memorySpace=*/b.getI64IntegerAttr(memorySpace));
308 }
309 
310 namespace {
311 class NewOpsListener : public RewriterBase::ForwardingListener {
312 public:
314 
315  SmallVector<Operation *> getNewOps() const {
316  return SmallVector<Operation *>(newOps.begin(), newOps.end());
317  }
318 
319 private:
320  void notifyOperationInserted(Operation *op,
321  OpBuilder::InsertPoint previous) override {
322  ForwardingListener::notifyOperationInserted(op, previous);
323  // We only care about newly created ops.
324  if (previous.isSet())
325  return;
326  auto inserted = newOps.insert(op);
327  (void)inserted;
328  assert(inserted.second && "expected newly created op");
329  }
330 
331  void notifyOperationErased(Operation *op) override {
332  ForwardingListener::notifyOperationErased(op);
333  op->walk([&](Operation *op) { newOps.erase(op); });
334  }
335 
336  DenseSet<Operation *> newOps;
337 };
338 } // namespace
339 
340 DiagnosedSilenceableFailure transform::BufferizeToAllocationOp::apply(
343  // Attach listener to keep track of newly created ops.
344  OpBuilder::Listener *previousListener = rewriter.getListener();
345  auto resetListener =
346  llvm::make_scope_exit([&]() { rewriter.setListener(previousListener); });
347  NewOpsListener newOpsListener(previousListener);
348  rewriter.setListener(&newOpsListener);
349 
351  if (getMemcpyOp() == "bufferization.materialize_in_destination") {
354  } else if (getMemcpyOp() == "memref.copy") {
355  options.memcpyOp =
357  } else if (getMemcpyOp() == "linalg.copy") {
358  options.memcpyOp =
360  } else {
361  llvm_unreachable("invalid memcpy op");
362  }
363  if (getAllocOp() == "memref.alloc") {
364  options.allocOp =
366  } else if (getAllocOp() == "memref.alloca") {
367  options.allocOp =
369  } else {
370  llvm_unreachable("invalid alloc op");
371  }
372  options.bufferizeDestinationOnly = getBufferizeDestinationOnly();
373  options.emitDealloc = getEmitDealloc();
374 
375  // Bufferize ops.
376  Attribute memorySpace =
377  getMemorySpace().has_value() ? getMemorySpace().value() : Attribute();
378  SmallVector<Value> allocatedBuffers;
379  for (Operation *op : state.getPayloadOps(getTarget())) {
380  Value buffer =
381  linalg::bufferizeToAllocation(rewriter, options, op, memorySpace);
382  if (!buffer) {
383  DiagnosedSilenceableFailure diag = emitSilenceableError()
384  << "failed to bufferize operation";
385  diag.attachNote(op->getLoc()) << "target payload op";
386  return diag;
387  }
388  allocatedBuffers.push_back(buffer);
389  }
390 
391  // Set results.
392  results.setValues(cast<OpResult>(getAllocatedBuffer()), allocatedBuffers);
393  results.set(cast<OpResult>(getNewOps()), newOpsListener.getNewOps());
395 }
396 
397 void transform::BufferizeToAllocationOp::getEffects(
399  if (getBufferizeDestinationOnly()) {
400  // The destination is replaced with a newly allocated buffer, but the op
401  // itself remains in place.
402  onlyReadsHandle(getTargetMutable(), effects);
403  } else {
404  consumesHandle(getTargetMutable(), effects);
405  }
406  producesHandle(getOperation()->getOpResults(), effects);
407  modifiesPayload(effects);
408 }
409 
411  if (getMemcpyOp() != "bufferization.materialize_in_destination" &&
412  getMemcpyOp() != "memref.copy" && getMemcpyOp() != "linalg.copy")
413  return emitOpError() << "unsupported memcpy op";
414  if (getAllocOp() != "memref.alloc" && getAllocOp() != "memref.alloca")
415  return emitOpError() << "unsupported alloc op";
416  return success();
417 }
418 
419 //===----------------------------------------------------------------------===//
420 // DecomposeOp
421 //===----------------------------------------------------------------------===//
422 
424 transform::DecomposeOp::applyToOne(transform::TransformRewriter &rewriter,
425  LinalgOp target,
427  transform::TransformState &state) {
428 #define DOWNSCALE(trans) \
429  { \
430  FailureOr<LinalgOp> res = tryApply<trans>(target); \
431  if (succeeded(res)) { \
432  results.push_back(*res); \
433  return DiagnosedSilenceableFailure::success(); \
434  } \
435  }
436 
437 #define DOWNSCALE_CALL(a, b) DownscaleSizeOneWindowed2DConvolution<a, b>
438 #define DOWNSCALE_NORMAL(a, b) DOWNSCALE(DOWNSCALE_CALL(a, b))
439 
440  DOWNSCALE_NORMAL(Conv2DNhwcHwcfOp, Conv1DNwcWcfOp)
441  DOWNSCALE_NORMAL(Conv2DNchwFchwOp, Conv1DNcwFcwOp)
442  DOWNSCALE_NORMAL(PoolingNhwcSumOp, PoolingNwcSumOp)
443  DOWNSCALE_NORMAL(PoolingNchwSumOp, PoolingNcwSumOp)
444  DOWNSCALE_NORMAL(PoolingNhwcMaxOp, PoolingNwcMaxOp)
445  DOWNSCALE_NORMAL(PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp)
446  DOWNSCALE_NORMAL(PoolingNhwcMinOp, PoolingNwcMinOp)
447  DOWNSCALE_NORMAL(PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp)
448  DOWNSCALE_NORMAL(PoolingNchwMaxOp, PoolingNcwMaxOp)
451 #undef DOWNSCALE_NORMAL
452 #undef DOWNSCALE_CALL
453 #undef DOWNSCALE
454  return emitDefaultSilenceableFailure(target);
455 }
456 
457 //===----------------------------------------------------------------------===//
458 // DecomposeInterfaceOp
459 //===----------------------------------------------------------------------===//
460 
461 // Decompose the target operation if it implements the AggregatedOpInterface.
462 // Push the decomposed operations (the ones that replaces the values produced by
463 // \p target) in the `results`.
464 DiagnosedSilenceableFailure transform::DecomposeInterfaceOp::applyToOne(
465  transform::TransformRewriter &rewriter, Operation *target,
467  transform::TransformState &state) {
468  auto decomposableOp = dyn_cast<AggregatedOpInterface>(target);
469  if (!decomposableOp) {
470  failed(rewriter.notifyMatchFailure(target,
471  "payload is not a decomposable op"));
472  return emitDefaultSilenceableFailure(target);
473  }
474 
475  FailureOr<SmallVector<Value>> maybeNewResults =
476  decomposableOp.decomposeOperation(rewriter);
477  if (failed(maybeNewResults))
478  return emitDefaultSilenceableFailure(target);
479 
480  rewriter.replaceOp(decomposableOp, *maybeNewResults);
481  for (Value val : *maybeNewResults) {
482  Operation *definition = val.getDefiningOp();
483  if (definition)
484  results.push_back(definition);
485  }
487 }
488 
489 //===----------------------------------------------------------------------===//
490 // EliminateLinalgOpAnchoredEmptyTensorsOp
491 //===----------------------------------------------------------------------===//
492 
493 void transform::EliminateLinalgOpAnchoredEmptyTensorsOp::getEffects(
495  onlyReadsHandle(getTargetMutable(), effects);
496  modifiesPayload(effects);
497 }
498 
500 transform::EliminateLinalgOpAnchoredEmptyTensorsOp::apply(
501  transform::TransformRewriter &rewriter, TransformResults &transformResults,
502  TransformState &state) {
504  options.allowReturnAllocsFromLoops = true;
505 
506  for (Operation *target : state.getPayloadOps(getTarget())) {
508  if (failed(analyzeOp(target, state)))
509  return mlir::emitSilenceableFailure(target->getLoc())
510  << "failed to analyze op";
512  rewriter, target, state)))
513  return mlir::emitSilenceableFailure(target->getLoc())
514  << "failed to eliminate LinalgOp anchored tensor.empty ops";
515  }
517 }
518 
519 //===----------------------------------------------------------------------===//
520 // FuseOp
521 //===----------------------------------------------------------------------===//
522 
523 /// Apply a tiling transformation to all payload ops and store both the
524 /// tiled operation as well as the created tile loops.
525 template <typename Range>
526 static LogicalResult applyTilingToAll(
527  RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps,
528  unsigned numLoops, transform::TransformResults &transformResults,
529  function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)>
530  applyFn) {
531  SmallVector<Operation *> tiledLinalgOps;
532  SmallVector<SmallVector<Operation *>> loopOps(numLoops);
533 
534  for (Operation *target : payloadOps) {
535  auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
536  if (!tilingInterfaceOp)
537  return transformOp->emitError("only TilingInterface ops are supported");
538 
539  rewriter.setInsertionPoint(target);
540  FailureOr<scf::SCFTileAndFuseResult> tiledResults =
541  applyFn(tilingInterfaceOp);
542  if (failed(tiledResults))
543  return failure();
544 
545  // Perform the replacement of tiled and fused values.
546  SmallVector<Operation *> opsToReplace{target};
547  llvm::append_range(opsToReplace, tiledResults->fusedProducers);
548  for (Operation *toReplace : opsToReplace) {
549  for (OpResult res : toReplace->getResults())
550  if (auto replacement = tiledResults->replacements.lookup(res))
551  rewriter.replaceAllUsesWith(res, replacement);
552  if (toReplace->use_empty()) {
553  rewriter.eraseOp(toReplace);
554  }
555  }
556 
557  // Report back the relevant handles to the transform op.
558  tiledLinalgOps.push_back(tiledResults->tiledAndFusedOps.front());
559  assert(tiledResults->loops.size() == numLoops &&
560  "Mismatched number of loops, tile and fuse transform should have "
561  "failed");
562  for (unsigned int i = 0; i < numLoops; ++i)
563  loopOps[i].push_back(tiledResults->loops[i]);
564  }
565 
566  transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
567  for (unsigned int i = 0; i < numLoops; ++i)
568  transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
569 
570  return success();
571 }
572 
574 transform::FuseOp::apply(transform::TransformRewriter &rewriter,
575  mlir::transform::TransformResults &transformResults,
577  SmallVector<int64_t> tileSizes =
578  extractFromIntegerArrayAttr<int64_t>(getTileSizes());
579  SmallVector<int64_t> tileInterchange =
580  extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
581 
582  scf::SCFTilingOptions tilingOptions;
583  tilingOptions.interchangeVector = tileInterchange;
584  SmallVector<OpFoldResult> tileSizesOfr =
585  getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
586  tilingOptions = tilingOptions.setTileSizes(tileSizesOfr);
587  scf::SCFTileAndFuseOptions tileAndFuseOptions;
588  tileAndFuseOptions.tilingOptions = tilingOptions;
589 
590  if (getApplyCleanup()) {
591  MLIRContext *context = rewriter.getContext();
592  RewritePatternSet patterns(context);
593  tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, context);
596  tileAndFuseOptions.cleanupPatterns = std::move(patterns);
597  }
598 
599  LogicalResult result = applyTilingToAll(
600  rewriter, getOperation(), state.getPayloadOps(getTarget()),
601  tileSizes.size() - llvm::count(tileSizes, 0), transformResults,
602  [&](TilingInterface tilingInterfaceOp)
603  -> FailureOr<scf::SCFTileAndFuseResult> {
604  return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
605  tileAndFuseOptions);
606  });
607  return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
608  : DiagnosedSilenceableFailure::success();
609 }
610 
611 LogicalResult transform::FuseOp::verify() {
612  SmallVector<int64_t> permutation =
613  extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
614  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
615  if (!std::is_permutation(sequence.begin(), sequence.end(),
616  permutation.begin(), permutation.end())) {
617  return emitOpError() << "expects interchange to be a permutation, found "
618  << getTileInterchange();
619  }
620 
621  SmallVector<int64_t> sizes =
622  extractFromIntegerArrayAttr<int64_t>(getTileSizes());
623  size_t numExpectedLoops = sizes.size() - llvm::count(sizes, 0);
624  if (numExpectedLoops != getNumResults() - 1)
625  return emitOpError() << "expects " << numExpectedLoops << " loop results";
626 
627  return success();
628 }
629 
630 //===----------------------------------------------------------------------===//
631 // FuseIntoContainingOp
632 //===----------------------------------------------------------------------===//
633 
634 void transform::FuseIntoContainingOp::build(OpBuilder &builder,
635  OperationState &result,
636  Value producerOp,
637  Value containingOp) {
638  result.addOperands({producerOp, containingOp});
639  auto resultType = transform::AnyOpType::get(builder.getContext());
640  result.addTypes({resultType, resultType});
641 }
642 
643 /// Add new operands to the forall op for users of the producerOp
644 /// that are dominated by the containing scf.forall op.
646  RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
647  Operation *containingOp, TilingResult &tileAndFuseResult,
648  int64_t resultNumber, SmallVector<OpFoldResult> &offsets,
649  SmallVector<OpFoldResult> &sizes) {
650 
651  // Count number of users not including the containing op
652  SetVector<Operation *> dominatedUsers;
653  DominanceInfo domInfo(containingOp);
654  for (Operation *user : producerOp->getResult(resultNumber).getUsers()) {
655  if (!containingOp->isAncestor(user) &&
656  (domInfo.dominates(containingOp, user))) {
657  dominatedUsers.insert(user);
658  }
659  }
660  if (dominatedUsers.empty())
661  return nullptr;
662 
663  // Create new scf.forall op
664  auto forallOp = cast<scf::ForallOp>(containingOp);
665  OpBuilder::InsertionGuard g(rewriter);
666  rewriter.setInsertionPoint(forallOp);
667 
668  // Get new output
669  Location loc = forallOp.getLoc();
670  auto genericOp = dyn_cast<linalg::GenericOp>(producerOp);
671  if (!genericOp)
672  return nullptr;
673  SmallVector<Value> outputs = genericOp.getOutputs();
674  SmallVector<Value> newOuts(forallOp.getOutputs());
675  newOuts.push_back(outputs[resultNumber]);
676 
677  // Create new scf.forall op
678  auto newforallOp = rewriter.create<scf::ForallOp>(
679  loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
680  forallOp.getMixedStep(), newOuts, forallOp.getMapping());
681  rewriter.eraseBlock(newforallOp.getBody());
682  newforallOp.getRegion().takeBody(forallOp.getRegion());
683 
684  // Add additional block argument for new value being returned
685  // and replaces all uses of the new output with corresponding bbArg
686  // inside the scf.forall to enable fusion into this new scf.forall.
687  newforallOp.getBody()->addArgument(newOuts.back().getType(),
688  newOuts.back().getLoc());
689  auto bbArgs = newforallOp.getBody()->getArguments();
690  rewriter.replaceUsesWithIf(newOuts.back(), bbArgs.back(),
691  [&](OpOperand &use) {
692  Operation *op = use.getOwner();
693  return newforallOp->isProperAncestor(op);
694  });
695 
696  // Fix terminator
697  scf::InParallelOp terminatorOp = newforallOp.getTerminator();
698  SmallVector<Operation *> yieldingOps = llvm::to_vector<4>(llvm::map_range(
699  terminatorOp.getYieldingOps(), [](Operation &op) { return &op; }));
700  Operation *firstYieldOp = yieldingOps.front();
701  rewriter.setInsertionPoint(firstYieldOp);
702  Value src = tileAndFuseResult.tiledValues[0];
703  Value dst = newforallOp.getRegionIterArgs().back();
704  SmallVector<OpFoldResult> strides(offsets.size(), rewriter.getIndexAttr(1));
705  rewriter.create<tensor::ParallelInsertSliceOp>(firstYieldOp->getLoc(), src,
706  dst, offsets, sizes, strides);
707 
708  for (auto result : llvm::enumerate(forallOp.getResults())) {
709  rewriter.replaceAllUsesWith(result.value(),
710  newforallOp->getResult(result.index()));
711  }
712  rewriter.replaceUsesWithIf(producerOp->getResult(resultNumber),
713  newforallOp->getResults().back(),
714  [&](OpOperand &use) {
715  Operation *user = use.getOwner();
716  return dominatedUsers.contains(user);
717  });
718  return newforallOp;
719 }
720 
721 /// Find the first "extract" user of `producerOp` and tile it right before its
722 /// use. The tiled op is fused under the `containingOp`.
723 /// Return this fused op on success or nullptr if anything fails.
724 /// If tiled op has uses that are dominated by `containingOp`, return
725 /// a new `containingOp` with results of the fused op appended to
726 /// results of the `containingOp` or nullptr if there are no dominated uses.
727 static std::tuple<SmallVector<Operation *>, Operation *>
729  Operation *producerOp, Operation *containingOp) {
730  LLVM_DEBUG(DBGS() << "Try to fuse a direct extract use\n");
731  auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
732  if (!tileableProducer) {
733  diag.attachNote(producerOp->getLoc())
734  << "producer is not a TileableInterface: " << *producerOp;
735  return {};
736  }
737 
738  // Search the producer slices accessed within the containing operation.
739  // TODO: Generalize to more extract/insert/parallel_insert triples, maybe
740  // evolve into an interface.
741  auto it = llvm::find_if(tileableProducer->getUsers(), [&](Operation *user) {
742  auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
743  return sliceOp && containingOp->isProperAncestor(sliceOp);
744  });
745 
746  // Find a fusion opportunity.
747  if (it == tileableProducer->getUsers().end()) {
748  diag.attachNote(tileableProducer->getLoc())
749  << "could not find fusion opportunity for: " << *tileableProducer;
750  return {};
751  }
752  auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*it);
753 
754  // Try to fuse the producer in-place.
755  OpBuilder::InsertionGuard guard(rewriter);
756  rewriter.setInsertionPoint(sliceOpToTile);
757 
758  // Tile the producer.
759  int64_t resultNumber =
760  cast<OpResult>(sliceOpToTile.getSource()).getResultNumber();
761  LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
762 
763  SmallVector<OpFoldResult> offsets = sliceOpToTile.getMixedOffsets();
764  SmallVector<OpFoldResult> sizes = sliceOpToTile.getMixedSizes();
765 
766  FailureOr<TilingResult> tileAndFuseResult =
767  tileableProducer.generateResultTileValue(rewriter, resultNumber, offsets,
768  sizes);
769 
770  if (failed(tileAndFuseResult)) {
771  diag.attachNote(tileableProducer->getLoc())
772  << "failed to tile producer op: " << *tileableProducer;
773  return {};
774  }
775 
776 #ifndef NDEBUG
777  for (auto *tiledOp : tileAndFuseResult->tiledOps) {
778  LLVM_DEBUG(DBGS() << "tiledProducer: " << *tiledOp << "\n");
779  }
780 #endif
781 
782  // Replace the extract op.
783  auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
784  rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
785  cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
786  if (failed(maybeRankReduced)) {
787  diag.attachNote(producerOp->getLoc())
788  << "shape types don't match (missing canonicalization?):\nTiledOp: "
789  << tileAndFuseResult->tiledValues[0]
790  << "\nSliceOp: " << sliceOpToTile.getOperation() << '\n';
791  return {};
792  }
793  rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
794 
795  // Add new outputs to containing op, if required
796  Operation *newContainingOp = replaceForAllWithNewSignature(
797  rewriter, diag, producerOp, containingOp, *tileAndFuseResult,
798  resultNumber, offsets, sizes);
799 
800  return std::make_tuple(tileAndFuseResult->tiledOps, newContainingOp);
801 }
802 
803 /// First, find the first "scf::ForallOp" user of `producerOp` and ensure
804 /// it is exactly the `containingOp`, otherwise bail.
805 /// Then, find the first "extract" user of the tied block argument and tile it
806 /// right before its "extract" use. The tiled op is fused under the
807 /// `containingOp`.
808 /// Return this fused op on success or nullptr if anything fails.
811  RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
812  Operation *containingOp) {
813  LLVM_DEBUG(DBGS() << "Try to fuse an extract use through block argument\n");
814 
815  auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
816  if (!tileableProducer) {
817  diag.attachNote(producerOp->getLoc())
818  << "producer is not a TileableInterface: " << *producerOp;
819  return {};
820  }
821 
822  // Search the first use by a "scf::ForallOp" user.
823  scf::ForallOp forallOp;
824  auto itProducerUses =
825  llvm::find_if(tileableProducer->getUses(), [&](OpOperand &use) {
826  forallOp = dyn_cast<scf::ForallOp>(use.getOwner());
827  return forallOp;
828  });
829  // If it's not from the containing op, return.
830  if (!forallOp || forallOp != containingOp) {
831  diag.attachNote(tileableProducer->getLoc())
832  << "could not find a use by the containing op: " << *tileableProducer;
833  return {};
834  }
835 
836  // Search the producer slices accessed within the containing
837  // operation.
838  // TODO: Generalize to more extract/insert/parallel_insert triples.
839  // Maybe evolve into an interface.
840  OpOperand *pUse = &(*itProducerUses);
841  BlockArgument bbArg = forallOp.getTiedBlockArgument(pUse);
842 
843  // Search the producer slices accessed within the containing operation.
844  // TODO: Generalize to more extract/insert/parallel_insert triples, maybe
845  // evolve into an interface.
846  auto itBBArgUsers = llvm::find_if(bbArg.getUsers(), [&](Operation *user) {
847  auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
848  return sliceOp && containingOp->isProperAncestor(sliceOp);
849  });
850 
851  // Find a fusion opportunity.
852  if (itBBArgUsers == bbArg.getUsers().end()) {
853  diag.attachNote(containingOp->getLoc())
854  << "could not find fusion opportunity for bbArg: " << bbArg;
855  return {};
856  }
857  auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*itBBArgUsers);
858 
859  // Try to fuse the producer in-place.
860  OpBuilder::InsertionGuard guard(rewriter);
861  rewriter.setInsertionPoint(sliceOpToTile);
862 
863  // Replace the use in the tileableProducer before tiling: clone, replace and
864  // then tile.
865  int64_t resultNumber = cast<OpResult>(pUse->get()).getResultNumber();
866  LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
867 
868  // Gather destination tensors.
869  SmallVector<Value> destinationTensors;
871  rewriter, tileableProducer->getLoc(), tileableProducer,
872  destinationTensors))) {
873  diag.attachNote(tileableProducer->getLoc())
874  << "failed to get destination tensors for: " << *tileableProducer;
875  return {};
876  }
877 
878  IRMapping bvm;
879  bvm.map(destinationTensors[resultNumber], bbArg);
880  auto tileableProducerClone =
881  cast<TilingInterface>(rewriter.clone(*tileableProducer, bvm));
882  auto scopeGuard =
883  llvm::make_scope_exit([&]() { rewriter.eraseOp(tileableProducerClone); });
884 
885  // Tile the producer.
886  FailureOr<TilingResult> tileAndFuseResult =
887  tileableProducerClone.generateResultTileValue(
888  rewriter, resultNumber, sliceOpToTile.getMixedOffsets(),
889  sliceOpToTile.getMixedSizes());
890  if (failed(tileAndFuseResult)) {
891  diag.attachNote(tileableProducer->getLoc())
892  << "failed to tile producer op: " << *tileableProducer;
893  return {};
894  }
895 
896  // Replace the extract op.
897  auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
898  rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
899  cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
900  assert(succeeded(maybeRankReduced) && "unexpected shape");
901  rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
902 
903  // Replace the use in containingOp.
904  rewriter.modifyOpInPlace(containingOp, [&]() {
905  containingOp->setOperand(pUse->getOperandNumber(),
906  destinationTensors.front());
907  });
908 
909  return tileAndFuseResult->tiledOps;
910 }
911 
913  Operation *producerOp,
914  Operation *containingOp) {
915  LLVM_DEBUG(DBGS() << "Try to fuse an use by cloning\n");
916 
917  // Gather all uses inside the containing op.
919  for (OpResult result : producerOp->getOpResults()) {
920  for (OpOperand &use : result.getUses()) {
921  if (containingOp->isProperAncestor(use.getOwner())) {
922  uses.push_back(&use);
923  continue;
924  }
925  // Cannot clone and fuse if the use is by the containing op itself: fail
926  // immediately.
927  if (containingOp == use.getOwner()) {
928  diag.attachNote(producerOp->getLoc())
929  << "producer op use by containing op cannot be fused by cloning";
930  return nullptr;
931  }
932  }
933  }
934 
935  // Check for a non-empty list of fusion opportunities.
936  if (uses.empty()) {
937  diag.attachNote(producerOp->getLoc()) << "no fusion opportunity by cloning";
938  return nullptr;
939  }
940 
941  // Clone and fuse inside the containing op.
942  Operation *fusedOp = nullptr;
943  OpOperand *use = uses.front();
944  // Parallel insert slice is not a valid clone destination.
945  // TODO: Generalize to other type of ops.
946  assert(!isa<tensor::ParallelInsertSliceOp>(use->getOwner()) &&
947  "Parallel insert slice is not a valid clone destination");
948  unsigned resultNumber = cast<OpResult>(use->get()).getResultNumber();
949  LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
950 
951  OpBuilder::InsertionGuard guard(rewriter);
952  rewriter.setInsertionPoint(use->getOwner());
953  fusedOp = rewriter.clone(*producerOp);
954  rewriter.modifyOpInPlace(
955  use->getOwner(), [&] { use->set(fusedOp->getOpResult(resultNumber)); });
956 
957  return fusedOp;
958 }
959 
960 bool transform::FuseIntoContainingOp::allowsRepeatedHandleOperands() {
961  // Allow repeated handles since we are fusing everything anyway.
962  return true;
963 }
964 
966 transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter,
968  transform::TransformState &state) {
969  SmallVector<Operation *> fusedOps;
970  auto producerOps = state.getPayloadOps(getProducerOp());
971  auto containingOps = state.getPayloadOps(getContainingOp());
972  if (!llvm::hasSingleElement(containingOps)) {
973  return emitDefiniteFailure()
974  << "requires exactly one containing_op handle (got "
975  << llvm::range_size(containingOps) << ")";
976  }
977  Operation *containingOp = *containingOps.begin();
978 
979  // If nothing to fuse, propagate success.
980  if (std::empty(producerOps)) {
981  results.set(cast<OpResult>(getFusedOp()), SmallVector<mlir::Operation *>{});
982  results.set(cast<OpResult>(getNewContainingOp()), {containingOp});
984  }
985 
986  // Helper function to find the next producer that should be fused. Take any
987  // producer that has a use inside the containing op.
988  SetVector<Operation *> remainingProducers(producerOps.begin(),
989  producerOps.end());
990  auto getNextProducer = [&]() -> FailureOr<Operation *> {
991  for (const auto &it : enumerate(remainingProducers)) {
992  Operation *producerOp = it.value();
993  // The containing op may be a user of producerOp: use isAncestor.
994  int64_t numUsesInContainingOp =
995  llvm::count_if(producerOp->getUsers(), [&](Operation *op) {
996  return containingOp->isAncestor(op);
997  });
998  // TODO: When resolving the TODO below (no duplicate ops), take an op
999  // that has no use among the remaining producers. This is a topological
1000  // sorting.
1001  if (numUsesInContainingOp > 0) {
1002  if (numUsesInContainingOp == 1)
1003  remainingProducers.erase(remainingProducers.begin() + it.index());
1004  return producerOp;
1005  }
1006  }
1007  return failure();
1008  };
1009 
1010  while (!remainingProducers.empty()) {
1011  auto nextProducer = getNextProducer();
1012  if (failed(nextProducer)) {
1013  auto diag = mlir::emitSilenceableFailure(getLoc())
1014  << "could not find next producer to fuse into container";
1015  diag.attachNote(containingOp->getLoc()) << "containing op";
1016  return diag;
1017  }
1018 
1019  Operation *producerOp = *nextProducer;
1020 
1021  // Default diagnostic, to be complemented with more failure information.
1023  diag << "could not fuse " << *producerOp << " into " << *containingOp;
1024 
1025  // TODO: If there are multiple uses of the producer in the containing op,
1026  // we currently tile/clone the op multiple times (once per use). In some
1027  // cases, we can tile/clone once and reuse the value for each use.
1028  // Futhermore, producers should then be traversed according to a
1029  // topological sorting.
1030  auto [tiledOps, newContainingOp] =
1031  tileAndFuseFirstExtractUse(rewriter, diag, producerOp, containingOp);
1032  if (!tiledOps.empty()) {
1033  LLVM_DEBUG(DBGS() << "\nFused a direct extract use\n" << *containingOp);
1034  fusedOps.append(tiledOps);
1035  if (newContainingOp) {
1036  // Update handles associated with the containing op so we don't need to
1037  // invalidate them. This is a hack to support better composability
1038  // between tiling and fusion while a proper mechanism is being
1039  // investigated.
1040  //
1041  // DO NOT replicate this elsewhere unless you understand what you are
1042  // doing.
1043  LogicalResult replacementStatus =
1044  rewriter.notifyPayloadOperationReplaced(containingOp,
1045  newContainingOp);
1046  (void)replacementStatus;
1047  assert(succeeded(replacementStatus) &&
1048  "unable to update transform state mapping");
1049  rewriter.eraseOp(containingOp);
1050  containingOp = newContainingOp;
1051  }
1052  continue;
1053  }
1054 
1055  SmallVector<Operation *> tiledContainingOpOperand =
1057  rewriter, diag, producerOp, containingOp);
1058  if (!tiledContainingOpOperand.empty()) {
1059  LLVM_DEBUG(DBGS() << "\nFused an extract use through block argument\n"
1060  << *containingOp);
1061  fusedOps.append(tiledContainingOpOperand);
1062  continue;
1063  }
1064 
1065  Operation *cloned =
1066  cloneAndFuseFirstUse(rewriter, diag, producerOp, containingOp);
1067  if (cloned) {
1068  LLVM_DEBUG(DBGS() << "\nFused an use by cloning\n" << *containingOp);
1069  fusedOps.push_back(cloned);
1070  continue;
1071  }
1073  }
1074 
1075  results.set(cast<OpResult>(getFusedOp()), fusedOps);
1076  results.set(cast<OpResult>(getNewContainingOp()), {containingOp});
1078 }
1079 
1080 void transform::FuseIntoContainingOp::getEffects(
1082  consumesHandle(getProducerOpMutable(), effects);
1083  onlyReadsHandle(getContainingOpMutable(), effects);
1084  producesHandle(getOperation()->getOpResults(), effects);
1085  modifiesPayload(effects);
1086 }
1087 
1088 //===----------------------------------------------------------------------===//
1089 // GeneralizeOp
1090 //===----------------------------------------------------------------------===//
1091 
1093 transform::GeneralizeOp::applyToOne(transform::TransformRewriter &rewriter,
1094  LinalgOp target,
1096  transform::TransformState &state) {
1097  // Exit early if no transformation is needed.
1098  if (isa<GenericOp>(target)) {
1099  results.push_back(target);
1101  }
1102  rewriter.setInsertionPoint(target);
1103  FailureOr<LinalgOp> generic = generalizeNamedOp(rewriter, target);
1104  if (succeeded(generic)) {
1105  results.push_back(generic->getOperation());
1107  }
1108  return emitDefaultSilenceableFailure(target);
1109 }
1110 
1111 //===----------------------------------------------------------------------===//
1112 // SpecializeOp
1113 //===----------------------------------------------------------------------===/
1114 
1116 transform::SpecializeOp::applyToOne(transform::TransformRewriter &rewriter,
1117  LinalgOp target,
1119  transform::TransformState &state) {
1120  // Exit early if the operation is not a generic.
1121  if (!isa<GenericOp>(target)) {
1122  results.push_back(target);
1124  }
1125  rewriter.setInsertionPoint(target);
1126  FailureOr<LinalgOp> named =
1127  specializeGenericOp(rewriter, cast<GenericOp>(target));
1128  if (succeeded(named)) {
1129  results.push_back(named->getOperation());
1131  }
1132  return emitDefaultSilenceableFailure(target);
1133 }
1134 
1135 //===----------------------------------------------------------------------===//
1136 // InterchangeOp
1137 //===----------------------------------------------------------------------===//
1138 
1140 transform::InterchangeOp::applyToOne(transform::TransformRewriter &rewriter,
1141  GenericOp target,
1143  transform::TransformState &state) {
1144  ArrayRef<int64_t> interchangeVector = getIteratorInterchange();
1145  // Exit early if no transformation is needed.
1146  if (interchangeVector.empty()) {
1147  results.push_back(target);
1149  }
1150 
1151  unsigned numLoops = cast<LinalgOp>(target.getOperation()).getNumLoops();
1152  if (interchangeVector.size() != numLoops) {
1153  return emitSilenceableError()
1154  << getIteratorInterchangeAttrName() << " has length ("
1155  << interchangeVector.size()
1156  << ") different from the number of loops in the target operation ("
1157  << numLoops << ")";
1158  }
1159  FailureOr<GenericOp> res = interchangeGenericOp(
1160  rewriter, target, SmallVector<unsigned>(interchangeVector));
1161  if (failed(res))
1162  return emitDefiniteFailure() << "failed to apply";
1163  results.push_back(res->getOperation());
1165 }
1166 
1167 LogicalResult transform::InterchangeOp::verify() {
1168  ArrayRef<int64_t> permutation = getIteratorInterchange();
1169  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
1170  if (!std::is_permutation(sequence.begin(), sequence.end(),
1171  permutation.begin(), permutation.end())) {
1172  return emitOpError()
1173  << "expects iterator_interchange to be a permutation, found "
1174  << getIteratorInterchange();
1175  }
1176  return success();
1177 }
1178 
1179 //===----------------------------------------------------------------------===//
1180 // LinalgCopyToMemrefOp
1181 //===----------------------------------------------------------------------===//
1182 
1183 DiagnosedSilenceableFailure transform::LinalgCopyToMemrefOp::applyToOne(
1184  transform::TransformRewriter &rewriter, Operation *targetOp,
1186  transform::TransformState &state) {
1187 
1188  // Check if the target can be converted.
1189  if (!isa<linalg::CopyOp>(targetOp)) {
1191  emitSilenceableError() << "only linalg.copy target ops are supported";
1192  diag.attachNote(targetOp->getLoc()) << "target op";
1193  return diag;
1194  }
1195 
1196  auto copyOp = dyn_cast<linalg::CopyOp>(targetOp);
1197  if (!copyOp.hasPureBufferSemantics()) {
1199  emitSilenceableError()
1200  << "cannot transform a linalg.copy on tensors into a memref.copy";
1201  diag.attachNote(targetOp->getLoc()) << "target op";
1202  return diag;
1203  }
1204 
1205  SmallVector<Value> inputs = copyOp.getInputs();
1206  SmallVector<Value> outputs = copyOp.getOutputs();
1207  assert(inputs.size() == 1 && "expected linalg copy op with one input");
1208  assert(outputs.size() == 1 && "expected memref copy op with one output");
1209  Value input = inputs.front();
1210  Value output = outputs.front();
1211 
1212  // linalg.copy supports different element types on source/dest whereas
1213  // memref.copy does not, so we must check that the source and dest types can
1214  // be handled by memref.copy and otherwise reject the transformation.
1215  if (!dyn_cast<ShapedType>(input.getType())) {
1217  emitSilenceableError()
1218  << "cannot transform a linalg.copy which input has no shape";
1219  diag.attachNote(targetOp->getLoc()) << "target op";
1220  return diag;
1221  }
1222 
1223  // linalg.copy destination must be a shaped type.
1224  assert(dyn_cast<ShapedType>(output.getType()));
1225 
1226  if (cast<ShapedType>(input.getType()).getElementType() !=
1227  cast<ShapedType>(output.getType()).getElementType()) {
1229  emitSilenceableError()
1230  << "cannot transform a linalg.copy with different source and "
1231  "destination element types ";
1232  diag.attachNote(targetOp->getLoc()) << "target op";
1233  return diag;
1234  }
1235 
1236  // Target can be converted, do it.
1237  auto memrefCopyOp =
1238  rewriter.replaceOpWithNewOp<memref::CopyOp>(targetOp, input, output);
1239 
1240  results.push_back(memrefCopyOp);
1242 }
1243 
1244 //===----------------------------------------------------------------------===//
1245 // LowerPackOp
1246 //===----------------------------------------------------------------------===//
1247 
1248 DiagnosedSilenceableFailure transform::LowerPackOp::applyToOne(
1249  transform::TransformRewriter &rewriter, linalg::PackOp target,
1250  transform::ApplyToEachResultList &transformResults,
1251  transform::TransformState &state) {
1252  rewriter.setInsertionPoint(target);
1253  bool lowerPadLikeWithInsertSlice = getLowerPadLikeWithInsertSlice();
1254  FailureOr<LowerPackResult> res =
1255  lowerPack(rewriter, target, lowerPadLikeWithInsertSlice);
1256  if (failed(res)) {
1257  return mlir::emitSilenceableFailure(target->getLoc())
1258  << "cannot lower to pad + expand + transpose";
1259  }
1260  transformResults.push_back(res->padOp);
1261  transformResults.push_back(res->expandShapeOp);
1262  transformResults.push_back(res->transposeOp);
1264 }
1265 
1266 //===----------------------------------------------------------------------===//
1267 // LowerUnPackOp
1268 //===----------------------------------------------------------------------===//
1269 
1270 DiagnosedSilenceableFailure transform::LowerUnPackOp::applyToOne(
1271  transform::TransformRewriter &rewriter, linalg::UnPackOp target,
1272  transform::ApplyToEachResultList &transformResults,
1273  transform::TransformState &state) {
1274  rewriter.setInsertionPoint(target);
1275  bool lowerUnpadLikeWithExtractSlice = getLowerUnpadLikeWithExtractSlice();
1276  FailureOr<LowerUnPackOpResult> res =
1277  lowerUnPack(rewriter, target, lowerUnpadLikeWithExtractSlice);
1278  if (failed(res)) {
1280  emitSilenceableError()
1281  << "cannot lower to transpose + collapse + extract";
1282  diag.attachNote(target->getLoc()) << "target payload op";
1283  return diag;
1284  }
1285  transformResults.push_back(res->emptyOp);
1286  transformResults.push_back(res->transposeOp);
1287  transformResults.push_back(res->collapseShapeOp);
1288  transformResults.push_back(res->extractSliceOp);
1290 }
1291 
1292 //===---------------------------------------------------------------------===//
1293 // MatchOp
1294 //===---------------------------------------------------------------------===//
1295 
1296 void transform::MatchOp::build(OpBuilder &builder, OperationState &result,
1297  Value target, ArrayRef<StringRef> opNames) {
1298  result.addOperands(target);
1299  result.addAttribute(MatchOp::getOpsAttrName(result.name),
1300  builder.getStrArrayAttr(opNames));
1301  result.addTypes(transform::AnyOpType::get(builder.getContext()));
1302 }
1303 
1304 void transform::MatchOp::build(OpBuilder &builder, OperationState &result,
1305  TypeRange resultTypes, Value target,
1306  ArrayRef<StringRef> opNames) {
1307  result.addOperands(target);
1308  result.addAttribute(MatchOp::getOpsAttrName(result.name),
1309  builder.getStrArrayAttr(opNames));
1310  result.addTypes(resultTypes);
1311 }
1312 
1314 transform::MatchOp::apply(transform::TransformRewriter &rewriter,
1315  transform::TransformResults &results,
1316  transform::TransformState &state) {
1317  llvm::StringSet<> strs;
1318  if (getOps().has_value())
1319  strs.insert_range(getOps()->getAsValueRange<StringAttr>());
1320 
1321  auto payloadOps = state.getPayloadOps(getTarget());
1322  if (!llvm::hasSingleElement(payloadOps)) {
1323  return emitDefiniteFailure("requires exactly one target handle");
1324  }
1325 
1327  bool incorrectNumOperandTypes = false;
1328  auto matchFun = [&](Operation *op) {
1329  if (getOps().has_value() && !strs.contains(op->getName().getStringRef()))
1330  return;
1331 
1332  // Interfaces cannot be matched by name, just by ID.
1333  // So we specifically encode the interfaces we care about for this op.
1334  if (getInterface().has_value()) {
1335  auto iface = getInterface().value();
1336  if (iface == transform::MatchInterfaceEnum::LinalgOp &&
1337  !isa<LinalgOp>(op))
1338  return;
1339  if (iface == transform::MatchInterfaceEnum::TilingInterface &&
1340  !isa<TilingInterface>(op))
1341  return;
1342  if (iface == transform::MatchInterfaceEnum::LoopLikeInterface &&
1343  !isa<LoopLikeOpInterface>(op))
1344  return;
1345  }
1346 
1347  // Check if all specified attributes match.
1348  if (getOpAttrs().has_value()) {
1349  DictionaryAttr opAttrs = getOpAttrs().value();
1350  for (NamedAttribute attr : opAttrs) {
1351  if (attr.getName() == getInterfaceAttrName() ||
1352  attr.getName() == getOpsAttrName())
1353  continue;
1354  if (!op->hasAttr(attr.getName()))
1355  return;
1356  if (op->getAttr(attr.getName()) != attr.getValue())
1357  return;
1358  }
1359  }
1360 
1361  if (getFilterResultType().has_value()) {
1362  Type t = getFilterResultType().value();
1363  if (op->getNumResults() != 1 || op->getResultTypes().front() != t)
1364  return;
1365  }
1366 
1367  if (getFilterOperandTypes().has_value()) {
1368  mlir::ArrayAttr types = getFilterOperandTypes().value();
1369  auto operandTypes = op->getOperandTypes();
1370 
1371  if (types.size() == 1) {
1372  // All the operands must must be equal to the specified type
1373  auto typeattr =
1374  dyn_cast<mlir::TypeAttr>(getFilterOperandTypes().value()[0]);
1375  Type t = cast<::mlir::Type>(typeattr.getValue());
1376  if (!llvm::all_of(op->getOperandTypes(),
1377  [&](Type operandType) { return operandType == t; }))
1378  return;
1379  } else {
1380  // The operand types must match all the types in the list (in the same
1381  // order in with they are specified)
1382  if (types.size() != operandTypes.size()) {
1383  incorrectNumOperandTypes = true;
1384  return;
1385  }
1386 
1387  for (auto [attr, operandType] :
1388  llvm::zip_equal(getFilterOperandTypes().value(), operandTypes)) {
1389  auto typeattr = cast<mlir::TypeAttr>(attr);
1390  Type type = cast<::mlir::Type>(typeattr.getValue());
1391 
1392  if (type != operandType)
1393  return;
1394  }
1395  }
1396  }
1397 
1398  // All constraints are satisfied.
1399  res.push_back(op);
1400  return;
1401  };
1402 
1403  (*payloadOps.begin())->walk(matchFun);
1404  if (incorrectNumOperandTypes)
1405  return emitDefiniteFailure("If filter_operand_types contains more than a "
1406  "type, then it must contain as much types as "
1407  "the number of operands in the target ops");
1408  results.set(cast<OpResult>(getResult()), res);
1410 }
1411 
1412 //===---------------------------------------------------------------------===//
1413 // MultiTileSizesOp
1414 //===---------------------------------------------------------------------===//
1415 
1417  Type targetType, Type lowSizeType, Type,
1418  Type) {
1419  printer.printFunctionalType(TypeRange{targetType}, TypeRange{lowSizeType});
1420 }
1421 
1422 static ParseResult parseMultitileSizesTypes(OpAsmParser &parser,
1423  Type &targetType, Type &lowSizeType,
1424  Type &highSizeType,
1425  Type &splitPointType) {
1426  FunctionType funcType;
1427  llvm::SMLoc typeLoc = parser.getCurrentLocation();
1428  if (failed(parser.parseType<FunctionType>(funcType)))
1429  return failure();
1430 
1431  if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
1432  parser.emitError(typeLoc) << "expects a trailing functional type with one "
1433  "argument and one result";
1434  }
1435  targetType = funcType.getInput(0);
1436  lowSizeType = highSizeType = splitPointType = funcType.getResult(0);
1437 
1438  return success();
1439 }
1440 
1441 DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne(
1442  transform::TransformRewriter &rewriter, LinalgOp target,
1444  if (isa<TransformParamTypeInterface>(getLowSize().getType())) {
1445  if (target.hasDynamicShape()) {
1446  auto diag = emitSilenceableError()
1447  << "cannot compute parametric tile sizes for dynamically "
1448  "shaped payload op";
1449  diag.attachNote(target->getLoc()) << "payload op";
1450  return diag;
1451  }
1452 
1453  FailureOr<StaticMultiSizeSpecification> spec = computeStaticMultiTileSizes(
1454  target, getDimension(), getTargetSize(), getDivisor());
1455  if (failed(spec)) {
1456  return emitSilenceableError()
1457  << "failed to compute multi-size tiling sizes";
1458  }
1459 
1460  Builder builder(target.getContext());
1461  results.assign(llvm::map_range(
1462  ArrayRef<int64_t>({spec->lowTileSize, spec->highTileSize,
1463  spec->lowTileSize * spec->lowTripCount}),
1464  [&builder, this](int64_t value) {
1465  return builder.getIntegerAttr(
1466  cast<ParamType>(getLowSize().getType()).getType(), value);
1467  }));
1469  }
1470 
1471  OpBuilder builder(target.getContext());
1472  builder.setInsertionPoint(target);
1473  OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
1474  OpFoldResult divisor = builder.getIndexAttr(getDivisor());
1475  FailureOr<MultiSizeSpecification> spec = computeMultiTileSizes(
1476  builder, target, getDimension(), targetSize, divisor);
1477  if (failed(spec)) {
1478  return emitSilenceableError() << "could not generate tile size computation";
1479  }
1480 
1481  AffineExpr s0 = builder.getAffineSymbolExpr(0);
1482  AffineExpr s1 = builder.getAffineSymbolExpr(1);
1483  Operation *splitPoint =
1484  affine::makeComposedAffineApply(builder, target.getLoc(), s0 * s1,
1485  {spec->lowTileSize, spec->lowTripCount});
1486  Operation *lowTileSize = spec->lowTileSize.getDefiningOp();
1487  Operation *highTileSize = spec->highTileSize.getDefiningOp();
1488  assert(lowTileSize && highTileSize && splitPoint &&
1489  "tile sizes are not produced by operations");
1490  results.reserve(results.size() + 3);
1491  results.push_back(lowTileSize);
1492  results.push_back(highTileSize);
1493  results.push_back(splitPoint);
1495 }
1496 
1497 void transform::MultiTileSizesOp::getEffects(
1499  onlyReadsHandle(getTargetMutable(), effects);
1500  producesHandle(getOperation()->getOpResults(), effects);
1501  if (isa<TransformParamTypeInterface>(getLowSize().getType()))
1502  onlyReadsPayload(effects);
1503  else
1504  modifiesPayload(effects);
1505 }
1506 
1507 LogicalResult transform::MultiTileSizesOp::verify() {
1508  if (getLowSize().getType() != getHighSize().getType() ||
1509  getLowSize().getType() != getSplitPoint().getType()) {
1510  return emitOpError() << "expects all results type to be the same";
1511  }
1512  return success();
1513 }
1514 
1515 //===---------------------------------------------------------------------===//
1516 // PackOp
1517 //===---------------------------------------------------------------------===//
1518 
1519 void transform::PackOp::build(OpBuilder &builder, OperationState &result,
1520  Value target,
1521  ArrayRef<OpFoldResult> mixedPackedSizes) {
1522  SmallVector<int64_t> staticPackedSizes;
1523  SmallVector<Value> dynamicPackedSizes;
1524  dispatchIndexOpFoldResults(mixedPackedSizes, dynamicPackedSizes,
1525  staticPackedSizes);
1526  // Call the default builder which sets up the proper operands segment sizes
1527  // attributes for multiple variadic operands. In the absence of this, horrible
1528  // bugs ensue.
1529  Type linalgOpHType = transform::OperationType::get(
1530  builder.getContext(), GenericOp::getOperationName());
1531  build(builder, result,
1532  /*resultType=*/linalgOpHType,
1533  /*target=*/target,
1534  /*dynamic_sizes=*/dynamicPackedSizes,
1535  /*static_sizes=*/builder.getDenseI64ArrayAttr(staticPackedSizes));
1536 }
1537 
1538 SmallVector<OpFoldResult> transform::PackOp::getMixedPackedSizes() {
1539  Builder b(getContext());
1540  return getMixedValues(getStaticPackedSizes(), getPackedSizes(), b);
1541 }
1542 
1544 transform::PackOp::apply(transform::TransformRewriter &rewriter,
1545  transform::TransformResults &transformResults,
1546  transform::TransformState &state) {
1547  auto targetOps = state.getPayloadOps(getTarget());
1548  // If nothing to pack, propagate success.
1549  if (std::empty(targetOps)) {
1550  transformResults.set(cast<OpResult>(getPackedOp()),
1551  ArrayRef<Operation *>({}));
1553  }
1554  // Fail on multi-op handles.
1555  auto linalgOp = dyn_cast<LinalgOp>(*targetOps.begin());
1556  if (!llvm::hasSingleElement(targetOps) || !linalgOp) {
1557  return emitSilenceableError()
1558  << "requires target to map to exactly 1 LinalgOp (got "
1559  << llvm::range_size(targetOps) << ")";
1560  }
1561  // Fail on mismatched number of pack sizes.
1562  if (getMixedPackedSizes().size() != linalgOp.getNumLoops()) {
1563  return emitSilenceableError()
1564  << "requires number of packed sizes match the number of loops ("
1565  << getMixedPackedSizes().size() << " vs " << linalgOp.getNumLoops()
1566  << ")";
1567  }
1568 
1569  // Unpack handles to constants or actual SSA index values.
1570  SmallVector<OpFoldResult> packedSizes;
1572  state, *this, packedSizes, getMixedPackedSizes());
1573 
1574  rewriter.setInsertionPoint(linalgOp);
1575  FailureOr<PackResult> maybeResult = pack(rewriter, linalgOp, packedSizes);
1576  if (failed(maybeResult))
1577  return emitDefiniteFailure("data tiling failed");
1578 
1579  transformResults.set(cast<OpResult>(getPackedOp()),
1580  {maybeResult->packedLinalgOp.getOperation()});
1582 }
1583 
1584 void transform::PackOp::getEffects(
1586  transform::consumesHandle(getTargetMutable(), effects);
1587  transform::onlyReadsHandle(getPackedSizesMutable(), effects);
1588  transform::producesHandle(getOperation()->getOpResults(), effects);
1589  transform::modifiesPayload(effects);
1590 }
1591 
1592 //===---------------------------------------------------------------------===//
1593 // PackGreedilyOp.
1594 //===---------------------------------------------------------------------===//
1595 
1596 LogicalResult transform::PackGreedilyOp::verify() {
1597  if (!isPermutationVector(getMatmulInnerDimsOrder())) {
1598  return emitOpError() << getMatmulInnerDimsOrderAttrName()
1599  << " is not a valid permutation";
1600  }
1601  // TODO: relax to allow empty once we have another strategy than just matmul.
1602  if (!getMatmulPaddedSizesNextMultipleOf().empty()) {
1603  for (auto [s, nmo] :
1604  llvm::zip_equal(getMixedMatmulPackedSizes(),
1605  getMatmulPaddedSizesNextMultipleOf())) {
1606  std::optional<int64_t> maybeStaticPackedSize = getConstantIntValue(s);
1607  if (nmo != 0 &&
1608  (!maybeStaticPackedSize.has_value() || *maybeStaticPackedSize != 0)) {
1609  return emitOpError() << "at most one of the packed_size and the "
1610  "padded_sizes_next_multiple_of can be nonzero "
1611  "for the matmul strategy";
1612  }
1613  }
1614  }
1615  return success();
1616 }
1617 
1619 PackGreedilyOp::apply(transform::TransformRewriter &rewriter,
1620  transform::TransformResults &transformResults,
1621  transform::TransformState &state) {
1622  SmallVector<Operation *> results;
1623  for (Operation *op : state.getPayloadOps(getTarget())) {
1624  auto linalgOp = dyn_cast<LinalgOp>(op);
1625  if (!linalgOp)
1626  continue;
1627  // linalgOp will be replaced and the insertion point may be invalidated if
1628  // we set it before -> set it after.
1629  rewriter.setInsertionPointAfter(linalgOp);
1630  // Failing to pack greedily is perfectly fine.
1631  // In the future we will want to order packings according to some metric.
1632  FailureOr<PackResult> packResult = packMatmulGreedily(
1633  /*rewriter=*/rewriter,
1634  /*linalgOp=*/linalgOp,
1635  /*mnkPackedSizes=*/getMixedMatmulPackedSizes(),
1636  /*mnkPaddedSizesNextMultipleOf=*/
1637  getMatmulPaddedSizesNextMultipleOf(),
1638  /*mnkOrder=*/getMatmulInnerDimsOrder());
1639  if (succeeded(packResult)) {
1640  results.push_back(packResult->packedLinalgOp);
1641  continue;
1642  }
1643  results.push_back(linalgOp);
1644  }
1645  transformResults.set(cast<OpResult>(getPackedOp()), results);
1647 }
1648 
1649 SmallVector<OpFoldResult> PackGreedilyOp::getMixedMatmulPackedSizes() {
1650  Builder b(getContext());
1651  return getMixedValues(getStaticMatmulPackedSizes(), getMatmulPackedSizes(),
1652  b);
1653 }
1654 
1655 void transform::PackGreedilyOp::getEffects(
1657  transform::consumesHandle(getTargetMutable(), effects);
1658  transform::onlyReadsHandle(getMatmulPackedSizesMutable(), effects);
1659  transform::producesHandle(getOperation()->getOpResults(), effects);
1660  transform::modifiesPayload(effects);
1661 }
1662 
1663 //===---------------------------------------------------------------------===//
1664 // PackTransposeOp
1665 //===---------------------------------------------------------------------===//
1666 
1667 LogicalResult transform::PackTransposeOp::verify() {
1668  if (!isPermutationVector(getInnerPerm())) {
1669  return emitOpError() << getInnerPermAttrName()
1670  << " is not a valid permutation";
1671  }
1672  if (!isPermutationVector(getOuterPerm())) {
1673  return emitOpError() << getOuterPermAttrName()
1674  << " is not a valid permutation";
1675  }
1676  if (getInnerPerm().empty() && getOuterPerm().empty()) {
1677  return emitOpError() << " at least one of " << getInnerPermAttrName()
1678  << " or " << getOuterPermAttrName()
1679  << " must be specified";
1680  }
1681  return success();
1682 }
1683 
1684 namespace {
1685 enum class OuterOrInnerPerm { Outer = 0, Inner = 1 };
1686 } // namespace
1687 
1688 /// Return true if `permutation` is a valid permutation of the
1689 /// `outer_dims_perm` (case OuterOrInnerPerm::Outer) or `inner_dims_pos`
1690 /// (OuterOrInnerPerm::Inner) of the `tensor.pack` or `tensor.unpack` `op.
1691 /// This is the case when the `permutation` rank matches the rank expected by
1692 /// `op` and `permutation` is itself a permutation vector.
1693 /// Return true if either `op` or `permutation` are empty to allow a simpler
1694 /// polymorphic implementation.
1695 template <typename RelayoutOpTy>
1697  RelayoutOpTy op, ArrayRef<int64_t> permutation,
1698  OuterOrInnerPerm outerOrInnerPerm = OuterOrInnerPerm::Outer) {
1699  static_assert(
1700  llvm::is_one_of<RelayoutOpTy, linalg::PackOp, linalg::UnPackOp>::value,
1701  "applies to only pack or unpack operations");
1702  if (!op || permutation.empty())
1703  return true;
1704  size_t innerRank = op.getInnerDimsPos().size();
1705  if (outerOrInnerPerm == OuterOrInnerPerm::Inner)
1706  return permutation.size() == innerRank && isPermutationVector(permutation);
1707  // op.getOuterDimsPerm() may be empty, in which case it is identity.
1708  // Don't rely on it.
1709  if (std::is_same<RelayoutOpTy, linalg::PackOp>::value) {
1710  return permutation.size() == op.getSourceRank() &&
1711  isPermutationVector(permutation);
1712  }
1713  return permutation.size() == op.getDestRank() &&
1714  isPermutationVector(permutation);
1715 }
1716 
1718 transform::PackTransposeOp::apply(transform::TransformRewriter &rewriter,
1719  transform::TransformResults &transformResults,
1720  transform::TransformState &state) {
1721  auto packOrUnpackOps = state.getPayloadOps(getTargetPackOrUnPackOp());
1722  auto linalgOps = state.getPayloadOps(getTargetLinalgOp());
1723  // Step 1. If nothing to pack, propagate success.
1724  if (std::empty(packOrUnpackOps)) {
1725  transformResults.set(cast<OpResult>(getPackedOp()), {});
1726  transformResults.set(cast<OpResult>(getPackOp()), {});
1727  transformResults.set(cast<OpResult>(getUnPackOp()), {});
1729  }
1730 
1731  // Step 2. Bunch of runtime sanity check and error messages.
1732  // Step 2.1. Fail on multi-op handles.
1733  if (!llvm::hasSingleElement(packOrUnpackOps) ||
1734  !llvm::hasSingleElement(linalgOps)) {
1735  return emitSilenceableError()
1736  << "requires target to map to exactly 1 "
1737  "packing op and 1 packed op ("
1738  << "got " << llvm::range_size(packOrUnpackOps) << " and "
1739  << llvm::range_size(linalgOps) << ")";
1740  }
1741 
1742  // Step 2.2. Fail on wrong type.
1743  auto packOp = dyn_cast<linalg::PackOp>(*packOrUnpackOps.begin());
1744  auto unPackOp = dyn_cast<linalg::UnPackOp>(*packOrUnpackOps.begin());
1745  if ((!packOp && !unPackOp)) {
1746  return emitSilenceableError() << "requires target to map to a "
1747  "linalg.pack or linalg.unpack";
1748  }
1749  LinalgOp linalgOpTarget = dyn_cast<LinalgOp>(*linalgOps.begin());
1750  if (!linalgOpTarget)
1751  return emitSilenceableError() << "requires a LinalgOp target";
1752 
1753  // Step 2.3. Fail if we can't get the producer / consumer Linalg op.
1754  LinalgOp linalgOp;
1755  if (packOp && packOp.getResult().hasOneUse())
1756  linalgOp = dyn_cast<LinalgOp>(*(packOp.getResult().getUsers().begin()));
1757  else if (unPackOp)
1758  linalgOp = unPackOp.getSource().getDefiningOp<LinalgOp>();
1759  if (linalgOp != linalgOpTarget) {
1760  auto errorMsg =
1761  packOp ? StringLiteral{"not a single use by the LinalgOp target"}
1762  : StringLiteral{"not produced by the LinalgOp target"};
1763  return emitSilenceableError() << errorMsg;
1764  }
1765 
1766  // Step 2.4. If we have an UnPackOp, we need to fetch the symmetrical
1767  // PackOp.
1768  if (unPackOp) {
1769  assert(!packOp && "packOp must be null on entry when unPackOp is not null");
1770  OpOperand *packUse = linalgOp.getDpsInitOperand(
1771  cast<OpResult>(unPackOp.getSource()).getResultNumber());
1772  packOp = dyn_cast_or_null<linalg::PackOp>(packUse->get().getDefiningOp());
1773  if (!packOp || !packOp.getResult().hasOneUse())
1774  return emitSilenceableError() << "could not find matching pack op";
1775  }
1776 
1777  // Step 2.5. Fail if any permutation does not validate.
1778  for (auto permType : {OuterOrInnerPerm::Outer, OuterOrInnerPerm::Inner}) {
1779  ArrayRef<int64_t> perm =
1780  (permType == OuterOrInnerPerm::Outer) ? getOuterPerm() : getInnerPerm();
1781  auto errorMsg = (permType == OuterOrInnerPerm::Outer)
1782  ? StringLiteral{"invalid outer_perm"}
1783  : StringLiteral{"invalid inner_perm"};
1784  if (!isValidPackingPermutation(packOp, perm, permType) ||
1785  !isValidPackingPermutation(unPackOp, perm, permType)) {
1786  Operation *packOrUnpackOp =
1787  unPackOp ? unPackOp.getOperation() : packOp.getOperation();
1788  return emitSilenceableError() << errorMsg << ": " << *packOrUnpackOp;
1789  }
1790  }
1791 
1792  // From here on, packOp and linalgOp are always present, unPackOp may or may
1793  // not be present.
1794  assert(packOp && linalgOp && "unexpected null op");
1795 
1796  // Step 3. Actually transpose the ops.
1797  FailureOr<PackTransposeResult> res = packTranspose(
1798  rewriter, packOp, linalgOp, unPackOp, getOuterPerm(), getInnerPerm());
1799  // Preconditions have been checked, it is an error to fail here.
1800  assert(succeeded(res) && "unexpected packTranspose failure");
1801 
1802  // Step 4. Return results.
1803  transformResults.set(cast<OpResult>(getPackOp()), {res->transposedPackOp});
1804  transformResults.set(cast<OpResult>(getPackedOp()),
1805  {res->transposedLinalgOp});
1806  if (unPackOp) {
1807  transformResults.set(cast<OpResult>(getUnPackOp()),
1808  {res->transposedUnPackOp});
1809  } else {
1810  transformResults.set(cast<OpResult>(getUnPackOp()), {});
1811  }
1812 
1814 }
1815 
1816 //===---------------------------------------------------------------------===//
1817 // PadOp
1818 //===---------------------------------------------------------------------===//
1819 
1820 void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
1821  ArrayRef<int64_t> paddingDimensions,
1822  ArrayRef<int64_t> padToMultipleOf,
1823  ArrayRef<int64_t> nofoldFlags,
1824  ArrayRef<Attribute> transposePaddings,
1825  StringRef copyBackOp) {
1826  auto resultType = transform::AnyOpType::get(b.getContext());
1827  return build(/*builder=*/b,
1828  /*result=*/result,
1829  /*types=*/TypeRange{resultType, resultType},
1830  /*target=*/target,
1831  /*paddingValues=*/ArrayAttr(), // let inference handle this
1832  /*paddingDimensions=*/b.getI64ArrayAttr(paddingDimensions),
1833  /*padToMultipleOf=*/ValueRange{},
1834  /*padToMultipleOf=*/
1835  (padToMultipleOf.empty()
1836  ? DenseI64ArrayAttr()
1837  : b.getDenseI64ArrayAttr(padToMultipleOf)),
1838  /*nofoldFlags=*/b.getI64ArrayAttr(nofoldFlags),
1839  /*transposePaddings=*/b.getArrayAttr(transposePaddings),
1840  /*copyBackOp=*/b.getStringAttr(copyBackOp));
1841 }
1842 
1843 void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
1844  ArrayRef<int64_t> paddingDimensions,
1845  ArrayRef<OpFoldResult> mixedPadToMultipleOf,
1846  ArrayRef<int64_t> nofoldFlags,
1847  ArrayRef<Attribute> transposePaddings,
1848  StringRef copyBackOp) {
1849  auto resultType = transform::AnyOpType::get(b.getContext());
1850  SmallVector<int64_t> staticPadToMultipleOf;
1851  SmallVector<Value> dynamicPadToMultipleOf;
1852  dispatchIndexOpFoldResults(mixedPadToMultipleOf, dynamicPadToMultipleOf,
1853  staticPadToMultipleOf);
1854  return build(/*builder=*/b,
1855  /*result=*/result,
1856  /*types=*/TypeRange{resultType, resultType},
1857  /*target=*/target,
1858  /*paddingValues=*/ArrayAttr(), // let inference handle this
1859  /*paddingDimensions=*/b.getI64ArrayAttr(paddingDimensions),
1860  /*padToMultipleOf=*/dynamicPadToMultipleOf,
1861  /*padToMultipleOf=*/staticPadToMultipleOf,
1862  /*nofoldFlags=*/b.getI64ArrayAttr(nofoldFlags),
1863  /*transposePaddings=*/b.getArrayAttr(transposePaddings),
1864  /*copyBackOp=*/b.getStringAttr(copyBackOp));
1865 }
1866 
1867 void PadOp::getEffects(
1869  consumesHandle(getTargetMutable(), effects);
1870  onlyReadsHandle(getPadToMultipleOfMutable(), effects);
1871  producesHandle(getOperation()->getOpResults(), effects);
1872  modifiesPayload(effects);
1873 }
1874 
1875 SmallVector<OpFoldResult> PadOp::getMixedPadToMultipleOf() {
1876  Builder b(getContext());
1877  return getMixedValues(getStaticPadToMultipleOf(), getPadToMultipleOf(), b);
1878 }
1879 
1881 transform::PadOp::apply(transform::TransformRewriter &rewriter,
1882  transform::TransformResults &results,
1883  transform::TransformState &state) {
1884  auto transformOp = cast<TransformOpInterface>(getOperation());
1885  SmallVector<Operation *> paddedOps, padOps, copyBackOps;
1886 
1887  for (Operation *target : state.getPayloadOps(getTarget())) {
1888  auto linalgTarget = dyn_cast<LinalgOp>(target);
1889  if (!linalgTarget) {
1890  auto diag = emitSilenceableError() << "expected LinalgOp target";
1891  diag.attachNote(target->getLoc()) << "target op";
1892  return diag;
1893  }
1894 
1895  // Convert the integer packing flags to booleans.
1896  SmallVector<bool> nofoldFlags;
1897  for (int64_t packPadding :
1898  extractFromIntegerArrayAttr<int64_t>(getNofoldFlags()))
1899  nofoldFlags.push_back(static_cast<bool>(packPadding));
1900 
1901  // Convert the padding values to attributes.
1902  SmallVector<Attribute> paddingValues;
1903  for (auto const &it :
1904  llvm::zip(getPaddingValues(), linalgTarget->getOperandTypes())) {
1905  auto attr = dyn_cast<TypedAttr>(std::get<0>(it));
1906  if (!attr) {
1907  emitOpError("expects padding values to be typed attributes");
1909  }
1910  Type elementType = getElementTypeOrSelf(std::get<1>(it));
1911  // Try to parse string attributes to obtain an attribute of element type.
1912  if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
1913  auto parsedAttr = dyn_cast_if_present<TypedAttr>(parseAttribute(
1914  stringAttr, getContext(), elementType,
1915  /*numRead=*/nullptr, /*isKnownNullTerminated=*/true));
1916  if (!parsedAttr || parsedAttr.getType() != elementType) {
1917  auto diag = this->emitOpError("expects a padding that parses to ")
1918  << elementType << ", got " << std::get<0>(it);
1919  diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";
1921  }
1922  paddingValues.push_back(parsedAttr);
1923  continue;
1924  }
1925  // Otherwise, add the attribute directly.
1926  if (attr.getType() != elementType) {
1927  auto diag = this->emitOpError("expects a padding value of type ")
1928  << elementType << ", got " << attr;
1929  diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";
1931  }
1932  paddingValues.push_back(attr);
1933  }
1934 
1935  // Extract the transpose vectors.
1936  SmallVector<SmallVector<int64_t>> transposePaddings;
1937  for (Attribute transposeVector : cast<ArrayAttr>(getTransposePaddings()))
1938  transposePaddings.push_back(extractFromIntegerArrayAttr<int64_t>(
1939  cast<ArrayAttr>(transposeVector)));
1940 
1941  LinalgOp paddedOp;
1943  options.paddingDimensions =
1944  extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
1945 
1946  SmallVector<int64_t> padToMultipleOf;
1948  state, transformOp, getMixedPadToMultipleOf(), padToMultipleOf);
1949  if (!status.succeeded())
1950  return status;
1951  if (padToMultipleOf.empty())
1952  padToMultipleOf =
1953  SmallVector<int64_t>(options.paddingDimensions.size(), 1);
1954 
1955  options.padToMultipleOf = padToMultipleOf;
1956  options.paddingValues = paddingValues;
1957  options.nofoldFlags = nofoldFlags;
1958  if (getCopyBackOp() ==
1959  bufferization::MaterializeInDestinationOp::getOperationName()) {
1962  } else if (getCopyBackOp() == linalg::CopyOp::getOperationName()) {
1964  } else if (getCopyBackOp() == kCopyOpNone) {
1966  } else {
1967  llvm_unreachable("unsupported copy_back op");
1968  }
1969 
1970  SmallVector<Value> replacements;
1971  SmallVector<tensor::PadOp> newPadOps;
1972  if (failed(rewriteAsPaddedOp(rewriter, linalgTarget, options, paddedOp,
1973  replacements, newPadOps))) {
1974  auto diag = emitSilenceableError() << "failed to pad op";
1975  diag.attachNote(target->getLoc()) << "target op";
1976  return diag;
1977  }
1978 
1979  // We need to perform our own replacement here because this API is still
1980  // used in patterns that "pad and hoist", for which the replacement values
1981  // need to be different.
1982  // TODO: clean this up and stop "pad and hoist" behavior more globally now
1983  // that we have more composable abstractions.
1984  rewriter.replaceOp(linalgTarget, replacements);
1985  paddedOps.push_back(paddedOp);
1986  padOps.append(newPadOps.begin(), newPadOps.end());
1987  if (options.copyBackOp != LinalgPaddingOptions::CopyBackOp::None) {
1988  for (Value v : replacements) {
1989  Operation *copyBackOp = v.getDefiningOp();
1990  if (!llvm::is_contained(copyBackOps, copyBackOp))
1991  copyBackOps.push_back(copyBackOp);
1992  }
1993  }
1994  }
1995 
1996  results.set(cast<OpResult>(getPadded()), paddedOps);
1997  results.set(cast<OpResult>(getPad()), padOps);
1998  results.set(cast<OpResult>(getCopy()), copyBackOps);
2000 }
2001 
2002 LogicalResult transform::PadOp::verify() {
2003  SmallVector<int64_t> nofoldFlags =
2004  extractFromIntegerArrayAttr<int64_t>(getNofoldFlags());
2005  if (any_of(nofoldFlags, [](int64_t packPadding) {
2006  return packPadding != 0 && packPadding != 1;
2007  })) {
2008  return emitOpError()
2009  << "expects nofold_flags to contain booleans (0/1), found "
2010  << getNofoldFlags();
2011  }
2012 
2013  SmallVector<int64_t> paddingDimensions =
2014  extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
2015  if (any_of(paddingDimensions,
2016  [](int64_t paddingDimension) { return paddingDimension < 0; })) {
2017  return emitOpError() << "expects padding_dimensions to contain positive "
2018  "integers, found "
2019  << getPaddingDimensions();
2020  }
2021  if (!getMixedPadToMultipleOf().empty()) {
2022  if (getMixedPadToMultipleOf().size() != paddingDimensions.size()) {
2023  return emitOpError() << "expects as many multiples as padding_dimensions";
2024  }
2025  }
2026  ArrayAttr transposes = getTransposePaddings();
2027  for (Attribute attr : transposes) {
2028  SmallVector<int64_t> transpose = extractFromIntegerArrayAttr<int64_t>(attr);
2029  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2030  if (!std::is_permutation(sequence.begin(), sequence.end(),
2031  transpose.begin(), transpose.end())) {
2032  return emitOpError()
2033  << "expects transpose_paddings to be a permutation, found "
2034  << attr;
2035  }
2036  }
2037  if (getCopyBackOp() !=
2038  bufferization::MaterializeInDestinationOp::getOperationName() &&
2039  getCopyBackOp() != linalg::CopyOp::getOperationName() &&
2040  getCopyBackOp() != kCopyOpNone)
2041  return emitOpError() << "invalid copy_back_op";
2042  return success();
2043 }
2044 
2045 //===---------------------------------------------------------------------===//
2046 // HoistPadOp
2047 //===---------------------------------------------------------------------===//
2048 
2049 DiagnosedSilenceableFailure transform::HoistPadBuildPackingLoopNestOp::apply(
2050  transform::TransformRewriter &rewriter,
2051  transform::TransformResults &transformResults,
2052  transform::TransformState &state) {
2053  auto targetOps = state.getPayloadOps(getTarget());
2054  auto loopOps = state.getPayloadOps(getLoop());
2055  if (!llvm::hasSingleElement(targetOps) || !llvm::hasSingleElement(loopOps)) {
2056  return emitDefiniteFailure()
2057  << "requires exactly one target and one loop handle (got "
2058  << llvm::range_size(targetOps) << " and "
2059  << llvm::range_size(loopOps) << ")";
2060  }
2061 
2062  auto padOp = dyn_cast_or_null<tensor::PadOp>(*targetOps.begin());
2063  auto loopOp = dyn_cast_or_null<scf::ForOp>(*loopOps.begin());
2064  if (!padOp || !loopOp)
2065  return emitDefiniteFailure() << "requires exactly 2 non-null handles";
2066 
2067  FailureOr<linalg::detail::PackingResult> result =
2068  linalg::detail::buildPackingLoopNest(rewriter, padOp, loopOp,
2069  getTranspose());
2070  if (failed(result))
2071  return emitDefiniteFailure() << "could not build packing loop nest";
2072 
2073  if (result->clonedLoopIvs.empty()) {
2074  transformResults.set(cast<OpResult>(getPackingLoop()),
2075  {result->hoistedPadOp.getOperation()});
2077  }
2078  auto outerPackedLoop =
2079  scf::getForInductionVarOwner(result->clonedLoopIvs.front());
2080  transformResults.set(cast<OpResult>(getPackingLoop()),
2081  {outerPackedLoop.getOperation()});
2083 }
2084 
2086  ArrayRef<int64_t> transpose = getTranspose();
2087  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2088  if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),
2089  transpose.end())) {
2090  return emitOpError() << "expects transpose to be a permutation, found "
2091  << getTranspose();
2092  }
2093  return success();
2094 }
2095 
2096 void transform::HoistPadBuildPackingLoopNestOp::getEffects(
2098  transform::onlyReadsHandle(getTargetMutable(), effects);
2099  transform::onlyReadsHandle(getLoopMutable(), effects);
2100  transform::producesHandle(getOperation()->getOpResults(), effects);
2101  transform::modifiesPayload(effects);
2102 }
2103 
2105 transform::HoistPadOp::applyToOne(transform::TransformRewriter &rewriter,
2106  tensor::PadOp target,
2108  transform::TransformState &state) {
2109  tensor::PadOp hoistedPadOp;
2110  SmallVector<TransposeOp> transposeOps;
2111  FailureOr<Value> result =
2112  hoistPaddingOnTensors(rewriter, target, getNumLoops(), getTranspose(),
2113  hoistedPadOp, transposeOps);
2114  if (succeeded(result)) {
2115  // We need to perform our own replacement here because this API is still
2116  // used in patterns that "pad and hoist", for which the replacement values
2117  // need to be different.
2118  // TODO: clean this up and stop "pad and hoist" behavior more globally now
2119  // that we have more composable abstractions.
2120  rewriter.replaceOp(target, *result);
2121  results.push_back(hoistedPadOp);
2123  }
2124  return emitDefaultSilenceableFailure(target);
2125 }
2126 
2127 LogicalResult transform::HoistPadOp::verify() {
2128  ArrayRef<int64_t> transpose = getTranspose();
2129  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2130  if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),
2131  transpose.end())) {
2132  return emitOpError() << "expects transpose to be a permutation, found "
2133  << getTranspose();
2134  }
2135  return success();
2136 }
2137 
2138 //===----------------------------------------------------------------------===//
2139 // PromoteOp
2140 //===----------------------------------------------------------------------===//
2141 
2143 transform::PromoteOp::applyToOne(transform::TransformRewriter &rewriter,
2144  LinalgOp target,
2146  transform::TransformState &state) {
2147  LinalgPromotionOptions promotionOptions;
2148  if (!getOperandsToPromote().empty())
2149  promotionOptions = promotionOptions.setOperandsToPromote(
2150  extractFromIntegerArrayAttr<int64_t>(getOperandsToPromote()));
2151  if (getUseFullTilesByDefault())
2152  promotionOptions = promotionOptions.setUseFullTileBuffersByDefault(
2153  getUseFullTilesByDefault());
2154  if (getUseAlloca())
2155  promotionOptions = promotionOptions.setUseAlloca(getUseAlloca());
2156  if (!getUseFullTileBuffers().empty())
2157  promotionOptions = promotionOptions.setUseFullTileBuffers(
2158  llvm::to_vector(getUseFullTileBuffers().getAsValueRange<BoolAttr>()));
2159  if (getAlignment().has_value())
2160  promotionOptions = promotionOptions.setAlignment(*getAlignment());
2161  if (getMemorySpace().has_value())
2162  promotionOptions = promotionOptions.setMemorySpace(*getMemorySpace());
2163 
2164  if (getMapping().has_value()) {
2165  // The mapping should only contain an element
2166  auto mapping = *getMapping();
2167  if (mapping.size() > 1)
2168  return emitDefaultDefiniteFailure(target);
2169 
2170  auto addressSpace = cast<mlir::gpu::GPUMemorySpaceMappingAttr>(mapping[0]);
2171 
2172  if (addressSpace.getAddressSpace() ==
2173  mlir::gpu::GPUDialect::getWorkgroupAddressSpace()) {
2174  promotionOptions =
2175  promotionOptions
2179  .setUseFullTileBuffers({false, false});
2180  } else if (addressSpace.getAddressSpace() ==
2181  mlir::gpu::GPUDialect::getPrivateAddressSpace()) {
2182  promotionOptions =
2183  promotionOptions
2187  .setUseFullTileBuffers({false, false});
2188  } else {
2189  return emitDefaultDefiniteFailure(target);
2190  }
2191  }
2192 
2193  if (failed(promoteSubviewsPrecondition(target, promotionOptions)))
2194  return emitDefaultDefiniteFailure(target);
2195 
2196  rewriter.setInsertionPoint(target);
2197  FailureOr<LinalgOp> res = promoteSubViews(rewriter, target, promotionOptions);
2198  if (failed(res))
2199  return emitDefaultDefiniteFailure(target);
2200  results.push_back(target);
2202 }
2203 
2204 //===----------------------------------------------------------------------===//
2205 // ReplaceOp
2206 //===----------------------------------------------------------------------===//
2207 
2209 transform::ReplaceOp::apply(transform::TransformRewriter &rewriter,
2210  TransformResults &transformResults,
2211  TransformState &state) {
2212  auto payload = state.getPayloadOps(getTarget());
2213 
2214  // Check for invalid targets.
2215  for (Operation *target : payload) {
2216  if (target->getNumOperands() > 0)
2217  return emitDefiniteFailure() << "expected target without operands";
2218  if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
2219  target->getNumRegions() > 0)
2220  return emitDefiniteFailure()
2221  << "expected target that is isolated from above";
2222  }
2223 
2224  // Clone and replace.
2225  Operation *pattern = &getBodyRegion().front().front();
2226  SmallVector<Operation *> replacements;
2227  for (Operation *target : payload) {
2228  if (getOperation()->isAncestor(target))
2229  continue;
2230  rewriter.setInsertionPoint(target);
2231  Operation *replacement = rewriter.clone(*pattern);
2232  rewriter.replaceOp(target, replacement->getResults());
2233  replacements.push_back(replacement);
2234  }
2235  transformResults.set(cast<OpResult>(getReplacement()), replacements);
2237 }
2238 
2239 void transform::ReplaceOp::getEffects(
2241  consumesHandle(getTargetMutable(), effects);
2242  producesHandle(getOperation()->getOpResults(), effects);
2243  modifiesPayload(effects);
2244 }
2245 
2246 LogicalResult transform::ReplaceOp::verify() {
2247  if (!getBodyRegion().hasOneBlock())
2248  return emitOpError() << "expected one block";
2249  if (std::distance(getBodyRegion().front().begin(),
2250  getBodyRegion().front().end()) != 1)
2251  return emitOpError() << "expected one operation in block";
2252  Operation *replacement = &getBodyRegion().front().front();
2253  if (replacement->getNumOperands() > 0)
2254  return replacement->emitOpError()
2255  << "expected replacement without operands";
2256  if (!replacement->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
2257  replacement->getNumRegions() > 0)
2258  return replacement->emitOpError()
2259  << "expect op that is isolated from above";
2260  return success();
2261 }
2262 
2263 //===----------------------------------------------------------------------===//
2264 // ScalarizeOp
2265 //===----------------------------------------------------------------------===//
2266 
2268 transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
2269  LinalgOp target,
2271  transform::TransformState &state) {
2272  scf::SCFTilingOptions tilingOptions;
2273  tilingOptions.setTileSizeComputationFunction([&](OpBuilder &b, Operation *) {
2274  SmallVector<OpFoldResult> tileSizes;
2275  Location loc = target.getLoc();
2276  SmallVector<OpFoldResult> allShapeSizes =
2277  target.createFlatListOfOperandDims(b, loc);
2278  AffineMap map = target.getShapesToLoopsMap();
2279  if (!map)
2280  return tileSizes;
2281  SmallVector<OpFoldResult> shapeSizes =
2283  allShapeSizes);
2284  // If the shape size is dynamic, tile by 1.
2285  // Otherwise, do not tile (i.e. tile size 0).
2286  for (OpFoldResult shapeSize : shapeSizes) {
2287  tileSizes.push_back(getConstantIntValue(shapeSize) ? b.getIndexAttr(0)
2288  : b.getIndexAttr(1));
2289  }
2290  return tileSizes;
2291  });
2292  SmallVector<int64_t> emptyTileSizes;
2293  rewriter.setInsertionPoint(target);
2294  FailureOr<scf::SCFTilingResult> maybeTilingResult = tileUsingSCF(
2295  rewriter, cast<TilingInterface>(target.getOperation()), tilingOptions);
2296  if (failed(maybeTilingResult))
2297  return emitDefaultDefiniteFailure(target);
2298 
2299  if (target->getNumResults())
2300  rewriter.replaceOp(target, maybeTilingResult->mergeResult.replacements);
2301  else
2302  rewriter.eraseOp(target);
2303 
2304  results.reserve(maybeTilingResult->tiledOps.size());
2305  for (Operation *tiled : maybeTilingResult->tiledOps)
2306  results.push_back(tiled);
2308 }
2309 
2310 //===----------------------------------------------------------------------===//
2311 // ConvertToLoopsOp
2312 //===----------------------------------------------------------------------===//
2313 
2315 transform::ConvertToLoopsOp::apply(transform::TransformRewriter &rewriter,
2316  transform::TransformResults &results,
2317  transform::TransformState &state) {
2319  for (Operation *target : state.getPayloadOps(getTarget())) {
2320  auto tilingOp = dyn_cast<TilingInterface>(*target);
2321  if (!tilingOp) {
2323  emitSilenceableError()
2324  << "expected the payload to implement TilingInterface";
2325  diag.attachNote(target->getLoc()) << "payload op";
2326  return diag;
2327  }
2328  rewriter.setInsertionPoint(target);
2329  FailureOr<SmallVector<scf::ForOp>> generatedLoops =
2330  scf::lowerToLoopsUsingSCFForOp(rewriter, tilingOp);
2331  if (failed(generatedLoops))
2332  return emitDefaultDefiniteFailure(target);
2333  for (scf::ForOp &loop : *generatedLoops) {
2334  loops.push_back(loop.getOperation());
2335  }
2336  rewriter.eraseOp(target);
2337  }
2338  results.set(cast<OpResult>(getResult()), loops);
2340 }
2341 
2342 //===----------------------------------------------------------------------===//
2343 // RewriteInDestinationPassingStyleOp
2344 //===----------------------------------------------------------------------===//
2345 
2347 transform::RewriteInDestinationPassingStyleOp::applyToOne(
2348  transform::TransformRewriter &rewriter, Operation *target,
2350  transform::TransformState &state) {
2352  rewriter.setInsertionPoint(target);
2353  FailureOr<Operation *> maybeResult =
2355  .Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp>(
2356  [&rewriter](auto op) {
2357  return rewriteInDestinationPassingStyle(rewriter, op);
2358  });
2359  if (failed(maybeResult))
2360  return emitDefaultSilenceableFailure(target);
2361  results.push_back(*maybeResult);
2363 }
2364 
2365 //===----------------------------------------------------------------------===//
2366 // SplitOp
2367 //===----------------------------------------------------------------------===//
2368 
2370 SplitOp::apply(transform::TransformRewriter &rewriter,
2371  TransformResults &results, TransformState &state) {
2372  // Collect the dynamic split points if provided.
2373  SmallVector<Operation *> payload =
2374  llvm::to_vector(state.getPayloadOps(getTarget()));
2375 
2376  bool isMultiwaySplit = getMultiway();
2377 
2378  if (isMultiwaySplit && !llvm::hasSingleElement(payload)) {
2379  return mlir::emitSilenceableFailure(getLoc())
2380  << "requires exactly one target when "
2381  "multiway split is enabled (got "
2382  << llvm::range_size(payload) << ")";
2383  }
2384 
2385  SmallVector<OpFoldResult> chunkSizes;
2386 
2387  if (!isMultiwaySplit)
2388  chunkSizes.reserve(payload.size());
2389 
2390  if (getDynamicChunkSizes()) {
2392  if (isa<TransformHandleTypeInterface>(getDynamicChunkSizes().getType())) {
2393  chunkSizes = llvm::to_vector(llvm::map_range(
2394  state.getPayloadOps(getDynamicChunkSizes()), [&](Operation *op) {
2395  if (op->getNumResults() != 1 ||
2396  !op->getResult(0).getType().isIndex()) {
2397  diag = emitSilenceableError()
2398  << "expected dynamic split point handle to point to a "
2399  "single-result index-typed op";
2400  diag.attachNote(op->getLoc()) << "dynamic split point";
2401  }
2402  return OpFoldResult(op->getResult(0));
2403  }));
2404  } else {
2405  chunkSizes = llvm::to_vector(
2406  llvm::map_range(state.getParams(getDynamicChunkSizes()),
2407  [](Attribute attr) { return OpFoldResult(attr); }));
2408  }
2409  if (diag.isSilenceableFailure())
2410  return diag;
2411 
2412  // For multiway split, a single payload is expected to have multiple
2413  // split points.
2414  if (!isMultiwaySplit && chunkSizes.size() != payload.size()) {
2415  return emitDefiniteFailure()
2416  << "expected the dynamic split point handle to point to as "
2417  "many operations ("
2418  << chunkSizes.size() << ") as the target handle ("
2419  << payload.size() << ")";
2420  }
2421  } else {
2422  chunkSizes.resize(payload.size(),
2423  rewriter.getIndexAttr(getStaticChunkSizes()));
2424  }
2425 
2426  auto checkStructuredOpAndDimensions =
2427  [&](LinalgOp linalgOp, Location loc) -> DiagnosedSilenceableFailure {
2428  if (!linalgOp) {
2429  auto diag = emitSilenceableError() << "only applies to structured ops";
2430  diag.attachNote(loc) << "target op";
2431  return diag;
2432  }
2433 
2434  if (getDimension() >= linalgOp.getNumLoops()) {
2435  auto diag = emitSilenceableError() << "dimension " << getDimension()
2436  << " does not exist in target op";
2437  diag.attachNote(loc) << "target op";
2438  return diag;
2439  }
2441  };
2442 
2443  auto checkFailureInSplitting =
2444  [&](bool hasFailed, Location loc) -> DiagnosedSilenceableFailure {
2445  if (hasFailed) {
2446  auto diag = emitDefiniteFailure() << "internal failure in splitting";
2447  diag.attachNote(loc) << "target op";
2448  return diag;
2449  }
2451  };
2452 
2453  SmallVector<Operation *> opList;
2454  if (isMultiwaySplit) {
2455 
2456  // Split a single target operation at multiple points.
2457  TilingInterface head, tail;
2458  Operation *target = payload.front();
2459 
2460  LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
2461 
2462  // Check that the target is a valid LinalgOp with correct dimensions.
2464  checkStructuredOpAndDimensions(linalgOp, target->getLoc());
2465  if (diag.isSilenceableFailure())
2466  return diag;
2467 
2468  for (auto &&[idx, chunkSize] : llvm::enumerate(chunkSizes)) {
2469 
2470  if (idx > 0)
2471  target = tail.getOperation();
2472 
2473  if (!target)
2474  break;
2475 
2476  linalgOp = cast<LinalgOp>(target);
2477  Location loc = target->getLoc();
2478 
2479  rewriter.setInsertionPoint(linalgOp);
2480  std::tie(head, tail) = linalg::splitOp(
2481  rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2482  getDimension(), chunkSize);
2483 
2484  // Propagate errors.
2486  checkFailureInSplitting(!head && !tail, loc);
2487  if (diag.isDefiniteFailure())
2488  return diag;
2489 
2490  opList.push_back(head.getOperation());
2491  }
2492 
2493  // Append any leftover parts to the end of the result list.
2494  if (tail)
2495  opList.push_back(tail.getOperation());
2496 
2497  } else {
2498  // Split each target operation.
2499  SmallVector<Operation *> first, second;
2500  Operation *noSecondPart = nullptr;
2501  for (const auto &pair : llvm::zip(payload, chunkSizes)) {
2502  Operation *target = std::get<0>(pair);
2503  Location loc = target->getLoc();
2504  LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
2506  checkStructuredOpAndDimensions(linalgOp, target->getLoc());
2507 
2508  if (diag.isSilenceableFailure())
2509  return diag;
2510 
2511  rewriter.setInsertionPoint(linalgOp);
2512  std::tie(first.emplace_back(), second.emplace_back()) = linalg::splitOp(
2513  rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2514  getDimension(), std::get<1>(pair));
2515 
2516  // Propagate errors.
2517  DiagnosedSilenceableFailure diagSplit =
2518  checkFailureInSplitting(!first.back() && !second.back(), loc);
2519  if (diagSplit.isDefiniteFailure())
2520  return diag;
2521 
2522  // Do not add null second parts.
2523  if (!second.back()) {
2524  noSecondPart = target;
2525  second.pop_back();
2526  }
2527  }
2528 
2529  if (second.size() != first.size() && !second.empty()) {
2530  auto diag = emitSilenceableError()
2531  << "splitting does not produce the second part for a subset "
2532  "of targets";
2533  diag.attachNote()
2534  << "expected splitting to produce the second part of all "
2535  "or none of the targets";
2536  diag.attachNote(noSecondPart->getLoc())
2537  << "first target with no second part";
2538  return diag;
2539  }
2540 
2541  opList.append(first);
2542  if (second.size())
2543  opList.append(second);
2544  }
2545  results.set(cast<OpResult>(getSplitList()), opList);
2547 }
2548 
2549 void SplitOp::getEffects(
2551  consumesHandle(getTargetMutable(), effects);
2552  if (getDynamicChunkSizes())
2553  onlyReadsHandle(getDynamicChunkSizesMutable(), effects);
2554  producesHandle(getOperation()->getOpResults(), effects);
2555  modifiesPayload(effects);
2556 }
2557 
2558 ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) {
2559  OpAsmParser::UnresolvedOperand target, dynamicChunkSizes;
2560  IntegerAttr staticChunkSizes;
2561  if (parser.parseOperand(target) || parser.parseKeyword("after"))
2562  return failure();
2563 
2564  OptionalParseResult dynamicPointParseResult =
2565  parser.parseOptionalOperand(dynamicChunkSizes);
2566  if (!dynamicPointParseResult.has_value()) {
2567  int64_t staticChunkSizesValue;
2568  if (failed(parser.parseInteger(staticChunkSizesValue)))
2569  return failure();
2570 
2571  staticChunkSizes =
2572  parser.getBuilder().getI64IntegerAttr(staticChunkSizesValue);
2573  }
2574 
2575  Type targetType;
2576  if (parser.parseOptionalAttrDict(result.attributes) ||
2577  parser.parseColonType(targetType) ||
2578  parser.resolveOperand(target, targetType, result.operands)) {
2579  return failure();
2580  }
2581  if (dynamicPointParseResult.has_value()) {
2582  Type ChunkSizesType;
2583  if (failed(*dynamicPointParseResult) || parser.parseComma() ||
2584  parser.parseType(ChunkSizesType) ||
2585  parser.resolveOperand(dynamicChunkSizes, ChunkSizesType,
2586  result.operands)) {
2587  return failure();
2588  }
2589 
2590  staticChunkSizes =
2591  parser.getBuilder().getI64IntegerAttr(ShapedType::kDynamic);
2592  }
2593 
2594  result.addAttribute(
2595  SplitOp::getStaticChunkSizesAttrName(result.name).getValue(),
2596  staticChunkSizes);
2597  result.addTypes(targetType);
2598  return success();
2599 }
2600 
2601 void SplitOp::print(OpAsmPrinter &printer) {
2602  printer << " " << getTarget() << " after ";
2603  int64_t staticChunkSize = static_cast<int64_t>(getStaticChunkSizes());
2604  if (staticChunkSize != ShapedType::kDynamic)
2605  printer << staticChunkSize;
2606  else
2607  printer << getDynamicChunkSizes();
2608  printer << " ";
2609  printer.printOptionalAttrDict(getOperation()->getAttrs(),
2610  {getStaticChunkSizesAttrName()});
2611  printer << " : " << getTarget().getType();
2612  if (staticChunkSize == ShapedType::kDynamic)
2613  printer << ", " << getDynamicChunkSizes().getType();
2614 }
2615 
2616 LogicalResult SplitOp::verify() {
2617  if ((static_cast<int64_t>(getStaticChunkSizes()) != ShapedType::kDynamic) ^
2618  (getDynamicChunkSizes() == nullptr)) {
2619  return emitOpError() << "expects either a dynamic or a static split "
2620  "point to be provided";
2621  }
2622  return success();
2623 }
2624 
2625 //===----------------------------------------------------------------------===//
2626 // SplitReductionOp
2627 //===----------------------------------------------------------------------===//
2628 
2629 void transform::SplitReductionOp::build(
2630  OpBuilder &builder, OperationState &result, Value target,
2631  int64_t splitFactor, int64_t insertSplitDimension, bool innerParallel,
2632  bool useScalingAlgorithm, bool useAlloc) {
2633  MLIRContext *ctx = builder.getContext();
2634  result.addOperands(target);
2635  result.addAttribute(SplitReductionOp::getSplitFactorAttrName(result.name),
2636  builder.getI64IntegerAttr(splitFactor));
2637  result.addAttribute(
2638  SplitReductionOp::getInsertSplitDimensionAttrName(result.name),
2639  builder.getI64IntegerAttr(insertSplitDimension));
2640  if (innerParallel) {
2641  result.addAttribute(SplitReductionOp::getInnerParallelAttrName(result.name),
2642  builder.getUnitAttr());
2643  }
2644  if (useScalingAlgorithm) {
2645  result.addAttribute(
2646  SplitReductionOp::getUseScalingAlgorithmAttrName(result.name),
2647  builder.getUnitAttr());
2648  }
2649  if (useAlloc) {
2650  result.addAttribute(SplitReductionOp::getUseAllocAttrName(result.name),
2651  builder.getUnitAttr());
2652  }
2653  auto resultType = transform::AnyOpType::get(ctx);
2654  result.addTypes({resultType, resultType, resultType, resultType});
2655 }
2656 
2657 DiagnosedSilenceableFailure transform::SplitReductionOp::applyToOne(
2658  transform::TransformRewriter &rewriter, LinalgOp target,
2660  transform::TransformState &state) {
2661  ControlSplitReductionFn splitFn = [&](LinalgOp) {
2662  return linalg::SplitReductionOptions{int64_t(getSplitFactor()),
2663  unsigned(getInsertSplitDimension()),
2664  bool(getInnerParallel())};
2665  };
2666  rewriter.setInsertionPoint(target);
2667  FailureOr<SplitReductionResult> splitResult =
2668  (getUseScalingAlgorithm())
2669  ? splitReductionByScaling(rewriter, target, splitFn, getUseAlloc())
2670  : splitReduction(rewriter, target, splitFn, getUseAlloc());
2671  if (failed(splitResult))
2672  return emitDefaultDefiniteFailure(target);
2673 
2674  results.push_back(splitResult->initOrAlloc);
2675  results.push_back(splitResult->fillOp);
2676  results.push_back(splitResult->splitLinalgOp);
2677  results.push_back(splitResult->resultCombiningLinalgOp);
2679 }
2680 
2681 //===----------------------------------------------------------------------===//
2682 // TileReductionUsingForOp
2683 //===----------------------------------------------------------------------===//
2684 
2685 void transform::TileReductionUsingForOp::build(
2686  OpBuilder &builder, OperationState &result, Value target,
2687  ArrayRef<int64_t> staticTileSizes) {
2688  // Call the default builder.
2689  // This is future-proof re mixed static-dynamic and setting up the proper
2690  // operands segment sizes attributes for multiple variadic operands.
2691  // In the absence of this, horrible bugs ensue.
2692  // TODO: support mixed static-dynamic (see TileUsingForallOp).
2693  MLIRContext *ctx = builder.getContext();
2694  auto opTy = transform::AnyOpType::get(ctx);
2695  auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
2696  build(builder, result,
2697  /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
2698  /*target=*/target,
2699  /*tile_sizes=*/staticTileSizesAttr);
2700 }
2701 
2702 DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
2703  transform::TransformRewriter &rewriter, Operation *target,
2705  transform::TransformState &state) {
2706  rewriter.setInsertionPoint(target);
2707 
2708  auto partialReductionOp = dyn_cast<PartialReductionOpInterface>(target);
2709  if (!partialReductionOp) {
2710  return emitSilenceableFailure(
2711  target->getLoc(),
2712  "Operation should implement PartialReductionOpInterface");
2713  }
2714  FailureOr<scf::SCFTilingResult> result = scf::tileReductionUsingScf(
2715  rewriter, partialReductionOp,
2716  getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes())));
2717 
2718  if (failed(result))
2719  return emitDefaultSilenceableFailure(target);
2720  rewriter.replaceOp(target, result->mergeResult.replacements);
2721  for (Value initValue : result->initialValues)
2722  results.push_back(initValue.getDefiningOp());
2723  for (auto parallelTiledOp : result->tiledOps)
2724  results.push_back(parallelTiledOp);
2725  for (auto mergeOp : result->mergeResult.mergeOps)
2726  results.push_back(mergeOp);
2727  results.push_back(result->loops.front());
2729 }
2730 
2731 //===----------------------------------------------------------------------===//
2732 // TileReductionUsingForallOp
2733 //===----------------------------------------------------------------------===//
2734 
2735 void transform::TileReductionUsingForallOp::build(
2736  OpBuilder &builder, OperationState &result, Value target,
2737  ArrayRef<int64_t> staticNumThreads, ArrayRef<int64_t> staticTileSizes,
2738  ArrayAttr mapping) {
2739  // Call the default builder.
2740  // This is future-proof re mixed static-dynamic and setting up the proper
2741  // operands segment sizes attributes for multiple variadic operands.
2742  // In the absence of this, horrible bugs ensue.
2743  // TODO: support mixed static-dynamic (see TileUsingForallOp).
2744  MLIRContext *ctx = builder.getContext();
2745  auto opTy = transform::AnyOpType::get(ctx);
2746  auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads);
2747  auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
2748  build(builder, result,
2749  /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
2750  /*target=*/target,
2751  /*num_threads=*/staticNumThreadsAttr,
2752  /*tile_sizes=*/staticTileSizesAttr,
2753  /*mapping=*/mapping);
2754 }
2755 
2756 DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
2757  transform::TransformRewriter &rewriter, LinalgOp target,
2759  transform::TransformState &state) {
2760  rewriter.setInsertionPoint(target);
2761  SmallVector<OpFoldResult> numThreads =
2762  getAsOpFoldResult(rewriter.getI64ArrayAttr(getNumThreads()));
2763  SmallVector<OpFoldResult> tileSizes =
2764  getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes()));
2765  FailureOr<linalg::ForallReductionTilingResult> result =
2767  rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
2768  numThreads, tileSizes, getMapping());
2769 
2770  if (failed(result)) {
2771  auto diag = emitSilenceableError() << "could not tile reduction";
2772  diag.attachNote(target.getLoc()) << "target operation";
2773  return diag;
2774  }
2775  for (Value initValue : result->initialValues)
2776  results.push_back(initValue.getDefiningOp());
2777  for (auto parallelTiledOp : result->parallelTiledOps)
2778  results.push_back(parallelTiledOp);
2779  for (auto mergeOp : result->mergeOps)
2780  results.push_back(mergeOp);
2781  results.push_back(result->loops);
2783 }
2784 
2785 //===----------------------------------------------------------------------===//
2786 // ContinuousTileSizesOp
2787 //===----------------------------------------------------------------------===//
2788 
2790 transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
2791  TransformResults &transformResults,
2792  TransformState &state) {
2793 
2794  SmallVector<Operation *> targetOps =
2795  llvm::to_vector(state.getPayloadOps(getTarget()));
2796 
2797  if (!llvm::hasSingleElement(targetOps)) {
2798  return mlir::emitSilenceableFailure(getLoc())
2799  << "requires exactly one target (got " << llvm::range_size(targetOps)
2800  << ")";
2801  }
2802 
2803  Operation *target = *targetOps.begin();
2804  auto linalgOp = dyn_cast<LinalgOp>(target);
2805  auto tileableOp = dyn_cast<TilingInterface>(target);
2806 
2807  if (!linalgOp)
2808  return emitDefiniteFailure() << "expected Linalg Op";
2809 
2810  OpBuilder builder(linalgOp.getContext());
2811 
2812  if (isa<TransformParamTypeInterface>(getChunkSizes().getType())) {
2813  if (linalgOp.hasDynamicShape()) {
2814  auto diag = emitSilenceableError()
2815  << "cannot compute parametric tile sizes for dynamically "
2816  "shaped payload op";
2817  diag.attachNote(linalgOp->getLoc()) << "payload op";
2818  return diag;
2819  }
2820 
2821  FailureOr<StaticContinuousTileSizeSpecification> spec =
2822  computeStaticContinuousTileSizes(linalgOp, getDimension(),
2823  getTargetSize());
2824  if (failed(spec)) {
2825  return emitSilenceableError()
2826  << "failed to compute multi-size tiling sizes";
2827  }
2828 
2829  SmallVector<int64_t> chunkSizes;
2830 
2831  for (auto &&[tileSize, tripCount] :
2832  llvm::zip_equal(spec->tileSizes, spec->tripCounts))
2833  chunkSizes.push_back(tileSize * tripCount);
2834 
2835  auto getI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
2836  return llvm::map_to_vector(values, [&](int64_t value) -> Attribute {
2837  return builder.getI64IntegerAttr(value);
2838  });
2839  };
2840  transformResults.setParams(cast<OpResult>(getTileSizes()),
2841  getI64AttrsFromI64(spec->tileSizes));
2842  transformResults.setParams(cast<OpResult>(getChunkSizes()),
2843  getI64AttrsFromI64(chunkSizes));
2844 
2846  }
2847 
2848  builder.setInsertionPoint(linalgOp);
2849 
2850  OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
2851  unsigned dimension = getDimension();
2852 
2853  FailureOr<ContinuousTileSizeSpecification> spec = computeContinuousTileSizes(
2854  builder, tileableOp, dimension, targetSize, true);
2855  if (failed(spec)) {
2856  return emitSilenceableError() << "could not generate tile size computation";
2857  }
2858 
2859  AffineExpr s0 = builder.getAffineSymbolExpr(0);
2860  AffineExpr s1 = builder.getAffineSymbolExpr(1);
2861  auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value {
2862  return affine::makeComposedAffineApply(builder, linalgOp->getLoc(), expr,
2863  ofrs);
2864  };
2865 
2866  SmallVector<Value> chunkSizes;
2867  Value splitPoint;
2868  for (auto &&[tileSize, tripCount] :
2869  llvm::zip_equal(spec->tileSizes, spec->tripCounts)) {
2870  splitPoint = apply(s0 * s1, {tileSize, tripCount});
2871  chunkSizes.push_back(splitPoint);
2872  }
2873 
2874  auto getDefiningOps = [&](ArrayRef<Value> values) {
2875  return llvm::map_to_vector(values, [&](Value value) -> Operation * {
2876  return value.getDefiningOp();
2877  });
2878  };
2879 
2880  transformResults.set(cast<OpResult>(getTileSizes()),
2881  getDefiningOps(spec->tileSizes));
2882  transformResults.set(cast<OpResult>(getChunkSizes()),
2883  getDefiningOps(chunkSizes));
2884 
2886 }
2887 
2889 
2890  if (getTileSizes().getType() != getChunkSizes().getType()) {
2891  return emitOpError() << "expects all results type to be the same";
2892  }
2893 
2894  return success();
2895 }
2896 
2897 void transform::ContinuousTileSizesOp::getEffects(
2899  if (isa<TransformParamTypeInterface>(getTileSizes().getType()))
2900  onlyReadsPayload(effects);
2901  else
2902  modifiesPayload(effects);
2903  onlyReadsHandle(getTargetMutable(), effects);
2904  producesHandle(getOperation()->getOpResults(), effects);
2905 }
2906 
2908  Type targetType, Type tile_sizes,
2909  Type) {
2910  printer.printFunctionalType(TypeRange{targetType}, TypeRange{tile_sizes});
2911 }
2912 
2913 static ParseResult parseContinuousTileSizeTypes(OpAsmParser &parser,
2914  Type &targetType,
2915  Type &tileSizesType,
2916  Type &chunkSizesType) {
2917  FunctionType funcType;
2918  llvm::SMLoc typeLoc = parser.getCurrentLocation();
2919  if (failed(parser.parseType<FunctionType>(funcType)))
2920  return failure();
2921 
2922  if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
2923  parser.emitError(typeLoc) << "expects a trailing functional type with one "
2924  "argument and one result";
2925  }
2926  targetType = funcType.getInput(0);
2927  tileSizesType = chunkSizesType = funcType.getResult(0);
2928 
2929  return success();
2930 }
2931 
2932 //===----------------------------------------------------------------------===//
2933 // TileUsingForOp
2934 //===----------------------------------------------------------------------===//
2935 
2936 void transform::TileUsingForOp::build(
2937  OpBuilder &builder, OperationState &result, TypeRange loopTypes,
2938  Value target, ArrayRef<int64_t> staticTileSizes,
2939  ArrayRef<int64_t> interchange,
2940  std::optional<ArrayRef<bool>> scalableSizes) {
2941  return build(builder, result, loopTypes,
2942  /*target=*/target,
2943  /*mixedTileSizes=*/
2944  getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
2945  interchange, scalableSizes);
2946 }
2947 
2948 void transform::TileUsingForOp::build(
2949  OpBuilder &builder, OperationState &result, Value target,
2950  ArrayRef<int64_t> staticTileSizes, ArrayRef<int64_t> interchange,
2951  std::optional<ArrayRef<bool>> scalableSizes) {
2952  build(builder, result, target,
2953  getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
2954  interchange, scalableSizes);
2955 }
2956 
2957 void transform::TileUsingForOp::build(
2958  OpBuilder &builder, OperationState &result, Value target,
2959  ArrayRef<OpFoldResult> mixedTileSizes, ArrayRef<int64_t> interchange,
2960  std::optional<ArrayRef<bool>> scalableSizes) {
2961  // Loop types are automaticaly splat by the callee, setting up one is
2962  // enough.
2963  SmallVector<Type> loopTypes(1, builder.getType<transform::AnyOpType>());
2964  build(builder, result, loopTypes, target, mixedTileSizes, interchange,
2965  scalableSizes);
2966 }
2967 
2968 void transform::TileUsingForOp::build(
2969  OpBuilder &builder, OperationState &result, TypeRange loopTypes,
2970  Value target, ArrayRef<OpFoldResult> mixedTileSizes,
2971  ArrayRef<int64_t> interchange,
2972  std::optional<ArrayRef<bool>> scalableSizes) {
2973  SmallVector<int64_t> staticTileSizes;
2974  SmallVector<Value> dynamicTileSizes;
2975  dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
2976  // Call the default builder which sets up the proper operands segment sizes
2977  // attributes for multiple variadic operands. In the absence of this,
2978  // horrible bugs ensue.
2979  auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
2980  unsigned numExpectedLoops =
2981  staticTileSizes.size() - llvm::count(staticTileSizes, 0);
2982  SmallVector<Type> resultTypes;
2983  resultTypes.reserve(numExpectedLoops);
2984  assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) &&
2985  "expected one loop type or as many as loops");
2986  if (loopTypes.size() == 1)
2987  resultTypes.append(numExpectedLoops, loopTypes[0]);
2988  else
2989  llvm::append_range(resultTypes, loopTypes);
2990  SmallVector<bool> expandedScalableSizes(mixedTileSizes.size(), false);
2991  if (scalableSizes.has_value())
2992  expandedScalableSizes.assign(scalableSizes->begin(), scalableSizes->end());
2993  build(builder, result, /*tiled_linalg_op=*/target.getType(),
2994  /*loops=*/resultTypes,
2995  /*target=*/target,
2996  /*dynamic_sizes=*/dynamicTileSizes,
2997  /*static_sizes=*/staticTileSizesAttr,
2998  /*interchange=*/builder.getDenseI64ArrayAttr(interchange),
2999  /*scalable_sizes=*/expandedScalableSizes);
3000 }
3001 
3002 LogicalResult transform::TileUsingForOp::verify() {
3003  if (getMixedSizes().size() != getScalableSizes().size())
3004  return emitOpError("expected same number of sizes (")
3005  << getMixedSizes().size() << ") and scalable sizes ("
3006  << getScalableSizes().size() << ")";
3007  ArrayRef<int64_t> staticSizes = getStaticSizes();
3008  unsigned numExpectedLoops = staticSizes.size() - llvm::count(staticSizes, 0);
3009  if (getLoops().size() != numExpectedLoops)
3010  return emitOpError("expected number of loops to tile (")
3011  << numExpectedLoops << ") to match number of `loops` results ("
3012  << getLoops().size() << ")";
3013  return success();
3014 }
3015 
3017 transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter,
3018  TransformResults &transformResults,
3019  TransformState &state) {
3020  ArrayRef<int64_t> tileSizes = getStaticSizes();
3021 
3022  SmallVector<Operation *> targets =
3023  llvm::to_vector(state.getPayloadOps(getTarget()));
3024  SmallVector<SmallVector<Operation *>> dynamicSizeProducers;
3026  dynamicSizeProducers.reserve(getDynamicSizes().size());
3027  paramSizes.reserve(getDynamicSizes().size());
3028  for (Value transformValue : getDynamicSizes()) {
3029  if (isa<ParamType>(transformValue.getType())) {
3030  dynamicSizeProducers.push_back({});
3031  ArrayRef<Attribute> params = state.getParams(transformValue);
3032  paramSizes.push_back(
3033  llvm::to_vector(llvm::map_range(params, [](Attribute attr) {
3034  return cast<IntegerAttr>(attr).getValue().getSExtValue();
3035  })));
3036 
3037  if (paramSizes.back().size() != targets.size()) {
3039  emitSilenceableError()
3040  << "expected as many parameter values ("
3041  << dynamicSizeProducers.back().size() << ") as target ops ("
3042  << targets.size() << ")";
3043  diag.attachNote(transformValue.getLoc()) << "for this parameter";
3044  return diag;
3045  }
3046 
3047  continue;
3048  }
3049  paramSizes.push_back({});
3050  dynamicSizeProducers.push_back(
3051  llvm::to_vector(state.getPayloadOps(transformValue)));
3052 
3053  if (dynamicSizeProducers.back().size() != targets.size()) {
3055  emitSilenceableError()
3056  << "expected as many dynamic size-producing operations ("
3057  << dynamicSizeProducers.back().size() << ") as target ops ("
3058  << targets.size() << ")";
3059  diag.attachNote(transformValue.getLoc()) << "for this handle";
3060  return diag;
3061  }
3062 
3063  for (Operation *op : dynamicSizeProducers.back()) {
3064  if (op->getNumResults() == 1 &&
3065  isa<IndexType>(op->getResult(0).getType())) {
3066  continue;
3067  }
3068 
3070  emitSilenceableError() << "expected sizes to be produced by ops "
3071  "with a single index-type result";
3072  diag.attachNote(op->getLoc()) << "size producer op";
3073  diag.attachNote(transformValue.getLoc()) << "for this handle";
3074  return diag;
3075  }
3076  }
3077 
3080  loops.resize(getLoops().size());
3081  auto scalableSizes = getScalableSizes();
3082  for (auto [i, op] : llvm::enumerate(targets)) {
3083  auto tilingInterface = dyn_cast<TilingInterface>(op);
3084  if (!tilingInterface) {
3086  emitSilenceableError()
3087  << "only ops implementing TilingInterface are supported";
3088  diag.attachNote(op->getLoc()) << "target op";
3089  return diag;
3090  }
3091  if (tileSizes.size() > tilingInterface.getLoopIteratorTypes().size()) {
3093  emitSilenceableError()
3094  << "too many tiles provided, expected at most "
3095  << tilingInterface.getLoopIteratorTypes().size() << " found "
3096  << tileSizes.size();
3097  diag.attachNote(op->getLoc()) << "target op";
3098  return diag;
3099  }
3100 
3101  scf::SCFTilingOptions tilingOptions;
3102  if (tileSizes.empty()) {
3103  tilingOptions.setTileSizeComputationFunction(
3105  return {};
3106  });
3107  } else {
3108  tilingOptions.setTileSizeComputationFunction([&, index = i](OpBuilder &b,
3109  Operation *) {
3111  sizes.reserve(tileSizes.size());
3112  unsigned dynamicIdx = 0;
3113 
3114  for (auto [ofrIdx, ofr] : llvm::enumerate(getMixedSizes())) {
3115  if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) {
3116  if (scalableSizes[ofrIdx]) {
3117  auto val = b.create<arith::ConstantIndexOp>(
3118  getLoc(), cast<IntegerAttr>(attr).getInt());
3119  Value vscale =
3120  b.create<vector::VectorScaleOp>(getLoc(), b.getIndexType());
3121  sizes.push_back(
3122  b.create<arith::MulIOp>(getLoc(), val, vscale).getResult());
3123  } else {
3124  sizes.push_back(attr);
3125  }
3126  continue;
3127  }
3128  ArrayRef<Operation *> dynamicSizes = dynamicSizeProducers[dynamicIdx];
3129  ArrayRef<int64_t> params = paramSizes[dynamicIdx];
3130  ++dynamicIdx;
3131  assert((dynamicSizes.empty() ^ params.empty()) &&
3132  "expected either dynamic sizes or parameters");
3133  if (!params.empty()) {
3134  sizes.push_back(b.getIndexAttr(params[index]));
3135  } else {
3136  sizes.push_back(dynamicSizes[index]->getResult(0));
3137  }
3138  }
3139  return sizes;
3140  });
3141  }
3142 
3143  tilingOptions.setInterchange(getInterchange());
3144  FailureOr<scf::SCFTilingResult> maybeTilingResult =
3145  tileUsingSCF(rewriter, tilingInterface, tilingOptions);
3146  if (failed(maybeTilingResult))
3148 
3149  rewriter.replaceOp(op, maybeTilingResult->mergeResult.replacements);
3150 
3151  tiled.append(maybeTilingResult->tiledOps);
3152  for (const auto &en2 : llvm::enumerate(maybeTilingResult->loops))
3153  loops[en2.index()].push_back(en2.value());
3154  }
3155 
3156  transformResults.set(cast<OpResult>(getTiledLinalgOp()), tiled);
3157  for (const auto &en : llvm::enumerate(loops))
3158  transformResults.set(cast<OpResult>(getLoops()[en.index()]), en.value());
3159 
3161 }
3162 
3164  ValueRange dynamic = getDynamicSizes();
3165  ArrayRef<int64_t> tileSizes = getStaticSizes();
3166  SmallVector<OpFoldResult> results;
3167  results.reserve(tileSizes.size());
3168  unsigned dynamicPos = 0;
3169  Builder builder(getContext());
3170  for (int64_t size : tileSizes) {
3171  if (size == ShapedType::kDynamic) {
3172  results.push_back(dynamic[dynamicPos++]);
3173  } else {
3174  results.push_back(builder.getIndexAttr(size));
3175  }
3176  }
3177  return results;
3178 }
3179 
3180 void transform::TileUsingForOp::getEffects(
3182  consumesHandle(getTargetMutable(), effects);
3183  onlyReadsHandle(getDynamicSizesMutable(), effects);
3184  producesHandle(getOperation()->getOpResults(), effects);
3185  modifiesPayload(effects);
3186 }
3187 
3188 //===----------------------------------------------------------------------===//
3189 // TileUsingForallOp
3190 //===----------------------------------------------------------------------===//
3191 
3192 void transform::TileUsingForallOp::build(OpBuilder &builder,
3193  OperationState &result, Value target,
3194  ArrayRef<int64_t> staticTileSizes,
3196  ArrayAttr mapping) {
3197  return build(builder, result,
3198  /*target=*/target,
3199  /*mixedTileSizes=*/
3200  getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
3201  /*_=*/TileSizesSpec(),
3202  /*mapping=*/mapping);
3203 }
3204 
3205 void transform::TileUsingForallOp::build(OpBuilder &builder,
3206  OperationState &result, Value target,
3207  ArrayRef<OpFoldResult> mixedTileSizes,
3209  ArrayAttr mapping) {
3210  SmallVector<int64_t> staticTileSizes;
3211  SmallVector<Value> dynamicTileSizes;
3212  dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
3213  // Call the default builder which sets up the proper operands segment sizes
3214  // attributes for multiple variadic operands. In the absence of this,
3215  // horrible bugs ensue.
3216  MLIRContext *ctx = builder.getContext();
3217  auto operationType = transform::AnyOpType::get(ctx);
3218  auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
3219  build(builder, result,
3220  /*resultTypes=*/TypeRange{operationType, operationType},
3221  /*target=*/target,
3222  /*num_threads=*/ValueRange{},
3223  /*tile_sizes=*/dynamicTileSizes,
3224  /*packed_num_threads=*/Value(),
3225  /*packed_tile_sizes=*/Value(),
3226  /*static_num_threads=*/builder.getDenseI64ArrayAttr({}),
3227  /*static_tile_sizes=*/staticTileSizesAttr,
3228  /*mapping=*/mapping);
3229 }
3230 
3231 void transform::TileUsingForallOp::build(OpBuilder &builder,
3232  OperationState &result, Value target,
3233  ArrayRef<int64_t> staticNumThreads,
3235  ArrayAttr mapping) {
3236  return build(builder, result, target,
3237  getAsOpFoldResult(builder.getI64ArrayAttr(staticNumThreads)),
3238  NumThreadsSpec(), mapping);
3239 }
3240 
3241 void transform::TileUsingForallOp::build(OpBuilder &builder,
3242  OperationState &result, Value target,
3243  ArrayRef<OpFoldResult> mixedNumThreads,
3245  ArrayAttr mapping) {
3246  SmallVector<int64_t> staticNumThreads;
3247  SmallVector<Value> dynamicNumThreads;
3248  dispatchIndexOpFoldResults(mixedNumThreads, dynamicNumThreads,
3249  staticNumThreads);
3250  // Call the default builder which sets up the proper operands segment sizes
3251  // attributes for multiple variadic operands. In the absence of this,
3252  // horrible bugs ensue.
3253  MLIRContext *ctx = builder.getContext();
3254  auto operationType = transform::AnyOpType::get(ctx);
3255  auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads);
3256  build(builder, result,
3257  /*resultTypes=*/TypeRange{operationType, operationType},
3258  /*target=*/target,
3259  /*num_threads=*/dynamicNumThreads,
3260  /*tile_sizes=*/ValueRange{},
3261  /*packed_num_threads=*/Value(),
3262  /*packed_tile_sizes=*/Value(),
3263  /*static_num_threads=*/staticNumThreadsAttr,
3264  /*static_tile_sizes=*/builder.getDenseI64ArrayAttr({}),
3265  /*mapping=*/mapping);
3266 }
3267 
3268 /// Given `lbs`, `ubs` and `steps` of loops, return (for each loop), the
3269 /// normalized upper bound.
3273  ArrayRef<OpFoldResult> steps) {
3274  AffineExpr s0, s1, s2;
3275  bindSymbols(rewriter.getContext(), s0, s1, s2);
3276  AffineExpr normalizedUbExpr = (s1 - s0).ceilDiv(s2);
3277  SmallVector<OpFoldResult> normalizedUbs;
3278  for (auto [lb, ub, step] : llvm::zip_equal(lbs, ubs, steps)) {
3280  rewriter, loc, normalizedUbExpr, {lb, ub, step});
3281  normalizedUbs.push_back(normalizedUb);
3282  }
3283  return normalizedUbs;
3284 }
3285 
3286 /// When a loop is normalized, the uses of the induction variable within the
3287 /// loop need to replaced with `original_lb + old_iv * original_step`.
3289  Location loc, ValueRange ivs,
3291  ArrayRef<OpFoldResult> steps) {
3292  AffineExpr s0, s1;
3293  AffineExpr d0;
3294  bindSymbols(rewriter.getContext(), s0, s1);
3295  bindDims(rewriter.getContext(), d0);
3296  AffineExpr denormExpr = s0 + d0 * s1;
3297  SmallVector<Value> denormalizedIvs;
3298 
3299  for (auto [iv, lb, step] : llvm::zip_equal(ivs, lbs, steps)) {
3301  rewriter, loc, denormExpr, ArrayRef<OpFoldResult>{iv, lb, step});
3302  denormalizedIvs.push_back(
3303  getValueOrCreateConstantIndexOp(rewriter, loc, denormValue));
3304  }
3305  return denormalizedIvs;
3306 }
3307 
3308 /// Given a `scf.forall` loop return a loop op with the loop bounds
3309 /// normalized.
3310 /// TODO: Replace this with a general utility to normalize `scf.forall`.
3311 /// At the time of writing, this wasnt done since adding this to `scf`
3312 /// dialect would disallow using of `affine.apply` operations due
3313 /// to cyclic dependencies. To avoid churn in lit tests
3314 /// with the change this was added with, defer that to a follow up.
3315 static scf::ForallOp normalizeForallLoopOp(RewriterBase &rewriter,
3316  scf::ForallOp loop) {
3317  SmallVector<OpFoldResult> lbs = loop.getMixedLowerBound();
3318  SmallVector<OpFoldResult> ubs = loop.getMixedUpperBound();
3319  SmallVector<OpFoldResult> steps = loop.getMixedStep();
3320 
3321  if (llvm::all_of(
3322  lbs, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }) &&
3323  llvm::all_of(
3324  steps, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); })) {
3325  return loop;
3326  }
3327 
3328  Location loc = loop.getLoc();
3329  SmallVector<OpFoldResult> normalizedUbs =
3330  normalizeUpperBounds(rewriter, loc, lbs, ubs, steps);
3331  SmallVector<OpFoldResult> normalizedLbs(normalizedUbs.size(),
3332  rewriter.getIndexAttr(0));
3333  SmallVector<OpFoldResult> normalizedSteps(normalizedUbs.size(),
3334  rewriter.getIndexAttr(1));
3335 
3336  auto normalizedForallOp = rewriter.create<scf::ForallOp>(
3337  loc, normalizedLbs, normalizedUbs, normalizedSteps, loop.getOutputs(),
3338  loop.getMapping(), [](OpBuilder &, Location, ValueRange) {});
3339 
3340  auto normalizedLoopIvs = normalizedForallOp.getInductionVars();
3341  OpBuilder::InsertionGuard g(rewriter);
3342  Block *normalizedLoopBlock = normalizedForallOp.getBody();
3343  rewriter.setInsertionPointToStart(normalizedLoopBlock);
3344 
3345  SmallVector<Value> argValues =
3346  denormalizeIndVar(rewriter, loc, normalizedLoopIvs, lbs, steps);
3347  argValues.append(normalizedForallOp.getRegionIterArgs().begin(),
3348  normalizedForallOp.getRegionIterArgs().end());
3349  Block *origLoopBlock = loop.getBody();
3350  rewriter.mergeBlocks(origLoopBlock, normalizedLoopBlock, argValues);
3351 
3352  rewriter.replaceOp(loop, normalizedForallOp);
3353  return normalizedForallOp;
3354 }
3355 
3357  RewriterBase &rewriter, transform::TransformState &state,
3358  TransformOpInterface transformOp, Operation *target,
3359  ArrayRef<OpFoldResult> mixedNumThreads,
3360  ArrayRef<OpFoldResult> mixedTileSizes, std::optional<ArrayAttr> mapping,
3361  scf::SCFTilingResult &tilingResult) {
3362  // Transform all targets one by one.
3363  auto tileableOp = dyn_cast<TilingInterface>(target);
3364  if (!tileableOp) {
3366  transformOp.emitSilenceableError()
3367  << "only TilingInterface ops are supported";
3368  diag.attachNote(target->getLoc()) << "target op";
3369  return diag;
3370  }
3371  rewriter.setInsertionPoint(tileableOp);
3374  if (!mixedNumThreads.empty()) {
3375  options.setNumThreads(mixedNumThreads);
3376  } else {
3377  options.setTileSizes(mixedTileSizes);
3378  }
3379  if (mapping) {
3380  options.setMapping(mapping.value().getValue());
3381  }
3382  FailureOr<scf::SCFTilingResult> maybeTilingResult =
3383  scf::tileUsingSCF(rewriter, tileableOp, options);
3384 
3385  if (failed(maybeTilingResult))
3386  return transformOp.emitDefaultSilenceableFailure(tileableOp);
3387 
3388  rewriter.replaceOp(tileableOp, maybeTilingResult->mergeResult.replacements);
3389 
3390  tilingResult = *maybeTilingResult;
3391 
3392  if (mixedNumThreads.empty()) {
3393  auto generatedForallOp = cast<scf::ForallOp>(tilingResult.loops.front());
3394  OpBuilder::InsertionGuard g(rewriter);
3395  rewriter.setInsertionPoint(generatedForallOp);
3396  scf::ForallOp normalizedForallOp =
3397  normalizeForallLoopOp(rewriter, generatedForallOp);
3398  tilingResult.loops.front() = normalizedForallOp;
3399  }
3400 
3402 }
3403 
3404 DiagnosedSilenceableFailure transform::TileUsingForallOp::apply(
3405  transform::TransformRewriter &rewriter,
3406  transform::TransformResults &transformResults,
3407  transform::TransformState &state) {
3408  auto transformOp = cast<TransformOpInterface>(getOperation());
3409 
3410  // Result payload ops.
3411  SmallVector<Operation *> tileOps;
3412  SmallVector<Operation *> tiledOps;
3413 
3414  // Unpack handles.
3415  SmallVector<OpFoldResult> mixedNumThreads;
3417  getPackedNumThreads()
3419  state, transformOp, mixedNumThreads, getPackedNumThreads())
3421  state, transformOp, mixedNumThreads, getMixedNumThreads());
3422  if (!status.succeeded())
3423  return status;
3424  SmallVector<OpFoldResult> mixedTileSizes;
3425  status = getPackedTileSizes()
3427  state, transformOp, mixedTileSizes, getPackedTileSizes())
3429  state, transformOp, mixedTileSizes, getMixedTileSizes());
3430  if (!status.succeeded())
3431  return status;
3432 
3433  for (Operation *target : state.getPayloadOps(getTarget())) {
3434  scf::SCFTilingResult tilingResult;
3436  rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes,
3437  getMapping(), tilingResult);
3438  if (!diag.succeeded())
3439  return diag;
3440  tileOps.push_back(tilingResult.loops.front());
3441  tiledOps.append(tilingResult.tiledOps);
3442  }
3443 
3444  transformResults.set(cast<OpResult>(getForallOp()), tileOps);
3445  transformResults.set(cast<OpResult>(getTiledOp()), tiledOps);
3446 
3448 }
3449 
3450 void transform::TileUsingForallOp::getEffects(
3452  consumesHandle(getTargetMutable(), effects);
3453  onlyReadsHandle(getTileSizesMutable(), effects);
3454  onlyReadsHandle(getNumThreadsMutable(), effects);
3455  onlyReadsHandle(getPackedNumThreadsMutable(), effects);
3456  onlyReadsHandle(getPackedTileSizesMutable(), effects);
3457  producesHandle(getOperation()->getOpResults(), effects);
3458  modifiesPayload(effects);
3459 }
3460 
3461 SmallVector<OpFoldResult> TileUsingForallOp::getMixedNumThreads() {
3462  Builder b(getContext());
3463  return getMixedValues(getStaticNumThreads(), getNumThreads(), b);
3464 }
3465 
3466 SmallVector<OpFoldResult> TileUsingForallOp::getMixedTileSizes() {
3467  Builder b(getContext());
3468  return getMixedValues(getStaticTileSizes(), getTileSizes(), b);
3469 }
3470 
3471 LogicalResult TileUsingForallOp::verify() {
3472  int numThreadsSpec = static_cast<int>(!getMixedNumThreads().empty()) +
3473  static_cast<int>(getPackedNumThreads() != Value());
3474  if (numThreadsSpec > 1)
3475  return emitOpError(
3476  "num_threads and packed_num_threads are mutually exclusive");
3477  int tileSizesSpec = static_cast<int>(!getMixedTileSizes().empty()) +
3478  static_cast<int>(getPackedTileSizes() != Value());
3479  if (tileSizesSpec > 1)
3480  return emitOpError(
3481  "tile_sizes and packed_tile_sizes are mutually exclusive");
3482  if (numThreadsSpec == 0 && tileSizesSpec == 0)
3483  return emitOpError("either (packed_)num_threads or (packed_)tile_sizes "
3484  "must be specified");
3485  return success();
3486 }
3487 
3488 //===----------------------------------------------------------------------===//
3489 // VectorizeChildrenAndApplyPatternsOp
3490 //===----------------------------------------------------------------------===//
3491 
3492 void transform::VectorizeChildrenAndApplyPatternsOp::build(
3493  OpBuilder &builder, OperationState &result, Value target,
3494  bool vectorizePadding, bool vectorizeExtract, bool flatten1DDepthwiseConv) {
3495  result.addOperands(target);
3496  if (vectorizePadding) {
3497  result.addAttribute(
3498  VectorizeChildrenAndApplyPatternsOp::getVectorizePaddingAttrName(
3499  result.name),
3500  builder.getUnitAttr());
3501  }
3502  if (vectorizeExtract) {
3503  result.addAttribute(
3504  VectorizeChildrenAndApplyPatternsOp::getVectorizeNdExtractAttrName(
3505  result.name),
3506  builder.getUnitAttr());
3507  }
3508  if (flatten1DDepthwiseConv) {
3509  result.addAttribute(
3510  VectorizeChildrenAndApplyPatternsOp::getFlatten_1dDepthwiseConvAttrName(
3511  result.name),
3512  builder.getUnitAttr());
3513  }
3514  result.addTypes(transform::AnyOpType::get(builder.getContext()));
3515 }
3516 
3517 namespace {
3518 /// This is an helper only to call vectorize via a pattern inside of
3519 /// VectorizeChildrenAndApplyPatternsOp::applyToOne.
3520 struct VectorizationPattern : public RewritePattern {
3521  explicit VectorizationPattern(MLIRContext *context,
3522  bool vectorizeExtract = false,
3523  bool flattenConv = false)
3524  : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context),
3525  vectorizeNDExtract(vectorizeExtract),
3526  flatten1DDepthwiseConv(flattenConv) {}
3527  LogicalResult matchAndRewrite(Operation *op,
3528  PatternRewriter &rewriter) const override {
3530  return rewriter.notifyMatchFailure(op,
3531  "Unsupported Op, cannot vectorize");
3532  return vectorize(rewriter, op, /*inputVectorSizes=*/{},
3533  /*inputScalableVecDims=*/{}, vectorizeNDExtract,
3534  flatten1DDepthwiseConv);
3535  }
3536 
3537 private:
3538  /// Controls whether to vectorize `tensor.extract` when the input tensor is
3539  /// rank >= 2.
3540  bool vectorizeNDExtract = false;
3541  /// Controls whether to "flatten" the channel dimension when vectorising 1D
3542  /// depthwise convolutions. This should lead to bette vectorization for
3543  /// tensors with a low number of channel dimensions.
3544  bool flatten1DDepthwiseConv = false;
3545 };
3546 } // namespace
3547 
3549 transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
3550  transform::TransformRewriter &rewriter, Operation *target,
3552  transform::TransformState &state) {
3553  if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
3554  auto diag = this->emitOpError("requires isolated-from-above targets");
3555  diag.attachNote(target->getLoc()) << "non-isolated target";
3557  }
3558 
3559  MLIRContext *ctx = getContext();
3561  patterns.add<VectorizationPattern>(ctx, getVectorizeNdExtract(),
3562  getFlatten_1dDepthwiseConv());
3563 
3564  if (!getDisableTransferPermutationMapLoweringPatterns())
3566 
3567  if (!getDisableMultiReductionToContractPatterns())
3569 
3571 
3574  /*benefit=*/2);
3575  vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx);
3576  vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx);
3578 
3580 
3581  if (getVectorizePadding()) {
3583  // This creates an alternative path for lowering tensor.pad - by
3584  // decomposing it into e.g. linalg.fill.
3586  }
3588 
3589  TrackingListener listener(state, *this);
3591  config.listener = &listener;
3592  if (failed(applyPatternsGreedily(target, std::move(patterns), config)))
3593  return emitDefaultDefiniteFailure(target);
3594 
3595  results.push_back(target);
3597 }
3598 
3599 //===----------------------------------------------------------------------===//
3600 // VectorizeOp
3601 //===----------------------------------------------------------------------===//
3602 
3603 DiagnosedSilenceableFailure transform::VectorizeOp::apply(
3604  transform::TransformRewriter &rewriter,
3605  mlir::transform::TransformResults &transformResults,
3607  auto targets = state.getPayloadOps(getTarget());
3608  if (std::empty(targets))
3610  auto transformOp = cast<TransformOpInterface>(getOperation());
3611  SmallVector<int64_t> vectorSizes;
3613  state, transformOp, getMixedVectorSizes(), vectorSizes);
3614  if (!status.succeeded())
3615  return status;
3616 
3617  // TODO: Check that the correct number of vectorSizes was provided.
3618  for (Operation *target : targets) {
3619  if (!linalg::hasVectorizationImpl(target)) {
3620  return mlir::emitSilenceableFailure(target->getLoc())
3621  << "Unsupported Op, cannot vectorize";
3622  }
3623 
3624  if (failed(linalg::vectorize(rewriter, target, vectorSizes,
3625  getScalableSizes(),
3626  getVectorizeNdExtract().value_or(false)))) {
3627  return mlir::emitSilenceableFailure(target->getLoc())
3628  << "Attempted to vectorize, but failed";
3629  }
3630  }
3631 
3633 }
3634 
3635 void transform::VectorizeOp::getEffects(
3637  consumesHandle(getTargetMutable(), effects);
3638  onlyReadsHandle(getVectorSizesMutable(), effects);
3639  modifiesPayload(effects);
3640 }
3641 
3642 SmallVector<OpFoldResult> VectorizeOp::getMixedVectorSizes() {
3643  OpBuilder b(getContext());
3644  return getMixedValues(getStaticVectorSizes(), getVectorSizes(), b);
3645 }
3646 
3647 LogicalResult transform::VectorizeOp::verify() {
3648  if (getStaticVectorSizes().size() != getScalableSizes().size())
3649  return emitOpError("expected same number of vector sizes (")
3650  << getStaticVectorSizes().size() << ") and scalable sizes ("
3651  << getScalableSizes().size() << ")";
3652  return success();
3653 }
3654 
3655 //===----------------------------------------------------------------------===//
3656 // HoistRedundantVectorTransfersOp
3657 //===----------------------------------------------------------------------===//
3658 
3660 transform::HoistRedundantVectorTransfersOp::applyToOne(
3661  transform::TransformRewriter &rewriter, func::FuncOp target,
3663  transform::TransformState &state) {
3664  // WARNING: This hoisting does not model parallelism and is generally
3665  // incorrect when used on distributed loops with memref semantics!
3666  // TODO: obsolete and should be retired.
3667  linalg::hoistRedundantVectorTransfers(target, getVerifyNonZeroTrip());
3668  results.push_back(target);
3670 }
3671 
3672 //===----------------------------------------------------------------------===//
3673 // HoistRedundantVectorBroadcastsOp
3674 //===----------------------------------------------------------------------===//
3675 
3677 transform::HoistRedundantVectorBroadcastsOp::applyToOne(
3678  transform::TransformRewriter &rewriter, mlir::Operation *target,
3680  transform::TransformState &state) {
3681  rewriter.setInsertionPoint(target);
3682  linalg::hoistRedundantVectorBroadcasts(rewriter, target);
3683  results.push_back(target);
3685 }
3686 
3687 //===----------------------------------------------------------------------===//
3688 // ConvertConv2DToImg2ColOp.
3689 //===----------------------------------------------------------------------===//
3690 
3691 DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne(
3692  transform::TransformRewriter &rewriter, linalg::LinalgOp target,
3694  transform::TransformState &state) {
3695  rewriter.setInsertionPoint(target);
3696  auto maybeTransformed =
3698  target)
3699  .Case([&](linalg::Conv2DNhwcHwcfOp op) {
3700  return rewriteInIm2Col(rewriter, op);
3701  })
3702  .Case([&](linalg::Conv2DNhwcFhwcOp op) {
3703  return rewriteInIm2Col(rewriter, op);
3704  })
3705  .Case([&](linalg::DepthwiseConv2DNhwcHwcOp op) {
3706  return rewriteInIm2Col(rewriter, op);
3707  })
3708  .Case([&](linalg::Conv2DNchwFchwOp op) {
3709  return rewriteInIm2Col(rewriter, op);
3710  })
3711  .Default([&](Operation *op) {
3712  return rewriter.notifyMatchFailure(op, "not supported");
3713  });
3714  if (failed(maybeTransformed))
3715  return emitDefaultSilenceableFailure(target);
3716  // Handle to the operation producing the img2col tensor.
3717  results.push_back(maybeTransformed->first);
3718  // Handle to the operation that replaces the original convolution.
3719  results.push_back(maybeTransformed->second);
3721 }
3722 
3723 //===----------------------------------------------------------------------===//
3724 // FlattenElementwiseLinalgOp.
3725 //===----------------------------------------------------------------------===//
3726 
3727 DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne(
3728  transform::TransformRewriter &rewriter, linalg::LinalgOp target,
3730  transform::TransformState &state) {
3731  rewriter.setInsertionPoint(target);
3732  if (!isElementwise(target))
3733  return mlir::emitSilenceableFailure(target->getLoc())
3734  << "only elementwise flattening is supported";
3735 
3736  // If rank <= 1, do nothing
3737  if (target.getNumLoops() <= 1) {
3738  results.push_back(target);
3740  }
3741 
3742  // Attempt to flatten all dims to one.
3743  ReassociationIndices reassociation(target.getNumLoops());
3744  std::iota(reassociation.begin(), reassociation.end(), 0);
3745  auto maybeFlattened =
3746  collapseOpIterationDims(target, reassociation, rewriter);
3747  if (failed(maybeFlattened))
3748  return mlir::emitSilenceableFailure(target->getLoc())
3749  << "attempted to flatten, but failed";
3750  results.push_back(maybeFlattened->collapsedOp);
3751  rewriter.replaceOp(target, maybeFlattened->results);
3753 }
3754 
3755 //===----------------------------------------------------------------------===//
3756 // TransposeConv2DOp
3757 //===----------------------------------------------------------------------===//
3758 
3759 DiagnosedSilenceableFailure transform::TransposeConv2DOp::applyToOne(
3760  transform::TransformRewriter &rewriter, linalg::LinalgOp target,
3762  transform::TransformState &state) {
3763  rewriter.setInsertionPoint(target);
3764  auto maybeTransformed =
3766  .Case([&](linalg::Conv2DNhwcFhwcOp op) {
3767  return transposeConv2D(rewriter, op);
3768  })
3769  .Case([&](linalg::Conv2DNhwcFhwcQOp op) {
3770  return transposeConv2D(rewriter, op);
3771  })
3772  .Default([&](Operation *op) {
3773  return rewriter.notifyMatchFailure(op, "not supported");
3774  });
3775  if (failed(maybeTransformed))
3776  return emitDefaultSilenceableFailure(target);
3777  // Handle to the new Conv2D operation with transposed filters
3778  results.push_back(*maybeTransformed);
3780 }
3781 
3782 //===----------------------------------------------------------------------===//
3783 // TransposeMatmulOp
3784 //===----------------------------------------------------------------------===//
3785 
3786 DiagnosedSilenceableFailure transform::TransposeMatmulOp::applyToOne(
3787  transform::TransformRewriter &rewriter, linalg::LinalgOp target,
3789  transform::TransformState &state) {
3790  rewriter.setInsertionPoint(target);
3791  bool transposeLHS = getInputToTranspose() == TransposeMatmulInput::lhs;
3792  auto maybeTransformed =
3794  .Case([&](linalg::MatmulOp op) {
3795  return transposeMatmul(rewriter, op, transposeLHS);
3796  })
3797  .Case([&](linalg::BatchMatmulOp op) {
3798  return transposeBatchMatmul(rewriter, op, transposeLHS);
3799  })
3800  .Default([&](Operation *op) { return failure(); });
3801  if (failed(maybeTransformed))
3802  return emitSilenceableFailure(target->getLoc()) << "not supported";
3803  // Handle to the new Matmul operation with transposed filters
3804  results.push_back(*maybeTransformed);
3806 }
3807 
3808 //===----------------------------------------------------------------------===//
3809 // InsertSliceToCopyOp
3810 //===----------------------------------------------------------------------===//
3811 template <typename OpTy>
3814  transform::TransformState &state) {
3815  static_assert(llvm::is_one_of<OpTy, tensor::InsertSliceOp,
3816  tensor::ParallelInsertSliceOp>() &&
3817  "wrong op type");
3818 
3819  if (auto copySource =
3820  target.getSource().template getDefiningOp<linalg::CopyOp>()) {
3821  results.push_back(copySource);
3823  }
3824 
3825  // If we are inside an InParallel region, temporarily set the insertion point
3826  // outside: only tensor.parallel_insert_slice ops are allowed in there.
3827  if constexpr (std::is_same_v<OpTy, tensor::ParallelInsertSliceOp>) {
3828  rewriter.setInsertionPoint(
3829  target->template getParentOfType<scf::InParallelOp>());
3830  }
3831 
3832  Value extracted = rewriter.create<tensor::ExtractSliceOp>(
3833  target.getLoc(), target.getDest(), target.getMixedOffsets(),
3834  target.getMixedSizes(), target.getMixedStrides());
3835  Value copied = rewriter
3836  .create<linalg::CopyOp>(target.getLoc(),
3837  target.getSource(), extracted)
3838  .getResult(0);
3839  // Reset the insertion point.
3840  rewriter.setInsertionPoint(target);
3841  rewriter.replaceOpWithNewOp<OpTy>(
3842  target, copied, target.getDest(), target.getMixedOffsets(),
3843  target.getMixedSizes(), target.getMixedStrides());
3844 
3845  results.push_back(copied.getDefiningOp());
3847 }
3848 
3849 DiagnosedSilenceableFailure transform::InsertSliceToCopyOp::applyToOne(
3850  transform::TransformRewriter &rewriter, Operation *targetOp,
3852  transform::TransformState &state) {
3853 
3854  rewriter.setInsertionPoint(targetOp);
3855  if (auto target = dyn_cast<tensor::InsertSliceOp>(targetOp))
3856  return doit(rewriter, target, results, state);
3857  if (auto target = dyn_cast<tensor::ParallelInsertSliceOp>(targetOp))
3858  return doit(rewriter, target, results, state);
3859 
3861  emitSilenceableError()
3862  << "only InsertSliceOp and ParallelInsertSliceOp ops are supported";
3863  diag.attachNote(targetOp->getLoc()) << "target op";
3864  return diag;
3865 }
3866 
3867 //===----------------------------------------------------------------------===//
3868 // MapCopyToThreadsOp
3869 //===----------------------------------------------------------------------===//
3870 
3871 DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
3872  transform::TransformRewriter &rewriter, Operation *target,
3874  transform::TransformState &state) {
3875  // Check if the op is supported.
3876  if (!isa<linalg::CopyOp, tensor::PadOp>(target)) {
3878  emitSilenceableError()
3879  << "only linalg.copy and tensor.pad target ops are supported";
3880  diag.attachNote(target->getLoc()) << "target op";
3881  return diag;
3882  }
3883  assert(target->getNumResults() == 1 && "expected single result");
3884  auto resultShapedType = cast<ShapedType>(target->getResult(0).getType());
3885  if (!resultShapedType.hasStaticShape()) {
3887  emitSilenceableError()
3888  << "only statically sized ops of rank <= 3 are supported";
3889  diag.attachNote(target->getLoc()) << "target op";
3890  return diag;
3891  }
3892 
3893  // Conservatively set the minimum viable desired bitwidth alignment.
3894  int64_t desiredBitAlignment = getDesiredBitAlignment();
3895  int64_t eltBitwidth =
3896  resultShapedType.getElementType().getIntOrFloatBitWidth();
3897  if (desiredBitAlignment % eltBitwidth != 0) {
3898  desiredBitAlignment = eltBitwidth;
3899  }
3900 
3901  gpu::CopyMappingInfo mapping(
3902  /*ctx=*/getContext(),
3903  /*totalNumThreads=*/getTotalNumThreads(),
3904  /*alignment=*/desiredBitAlignment,
3905  /*sizes=*/resultShapedType.getShape(),
3906  /*favorPredication=*/false,
3907  /*elementalBitwidth=*/
3908  resultShapedType.getElementType().getIntOrFloatBitWidth());
3909  if (mapping.status == gpu::CopyMappingInfo::Status::Invalid) {
3911  emitSilenceableError()
3912  << "too few threads to map copy op to threads on the most minor "
3913  "dimension, given alignment and vector size constraints, try "
3914  "smaller tile size of mapping to more threads";
3915  diag.attachNote(target->getLoc()) << "target op";
3916  return diag;
3917  }
3918 
3919  // OpBuilder only used to compute attributes.
3920  OpBuilder b(getContext());
3921  scf::SCFTilingResult tilingResult;
3923  /*rewriter=*/rewriter,
3924  /*state=*/state,
3925  /*transformOp=*/*this,
3926  /*target=*/target,
3927  /*mixedNumThreads=*/getMixedValues(mapping.numThreads, {}, b),
3928  /*mixedTileSizes=*/ArrayRef<OpFoldResult>{},
3929  /*mapping=*/b.getArrayAttr(mapping.threadMapping),
3930  /*tilingResult=*/tilingResult);
3931  if (!diag.succeeded())
3932  return diag;
3933 
3934  results.push_back(tilingResult.loops.front());
3935  for (auto op : tilingResult.tiledOps)
3936  results.push_back(op);
3938 }
3939 
3940 //===----------------------------------------------------------------------===//
3941 // WinogradConv2DOp
3942 //===----------------------------------------------------------------------===//
3943 
3944 DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne(
3945  transform::TransformRewriter &rewriter, linalg::LinalgOp target,
3947  transform::TransformState &state) {
3948  rewriter.setInsertionPoint(target);
3949  FailureOr<Operation *> maybeTransformed = failure();
3950  bool supported = TypeSwitch<Operation *, bool>(target)
3951  .Case([&](linalg::Conv2DNhwcFhwcOp op) {
3952  maybeTransformed =
3953  winogradConv2D(rewriter, op, getM(), getR());
3954  return true;
3955  })
3956  .Default([&](Operation *op) { return false; });
3957 
3958  if (!supported) {
3959  return emitSilenceableError()
3960  << "this operation is not supported to convert to Winograd Conv2D";
3961  }
3962 
3963  if (failed(maybeTransformed)) {
3964  return emitSilenceableError() << "apply Winograd Conv2D failed";
3965  }
3966 
3967  results.push_back(*maybeTransformed);
3969 }
3970 
3971 DiagnosedSilenceableFailure transform::DecomposeWinogradOp::applyToOne(
3972  transform::TransformRewriter &rewriter, Operation *target,
3974  transform::TransformState &state) {
3975  rewriter.setInsertionPoint(target);
3976  FailureOr<Operation *> maybeTransformed = failure();
3977  bool supported =
3979  .Case([&](linalg::WinogradFilterTransformOp op) {
3980  maybeTransformed = decomposeWinogradFilterTransformOp(rewriter, op);
3981  return true;
3982  })
3983  .Case([&](linalg::WinogradInputTransformOp op) {
3984  maybeTransformed = decomposeWinogradInputTransformOp(rewriter, op);
3985  return true;
3986  })
3987  .Case([&](linalg::WinogradOutputTransformOp op) {
3988  maybeTransformed = decomposeWinogradOutputTransformOp(rewriter, op);
3989  return true;
3990  })
3991  .Default([&](Operation *op) { return false; });
3992 
3993  if (!supported) {
3995  emitSilenceableError()
3996  << "this operation is not supported to decompose into other operations";
3997  diag.attachNote(target->getLoc()) << "target op";
3998  return diag;
3999  }
4000 
4001  if (failed(maybeTransformed)) {
4003  emitSilenceableError() << "decompose Winograd operations failed";
4004  diag.attachNote(target->getLoc()) << "target op";
4005  return diag;
4006  }
4007 
4008  results.push_back(*maybeTransformed);
4010 }
4011 
4012 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
4013 
4014 #define GET_OP_CLASSES
4015 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
static MLIRContext * getContext(OpFoldResult val)
DiagnosedSilenceableFailure doit(RewriterBase &rewriter, OpTy target, transform::ApplyToEachResultList &results, transform::TransformState &state)
static Operation * cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
#define DOWNSCALE(trans)
bool isValidPackingPermutation(RelayoutOpTy op, ArrayRef< int64_t > permutation, OuterOrInnerPerm outerOrInnerPerm=OuterOrInnerPerm::Outer)
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static DiagnosedSilenceableFailure reifyMixedParamAndHandleResults(TransformState &state, TransformOpInterface &transformOp, ArrayRef< OpFoldResult > mixedResults, SmallVectorImpl< int64_t > &reified)
When possible, converts each OpFoldResult in mixedResult to an integer if the value can be statically...
static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(transform::TransformState &state, TransformOpInterface transformOp, SmallVector< OpFoldResult > &result, ArrayRef< OpFoldResult > ofrs)
Assuming that ofr is an index attr or a param of index type or a transform dialect handle mapped to e...
static void printContinuousTileSizeTypes(OpAsmPrinter &printer, Operation *op, Type targetType, Type tile_sizes, Type)
static scf::ForallOp normalizeForallLoopOp(RewriterBase &rewriter, scf::ForallOp loop)
Given a scf.forall loop return a loop op with the loop bounds normalized.
static SmallVector< Value > denormalizeIndVar(RewriterBase &rewriter, Location loc, ValueRange ivs, ArrayRef< OpFoldResult > lbs, ArrayRef< OpFoldResult > steps)
When a loop is normalized, the uses of the induction variable within the loop need to replaced with o...
#define DOWNSCALE_NORMAL(a, b)
static FailureOr< LinalgOp > tryApply(Operation *operation, Args &&...args)
Attempts to apply the pattern specified as template argument to the given operation.
static void printMultitileSizesTypes(OpAsmPrinter &printer, Operation *op, Type targetType, Type lowSizeType, Type, Type)
static Operation * replaceForAllWithNewSignature(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp, TilingResult &tileAndFuseResult, int64_t resultNumber, SmallVector< OpFoldResult > &offsets, SmallVector< OpFoldResult > &sizes)
Add new operands to the forall op for users of the producerOp that are dominated by the containing sc...
static ParseResult parseContinuousTileSizeTypes(OpAsmParser &parser, Type &targetType, Type &tileSizesType, Type &chunkSizesType)
static SmallVector< Operation * > tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
First, find the first "scf::ForallOp" user of producerOp and ensure it is exactly the containingOp,...
static ParseResult parseMultitileSizesTypes(OpAsmParser &parser, Type &targetType, Type &lowSizeType, Type &highSizeType, Type &splitPointType)
static SmallVector< OpFoldResult > normalizeUpperBounds(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > lbs, ArrayRef< OpFoldResult > ubs, ArrayRef< OpFoldResult > steps)
Given lbs, ubs and steps of loops, return (for each loop), the normalized upper bound.
#define DBGS()
static LogicalResult applyTilingToAll(RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps, unsigned numLoops, transform::TransformResults &transformResults, function_ref< FailureOr< scf::SCFTileAndFuseResult >(TilingInterface)> applyFn)
Apply a tiling transformation to all payload ops and store both the tiled operation as well as the cr...
static std::tuple< SmallVector< Operation * >, Operation * > tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
Find the first "extract" user of producerOp and tile it right before its use.
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:295
Block represents an ordered list of Operations.
Definition: Block.h:33
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
UnitAttr getUnitAttr()
Definition: Builders.cpp:94
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:224
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:163
AffineExpr getAffineSymbolExpr(unsigned position)
Definition: Builders.cpp:364
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:108
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:89
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:258
MLIRContext * getContext() const
Definition: Builders.h:56
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:262
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:277
IndexType getIndexType()
Definition: Builders.cpp:51
ArrayAttr getStrArrayAttr(ArrayRef< StringRef > values)
Definition: Builders.cpp:302
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
A class for computing basic dominance information.
Definition: Dominance.h:140
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition: Dominance.h:158
This class allows control over how the GreedyPatternRewriteDriver works.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
Definition: AsmPrinter.cpp:95
This class represents a saved insertion point.
Definition: Builders.h:325
bool isSet() const
Returns true if this insert point is set.
Definition: Builders.h:335
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:544
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
Definition: Builders.h:314
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:318
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
This class represents an operand of an operation.
Definition: Value.h:243
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:433
This class provides the API for ops that are known to be isolated from above.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getOpResult(unsigned idx)
Definition: Operation.h:421
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:750
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:534
void setOperand(unsigned idx, Value value)
Definition: Operation.h:351
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition: Operation.h:560
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:798
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:346
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition: Operation.h:263
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:874
result_range getOpResults()
Definition: Operation.h:420
result_range getResults()
Definition: Operation.h:415
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:219
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:753
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:238
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:362
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:686
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:606
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:598
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:504
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:54
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
Type front()
Return first type in the range.
Definition: TypeRange.h:152
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
user_range getUsers() const
Definition: Value.h:204
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
State for analysis-enabled bufferization.
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
void assign(unsigned size, std::nullptr_t)
Sets the list of results to size null pointers.
void reserve(unsigned size)
Reserves space for size elements in the list.
size_t size() const
Returns the number of elements in the list.
void push_back(Operation *op)
Appends an element to the list.
A listener that updates a TransformState based on IR modifications.
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void setValues(OpResult handle, Range &&values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
void setParams(OpResult value, ArrayRef< TransformState::Param > params)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
LogicalResult notifyPayloadOperationReplaced(Operation *op, Operation *replacement)
Notify the transform dialect interpreter that the given op has been replaced with another op and that...
The state maintained across applications of various ops implementing the TransformOpInterface.
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
Definition: AffineOps.cpp:1254
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1158
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1208
LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze op and its nested ops.
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition: Visitors.h:136
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
FailureOr< PackingResult > buildPackingLoopNest(RewriterBase &rewriter, tensor::PadOp opToHoist, scf::ForOp outermostEnclosingForOp, ArrayRef< int64_t > transposeVector)
Build the packing loop nest required to hoist opToHoist above outermostEnclosingForOp.
LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, const LinalgPaddingOptions &options, LinalgOp &paddedOp, SmallVector< Value > &replacements, SmallVector< tensor::PadOp > &padOps)
Pad the iterator dimensions paddingDimensions of all opToPad operands to a static bounding box.
Definition: Padding.cpp:153
FailureOr< std::pair< Operation *, Operation * > > rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp)
Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) and linalg....
bool hasVectorizationImpl(Operation *)
Return true if there's dedicated logic in the Linalg Vectorizer to vectorize this Op,...
FailureOr< Operation * > decomposeWinogradFilterTransformOp(RewriterBase &rewriter, linalg::WinogradFilterTransformOp op)
Rewrite linalg.winograd_filter_transform.
std::optional< Value > allocateWorkgroupMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU workgroup memory.
Definition: Promotion.cpp:470
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, linalg::PackOp packOp, linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
Definition: Transforms.cpp:677
Value bufferizeToAllocation(RewriterBase &rewriter, const BufferizeToAllocationOptions &options, tensor::PadOp padOp, Attribute memorySpace={}, Operation *insertionPoint=nullptr)
Materialize a buffer allocation for the given tensor.pad op and lower the op to linalg....
FailureOr< Value > hoistPaddingOnTensors(RewriterBase &rewriter, tensor::PadOp opToHoist, int64_t numLoops, ArrayRef< int64_t > transposeVector, tensor::PadOp &hoistedOp, SmallVectorImpl< TransposeOp > &transposeOps)
Mechanically hoist padding operations on tensors by numLoops into a new, generally larger tensor.
FailureOr< LinalgOp > specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp)
Create a namedOp from the given GenericOp and replace the GenericOp.
Definition: Specialize.cpp:260
FailureOr< LowerUnPackOpResult > lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp, bool lowerUnpadLikeWithExtractSlice=true)
Rewrite pack as empty + transpose + reshape + extract_slice.
Definition: Transforms.cpp:356
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns)
Definition: Tiling.cpp:860
LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value)
In case of GPU private memory there is no need to deallocate since the memory is freed when going out...
Definition: Promotion.cpp:511
FailureOr< Operation * > decomposeWinogradOutputTransformOp(RewriterBase &rewriter, linalg::WinogradOutputTransformOp op)
Rewrite linalg.winograd_output_transform.
std::optional< Value > allocateGPUPrivateMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU private memory.
Definition: Promotion.cpp:495
FailureOr< Operation * > rewriteInDestinationPassingStyle(RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp)
Rewrite tensor.from_elements to linalg.generic.
FailureOr< Operation * > winogradConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op, int64_t m, int64_t r)
Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm F(m x m, r x r).
FailureOr< Operation * > transposeConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op)
Convert linalg.conv_2d_nhwc_fhwc(_q) to linalg.conv_2d_nhwc_hwcf(_q) by materializing transpose.
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
Patterns to fold unit-extent dimensions in operands/results of linalg ops on tensors via reassociativ...
LogicalResult copyToWorkgroupMemory(OpBuilder &b, Value src, Value dst)
Create Memref copy operations and add gpu barrier guards before and after the copy operation to ensur...
Definition: Promotion.cpp:486
LogicalResult linalgOpAnchoredEmptyTensorEliminationStep(RewriterBase &rewriter, Operation *op, bufferization::OneShotAnalysisState &state)
Try to eliminate tensor::EmptyOps inside op that are anchored on a LinalgOp.
LogicalResult vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Emit a suitable vector form for an operation.
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp linalgOp)
Create a GenericOp from the given named operation linalgOp and replace the given linalgOp.
FailureOr< Operation * > transposeBatchMatmul(RewriterBase &rewriter, linalg::BatchMatmulOp op, bool transposeLHS=true)
Pattern to replace.
LogicalResult promoteSubviewsPrecondition(Operation *op, LinalgPromotionOptions options)
Promote memref.subviews feeding linalg-on-buffers operations.
Definition: Promotion.cpp:399
LogicalResult copyToGPUPrivateMemory(OpBuilder &b, Value src, Value dst)
Normal copy to between src and dst.
Definition: Promotion.cpp:503
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
Definition: Utils.cpp:223
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.
Definition: Interchange.cpp:50
FailureOr< StaticMultiSizeSpecification > computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, int64_t targetSize, int64_t divisor)
Definition: Tiling.cpp:242
void populateDecomposePackUnpackPatterns(RewritePatternSet &patterns)
Populates patterns to decompose linalg.pack and linalg.unpack Ops into e.g.
FailureOr< ContinuousTileSizeSpecification > computeContinuousTileSizes(OpBuilder &builder, TilingInterface op, unsigned dimension, OpFoldResult targetSize, bool emitAssertions)
Definition: Tiling.cpp:162
FailureOr< StaticContinuousTileSizeSpecification > computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension, unsigned targetSize)
Definition: Tiling.cpp:111
FailureOr< SplitReductionResult > splitReduction(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
void populateFoldPackUnpackIntoTensorEmptyPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold operations like linalg.pack and linalg....
void hoistRedundantVectorBroadcasts(RewriterBase &rewriter, Operation *root)
Hoist vector.extract/vector.broadcast pairs out of immediately enclosing scf::ForOp iteratively,...
Definition: Hoisting.cpp:97
FailureOr< ForallReductionTilingResult > tileReductionUsingForall(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > tileSizes={}, std::optional< ArrayAttr > mapping=std::nullopt)
Method to tile a reduction to parallel iterations computing partial reductions.
Definition: Tiling.cpp:594
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
Definition: Transforms.cpp:768
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
Definition: Transforms.cpp:479
void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns)
Patterns to promote inputs to outputs and remove unused inputs of linalg.generic ops.
FailureOr< LinalgOp > promoteSubViews(OpBuilder &b, LinalgOp op, const LinalgPromotionOptions &options)
Promote the subViews into a new buffer allocated at the insertion point b.
Definition: Promotion.cpp:421
std::function< SplitReductionOptions(LinalgOp op)> ControlSplitReductionFn
Function signature to control reduction splitting.
Definition: Transforms.h:442
LogicalResult deallocateWorkgroupMemory(OpBuilder &, Value)
In case of GPU group memory there is no need to deallocate.
Definition: Promotion.cpp:479
FailureOr< Operation * > transposeMatmul(RewriterBase &rewriter, linalg::MatmulOp op, bool transposeLHS=true)
Convert Linalg matmul ops to transposed variants.
FailureOr< CollapseResult > collapseOpIterationDims(LinalgOp op, ArrayRef< ReassociationIndices > foldedIterationDims, RewriterBase &rewriter)
Collapses dimensions of linalg.generic/linalg.copy operation.
void hoistRedundantVectorTransfers(Operation *root, bool verifyNonZeroTrip=false)
Hoist vector.transfer_read/vector.transfer_write on buffers pairs out of immediately enclosing scf::F...
Definition: Hoisting.cpp:202
FailureOr< Operation * > decomposeWinogradInputTransformOp(RewriterBase &rewriter, linalg::WinogradInputTransformOp op)
Rewrite linalg.winograd_input_transform.
void populateDecomposePadPatterns(RewritePatternSet &patterns)
Populates patterns to decompose tensor.pad into e.g.
void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns)
Pattern to replace linalg.add when destination passing on a contraction op suffices for achieving the...
std::pair< TilingInterface, TilingInterface > splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension, OpFoldResult splitPoint)
Split the given op into two parts along the given iteration space dimension at the specified splitPoi...
Definition: Split.cpp:67
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold operations like tensor.pad and tensor.extract_slice into t...
FailureOr< SplitReductionResult > splitReductionByScaling(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
Scaling-based implementation of the split reduction transformation.
FailureOr< MultiSizeSpecification > computeMultiTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension, OpFoldResult targetSize, OpFoldResult divisor, bool emitAssertions=true)
Emits the IR computing the multi-sized tiling specification with two tile sizes not exceeding targetS...
Definition: Tiling.cpp:268
FailureOr< LowerPackResult > lowerPack(RewriterBase &rewriter, linalg::PackOp packOp, bool lowerPadLikeWithInsertSlice=true)
Rewrite pack as pad + reshape + transpose.
Definition: Transforms.cpp:219
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:78
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
FailureOr< SCFTilingResult > tileUsingSCF(RewriterBase &rewriter, TilingInterface op, const SCFTilingOptions &options)
Method to tile an op that implements the TilingInterface using scf.for for iterating over the tiles.
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
Definition: SCF.cpp:602
FailureOr< SmallVector< scf::ForOp > > lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op)
Method to lower an op that implements the TilingInterface to loops/scalars.
FailureOr< scf::SCFTilingResult > tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > tileSize)
Method to tile a reduction and generate a parallel op within a serial loop.
uint64_t getM(LevelType lt)
Definition: Enums.h:443
void populateMergeConsecutiveInsertExtractSlicePatterns(RewritePatternSet &patterns)
Collects patterns to merge consecutive tensor.insert_slice/extract_slice into one.
void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns)
Appends patterns that are used to bubble up tensor.extract slice op above its producer.
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:114
void populateFoldTensorSubsetIntoVectorTransferPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into vector transfer ops.
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
DiagnosedSilenceableFailure tileToForallOpImpl(RewriterBase &rewriter, transform::TransformState &state, TransformOpInterface transformOp, Operation *target, ArrayRef< OpFoldResult > mixedNumThreads, ArrayRef< OpFoldResult > mixedTileSizes, std::optional< ArrayAttr > mapping, scf::SCFTilingResult &tilingResult)
Implementation of tiling operations using scf.forall.
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
void populateVectorStepLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant Vector Ops by re-ordering them with e.g.
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Definition: XeGPUOps.cpp:22
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:325
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:424
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
This is the representation of an operand reference.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Definition: Builders.h:283
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
NamedAttrList attributes
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
A listener that forwards all notifications to another listener.
Definition: PatternMatch.h:425
ForwardingListener(OpBuilder::Listener *listener)
Definition: PatternMatch.h:426
Container for result values of tiling.
SmallVector< Value > tiledValues
Options for analysis-enabled bufferization.
Transformation to drop unit-extent dimensions from linalg.generic operations.
Definition: Transforms.h:473
Vectorization pattern for memref::CopyOp.
Definition: Transforms.h:1505
Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) dimensions into 1-D depthwise c...
Definition: Transforms.h:1437
Match and rewrite for the pattern:
Definition: Transforms.h:1627
Match and rewrite for the pattern:
Definition: Transforms.h:1655
LinalgPromotionOptions & setUseFullTileBuffersByDefault(bool use)
Definition: Transforms.h:379
LinalgPromotionOptions & setAlignment(unsigned align)
Definition: Transforms.h:385
LinalgPromotionOptions & setUseAlloca(bool use)
Definition: Transforms.h:398
LinalgPromotionOptions & setCopyInOutFns(CopyCallbackFn const &copyIn, CopyCallbackFn const &copyOut)
Definition: Transforms.h:418
LinalgPromotionOptions & setUseFullTileBuffers(ArrayRef< bool > useFullTiles)
Definition: Transforms.h:368
LinalgPromotionOptions & setMemorySpace(Attribute memorySpc)
Definition: Transforms.h:392
LinalgPromotionOptions & setAllocationDeallocationFns(AllocBufferCallbackFn const &allocFn, DeallocBufferCallbackFn const &deallocFn)
Definition: Transforms.h:408
LinalgPromotionOptions & setOperandsToPromote(ArrayRef< int64_t > operands)
Definition: Transforms.h:357
Split Reduction options.
Definition: Transforms.h:427
Options used to control tile + fuse.
SCFTilingOptions tilingOptions
The tiling options used to control the tiling of the consumer.
std::optional< FrozenRewritePatternSet > cleanupPatterns
An optional set of rewrite patterns to apply to the results of tiling before fusion.
Options to use to control tiling.
SCFTilingOptions & setTileSizeComputationFunction(SCFTileSizeComputationFunction fun)
SCFTilingOptions & setInterchange(ArrayRef< int64_t > interchange)
SCFTilingOptions & setTileSizes(ArrayRef< OpFoldResult > tileSizes)
Convenience function to set the tileSizeComputationFunction to a function that computes tile sizes at...
SmallVector< int64_t > interchangeVector
The interchange vector to reorder the tiled loops.
Transformation information returned after tiling.
SmallVector< Operation * > tiledOps
Tiled operations that are generated during tiling.
SmallVector< LoopLikeOpInterface > loops
The scf.for operations that iterate over the tiles.