MLIR  20.0.0git
LinalgTransformOps.cpp
Go to the documentation of this file.
1 //===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
12 
38 #include "mlir/IR/PatternMatch.h"
39 #include "mlir/IR/TypeUtilities.h"
41 #include "mlir/Support/LLVM.h"
42 #include "mlir/Support/TypeID.h"
44 #include "llvm/ADT/STLExtras.h"
45 #include "llvm/ADT/ScopeExit.h"
46 #include "llvm/ADT/TypeSwitch.h"
47 #include "llvm/Support/Debug.h"
48 #include <type_traits>
49 
50 using namespace mlir;
51 using namespace mlir::linalg;
52 using namespace mlir::transform;
53 
54 #define DEBUG_TYPE "linalg-transforms"
55 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
56 #define DBGSNL() (llvm::dbgs() << "\n")
57 #define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n")
58 
59 /// Attempts to apply the pattern specified as template argument to the given
60 /// operation. The pattern is expected to have a `returningMatchAndRewrite`
61 /// function that returns the "main" result or failure. Returns failure if the
62 /// pattern failed to apply. Extra arguments are forwarded to the pattern
63 /// constructor.
64 template <typename PatternTy, typename... Args>
65 static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
66  // Check if the given operation has the type expected by the pattern.
67  using OpTy = typename llvm::function_traits<
68  decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
69  auto op = dyn_cast<OpTy>(operation);
70  if (!op)
71  return failure();
72 
73  // Apply the pattern directly to the op.
74  PatternTy pattern(operation->getContext(), std::forward<Args>(args)...);
75  // We want to discourage direct use of PatternRewriter in APIs but In this
76  // very specific case, an IRRewriter is not enough.
77  struct TrivialPatternRewriter : public PatternRewriter {
78  public:
79  explicit TrivialPatternRewriter(MLIRContext *context)
80  : PatternRewriter(context) {}
81  };
82  TrivialPatternRewriter rewriter(operation->getContext());
83  rewriter.setInsertionPoint(operation);
84  auto result = pattern.returningMatchAndRewrite(op, rewriter);
85  if (failed(result))
86  return failure();
87  return cast<LinalgOp>(result->getOperation());
88 }
89 
90 /// Assuming that `ofr` is an index attr or a param of index type
91 /// or a transform dialect handle mapped to exactly one op
92 /// with one index result, return that value.
94  transform::TransformState &state, TransformOpInterface transformOp,
96  for (OpFoldResult ofr : ofrs) {
97  if (ofr.is<Attribute>()) {
98  if (!isa<IntegerAttr>(ofr.get<Attribute>()))
99  return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
100  result.push_back(ofr);
101  continue;
102  }
103 
104  Value transformValue = ofr.get<Value>();
105  if (isa<TransformParamTypeInterface>(transformValue.getType())) {
106  ArrayRef<Attribute> params = state.getParams(transformValue);
107  if (params.size() != 1)
108  return transformOp.emitDefiniteFailure()
109  << "requires exactly one parameter associated";
110  result.push_back(params[0]);
111  continue;
112  }
113 
114  auto payloadOps = state.getPayloadOps(transformValue);
115  if (!llvm::hasSingleElement(payloadOps)) {
117  transformOp.emitSilenceableError()
118  << "handle must be mapped to exactly one payload op";
119  diag.attachNote(transformValue.getLoc())
120  << "mapped to " << llvm::range_size(payloadOps) << " payload ops";
121  return diag;
122  }
123 
124  Operation *op = *payloadOps.begin();
125  if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
127  transformOp.emitSilenceableError()
128  << "payload op must have exactly 1 index result";
129  diag.attachNote(op->getLoc())
130  << "has " << op->getNumResults() << " results";
131  return diag;
132  }
133  result.push_back(op->getResult(0));
134  }
135 
137 }
138 
139 // Given a list of params that are index attrs or a list of OpFoldResults
140 // that are either index attrs or op handles, return a list of OpFoldResults
141 // of index attrs or a list of OpFoldResults where all op handles are
142 // replaced with the first (and only) OpResult of that payload op.
143 // (There must be exactly one parameter associated with the AnyParamType or
144 // one mapped payload op which must have exactly one index result.)
146  transform::TransformState &state, TransformOpInterface transformOp,
147  SmallVector<OpFoldResult> &result, Value packedHandle) {
148  if (isa<TransformParamTypeInterface>(packedHandle.getType())) {
149  ArrayRef<Attribute> params = state.getParams(packedHandle);
150  for (auto param : params) {
151  if (!isa<IntegerAttr>(param))
152  return transformOp.emitDefiniteFailure()
153  << "expected the parameter to be associated with an integer "
154  "attribute";
155  result.push_back(param);
156  }
158  }
159 
160  for (Operation *op : state.getPayloadOps(packedHandle)) {
161  if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
163  transformOp.emitSilenceableError()
164  << "payload op must have exactly 1 index result";
165  diag.attachNote(op->getLoc())
166  << "has " << op->getNumResults() << " results";
167  return diag;
168  }
169  result.push_back(op->getResult(0));
170  }
171 
173 }
174 
175 /// When possible, converts each `OpFoldResult` in `mixedResult` to
176 /// an integer if the value can be statically inferred. If a result
177 /// is a `Value` then it must be either a `ParamType` or a handle
178 /// to an a constant like op.
180  TransformState &state, TransformOpInterface &transformOp,
181  ArrayRef<OpFoldResult> mixedResults, SmallVectorImpl<int64_t> &reified) {
182  for (OpFoldResult paramOrHandle : mixedResults) {
183  if (isa<Attribute>(paramOrHandle)) {
184  reified.push_back(
185  cast<IntegerAttr>(paramOrHandle.get<Attribute>()).getInt());
186  continue;
187  } else if (isa<ParamType>(paramOrHandle.get<Value>().getType())) {
188  ArrayRef<Attribute> params = state.getParams(paramOrHandle.get<Value>());
189  if (params.size() != 1)
190  return transformOp.emitSilenceableError() << "expected a single param";
191  reified.push_back(
192  cast<IntegerAttr>(params.front()).getValue().getSExtValue());
193  continue;
194  }
195 
196  Value handle = paramOrHandle.get<Value>();
197  if (!isa<TransformHandleTypeInterface>(handle.getType()))
198  return transformOp.emitSilenceableError() << "unexpected value handle";
199  auto payload = state.getPayloadOps(handle);
200  if (!llvm::hasSingleElement(payload))
201  return transformOp.emitSilenceableError()
202  << "requires param or handle that is mapped to 1 payload op";
203 
204  Operation *paramOrHandlePayloadOp = *payload.begin();
205  if (paramOrHandlePayloadOp->getNumResults() != 1 ||
206  !paramOrHandlePayloadOp->getResult(0).getType().isIndex()) {
207  return transformOp.emitSilenceableError()
208  << "requires param or handle to be result of op with 1 index "
209  "result";
210  }
211 
212  IntegerAttr attr;
213  if (!matchPattern(paramOrHandlePayloadOp->getResult(0), m_Constant(&attr)))
214  return transformOp.emitSilenceableError()
215  << "requires param or handle to be the result of a constant like "
216  "op";
217 
218  reified.push_back(attr.getInt());
219  }
221 }
222 
223 //===----------------------------------------------------------------------===//
224 // Apply...PatternsOp
225 //===----------------------------------------------------------------------===//
226 
227 void transform::ApplyEraseUnnecessaryInputsPatternsOp::populatePatterns(
228  RewritePatternSet &patterns) {
230 }
231 
232 void transform::ApplyDecomposeTensorPackUnpackPatternsOp::populatePatterns(
233  RewritePatternSet &patterns) {
235 }
236 
237 void transform::ApplyDecomposeTensorPadPatternsOp::populatePatterns(
238  RewritePatternSet &patterns) {
240 }
241 
242 void transform::ApplyFoldUnitExtentDimsViaReshapesPatternsOp::populatePatterns(
243  RewritePatternSet &patterns) {
246 }
247 
248 void transform::ApplyFoldUnitExtentDimsViaSlicesPatternsOp::populatePatterns(
249  RewritePatternSet &patterns) {
251  options.rankReductionStrategy =
254 }
255 
256 void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns(
257  RewritePatternSet &patterns) {
259 }
260 
261 void transform::ApplyFoldAddIntoDestPatternsOp::populatePatterns(
262  RewritePatternSet &patterns) {
264 }
265 
266 void transform::ApplyPadVectorizationPatternsOp::populatePatterns(
267  RewritePatternSet &patterns) {
270 }
271 
272 //===----------------------------------------------------------------------===//
273 // BufferizeToAllocationOp
274 //===----------------------------------------------------------------------===//
275 
276 void transform::BufferizeToAllocationOp::build(OpBuilder &b,
277  OperationState &result,
278  Value target,
279  Attribute memorySpace) {
280  SmallVector<Type> resultTypes;
281  resultTypes.push_back(b.getType<transform::AnyValueType>());
282  resultTypes.push_back(b.getType<transform::AnyOpType>());
283  return build(b, result,
284  /*resultTypes=*/resultTypes,
285  /*target=*/target,
286  /*memorySpace=*/memorySpace);
287 }
288 
289 void transform::BufferizeToAllocationOp::build(OpBuilder &b,
290  OperationState &result,
291  Value target,
292  int64_t memorySpace) {
293  SmallVector<Type> resultTypes;
294  resultTypes.push_back(b.getType<transform::AnyValueType>());
295  resultTypes.push_back(b.getType<transform::AnyOpType>());
296  return build(b, result,
297  /*resultTypes=*/resultTypes,
298  /*target=*/target,
299  /*memorySpace=*/b.getI64IntegerAttr(memorySpace));
300 }
301 
302 namespace {
303 class NewOpsListener : public RewriterBase::ForwardingListener {
304 public:
306 
307  SmallVector<Operation *> getNewOps() const {
308  return SmallVector<Operation *>(newOps.begin(), newOps.end());
309  }
310 
311 private:
312  void notifyOperationInserted(Operation *op,
313  OpBuilder::InsertPoint previous) override {
314  ForwardingListener::notifyOperationInserted(op, previous);
315  // We only care about newly created ops.
316  if (previous.isSet())
317  return;
318  auto inserted = newOps.insert(op);
319  (void)inserted;
320  assert(inserted.second && "expected newly created op");
321  }
322 
323  void notifyOperationErased(Operation *op) override {
324  ForwardingListener::notifyOperationErased(op);
325  op->walk([&](Operation *op) { newOps.erase(op); });
326  }
327 
328  DenseSet<Operation *> newOps;
329 };
330 } // namespace
331 
332 DiagnosedSilenceableFailure transform::BufferizeToAllocationOp::apply(
335  // Attach listener to keep track of newly created ops.
336  OpBuilder::Listener *previousListener = rewriter.getListener();
337  auto resetListener =
338  llvm::make_scope_exit([&]() { rewriter.setListener(previousListener); });
339  NewOpsListener newOpsListener(previousListener);
340  rewriter.setListener(&newOpsListener);
341 
343  if (getMemcpyOp() == "bufferization.materialize_in_destination") {
346  } else if (getMemcpyOp() == "memref.copy") {
347  options.memcpyOp =
349  } else if (getMemcpyOp() == "linalg.copy") {
350  options.memcpyOp =
352  } else {
353  llvm_unreachable("invalid memcpy op");
354  }
355  if (getAllocOp() == "memref.alloc") {
356  options.allocOp =
358  } else if (getAllocOp() == "memref.alloca") {
359  options.allocOp =
361  } else {
362  llvm_unreachable("invalid alloc op");
363  }
364  options.bufferizeDestinationOnly = getBufferizeDestinationOnly();
365  options.emitDealloc = getEmitDealloc();
366 
367  // Bufferize ops.
368  Attribute memorySpace =
369  getMemorySpace().has_value() ? getMemorySpace().value() : Attribute();
370  SmallVector<Value> allocatedBuffers;
371  for (Operation *op : state.getPayloadOps(getTarget())) {
372  Value buffer =
373  linalg::bufferizeToAllocation(rewriter, options, op, memorySpace);
374  if (!buffer) {
375  DiagnosedSilenceableFailure diag = emitSilenceableError()
376  << "failed to bufferize operation";
377  diag.attachNote(op->getLoc()) << "target payload op";
378  return diag;
379  }
380  allocatedBuffers.push_back(buffer);
381  }
382 
383  // Set results.
384  results.setValues(cast<OpResult>(getAllocatedBuffer()), allocatedBuffers);
385  results.set(cast<OpResult>(getNewOps()), newOpsListener.getNewOps());
387 }
388 
389 void transform::BufferizeToAllocationOp::getEffects(
391  if (getBufferizeDestinationOnly()) {
392  // The destination is replaced with a newly allocated buffer, but the op
393  // itself remains in place.
394  onlyReadsHandle(getTargetMutable(), effects);
395  } else {
396  consumesHandle(getTargetMutable(), effects);
397  }
398  producesHandle(getOperation()->getOpResults(), effects);
399  modifiesPayload(effects);
400 }
401 
403  if (getMemcpyOp() != "bufferization.materialize_in_destination" &&
404  getMemcpyOp() != "memref.copy" && getMemcpyOp() != "linalg.copy")
405  return emitOpError() << "unsupported memcpy op";
406  if (getAllocOp() != "memref.alloc" && getAllocOp() != "memref.alloca")
407  return emitOpError() << "unsupported alloc op";
408  return success();
409 }
410 
411 //===----------------------------------------------------------------------===//
412 // DecomposeOp
413 //===----------------------------------------------------------------------===//
414 
416 transform::DecomposeOp::applyToOne(transform::TransformRewriter &rewriter,
417  LinalgOp target,
419  transform::TransformState &state) {
420 #define DOWNSCALE(trans) \
421  { \
422  FailureOr<LinalgOp> res = tryApply<trans>(target); \
423  if (succeeded(res)) { \
424  results.push_back(*res); \
425  return DiagnosedSilenceableFailure::success(); \
426  } \
427  }
428 
429 #define DOWNSCALE_CALL(a, b) DownscaleSizeOneWindowed2DConvolution<a, b>
430 #define DOWNSCALE_NORMAL(a, b) DOWNSCALE(DOWNSCALE_CALL(a, b))
431 
432  DOWNSCALE_NORMAL(Conv2DNhwcHwcfOp, Conv1DNwcWcfOp)
433  DOWNSCALE_NORMAL(Conv2DNchwFchwOp, Conv1DNcwFcwOp)
434  DOWNSCALE_NORMAL(PoolingNhwcSumOp, PoolingNwcSumOp)
435  DOWNSCALE_NORMAL(PoolingNchwSumOp, PoolingNcwSumOp)
436  DOWNSCALE_NORMAL(PoolingNhwcMaxOp, PoolingNwcMaxOp)
437  DOWNSCALE_NORMAL(PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp)
438  DOWNSCALE_NORMAL(PoolingNhwcMinOp, PoolingNwcMinOp)
439  DOWNSCALE_NORMAL(PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp)
440  DOWNSCALE_NORMAL(PoolingNchwMaxOp, PoolingNcwMaxOp)
443 #undef DOWNSCALE_NORMAL
444 #undef DOWNSCALE_CALL
445 #undef DOWNSCALE
446  return emitDefaultSilenceableFailure(target);
447 }
448 
449 //===----------------------------------------------------------------------===//
450 // DecomposeInterfaceOp
451 //===----------------------------------------------------------------------===//
452 
453 // Decompose the target operation if it implements the AggregatedOpInterface.
454 // Push the decomposed operations (the ones that replaces the values produced by
455 // \p target) in the `results`.
456 DiagnosedSilenceableFailure transform::DecomposeInterfaceOp::applyToOne(
457  transform::TransformRewriter &rewriter, Operation *target,
459  transform::TransformState &state) {
460  auto decomposableOp = dyn_cast<AggregatedOpInterface>(target);
461  if (!decomposableOp) {
462  failed(rewriter.notifyMatchFailure(target,
463  "payload is not a decomposable op"));
464  return emitDefaultSilenceableFailure(target);
465  }
466 
467  FailureOr<SmallVector<Value>> maybeNewResults =
468  decomposableOp.decomposeOperation(rewriter);
469  if (failed(maybeNewResults))
470  return emitDefaultSilenceableFailure(target);
471 
472  rewriter.replaceOp(decomposableOp, *maybeNewResults);
473  for (Value val : *maybeNewResults) {
474  Operation *definition = val.getDefiningOp();
475  if (definition)
476  results.push_back(definition);
477  }
479 }
480 
481 //===----------------------------------------------------------------------===//
482 // EliminateLinalgOpAnchoredEmptyTensorsOp
483 //===----------------------------------------------------------------------===//
484 
485 void transform::EliminateLinalgOpAnchoredEmptyTensorsOp::getEffects(
487  onlyReadsHandle(getTargetMutable(), effects);
488  modifiesPayload(effects);
489 }
490 
492 transform::EliminateLinalgOpAnchoredEmptyTensorsOp::apply(
493  transform::TransformRewriter &rewriter, TransformResults &transformResults,
494  TransformState &state) {
496  options.allowReturnAllocsFromLoops = true;
497 
498  for (Operation *target : state.getPayloadOps(getTarget())) {
500  if (failed(analyzeOp(target, state)))
501  return mlir::emitSilenceableFailure(target->getLoc())
502  << "failed to analyze op";
504  rewriter, target, state)))
505  return mlir::emitSilenceableFailure(target->getLoc())
506  << "failed to eliminate LinalgOp anchored tensor.empty ops";
507  }
509 }
510 
511 //===----------------------------------------------------------------------===//
512 // FuseOp
513 //===----------------------------------------------------------------------===//
514 
515 /// Apply a tiling transformation to all payload ops and store both the
516 /// tiled operation as well as the created tile loops.
517 template <typename Range>
518 static LogicalResult applyTilingToAll(
519  RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps,
520  unsigned numLoops, transform::TransformResults &transformResults,
521  function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)>
522  applyFn) {
523  SmallVector<Operation *> tiledLinalgOps;
524  SmallVector<SmallVector<Operation *>> loopOps(numLoops);
525 
526  for (Operation *target : payloadOps) {
527  auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
528  if (!tilingInterfaceOp)
529  return transformOp->emitError("only TilingInterface ops are supported");
530 
531  rewriter.setInsertionPoint(target);
532  FailureOr<scf::SCFTileAndFuseResult> tiledResults =
533  applyFn(tilingInterfaceOp);
534  if (failed(tiledResults))
535  return failure();
536 
537  // Perform the replacement of tiled and fused values.
538  SmallVector<Operation *> opsToReplace{target};
539  llvm::append_range(opsToReplace, tiledResults->fusedProducers);
540  for (Operation *toReplace : opsToReplace) {
541  for (OpResult res : toReplace->getResults())
542  if (auto replacement = tiledResults->replacements.lookup(res))
543  rewriter.replaceAllUsesWith(res, replacement);
544  if (toReplace->use_empty()) {
545  rewriter.eraseOp(toReplace);
546  }
547  }
548 
549  // Report back the relevant handles to the transform op.
550  tiledLinalgOps.push_back(tiledResults->tiledAndFusedOps.front());
551  assert(tiledResults->loops.size() == numLoops &&
552  "Mismatched number of loops, tile and fuse transform should have "
553  "failed");
554  for (unsigned int i = 0; i < numLoops; ++i)
555  loopOps[i].push_back(tiledResults->loops[i]);
556  }
557 
558  transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
559  for (unsigned int i = 0; i < numLoops; ++i)
560  transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
561 
562  return success();
563 }
564 
566 transform::FuseOp::apply(transform::TransformRewriter &rewriter,
567  mlir::transform::TransformResults &transformResults,
569  SmallVector<int64_t> tileSizes =
570  extractFromIntegerArrayAttr<int64_t>(getTileSizes());
571  SmallVector<int64_t> tileInterchange =
572  extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
573 
574  scf::SCFTilingOptions tilingOptions;
575  tilingOptions.interchangeVector = tileInterchange;
576  SmallVector<OpFoldResult> tileSizesOfr =
577  getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
578  tilingOptions = tilingOptions.setTileSizes(tileSizesOfr);
579  scf::SCFTileAndFuseOptions tileAndFuseOptions;
580  tileAndFuseOptions.tilingOptions = tilingOptions;
581 
582  if (getApplyCleanup()) {
583  MLIRContext *context = rewriter.getContext();
584  RewritePatternSet patterns(context);
585  tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, context);
587  tileAndFuseOptions.cleanupPatterns = std::move(patterns);
588  }
589 
590  LogicalResult result = applyTilingToAll(
591  rewriter, getOperation(), state.getPayloadOps(getTarget()),
592  tileSizes.size() - llvm::count(tileSizes, 0), transformResults,
593  [&](TilingInterface tilingInterfaceOp)
594  -> FailureOr<scf::SCFTileAndFuseResult> {
595  return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
596  tileAndFuseOptions);
597  });
598  return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
599  : DiagnosedSilenceableFailure::success();
600 }
601 
602 LogicalResult transform::FuseOp::verify() {
603  SmallVector<int64_t> permutation =
604  extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
605  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
606  if (!std::is_permutation(sequence.begin(), sequence.end(),
607  permutation.begin(), permutation.end())) {
608  return emitOpError() << "expects interchange to be a permutation, found "
609  << getTileInterchange();
610  }
611 
612  SmallVector<int64_t> sizes =
613  extractFromIntegerArrayAttr<int64_t>(getTileSizes());
614  size_t numExpectedLoops = sizes.size() - llvm::count(sizes, 0);
615  if (numExpectedLoops != getNumResults() - 1)
616  return emitOpError() << "expects " << numExpectedLoops << " loop results";
617 
618  return success();
619 }
620 
621 //===----------------------------------------------------------------------===//
622 // FuseIntoContainingOp
623 //===----------------------------------------------------------------------===//
624 
625 void transform::FuseIntoContainingOp::build(OpBuilder &builder,
626  OperationState &result,
627  Value producerOp,
628  Value containingOp) {
629  result.addOperands({producerOp, containingOp});
630  auto resultType = transform::AnyOpType::get(builder.getContext());
631  result.addTypes({resultType, resultType});
632 }
633 
634 /// Add new operands to the forall op for users of the producerOp
635 /// that are dominated by the containing scf.forall op.
637  RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
638  Operation *containingOp, TilingResult &tileAndFuseResult,
639  int64_t resultNumber, SmallVector<OpFoldResult> &offsets,
640  SmallVector<OpFoldResult> &sizes) {
641 
642  // Count number of users not including the containing op
643  SetVector<Operation *> dominatedUsers;
644  DominanceInfo domInfo(containingOp);
645  for (Operation *user : producerOp->getResult(resultNumber).getUsers()) {
646  if (!containingOp->isAncestor(user) &&
647  (domInfo.dominates(containingOp, user))) {
648  dominatedUsers.insert(user);
649  }
650  }
651  if (dominatedUsers.empty())
652  return nullptr;
653 
654  // Create new scf.forall op
655  auto forallOp = cast<scf::ForallOp>(containingOp);
656  OpBuilder::InsertionGuard g(rewriter);
657  rewriter.setInsertionPoint(forallOp);
658 
659  // Get new output
660  Location loc = forallOp.getLoc();
661  auto genericOp = dyn_cast<linalg::GenericOp>(producerOp);
662  if (!genericOp)
663  return nullptr;
664  SmallVector<Value> outputs = genericOp.getOutputs();
665  SmallVector<Value> newOuts(forallOp.getOutputs());
666  newOuts.push_back(outputs[resultNumber]);
667 
668  // Create new scf.forall op
669  auto newforallOp = rewriter.create<scf::ForallOp>(
670  loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
671  forallOp.getMixedStep(), newOuts, forallOp.getMapping());
672  rewriter.eraseBlock(newforallOp.getBody());
673  newforallOp.getRegion().takeBody(forallOp.getRegion());
674 
675  // Add additional block argument for new value being returned
676  // and replaces all uses of the new output with corresponding bbArg
677  // inside the scf.forall to enable fusion into this new scf.forall.
678  newforallOp.getBody()->addArgument(newOuts.back().getType(),
679  newOuts.back().getLoc());
680  auto bbArgs = newforallOp.getBody()->getArguments();
681  rewriter.replaceUsesWithIf(newOuts.back(), bbArgs.back(),
682  [&](OpOperand &use) {
683  Operation *op = use.getOwner();
684  return newforallOp->isProperAncestor(op);
685  });
686 
687  // Fix terminator
688  scf::InParallelOp terminatorOp = newforallOp.getTerminator();
689  SmallVector<Operation *> yieldingOps = llvm::to_vector<4>(llvm::map_range(
690  terminatorOp.getYieldingOps(), [](Operation &op) { return &op; }));
691  Operation *firstYieldOp = yieldingOps.front();
692  rewriter.setInsertionPoint(firstYieldOp);
693  Value src = tileAndFuseResult.tiledValues[0];
694  Value dst = newforallOp.getRegionIterArgs().back();
695  SmallVector<OpFoldResult> strides(offsets.size(), rewriter.getIndexAttr(1));
696  rewriter.create<tensor::ParallelInsertSliceOp>(firstYieldOp->getLoc(), src,
697  dst, offsets, sizes, strides);
698 
699  for (auto result : llvm::enumerate(forallOp.getResults())) {
700  rewriter.replaceAllUsesWith(result.value(),
701  newforallOp->getResult(result.index()));
702  }
703  rewriter.replaceUsesWithIf(producerOp->getResult(resultNumber),
704  newforallOp->getResults().back(),
705  [&](OpOperand &use) {
706  Operation *user = use.getOwner();
707  return dominatedUsers.contains(user);
708  });
709  return newforallOp;
710 }
711 
712 /// Find the first "extract" user of `producerOp` and tile it right before its
713 /// use. The tiled op is fused under the `containingOp`.
714 /// Return this fused op on success or nullptr if anything fails.
715 /// If tiled op has uses that are dominated by `containingOp`, return
716 /// a new `containingOp` with results of the fused op appended to
717 /// results of the `containingOp` or nullptr if there are no dominated uses.
718 static std::tuple<SmallVector<Operation *>, Operation *>
720  Operation *producerOp, Operation *containingOp) {
721  LLVM_DEBUG(DBGS() << "Try to fuse a direct extract use\n");
722  auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
723  if (!tileableProducer) {
724  diag.attachNote(producerOp->getLoc())
725  << "producer is not a TileableInterface: " << *producerOp;
726  return {};
727  }
728 
729  // Search the producer slices accessed within the containing operation.
730  // TODO: Generalize to more extract/insert/parallel_insert triples, maybe
731  // evolve into an interface.
732  auto it = llvm::find_if(tileableProducer->getUsers(), [&](Operation *user) {
733  auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
734  return sliceOp && containingOp->isProperAncestor(sliceOp);
735  });
736 
737  // Find a fusion opportunity.
738  if (it == tileableProducer->getUsers().end()) {
739  diag.attachNote(tileableProducer->getLoc())
740  << "could not find fusion opportunity for: " << *tileableProducer;
741  return {};
742  }
743  auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*it);
744 
745  // Try to fuse the producer in-place.
746  OpBuilder::InsertionGuard guard(rewriter);
747  rewriter.setInsertionPoint(sliceOpToTile);
748 
749  // Tile the producer.
750  int64_t resultNumber =
751  cast<OpResult>(sliceOpToTile.getSource()).getResultNumber();
752  LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
753 
754  SmallVector<OpFoldResult> offsets = sliceOpToTile.getMixedOffsets();
755  SmallVector<OpFoldResult> sizes = sliceOpToTile.getMixedSizes();
756 
757  FailureOr<TilingResult> tileAndFuseResult =
758  tileableProducer.generateResultTileValue(rewriter, resultNumber, offsets,
759  sizes);
760 
761  if (failed(tileAndFuseResult)) {
762  diag.attachNote(tileableProducer->getLoc())
763  << "failed to tile producer op: " << *tileableProducer;
764  return {};
765  }
766 
767 #ifndef NDEBUG
768  for (auto *tiledOp : tileAndFuseResult->tiledOps) {
769  LLVM_DEBUG(DBGS() << "tiledProducer: " << *tiledOp << "\n");
770  }
771 #endif
772 
773  // Replace the extract op.
774  auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
775  rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
776  cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
777  if (failed(maybeRankReduced)) {
778  diag.attachNote(producerOp->getLoc())
779  << "shape types don't match (missing canonicalization?):\nTiledOp: "
780  << tileAndFuseResult->tiledValues[0]
781  << "\nSliceOp: " << sliceOpToTile.getOperation() << '\n';
782  return {};
783  }
784  rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
785 
786  // Add new outputs to containing op, if required
787  Operation *newContainingOp = replaceForAllWithNewSignature(
788  rewriter, diag, producerOp, containingOp, *tileAndFuseResult,
789  resultNumber, offsets, sizes);
790 
791  return std::make_tuple(tileAndFuseResult->tiledOps, newContainingOp);
792 }
793 
794 /// First, find the first "scf::ForallOp" user of `producerOp` and ensure
795 /// it is exactly the `containingOp`, otherwise bail.
796 /// Then, find the first "extract" user of the tied block argument and tile it
797 /// right before its "extract" use. The tiled op is fused under the
798 /// `containingOp`.
799 /// Return this fused op on success or nullptr if anything fails.
802  RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
803  Operation *containingOp) {
804  LLVM_DEBUG(DBGS() << "Try to fuse an extract use through block argument\n");
805 
806  auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
807  if (!tileableProducer) {
808  diag.attachNote(producerOp->getLoc())
809  << "producer is not a TileableInterface: " << *producerOp;
810  return {};
811  }
812 
813  // Search the first use by a "scf::ForallOp" user.
814  scf::ForallOp forallOp;
815  auto itProducerUses =
816  llvm::find_if(tileableProducer->getUses(), [&](OpOperand &use) {
817  forallOp = dyn_cast<scf::ForallOp>(use.getOwner());
818  return forallOp;
819  });
820  // If it's not from the containing op, return.
821  if (!forallOp || forallOp != containingOp) {
822  diag.attachNote(tileableProducer->getLoc())
823  << "could not find a use by the containing op: " << *tileableProducer;
824  return {};
825  }
826 
827  // Search the producer slices accessed within the containing
828  // operation.
829  // TODO: Generalize to more extract/insert/parallel_insert triples.
830  // Maybe evolve into an interface.
831  OpOperand *pUse = &(*itProducerUses);
832  BlockArgument bbArg = forallOp.getTiedBlockArgument(pUse);
833 
834  // Search the producer slices accessed within the containing operation.
835  // TODO: Generalize to more extract/insert/parallel_insert triples, maybe
836  // evolve into an interface.
837  auto itBBArgUsers = llvm::find_if(bbArg.getUsers(), [&](Operation *user) {
838  auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
839  return sliceOp && containingOp->isProperAncestor(sliceOp);
840  });
841 
842  // Find a fusion opportunity.
843  if (itBBArgUsers == bbArg.getUsers().end()) {
844  diag.attachNote(containingOp->getLoc())
845  << "could not find fusion opportunity for bbArg: " << bbArg;
846  return {};
847  }
848  auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*itBBArgUsers);
849 
850  // Try to fuse the producer in-place.
851  OpBuilder::InsertionGuard guard(rewriter);
852  rewriter.setInsertionPoint(sliceOpToTile);
853 
854  // Replace the use in the tileableProducer before tiling: clone, replace and
855  // then tile.
856  int64_t resultNumber = cast<OpResult>(pUse->get()).getResultNumber();
857  LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
858 
859  // Gather destination tensors.
860  SmallVector<Value> destinationTensors;
862  rewriter, tileableProducer->getLoc(), tileableProducer,
863  destinationTensors))) {
864  diag.attachNote(tileableProducer->getLoc())
865  << "failed to get destination tensors for: " << *tileableProducer;
866  return {};
867  }
868 
869  IRMapping bvm;
870  bvm.map(destinationTensors[resultNumber], bbArg);
871  auto tileableProducerClone =
872  cast<TilingInterface>(rewriter.clone(*tileableProducer, bvm));
873  auto scopeGuard =
874  llvm::make_scope_exit([&]() { rewriter.eraseOp(tileableProducerClone); });
875 
876  // Tile the producer.
877  FailureOr<TilingResult> tileAndFuseResult =
878  tileableProducerClone.generateResultTileValue(
879  rewriter, resultNumber, sliceOpToTile.getMixedOffsets(),
880  sliceOpToTile.getMixedSizes());
881  if (failed(tileAndFuseResult)) {
882  diag.attachNote(tileableProducer->getLoc())
883  << "failed to tile producer op: " << *tileableProducer;
884  return {};
885  }
886 
887  // Replace the extract op.
888  auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
889  rewriter, sliceOpToTile->getLoc(), tileAndFuseResult->tiledValues[0],
890  cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
891  assert(succeeded(maybeRankReduced) && "unexpected shape");
892  rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
893 
894  // Replace the use in containingOp.
895  rewriter.modifyOpInPlace(containingOp, [&]() {
896  containingOp->setOperand(pUse->getOperandNumber(),
897  destinationTensors.front());
898  });
899 
900  return tileAndFuseResult->tiledOps;
901 }
902 
904  Operation *producerOp,
905  Operation *containingOp) {
906  LLVM_DEBUG(DBGS() << "Try to fuse an use by cloning\n");
907 
908  // Gather all uses inside the containing op.
910  for (OpResult result : producerOp->getOpResults()) {
911  for (OpOperand &use : result.getUses()) {
912  if (containingOp->isProperAncestor(use.getOwner())) {
913  uses.push_back(&use);
914  continue;
915  }
916  // Cannot clone and fuse if the use is by the containing op itself: fail
917  // immediately.
918  if (containingOp == use.getOwner()) {
919  diag.attachNote(producerOp->getLoc())
920  << "producer op use by containing op cannot be fused by cloning";
921  return nullptr;
922  }
923  }
924  }
925 
926  // Check for a non-empty list of fusion opportunities.
927  if (uses.empty()) {
928  diag.attachNote(producerOp->getLoc()) << "no fusion opportunity by cloning";
929  return nullptr;
930  }
931 
932  // Clone and fuse inside the containing op.
933  Operation *fusedOp = nullptr;
934  OpOperand *use = uses.front();
935  // Parallel insert slice is not a valid clone destination.
936  // TODO: Generalize to other type of ops.
937  assert(!isa<tensor::ParallelInsertSliceOp>(use->getOwner()) &&
938  "Parallel insert slice is not a valid clone destination");
939  unsigned resultNumber = cast<OpResult>(use->get()).getResultNumber();
940  LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
941 
942  OpBuilder::InsertionGuard guard(rewriter);
943  rewriter.setInsertionPoint(use->getOwner());
944  fusedOp = rewriter.clone(*producerOp);
945  rewriter.modifyOpInPlace(
946  use->getOwner(), [&] { use->set(fusedOp->getOpResult(resultNumber)); });
947 
948  return fusedOp;
949 }
950 
951 bool transform::FuseIntoContainingOp::allowsRepeatedHandleOperands() {
952  // Allow repeated handles since we are fusing everything anyway.
953  return true;
954 }
955 
957 transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter,
959  transform::TransformState &state) {
960  SmallVector<Operation *> fusedOps;
961  auto producerOps = state.getPayloadOps(getProducerOp());
962  auto containingOps = state.getPayloadOps(getContainingOp());
963  if (!llvm::hasSingleElement(containingOps)) {
964  return emitDefiniteFailure()
965  << "requires exactly one containing_op handle (got "
966  << llvm::range_size(containingOps) << ")";
967  }
968  Operation *containingOp = *containingOps.begin();
969 
970  // If nothing to fuse, propagate success.
971  if (std::empty(producerOps)) {
972  results.set(cast<OpResult>(getFusedOp()), SmallVector<mlir::Operation *>{});
973  results.set(cast<OpResult>(getNewContainingOp()), {containingOp});
975  }
976 
977  // Helper function to find the next producer that should be fused. Take any
978  // producer that has a use inside the containing op.
979  SetVector<Operation *> remainingProducers(producerOps.begin(),
980  producerOps.end());
981  auto getNextProducer = [&]() -> FailureOr<Operation *> {
982  for (const auto &it : enumerate(remainingProducers)) {
983  Operation *producerOp = it.value();
984  // The containing op may be a user of producerOp: use isAncestor.
985  int64_t numUsesInContainingOp =
986  llvm::count_if(producerOp->getUsers(), [&](Operation *op) {
987  return containingOp->isAncestor(op);
988  });
989  // TODO: When resolving the TODO below (no duplicate ops), take an op
990  // that has no use among the remaining producers. This is a topological
991  // sorting.
992  if (numUsesInContainingOp > 0) {
993  if (numUsesInContainingOp == 1)
994  remainingProducers.erase(remainingProducers.begin() + it.index());
995  return producerOp;
996  }
997  }
998  return failure();
999  };
1000 
1001  while (!remainingProducers.empty()) {
1002  auto nextProducer = getNextProducer();
1003  if (failed(nextProducer)) {
1004  auto diag = mlir::emitSilenceableFailure(getLoc())
1005  << "could not find next producer to fuse into container";
1006  diag.attachNote(containingOp->getLoc()) << "containing op";
1007  return diag;
1008  }
1009 
1010  Operation *producerOp = *nextProducer;
1011 
1012  // Default diagnostic, to be complemented with more failure information.
1014  diag << "could not fuse " << *producerOp << " into " << *containingOp;
1015 
1016  // TODO: If there are multiple uses of the producer in the containing op,
1017  // we currently tile/clone the op multiple times (once per use). In some
1018  // cases, we can tile/clone once and reuse the value for each use.
1019  // Futhermore, producers should then be traversed according to a
1020  // topological sorting.
1021  auto [tiledOps, newContainingOp] =
1022  tileAndFuseFirstExtractUse(rewriter, diag, producerOp, containingOp);
1023  if (!tiledOps.empty()) {
1024  LLVM_DEBUG(DBGS() << "\nFused a direct extract use\n" << *containingOp);
1025  fusedOps.append(tiledOps);
1026  if (newContainingOp) {
1027  // Update handles associated with the containing op so we don't need to
1028  // invalidate them. This is a hack to support better composability
1029  // between tiling and fusion while a proper mechanism is being
1030  // investigated.
1031  //
1032  // DO NOT replicate this elsewhere unless you understand what you are
1033  // doing.
1034  LogicalResult replacementStatus =
1035  rewriter.notifyPayloadOperationReplaced(containingOp,
1036  newContainingOp);
1037  (void)replacementStatus;
1038  assert(succeeded(replacementStatus) &&
1039  "unable to update transform state mapping");
1040  rewriter.eraseOp(containingOp);
1041  containingOp = newContainingOp;
1042  }
1043  continue;
1044  }
1045 
1046  SmallVector<Operation *> tiledContainingOpOperand =
1048  rewriter, diag, producerOp, containingOp);
1049  if (!tiledContainingOpOperand.empty()) {
1050  LLVM_DEBUG(DBGS() << "\nFused an extract use through block argument\n"
1051  << *containingOp);
1052  fusedOps.append(tiledContainingOpOperand);
1053  continue;
1054  }
1055 
1056  Operation *cloned =
1057  cloneAndFuseFirstUse(rewriter, diag, producerOp, containingOp);
1058  if (cloned) {
1059  LLVM_DEBUG(DBGS() << "\nFused an use by cloning\n" << *containingOp);
1060  fusedOps.push_back(cloned);
1061  continue;
1062  }
1064  }
1065 
1066  results.set(cast<OpResult>(getFusedOp()), fusedOps);
1067  results.set(cast<OpResult>(getNewContainingOp()), {containingOp});
1069 }
1070 
1071 void transform::FuseIntoContainingOp::getEffects(
1073  consumesHandle(getProducerOpMutable(), effects);
1074  onlyReadsHandle(getContainingOpMutable(), effects);
1075  producesHandle(getOperation()->getOpResults(), effects);
1076  modifiesPayload(effects);
1077 }
1078 
1079 //===----------------------------------------------------------------------===//
1080 // GeneralizeOp
1081 //===----------------------------------------------------------------------===//
1082 
1084 transform::GeneralizeOp::applyToOne(transform::TransformRewriter &rewriter,
1085  LinalgOp target,
1087  transform::TransformState &state) {
1088  // Exit early if no transformation is needed.
1089  if (isa<GenericOp>(target)) {
1090  results.push_back(target);
1092  }
1093  rewriter.setInsertionPoint(target);
1094  FailureOr<LinalgOp> generic = generalizeNamedOp(rewriter, target);
1095  if (succeeded(generic)) {
1096  results.push_back(generic->getOperation());
1098  }
1099  return emitDefaultSilenceableFailure(target);
1100 }
1101 
1102 //===----------------------------------------------------------------------===//
1103 // SpecializeOp
1104 //===----------------------------------------------------------------------===/
1105 
1107 transform::SpecializeOp::applyToOne(transform::TransformRewriter &rewriter,
1108  LinalgOp target,
1110  transform::TransformState &state) {
1111  // Exit early if the operation is not a generic.
1112  if (!isa<GenericOp>(target)) {
1113  results.push_back(target);
1115  }
1116  rewriter.setInsertionPoint(target);
1117  FailureOr<LinalgOp> named =
1118  specializeGenericOp(rewriter, cast<GenericOp>(target));
1119  if (succeeded(named)) {
1120  results.push_back(named->getOperation());
1122  }
1123  return emitDefaultSilenceableFailure(target);
1124 }
1125 
1126 //===----------------------------------------------------------------------===//
1127 // InterchangeOp
1128 //===----------------------------------------------------------------------===//
1129 
1131 transform::InterchangeOp::applyToOne(transform::TransformRewriter &rewriter,
1132  GenericOp target,
1134  transform::TransformState &state) {
1135  ArrayRef<int64_t> interchangeVector = getIteratorInterchange();
1136  // Exit early if no transformation is needed.
1137  if (interchangeVector.empty()) {
1138  results.push_back(target);
1140  }
1141 
1142  unsigned numLoops = cast<LinalgOp>(target.getOperation()).getNumLoops();
1143  if (interchangeVector.size() != numLoops) {
1144  return emitSilenceableError()
1145  << getIteratorInterchangeAttrName() << " has length ("
1146  << interchangeVector.size()
1147  << ") different from the number of loops in the target operation ("
1148  << numLoops << ")";
1149  }
1150  FailureOr<GenericOp> res = interchangeGenericOp(
1151  rewriter, target, SmallVector<unsigned>(interchangeVector));
1152  if (failed(res))
1153  return emitDefiniteFailure() << "failed to apply";
1154  results.push_back(res->getOperation());
1156 }
1157 
1158 LogicalResult transform::InterchangeOp::verify() {
1159  ArrayRef<int64_t> permutation = getIteratorInterchange();
1160  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
1161  if (!std::is_permutation(sequence.begin(), sequence.end(),
1162  permutation.begin(), permutation.end())) {
1163  return emitOpError()
1164  << "expects iterator_interchange to be a permutation, found "
1165  << getIteratorInterchange();
1166  }
1167  return success();
1168 }
1169 
1170 //===----------------------------------------------------------------------===//
1171 // LowerPackOp
1172 //===----------------------------------------------------------------------===//
1173 
1174 DiagnosedSilenceableFailure transform::LowerPackOp::applyToOne(
1175  transform::TransformRewriter &rewriter, tensor::PackOp target,
1176  transform::ApplyToEachResultList &transformResults,
1177  transform::TransformState &state) {
1178  rewriter.setInsertionPoint(target);
1179  FailureOr<LowerPackResult> res = lowerPack(rewriter, target);
1180  if (failed(res)) {
1181  return mlir::emitSilenceableFailure(target->getLoc())
1182  << "cannot lower to pad + expand + transpose";
1183  }
1184  transformResults.push_back(res->padOp);
1185  transformResults.push_back(res->expandShapeOp);
1186  transformResults.push_back(res->transposeOp);
1188 }
1189 
1190 //===----------------------------------------------------------------------===//
1191 // LowerUnPackOp
1192 //===----------------------------------------------------------------------===//
1193 
1194 DiagnosedSilenceableFailure transform::LowerUnPackOp::applyToOne(
1195  transform::TransformRewriter &rewriter, tensor::UnPackOp target,
1196  transform::ApplyToEachResultList &transformResults,
1197  transform::TransformState &state) {
1198  rewriter.setInsertionPoint(target);
1199  FailureOr<LowerUnPackOpResult> res = lowerUnPack(rewriter, target);
1200  if (failed(res)) {
1202  emitSilenceableError()
1203  << "cannot lower to transpose + collapse + extract";
1204  diag.attachNote(target->getLoc()) << "target payload op";
1205  return diag;
1206  }
1207  transformResults.push_back(res->emptyOp);
1208  transformResults.push_back(res->transposeOp);
1209  transformResults.push_back(res->collapseShapeOp);
1210  transformResults.push_back(res->extractSliceOp);
1212 }
1213 
1214 //===---------------------------------------------------------------------===//
1215 // MatchOp
1216 //===---------------------------------------------------------------------===//
1217 
1218 void transform::MatchOp::build(OpBuilder &builder, OperationState &result,
1219  Value target, ArrayRef<StringRef> opNames) {
1220  result.addOperands(target);
1221  result.addAttribute(MatchOp::getOpsAttrName(result.name),
1222  builder.getStrArrayAttr(opNames));
1223  result.addTypes(transform::AnyOpType::get(builder.getContext()));
1224 }
1225 
1226 void transform::MatchOp::build(OpBuilder &builder, OperationState &result,
1227  TypeRange resultTypes, Value target,
1228  ArrayRef<StringRef> opNames) {
1229  result.addOperands(target);
1230  result.addAttribute(MatchOp::getOpsAttrName(result.name),
1231  builder.getStrArrayAttr(opNames));
1232  result.addTypes(resultTypes);
1233 }
1234 
1236 transform::MatchOp::apply(transform::TransformRewriter &rewriter,
1237  transform::TransformResults &results,
1238  transform::TransformState &state) {
1239  llvm::StringSet<> strs;
1240  if (getOps().has_value())
1241  strs.insert(getOps()->getAsValueRange<StringAttr>().begin(),
1242  getOps()->getAsValueRange<StringAttr>().end());
1243 
1244  auto payloadOps = state.getPayloadOps(getTarget());
1245  if (!llvm::hasSingleElement(payloadOps)) {
1246  return emitDefiniteFailure("requires exactly one target handle");
1247  }
1248 
1250  bool incorrectNumOperandTypes = false;
1251  auto matchFun = [&](Operation *op) {
1252  if (getOps().has_value() && !strs.contains(op->getName().getStringRef()))
1253  return;
1254 
1255  // Interfaces cannot be matched by name, just by ID.
1256  // So we specifically encode the interfaces we care about for this op.
1257  if (getInterface().has_value()) {
1258  auto iface = getInterface().value();
1259  if (iface == transform::MatchInterfaceEnum::LinalgOp &&
1260  !isa<LinalgOp>(op))
1261  return;
1262  if (iface == transform::MatchInterfaceEnum::TilingInterface &&
1263  !isa<TilingInterface>(op))
1264  return;
1265  if (iface == transform::MatchInterfaceEnum::LoopLikeInterface &&
1266  !isa<LoopLikeOpInterface>(op))
1267  return;
1268  }
1269 
1270  // Check if all specified attributes match.
1271  if (getOpAttrs().has_value()) {
1272  DictionaryAttr opAttrs = getOpAttrs().value();
1273  for (NamedAttribute attr : opAttrs) {
1274  if (attr.getName() == getInterfaceAttrName() ||
1275  attr.getName() == getOpsAttrName())
1276  continue;
1277  if (!op->hasAttr(attr.getName()))
1278  return;
1279  if (op->getAttr(attr.getName()) != attr.getValue())
1280  return;
1281  }
1282  }
1283 
1284  if (getFilterResultType().has_value()) {
1285  Type t = getFilterResultType().value();
1286  if (op->getNumResults() != 1 || op->getResultTypes().front() != t)
1287  return;
1288  }
1289 
1290  if (getFilterOperandTypes().has_value()) {
1291  mlir::ArrayAttr types = getFilterOperandTypes().value();
1292  auto operandTypes = op->getOperandTypes();
1293 
1294  if (types.size() == 1) {
1295  // All the operands must must be equal to the specified type
1296  auto typeattr =
1297  dyn_cast<mlir::TypeAttr>(getFilterOperandTypes().value()[0]);
1298  Type t = cast<::mlir::Type>(typeattr.getValue());
1299  if (!llvm::all_of(op->getOperandTypes(),
1300  [&](Type operandType) { return operandType == t; }))
1301  return;
1302  } else {
1303  // The operand types must match all the types in the list (in the same
1304  // order in with they are specified)
1305  if (types.size() != operandTypes.size()) {
1306  incorrectNumOperandTypes = true;
1307  return;
1308  }
1309 
1310  for (auto [attr, operandType] :
1311  llvm::zip_equal(getFilterOperandTypes().value(), operandTypes)) {
1312  auto typeattr = cast<mlir::TypeAttr>(attr);
1313  Type type = cast<::mlir::Type>(typeattr.getValue());
1314 
1315  if (type != operandType)
1316  return;
1317  }
1318  }
1319  }
1320 
1321  // All constraints are satisfied.
1322  res.push_back(op);
1323  return;
1324  };
1325 
1326  (*payloadOps.begin())->walk(matchFun);
1327  if (incorrectNumOperandTypes)
1328  return emitDefiniteFailure("If filter_operand_types contains more than a "
1329  "type, then it must contain as much types as "
1330  "the number of operands in the target ops");
1331  results.set(cast<OpResult>(getResult()), res);
1333 }
1334 
1335 //===---------------------------------------------------------------------===//
1336 // MultiTileSizesOp
1337 //===---------------------------------------------------------------------===//
1338 
1340  Type targetType, Type lowSizeType, Type,
1341  Type) {
1342  printer.printFunctionalType(TypeRange{targetType}, TypeRange{lowSizeType});
1343 }
1344 
1345 static ParseResult parseMultitileSizesTypes(OpAsmParser &parser,
1346  Type &targetType, Type &lowSizeType,
1347  Type &highSizeType,
1348  Type &splitPointType) {
1349  FunctionType funcType;
1350  llvm::SMLoc typeLoc = parser.getCurrentLocation();
1351  if (failed(parser.parseType<FunctionType>(funcType)))
1352  return failure();
1353 
1354  if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
1355  parser.emitError(typeLoc) << "expects a trailing functional type with one "
1356  "argument and one result";
1357  }
1358  targetType = funcType.getInput(0);
1359  lowSizeType = highSizeType = splitPointType = funcType.getResult(0);
1360 
1361  return success();
1362 }
1363 
1364 DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne(
1365  transform::TransformRewriter &rewriter, LinalgOp target,
1367  if (isa<TransformParamTypeInterface>(getLowSize().getType())) {
1368  if (target.hasDynamicShape()) {
1369  auto diag = emitSilenceableError()
1370  << "cannot compute parametric tile sizes for dynamically "
1371  "shaped payload op";
1372  diag.attachNote(target->getLoc()) << "payload op";
1373  return diag;
1374  }
1375 
1376  FailureOr<StaticMultiSizeSpecification> spec = computeStaticMultiTileSizes(
1377  target, getDimension(), getTargetSize(), getDivisor());
1378  if (failed(spec)) {
1379  return emitSilenceableError()
1380  << "failed to compute multi-size tiling sizes";
1381  }
1382 
1383  Builder builder(target.getContext());
1384  results.assign(llvm::map_range(
1385  ArrayRef<int64_t>({spec->lowTileSize, spec->highTileSize,
1386  spec->lowTileSize * spec->lowTripCount}),
1387  [&builder, this](int64_t value) {
1388  return builder.getIntegerAttr(
1389  cast<ParamType>(getLowSize().getType()).getType(), value);
1390  }));
1392  }
1393 
1394  OpBuilder builder(target.getContext());
1395  builder.setInsertionPoint(target);
1396  OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
1397  OpFoldResult divisor = builder.getIndexAttr(getDivisor());
1398  FailureOr<MultiSizeSpecification> spec = computeMultiTileSizes(
1399  builder, target, getDimension(), targetSize, divisor);
1400  if (failed(spec)) {
1401  return emitSilenceableError() << "could not generate tile size computation";
1402  }
1403 
1404  AffineExpr s0 = builder.getAffineSymbolExpr(0);
1405  AffineExpr s1 = builder.getAffineSymbolExpr(1);
1406  Operation *splitPoint =
1407  affine::makeComposedAffineApply(builder, target.getLoc(), s0 * s1,
1408  {spec->lowTileSize, spec->lowTripCount});
1409  Operation *lowTileSize = spec->lowTileSize.getDefiningOp();
1410  Operation *highTileSize = spec->highTileSize.getDefiningOp();
1411  assert(lowTileSize && highTileSize && splitPoint &&
1412  "tile sizes are not produced by operations");
1413  results.reserve(results.size() + 3);
1414  results.push_back(lowTileSize);
1415  results.push_back(highTileSize);
1416  results.push_back(splitPoint);
1418 }
1419 
1420 void transform::MultiTileSizesOp::getEffects(
1422  onlyReadsHandle(getTargetMutable(), effects);
1423  producesHandle(getOperation()->getOpResults(), effects);
1424  if (isa<TransformParamTypeInterface>(getLowSize().getType()))
1425  onlyReadsPayload(effects);
1426  else
1427  modifiesPayload(effects);
1428 }
1429 
1430 LogicalResult transform::MultiTileSizesOp::verify() {
1431  if (getLowSize().getType() != getHighSize().getType() ||
1432  getLowSize().getType() != getSplitPoint().getType()) {
1433  return emitOpError() << "expects all results type to be the same";
1434  }
1435  return success();
1436 }
1437 
1438 //===---------------------------------------------------------------------===//
1439 // PackOp
1440 //===---------------------------------------------------------------------===//
1441 
1442 void transform::PackOp::build(OpBuilder &builder, OperationState &result,
1443  Value target,
1444  ArrayRef<OpFoldResult> mixedPackedSizes) {
1445  SmallVector<int64_t> staticPackedSizes;
1446  SmallVector<Value> dynamicPackedSizes;
1447  dispatchIndexOpFoldResults(mixedPackedSizes, dynamicPackedSizes,
1448  staticPackedSizes);
1449  // Call the default builder which sets up the proper operands segment sizes
1450  // attributes for multiple variadic operands. In the absence of this, horrible
1451  // bugs ensue.
1452  Type linalgOpHType = transform::OperationType::get(
1453  builder.getContext(), GenericOp::getOperationName());
1454  build(builder, result,
1455  /*resultType=*/linalgOpHType,
1456  /*target=*/target,
1457  /*dynamic_sizes=*/dynamicPackedSizes,
1458  /*static_sizes=*/builder.getDenseI64ArrayAttr(staticPackedSizes));
1459 }
1460 
1461 SmallVector<OpFoldResult> transform::PackOp::getMixedPackedSizes() {
1462  Builder b(getContext());
1463  return getMixedValues(getStaticPackedSizes(), getPackedSizes(), b);
1464 }
1465 
1467 transform::PackOp::apply(transform::TransformRewriter &rewriter,
1468  transform::TransformResults &transformResults,
1469  transform::TransformState &state) {
1470  auto targetOps = state.getPayloadOps(getTarget());
1471  // If nothing to pack, propagate success.
1472  if (std::empty(targetOps)) {
1473  transformResults.set(cast<OpResult>(getPackedOp()),
1474  ArrayRef<Operation *>({}));
1476  }
1477  // Fail on multi-op handles.
1478  auto linalgOp = dyn_cast<LinalgOp>(*targetOps.begin());
1479  if (!llvm::hasSingleElement(targetOps) || !linalgOp) {
1480  return emitSilenceableError()
1481  << "requires target to map to exactly 1 LinalgOp (got "
1482  << llvm::range_size(targetOps) << ")";
1483  }
1484  // Fail on mismatched number of pack sizes.
1485  if (getMixedPackedSizes().size() != linalgOp.getNumLoops()) {
1486  return emitSilenceableError()
1487  << "requires number of packed sizes match the number of loops ("
1488  << getMixedPackedSizes().size() << " vs " << linalgOp.getNumLoops()
1489  << ")";
1490  }
1491 
1492  // Unpack handles to constants or actual SSA index values.
1493  SmallVector<OpFoldResult> packedSizes;
1495  state, *this, packedSizes, getMixedPackedSizes());
1496 
1497  rewriter.setInsertionPoint(linalgOp);
1498  FailureOr<PackResult> maybeResult = pack(rewriter, linalgOp, packedSizes);
1499  if (failed(maybeResult))
1500  return emitDefiniteFailure("data tiling failed");
1501 
1502  transformResults.set(cast<OpResult>(getPackedOp()),
1503  {maybeResult->packedLinalgOp.getOperation()});
1505 }
1506 
1507 void transform::PackOp::getEffects(
1509  transform::consumesHandle(getTargetMutable(), effects);
1510  transform::onlyReadsHandle(getPackedSizesMutable(), effects);
1511  transform::producesHandle(getOperation()->getOpResults(), effects);
1512  transform::modifiesPayload(effects);
1513 }
1514 
1515 //===---------------------------------------------------------------------===//
1516 // PackGreedilyOp.
1517 //===---------------------------------------------------------------------===//
1518 
1519 LogicalResult transform::PackGreedilyOp::verify() {
1520  if (!isPermutationVector(getMatmulInnerDimsOrder())) {
1521  return emitOpError() << getMatmulInnerDimsOrderAttrName()
1522  << " is not a valid permutation";
1523  }
1524  // TODO: relax to allow empty once we have another strategy than just matmul.
1525  if (!getMatmulPaddedSizesNextMultipleOf().empty()) {
1526  for (auto [s, nmo] :
1527  llvm::zip_equal(getMixedMatmulPackedSizes(),
1528  getMatmulPaddedSizesNextMultipleOf())) {
1529  std::optional<int64_t> maybeStaticPackedSize = getConstantIntValue(s);
1530  if (nmo != 0 &&
1531  (!maybeStaticPackedSize.has_value() || *maybeStaticPackedSize != 0)) {
1532  return emitOpError() << "at most one of the packed_size and the "
1533  "padded_sizes_next_multiple_of can be nonzero "
1534  "for the matmul strategy";
1535  }
1536  }
1537  }
1538  return success();
1539 }
1540 
1542 PackGreedilyOp::apply(transform::TransformRewriter &rewriter,
1543  transform::TransformResults &transformResults,
1544  transform::TransformState &state) {
1545  SmallVector<Operation *> results;
1546  for (Operation *op : state.getPayloadOps(getTarget())) {
1547  auto linalgOp = dyn_cast<LinalgOp>(op);
1548  if (!linalgOp)
1549  continue;
1550  // linalgOp will be replaced and the insertion point may be invalidated if
1551  // we set it before -> set it after.
1552  rewriter.setInsertionPointAfter(linalgOp);
1553  // Failing to pack greedily is perfectly fine.
1554  // In the future we will want to order packings according to some metric.
1555  FailureOr<PackResult> packResult = packMatmulGreedily(
1556  /*rewriter=*/rewriter,
1557  /*linalgOp=*/linalgOp,
1558  /*mnkPackedSizes=*/getMixedMatmulPackedSizes(),
1559  /*mnkPaddedSizesNextMultipleOf=*/
1560  getMatmulPaddedSizesNextMultipleOf(),
1561  /*mnkOrder=*/getMatmulInnerDimsOrder());
1562  if (succeeded(packResult)) {
1563  results.push_back(packResult->packedLinalgOp);
1564  continue;
1565  }
1566  results.push_back(linalgOp);
1567  }
1568  transformResults.set(cast<OpResult>(getPackedOp()), results);
1570 }
1571 
1572 SmallVector<OpFoldResult> PackGreedilyOp::getMixedMatmulPackedSizes() {
1573  Builder b(getContext());
1574  return getMixedValues(getStaticMatmulPackedSizes(), getMatmulPackedSizes(),
1575  b);
1576 }
1577 
1578 void transform::PackGreedilyOp::getEffects(
1580  transform::consumesHandle(getTargetMutable(), effects);
1581  transform::onlyReadsHandle(getMatmulPackedSizesMutable(), effects);
1582  transform::producesHandle(getOperation()->getOpResults(), effects);
1583  transform::modifiesPayload(effects);
1584 }
1585 
1586 //===---------------------------------------------------------------------===//
1587 // PackTransposeOp
1588 //===---------------------------------------------------------------------===//
1589 
1590 LogicalResult transform::PackTransposeOp::verify() {
1591  if (!isPermutationVector(getInnerPerm())) {
1592  return emitOpError() << getInnerPermAttrName()
1593  << " is not a valid permutation";
1594  }
1595  if (!isPermutationVector(getOuterPerm())) {
1596  return emitOpError() << getOuterPermAttrName()
1597  << " is not a valid permutation";
1598  }
1599  if (getInnerPerm().empty() && getOuterPerm().empty()) {
1600  return emitOpError() << " at least one of " << getInnerPermAttrName()
1601  << " or " << getOuterPermAttrName()
1602  << " must be specified";
1603  }
1604  return success();
1605 }
1606 
1607 namespace {
1608 enum class OuterOrInnerPerm { Outer = 0, Inner = 1 };
1609 } // namespace
1610 
1611 /// Return true if `permutation` is a valid permutation of the
1612 /// `outer_dims_perm` (case OuterOrInnerPerm::Outer) or `inner_dims_pos`
1613 /// (OuterOrInnerPerm::Inner) of the `tensor.pack` or `tensor.unpack` `op.
1614 /// This is the case when the `permutation` rank matches the rank expected by
1615 /// `op` and `permutation` is itself a permutation vector.
1616 /// Return true if either `op` or `permutation` are empty to allow a simpler
1617 /// polymorphic implementation.
1618 template <typename RelayoutOpTy>
1620  RelayoutOpTy op, ArrayRef<int64_t> permutation,
1621  OuterOrInnerPerm outerOrInnerPerm = OuterOrInnerPerm::Outer) {
1622  static_assert(
1623  llvm::is_one_of<RelayoutOpTy, tensor::PackOp, tensor::UnPackOp>::value,
1624  "applies to only pack or unpack operations");
1625  if (!op || permutation.empty())
1626  return true;
1627  size_t innerRank = op.getInnerDimsPos().size();
1628  if (outerOrInnerPerm == OuterOrInnerPerm::Inner)
1629  return permutation.size() == innerRank && isPermutationVector(permutation);
1630  // op.getOuterDimsPerm() may be empty, in which case it is identity.
1631  // Don't rely on it.
1632  if (std::is_same<RelayoutOpTy, tensor::PackOp>::value) {
1633  return permutation.size() == op.getSourceRank() &&
1634  isPermutationVector(permutation);
1635  }
1636  return permutation.size() == op.getDestRank() &&
1637  isPermutationVector(permutation);
1638 }
1639 
1641 transform::PackTransposeOp::apply(transform::TransformRewriter &rewriter,
1642  transform::TransformResults &transformResults,
1643  transform::TransformState &state) {
1644  auto packOrUnpackOps = state.getPayloadOps(getTargetPackOrUnPackOp());
1645  auto linalgOps = state.getPayloadOps(getTargetLinalgOp());
1646  // Step 1. If nothing to pack, propagate success.
1647  if (std::empty(packOrUnpackOps)) {
1648  transformResults.set(cast<OpResult>(getPackedOp()), {});
1649  transformResults.set(cast<OpResult>(getPackOp()), {});
1650  transformResults.set(cast<OpResult>(getUnPackOp()), {});
1652  }
1653 
1654  // Step 2. Bunch of runtime sanity check and error messages.
1655  // Step 2.1. Fail on multi-op handles.
1656  if (!llvm::hasSingleElement(packOrUnpackOps) ||
1657  !llvm::hasSingleElement(linalgOps)) {
1658  return emitSilenceableError()
1659  << "requires target to map to exactly 1 "
1660  "packing op and 1 packed op ("
1661  << "got " << llvm::range_size(packOrUnpackOps) << " and "
1662  << llvm::range_size(linalgOps) << ")";
1663  }
1664 
1665  // Step 2.2. Fail on wrong type.
1666  auto packOp = dyn_cast<tensor::PackOp>(*packOrUnpackOps.begin());
1667  auto unPackOp = dyn_cast<tensor::UnPackOp>(*packOrUnpackOps.begin());
1668  if ((!packOp && !unPackOp)) {
1669  return emitSilenceableError() << "requires target to map to a "
1670  "tensor.pack or tensor.unpack";
1671  }
1672  LinalgOp linalgOpTarget = dyn_cast<LinalgOp>(*linalgOps.begin());
1673  if (!linalgOpTarget)
1674  return emitSilenceableError() << "requires a LinalgOp target";
1675 
1676  // Step 2.3. Fail if we can't get the producer / consumer Linalg op.
1677  LinalgOp linalgOp;
1678  if (packOp && packOp.getResult().hasOneUse())
1679  linalgOp = dyn_cast<LinalgOp>(*(packOp.getResult().getUsers().begin()));
1680  else if (unPackOp)
1681  linalgOp = unPackOp.getSource().getDefiningOp<LinalgOp>();
1682  if (linalgOp != linalgOpTarget) {
1683  auto errorMsg =
1684  packOp ? StringLiteral{"not a single use by the LinalgOp target"}
1685  : StringLiteral{"not produced by the LinalgOp target"};
1686  return emitSilenceableError() << errorMsg;
1687  }
1688 
1689  // Step 2.4. If we have an UnPackOp, we need to fetch the symmetrical
1690  // PackOp.
1691  if (unPackOp) {
1692  assert(!packOp && "packOp must be null on entry when unPackOp is not null");
1693  OpOperand *packUse = linalgOp.getDpsInitOperand(
1694  cast<OpResult>(unPackOp.getSource()).getResultNumber());
1695  packOp = dyn_cast_or_null<tensor::PackOp>(packUse->get().getDefiningOp());
1696  if (!packOp || !packOp.getResult().hasOneUse())
1697  return emitSilenceableError() << "could not find matching pack op";
1698  }
1699 
1700  // Step 2.5. Fail if any permutation does not validate.
1701  for (auto permType : {OuterOrInnerPerm::Outer, OuterOrInnerPerm::Inner}) {
1702  ArrayRef<int64_t> perm =
1703  (permType == OuterOrInnerPerm::Outer) ? getOuterPerm() : getInnerPerm();
1704  auto errorMsg = (permType == OuterOrInnerPerm::Outer)
1705  ? StringLiteral{"invalid outer_perm"}
1706  : StringLiteral{"invalid inner_perm"};
1707  if (!isValidPackingPermutation(packOp, perm, permType) ||
1708  !isValidPackingPermutation(unPackOp, perm, permType)) {
1709  Operation *packOrUnpackOp =
1710  unPackOp ? unPackOp.getOperation() : packOp.getOperation();
1711  return emitSilenceableError() << errorMsg << ": " << *packOrUnpackOp;
1712  }
1713  }
1714 
1715  // From here on, packOp and linalgOp are always present, unPackOp may or may
1716  // not be present.
1717  assert(packOp && linalgOp && "unexpected null op");
1718 
1719  // Step 3. Actually transpose the ops.
1720  FailureOr<PackTransposeResult> res = packTranspose(
1721  rewriter, packOp, linalgOp, unPackOp, getOuterPerm(), getInnerPerm());
1722  // Preconditions have been checked, it is an error to fail here.
1723  assert(succeeded(res) && "unexpected packTranspose failure");
1724 
1725  // Step 4. Return results.
1726  transformResults.set(cast<OpResult>(getPackOp()), {res->transposedPackOp});
1727  transformResults.set(cast<OpResult>(getPackedOp()),
1728  {res->transposedLinalgOp});
1729  if (unPackOp) {
1730  transformResults.set(cast<OpResult>(getUnPackOp()),
1731  {res->transposedUnPackOp});
1732  } else {
1733  transformResults.set(cast<OpResult>(getUnPackOp()), {});
1734  }
1735 
1737 }
1738 
1739 //===---------------------------------------------------------------------===//
1740 // PadOp
1741 //===---------------------------------------------------------------------===//
1742 
1743 void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
1744  ArrayRef<int64_t> paddingDimensions,
1745  ArrayRef<int64_t> padToMultipleOf,
1746  ArrayRef<int64_t> nofoldFlags,
1747  ArrayRef<Attribute> transposePaddings,
1748  StringRef copyBackOp) {
1749  auto resultType = transform::AnyOpType::get(b.getContext());
1750  return build(/*builder=*/b,
1751  /*result=*/result,
1752  /*types=*/TypeRange{resultType, resultType},
1753  /*target=*/target,
1754  /*paddingValues=*/ArrayAttr(), // let inference handle this
1755  /*paddingDimensions=*/b.getI64ArrayAttr(paddingDimensions),
1756  /*padToMultipleOf=*/ValueRange{},
1757  /*padToMultipleOf=*/
1758  (padToMultipleOf.empty()
1759  ? DenseI64ArrayAttr()
1760  : b.getDenseI64ArrayAttr(padToMultipleOf)),
1761  /*nofoldFlags=*/b.getI64ArrayAttr(nofoldFlags),
1762  /*transposePaddings=*/b.getArrayAttr(transposePaddings),
1763  /*copyBackOp=*/b.getStringAttr(copyBackOp));
1764 }
1765 
1766 void transform::PadOp::build(OpBuilder &b, OperationState &result, Value target,
1767  ArrayRef<int64_t> paddingDimensions,
1768  ArrayRef<OpFoldResult> mixedPadToMultipleOf,
1769  ArrayRef<int64_t> nofoldFlags,
1770  ArrayRef<Attribute> transposePaddings,
1771  StringRef copyBackOp) {
1772  auto resultType = transform::AnyOpType::get(b.getContext());
1773  SmallVector<int64_t> staticPadToMultipleOf;
1774  SmallVector<Value> dynamicPadToMultipleOf;
1775  dispatchIndexOpFoldResults(mixedPadToMultipleOf, dynamicPadToMultipleOf,
1776  staticPadToMultipleOf);
1777  return build(/*builder=*/b,
1778  /*result=*/result,
1779  /*types=*/TypeRange{resultType, resultType},
1780  /*target=*/target,
1781  /*paddingValues=*/ArrayAttr(), // let inference handle this
1782  /*paddingDimensions=*/b.getI64ArrayAttr(paddingDimensions),
1783  /*padToMultipleOf=*/dynamicPadToMultipleOf,
1784  /*padToMultipleOf=*/staticPadToMultipleOf,
1785  /*nofoldFlags=*/b.getI64ArrayAttr(nofoldFlags),
1786  /*transposePaddings=*/b.getArrayAttr(transposePaddings),
1787  /*copyBackOp=*/b.getStringAttr(copyBackOp));
1788 }
1789 
1790 void PadOp::getEffects(
1792  consumesHandle(getTargetMutable(), effects);
1793  onlyReadsHandle(getPadToMultipleOfMutable(), effects);
1794  producesHandle(getOperation()->getOpResults(), effects);
1795  modifiesPayload(effects);
1796 }
1797 
1798 SmallVector<OpFoldResult> PadOp::getMixedPadToMultipleOf() {
1799  Builder b(getContext());
1800  return getMixedValues(getStaticPadToMultipleOf(), getPadToMultipleOf(), b);
1801 }
1802 
1804 transform::PadOp::apply(transform::TransformRewriter &rewriter,
1805  transform::TransformResults &results,
1806  transform::TransformState &state) {
1807  auto transformOp = cast<TransformOpInterface>(getOperation());
1808  SmallVector<Operation *> paddedOps, padOps, copyBackOps;
1809 
1810  for (Operation *target : state.getPayloadOps(getTarget())) {
1811  auto linalgTarget = dyn_cast<LinalgOp>(target);
1812  if (!linalgTarget) {
1813  auto diag = emitSilenceableError() << "expected LinalgOp target";
1814  diag.attachNote(target->getLoc()) << "target op";
1815  return diag;
1816  }
1817 
1818  // Convert the integer packing flags to booleans.
1819  SmallVector<bool> nofoldFlags;
1820  for (int64_t packPadding :
1821  extractFromIntegerArrayAttr<int64_t>(getNofoldFlags()))
1822  nofoldFlags.push_back(static_cast<bool>(packPadding));
1823 
1824  // Convert the padding values to attributes.
1825  SmallVector<Attribute> paddingValues;
1826  for (auto const &it :
1827  llvm::zip(getPaddingValues(), linalgTarget->getOperandTypes())) {
1828  auto attr = dyn_cast<TypedAttr>(std::get<0>(it));
1829  if (!attr) {
1830  emitOpError("expects padding values to be typed attributes");
1832  }
1833  Type elementType = getElementTypeOrSelf(std::get<1>(it));
1834  // Try to parse string attributes to obtain an attribute of element type.
1835  if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
1836  auto parsedAttr = dyn_cast_if_present<TypedAttr>(parseAttribute(
1837  stringAttr, getContext(), elementType,
1838  /*numRead=*/nullptr, /*isKnownNullTerminated=*/true));
1839  if (!parsedAttr || parsedAttr.getType() != elementType) {
1840  auto diag = this->emitOpError("expects a padding that parses to ")
1841  << elementType << ", got " << std::get<0>(it);
1842  diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";
1844  }
1845  paddingValues.push_back(parsedAttr);
1846  continue;
1847  }
1848  // Otherwise, add the attribute directly.
1849  if (attr.getType() != elementType) {
1850  auto diag = this->emitOpError("expects a padding value of type ")
1851  << elementType << ", got " << attr;
1852  diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";
1854  }
1855  paddingValues.push_back(attr);
1856  }
1857 
1858  // Extract the transpose vectors.
1859  SmallVector<SmallVector<int64_t>> transposePaddings;
1860  for (Attribute transposeVector : cast<ArrayAttr>(getTransposePaddings()))
1861  transposePaddings.push_back(extractFromIntegerArrayAttr<int64_t>(
1862  cast<ArrayAttr>(transposeVector)));
1863 
1864  LinalgOp paddedOp;
1866  options.paddingDimensions =
1867  extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
1868 
1869  SmallVector<int64_t> padToMultipleOf;
1871  state, transformOp, getMixedPadToMultipleOf(), padToMultipleOf);
1872  if (!status.succeeded())
1873  return status;
1874  if (padToMultipleOf.empty())
1875  padToMultipleOf =
1876  SmallVector<int64_t>(options.paddingDimensions.size(), 1);
1877 
1878  options.padToMultipleOf = padToMultipleOf;
1879  options.paddingValues = paddingValues;
1880  options.nofoldFlags = nofoldFlags;
1881  if (getCopyBackOp() ==
1882  bufferization::MaterializeInDestinationOp::getOperationName()) {
1885  } else if (getCopyBackOp() == linalg::CopyOp::getOperationName()) {
1887  } else if (getCopyBackOp() == kCopyOpNone) {
1889  } else {
1890  llvm_unreachable("unsupported copy_back op");
1891  }
1892 
1893  SmallVector<Value> replacements;
1894  SmallVector<tensor::PadOp> newPadOps;
1895  if (failed(rewriteAsPaddedOp(rewriter, linalgTarget, options, paddedOp,
1896  replacements, newPadOps))) {
1897  auto diag = emitSilenceableError() << "failed to pad op";
1898  diag.attachNote(target->getLoc()) << "target op";
1899  return diag;
1900  }
1901 
1902  // We need to perform our own replacement here because this API is still
1903  // used in patterns that "pad and hoist", for which the replacement values
1904  // need to be different.
1905  // TODO: clean this up and stop "pad and hoist" behavior more globally now
1906  // that we have more composable abstractions.
1907  rewriter.replaceOp(linalgTarget, replacements);
1908  paddedOps.push_back(paddedOp);
1909  padOps.append(newPadOps.begin(), newPadOps.end());
1910  if (options.copyBackOp != LinalgPaddingOptions::CopyBackOp::None) {
1911  for (Value v : replacements) {
1912  Operation *copyBackOp = v.getDefiningOp();
1913  if (!llvm::is_contained(copyBackOps, copyBackOp))
1914  copyBackOps.push_back(copyBackOp);
1915  }
1916  }
1917  }
1918 
1919  results.set(cast<OpResult>(getPadded()), paddedOps);
1920  results.set(cast<OpResult>(getPad()), padOps);
1921  results.set(cast<OpResult>(getCopy()), copyBackOps);
1923 }
1924 
1925 LogicalResult transform::PadOp::verify() {
1926  SmallVector<int64_t> nofoldFlags =
1927  extractFromIntegerArrayAttr<int64_t>(getNofoldFlags());
1928  if (any_of(nofoldFlags, [](int64_t packPadding) {
1929  return packPadding != 0 && packPadding != 1;
1930  })) {
1931  return emitOpError()
1932  << "expects nofold_flags to contain booleans (0/1), found "
1933  << getNofoldFlags();
1934  }
1935 
1936  SmallVector<int64_t> paddingDimensions =
1937  extractFromIntegerArrayAttr<int64_t>(getPaddingDimensions());
1938  if (any_of(paddingDimensions,
1939  [](int64_t paddingDimension) { return paddingDimension < 0; })) {
1940  return emitOpError() << "expects padding_dimensions to contain positive "
1941  "integers, found "
1942  << getPaddingDimensions();
1943  }
1944  if (!getMixedPadToMultipleOf().empty()) {
1945  if (getMixedPadToMultipleOf().size() != paddingDimensions.size()) {
1946  return emitOpError() << "expects as many multiples as padding_dimensions";
1947  }
1948  }
1949  ArrayAttr transposes = getTransposePaddings();
1950  for (Attribute attr : transposes) {
1951  SmallVector<int64_t> transpose = extractFromIntegerArrayAttr<int64_t>(attr);
1952  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
1953  if (!std::is_permutation(sequence.begin(), sequence.end(),
1954  transpose.begin(), transpose.end())) {
1955  return emitOpError()
1956  << "expects transpose_paddings to be a permutation, found "
1957  << attr;
1958  }
1959  }
1960  if (getCopyBackOp() !=
1961  bufferization::MaterializeInDestinationOp::getOperationName() &&
1962  getCopyBackOp() != linalg::CopyOp::getOperationName() &&
1963  getCopyBackOp() != kCopyOpNone)
1964  return emitOpError() << "invalid copy_back_op";
1965  return success();
1966 }
1967 
1968 //===---------------------------------------------------------------------===//
1969 // HoistPadOp
1970 //===---------------------------------------------------------------------===//
1971 
1972 DiagnosedSilenceableFailure transform::HoistPadBuildPackingLoopNestOp::apply(
1973  transform::TransformRewriter &rewriter,
1974  transform::TransformResults &transformResults,
1975  transform::TransformState &state) {
1976  auto targetOps = state.getPayloadOps(getTarget());
1977  auto loopOps = state.getPayloadOps(getLoop());
1978  if (!llvm::hasSingleElement(targetOps) || !llvm::hasSingleElement(loopOps)) {
1979  return emitDefiniteFailure()
1980  << "requires exactly one target and one loop handle (got "
1981  << llvm::range_size(targetOps) << " and "
1982  << llvm::range_size(loopOps) << ")";
1983  }
1984 
1985  auto padOp = dyn_cast_or_null<tensor::PadOp>(*targetOps.begin());
1986  auto loopOp = dyn_cast_or_null<scf::ForOp>(*loopOps.begin());
1987  if (!padOp || !loopOp)
1988  return emitDefiniteFailure() << "requires exactly 2 non-null handles";
1989 
1990  FailureOr<linalg::detail::PackingResult> result =
1991  linalg::detail::buildPackingLoopNest(rewriter, padOp, loopOp,
1992  getTranspose());
1993  if (failed(result))
1994  return emitDefiniteFailure() << "could not build packing loop nest";
1995 
1996  if (result->clonedLoopIvs.empty()) {
1997  transformResults.set(cast<OpResult>(getPackingLoop()),
1998  {result->hoistedPadOp.getOperation()});
2000  }
2001  auto outerPackedLoop =
2002  scf::getForInductionVarOwner(result->clonedLoopIvs.front());
2003  transformResults.set(cast<OpResult>(getPackingLoop()),
2004  {outerPackedLoop.getOperation()});
2006 }
2007 
2009  ArrayRef<int64_t> transpose = getTranspose();
2010  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2011  if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),
2012  transpose.end())) {
2013  return emitOpError() << "expects transpose to be a permutation, found "
2014  << getTranspose();
2015  }
2016  return success();
2017 }
2018 
2019 void transform::HoistPadBuildPackingLoopNestOp::getEffects(
2021  transform::onlyReadsHandle(getTargetMutable(), effects);
2022  transform::onlyReadsHandle(getLoopMutable(), effects);
2023  transform::producesHandle(getOperation()->getOpResults(), effects);
2024  transform::modifiesPayload(effects);
2025 }
2026 
2028 transform::HoistPadOp::applyToOne(transform::TransformRewriter &rewriter,
2029  tensor::PadOp target,
2031  transform::TransformState &state) {
2032  tensor::PadOp hoistedPadOp;
2033  SmallVector<TransposeOp> transposeOps;
2034  FailureOr<Value> result =
2035  hoistPaddingOnTensors(rewriter, target, getNumLoops(), getTranspose(),
2036  hoistedPadOp, transposeOps);
2037  if (succeeded(result)) {
2038  // We need to perform our own replacement here because this API is still
2039  // used in patterns that "pad and hoist", for which the replacement values
2040  // need to be different.
2041  // TODO: clean this up and stop "pad and hoist" behavior more globally now
2042  // that we have more composable abstractions.
2043  rewriter.replaceOp(target, *result);
2044  results.push_back(hoistedPadOp);
2046  }
2047  return emitDefaultSilenceableFailure(target);
2048 }
2049 
2050 LogicalResult transform::HoistPadOp::verify() {
2051  ArrayRef<int64_t> transpose = getTranspose();
2052  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
2053  if (!std::is_permutation(sequence.begin(), sequence.end(), transpose.begin(),
2054  transpose.end())) {
2055  return emitOpError() << "expects transpose to be a permutation, found "
2056  << getTranspose();
2057  }
2058  return success();
2059 }
2060 
2061 //===----------------------------------------------------------------------===//
2062 // PromoteOp
2063 //===----------------------------------------------------------------------===//
2064 
2066 transform::PromoteOp::applyToOne(transform::TransformRewriter &rewriter,
2067  LinalgOp target,
2069  transform::TransformState &state) {
2070  LinalgPromotionOptions promotionOptions;
2071  if (!getOperandsToPromote().empty())
2072  promotionOptions = promotionOptions.setOperandsToPromote(
2073  extractFromIntegerArrayAttr<int64_t>(getOperandsToPromote()));
2074  if (getUseFullTilesByDefault())
2075  promotionOptions = promotionOptions.setUseFullTileBuffersByDefault(
2076  getUseFullTilesByDefault());
2077  if (getUseAlloca())
2078  promotionOptions = promotionOptions.setUseAlloca(getUseAlloca());
2079  if (!getUseFullTileBuffers().empty())
2080  promotionOptions = promotionOptions.setUseFullTileBuffers(
2081  llvm::to_vector(getUseFullTileBuffers().getAsValueRange<BoolAttr>()));
2082  if (getAlignment().has_value())
2083  promotionOptions = promotionOptions.setAlignment(*getAlignment());
2084  if (getMemorySpace().has_value())
2085  promotionOptions = promotionOptions.setMemorySpace(*getMemorySpace());
2086 
2087  if (getMapping().has_value()) {
2088  // The mapping should only contain an element
2089  auto mapping = *getMapping();
2090  if (mapping.size() > 1)
2091  return emitDefaultDefiniteFailure(target);
2092 
2093  auto addressSpace = cast<mlir::gpu::GPUMemorySpaceMappingAttr>(mapping[0]);
2094 
2095  if (addressSpace.getAddressSpace() ==
2096  mlir::gpu::GPUDialect::getWorkgroupAddressSpace()) {
2097  promotionOptions =
2098  promotionOptions
2102  .setUseFullTileBuffers({false, false});
2103  } else if (addressSpace.getAddressSpace() ==
2104  mlir::gpu::GPUDialect::getPrivateAddressSpace()) {
2105  promotionOptions =
2106  promotionOptions
2110  .setUseFullTileBuffers({false, false});
2111  } else {
2112  return emitDefaultDefiniteFailure(target);
2113  }
2114  }
2115 
2116  if (failed(promoteSubviewsPrecondition(target, promotionOptions)))
2117  return emitDefaultDefiniteFailure(target);
2118 
2119  rewriter.setInsertionPoint(target);
2120  FailureOr<LinalgOp> res = promoteSubViews(rewriter, target, promotionOptions);
2121  if (failed(res))
2122  return emitDefaultDefiniteFailure(target);
2123  results.push_back(target);
2125 }
2126 
2127 //===----------------------------------------------------------------------===//
2128 // ReplaceOp
2129 //===----------------------------------------------------------------------===//
2130 
2132 transform::ReplaceOp::apply(transform::TransformRewriter &rewriter,
2133  TransformResults &transformResults,
2134  TransformState &state) {
2135  auto payload = state.getPayloadOps(getTarget());
2136 
2137  // Check for invalid targets.
2138  for (Operation *target : payload) {
2139  if (target->getNumOperands() > 0)
2140  return emitDefiniteFailure() << "expected target without operands";
2141  if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
2142  target->getNumRegions() > 0)
2143  return emitDefiniteFailure()
2144  << "expected target that is isolated from above";
2145  }
2146 
2147  // Clone and replace.
2148  Operation *pattern = &getBodyRegion().front().front();
2149  SmallVector<Operation *> replacements;
2150  for (Operation *target : payload) {
2151  if (getOperation()->isAncestor(target))
2152  continue;
2153  rewriter.setInsertionPoint(target);
2154  Operation *replacement = rewriter.clone(*pattern);
2155  rewriter.replaceOp(target, replacement->getResults());
2156  replacements.push_back(replacement);
2157  }
2158  transformResults.set(cast<OpResult>(getReplacement()), replacements);
2160 }
2161 
2162 void transform::ReplaceOp::getEffects(
2164  consumesHandle(getTargetMutable(), effects);
2165  producesHandle(getOperation()->getOpResults(), effects);
2166  modifiesPayload(effects);
2167 }
2168 
2169 LogicalResult transform::ReplaceOp::verify() {
2170  if (!getBodyRegion().hasOneBlock())
2171  return emitOpError() << "expected one block";
2172  if (std::distance(getBodyRegion().front().begin(),
2173  getBodyRegion().front().end()) != 1)
2174  return emitOpError() << "expected one operation in block";
2175  Operation *replacement = &getBodyRegion().front().front();
2176  if (replacement->getNumOperands() > 0)
2177  return replacement->emitOpError()
2178  << "expected replacement without operands";
2179  if (!replacement->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
2180  replacement->getNumRegions() > 0)
2181  return replacement->emitOpError()
2182  << "expect op that is isolated from above";
2183  return success();
2184 }
2185 
2186 //===----------------------------------------------------------------------===//
2187 // ScalarizeOp
2188 //===----------------------------------------------------------------------===//
2189 
2191 transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
2192  LinalgOp target,
2194  transform::TransformState &state) {
2195  scf::SCFTilingOptions tilingOptions;
2196  tilingOptions.setTileSizeComputationFunction([&](OpBuilder &b, Operation *) {
2197  SmallVector<OpFoldResult> tileSizes;
2198  Location loc = target.getLoc();
2199  SmallVector<OpFoldResult> allShapeSizes =
2200  target.createFlatListOfOperandDims(b, loc);
2201  AffineMap map = target.getShapesToLoopsMap();
2202  if (!map)
2203  return tileSizes;
2204  SmallVector<OpFoldResult> shapeSizes =
2206  allShapeSizes);
2207  // If the shape size is dynamic, tile by 1.
2208  // Otherwise, do not tile (i.e. tile size 0).
2209  for (OpFoldResult shapeSize : shapeSizes) {
2210  tileSizes.push_back(getConstantIntValue(shapeSize) ? b.getIndexAttr(0)
2211  : b.getIndexAttr(1));
2212  }
2213  return tileSizes;
2214  });
2215  SmallVector<int64_t> emptyTileSizes;
2216  rewriter.setInsertionPoint(target);
2217  FailureOr<scf::SCFTilingResult> maybeTilingResult = tileUsingSCF(
2218  rewriter, cast<TilingInterface>(target.getOperation()), tilingOptions);
2219  if (failed(maybeTilingResult))
2220  return emitDefaultDefiniteFailure(target);
2221 
2222  if (target->getNumResults())
2223  rewriter.replaceOp(target, maybeTilingResult->replacements);
2224  else
2225  rewriter.eraseOp(target);
2226 
2227  results.reserve(maybeTilingResult->tiledOps.size());
2228  for (Operation *tiled : maybeTilingResult->tiledOps)
2229  results.push_back(tiled);
2231 }
2232 
2233 //===----------------------------------------------------------------------===//
2234 // ConvertToLoopsOp
2235 //===----------------------------------------------------------------------===//
2236 
2238 transform::ConvertToLoopsOp::apply(transform::TransformRewriter &rewriter,
2239  transform::TransformResults &results,
2240  transform::TransformState &state) {
2242  for (Operation *target : state.getPayloadOps(getTarget())) {
2243  auto tilingOp = dyn_cast<TilingInterface>(*target);
2244  if (!target) {
2246  emitSilenceableError()
2247  << "expected the payload to implement TilingInterface";
2248  diag.attachNote(target->getLoc()) << "payload op";
2249  return diag;
2250  }
2251  rewriter.setInsertionPoint(target);
2252  FailureOr<SmallVector<scf::ForOp>> generatedLoops =
2253  scf::lowerToLoopsUsingSCFForOp(rewriter, tilingOp);
2254  if (failed(generatedLoops))
2255  return emitDefaultDefiniteFailure(target);
2256  for (scf::ForOp &loop : *generatedLoops) {
2257  loops.push_back(loop.getOperation());
2258  }
2259  rewriter.eraseOp(target);
2260  }
2261  results.set(cast<OpResult>(getResult()), loops);
2263 }
2264 
2265 //===----------------------------------------------------------------------===//
2266 // RewriteInDestinationPassingStyleOp
2267 //===----------------------------------------------------------------------===//
2268 
2270 transform::RewriteInDestinationPassingStyleOp::applyToOne(
2271  transform::TransformRewriter &rewriter, Operation *target,
2273  transform::TransformState &state) {
2275  rewriter.setInsertionPoint(target);
2276  FailureOr<Operation *> maybeResult =
2278  .Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp>(
2279  [&rewriter](auto op) {
2280  return rewriteInDestinationPassingStyle(rewriter, op);
2281  });
2282  if (failed(maybeResult))
2283  return emitDefaultSilenceableFailure(target);
2284  results.push_back(*maybeResult);
2286 }
2287 
2288 //===----------------------------------------------------------------------===//
2289 // SplitOp
2290 //===----------------------------------------------------------------------===//
2291 
2293 SplitOp::apply(transform::TransformRewriter &rewriter,
2294  TransformResults &results, TransformState &state) {
2295  // Collect the dynamic split points if provided.
2296  SmallVector<Operation *> payload =
2297  llvm::to_vector(state.getPayloadOps(getTarget()));
2298 
2299  bool isMultiwaySplit = getMultiway();
2300 
2301  if (isMultiwaySplit && !llvm::hasSingleElement(payload)) {
2302  return mlir::emitSilenceableFailure(getLoc())
2303  << "requires exactly one target when "
2304  "multiway split is enabled (got "
2305  << llvm::range_size(payload) << ")";
2306  }
2307 
2308  SmallVector<OpFoldResult> chunkSizes;
2309 
2310  if (!isMultiwaySplit)
2311  chunkSizes.reserve(payload.size());
2312 
2313  if (getDynamicChunkSizes()) {
2315  if (isa<TransformHandleTypeInterface>(getDynamicChunkSizes().getType())) {
2316  chunkSizes = llvm::to_vector(llvm::map_range(
2317  state.getPayloadOps(getDynamicChunkSizes()), [&](Operation *op) {
2318  if (op->getNumResults() != 1 ||
2319  !op->getResult(0).getType().isIndex()) {
2320  diag = emitSilenceableError()
2321  << "expected dynamic split point handle to point to a "
2322  "single-result index-typed op";
2323  diag.attachNote(op->getLoc()) << "dynamic split point";
2324  }
2325  return OpFoldResult(op->getResult(0));
2326  }));
2327  } else {
2328  chunkSizes = llvm::to_vector(
2329  llvm::map_range(state.getParams(getDynamicChunkSizes()),
2330  [](Attribute attr) { return OpFoldResult(attr); }));
2331  }
2332  if (diag.isSilenceableFailure())
2333  return diag;
2334 
2335  // For multiway split, a single payload is expected to have multiple
2336  // split points.
2337  if (!isMultiwaySplit && chunkSizes.size() != payload.size()) {
2338  return emitDefiniteFailure()
2339  << "expected the dynamic split point handle to point to as "
2340  "many operations ("
2341  << chunkSizes.size() << ") as the target handle ("
2342  << payload.size() << ")";
2343  }
2344  } else {
2345  chunkSizes.resize(payload.size(),
2346  rewriter.getIndexAttr(getStaticChunkSizes()));
2347  }
2348 
2349  auto checkStructuredOpAndDimensions =
2350  [&](LinalgOp linalgOp, Location loc) -> DiagnosedSilenceableFailure {
2351  if (!linalgOp) {
2352  auto diag = emitSilenceableError() << "only applies to structured ops";
2353  diag.attachNote(loc) << "target op";
2354  return diag;
2355  }
2356 
2357  if (getDimension() >= linalgOp.getNumLoops()) {
2358  auto diag = emitSilenceableError() << "dimension " << getDimension()
2359  << " does not exist in target op";
2360  diag.attachNote(loc) << "target op";
2361  return diag;
2362  }
2364  };
2365 
2366  auto checkFailureInSplitting =
2367  [&](bool hasFailed, Location loc) -> DiagnosedSilenceableFailure {
2368  if (hasFailed) {
2369  auto diag = emitDefiniteFailure() << "internal failure in splitting";
2370  diag.attachNote(loc) << "target op";
2371  return diag;
2372  }
2374  };
2375 
2376  SmallVector<Operation *> opList;
2377  if (isMultiwaySplit) {
2378 
2379  // Split a single target operation at multiple points.
2380  TilingInterface head, tail;
2381  Operation *target = payload.front();
2382 
2383  LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
2384 
2385  // Check that the target is a valid LinalgOp with correct dimensions.
2387  checkStructuredOpAndDimensions(linalgOp, target->getLoc());
2388  if (diag.isSilenceableFailure())
2389  return diag;
2390 
2391  for (auto &&[idx, chunkSize] : llvm::enumerate(chunkSizes)) {
2392 
2393  if (idx > 0)
2394  target = tail.getOperation();
2395 
2396  if (!target)
2397  break;
2398 
2399  linalgOp = cast<LinalgOp>(target);
2400  Location loc = target->getLoc();
2401 
2402  rewriter.setInsertionPoint(linalgOp);
2403  std::tie(head, tail) = linalg::splitOp(
2404  rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2405  getDimension(), chunkSize);
2406 
2407  // Propagate errors.
2409  checkFailureInSplitting(!head && !tail, loc);
2410  if (diag.isDefiniteFailure())
2411  return diag;
2412 
2413  opList.push_back(head.getOperation());
2414  }
2415 
2416  // Append any leftover parts to the end of the result list.
2417  if (tail)
2418  opList.push_back(tail.getOperation());
2419 
2420  } else {
2421  // Split each target operation.
2422  SmallVector<Operation *> first, second;
2423  Operation *noSecondPart = nullptr;
2424  for (const auto &pair : llvm::zip(payload, chunkSizes)) {
2425  Operation *target = std::get<0>(pair);
2426  Location loc = target->getLoc();
2427  LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
2429  checkStructuredOpAndDimensions(linalgOp, target->getLoc());
2430 
2431  if (diag.isSilenceableFailure())
2432  return diag;
2433 
2434  rewriter.setInsertionPoint(linalgOp);
2435  std::tie(first.emplace_back(), second.emplace_back()) = linalg::splitOp(
2436  rewriter, cast<TilingInterface>(linalgOp.getOperation()),
2437  getDimension(), std::get<1>(pair));
2438 
2439  // Propagate errors.
2440  DiagnosedSilenceableFailure diagSplit =
2441  checkFailureInSplitting(!first.back() && !second.back(), loc);
2442  if (diagSplit.isDefiniteFailure())
2443  return diag;
2444 
2445  // Do not add null second parts.
2446  if (!second.back()) {
2447  noSecondPart = target;
2448  second.pop_back();
2449  }
2450  }
2451 
2452  if (second.size() != first.size() && !second.empty()) {
2453  auto diag = emitSilenceableError()
2454  << "splitting does not produce the second part for a subset "
2455  "of targets";
2456  diag.attachNote()
2457  << "expected splitting to produce the second part of all "
2458  "or none of the targets";
2459  diag.attachNote(noSecondPart->getLoc())
2460  << "first target with no second part";
2461  return diag;
2462  }
2463 
2464  opList.append(first);
2465  if (second.size())
2466  opList.append(second);
2467  }
2468  results.set(cast<OpResult>(getSplitList()), opList);
2470 }
2471 
2472 void SplitOp::getEffects(
2474  consumesHandle(getTargetMutable(), effects);
2475  if (getDynamicChunkSizes())
2476  onlyReadsHandle(getDynamicChunkSizesMutable(), effects);
2477  producesHandle(getOperation()->getOpResults(), effects);
2478  modifiesPayload(effects);
2479 }
2480 
2481 ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) {
2482  OpAsmParser::UnresolvedOperand target, dynamicChunkSizes;
2483  IntegerAttr staticChunkSizes;
2484  if (parser.parseOperand(target) || parser.parseKeyword("after"))
2485  return failure();
2486 
2487  OptionalParseResult dynamicPointParseResult =
2488  parser.parseOptionalOperand(dynamicChunkSizes);
2489  if (!dynamicPointParseResult.has_value()) {
2490  int64_t staticChunkSizesValue;
2491  if (failed(parser.parseInteger(staticChunkSizesValue)))
2492  return failure();
2493 
2494  staticChunkSizes =
2495  parser.getBuilder().getI64IntegerAttr(staticChunkSizesValue);
2496  }
2497 
2498  Type targetType;
2499  if (parser.parseOptionalAttrDict(result.attributes) ||
2500  parser.parseColonType(targetType) ||
2501  parser.resolveOperand(target, targetType, result.operands)) {
2502  return failure();
2503  }
2504  if (dynamicPointParseResult.has_value()) {
2505  Type ChunkSizesType;
2506  if (failed(*dynamicPointParseResult) || parser.parseComma() ||
2507  parser.parseType(ChunkSizesType) ||
2508  parser.resolveOperand(dynamicChunkSizes, ChunkSizesType,
2509  result.operands)) {
2510  return failure();
2511  }
2512 
2513  staticChunkSizes =
2514  parser.getBuilder().getI64IntegerAttr(ShapedType::kDynamic);
2515  }
2516 
2517  result.addAttribute(
2518  SplitOp::getStaticChunkSizesAttrName(result.name).getValue(),
2519  staticChunkSizes);
2520  result.addTypes(targetType);
2521  return success();
2522 }
2523 
2524 void SplitOp::print(OpAsmPrinter &printer) {
2525  printer << " " << getTarget() << " after ";
2526  int64_t staticChunkSize = static_cast<int64_t>(getStaticChunkSizes());
2527  if (staticChunkSize != ShapedType::kDynamic)
2528  printer << staticChunkSize;
2529  else
2530  printer << getDynamicChunkSizes();
2531  printer << " ";
2532  printer.printOptionalAttrDict(getOperation()->getAttrs(),
2533  {getStaticChunkSizesAttrName()});
2534  printer << " : " << getTarget().getType();
2535  if (staticChunkSize == ShapedType::kDynamic)
2536  printer << ", " << getDynamicChunkSizes().getType();
2537 }
2538 
2539 LogicalResult SplitOp::verify() {
2540  if ((static_cast<int64_t>(getStaticChunkSizes()) != ShapedType::kDynamic) ^
2541  (getDynamicChunkSizes() == nullptr)) {
2542  return emitOpError() << "expects either a dynamic or a static split "
2543  "point to be provided";
2544  }
2545  return success();
2546 }
2547 
2548 //===----------------------------------------------------------------------===//
2549 // SplitReductionOp
2550 //===----------------------------------------------------------------------===//
2551 
2552 void transform::SplitReductionOp::build(
2553  OpBuilder &builder, OperationState &result, Value target,
2554  int64_t splitFactor, int64_t insertSplitDimension, bool innerParallel,
2555  bool useScalingAlgorithm, bool useAlloc) {
2556  MLIRContext *ctx = builder.getContext();
2557  result.addOperands(target);
2558  result.addAttribute(SplitReductionOp::getSplitFactorAttrName(result.name),
2559  builder.getI64IntegerAttr(splitFactor));
2560  result.addAttribute(
2561  SplitReductionOp::getInsertSplitDimensionAttrName(result.name),
2562  builder.getI64IntegerAttr(insertSplitDimension));
2563  if (innerParallel) {
2564  result.addAttribute(SplitReductionOp::getInnerParallelAttrName(result.name),
2565  builder.getUnitAttr());
2566  }
2567  if (useScalingAlgorithm) {
2568  result.addAttribute(
2569  SplitReductionOp::getUseScalingAlgorithmAttrName(result.name),
2570  builder.getUnitAttr());
2571  }
2572  if (useAlloc) {
2573  result.addAttribute(SplitReductionOp::getUseAllocAttrName(result.name),
2574  builder.getUnitAttr());
2575  }
2576  auto resultType = transform::AnyOpType::get(ctx);
2577  result.addTypes({resultType, resultType, resultType, resultType});
2578 }
2579 
2580 DiagnosedSilenceableFailure transform::SplitReductionOp::applyToOne(
2581  transform::TransformRewriter &rewriter, LinalgOp target,
2583  transform::TransformState &state) {
2584  ControlSplitReductionFn splitFn = [&](LinalgOp) {
2585  return linalg::SplitReductionOptions{int64_t(getSplitFactor()),
2586  unsigned(getInsertSplitDimension()),
2587  bool(getInnerParallel())};
2588  };
2589  rewriter.setInsertionPoint(target);
2590  FailureOr<SplitReductionResult> splitResult =
2591  (getUseScalingAlgorithm())
2592  ? splitReductionByScaling(rewriter, target, splitFn, getUseAlloc())
2593  : splitReduction(rewriter, target, splitFn, getUseAlloc());
2594  if (failed(splitResult))
2595  return emitDefaultDefiniteFailure(target);
2596 
2597  results.push_back(splitResult->initOrAlloc);
2598  results.push_back(splitResult->fillOp);
2599  results.push_back(splitResult->splitLinalgOp);
2600  results.push_back(splitResult->resultCombiningLinalgOp);
2602 }
2603 
2604 //===----------------------------------------------------------------------===//
2605 // TileReductionUsingForOp
2606 //===----------------------------------------------------------------------===//
2607 
2608 void transform::TileReductionUsingForOp::build(
2609  OpBuilder &builder, OperationState &result, Value target,
2610  ArrayRef<int64_t> staticTileSizes) {
2611  // Call the default builder.
2612  // This is future-proof re mixed static-dynamic and setting up the proper
2613  // operands segment sizes attributes for multiple variadic operands.
2614  // In the absence of this, horrible bugs ensue.
2615  // TODO: support mixed static-dynamic (see TileUsingForallOp).
2616  MLIRContext *ctx = builder.getContext();
2617  auto opTy = transform::AnyOpType::get(ctx);
2618  auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
2619  build(builder, result,
2620  /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
2621  /*target=*/target,
2622  /*tile_sizes=*/staticTileSizesAttr);
2623 }
2624 
2625 DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
2626  transform::TransformRewriter &rewriter, LinalgOp target,
2628  transform::TransformState &state) {
2629  rewriter.setInsertionPoint(target);
2630  FailureOr<scf::SCFReductionTilingResult> result = scf::tileReductionUsingScf(
2631  rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
2632  getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes())));
2633 
2634  if (failed(result))
2635  return emitDefaultSilenceableFailure(target);
2636  for (Value initValue : result->initialValues)
2637  results.push_back(initValue.getDefiningOp());
2638  for (auto parallelTiledOp : result->parallelTiledOps)
2639  results.push_back(parallelTiledOp);
2640  for (auto mergeOp : result->mergeOps)
2641  results.push_back(mergeOp);
2642  results.push_back(result->loops.front());
2644 }
2645 
2646 //===----------------------------------------------------------------------===//
2647 // TileReductionUsingForallOp
2648 //===----------------------------------------------------------------------===//
2649 
2650 void transform::TileReductionUsingForallOp::build(
2651  OpBuilder &builder, OperationState &result, Value target,
2652  ArrayRef<int64_t> staticNumThreads, ArrayRef<int64_t> staticTileSizes,
2653  ArrayAttr mapping) {
2654  // Call the default builder.
2655  // This is future-proof re mixed static-dynamic and setting up the proper
2656  // operands segment sizes attributes for multiple variadic operands.
2657  // In the absence of this, horrible bugs ensue.
2658  // TODO: support mixed static-dynamic (see TileUsingForallOp).
2659  MLIRContext *ctx = builder.getContext();
2660  auto opTy = transform::AnyOpType::get(ctx);
2661  auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads);
2662  auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
2663  build(builder, result,
2664  /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
2665  /*target=*/target,
2666  /*num_threads=*/staticNumThreadsAttr,
2667  /*tile_sizes=*/staticTileSizesAttr,
2668  /*mapping=*/mapping);
2669 }
2670 
2671 DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
2672  transform::TransformRewriter &rewriter, LinalgOp target,
2674  transform::TransformState &state) {
2675  rewriter.setInsertionPoint(target);
2676  SmallVector<OpFoldResult> numThreads =
2677  getAsOpFoldResult(rewriter.getI64ArrayAttr(getNumThreads()));
2678  SmallVector<OpFoldResult> tileSizes =
2679  getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes()));
2680  FailureOr<linalg::ForallReductionTilingResult> result =
2682  rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
2683  numThreads, tileSizes, getMapping());
2684 
2685  if (failed(result)) {
2686  auto diag = emitSilenceableError() << "could not tile reduction";
2687  diag.attachNote(target.getLoc()) << "target operation";
2688  return diag;
2689  }
2690  for (Value initValue : result->initialValues)
2691  results.push_back(initValue.getDefiningOp());
2692  for (auto parallelTiledOp : result->parallelTiledOps)
2693  results.push_back(parallelTiledOp);
2694  for (auto mergeOp : result->mergeOps)
2695  results.push_back(mergeOp);
2696  results.push_back(result->loops);
2698 }
2699 
2700 //===----------------------------------------------------------------------===//
2701 // ContinuousTileSizesOp
2702 //===----------------------------------------------------------------------===//
2703 
2705 transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
2706  TransformResults &transformResults,
2707  TransformState &state) {
2708 
2709  SmallVector<Operation *> targetOps =
2710  llvm::to_vector(state.getPayloadOps(getTarget()));
2711 
2712  if (!llvm::hasSingleElement(targetOps)) {
2713  return mlir::emitSilenceableFailure(getLoc())
2714  << "requires exactly one target (got " << llvm::range_size(targetOps)
2715  << ")";
2716  }
2717 
2718  Operation *target = *targetOps.begin();
2719  auto linalgOp = dyn_cast<LinalgOp>(target);
2720  auto tileableOp = dyn_cast<TilingInterface>(target);
2721 
2722  if (!linalgOp)
2723  return emitDefiniteFailure() << "expected Linalg Op";
2724 
2725  OpBuilder builder(linalgOp.getContext());
2726 
2727  if (isa<TransformParamTypeInterface>(getChunkSizes().getType())) {
2728  if (linalgOp.hasDynamicShape()) {
2729  auto diag = emitSilenceableError()
2730  << "cannot compute parametric tile sizes for dynamically "
2731  "shaped payload op";
2732  diag.attachNote(linalgOp->getLoc()) << "payload op";
2733  return diag;
2734  }
2735 
2736  FailureOr<StaticContinuousTileSizeSpecification> spec =
2737  computeStaticContinuousTileSizes(linalgOp, getDimension(),
2738  getTargetSize());
2739  if (failed(spec)) {
2740  return emitSilenceableError()
2741  << "failed to compute multi-size tiling sizes";
2742  }
2743 
2744  SmallVector<int64_t> chunkSizes;
2745 
2746  for (auto &&[tileSize, tripCount] :
2747  llvm::zip_equal(spec->tileSizes, spec->tripCounts))
2748  chunkSizes.push_back(tileSize * tripCount);
2749 
2750  auto getI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
2751  return llvm::map_to_vector(values, [&](int64_t value) -> Attribute {
2752  return builder.getI64IntegerAttr(value);
2753  });
2754  };
2755  transformResults.setParams(cast<OpResult>(getTileSizes()),
2756  getI64AttrsFromI64(spec->tileSizes));
2757  transformResults.setParams(cast<OpResult>(getChunkSizes()),
2758  getI64AttrsFromI64(chunkSizes));
2759 
2761  }
2762 
2763  builder.setInsertionPoint(linalgOp);
2764 
2765  OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
2766  unsigned dimension = getDimension();
2767 
2768  FailureOr<ContinuousTileSizeSpecification> spec = computeContinuousTileSizes(
2769  builder, tileableOp, dimension, targetSize, true);
2770  if (failed(spec)) {
2771  return emitSilenceableError() << "could not generate tile size computation";
2772  }
2773 
2774  AffineExpr s0 = builder.getAffineSymbolExpr(0);
2775  AffineExpr s1 = builder.getAffineSymbolExpr(1);
2776  auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value {
2777  return affine::makeComposedAffineApply(builder, linalgOp->getLoc(), expr,
2778  ofrs);
2779  };
2780 
2781  SmallVector<Value> chunkSizes;
2782  Value splitPoint;
2783  for (auto &&[tileSize, tripCount] :
2784  llvm::zip_equal(spec->tileSizes, spec->tripCounts)) {
2785  splitPoint = apply(s0 * s1, {tileSize, tripCount});
2786  chunkSizes.push_back(splitPoint);
2787  }
2788 
2789  auto getDefiningOps = [&](ArrayRef<Value> values) {
2790  return llvm::map_to_vector(values, [&](Value value) -> Operation * {
2791  return value.getDefiningOp();
2792  });
2793  };
2794 
2795  transformResults.set(cast<OpResult>(getTileSizes()),
2796  getDefiningOps(spec->tileSizes));
2797  transformResults.set(cast<OpResult>(getChunkSizes()),
2798  getDefiningOps(chunkSizes));
2799 
2801 }
2802 
2804 
2805  if (getTileSizes().getType() != getChunkSizes().getType()) {
2806  return emitOpError() << "expects all results type to be the same";
2807  }
2808 
2809  return success();
2810 }
2811 
2812 void transform::ContinuousTileSizesOp::getEffects(
2814  if (isa<TransformParamTypeInterface>(getTileSizes().getType()))
2815  onlyReadsPayload(effects);
2816  else
2817  modifiesPayload(effects);
2818  onlyReadsHandle(getTargetMutable(), effects);
2819  producesHandle(getOperation()->getOpResults(), effects);
2820 }
2821 
2823  Type targetType, Type tile_sizes,
2824  Type) {
2825  printer.printFunctionalType(TypeRange{targetType}, TypeRange{tile_sizes});
2826 }
2827 
2828 static ParseResult parseContinuousTileSizeTypes(OpAsmParser &parser,
2829  Type &targetType,
2830  Type &tileSizesType,
2831  Type &chunkSizesType) {
2832  FunctionType funcType;
2833  llvm::SMLoc typeLoc = parser.getCurrentLocation();
2834  if (failed(parser.parseType<FunctionType>(funcType)))
2835  return failure();
2836 
2837  if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
2838  parser.emitError(typeLoc) << "expects a trailing functional type with one "
2839  "argument and one result";
2840  }
2841  targetType = funcType.getInput(0);
2842  tileSizesType = chunkSizesType = funcType.getResult(0);
2843 
2844  return success();
2845 }
2846 
2847 //===----------------------------------------------------------------------===//
2848 // TileUsingForOp
2849 //===----------------------------------------------------------------------===//
2850 
2851 void transform::TileUsingForOp::build(
2852  OpBuilder &builder, OperationState &result, TypeRange loopTypes,
2853  Value target, ArrayRef<int64_t> staticTileSizes,
2854  ArrayRef<int64_t> interchange,
2855  std::optional<ArrayRef<bool>> scalableSizes) {
2856  return build(builder, result, loopTypes,
2857  /*target=*/target,
2858  /*mixedTileSizes=*/
2859  getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
2860  interchange, scalableSizes);
2861 }
2862 
2863 void transform::TileUsingForOp::build(
2864  OpBuilder &builder, OperationState &result, Value target,
2865  ArrayRef<int64_t> staticTileSizes, ArrayRef<int64_t> interchange,
2866  std::optional<ArrayRef<bool>> scalableSizes) {
2867  build(builder, result, target,
2868  getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
2869  interchange, scalableSizes);
2870 }
2871 
2872 void transform::TileUsingForOp::build(
2873  OpBuilder &builder, OperationState &result, Value target,
2874  ArrayRef<OpFoldResult> mixedTileSizes, ArrayRef<int64_t> interchange,
2875  std::optional<ArrayRef<bool>> scalableSizes) {
2876  // Loop types are automaticaly splat by the callee, setting up one is
2877  // enough.
2878  SmallVector<Type> loopTypes(1, builder.getType<transform::AnyOpType>());
2879  build(builder, result, loopTypes, target, mixedTileSizes, interchange,
2880  scalableSizes);
2881 }
2882 
2883 void transform::TileUsingForOp::build(
2884  OpBuilder &builder, OperationState &result, TypeRange loopTypes,
2885  Value target, ArrayRef<OpFoldResult> mixedTileSizes,
2886  ArrayRef<int64_t> interchange,
2887  std::optional<ArrayRef<bool>> scalableSizes) {
2888  SmallVector<int64_t> staticTileSizes;
2889  SmallVector<Value> dynamicTileSizes;
2890  dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
2891  // Call the default builder which sets up the proper operands segment sizes
2892  // attributes for multiple variadic operands. In the absence of this,
2893  // horrible bugs ensue.
2894  auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
2895  unsigned numExpectedLoops =
2896  staticTileSizes.size() - llvm::count(staticTileSizes, 0);
2897  SmallVector<Type> resultTypes;
2898  resultTypes.reserve(numExpectedLoops);
2899  assert((loopTypes.size() == 1 || loopTypes.size() == numExpectedLoops) &&
2900  "expected one loop type or as many as loops");
2901  if (loopTypes.size() == 1)
2902  resultTypes.append(numExpectedLoops, loopTypes[0]);
2903  else
2904  llvm::append_range(resultTypes, loopTypes);
2905  SmallVector<bool> expandedScalableSizes(mixedTileSizes.size(), false);
2906  if (scalableSizes.has_value())
2907  expandedScalableSizes.assign(scalableSizes->begin(), scalableSizes->end());
2908  build(builder, result, /*tiled_linalg_op=*/target.getType(),
2909  /*loops=*/resultTypes,
2910  /*target=*/target,
2911  /*dynamic_sizes=*/dynamicTileSizes,
2912  /*static_sizes=*/staticTileSizesAttr,
2913  /*interchange=*/builder.getDenseI64ArrayAttr(interchange),
2914  /*scalable_sizes=*/expandedScalableSizes);
2915 }
2916 
2917 LogicalResult transform::TileUsingForOp::verify() {
2918  if (getMixedSizes().size() != getScalableSizes().size())
2919  return emitOpError("expected same number of sizes (")
2920  << getMixedSizes().size() << ") and scalable sizes ("
2921  << getScalableSizes().size() << ")";
2922  ArrayRef<int64_t> staticSizes = getStaticSizes();
2923  unsigned numExpectedLoops = staticSizes.size() - llvm::count(staticSizes, 0);
2924  if (getLoops().size() != numExpectedLoops)
2925  return emitOpError("expected number of loops to tile (")
2926  << numExpectedLoops << ") to match number of `loops` results ("
2927  << getLoops().size() << ")";
2928  return success();
2929 }
2930 
2932 transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter,
2933  TransformResults &transformResults,
2934  TransformState &state) {
2935  ArrayRef<int64_t> tileSizes = getStaticSizes();
2936 
2937  SmallVector<Operation *> targets =
2938  llvm::to_vector(state.getPayloadOps(getTarget()));
2939  SmallVector<SmallVector<Operation *>> dynamicSizeProducers;
2941  dynamicSizeProducers.reserve(getDynamicSizes().size());
2942  paramSizes.reserve(getDynamicSizes().size());
2943  for (Value transformValue : getDynamicSizes()) {
2944  if (isa<ParamType>(transformValue.getType())) {
2945  dynamicSizeProducers.push_back({});
2946  ArrayRef<Attribute> params = state.getParams(transformValue);
2947  paramSizes.push_back(
2948  llvm::to_vector(llvm::map_range(params, [](Attribute attr) {
2949  return cast<IntegerAttr>(attr).getValue().getSExtValue();
2950  })));
2951 
2952  if (paramSizes.back().size() != targets.size()) {
2954  emitSilenceableError()
2955  << "expected as many parameter values ("
2956  << dynamicSizeProducers.back().size() << ") as target ops ("
2957  << targets.size() << ")";
2958  diag.attachNote(transformValue.getLoc()) << "for this parameter";
2959  return diag;
2960  }
2961 
2962  continue;
2963  }
2964  paramSizes.push_back({});
2965  dynamicSizeProducers.push_back(
2966  llvm::to_vector(state.getPayloadOps(transformValue)));
2967 
2968  if (dynamicSizeProducers.back().size() != targets.size()) {
2970  emitSilenceableError()
2971  << "expected as many dynamic size-producing operations ("
2972  << dynamicSizeProducers.back().size() << ") as target ops ("
2973  << targets.size() << ")";
2974  diag.attachNote(transformValue.getLoc()) << "for this handle";
2975  return diag;
2976  }
2977 
2978  for (Operation *op : dynamicSizeProducers.back()) {
2979  if (op->getNumResults() == 1 &&
2980  isa<IndexType>(op->getResult(0).getType())) {
2981  continue;
2982  }
2983 
2985  emitSilenceableError() << "expected sizes to be produced by ops "
2986  "with a single index-type result";
2987  diag.attachNote(op->getLoc()) << "size producer op";
2988  diag.attachNote(transformValue.getLoc()) << "for this handle";
2989  return diag;
2990  }
2991  }
2992 
2995  loops.resize(getLoops().size());
2996  auto scalableSizes = getScalableSizes();
2997  for (auto [i, op] : llvm::enumerate(targets)) {
2998  auto tilingInterface = dyn_cast<TilingInterface>(op);
2999  if (!tilingInterface) {
3001  emitSilenceableError()
3002  << "only ops implementing TilingInterface are supported";
3003  diag.attachNote(op->getLoc()) << "target op";
3004  return diag;
3005  }
3006  if (tileSizes.size() > tilingInterface.getLoopIteratorTypes().size()) {
3008  emitSilenceableError()
3009  << "too many tiles provided, expected at most "
3010  << tilingInterface.getLoopIteratorTypes().size() << " found "
3011  << tileSizes.size();
3012  diag.attachNote(op->getLoc()) << "target op";
3013  return diag;
3014  }
3015 
3016  scf::SCFTilingOptions tilingOptions;
3017  if (tileSizes.empty()) {
3018  tilingOptions.setTileSizeComputationFunction(
3020  return {};
3021  });
3022  } else {
3023  tilingOptions.setTileSizeComputationFunction([&, index = i](OpBuilder &b,
3024  Operation *) {
3026  sizes.reserve(tileSizes.size());
3027  unsigned dynamicIdx = 0;
3028 
3029  for (auto [ofrIdx, ofr] : llvm::enumerate(getMixedSizes())) {
3030  if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) {
3031  if (scalableSizes[ofrIdx]) {
3032  auto val = b.create<arith::ConstantIndexOp>(
3033  getLoc(), cast<IntegerAttr>(attr).getInt());
3034  Value vscale =
3035  b.create<vector::VectorScaleOp>(getLoc(), b.getIndexType());
3036  sizes.push_back(
3037  b.create<arith::MulIOp>(getLoc(), val, vscale).getResult());
3038  } else {
3039  sizes.push_back(attr);
3040  }
3041  continue;
3042  }
3043  ArrayRef<Operation *> dynamicSizes = dynamicSizeProducers[dynamicIdx];
3044  ArrayRef<int64_t> params = paramSizes[dynamicIdx];
3045  ++dynamicIdx;
3046  assert((dynamicSizes.empty() ^ params.empty()) &&
3047  "expected either dynamic sizes or parameters");
3048  if (!params.empty()) {
3049  sizes.push_back(b.getIndexAttr(params[index]));
3050  } else {
3051  sizes.push_back(dynamicSizes[index]->getResult(0));
3052  }
3053  }
3054  return sizes;
3055  });
3056  }
3057 
3058  tilingOptions.setInterchange(getInterchange());
3059  FailureOr<scf::SCFTilingResult> maybeTilingResult =
3060  tileUsingSCF(rewriter, tilingInterface, tilingOptions);
3061  if (failed(maybeTilingResult))
3063 
3064  rewriter.replaceOp(op, maybeTilingResult->replacements);
3065 
3066  tiled.append(maybeTilingResult->tiledOps);
3067  for (const auto &en2 : llvm::enumerate(maybeTilingResult->loops))
3068  loops[en2.index()].push_back(en2.value());
3069  }
3070 
3071  transformResults.set(cast<OpResult>(getTiledLinalgOp()), tiled);
3072  for (const auto &en : llvm::enumerate(loops))
3073  transformResults.set(cast<OpResult>(getLoops()[en.index()]), en.value());
3074 
3076 }
3077 
3079  ValueRange dynamic = getDynamicSizes();
3080  ArrayRef<int64_t> tileSizes = getStaticSizes();
3081  SmallVector<OpFoldResult> results;
3082  results.reserve(tileSizes.size());
3083  unsigned dynamicPos = 0;
3084  Builder builder(getContext());
3085  for (int64_t size : tileSizes) {
3086  if (size == ShapedType::kDynamic) {
3087  results.push_back(dynamic[dynamicPos++]);
3088  } else {
3089  results.push_back(builder.getIndexAttr(size));
3090  }
3091  }
3092  return results;
3093 }
3094 
3095 void transform::TileUsingForOp::getEffects(
3097  consumesHandle(getTargetMutable(), effects);
3098  onlyReadsHandle(getDynamicSizesMutable(), effects);
3099  producesHandle(getOperation()->getOpResults(), effects);
3100  modifiesPayload(effects);
3101 }
3102 
3103 //===----------------------------------------------------------------------===//
3104 // TileUsingForallOp
3105 //===----------------------------------------------------------------------===//
3106 
3107 void transform::TileUsingForallOp::build(OpBuilder &builder,
3108  OperationState &result, Value target,
3109  ArrayRef<int64_t> staticTileSizes,
3111  ArrayAttr mapping) {
3112  return build(builder, result,
3113  /*target=*/target,
3114  /*mixedTileSizes=*/
3115  getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
3116  /*_=*/TileSizesSpec(),
3117  /*mapping=*/mapping);
3118 }
3119 
3120 void transform::TileUsingForallOp::build(OpBuilder &builder,
3121  OperationState &result, Value target,
3122  ArrayRef<OpFoldResult> mixedTileSizes,
3124  ArrayAttr mapping) {
3125  SmallVector<int64_t> staticTileSizes;
3126  SmallVector<Value> dynamicTileSizes;
3127  dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
3128  // Call the default builder which sets up the proper operands segment sizes
3129  // attributes for multiple variadic operands. In the absence of this,
3130  // horrible bugs ensue.
3131  MLIRContext *ctx = builder.getContext();
3132  auto operationType = transform::AnyOpType::get(ctx);
3133  auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
3134  build(builder, result,
3135  /*resultTypes=*/TypeRange{operationType, operationType},
3136  /*target=*/target,
3137  /*num_threads=*/ValueRange{},
3138  /*tile_sizes=*/dynamicTileSizes,
3139  /*packed_num_threads=*/Value(),
3140  /*packed_tile_sizes=*/Value(),
3141  /*static_num_threads=*/builder.getDenseI64ArrayAttr({}),
3142  /*static_tile_sizes=*/staticTileSizesAttr,
3143  /*mapping=*/mapping);
3144 }
3145 
3146 void transform::TileUsingForallOp::build(OpBuilder &builder,
3147  OperationState &result, Value target,
3148  ArrayRef<int64_t> staticNumThreads,
3150  ArrayAttr mapping) {
3151  return build(builder, result, target,
3152  getAsOpFoldResult(builder.getI64ArrayAttr(staticNumThreads)),
3153  NumThreadsSpec(), mapping);
3154 }
3155 
3156 void transform::TileUsingForallOp::build(OpBuilder &builder,
3157  OperationState &result, Value target,
3158  ArrayRef<OpFoldResult> mixedNumThreads,
3160  ArrayAttr mapping) {
3161  SmallVector<int64_t> staticNumThreads;
3162  SmallVector<Value> dynamicNumThreads;
3163  dispatchIndexOpFoldResults(mixedNumThreads, dynamicNumThreads,
3164  staticNumThreads);
3165  // Call the default builder which sets up the proper operands segment sizes
3166  // attributes for multiple variadic operands. In the absence of this,
3167  // horrible bugs ensue.
3168  MLIRContext *ctx = builder.getContext();
3169  auto operationType = transform::AnyOpType::get(ctx);
3170  auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads);
3171  build(builder, result,
3172  /*resultTypes=*/TypeRange{operationType, operationType},
3173  /*target=*/target,
3174  /*num_threads=*/dynamicNumThreads,
3175  /*tile_sizes=*/ValueRange{},
3176  /*packed_num_threads=*/Value(),
3177  /*packed_tile_sizes=*/Value(),
3178  /*static_num_threads=*/staticNumThreadsAttr,
3179  /*static_tile_sizes=*/builder.getDenseI64ArrayAttr({}),
3180  /*mapping=*/mapping);
3181 }
3182 
3183 /// Given `lbs`, `ubs` and `steps` of loops, return (for each loop), the
3184 /// normalized upper bound.
3188  ArrayRef<OpFoldResult> steps) {
3189  AffineExpr s0, s1, s2;
3190  bindSymbols(rewriter.getContext(), s0, s1, s2);
3191  AffineExpr normalizedUbExpr = (s1 - s0).ceilDiv(s2);
3192  SmallVector<OpFoldResult> normalizedUbs;
3193  for (auto [lb, ub, step] : llvm::zip_equal(lbs, ubs, steps)) {
3195  rewriter, loc, normalizedUbExpr, {lb, ub, step});
3196  normalizedUbs.push_back(normalizedUb);
3197  }
3198  return normalizedUbs;
3199 }
3200 
3201 /// When a loop is normalized, the uses of the induction variable within the
3202 /// loop need to replaced with `original_lb + old_iv * original_step`.
3204  Location loc, ValueRange ivs,
3206  ArrayRef<OpFoldResult> steps) {
3207  AffineExpr s0, s1;
3208  AffineExpr d0;
3209  bindSymbols(rewriter.getContext(), s0, s1);
3210  bindDims(rewriter.getContext(), d0);
3211  AffineExpr denormExpr = s0 + d0 * s1;
3212  SmallVector<Value> denormalizedIvs;
3213 
3214  for (auto [iv, lb, step] : llvm::zip_equal(ivs, lbs, steps)) {
3216  rewriter, loc, denormExpr, ArrayRef<OpFoldResult>{iv, lb, step});
3217  denormalizedIvs.push_back(
3218  getValueOrCreateConstantIndexOp(rewriter, loc, denormValue));
3219  }
3220  return denormalizedIvs;
3221 }
3222 
3223 /// Given a `scf.forall` loop return a loop op with the loop bounds
3224 /// normalized.
3225 /// TODO: Replace this with a general utility to normalize `scf.forall`.
3226 /// At the time of writing, this wasnt done since adding this to `scf`
3227 /// dialect would disallow using of `affine.apply` operations due
3228 /// to cyclic dependencies. To avoid churn in lit tests
3229 /// with the change this was added with, defer that to a follow up.
3230 static scf::ForallOp normalizeForallLoopOp(RewriterBase &rewriter,
3231  scf::ForallOp loop) {
3232  SmallVector<OpFoldResult> lbs = loop.getMixedLowerBound();
3233  SmallVector<OpFoldResult> ubs = loop.getMixedUpperBound();
3234  SmallVector<OpFoldResult> steps = loop.getMixedStep();
3235 
3236  if (llvm::all_of(
3237  lbs, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }) &&
3238  llvm::all_of(
3239  steps, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); })) {
3240  return loop;
3241  }
3242 
3243  Location loc = loop.getLoc();
3244  SmallVector<OpFoldResult> normalizedUbs =
3245  normalizeUpperBounds(rewriter, loc, lbs, ubs, steps);
3246  SmallVector<OpFoldResult> normalizedLbs(normalizedUbs.size(),
3247  rewriter.getIndexAttr(0));
3248  SmallVector<OpFoldResult> normalizedSteps(normalizedUbs.size(),
3249  rewriter.getIndexAttr(1));
3250 
3251  auto normalizedForallOp = rewriter.create<scf::ForallOp>(
3252  loc, normalizedLbs, normalizedUbs, normalizedSteps, loop.getOutputs(),
3253  loop.getMapping(), [](OpBuilder &, Location, ValueRange) {});
3254 
3255  auto normalizedLoopIvs = normalizedForallOp.getInductionVars();
3256  OpBuilder::InsertionGuard g(rewriter);
3257  Block *normalizedLoopBlock = normalizedForallOp.getBody();
3258  rewriter.setInsertionPointToStart(normalizedLoopBlock);
3259 
3260  SmallVector<Value> argValues =
3261  denormalizeIndVar(rewriter, loc, normalizedLoopIvs, lbs, steps);
3262  argValues.append(normalizedForallOp.getRegionIterArgs().begin(),
3263  normalizedForallOp.getRegionIterArgs().end());
3264  Block *origLoopBlock = loop.getBody();
3265  rewriter.mergeBlocks(origLoopBlock, normalizedLoopBlock, argValues);
3266 
3267  rewriter.replaceOp(loop, normalizedForallOp);
3268  return normalizedForallOp;
3269 }
3270 
3272  RewriterBase &rewriter, transform::TransformState &state,
3273  TransformOpInterface transformOp, Operation *target,
3274  ArrayRef<OpFoldResult> mixedNumThreads,
3275  ArrayRef<OpFoldResult> mixedTileSizes, std::optional<ArrayAttr> mapping,
3276  scf::SCFTilingResult &tilingResult) {
3277  // Transform all targets one by one.
3278  auto tileableOp = dyn_cast<TilingInterface>(target);
3279  if (!tileableOp) {
3281  transformOp.emitSilenceableError()
3282  << "only TilingInterface ops are supported";
3283  diag.attachNote(target->getLoc()) << "target op";
3284  return diag;
3285  }
3286  rewriter.setInsertionPoint(tileableOp);
3289  if (!mixedNumThreads.empty()) {
3290  options.setNumThreads(mixedNumThreads);
3291  } else {
3292  options.setTileSizes(mixedTileSizes);
3293  }
3294  if (mapping) {
3295  options.setMapping(mapping.value().getValue());
3296  }
3297  FailureOr<scf::SCFTilingResult> maybeTilingResult =
3298  scf::tileUsingSCF(rewriter, tileableOp, options);
3299 
3300  if (failed(maybeTilingResult))
3301  return transformOp.emitDefaultSilenceableFailure(tileableOp);
3302 
3303  rewriter.replaceOp(tileableOp, maybeTilingResult->replacements);
3304 
3305  tilingResult = *maybeTilingResult;
3306 
3307  if (mixedNumThreads.empty()) {
3308  auto generatedForallOp = cast<scf::ForallOp>(tilingResult.loops.front());
3309  OpBuilder::InsertionGuard g(rewriter);
3310  rewriter.setInsertionPoint(generatedForallOp);
3311  scf::ForallOp normalizedForallOp =
3312  normalizeForallLoopOp(rewriter, generatedForallOp);
3313  tilingResult.loops.front() = normalizedForallOp;
3314  }
3315 
3317 }
3318 
3319 DiagnosedSilenceableFailure transform::TileUsingForallOp::apply(
3320  transform::TransformRewriter &rewriter,
3321  transform::TransformResults &transformResults,
3322  transform::TransformState &state) {
3323  auto transformOp = cast<TransformOpInterface>(getOperation());
3324 
3325  // Result payload ops.
3326  SmallVector<Operation *> tileOps;
3327  SmallVector<Operation *> tiledOps;
3328 
3329  // Unpack handles.
3330  SmallVector<OpFoldResult> mixedNumThreads;
3332  getPackedNumThreads()
3334  state, transformOp, mixedNumThreads, getPackedNumThreads())
3336  state, transformOp, mixedNumThreads, getMixedNumThreads());
3337  if (!status.succeeded())
3338  return status;
3339  SmallVector<OpFoldResult> mixedTileSizes;
3340  status = getPackedTileSizes()
3342  state, transformOp, mixedTileSizes, getPackedTileSizes())
3344  state, transformOp, mixedTileSizes, getMixedTileSizes());
3345  if (!status.succeeded())
3346  return status;
3347 
3348  for (Operation *target : state.getPayloadOps(getTarget())) {
3349  scf::SCFTilingResult tilingResult;
3351  rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes,
3352  getMapping(), tilingResult);
3353  if (!diag.succeeded())
3354  return diag;
3355  tileOps.push_back(tilingResult.loops.front());
3356  tiledOps.append(tilingResult.tiledOps);
3357  }
3358 
3359  transformResults.set(cast<OpResult>(getForallOp()), tileOps);
3360  transformResults.set(cast<OpResult>(getTiledOp()), tiledOps);
3361 
3363 }
3364 
3365 void transform::TileUsingForallOp::getEffects(
3367  consumesHandle(getTargetMutable(), effects);
3368  onlyReadsHandle(getTileSizesMutable(), effects);
3369  onlyReadsHandle(getNumThreadsMutable(), effects);
3370  onlyReadsHandle(getPackedNumThreadsMutable(), effects);
3371  onlyReadsHandle(getPackedTileSizesMutable(), effects);
3372  producesHandle(getOperation()->getOpResults(), effects);
3373  modifiesPayload(effects);
3374 }
3375 
3376 SmallVector<OpFoldResult> TileUsingForallOp::getMixedNumThreads() {
3377  Builder b(getContext());
3378  return getMixedValues(getStaticNumThreads(), getNumThreads(), b);
3379 }
3380 
3381 SmallVector<OpFoldResult> TileUsingForallOp::getMixedTileSizes() {
3382  Builder b(getContext());
3383  return getMixedValues(getStaticTileSizes(), getTileSizes(), b);
3384 }
3385 
3386 LogicalResult TileUsingForallOp::verify() {
3387  int numThreadsSpec = static_cast<int>(!getMixedNumThreads().empty()) +
3388  static_cast<int>(getPackedNumThreads() != Value());
3389  if (numThreadsSpec > 1)
3390  return emitOpError(
3391  "num_threads and packed_num_threads are mutually exclusive");
3392  int tileSizesSpec = static_cast<int>(!getMixedTileSizes().empty()) +
3393  static_cast<int>(getPackedTileSizes() != Value());
3394  if (tileSizesSpec > 1)
3395  return emitOpError(
3396  "tile_sizes and packed_tile_sizes are mutually exclusive");
3397  if (numThreadsSpec == 0 && tileSizesSpec == 0)
3398  return emitOpError("either (packed_)num_threads or (packed_)tile_sizes "
3399  "must be specified");
3400  return success();
3401 }
3402 
3403 //===----------------------------------------------------------------------===//
3404 // VectorizeChildrenAndApplyPatternsOp
3405 //===----------------------------------------------------------------------===//
3406 
3407 void transform::VectorizeChildrenAndApplyPatternsOp::build(
3408  OpBuilder &builder, OperationState &result, Value target,
3409  bool vectorizePadding, bool vectorizeExtract, bool flatten1DDepthwiseConv) {
3410  result.addOperands(target);
3411  if (vectorizePadding) {
3412  result.addAttribute(
3413  VectorizeChildrenAndApplyPatternsOp::getVectorizePaddingAttrName(
3414  result.name),
3415  builder.getUnitAttr());
3416  }
3417  if (vectorizeExtract) {
3418  result.addAttribute(
3419  VectorizeChildrenAndApplyPatternsOp::getVectorizeNdExtractAttrName(
3420  result.name),
3421  builder.getUnitAttr());
3422  }
3423  if (flatten1DDepthwiseConv) {
3424  result.addAttribute(
3425  VectorizeChildrenAndApplyPatternsOp::getFlatten_1dDepthwiseConvAttrName(
3426  result.name),
3427  builder.getUnitAttr());
3428  }
3429  result.addTypes(transform::AnyOpType::get(builder.getContext()));
3430 }
3431 
3432 namespace {
3433 /// This is an helper only to call vectorize via a pattern inside of
3434 /// VectorizeChildrenAndApplyPatternsOp::applyToOne.
3435 struct VectorizationPattern : public RewritePattern {
3436  explicit VectorizationPattern(MLIRContext *context,
3437  bool vectorizeExtract = false,
3438  bool flattenConv = false)
3439  : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context),
3440  vectorizeNDExtract(vectorizeExtract),
3441  flatten1DDepthwiseConv(flattenConv) {}
3442  LogicalResult matchAndRewrite(Operation *op,
3443  PatternRewriter &rewriter) const override {
3445  return rewriter.notifyMatchFailure(op,
3446  "Unsupported Op, cannot vectorize");
3447  return vectorize(rewriter, op, /*inputVectorSizes=*/{},
3448  /*inputScalableVecDims=*/{}, vectorizeNDExtract,
3449  flatten1DDepthwiseConv);
3450  }
3451 
3452 private:
3453  /// Controls whether to vectorize `tensor.extract` when the input tensor is
3454  /// rank >= 2.
3455  bool vectorizeNDExtract = false;
3456  /// Controls whether to "flatten" the channel dimension when vectorising 1D
3457  /// depthwise convolutions. This should lead to bette vectorization for
3458  /// tensors with a low number of channel dimensions.
3459  bool flatten1DDepthwiseConv = false;
3460 };
3461 } // namespace
3462 
3464 transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
3465  transform::TransformRewriter &rewriter, Operation *target,
3467  transform::TransformState &state) {
3468  if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
3469  auto diag = this->emitOpError("requires isolated-from-above targets");
3470  diag.attachNote(target->getLoc()) << "non-isolated target";
3472  }
3473 
3474  MLIRContext *ctx = getContext();
3475  RewritePatternSet patterns(ctx);
3476  patterns.add<VectorizationPattern>(ctx, getVectorizeNdExtract(),
3477  getFlatten_1dDepthwiseConv());
3478 
3479  if (!getDisableTransferPermutationMapLoweringPatterns())
3481 
3482  if (!getDisableMultiReductionToContractPatterns())
3484 
3486 
3489  /*benefit=*/2);
3490  vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx);
3491  vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx);
3493 
3494  patterns.add<CopyVectorizationPattern>(ctx);
3495 
3496  // Add misc. vectorization patterns (e.g. for tensor.insert_slice)
3498 
3499  if (getVectorizePadding()) {
3501  // This creates an alternative path for lowering tensor.pad - by
3502  // decomposing it into e.g. linalg.fill.
3504  }
3506 
3507  TrackingListener listener(state, *this);
3508  GreedyRewriteConfig config;
3509  config.listener = &listener;
3510  if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns), config)))
3511  return emitDefaultDefiniteFailure(target);
3512 
3513  results.push_back(target);
3515 }
3516 
3517 //===----------------------------------------------------------------------===//
3518 // VectorizeOp
3519 //===----------------------------------------------------------------------===//
3520 
3521 DiagnosedSilenceableFailure transform::VectorizeOp::apply(
3522  transform::TransformRewriter &rewriter,
3523  mlir::transform::TransformResults &transformResults,
3525  auto targets = state.getPayloadOps(getTarget());
3526  if (std::empty(targets))
3528  auto transformOp = cast<TransformOpInterface>(getOperation());
3529  SmallVector<int64_t> vectorSizes;
3531  state, transformOp, getMixedVectorSizes(), vectorSizes);
3532  if (!status.succeeded())
3533  return status;
3534 
3535  // TODO: Check that the correct number of vectorSizes was provided.
3536  for (Operation *target : targets) {
3537  if (!linalg::hasVectorizationImpl(target)) {
3538  return mlir::emitSilenceableFailure(target->getLoc())
3539  << "Unsupported Op, cannot vectorize";
3540  }
3541 
3542  if (failed(linalg::vectorize(rewriter, target, vectorSizes,
3543  getScalableSizes(),
3544  getVectorizeNdExtract().value_or(false)))) {
3545  return mlir::emitSilenceableFailure(target->getLoc())
3546  << "Attempted to vectorize, but failed";
3547  }
3548  }
3549 
3551 }
3552 
3553 void transform::VectorizeOp::getEffects(
3555  consumesHandle(getTargetMutable(), effects);
3556  onlyReadsHandle(getVectorSizesMutable(), effects);
3557  modifiesPayload(effects);
3558 }
3559 
3560 SmallVector<OpFoldResult> VectorizeOp::getMixedVectorSizes() {
3561  OpBuilder b(getContext());
3562  return getMixedValues(getStaticVectorSizes(), getVectorSizes(), b);
3563 }
3564 
3565 LogicalResult transform::VectorizeOp::verify() {
3566  if (getStaticVectorSizes().size() != getScalableSizes().size())
3567  return emitOpError("expected same number of vector sizes (")
3568  << getStaticVectorSizes().size() << ") and scalable sizes ("
3569  << getScalableSizes().size() << ")";
3570  return success();
3571 }
3572 
3573 //===----------------------------------------------------------------------===//
3574 // HoistRedundantVectorTransfersOp
3575 //===----------------------------------------------------------------------===//
3576 
3578 transform::HoistRedundantVectorTransfersOp::applyToOne(
3579  transform::TransformRewriter &rewriter, func::FuncOp target,
3581  transform::TransformState &state) {
3582  // WARNING: This hoisting does not model parallelism and is generally
3583  // incorrect when used on distributed loops with memref semantics!
3584  // TODO: obsolete and should be retired.
3585  linalg::hoistRedundantVectorTransfers(target, getVerifyNonZeroTrip());
3586  results.push_back(target);
3588 }
3589 
3590 //===----------------------------------------------------------------------===//
3591 // HoistRedundantVectorBroadcastsOp
3592 //===----------------------------------------------------------------------===//
3593 
3595 transform::HoistRedundantVectorBroadcastsOp::applyToOne(
3596  transform::TransformRewriter &rewriter, mlir::Operation *target,
3598  transform::TransformState &state) {
3599  rewriter.setInsertionPoint(target);
3600  linalg::hoistRedundantVectorBroadcasts(rewriter, target);
3601  results.push_back(target);
3603 }
3604 
3605 //===----------------------------------------------------------------------===//
3606 // ConvertConv2DToImg2ColOp.
3607 //===----------------------------------------------------------------------===//
3608 
3609 DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne(
3610  transform::TransformRewriter &rewriter, linalg::LinalgOp target,
3612  transform::TransformState &state) {
3613  rewriter.setInsertionPoint(target);
3614  auto maybeTransformed =
3616  target)
3617  .Case([&](linalg::Conv2DNhwcHwcfOp op) {
3618  return rewriteInIm2Col(rewriter, op);
3619  })
3620  .Case([&](linalg::Conv2DNhwcFhwcOp op) {
3621  return rewriteInIm2Col(rewriter, op);
3622  })
3623  .Case([&](linalg::DepthwiseConv2DNhwcHwcOp op) {
3624  return rewriteInIm2Col(rewriter, op);
3625  })
3626  .Case([&](linalg::Conv2DNchwFchwOp op) {
3627  return rewriteInIm2Col(rewriter, op);
3628  })
3629  .Default([&](Operation *op) {
3630  return rewriter.notifyMatchFailure(op, "not supported");
3631  });
3632  if (failed(maybeTransformed))
3633  return emitDefaultSilenceableFailure(target);
3634  // Handle to the operation producing the img2col tensor.
3635  results.push_back(maybeTransformed->first);
3636  // Handle to the operation that replaces the original convolution.
3637  results.push_back(maybeTransformed->second);
3639 }
3640 
3641 //===----------------------------------------------------------------------===//
3642 // FlattenElementwiseLinalgOp.
3643 //===----------------------------------------------------------------------===//
3644 
3645 DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne(
3646  transform::TransformRewriter &rewriter, linalg::LinalgOp target,
3648  transform::TransformState &state) {
3649  rewriter.setInsertionPoint(target);
3650  if (!isElementwise(target))
3651  return mlir::emitSilenceableFailure(target->getLoc())
3652  << "only elementwise flattening is supported";
3653 
3654  // If rank <= 1, do nothing
3655  if (target.getNumLoops() <= 1) {
3656  results.push_back(target);
3658  }
3659 
3660  // Attempt to flatten all dims to one.
3661  ReassociationIndices reassociation(target.getNumLoops());
3662  std::iota(reassociation.begin(), reassociation.end(), 0);
3663  auto maybeFlattened =
3664  collapseOpIterationDims(target, reassociation, rewriter);
3665  if (failed(maybeFlattened))
3666  return mlir::emitSilenceableFailure(target->getLoc())
3667  << "attempted to flatten, but failed";
3668  results.push_back(maybeFlattened->collapsedOp);
3669  rewriter.replaceOp(target, maybeFlattened->results);
3671 }
3672 
3673 //===----------------------------------------------------------------------===//
3674 // TransposeConv2DOp
3675 //===----------------------------------------------------------------------===//
3676 
3677 DiagnosedSilenceableFailure transform::TransposeConv2DOp::applyToOne(
3678  transform::TransformRewriter &rewriter, linalg::LinalgOp target,
3680  transform::TransformState &state) {
3681  rewriter.setInsertionPoint(target);
3682  auto maybeTransformed =
3684  .Case([&](linalg::Conv2DNhwcFhwcOp op) {
3685  return transposeConv2D(rewriter, op);
3686  })
3687  .Case([&](linalg::Conv2DNhwcFhwcQOp op) {
3688  return transposeConv2D(rewriter, op);
3689  })
3690  .Default([&](Operation *op) {
3691  return rewriter.notifyMatchFailure(op, "not supported");
3692  });
3693  if (failed(maybeTransformed))
3694  return emitDefaultSilenceableFailure(target);
3695  // Handle to the new Conv2D operation with transposed filters
3696  results.push_back(*maybeTransformed);
3698 }
3699 
3700 //===----------------------------------------------------------------------===//
3701 // TransposeMatmulOp
3702 //===----------------------------------------------------------------------===//
3703 
3704 DiagnosedSilenceableFailure transform::TransposeMatmulOp::applyToOne(
3705  transform::TransformRewriter &rewriter, linalg::LinalgOp target,
3707  transform::TransformState &state) {
3708  rewriter.setInsertionPoint(target);
3709  bool transposeLHS = getInputToTranspose() == TransposeMatmulInput::lhs;
3710  auto maybeTransformed =
3712  .Case([&](linalg::MatmulOp op) {
3713  return transposeMatmul(rewriter, op, transposeLHS);
3714  })
3715  .Case([&](linalg::BatchMatmulOp op) {
3716  return transposeBatchMatmul(rewriter, op, transposeLHS);
3717  })
3718  .Default([&](Operation *op) { return failure(); });
3719  if (failed(maybeTransformed))
3720  return emitSilenceableFailure(target->getLoc()) << "not supported";
3721  // Handle to the new Matmul operation with transposed filters
3722  results.push_back(*maybeTransformed);
3724 }
3725 
3726 //===----------------------------------------------------------------------===//
3727 // InsertSliceToCopyOp
3728 //===----------------------------------------------------------------------===//
3729 template <typename OpTy>
3732  transform::TransformState &state) {
3733  static_assert(llvm::is_one_of<OpTy, tensor::InsertSliceOp,
3734  tensor::ParallelInsertSliceOp>() &&
3735  "wrong op type");
3736 
3737  if (auto copySource =
3738  target.getSource().template getDefiningOp<linalg::CopyOp>()) {
3739  results.push_back(copySource);
3741  }
3742 
3743  // If we are inside an InParallel region, temporarily set the insertion point
3744  // outside: only tensor.parallel_insert_slice ops are allowed in there.
3745  if constexpr (std::is_same_v<OpTy, tensor::ParallelInsertSliceOp>) {
3746  rewriter.setInsertionPoint(
3747  target->template getParentOfType<scf::InParallelOp>());
3748  }
3749 
3750  Value extracted = rewriter.create<tensor::ExtractSliceOp>(
3751  target.getLoc(), target.getDest(), target.getMixedOffsets(),
3752  target.getMixedSizes(), target.getMixedStrides());
3753  Value copied = rewriter
3754  .create<linalg::CopyOp>(target.getLoc(),
3755  target.getSource(), extracted)
3756  .getResult(0);
3757  // Reset the insertion point.
3758  rewriter.setInsertionPoint(target);
3759  rewriter.replaceOpWithNewOp<OpTy>(
3760  target, copied, target.getDest(), target.getMixedOffsets(),
3761  target.getMixedSizes(), target.getMixedStrides());
3762 
3763  results.push_back(copied.getDefiningOp());
3765 }
3766 
3767 DiagnosedSilenceableFailure transform::InsertSliceToCopyOp::applyToOne(
3768  transform::TransformRewriter &rewriter, Operation *targetOp,
3770  transform::TransformState &state) {
3771 
3772  rewriter.setInsertionPoint(targetOp);
3773  if (auto target = dyn_cast<tensor::InsertSliceOp>(targetOp))
3774  return doit(rewriter, target, results, state);
3775  if (auto target = dyn_cast<tensor::ParallelInsertSliceOp>(targetOp))
3776  return doit(rewriter, target, results, state);
3777 
3779  emitSilenceableError()
3780  << "only InsertSliceOp and ParallelInsertSliceOp ops are supported";
3781  diag.attachNote(targetOp->getLoc()) << "target op";
3782  return diag;
3783 }
3784 
3785 //===----------------------------------------------------------------------===//
3786 // MapCopyToThreadsOp
3787 //===----------------------------------------------------------------------===//
3788 
3789 DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
3790  transform::TransformRewriter &rewriter, Operation *target,
3792  transform::TransformState &state) {
3793  // Check if the op is supported.
3794  if (!isa<linalg::CopyOp, tensor::PadOp>(target)) {
3796  emitSilenceableError()
3797  << "only linalg.copy and tensor.pad target ops are supported";
3798  diag.attachNote(target->getLoc()) << "target op";
3799  return diag;
3800  }
3801  assert(target->getNumResults() == 1 && "expected single result");
3802  auto resultShapedType = cast<ShapedType>(target->getResult(0).getType());
3803  if (!resultShapedType.hasStaticShape()) {
3805  emitSilenceableError()
3806  << "only statically sized ops of rank <= 3 are supported";
3807  diag.attachNote(target->getLoc()) << "target op";
3808  return diag;
3809  }
3810 
3811  // Conservatively set the minimum viable desired bitwidth alignment.
3812  int64_t desiredBitAlignment = getDesiredBitAlignment();
3813  int64_t eltBitwidth =
3814  resultShapedType.getElementType().getIntOrFloatBitWidth();
3815  if (desiredBitAlignment % eltBitwidth != 0) {
3816  desiredBitAlignment = eltBitwidth;
3817  }
3818 
3819  gpu::CopyMappingInfo mapping(
3820  /*ctx=*/getContext(),
3821  /*totalNumThreads=*/getTotalNumThreads(),
3822  /*alignment=*/desiredBitAlignment,
3823  /*sizes=*/resultShapedType.getShape(),
3824  /*favorPredication=*/false,
3825  /*elementalBitwidth=*/
3826  resultShapedType.getElementType().getIntOrFloatBitWidth());
3827  if (mapping.status == gpu::CopyMappingInfo::Status::Invalid) {
3829  emitSilenceableError()
3830  << "too few threads to map copy op to threads on the most minor "
3831  "dimension, given alignment and vector size constraints, try "
3832  "smaller tile size of mapping to more threads";
3833  diag.attachNote(target->getLoc()) << "target op";
3834  return diag;
3835  }
3836 
3837  // OpBuilder only used to compute attributes.
3838  OpBuilder b(getContext());
3839  scf::SCFTilingResult tilingResult;
3841  /*rewriter=*/rewriter,
3842  /*state=*/state,
3843  /*transformOp=*/*this,
3844  /*target=*/target,
3845  /*mixedNumThreads=*/getMixedValues(mapping.numThreads, {}, b),
3846  /*mixedTileSizes=*/ArrayRef<OpFoldResult>{},
3847  /*mapping=*/b.getArrayAttr(mapping.threadMapping),
3848  /*tilingResult=*/tilingResult);
3849  if (!diag.succeeded())
3850  return diag;
3851 
3852  results.push_back(tilingResult.loops.front());
3853  for (auto op : tilingResult.tiledOps)
3854  results.push_back(op);
3856 }
3857 
3858 //===----------------------------------------------------------------------===//
3859 // WinogradConv2DOp
3860 //===----------------------------------------------------------------------===//
3861 
3862 DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne(
3863  transform::TransformRewriter &rewriter, linalg::LinalgOp target,
3865  transform::TransformState &state) {
3866  rewriter.setInsertionPoint(target);
3867  FailureOr<Operation *> maybeTransformed = failure();
3868  bool supported = TypeSwitch<Operation *, bool>(target)
3869  .Case([&](linalg::Conv2DNhwcFhwcOp op) {
3870  maybeTransformed =
3871  winogradConv2D(rewriter, op, getM(), getR());
3872  return true;
3873  })
3874  .Default([&](Operation *op) { return false; });
3875 
3876  if (!supported) {
3877  return emitSilenceableError()
3878  << "this operation is not supported to convert to Winograd Conv2D";
3879  }
3880 
3881  if (supported && failed(maybeTransformed)) {
3882  return emitSilenceableError() << "apply Winograd Conv2D failed";
3883  }
3884 
3885  results.push_back(*maybeTransformed);
3887 }
3888 
3889 DiagnosedSilenceableFailure transform::DecomposeWinogradOp::applyToOne(
3890  transform::TransformRewriter &rewriter, Operation *target,
3892  transform::TransformState &state) {
3893  rewriter.setInsertionPoint(target);
3894  FailureOr<Operation *> maybeTransformed = failure();
3895  bool supported =
3897  .Case([&](linalg::WinogradFilterTransformOp op) {
3898  maybeTransformed = decomposeWinogradFilterTransformOp(rewriter, op);
3899  return true;
3900  })
3901  .Case([&](linalg::WinogradInputTransformOp op) {
3902  maybeTransformed = decomposeWinogradInputTransformOp(rewriter, op);
3903  return true;
3904  })
3905  .Case([&](linalg::WinogradOutputTransformOp op) {
3906  maybeTransformed = decomposeWinogradOutputTransformOp(rewriter, op);
3907  return true;
3908  })
3909  .Default([&](Operation *op) { return false; });
3910 
3911  if (!supported) {
3913  emitSilenceableError()
3914  << "this operation is not supported to decompose into other operations";
3915  diag.attachNote(target->getLoc()) << "target op";
3916  return diag;
3917  }
3918 
3919  if (supported && failed(maybeTransformed)) {
3921  emitSilenceableError() << "decompose Winograd operations failed";
3922  diag.attachNote(target->getLoc()) << "target op";
3923  return diag;
3924  }
3925 
3926  results.push_back(*maybeTransformed);
3928 }
3929 
3930 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
3931 
3932 #define GET_OP_CLASSES
3933 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
static MLIRContext * getContext(OpFoldResult val)
DiagnosedSilenceableFailure doit(RewriterBase &rewriter, OpTy target, transform::ApplyToEachResultList &results, transform::TransformState &state)
static Operation * cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
#define DOWNSCALE(trans)
bool isValidPackingPermutation(RelayoutOpTy op, ArrayRef< int64_t > permutation, OuterOrInnerPerm outerOrInnerPerm=OuterOrInnerPerm::Outer)
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static DiagnosedSilenceableFailure reifyMixedParamAndHandleResults(TransformState &state, TransformOpInterface &transformOp, ArrayRef< OpFoldResult > mixedResults, SmallVectorImpl< int64_t > &reified)
When possible, converts each OpFoldResult in mixedResult to an integer if the value can be statically...
static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations(transform::TransformState &state, TransformOpInterface transformOp, SmallVector< OpFoldResult > &result, ArrayRef< OpFoldResult > ofrs)
Assuming that ofr is an index attr or a param of index type or a transform dialect handle mapped to e...
static void printContinuousTileSizeTypes(OpAsmPrinter &printer, Operation *op, Type targetType, Type tile_sizes, Type)
static scf::ForallOp normalizeForallLoopOp(RewriterBase &rewriter, scf::ForallOp loop)
Given a scf.forall loop return a loop op with the loop bounds normalized.
static SmallVector< Value > denormalizeIndVar(RewriterBase &rewriter, Location loc, ValueRange ivs, ArrayRef< OpFoldResult > lbs, ArrayRef< OpFoldResult > steps)
When a loop is normalized, the uses of the induction variable within the loop need to replaced with o...
#define DOWNSCALE_NORMAL(a, b)
static FailureOr< LinalgOp > tryApply(Operation *operation, Args &&...args)
Attempts to apply the pattern specified as template argument to the given operation.
static void printMultitileSizesTypes(OpAsmPrinter &printer, Operation *op, Type targetType, Type lowSizeType, Type, Type)
static Operation * replaceForAllWithNewSignature(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp, TilingResult &tileAndFuseResult, int64_t resultNumber, SmallVector< OpFoldResult > &offsets, SmallVector< OpFoldResult > &sizes)
Add new operands to the forall op for users of the producerOp that are dominated by the containing sc...
static ParseResult parseContinuousTileSizeTypes(OpAsmParser &parser, Type &targetType, Type &tileSizesType, Type &chunkSizesType)
static SmallVector< Operation * > tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
First, find the first "scf::ForallOp" user of producerOp and ensure it is exactly the containingOp,...
static ParseResult parseMultitileSizesTypes(OpAsmParser &parser, Type &targetType, Type &lowSizeType, Type &highSizeType, Type &splitPointType)
static SmallVector< OpFoldResult > normalizeUpperBounds(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > lbs, ArrayRef< OpFoldResult > ubs, ArrayRef< OpFoldResult > steps)
Given lbs, ubs and steps of loops, return (for each loop), the normalized upper bound.
#define DBGS()
static LogicalResult applyTilingToAll(RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps, unsigned numLoops, transform::TransformResults &transformResults, function_ref< FailureOr< scf::SCFTileAndFuseResult >(TilingInterface)> applyFn)
Apply a tiling transformation to all payload ops and store both the tiled operation as well as the cr...
static std::tuple< SmallVector< Operation * >, Operation * > tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp)
Find the first "extract" user of producerOp and tile it right before its use.
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
Block represents an ordered list of Operations.
Definition: Block.h:33
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:148
UnitAttr getUnitAttr()
Definition: Builders.cpp:138
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:268
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:207
AffineExpr getAffineSymbolExpr(unsigned position)
Definition: Builders.cpp:408
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:152
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:100
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:302
MLIRContext * getContext() const
Definition: Builders.h:56
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:306
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:321
IndexType getIndexType()
Definition: Builders.cpp:95
ArrayAttr getStrArrayAttr(ArrayRef< StringRef > values)
Definition: Builders.cpp:346
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
A class for computing basic dominance information.
Definition: Dominance.h:140
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition: Dominance.h:160
This class allows control over how the GreedyPatternRewriteDriver works.
RewriterBase::Listener * listener
An optional listener that should be notified about IR modifications.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:207
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
Definition: AsmPrinter.cpp:94
This class represents a saved insertion point.
Definition: Builders.h:336
bool isSet() const
Returns true if this insert point is set.
Definition: Builders.h:346
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:357
This class helps build Operations.
Definition: Builders.h:216
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:588
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
Definition: Builders.h:325
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:440
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:407
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:329
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:421
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:457
This class provides the API for ops that are known to be isolated from above.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getOpResult(unsigned idx)
Definition: Operation.h:416
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:745
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:529
void setOperand(unsigned idx, Value value)
Definition: Operation.h:346
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition: Operation.h:555
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:669
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:341
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_type_range getOperandTypes()
Definition: Operation.h:392
result_type_range getResultTypes()
Definition: Operation.h:423
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition: Operation.h:263
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:869
result_range getOpResults()
Definition: Operation.h:415
result_range getResults()
Definition: Operation.h:410
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:219
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:246
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:644
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:636
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:64
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
Type front()
Return first type in the range.
Definition: TypeRange.h:148
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
user_range getUsers() const
Definition: Value.h:228
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
State for analysis-enabled bufferization.
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
void assign(unsigned size, std::nullptr_t)
Sets the list of results to size null pointers.
void reserve(unsigned size)
Reserves space for size elements in the list.
size_t size() const
Returns the number of elements in the list.
void push_back(Operation *op)
Appends an element to the list.
A listener that updates a TransformState based on IR modifications.
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void setValues(OpResult handle, Range &&values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
void setParams(OpResult value, ArrayRef< TransformState::Param > params)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
LogicalResult notifyPayloadOperationReplaced(Operation *op, Operation *replacement)
Notify the transform dialect interpreter that the given op has been replaced with another op and that...
The state maintained across applications of various ops implementing the TransformOpInterface.
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
Definition: AffineOps.cpp:1241
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1144
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1194
LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze op and its nested ops.
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition: Visitors.h:136
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
FailureOr< PackingResult > buildPackingLoopNest(RewriterBase &rewriter, tensor::PadOp opToHoist, scf::ForOp outermostEnclosingForOp, ArrayRef< int64_t > transposeVector)
Build the packing loop nest required to hoist opToHoist above outermostEnclosingForOp.
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp namedOp)
Create a GenericOp from the given named operation namedOp and replace namedOp.
LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, const LinalgPaddingOptions &options, LinalgOp &paddedOp, SmallVector< Value > &replacements, SmallVector< tensor::PadOp > &padOps)
Pad the iterator dimensions paddingDimensions of all opToPad operands to a static bounding box.
Definition: Padding.cpp:153
FailureOr< std::pair< Operation *, Operation * > > rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp)
Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) and linalg....
bool hasVectorizationImpl(Operation *)
Return true if there's dedicated logic in the Linalg Vectorizer to vectorize this Op,...
FailureOr< LowerUnPackOpResult > lowerUnPack(RewriterBase &rewriter, tensor::UnPackOp unPackOp)
Rewrite pack as empty + transpose + reshape + extract_slice.
Definition: Transforms.cpp:354
FailureOr< Operation * > decomposeWinogradFilterTransformOp(RewriterBase &rewriter, linalg::WinogradFilterTransformOp op)
Rewrite linalg.winograd_filter_transform.
std::optional< Value > allocateWorkgroupMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU workgroup memory.
Definition: Promotion.cpp:470
Value bufferizeToAllocation(RewriterBase &rewriter, const BufferizeToAllocationOptions &options, tensor::PadOp padOp, Attribute memorySpace={}, Operation *insertionPoint=nullptr)
Materialize a buffer allocation for the given tensor.pad op and lower the op to linalg....
FailureOr< Value > hoistPaddingOnTensors(RewriterBase &rewriter, tensor::PadOp opToHoist, int64_t numLoops, ArrayRef< int64_t > transposeVector, tensor::PadOp &hoistedOp, SmallVectorImpl< TransposeOp > &transposeOps)
Mechanically hoist padding operations on tensors by numLoops into a new, generally larger tensor.
FailureOr< LinalgOp > specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp)
Create a namedOp from the given GenericOp and replace the GenericOp.
Definition: Specialize.cpp:260
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns)
Definition: Tiling.cpp:860
LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value)
In case of GPU private memory there is no need to deallocate since the memory is freed when going out...
Definition: Promotion.cpp:511
FailureOr< Operation * > decomposeWinogradOutputTransformOp(RewriterBase &rewriter, linalg::WinogradOutputTransformOp op)
Rewrite linalg.winograd_output_transform.
std::optional< Value > allocateGPUPrivateMemory(OpBuilder &builder, memref::SubViewOp subview, ArrayRef< Value > sizeBounds, DataLayout &)
Allocate the subview in the GPU private memory.
Definition: Promotion.cpp:495
FailureOr< Operation * > rewriteInDestinationPassingStyle(RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp)
Rewrite tensor.from_elements to linalg.generic.
FailureOr< Operation * > winogradConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op, int64_t m, int64_t r)
Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm F(m x m, r x r).
FailureOr< Operation * > transposeConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op)
Convert linalg.conv_2d_nhwc_fhwc(_q) to linalg.conv_2d_nhwc_hwcf(_q) by materializing transpose.
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
Patterns to fold unit-extent dimensions in operands/results of linalg ops on tensors via reassociativ...
LogicalResult copyToWorkgroupMemory(OpBuilder &b, Value src, Value dst)
Create Memref copy operations and add gpu barrier guards before and after the copy operation to ensur...
Definition: Promotion.cpp:486
LogicalResult linalgOpAnchoredEmptyTensorEliminationStep(RewriterBase &rewriter, Operation *op, bufferization::OneShotAnalysisState &state)
Try to eliminate tensor::EmptyOps inside op that are anchored on a LinalgOp.
LogicalResult vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Emit a suitable vector form for an operation.
FailureOr< Operation * > transposeBatchMatmul(RewriterBase &rewriter, linalg::BatchMatmulOp op, bool transposeLHS=true)
Pattern to replace.
LogicalResult promoteSubviewsPrecondition(Operation *op, LinalgPromotionOptions options)
Promote memref.subviews feeding linalg-on-buffers operations.
Definition: Promotion.cpp:399
LogicalResult copyToGPUPrivateMemory(OpBuilder &b, Value src, Value dst)
Normal copy to between src and dst.
Definition: Promotion.cpp:503
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
Definition: Utils.cpp:169
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.
Definition: Interchange.cpp:50
FailureOr< StaticMultiSizeSpecification > computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, int64_t targetSize, int64_t divisor)
Definition: Tiling.cpp:242
void populateDecomposePackUnpackPatterns(RewritePatternSet &patterns)
Populates patterns to decompose tensor.pack and tensor.unpack Ops into e.g.
FailureOr< ContinuousTileSizeSpecification > computeContinuousTileSizes(OpBuilder &builder, TilingInterface op, unsigned dimension, OpFoldResult targetSize, bool emitAssertions)
Definition: Tiling.cpp:162
FailureOr< StaticContinuousTileSizeSpecification > computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension, unsigned targetSize)
Definition: Tiling.cpp:111
FailureOr< SplitReductionResult > splitReduction(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
void hoistRedundantVectorBroadcasts(RewriterBase &rewriter, Operation *root)
Hoist vector.extract/vector.broadcast pairs out of immediately enclosing scf::ForOp iteratively,...
Definition: Hoisting.cpp:97
FailureOr< ForallReductionTilingResult > tileReductionUsingForall(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > tileSizes={}, std::optional< ArrayAttr > mapping=std::nullopt)
Method to tile a reduction to parallel iterations computing partial reductions.
Definition: Tiling.cpp:594
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
Definition: Transforms.cpp:766
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
Definition: Transforms.cpp:477
void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns)
Patterns to promote inputs to outputs and remove unused inputs of linalg.generic ops.
FailureOr< LowerPackResult > lowerPack(RewriterBase &rewriter, tensor::PackOp packOp)
Rewrite pack as pad + reshape + transpose.
Definition: Transforms.cpp:219
FailureOr< LinalgOp > promoteSubViews(OpBuilder &b, LinalgOp op, const LinalgPromotionOptions &options)
Promote the subViews into a new buffer allocated at the insertion point b.
Definition: Promotion.cpp:421
std::function< SplitReductionOptions(LinalgOp op)> ControlSplitReductionFn
Function signature to control reduction splitting.
Definition: Transforms.h:442
LogicalResult deallocateWorkgroupMemory(OpBuilder &, Value)
In case of GPU group memory there is no need to deallocate.
Definition: Promotion.cpp:479
FailureOr< Operation * > transposeMatmul(RewriterBase &rewriter, linalg::MatmulOp op, bool transposeLHS=true)
Convert Linalg matmul ops to transposed variants.
FailureOr< CollapseResult > collapseOpIterationDims(LinalgOp op, ArrayRef< ReassociationIndices > foldedIterationDims, RewriterBase &rewriter)
Collapses dimensions of linalg.generic/linalg.copy operation.
void hoistRedundantVectorTransfers(Operation *root, bool verifyNonZeroTrip=false)
Hoist vector.transfer_read/vector.transfer_write on buffers pairs out of immediately enclosing scf::F...
Definition: Hoisting.cpp:202
FailureOr< Operation * > decomposeWinogradInputTransformOp(RewriterBase &rewriter, linalg::WinogradInputTransformOp op)
Rewrite linalg.winograd_input_transform.
void populateDecomposePadPatterns(RewritePatternSet &patterns)
Populates patterns to decompose tensor.pad into e.g.
void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns)
Pattern to replace linalg.add when destination passing on a contraction op suffices for achieving the...
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, tensor::PackOp packOp, linalg::LinalgOp linalgOp, tensor::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
Definition: Transforms.cpp:675
std::pair< TilingInterface, TilingInterface > splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension, OpFoldResult splitPoint)
Split the given op into two parts along the given iteration space dimension at the specified splitPoi...
Definition: Split.cpp:67
void populateInsertSliceVectorizationPatterns(RewritePatternSet &patterns)
Populates patterns with vectorisation patterns for tensor.insert_slice.
FailureOr< SplitReductionResult > splitReductionByScaling(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
Scaling-based implementation of the split reduction transformation.
FailureOr< MultiSizeSpecification > computeMultiTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension, OpFoldResult targetSize, OpFoldResult divisor, bool emitAssertions=true)
Emits the IR computing the multi-sized tiling specification with two tile sizes not exceeding targetS...
Definition: Tiling.cpp:268
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:77
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
FailureOr< scf::SCFReductionTilingResult > tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > tileSize)
Method to tile a reduction and generate a parallel op within a serial loop.
FailureOr< SCFTilingResult > tileUsingSCF(RewriterBase &rewriter, TilingInterface op, const SCFTilingOptions &options)
Method to tile an op that implements the TilingInterface using scf.for for iterating over the tiles.
ForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
Definition: SCF.cpp:597
FailureOr< SmallVector< scf::ForOp > > lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op)
Method to lower an op that implements the TilingInterface to loops/scalars.
uint64_t getM(LevelType lt)
Definition: Enums.h:443
void populateMergeConsecutiveInsertExtractSlicePatterns(RewritePatternSet &patterns)
Collects patterns to merge consecutive tensor.insert_slice/extract_slice into one.
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:110
void populateFoldTensorSubsetIntoVectorTransferPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into vector transfer ops.
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
DiagnosedSilenceableFailure tileToForallOpImpl(RewriterBase &rewriter, transform::TransformState &state, TransformOpInterface transformOp, Operation *target, ArrayRef< OpFoldResult > mixedNumThreads, ArrayRef< OpFoldResult > mixedTileSizes, std::optional< ArrayAttr > mapping, scf::SCFTilingResult &tilingResult)
Implementation of tiling operations using scf.forall.
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e....
void populateVectorStepLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant Vector Ops by re-ordering them with e.g.
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Definition: XeGPUOps.cpp:22
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:491
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:348
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:362
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:426
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
This is the representation of an operand reference.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Definition: Builders.h:294
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
NamedAttrList attributes
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
A listener that forwards all notifications to another listener.
Definition: PatternMatch.h:463
ForwardingListener(OpBuilder::Listener *listener)
Definition: PatternMatch.h:464
Container for result values of tiling.
SmallVector< Value > tiledValues
Options for analysis-enabled bufferization.
Transformation to drop unit-extent dimensions from linalg.generic operations.
Definition: Transforms.h:473
Vectorization pattern for memref::CopyOp.
Definition: Transforms.h:1494
Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) dimensions into 1-D depthwise c...
Definition: Transforms.h:1426
Match and rewrite for the pattern:
Definition: Transforms.h:1592
Match and rewrite for the pattern:
Definition: Transforms.h:1620
LinalgPromotionOptions & setUseFullTileBuffersByDefault(bool use)
Definition: Transforms.h:379
LinalgPromotionOptions & setAlignment(unsigned align)
Definition: Transforms.h:385
LinalgPromotionOptions & setUseAlloca(bool use)
Definition: Transforms.h:398
LinalgPromotionOptions & setCopyInOutFns(CopyCallbackFn const &copyIn, CopyCallbackFn const &copyOut)
Definition: Transforms.h:418
LinalgPromotionOptions & setUseFullTileBuffers(ArrayRef< bool > useFullTiles)
Definition: Transforms.h:368
LinalgPromotionOptions & setMemorySpace(Attribute memorySpc)
Definition: Transforms.h:392
LinalgPromotionOptions & setAllocationDeallocationFns(AllocBufferCallbackFn const &allocFn, DeallocBufferCallbackFn const &deallocFn)
Definition: Transforms.h:408
LinalgPromotionOptions & setOperandsToPromote(ArrayRef< int64_t > operands)
Definition: Transforms.h:357
Split Reduction options.
Definition: Transforms.h:427
Options used to control tile + fuse.
SCFTilingOptions tilingOptions
The tiling options used to control the tiling of the consumer.
std::optional< FrozenRewritePatternSet > cleanupPatterns
An optional set of rewrite patterns to apply to the results of tiling before fusion.
Options to use to control tiling.
SCFTilingOptions & setTileSizeComputationFunction(SCFTileSizeComputationFunction fun)
SCFTilingOptions & setInterchange(ArrayRef< int64_t > interchange)
SCFTilingOptions & setTileSizes(ArrayRef< OpFoldResult > tileSizes)
Convenience function to set the tileSizeComputationFunction to a function that computes tile sizes at...
SmallVector< int64_t > interchangeVector
The interchange vector to reorder the tiled loops.
Transformation information returned after tiling.
SmallVector< Operation * > tiledOps
Tiled operations that are generated during tiling.
SmallVector< LoopLikeOpInterface > loops
The scf.for operations that iterate over the tiles.